refactor(database): use a new 'On-Demand' design (#86)

Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
MingxuanGame
2025-11-23 21:41:02 +08:00
committed by GitHub
parent 42f1d53d3e
commit 40da994ae8
46 changed files with 4396 additions and 2354 deletions

View File

@@ -3,14 +3,14 @@ Beatmapset缓存服务
用于缓存beatmapset数据减少数据库查询频率
"""
from datetime import datetime
import hashlib
import json
from typing import TYPE_CHECKING
from app.config import settings
from app.database.beatmapset import BeatmapsetResp
from app.database import BeatmapsetDict
from app.log import logger
from app.utils import safe_json_dumps
from redis.asyncio import Redis
@@ -18,20 +18,6 @@ if TYPE_CHECKING:
pass
class DateTimeEncoder(json.JSONEncoder):
"""处理datetime序列化的JSON编码器"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
def safe_json_dumps(data) -> str:
"""安全的JSON序列化处理datetime对象"""
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False)
def generate_hash(data) -> str:
"""生成数据的MD5哈希值"""
content = data if isinstance(data, str) else safe_json_dumps(data)
@@ -57,15 +43,14 @@ class BeatmapsetCacheService:
"""生成搜索结果缓存键"""
return f"beatmapset_search:{query_hash}:{cursor_hash}"
async def get_beatmapset_from_cache(self, beatmapset_id: int) -> BeatmapsetResp | None:
async def get_beatmapset_from_cache(self, beatmapset_id: int) -> BeatmapsetDict | None:
"""从缓存获取beatmapset信息"""
try:
cache_key = self._get_beatmapset_cache_key(beatmapset_id)
cached_data = await self.redis.get(cache_key)
if cached_data:
logger.debug(f"Beatmapset cache hit for {beatmapset_id}")
data = json.loads(cached_data)
return BeatmapsetResp(**data)
return json.loads(cached_data)
return None
except (ValueError, TypeError, AttributeError) as e:
logger.error(f"Error getting beatmapset from cache: {e}")
@@ -73,24 +58,21 @@ class BeatmapsetCacheService:
async def cache_beatmapset(
self,
beatmapset_resp: BeatmapsetResp,
beatmapset_resp: BeatmapsetDict,
expire_seconds: int | None = None,
):
"""缓存beatmapset信息"""
try:
if expire_seconds is None:
expire_seconds = self._default_ttl
if beatmapset_resp.id is None:
logger.warning("Cannot cache beatmapset with None id")
return
cache_key = self._get_beatmapset_cache_key(beatmapset_resp.id)
cached_data = beatmapset_resp.model_dump_json()
cache_key = self._get_beatmapset_cache_key(beatmapset_resp["id"])
cached_data = safe_json_dumps(beatmapset_resp)
await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore
logger.debug(f"Cached beatmapset {beatmapset_resp.id} for {expire_seconds}s")
logger.debug(f"Cached beatmapset {beatmapset_resp['id']} for {expire_seconds}s")
except (ValueError, TypeError, AttributeError) as e:
logger.error(f"Error caching beatmapset: {e}")
async def get_beatmap_lookup_from_cache(self, beatmap_id: int) -> BeatmapsetResp | None:
async def get_beatmap_lookup_from_cache(self, beatmap_id: int) -> BeatmapsetDict | None:
"""从缓存获取通过beatmap ID查找的beatmapset信息"""
try:
cache_key = self._get_beatmap_lookup_cache_key(beatmap_id)
@@ -98,7 +80,7 @@ class BeatmapsetCacheService:
if cached_data:
logger.debug(f"Beatmap lookup cache hit for {beatmap_id}")
data = json.loads(cached_data)
return BeatmapsetResp(**data)
return data
return None
except (ValueError, TypeError, AttributeError) as e:
logger.error(f"Error getting beatmap lookup from cache: {e}")
@@ -107,7 +89,7 @@ class BeatmapsetCacheService:
async def cache_beatmap_lookup(
self,
beatmap_id: int,
beatmapset_resp: BeatmapsetResp,
beatmapset_resp: BeatmapsetDict,
expire_seconds: int | None = None,
):
"""缓存通过beatmap ID查找的beatmapset信息"""
@@ -115,7 +97,7 @@ class BeatmapsetCacheService:
if expire_seconds is None:
expire_seconds = self._default_ttl
cache_key = self._get_beatmap_lookup_cache_key(beatmap_id)
cached_data = beatmapset_resp.model_dump_json()
cached_data = safe_json_dumps(beatmapset_resp)
await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore
logger.debug(f"Cached beatmap lookup {beatmap_id} for {expire_seconds}s")
except (ValueError, TypeError, AttributeError) as e:

View File

@@ -3,12 +3,12 @@ from datetime import timedelta
from enum import Enum
import math
import random
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, NamedTuple, cast
from app.config import OldScoreProcessingMode, settings
from app.database.beatmap import Beatmap, BeatmapResp
from app.database.beatmap import Beatmap, BeatmapDict
from app.database.beatmap_sync import BeatmapSync, SavedBeatmapMeta
from app.database.beatmapset import Beatmapset, BeatmapsetResp
from app.database.beatmapset import Beatmapset, BeatmapsetDict
from app.database.score import Score
from app.dependencies.database import get_redis, with_db
from app.dependencies.storage import get_storage_service
@@ -62,10 +62,23 @@ STATUS_FACTOR: dict[BeatmapRankStatus, float] = {
SCHEDULER_INTERVAL_MINUTES = 2
class EnsuredBeatmap(BeatmapDict):
checksum: str
ranked: int
class EnsuredBeatmapset(BeatmapsetDict):
ranked: int
ranked_date: datetime.datetime
last_updated: datetime.datetime
play_count: int
beatmaps: list[EnsuredBeatmap]
class ProcessingBeatmapset:
def __init__(self, beatmapset: BeatmapsetResp, record: BeatmapSync) -> None:
def __init__(self, beatmapset: EnsuredBeatmapset, record: BeatmapSync) -> None:
self.beatmapset = beatmapset
self.status = BeatmapRankStatus(self.beatmapset.ranked)
self.status = BeatmapRankStatus(self.beatmapset["ranked"])
self.record = record
def calculate_next_sync_time(
@@ -76,19 +89,19 @@ class ProcessingBeatmapset:
now = utcnow()
if self.status == BeatmapRankStatus.QUALIFIED:
assert self.beatmapset.ranked_date is not None, "ranked_date should not be None for qualified maps"
time_to_ranked = (self.beatmapset.ranked_date + timedelta(days=7) - now).total_seconds()
assert self.beatmapset["ranked_date"] is not None, "ranked_date should not be None for qualified maps"
time_to_ranked = (self.beatmapset["ranked_date"] + timedelta(days=7) - now).total_seconds()
baseline = max(MIN_DELTA, time_to_ranked / 2)
next_delta = max(MIN_DELTA, baseline)
elif self.status in {BeatmapRankStatus.WIP, BeatmapRankStatus.PENDING}:
seconds_since_update = (now - self.beatmapset.last_updated).total_seconds()
seconds_since_update = (now - self.beatmapset["last_updated"]).total_seconds()
factor_update = max(1.0, seconds_since_update / TAU)
factor_play = 1.0 + math.log(1.0 + self.beatmapset.play_count)
factor_play = 1.0 + math.log(1.0 + self.beatmapset["play_count"])
status_factor = STATUS_FACTOR[self.status]
baseline = BASE * factor_play / factor_update * status_factor
next_delta = max(MIN_DELTA, baseline * (GROWTH ** (self.record.consecutive_no_change + 1)))
elif self.status == BeatmapRankStatus.GRAVEYARD:
days_since_update = (now - self.beatmapset.last_updated).days
days_since_update = (now - self.beatmapset["last_updated"]).days
doubling_periods = days_since_update / GRAVEYARD_DOUBLING_PERIOD_DAYS
delta = MIN_DELTA * (2**doubling_periods)
max_seconds = GRAVEYARD_MAX_DAYS * 86400
@@ -105,21 +118,24 @@ class ProcessingBeatmapset:
@property
def beatmapset_changed(self) -> bool:
return self.record.beatmap_status != BeatmapRankStatus(self.beatmapset.ranked)
return self.record.beatmap_status != BeatmapRankStatus(self.beatmapset["ranked"])
@property
def changed_beatmaps(self) -> list[ChangedBeatmap]:
changed_beatmaps = []
for bm in self.beatmapset.beatmaps:
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None)
for bm in self.beatmapset["beatmaps"]:
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm["id"]), None)
if not saved or saved["is_deleted"]:
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
elif saved["md5"] != bm.checksum:
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED))
elif saved["beatmap_status"] != BeatmapRankStatus(bm.ranked):
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.STATUS_CHANGED))
changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.MAP_ADDED))
elif saved["md5"] != bm["checksum"]:
changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.MAP_UPDATED))
elif saved["beatmap_status"] != BeatmapRankStatus(bm["ranked"]):
changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.STATUS_CHANGED))
for saved in self.record.beatmaps:
if not any(bm.id == saved["beatmap_id"] for bm in self.beatmapset.beatmaps) and not saved["is_deleted"]:
if (
not any(bm["id"] == saved["beatmap_id"] for bm in self.beatmapset["beatmaps"])
and not saved["is_deleted"]
):
changed_beatmaps.append(ChangedBeatmap(saved["beatmap_id"], BeatmapChangeType.MAP_DELETED))
return changed_beatmaps
@@ -132,7 +148,7 @@ class BeatmapsetUpdateService:
async def add_missing_beatmapset(self, beatmapset_id: int, immediate: bool = False) -> bool:
beatmapset = await self.fetcher.get_beatmapset(beatmapset_id)
if immediate:
await self._sync_immediately(beatmapset)
await self._sync_immediately(cast(EnsuredBeatmapset, beatmapset))
logger.debug(f"triggered immediate sync for beatmapset {beatmapset_id} ")
return True
await self.add(beatmapset)
@@ -172,7 +188,7 @@ class BeatmapsetUpdateService:
BeatmapSync(
beatmapset_id=missing,
beatmap_status=BeatmapRankStatus.GRAVEYARD,
next_sync_time=datetime.datetime.max,
next_sync_time=datetime.datetime(year=6000, month=1, day=1),
beatmaps=[],
)
)
@@ -185,11 +201,13 @@ class BeatmapsetUpdateService:
await session.commit()
self._adding_missing = False
async def add(self, beatmapset: BeatmapsetResp, calculate_next_sync: bool = True):
async def add(self, set: BeatmapsetDict, calculate_next_sync: bool = True):
beatmapset = cast(EnsuredBeatmapset, set)
async with with_db() as session:
sync_record = await session.get(BeatmapSync, beatmapset.id)
beatmapset_id = beatmapset["id"]
sync_record = await session.get(BeatmapSync, beatmapset_id)
if not sync_record:
database_beatmapset = await session.get(Beatmapset, beatmapset.id)
database_beatmapset = await session.get(Beatmapset, beatmapset_id)
if database_beatmapset:
status = BeatmapRankStatus(database_beatmapset.beatmap_status)
await database_beatmapset.awaitable_attrs.beatmaps
@@ -203,19 +221,29 @@ class BeatmapsetUpdateService:
for bm in database_beatmapset.beatmaps
]
else:
status = BeatmapRankStatus(beatmapset.ranked)
beatmaps = [
SavedBeatmapMeta(
beatmap_id=bm.id,
md5=bm.checksum,
is_deleted=False,
beatmap_status=BeatmapRankStatus(bm.ranked),
ranked = beatmapset.get("ranked")
if ranked is None:
raise ValueError("ranked field is required")
status = BeatmapRankStatus(ranked)
beatmap_list = beatmapset.get("beatmaps", [])
beatmaps = []
for bm in beatmap_list:
bm_id = bm.get("id")
checksum = bm.get("checksum")
ranked = bm.get("ranked")
if bm_id is None or checksum is None or ranked is None:
continue
beatmaps.append(
SavedBeatmapMeta(
beatmap_id=bm_id,
md5=checksum,
is_deleted=False,
beatmap_status=BeatmapRankStatus(ranked),
)
)
for bm in beatmapset.beatmaps
]
sync_record = BeatmapSync(
beatmapset_id=beatmapset.id,
beatmapset_id=beatmapset_id,
beatmaps=beatmaps,
beatmap_status=status,
)
@@ -223,13 +251,27 @@ class BeatmapsetUpdateService:
await session.commit()
await session.refresh(sync_record)
else:
sync_record.beatmaps = [
SavedBeatmapMeta(
beatmap_id=bm.id, md5=bm.checksum, is_deleted=False, beatmap_status=BeatmapRankStatus(bm.ranked)
ranked = beatmapset.get("ranked")
if ranked is None:
raise ValueError("ranked field is required")
beatmap_list = beatmapset.get("beatmaps", [])
beatmaps = []
for bm in beatmap_list:
bm_id = bm.get("id")
checksum = bm.get("checksum")
bm_ranked = bm.get("ranked")
if bm_id is None or checksum is None or bm_ranked is None:
continue
beatmaps.append(
SavedBeatmapMeta(
beatmap_id=bm_id,
md5=checksum,
is_deleted=False,
beatmap_status=BeatmapRankStatus(bm_ranked),
)
)
for bm in beatmapset.beatmaps
]
sync_record.beatmap_status = BeatmapRankStatus(beatmapset.ranked)
sync_record.beatmaps = beatmaps
sync_record.beatmap_status = BeatmapRankStatus(ranked)
if calculate_next_sync:
processing = ProcessingBeatmapset(beatmapset, sync_record)
next_time_delta = processing.calculate_next_sync_time()
@@ -238,17 +280,19 @@ class BeatmapsetUpdateService:
await BeatmapsetUpdateService._sync_immediately(self, beatmapset)
return
sync_record.next_sync_time = utcnow() + next_time_delta
logger.opt(colors=True).info(f"<g>[{beatmapset.id}]</g> next sync at {sync_record.next_sync_time}")
beatmapset_id = beatmapset.get("id")
if beatmapset_id:
logger.opt(colors=True).debug(f"<g>[{beatmapset_id}]</g> next sync at {sync_record.next_sync_time}")
await session.commit()
async def _sync_immediately(self, beatmapset: BeatmapsetResp) -> None:
async def _sync_immediately(self, beatmapset: EnsuredBeatmapset) -> None:
async with with_db() as session:
record = await session.get(BeatmapSync, beatmapset.id)
record = await session.get(BeatmapSync, beatmapset["id"])
if not record:
record = BeatmapSync(
beatmapset_id=beatmapset.id,
beatmapset_id=beatmapset["id"],
beatmaps=[],
beatmap_status=BeatmapRankStatus(beatmapset.ranked),
beatmap_status=BeatmapRankStatus(beatmapset["ranked"]),
)
session.add(record)
await session.commit()
@@ -261,19 +305,18 @@ class BeatmapsetUpdateService:
record: BeatmapSync,
session: AsyncSession,
*,
beatmapset: BeatmapsetResp | None = None,
beatmapset: EnsuredBeatmapset | None = None,
):
logger.opt(colors=True).info(f"<g>[{record.beatmapset_id}]</g> syncing...")
logger.opt(colors=True).debug(f"<g>[{record.beatmapset_id}]</g> syncing...")
if beatmapset is None:
try:
beatmapset = await self.fetcher.get_beatmapset(record.beatmapset_id)
beatmapset = cast(EnsuredBeatmapset, await self.fetcher.get_beatmapset(record.beatmapset_id))
except Exception as e:
if isinstance(e, HTTPStatusError) and e.response.status_code == 404:
logger.opt(colors=True).warning(
f"<g>[{record.beatmapset_id}]</g> beatmapset not found (404), removing from sync list"
)
await session.delete(record)
await session.commit()
return
if isinstance(e, HTTPError):
logger.opt(colors=True).warning(
@@ -292,20 +335,20 @@ class BeatmapsetUpdateService:
if changed:
record.beatmaps = [
SavedBeatmapMeta(
beatmap_id=bm.id,
md5=bm.checksum,
beatmap_id=bm["id"],
md5=bm["checksum"],
is_deleted=False,
beatmap_status=BeatmapRankStatus(bm.ranked),
beatmap_status=BeatmapRankStatus(bm["ranked"]),
)
for bm in beatmapset.beatmaps
for bm in beatmapset["beatmaps"]
]
record.beatmap_status = BeatmapRankStatus(beatmapset.ranked)
record.beatmap_status = BeatmapRankStatus(beatmapset["ranked"])
record.consecutive_no_change = 0
bg_tasks.add_task(
self._process_changed_beatmaps,
changed_beatmaps,
beatmapset.beatmaps,
beatmapset["beatmaps"],
)
bg_tasks.add_task(
self._process_changed_beatmapset,
@@ -317,13 +360,13 @@ class BeatmapsetUpdateService:
next_time_delta = processing.calculate_next_sync_time()
if not next_time_delta:
logger.opt(colors=True).info(
f"<yellow>[{beatmapset.id}]</yellow> beatmapset has transformed to ranked or loved,"
f"<yellow>[{beatmapset['id']}]</yellow> beatmapset has transformed to ranked or loved,"
f" removing from sync list"
)
await session.delete(record)
else:
record.next_sync_time = utcnow() + next_time_delta
logger.opt(colors=True).info(f"<g>[{record.beatmapset_id}]</g> next sync at {record.next_sync_time}")
logger.opt(colors=True).debug(f"<g>[{record.beatmapset_id}]</g> next sync at {record.next_sync_time}")
async def _update_beatmaps(self):
async with with_db() as session:
@@ -338,18 +381,18 @@ class BeatmapsetUpdateService:
await self.sync(record, session)
await session.commit()
async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp):
async def _process_changed_beatmapset(self, beatmapset: EnsuredBeatmapset):
async with with_db() as session:
db_beatmapset = await session.get(Beatmapset, beatmapset.id)
new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset)
db_beatmapset = await session.get(Beatmapset, beatmapset["id"])
new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset) # pyright: ignore[reportArgumentType]
if db_beatmapset:
await session.merge(new_beatmapset)
await get_beatmapset_cache_service(get_redis()).invalidate_beatmapset_cache(beatmapset.id)
await get_beatmapset_cache_service(get_redis()).invalidate_beatmapset_cache(beatmapset["id"])
await session.commit()
async def _process_changed_beatmaps(self, changed: list[ChangedBeatmap], beatmaps_list: list[BeatmapResp]):
async def _process_changed_beatmaps(self, changed: list[ChangedBeatmap], beatmaps_list: list[EnsuredBeatmap]):
storage_service = get_storage_service()
beatmaps = {bm.id: bm for bm in beatmaps_list}
beatmaps = {bm["id"]: bm for bm in beatmaps_list}
async with with_db() as session:
@@ -380,9 +423,9 @@ class BeatmapsetUpdateService:
)
continue
logger.opt(colors=True).info(
f"<g>[{beatmap.beatmapset_id}]</g> adding beatmap <blue>{beatmap.id}</blue>"
f"<g>[{beatmap['beatmapset_id']}]</g> adding beatmap <blue>{beatmap['id']}</blue>"
)
await Beatmap.from_resp_no_save(session, beatmap)
await Beatmap.from_resp_no_save(session, beatmap) # pyright: ignore[reportArgumentType]
else:
beatmap = beatmaps.get(change.beatmap_id)
if not beatmap:
@@ -391,10 +434,10 @@ class BeatmapsetUpdateService:
)
continue
logger.opt(colors=True).info(
f"<g>[{beatmap.beatmapset_id}]</g> processing beatmap <blue>{beatmap.id}</blue> "
f"<g>[{beatmap['beatmapset_id']}]</g> processing beatmap <blue>{beatmap['id']}</blue> "
f"change <cyan>{change.type}</cyan>"
)
new_db_beatmap = await Beatmap.from_resp_no_save(session, beatmap)
new_db_beatmap = await Beatmap.from_resp_no_save(session, beatmap) # pyright: ignore[reportArgumentType]
existing_beatmap = await session.get(Beatmap, change.beatmap_id)
if existing_beatmap:
await session.merge(new_db_beatmap)

View File

@@ -4,16 +4,15 @@
"""
import asyncio
from datetime import datetime
import json
from typing import TYPE_CHECKING, Literal
from app.config import settings
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.database.statistics import UserStatistics, UserStatisticsModel
from app.helpers.asset_proxy_helper import replace_asset_urls
from app.log import logger
from app.models.score import GameMode
from app.utils import utcnow
from app.utils import safe_json_dumps, utcnow
from redis.asyncio import Redis
from sqlmodel import col, select
@@ -23,20 +22,6 @@ if TYPE_CHECKING:
pass
class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
def safe_json_dumps(data) -> str:
"""安全的 JSON 序列化,支持 datetime 对象"""
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
class RankingCacheService:
"""用户排行榜缓存服务"""
@@ -311,7 +296,7 @@ class RankingCacheService:
col(UserStatistics.pp) > 0,
col(UserStatistics.is_ranked).is_(True),
]
include = ["user"]
include = UserStatistics.RANKING_INCLUDES.copy()
if type == "performance":
order_by = col(UserStatistics.pp).desc()
@@ -321,6 +306,7 @@ class RankingCacheService:
if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
include.append("country_rank")
# 获取总用户数用于统计
total_users_query = select(UserStatistics).where(*wheres)
@@ -353,9 +339,9 @@ class RankingCacheService:
# 转换为响应格式并确保正确序列化
ranking_data = []
for statistics in statistics_data:
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
user_stats_resp = await UserStatisticsModel.transform(statistics, includes=include)
user_dict = user_stats_resp.model_dump()
user_dict = user_stats_resp
# 应用资源代理处理
if settings.enable_asset_proxy:

View File

@@ -8,14 +8,14 @@
import asyncio
from datetime import datetime
import json
import time
from typing import Any
from app.database.chat import ChatMessage, ChatMessageResp, MessageType
from app.database.user import RANKING_INCLUDES, User, UserResp
from app.database import ChatMessageDict
from app.database.chat import ChatMessage, ChatMessageModel, MessageType
from app.database.user import User, UserModel
from app.dependencies.database import get_redis_message, with_db
from app.log import logger
from app.utils import bg_tasks
from app.utils import bg_tasks, safe_json_dumps
class RedisMessageSystem:
@@ -35,7 +35,7 @@ class RedisMessageSystem:
content: str,
is_action: bool = False,
user_uuid: str | None = None,
) -> ChatMessageResp:
) -> "ChatMessageDict":
"""
发送消息 - 立即存储到 Redis 并返回
@@ -47,7 +47,7 @@ class RedisMessageSystem:
user_uuid: 用户UUID
Returns:
ChatMessageResp: 消息响应对象
ChatMessage: 消息响应对象
"""
# 生成消息ID和时间戳
message_id = await self._generate_message_id(channel_id)
@@ -57,28 +57,16 @@ class RedisMessageSystem:
if not user.id:
raise ValueError("User ID is required")
# 获取频道类型以判断是否需要存储到数据库
async with with_db() as session:
from app.database.chat import ChannelType, ChatChannel
from sqlmodel import select
channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id))
channel_type = channel_result.first()
is_multiplayer = channel_type == ChannelType.MULTIPLAYER
# 准备消息数据
message_data = {
message_data: "ChatMessageDict" = {
"message_id": message_id,
"channel_id": channel_id,
"sender_id": user.id,
"content": content,
"timestamp": timestamp.isoformat(),
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
"timestamp": timestamp,
"type": MessageType.ACTION if is_action else MessageType.PLAIN,
"uuid": user_uuid or "",
"status": "cached", # Redis 缓存状态
"created_at": time.time(),
"is_multiplayer": is_multiplayer, # 标记是否为多人房间消息
"is_action": is_action,
}
# 立即存储到 Redis
@@ -86,51 +74,13 @@ class RedisMessageSystem:
# 创建响应对象
async with with_db() as session:
user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES)
user_resp = await UserModel.transform(user, session=session, includes=User.LIST_INCLUDES)
message_data["sender"] = user_resp
# 确保 statistics 不为空
if user_resp.statistics is None:
from app.database.statistics import UserStatisticsResp
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
return message_data
user_resp.statistics = UserStatisticsResp(
mode=user.playmode,
global_rank=0,
country_rank=0,
pp=0.0,
ranked_score=0,
hit_accuracy=0.0,
play_count=0,
play_time=0,
total_score=0,
total_hits=0,
maximum_combo=0,
replays_watched_by_others=0,
is_ranked=False,
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
level={"current": 1, "progress": 0},
)
response = ChatMessageResp(
message_id=message_id,
channel_id=channel_id,
content=content,
timestamp=timestamp,
sender_id=user.id,
sender=user_resp,
is_action=is_action,
uuid=user_uuid,
)
if is_multiplayer:
logger.info(
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id},"
" will not be persisted to database"
)
else:
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
return response
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]:
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
"""
获取频道消息 - 优先从 Redis 获取最新消息
@@ -140,9 +90,9 @@ class RedisMessageSystem:
since: 起始消息ID
Returns:
List[ChatMessageResp]: 消息列表
List[ChatMessageDict]: 消息列表
"""
messages = []
messages: list["ChatMessageDict"] = []
try:
# 从 Redis 获取最新消息
@@ -154,45 +104,21 @@ class RedisMessageSystem:
# 获取发送者信息
sender = await session.get(User, msg_data["sender_id"])
if sender:
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
user_resp = await UserModel.transform(sender, includes=User.LIST_INCLUDES)
if user_resp.statistics is None:
from app.database.statistics import UserStatisticsResp
from app.database.chat import ChatMessageDict
user_resp.statistics = UserStatisticsResp(
mode=sender.playmode,
global_rank=0,
country_rank=0,
pp=0.0,
ranked_score=0,
hit_accuracy=0.0,
play_count=0,
play_time=0,
total_score=0,
total_hits=0,
maximum_combo=0,
replays_watched_by_others=0,
is_ranked=False,
grade_counts={
"ssh": 0,
"ss": 0,
"sh": 0,
"s": 0,
"a": 0,
},
level={"current": 1, "progress": 0},
)
message_resp = ChatMessageResp(
message_id=msg_data["message_id"],
channel_id=msg_data["channel_id"],
content=msg_data["content"],
timestamp=datetime.fromisoformat(msg_data["timestamp"]),
sender_id=msg_data["sender_id"],
sender=user_resp,
is_action=msg_data["type"] == MessageType.ACTION.value,
uuid=msg_data.get("uuid") or None,
)
message_resp: ChatMessageDict = {
"message_id": msg_data["message_id"],
"channel_id": msg_data["channel_id"],
"content": msg_data["content"],
"timestamp": datetime.fromisoformat(msg_data["timestamp"]), # pyright: ignore[reportArgumentType]
"sender_id": msg_data["sender_id"],
"sender": user_resp,
"is_action": msg_data["type"] == MessageType.ACTION.value,
"uuid": msg_data.get("uuid") or None,
"type": MessageType(msg_data["type"]),
}
messages.append(message_resp)
# 如果 Redis 消息不够,从数据库补充
@@ -216,86 +142,46 @@ class RedisMessageSystem:
return message_id
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]):
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: ChatMessageDict):
"""存储消息到 Redis"""
try:
# 检查是否是多人房间消息
is_multiplayer = message_data.get("is_multiplayer", False)
# 存储消息数据
await self.redis.hset(
# 存储消息数据为 JSON 字符串
await self.redis.set(
f"msg:{channel_id}:{message_id}",
mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) for k, v in message_data.items()},
safe_json_dumps(message_data),
ex=604800, # 7天过期
)
# 设置消息过期时间7天
await self.redis.expire(f"msg:{channel_id}:{message_id}", 604800)
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
# 添加到频道消息列表(按时间排序
channel_messages_key = f"channel:{channel_id}:messages"
# 更健壮的键类型检查清理
# 检查清理错误类型的键
try:
key_type = await self.redis.type(channel_messages_key)
if key_type == "none":
# 键不存在,这是正常的
pass
elif key_type != "zset":
# 键类型错误,需要清理
if key_type not in ("none", "zset"):
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
await self.redis.delete(channel_messages_key)
# 验证删除是否成功
verify_type = await self.redis.type(channel_messages_key)
if verify_type != "none":
logger.error(
f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}"
)
# 强制删除
await self.redis.unlink(channel_messages_key)
except Exception as type_check_error:
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
# 如果检查失败,尝试强制删除键以确保清理
try:
await self.redis.delete(channel_messages_key)
except Exception:
# 最后的努力使用unlink
try:
await self.redis.unlink(channel_messages_key)
except Exception as final_error:
logger.error(f"Critical: Unable to clear problematic key {channel_messages_key}: {final_error}")
await self.redis.delete(channel_messages_key)
# 添加到频道消息列表sorted set
try:
await self.redis.zadd(
channel_messages_key,
mapping={f"msg:{channel_id}:{message_id}": message_id},
)
except Exception as zadd_error:
logger.error(f"Failed to add message to sorted set {channel_messages_key}: {zadd_error}")
# 如果添加失败,再次尝试清理并重试
await self.redis.delete(channel_messages_key)
await self.redis.zadd(
channel_messages_key,
mapping={f"msg:{channel_id}:{message_id}": message_id},
)
await self.redis.zadd(
channel_messages_key,
mapping={f"msg:{channel_id}:{message_id}": message_id},
)
# 保持频道消息列表大小最多1000条
await self.redis.zremrangebyrank(channel_messages_key, 0, -1001)
# 只有非多人房间消息才添加到待持久化队列
if not is_multiplayer:
await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
logger.debug(f"Message {message_id} added to persistence queue")
else:
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
logger.debug(f"Message {message_id} added to persistence queue")
except Exception as e:
logger.error(f"Failed to store message to Redis: {e}")
raise
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]:
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
"""从 Redis 获取消息"""
try:
# 获取消息键列表按消息ID排序
@@ -314,28 +200,16 @@ class RedisMessageSystem:
messages = []
for key in message_keys:
# 获取消息数据
raw_data = await self.redis.hgetall(key)
# 获取消息数据JSON 字符串)
raw_data = await self.redis.get(key)
if raw_data:
# 解码数据
message_data: dict[str, Any] = {}
for k, v in raw_data.items():
# 尝试解析 JSON
try:
if k in ["grade_counts", "level"] or v.startswith(("{", "[")):
message_data[k] = json.loads(v)
elif k in ["message_id", "channel_id", "sender_id"]:
message_data[k] = int(v)
elif k == "is_multiplayer":
message_data[k] = v == "True"
elif k == "created_at":
message_data[k] = float(v)
else:
message_data[k] = v
except (json.JSONDecodeError, ValueError):
message_data[k] = v
messages.append(message_data)
try:
# 解析 JSON 字符串为字典
message_data = json.loads(raw_data)
messages.append(message_data)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode message JSON from {key}: {e}")
continue
# 确保消息按ID正序排序时间顺序
messages.sort(key=lambda x: x.get("message_id", 0))
@@ -350,15 +224,15 @@ class RedisMessageSystem:
logger.error(f"Failed to get messages from Redis: {e}")
return []
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int):
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageDict], limit: int):
"""从数据库补充历史消息"""
try:
# 找到最小的消息ID
min_id = float("inf")
if existing_messages:
for msg in existing_messages:
if msg.message_id is not None and msg.message_id < min_id:
min_id = msg.message_id
if msg["message_id"] is not None and msg["message_id"] < min_id:
min_id = msg["message_id"]
needed = limit - len(existing_messages)
@@ -378,13 +252,13 @@ class RedisMessageSystem:
db_messages = (await session.exec(query)).all()
for msg in reversed(db_messages): # 按时间正序插入
msg_resp = await ChatMessageResp.from_db(msg, session)
msg_resp = await ChatMessageModel.transform(msg, includes=["sender"])
existing_messages.insert(0, msg_resp)
except Exception as e:
logger.error(f"Failed to backfill from database: {e}")
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]:
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageDict]:
"""仅从数据库获取消息(回退方案)"""
try:
async with with_db() as session:
@@ -402,7 +276,7 @@ class RedisMessageSystem:
messages = (await session.exec(query)).all()
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
results = await ChatMessageModel.transform_many(messages, includes=["sender"])
# 如果是 since > 0保持正序否则反转为时间正序
if since == 0:
@@ -450,27 +324,17 @@ class RedisMessageSystem:
# 解析频道ID和消息ID
channel_id, message_id = map(int, key.split(":"))
# 从 Redis 获取消息数据
raw_data = await self.redis.hgetall(f"msg:{channel_id}:{message_id}")
# 从 Redis 获取消息数据JSON 字符串)
raw_data = await self.redis.get(f"msg:{channel_id}:{message_id}")
if not raw_data:
continue
# 解码数据
message_data = {}
for k, v in raw_data.items():
message_data[k] = v
# 检查是否是多人房间消息,如果是则跳过数据库存储
is_multiplayer = message_data.get("is_multiplayer", "False") == "True"
if is_multiplayer:
# 多人房间消息不存储到数据库,直接标记为已跳过
await self.redis.hset(
f"msg:{channel_id}:{message_id}",
"status",
"skipped_multiplayer",
)
logger.debug(f"Message {message_id} in multiplayer room skipped from database storage")
# 解析 JSON 字符串为字典
try:
message_data = json.loads(raw_data)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode message JSON for {channel_id}:{message_id}: {e}")
continue
# 检查消息是否已存在于数据库
@@ -491,13 +355,6 @@ class RedisMessageSystem:
session.add(db_message)
# 更新 Redis 中的状态
await self.redis.hset(
f"msg:{channel_id}:{message_id}",
"status",
"persisted",
)
logger.debug(f"Message {message_id} persisted to database")
except Exception as e:

View File

@@ -14,8 +14,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
async def create_playlist_room_from_api(session: AsyncSession, room: APIUploadedRoom, host_id: int) -> Room:
db_room = room.to_room()
db_room.host_id = host_id
db_room = Room.model_validate({"host_id": host_id, **room.model_dump(exclude={"playlist"})})
db_room.starts_at = utcnow()
db_room.ends_at = db_room.starts_at + timedelta(minutes=db_room.duration if db_room.duration is not None else 0)
session.add(db_room)

View File

@@ -3,19 +3,19 @@
用于缓存用户信息,提供热缓存和实时刷新功能
"""
from datetime import datetime
import json
from typing import TYPE_CHECKING, Any
from app.config import settings
from app.const import BANCHOBOT_ID
from app.database import User, UserResp
from app.database.score import LegacyScoreResp, ScoreResp
from app.database.user import SEARCH_INCLUDED
from app.database import User
from app.database.score import LegacyScoreResp
from app.database.user import UserDict, UserModel
from app.dependencies.database import with_db
from app.helpers.asset_proxy_helper import replace_asset_urls
from app.log import logger
from app.models.score import GameMode
from app.utils import safe_json_dumps
from redis.asyncio import Redis
from sqlmodel import col, select
@@ -25,20 +25,6 @@ if TYPE_CHECKING:
pass
class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
def safe_json_dumps(data: Any) -> str:
"""安全的 JSON 序列化,支持 datetime 对象"""
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False)
class UserCacheService:
"""用户缓存服务"""
@@ -125,7 +111,7 @@ class UserCacheService:
"""生成用户谱面集缓存键"""
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserResp | None:
async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserDict | None:
"""从缓存获取用户信息"""
try:
cache_key = self._get_user_cache_key(user_id, ruleset)
@@ -133,7 +119,7 @@ class UserCacheService:
if cached_data:
logger.debug(f"User cache hit for user {user_id}")
data = json.loads(cached_data)
return UserResp(**data)
return data
return None
except Exception as e:
logger.error(f"Error getting user from cache: {e}")
@@ -141,7 +127,7 @@ class UserCacheService:
async def cache_user(
self,
user_resp: UserResp,
user_resp: UserDict,
ruleset: GameMode | None = None,
expire_seconds: int | None = None,
):
@@ -149,13 +135,10 @@ class UserCacheService:
try:
if expire_seconds is None:
expire_seconds = settings.user_cache_expire_seconds
if user_resp.id is None:
logger.warning("Cannot cache user with None id")
return
cache_key = self._get_user_cache_key(user_resp.id, ruleset)
cached_data = user_resp.model_dump_json()
cache_key = self._get_user_cache_key(user_resp["id"], ruleset)
cached_data = safe_json_dumps(user_resp)
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_resp.id} for {expire_seconds}s")
logger.debug(f"Cached user {user_resp['id']} for {expire_seconds}s")
except Exception as e:
logger.error(f"Error caching user: {e}")
@@ -168,10 +151,9 @@ class UserCacheService:
limit: int = 100,
offset: int = 0,
is_legacy: bool = False,
) -> list[ScoreResp] | list[LegacyScoreResp] | None:
) -> list[UserDict] | list[LegacyScoreResp] | None:
"""从缓存获取用户成绩"""
try:
model = LegacyScoreResp if is_legacy else ScoreResp
cache_key = self._get_user_scores_cache_key(
user_id, score_type, include_fail, mode, limit, offset, is_legacy
)
@@ -179,7 +161,7 @@ class UserCacheService:
if cached_data:
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
data = json.loads(cached_data)
return [model(**score_data) for score_data in data] # pyright: ignore[reportReturnType]
return [LegacyScoreResp(**score_data) for score_data in data] if is_legacy else data
return None
except Exception as e:
logger.error(f"Error getting user scores from cache: {e}")
@@ -189,7 +171,7 @@ class UserCacheService:
self,
user_id: int,
score_type: str,
scores: list[ScoreResp] | list[LegacyScoreResp],
scores: list[UserDict] | list[LegacyScoreResp],
include_fail: bool,
mode: GameMode | None = None,
limit: int = 100,
@@ -204,8 +186,12 @@ class UserCacheService:
cache_key = self._get_user_scores_cache_key(
user_id, score_type, include_fail, mode, limit, offset, is_legacy
)
# 使用 model_dump_json() 而不是 model_dump() + json.dumps()
scores_json_list = [score.model_dump_json() for score in scores]
if len(scores) == 0:
return
if isinstance(scores[0], dict):
scores_json_list = [safe_json_dumps(score) for score in scores]
else:
scores_json_list = [score.model_dump_json() for score in scores] # pyright: ignore[reportAttributeAccessIssue]
cached_data = f"[{','.join(scores_json_list)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
@@ -308,7 +294,7 @@ class UserCacheService:
for user in users:
if user.id != BANCHOBOT_ID:
try:
await self._cache_single_user(user, session)
await self._cache_single_user(user)
cached_count += 1
except Exception as e:
logger.error(f"Failed to cache user {user.id}: {e}")
@@ -320,10 +306,10 @@ class UserCacheService:
finally:
self._refreshing = False
async def _cache_single_user(self, user: User, session: AsyncSession):
async def _cache_single_user(self, user: User):
"""缓存单个用户"""
try:
user_resp = await UserResp.from_db(user, session, include=SEARCH_INCLUDED)
user_resp = await UserModel.transform(user, includes=User.USER_INCLUDES)
# 应用资源代理处理
if settings.enable_asset_proxy:
@@ -347,7 +333,7 @@ class UserCacheService:
# 立即重新加载用户信息
user = await session.get(User, user_id)
if user and user.id != BANCHOBOT_ID:
await self._cache_single_user(user, session)
await self._cache_single_user(user)
logger.info(f"Refreshed cache for user {user_id} after score submit")
except Exception as e:
logger.error(f"Error refreshing user cache on score submit: {e}")