refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
@@ -3,14 +3,14 @@ Beatmapset缓存服务
|
||||
用于缓存beatmapset数据,减少数据库查询频率
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.config import settings
|
||||
from app.database.beatmapset import BeatmapsetResp
|
||||
from app.database import BeatmapsetDict
|
||||
from app.log import logger
|
||||
from app.utils import safe_json_dumps
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
@@ -18,20 +18,6 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""处理datetime序列化的JSON编码器"""
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
def safe_json_dumps(data) -> str:
|
||||
"""安全的JSON序列化,处理datetime对象"""
|
||||
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False)
|
||||
|
||||
|
||||
def generate_hash(data) -> str:
|
||||
"""生成数据的MD5哈希值"""
|
||||
content = data if isinstance(data, str) else safe_json_dumps(data)
|
||||
@@ -57,15 +43,14 @@ class BeatmapsetCacheService:
|
||||
"""生成搜索结果缓存键"""
|
||||
return f"beatmapset_search:{query_hash}:{cursor_hash}"
|
||||
|
||||
async def get_beatmapset_from_cache(self, beatmapset_id: int) -> BeatmapsetResp | None:
|
||||
async def get_beatmapset_from_cache(self, beatmapset_id: int) -> BeatmapsetDict | None:
|
||||
"""从缓存获取beatmapset信息"""
|
||||
try:
|
||||
cache_key = self._get_beatmapset_cache_key(beatmapset_id)
|
||||
cached_data = await self.redis.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Beatmapset cache hit for {beatmapset_id}")
|
||||
data = json.loads(cached_data)
|
||||
return BeatmapsetResp(**data)
|
||||
return json.loads(cached_data)
|
||||
return None
|
||||
except (ValueError, TypeError, AttributeError) as e:
|
||||
logger.error(f"Error getting beatmapset from cache: {e}")
|
||||
@@ -73,24 +58,21 @@ class BeatmapsetCacheService:
|
||||
|
||||
async def cache_beatmapset(
|
||||
self,
|
||||
beatmapset_resp: BeatmapsetResp,
|
||||
beatmapset_resp: BeatmapsetDict,
|
||||
expire_seconds: int | None = None,
|
||||
):
|
||||
"""缓存beatmapset信息"""
|
||||
try:
|
||||
if expire_seconds is None:
|
||||
expire_seconds = self._default_ttl
|
||||
if beatmapset_resp.id is None:
|
||||
logger.warning("Cannot cache beatmapset with None id")
|
||||
return
|
||||
cache_key = self._get_beatmapset_cache_key(beatmapset_resp.id)
|
||||
cached_data = beatmapset_resp.model_dump_json()
|
||||
cache_key = self._get_beatmapset_cache_key(beatmapset_resp["id"])
|
||||
cached_data = safe_json_dumps(beatmapset_resp)
|
||||
await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore
|
||||
logger.debug(f"Cached beatmapset {beatmapset_resp.id} for {expire_seconds}s")
|
||||
logger.debug(f"Cached beatmapset {beatmapset_resp['id']} for {expire_seconds}s")
|
||||
except (ValueError, TypeError, AttributeError) as e:
|
||||
logger.error(f"Error caching beatmapset: {e}")
|
||||
|
||||
async def get_beatmap_lookup_from_cache(self, beatmap_id: int) -> BeatmapsetResp | None:
|
||||
async def get_beatmap_lookup_from_cache(self, beatmap_id: int) -> BeatmapsetDict | None:
|
||||
"""从缓存获取通过beatmap ID查找的beatmapset信息"""
|
||||
try:
|
||||
cache_key = self._get_beatmap_lookup_cache_key(beatmap_id)
|
||||
@@ -98,7 +80,7 @@ class BeatmapsetCacheService:
|
||||
if cached_data:
|
||||
logger.debug(f"Beatmap lookup cache hit for {beatmap_id}")
|
||||
data = json.loads(cached_data)
|
||||
return BeatmapsetResp(**data)
|
||||
return data
|
||||
return None
|
||||
except (ValueError, TypeError, AttributeError) as e:
|
||||
logger.error(f"Error getting beatmap lookup from cache: {e}")
|
||||
@@ -107,7 +89,7 @@ class BeatmapsetCacheService:
|
||||
async def cache_beatmap_lookup(
|
||||
self,
|
||||
beatmap_id: int,
|
||||
beatmapset_resp: BeatmapsetResp,
|
||||
beatmapset_resp: BeatmapsetDict,
|
||||
expire_seconds: int | None = None,
|
||||
):
|
||||
"""缓存通过beatmap ID查找的beatmapset信息"""
|
||||
@@ -115,7 +97,7 @@ class BeatmapsetCacheService:
|
||||
if expire_seconds is None:
|
||||
expire_seconds = self._default_ttl
|
||||
cache_key = self._get_beatmap_lookup_cache_key(beatmap_id)
|
||||
cached_data = beatmapset_resp.model_dump_json()
|
||||
cached_data = safe_json_dumps(beatmapset_resp)
|
||||
await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore
|
||||
logger.debug(f"Cached beatmap lookup {beatmap_id} for {expire_seconds}s")
|
||||
except (ValueError, TypeError, AttributeError) as e:
|
||||
|
||||
@@ -3,12 +3,12 @@ from datetime import timedelta
|
||||
from enum import Enum
|
||||
import math
|
||||
import random
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
from typing import TYPE_CHECKING, NamedTuple, cast
|
||||
|
||||
from app.config import OldScoreProcessingMode, settings
|
||||
from app.database.beatmap import Beatmap, BeatmapResp
|
||||
from app.database.beatmap import Beatmap, BeatmapDict
|
||||
from app.database.beatmap_sync import BeatmapSync, SavedBeatmapMeta
|
||||
from app.database.beatmapset import Beatmapset, BeatmapsetResp
|
||||
from app.database.beatmapset import Beatmapset, BeatmapsetDict
|
||||
from app.database.score import Score
|
||||
from app.dependencies.database import get_redis, with_db
|
||||
from app.dependencies.storage import get_storage_service
|
||||
@@ -62,10 +62,23 @@ STATUS_FACTOR: dict[BeatmapRankStatus, float] = {
|
||||
SCHEDULER_INTERVAL_MINUTES = 2
|
||||
|
||||
|
||||
class EnsuredBeatmap(BeatmapDict):
|
||||
checksum: str
|
||||
ranked: int
|
||||
|
||||
|
||||
class EnsuredBeatmapset(BeatmapsetDict):
|
||||
ranked: int
|
||||
ranked_date: datetime.datetime
|
||||
last_updated: datetime.datetime
|
||||
play_count: int
|
||||
beatmaps: list[EnsuredBeatmap]
|
||||
|
||||
|
||||
class ProcessingBeatmapset:
|
||||
def __init__(self, beatmapset: BeatmapsetResp, record: BeatmapSync) -> None:
|
||||
def __init__(self, beatmapset: EnsuredBeatmapset, record: BeatmapSync) -> None:
|
||||
self.beatmapset = beatmapset
|
||||
self.status = BeatmapRankStatus(self.beatmapset.ranked)
|
||||
self.status = BeatmapRankStatus(self.beatmapset["ranked"])
|
||||
self.record = record
|
||||
|
||||
def calculate_next_sync_time(
|
||||
@@ -76,19 +89,19 @@ class ProcessingBeatmapset:
|
||||
|
||||
now = utcnow()
|
||||
if self.status == BeatmapRankStatus.QUALIFIED:
|
||||
assert self.beatmapset.ranked_date is not None, "ranked_date should not be None for qualified maps"
|
||||
time_to_ranked = (self.beatmapset.ranked_date + timedelta(days=7) - now).total_seconds()
|
||||
assert self.beatmapset["ranked_date"] is not None, "ranked_date should not be None for qualified maps"
|
||||
time_to_ranked = (self.beatmapset["ranked_date"] + timedelta(days=7) - now).total_seconds()
|
||||
baseline = max(MIN_DELTA, time_to_ranked / 2)
|
||||
next_delta = max(MIN_DELTA, baseline)
|
||||
elif self.status in {BeatmapRankStatus.WIP, BeatmapRankStatus.PENDING}:
|
||||
seconds_since_update = (now - self.beatmapset.last_updated).total_seconds()
|
||||
seconds_since_update = (now - self.beatmapset["last_updated"]).total_seconds()
|
||||
factor_update = max(1.0, seconds_since_update / TAU)
|
||||
factor_play = 1.0 + math.log(1.0 + self.beatmapset.play_count)
|
||||
factor_play = 1.0 + math.log(1.0 + self.beatmapset["play_count"])
|
||||
status_factor = STATUS_FACTOR[self.status]
|
||||
baseline = BASE * factor_play / factor_update * status_factor
|
||||
next_delta = max(MIN_DELTA, baseline * (GROWTH ** (self.record.consecutive_no_change + 1)))
|
||||
elif self.status == BeatmapRankStatus.GRAVEYARD:
|
||||
days_since_update = (now - self.beatmapset.last_updated).days
|
||||
days_since_update = (now - self.beatmapset["last_updated"]).days
|
||||
doubling_periods = days_since_update / GRAVEYARD_DOUBLING_PERIOD_DAYS
|
||||
delta = MIN_DELTA * (2**doubling_periods)
|
||||
max_seconds = GRAVEYARD_MAX_DAYS * 86400
|
||||
@@ -105,21 +118,24 @@ class ProcessingBeatmapset:
|
||||
|
||||
@property
|
||||
def beatmapset_changed(self) -> bool:
|
||||
return self.record.beatmap_status != BeatmapRankStatus(self.beatmapset.ranked)
|
||||
return self.record.beatmap_status != BeatmapRankStatus(self.beatmapset["ranked"])
|
||||
|
||||
@property
|
||||
def changed_beatmaps(self) -> list[ChangedBeatmap]:
|
||||
changed_beatmaps = []
|
||||
for bm in self.beatmapset.beatmaps:
|
||||
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None)
|
||||
for bm in self.beatmapset["beatmaps"]:
|
||||
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm["id"]), None)
|
||||
if not saved or saved["is_deleted"]:
|
||||
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
|
||||
elif saved["md5"] != bm.checksum:
|
||||
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED))
|
||||
elif saved["beatmap_status"] != BeatmapRankStatus(bm.ranked):
|
||||
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.STATUS_CHANGED))
|
||||
changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.MAP_ADDED))
|
||||
elif saved["md5"] != bm["checksum"]:
|
||||
changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.MAP_UPDATED))
|
||||
elif saved["beatmap_status"] != BeatmapRankStatus(bm["ranked"]):
|
||||
changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.STATUS_CHANGED))
|
||||
for saved in self.record.beatmaps:
|
||||
if not any(bm.id == saved["beatmap_id"] for bm in self.beatmapset.beatmaps) and not saved["is_deleted"]:
|
||||
if (
|
||||
not any(bm["id"] == saved["beatmap_id"] for bm in self.beatmapset["beatmaps"])
|
||||
and not saved["is_deleted"]
|
||||
):
|
||||
changed_beatmaps.append(ChangedBeatmap(saved["beatmap_id"], BeatmapChangeType.MAP_DELETED))
|
||||
return changed_beatmaps
|
||||
|
||||
@@ -132,7 +148,7 @@ class BeatmapsetUpdateService:
|
||||
async def add_missing_beatmapset(self, beatmapset_id: int, immediate: bool = False) -> bool:
|
||||
beatmapset = await self.fetcher.get_beatmapset(beatmapset_id)
|
||||
if immediate:
|
||||
await self._sync_immediately(beatmapset)
|
||||
await self._sync_immediately(cast(EnsuredBeatmapset, beatmapset))
|
||||
logger.debug(f"triggered immediate sync for beatmapset {beatmapset_id} ")
|
||||
return True
|
||||
await self.add(beatmapset)
|
||||
@@ -172,7 +188,7 @@ class BeatmapsetUpdateService:
|
||||
BeatmapSync(
|
||||
beatmapset_id=missing,
|
||||
beatmap_status=BeatmapRankStatus.GRAVEYARD,
|
||||
next_sync_time=datetime.datetime.max,
|
||||
next_sync_time=datetime.datetime(year=6000, month=1, day=1),
|
||||
beatmaps=[],
|
||||
)
|
||||
)
|
||||
@@ -185,11 +201,13 @@ class BeatmapsetUpdateService:
|
||||
await session.commit()
|
||||
self._adding_missing = False
|
||||
|
||||
async def add(self, beatmapset: BeatmapsetResp, calculate_next_sync: bool = True):
|
||||
async def add(self, set: BeatmapsetDict, calculate_next_sync: bool = True):
|
||||
beatmapset = cast(EnsuredBeatmapset, set)
|
||||
async with with_db() as session:
|
||||
sync_record = await session.get(BeatmapSync, beatmapset.id)
|
||||
beatmapset_id = beatmapset["id"]
|
||||
sync_record = await session.get(BeatmapSync, beatmapset_id)
|
||||
if not sync_record:
|
||||
database_beatmapset = await session.get(Beatmapset, beatmapset.id)
|
||||
database_beatmapset = await session.get(Beatmapset, beatmapset_id)
|
||||
if database_beatmapset:
|
||||
status = BeatmapRankStatus(database_beatmapset.beatmap_status)
|
||||
await database_beatmapset.awaitable_attrs.beatmaps
|
||||
@@ -203,19 +221,29 @@ class BeatmapsetUpdateService:
|
||||
for bm in database_beatmapset.beatmaps
|
||||
]
|
||||
else:
|
||||
status = BeatmapRankStatus(beatmapset.ranked)
|
||||
beatmaps = [
|
||||
SavedBeatmapMeta(
|
||||
beatmap_id=bm.id,
|
||||
md5=bm.checksum,
|
||||
is_deleted=False,
|
||||
beatmap_status=BeatmapRankStatus(bm.ranked),
|
||||
ranked = beatmapset.get("ranked")
|
||||
if ranked is None:
|
||||
raise ValueError("ranked field is required")
|
||||
status = BeatmapRankStatus(ranked)
|
||||
beatmap_list = beatmapset.get("beatmaps", [])
|
||||
beatmaps = []
|
||||
for bm in beatmap_list:
|
||||
bm_id = bm.get("id")
|
||||
checksum = bm.get("checksum")
|
||||
ranked = bm.get("ranked")
|
||||
if bm_id is None or checksum is None or ranked is None:
|
||||
continue
|
||||
beatmaps.append(
|
||||
SavedBeatmapMeta(
|
||||
beatmap_id=bm_id,
|
||||
md5=checksum,
|
||||
is_deleted=False,
|
||||
beatmap_status=BeatmapRankStatus(ranked),
|
||||
)
|
||||
)
|
||||
for bm in beatmapset.beatmaps
|
||||
]
|
||||
|
||||
sync_record = BeatmapSync(
|
||||
beatmapset_id=beatmapset.id,
|
||||
beatmapset_id=beatmapset_id,
|
||||
beatmaps=beatmaps,
|
||||
beatmap_status=status,
|
||||
)
|
||||
@@ -223,13 +251,27 @@ class BeatmapsetUpdateService:
|
||||
await session.commit()
|
||||
await session.refresh(sync_record)
|
||||
else:
|
||||
sync_record.beatmaps = [
|
||||
SavedBeatmapMeta(
|
||||
beatmap_id=bm.id, md5=bm.checksum, is_deleted=False, beatmap_status=BeatmapRankStatus(bm.ranked)
|
||||
ranked = beatmapset.get("ranked")
|
||||
if ranked is None:
|
||||
raise ValueError("ranked field is required")
|
||||
beatmap_list = beatmapset.get("beatmaps", [])
|
||||
beatmaps = []
|
||||
for bm in beatmap_list:
|
||||
bm_id = bm.get("id")
|
||||
checksum = bm.get("checksum")
|
||||
bm_ranked = bm.get("ranked")
|
||||
if bm_id is None or checksum is None or bm_ranked is None:
|
||||
continue
|
||||
beatmaps.append(
|
||||
SavedBeatmapMeta(
|
||||
beatmap_id=bm_id,
|
||||
md5=checksum,
|
||||
is_deleted=False,
|
||||
beatmap_status=BeatmapRankStatus(bm_ranked),
|
||||
)
|
||||
)
|
||||
for bm in beatmapset.beatmaps
|
||||
]
|
||||
sync_record.beatmap_status = BeatmapRankStatus(beatmapset.ranked)
|
||||
sync_record.beatmaps = beatmaps
|
||||
sync_record.beatmap_status = BeatmapRankStatus(ranked)
|
||||
if calculate_next_sync:
|
||||
processing = ProcessingBeatmapset(beatmapset, sync_record)
|
||||
next_time_delta = processing.calculate_next_sync_time()
|
||||
@@ -238,17 +280,19 @@ class BeatmapsetUpdateService:
|
||||
await BeatmapsetUpdateService._sync_immediately(self, beatmapset)
|
||||
return
|
||||
sync_record.next_sync_time = utcnow() + next_time_delta
|
||||
logger.opt(colors=True).info(f"<g>[{beatmapset.id}]</g> next sync at {sync_record.next_sync_time}")
|
||||
beatmapset_id = beatmapset.get("id")
|
||||
if beatmapset_id:
|
||||
logger.opt(colors=True).debug(f"<g>[{beatmapset_id}]</g> next sync at {sync_record.next_sync_time}")
|
||||
await session.commit()
|
||||
|
||||
async def _sync_immediately(self, beatmapset: BeatmapsetResp) -> None:
|
||||
async def _sync_immediately(self, beatmapset: EnsuredBeatmapset) -> None:
|
||||
async with with_db() as session:
|
||||
record = await session.get(BeatmapSync, beatmapset.id)
|
||||
record = await session.get(BeatmapSync, beatmapset["id"])
|
||||
if not record:
|
||||
record = BeatmapSync(
|
||||
beatmapset_id=beatmapset.id,
|
||||
beatmapset_id=beatmapset["id"],
|
||||
beatmaps=[],
|
||||
beatmap_status=BeatmapRankStatus(beatmapset.ranked),
|
||||
beatmap_status=BeatmapRankStatus(beatmapset["ranked"]),
|
||||
)
|
||||
session.add(record)
|
||||
await session.commit()
|
||||
@@ -261,19 +305,18 @@ class BeatmapsetUpdateService:
|
||||
record: BeatmapSync,
|
||||
session: AsyncSession,
|
||||
*,
|
||||
beatmapset: BeatmapsetResp | None = None,
|
||||
beatmapset: EnsuredBeatmapset | None = None,
|
||||
):
|
||||
logger.opt(colors=True).info(f"<g>[{record.beatmapset_id}]</g> syncing...")
|
||||
logger.opt(colors=True).debug(f"<g>[{record.beatmapset_id}]</g> syncing...")
|
||||
if beatmapset is None:
|
||||
try:
|
||||
beatmapset = await self.fetcher.get_beatmapset(record.beatmapset_id)
|
||||
beatmapset = cast(EnsuredBeatmapset, await self.fetcher.get_beatmapset(record.beatmapset_id))
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPStatusError) and e.response.status_code == 404:
|
||||
logger.opt(colors=True).warning(
|
||||
f"<g>[{record.beatmapset_id}]</g> beatmapset not found (404), removing from sync list"
|
||||
)
|
||||
await session.delete(record)
|
||||
await session.commit()
|
||||
return
|
||||
if isinstance(e, HTTPError):
|
||||
logger.opt(colors=True).warning(
|
||||
@@ -292,20 +335,20 @@ class BeatmapsetUpdateService:
|
||||
if changed:
|
||||
record.beatmaps = [
|
||||
SavedBeatmapMeta(
|
||||
beatmap_id=bm.id,
|
||||
md5=bm.checksum,
|
||||
beatmap_id=bm["id"],
|
||||
md5=bm["checksum"],
|
||||
is_deleted=False,
|
||||
beatmap_status=BeatmapRankStatus(bm.ranked),
|
||||
beatmap_status=BeatmapRankStatus(bm["ranked"]),
|
||||
)
|
||||
for bm in beatmapset.beatmaps
|
||||
for bm in beatmapset["beatmaps"]
|
||||
]
|
||||
record.beatmap_status = BeatmapRankStatus(beatmapset.ranked)
|
||||
record.beatmap_status = BeatmapRankStatus(beatmapset["ranked"])
|
||||
record.consecutive_no_change = 0
|
||||
|
||||
bg_tasks.add_task(
|
||||
self._process_changed_beatmaps,
|
||||
changed_beatmaps,
|
||||
beatmapset.beatmaps,
|
||||
beatmapset["beatmaps"],
|
||||
)
|
||||
bg_tasks.add_task(
|
||||
self._process_changed_beatmapset,
|
||||
@@ -317,13 +360,13 @@ class BeatmapsetUpdateService:
|
||||
next_time_delta = processing.calculate_next_sync_time()
|
||||
if not next_time_delta:
|
||||
logger.opt(colors=True).info(
|
||||
f"<yellow>[{beatmapset.id}]</yellow> beatmapset has transformed to ranked or loved,"
|
||||
f"<yellow>[{beatmapset['id']}]</yellow> beatmapset has transformed to ranked or loved,"
|
||||
f" removing from sync list"
|
||||
)
|
||||
await session.delete(record)
|
||||
else:
|
||||
record.next_sync_time = utcnow() + next_time_delta
|
||||
logger.opt(colors=True).info(f"<g>[{record.beatmapset_id}]</g> next sync at {record.next_sync_time}")
|
||||
logger.opt(colors=True).debug(f"<g>[{record.beatmapset_id}]</g> next sync at {record.next_sync_time}")
|
||||
|
||||
async def _update_beatmaps(self):
|
||||
async with with_db() as session:
|
||||
@@ -338,18 +381,18 @@ class BeatmapsetUpdateService:
|
||||
await self.sync(record, session)
|
||||
await session.commit()
|
||||
|
||||
async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp):
|
||||
async def _process_changed_beatmapset(self, beatmapset: EnsuredBeatmapset):
|
||||
async with with_db() as session:
|
||||
db_beatmapset = await session.get(Beatmapset, beatmapset.id)
|
||||
new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset)
|
||||
db_beatmapset = await session.get(Beatmapset, beatmapset["id"])
|
||||
new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset) # pyright: ignore[reportArgumentType]
|
||||
if db_beatmapset:
|
||||
await session.merge(new_beatmapset)
|
||||
await get_beatmapset_cache_service(get_redis()).invalidate_beatmapset_cache(beatmapset.id)
|
||||
await get_beatmapset_cache_service(get_redis()).invalidate_beatmapset_cache(beatmapset["id"])
|
||||
await session.commit()
|
||||
|
||||
async def _process_changed_beatmaps(self, changed: list[ChangedBeatmap], beatmaps_list: list[BeatmapResp]):
|
||||
async def _process_changed_beatmaps(self, changed: list[ChangedBeatmap], beatmaps_list: list[EnsuredBeatmap]):
|
||||
storage_service = get_storage_service()
|
||||
beatmaps = {bm.id: bm for bm in beatmaps_list}
|
||||
beatmaps = {bm["id"]: bm for bm in beatmaps_list}
|
||||
|
||||
async with with_db() as session:
|
||||
|
||||
@@ -380,9 +423,9 @@ class BeatmapsetUpdateService:
|
||||
)
|
||||
continue
|
||||
logger.opt(colors=True).info(
|
||||
f"<g>[{beatmap.beatmapset_id}]</g> adding beatmap <blue>{beatmap.id}</blue>"
|
||||
f"<g>[{beatmap['beatmapset_id']}]</g> adding beatmap <blue>{beatmap['id']}</blue>"
|
||||
)
|
||||
await Beatmap.from_resp_no_save(session, beatmap)
|
||||
await Beatmap.from_resp_no_save(session, beatmap) # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
beatmap = beatmaps.get(change.beatmap_id)
|
||||
if not beatmap:
|
||||
@@ -391,10 +434,10 @@ class BeatmapsetUpdateService:
|
||||
)
|
||||
continue
|
||||
logger.opt(colors=True).info(
|
||||
f"<g>[{beatmap.beatmapset_id}]</g> processing beatmap <blue>{beatmap.id}</blue> "
|
||||
f"<g>[{beatmap['beatmapset_id']}]</g> processing beatmap <blue>{beatmap['id']}</blue> "
|
||||
f"change <cyan>{change.type}</cyan>"
|
||||
)
|
||||
new_db_beatmap = await Beatmap.from_resp_no_save(session, beatmap)
|
||||
new_db_beatmap = await Beatmap.from_resp_no_save(session, beatmap) # pyright: ignore[reportArgumentType]
|
||||
existing_beatmap = await session.get(Beatmap, change.beatmap_id)
|
||||
if existing_beatmap:
|
||||
await session.merge(new_db_beatmap)
|
||||
|
||||
@@ -4,16 +4,15 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from app.config import settings
|
||||
from app.database.statistics import UserStatistics, UserStatisticsResp
|
||||
from app.database.statistics import UserStatistics, UserStatisticsModel
|
||||
from app.helpers.asset_proxy_helper import replace_asset_urls
|
||||
from app.log import logger
|
||||
from app.models.score import GameMode
|
||||
from app.utils import utcnow
|
||||
from app.utils import safe_json_dumps, utcnow
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
@@ -23,20 +22,6 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""自定义 JSON 编码器,支持 datetime 序列化"""
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
def safe_json_dumps(data) -> str:
|
||||
"""安全的 JSON 序列化,支持 datetime 对象"""
|
||||
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
|
||||
class RankingCacheService:
|
||||
"""用户排行榜缓存服务"""
|
||||
|
||||
@@ -311,7 +296,7 @@ class RankingCacheService:
|
||||
col(UserStatistics.pp) > 0,
|
||||
col(UserStatistics.is_ranked).is_(True),
|
||||
]
|
||||
include = ["user"]
|
||||
include = UserStatistics.RANKING_INCLUDES.copy()
|
||||
|
||||
if type == "performance":
|
||||
order_by = col(UserStatistics.pp).desc()
|
||||
@@ -321,6 +306,7 @@ class RankingCacheService:
|
||||
|
||||
if country:
|
||||
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
|
||||
include.append("country_rank")
|
||||
|
||||
# 获取总用户数用于统计
|
||||
total_users_query = select(UserStatistics).where(*wheres)
|
||||
@@ -353,9 +339,9 @@ class RankingCacheService:
|
||||
# 转换为响应格式并确保正确序列化
|
||||
ranking_data = []
|
||||
for statistics in statistics_data:
|
||||
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
|
||||
user_stats_resp = await UserStatisticsModel.transform(statistics, includes=include)
|
||||
|
||||
user_dict = user_stats_resp.model_dump()
|
||||
user_dict = user_stats_resp
|
||||
|
||||
# 应用资源代理处理
|
||||
if settings.enable_asset_proxy:
|
||||
|
||||
@@ -8,14 +8,14 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.database.chat import ChatMessage, ChatMessageResp, MessageType
|
||||
from app.database.user import RANKING_INCLUDES, User, UserResp
|
||||
from app.database import ChatMessageDict
|
||||
from app.database.chat import ChatMessage, ChatMessageModel, MessageType
|
||||
from app.database.user import User, UserModel
|
||||
from app.dependencies.database import get_redis_message, with_db
|
||||
from app.log import logger
|
||||
from app.utils import bg_tasks
|
||||
from app.utils import bg_tasks, safe_json_dumps
|
||||
|
||||
|
||||
class RedisMessageSystem:
|
||||
@@ -35,7 +35,7 @@ class RedisMessageSystem:
|
||||
content: str,
|
||||
is_action: bool = False,
|
||||
user_uuid: str | None = None,
|
||||
) -> ChatMessageResp:
|
||||
) -> "ChatMessageDict":
|
||||
"""
|
||||
发送消息 - 立即存储到 Redis 并返回
|
||||
|
||||
@@ -47,7 +47,7 @@ class RedisMessageSystem:
|
||||
user_uuid: 用户UUID
|
||||
|
||||
Returns:
|
||||
ChatMessageResp: 消息响应对象
|
||||
ChatMessage: 消息响应对象
|
||||
"""
|
||||
# 生成消息ID和时间戳
|
||||
message_id = await self._generate_message_id(channel_id)
|
||||
@@ -57,28 +57,16 @@ class RedisMessageSystem:
|
||||
if not user.id:
|
||||
raise ValueError("User ID is required")
|
||||
|
||||
# 获取频道类型以判断是否需要存储到数据库
|
||||
async with with_db() as session:
|
||||
from app.database.chat import ChannelType, ChatChannel
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id))
|
||||
channel_type = channel_result.first()
|
||||
is_multiplayer = channel_type == ChannelType.MULTIPLAYER
|
||||
|
||||
# 准备消息数据
|
||||
message_data = {
|
||||
message_data: "ChatMessageDict" = {
|
||||
"message_id": message_id,
|
||||
"channel_id": channel_id,
|
||||
"sender_id": user.id,
|
||||
"content": content,
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
|
||||
"timestamp": timestamp,
|
||||
"type": MessageType.ACTION if is_action else MessageType.PLAIN,
|
||||
"uuid": user_uuid or "",
|
||||
"status": "cached", # Redis 缓存状态
|
||||
"created_at": time.time(),
|
||||
"is_multiplayer": is_multiplayer, # 标记是否为多人房间消息
|
||||
"is_action": is_action,
|
||||
}
|
||||
|
||||
# 立即存储到 Redis
|
||||
@@ -86,51 +74,13 @@ class RedisMessageSystem:
|
||||
|
||||
# 创建响应对象
|
||||
async with with_db() as session:
|
||||
user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
user_resp = await UserModel.transform(user, session=session, includes=User.LIST_INCLUDES)
|
||||
message_data["sender"] = user_resp
|
||||
|
||||
# 确保 statistics 不为空
|
||||
if user_resp.statistics is None:
|
||||
from app.database.statistics import UserStatisticsResp
|
||||
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
|
||||
return message_data
|
||||
|
||||
user_resp.statistics = UserStatisticsResp(
|
||||
mode=user.playmode,
|
||||
global_rank=0,
|
||||
country_rank=0,
|
||||
pp=0.0,
|
||||
ranked_score=0,
|
||||
hit_accuracy=0.0,
|
||||
play_count=0,
|
||||
play_time=0,
|
||||
total_score=0,
|
||||
total_hits=0,
|
||||
maximum_combo=0,
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=False,
|
||||
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
|
||||
level={"current": 1, "progress": 0},
|
||||
)
|
||||
|
||||
response = ChatMessageResp(
|
||||
message_id=message_id,
|
||||
channel_id=channel_id,
|
||||
content=content,
|
||||
timestamp=timestamp,
|
||||
sender_id=user.id,
|
||||
sender=user_resp,
|
||||
is_action=is_action,
|
||||
uuid=user_uuid,
|
||||
)
|
||||
|
||||
if is_multiplayer:
|
||||
logger.info(
|
||||
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id},"
|
||||
" will not be persisted to database"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
|
||||
return response
|
||||
|
||||
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]:
|
||||
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
|
||||
"""
|
||||
获取频道消息 - 优先从 Redis 获取最新消息
|
||||
|
||||
@@ -140,9 +90,9 @@ class RedisMessageSystem:
|
||||
since: 起始消息ID
|
||||
|
||||
Returns:
|
||||
List[ChatMessageResp]: 消息列表
|
||||
List[ChatMessageDict]: 消息列表
|
||||
"""
|
||||
messages = []
|
||||
messages: list["ChatMessageDict"] = []
|
||||
|
||||
try:
|
||||
# 从 Redis 获取最新消息
|
||||
@@ -154,45 +104,21 @@ class RedisMessageSystem:
|
||||
# 获取发送者信息
|
||||
sender = await session.get(User, msg_data["sender_id"])
|
||||
if sender:
|
||||
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
|
||||
user_resp = await UserModel.transform(sender, includes=User.LIST_INCLUDES)
|
||||
|
||||
if user_resp.statistics is None:
|
||||
from app.database.statistics import UserStatisticsResp
|
||||
from app.database.chat import ChatMessageDict
|
||||
|
||||
user_resp.statistics = UserStatisticsResp(
|
||||
mode=sender.playmode,
|
||||
global_rank=0,
|
||||
country_rank=0,
|
||||
pp=0.0,
|
||||
ranked_score=0,
|
||||
hit_accuracy=0.0,
|
||||
play_count=0,
|
||||
play_time=0,
|
||||
total_score=0,
|
||||
total_hits=0,
|
||||
maximum_combo=0,
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=False,
|
||||
grade_counts={
|
||||
"ssh": 0,
|
||||
"ss": 0,
|
||||
"sh": 0,
|
||||
"s": 0,
|
||||
"a": 0,
|
||||
},
|
||||
level={"current": 1, "progress": 0},
|
||||
)
|
||||
|
||||
message_resp = ChatMessageResp(
|
||||
message_id=msg_data["message_id"],
|
||||
channel_id=msg_data["channel_id"],
|
||||
content=msg_data["content"],
|
||||
timestamp=datetime.fromisoformat(msg_data["timestamp"]),
|
||||
sender_id=msg_data["sender_id"],
|
||||
sender=user_resp,
|
||||
is_action=msg_data["type"] == MessageType.ACTION.value,
|
||||
uuid=msg_data.get("uuid") or None,
|
||||
)
|
||||
message_resp: ChatMessageDict = {
|
||||
"message_id": msg_data["message_id"],
|
||||
"channel_id": msg_data["channel_id"],
|
||||
"content": msg_data["content"],
|
||||
"timestamp": datetime.fromisoformat(msg_data["timestamp"]), # pyright: ignore[reportArgumentType]
|
||||
"sender_id": msg_data["sender_id"],
|
||||
"sender": user_resp,
|
||||
"is_action": msg_data["type"] == MessageType.ACTION.value,
|
||||
"uuid": msg_data.get("uuid") or None,
|
||||
"type": MessageType(msg_data["type"]),
|
||||
}
|
||||
messages.append(message_resp)
|
||||
|
||||
# 如果 Redis 消息不够,从数据库补充
|
||||
@@ -216,86 +142,46 @@ class RedisMessageSystem:
|
||||
|
||||
return message_id
|
||||
|
||||
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]):
|
||||
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: ChatMessageDict):
|
||||
"""存储消息到 Redis"""
|
||||
try:
|
||||
# 检查是否是多人房间消息
|
||||
is_multiplayer = message_data.get("is_multiplayer", False)
|
||||
|
||||
# 存储消息数据
|
||||
await self.redis.hset(
|
||||
# 存储消息数据为 JSON 字符串
|
||||
await self.redis.set(
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) for k, v in message_data.items()},
|
||||
safe_json_dumps(message_data),
|
||||
ex=604800, # 7天过期
|
||||
)
|
||||
|
||||
# 设置消息过期时间(7天)
|
||||
await self.redis.expire(f"msg:{channel_id}:{message_id}", 604800)
|
||||
|
||||
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
|
||||
# 添加到频道消息列表(按时间排序)
|
||||
channel_messages_key = f"channel:{channel_id}:messages"
|
||||
|
||||
# 更健壮的键类型检查和清理
|
||||
# 检查并清理错误类型的键
|
||||
try:
|
||||
key_type = await self.redis.type(channel_messages_key)
|
||||
if key_type == "none":
|
||||
# 键不存在,这是正常的
|
||||
pass
|
||||
elif key_type != "zset":
|
||||
# 键类型错误,需要清理
|
||||
if key_type not in ("none", "zset"):
|
||||
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
|
||||
await self.redis.delete(channel_messages_key)
|
||||
|
||||
# 验证删除是否成功
|
||||
verify_type = await self.redis.type(channel_messages_key)
|
||||
if verify_type != "none":
|
||||
logger.error(
|
||||
f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}"
|
||||
)
|
||||
# 强制删除
|
||||
await self.redis.unlink(channel_messages_key)
|
||||
|
||||
except Exception as type_check_error:
|
||||
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
|
||||
# 如果检查失败,尝试强制删除键以确保清理
|
||||
try:
|
||||
await self.redis.delete(channel_messages_key)
|
||||
except Exception:
|
||||
# 最后的努力:使用unlink
|
||||
try:
|
||||
await self.redis.unlink(channel_messages_key)
|
||||
except Exception as final_error:
|
||||
logger.error(f"Critical: Unable to clear problematic key {channel_messages_key}: {final_error}")
|
||||
await self.redis.delete(channel_messages_key)
|
||||
|
||||
# 添加到频道消息列表(sorted set)
|
||||
try:
|
||||
await self.redis.zadd(
|
||||
channel_messages_key,
|
||||
mapping={f"msg:{channel_id}:{message_id}": message_id},
|
||||
)
|
||||
except Exception as zadd_error:
|
||||
logger.error(f"Failed to add message to sorted set {channel_messages_key}: {zadd_error}")
|
||||
# 如果添加失败,再次尝试清理并重试
|
||||
await self.redis.delete(channel_messages_key)
|
||||
await self.redis.zadd(
|
||||
channel_messages_key,
|
||||
mapping={f"msg:{channel_id}:{message_id}": message_id},
|
||||
)
|
||||
await self.redis.zadd(
|
||||
channel_messages_key,
|
||||
mapping={f"msg:{channel_id}:{message_id}": message_id},
|
||||
)
|
||||
|
||||
# 保持频道消息列表大小(最多1000条)
|
||||
await self.redis.zremrangebyrank(channel_messages_key, 0, -1001)
|
||||
|
||||
# 只有非多人房间消息才添加到待持久化队列
|
||||
if not is_multiplayer:
|
||||
await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
|
||||
logger.debug(f"Message {message_id} added to persistence queue")
|
||||
else:
|
||||
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
|
||||
await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
|
||||
logger.debug(f"Message {message_id} added to persistence queue")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store message to Redis: {e}")
|
||||
raise
|
||||
|
||||
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]:
|
||||
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
|
||||
"""从 Redis 获取消息"""
|
||||
try:
|
||||
# 获取消息键列表,按消息ID排序
|
||||
@@ -314,28 +200,16 @@ class RedisMessageSystem:
|
||||
|
||||
messages = []
|
||||
for key in message_keys:
|
||||
# 获取消息数据
|
||||
raw_data = await self.redis.hgetall(key)
|
||||
# 获取消息数据(JSON 字符串)
|
||||
raw_data = await self.redis.get(key)
|
||||
if raw_data:
|
||||
# 解码数据
|
||||
message_data: dict[str, Any] = {}
|
||||
for k, v in raw_data.items():
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
if k in ["grade_counts", "level"] or v.startswith(("{", "[")):
|
||||
message_data[k] = json.loads(v)
|
||||
elif k in ["message_id", "channel_id", "sender_id"]:
|
||||
message_data[k] = int(v)
|
||||
elif k == "is_multiplayer":
|
||||
message_data[k] = v == "True"
|
||||
elif k == "created_at":
|
||||
message_data[k] = float(v)
|
||||
else:
|
||||
message_data[k] = v
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
message_data[k] = v
|
||||
|
||||
messages.append(message_data)
|
||||
try:
|
||||
# 解析 JSON 字符串为字典
|
||||
message_data = json.loads(raw_data)
|
||||
messages.append(message_data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode message JSON from {key}: {e}")
|
||||
continue
|
||||
|
||||
# 确保消息按ID正序排序(时间顺序)
|
||||
messages.sort(key=lambda x: x.get("message_id", 0))
|
||||
@@ -350,15 +224,15 @@ class RedisMessageSystem:
|
||||
logger.error(f"Failed to get messages from Redis: {e}")
|
||||
return []
|
||||
|
||||
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int):
|
||||
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageDict], limit: int):
|
||||
"""从数据库补充历史消息"""
|
||||
try:
|
||||
# 找到最小的消息ID
|
||||
min_id = float("inf")
|
||||
if existing_messages:
|
||||
for msg in existing_messages:
|
||||
if msg.message_id is not None and msg.message_id < min_id:
|
||||
min_id = msg.message_id
|
||||
if msg["message_id"] is not None and msg["message_id"] < min_id:
|
||||
min_id = msg["message_id"]
|
||||
|
||||
needed = limit - len(existing_messages)
|
||||
|
||||
@@ -378,13 +252,13 @@ class RedisMessageSystem:
|
||||
db_messages = (await session.exec(query)).all()
|
||||
|
||||
for msg in reversed(db_messages): # 按时间正序插入
|
||||
msg_resp = await ChatMessageResp.from_db(msg, session)
|
||||
msg_resp = await ChatMessageModel.transform(msg, includes=["sender"])
|
||||
existing_messages.insert(0, msg_resp)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to backfill from database: {e}")
|
||||
|
||||
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]:
|
||||
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageDict]:
|
||||
"""仅从数据库获取消息(回退方案)"""
|
||||
try:
|
||||
async with with_db() as session:
|
||||
@@ -402,7 +276,7 @@ class RedisMessageSystem:
|
||||
|
||||
messages = (await session.exec(query)).all()
|
||||
|
||||
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
|
||||
results = await ChatMessageModel.transform_many(messages, includes=["sender"])
|
||||
|
||||
# 如果是 since > 0,保持正序;否则反转为时间正序
|
||||
if since == 0:
|
||||
@@ -450,27 +324,17 @@ class RedisMessageSystem:
|
||||
# 解析频道ID和消息ID
|
||||
channel_id, message_id = map(int, key.split(":"))
|
||||
|
||||
# 从 Redis 获取消息数据
|
||||
raw_data = await self.redis.hgetall(f"msg:{channel_id}:{message_id}")
|
||||
# 从 Redis 获取消息数据(JSON 字符串)
|
||||
raw_data = await self.redis.get(f"msg:{channel_id}:{message_id}")
|
||||
|
||||
if not raw_data:
|
||||
continue
|
||||
|
||||
# 解码数据
|
||||
message_data = {}
|
||||
for k, v in raw_data.items():
|
||||
message_data[k] = v
|
||||
|
||||
# 检查是否是多人房间消息,如果是则跳过数据库存储
|
||||
is_multiplayer = message_data.get("is_multiplayer", "False") == "True"
|
||||
if is_multiplayer:
|
||||
# 多人房间消息不存储到数据库,直接标记为已跳过
|
||||
await self.redis.hset(
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
"status",
|
||||
"skipped_multiplayer",
|
||||
)
|
||||
logger.debug(f"Message {message_id} in multiplayer room skipped from database storage")
|
||||
# 解析 JSON 字符串为字典
|
||||
try:
|
||||
message_data = json.loads(raw_data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode message JSON for {channel_id}:{message_id}: {e}")
|
||||
continue
|
||||
|
||||
# 检查消息是否已存在于数据库
|
||||
@@ -491,13 +355,6 @@ class RedisMessageSystem:
|
||||
|
||||
session.add(db_message)
|
||||
|
||||
# 更新 Redis 中的状态
|
||||
await self.redis.hset(
|
||||
f"msg:{channel_id}:{message_id}",
|
||||
"status",
|
||||
"persisted",
|
||||
)
|
||||
|
||||
logger.debug(f"Message {message_id} persisted to database")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -14,8 +14,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
async def create_playlist_room_from_api(session: AsyncSession, room: APIUploadedRoom, host_id: int) -> Room:
|
||||
db_room = room.to_room()
|
||||
db_room.host_id = host_id
|
||||
db_room = Room.model_validate({"host_id": host_id, **room.model_dump(exclude={"playlist"})})
|
||||
db_room.starts_at = utcnow()
|
||||
db_room.ends_at = db_room.starts_at + timedelta(minutes=db_room.duration if db_room.duration is not None else 0)
|
||||
session.add(db_room)
|
||||
|
||||
@@ -3,19 +3,19 @@
|
||||
用于缓存用户信息,提供热缓存和实时刷新功能
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.config import settings
|
||||
from app.const import BANCHOBOT_ID
|
||||
from app.database import User, UserResp
|
||||
from app.database.score import LegacyScoreResp, ScoreResp
|
||||
from app.database.user import SEARCH_INCLUDED
|
||||
from app.database import User
|
||||
from app.database.score import LegacyScoreResp
|
||||
from app.database.user import UserDict, UserModel
|
||||
from app.dependencies.database import with_db
|
||||
from app.helpers.asset_proxy_helper import replace_asset_urls
|
||||
from app.log import logger
|
||||
from app.models.score import GameMode
|
||||
from app.utils import safe_json_dumps
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
@@ -25,20 +25,6 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""自定义 JSON 编码器,支持 datetime 序列化"""
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
def safe_json_dumps(data: Any) -> str:
|
||||
"""安全的 JSON 序列化,支持 datetime 对象"""
|
||||
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False)
|
||||
|
||||
|
||||
class UserCacheService:
|
||||
"""用户缓存服务"""
|
||||
|
||||
@@ -125,7 +111,7 @@ class UserCacheService:
|
||||
"""生成用户谱面集缓存键"""
|
||||
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
|
||||
|
||||
async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserResp | None:
|
||||
async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserDict | None:
|
||||
"""从缓存获取用户信息"""
|
||||
try:
|
||||
cache_key = self._get_user_cache_key(user_id, ruleset)
|
||||
@@ -133,7 +119,7 @@ class UserCacheService:
|
||||
if cached_data:
|
||||
logger.debug(f"User cache hit for user {user_id}")
|
||||
data = json.loads(cached_data)
|
||||
return UserResp(**data)
|
||||
return data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user from cache: {e}")
|
||||
@@ -141,7 +127,7 @@ class UserCacheService:
|
||||
|
||||
async def cache_user(
|
||||
self,
|
||||
user_resp: UserResp,
|
||||
user_resp: UserDict,
|
||||
ruleset: GameMode | None = None,
|
||||
expire_seconds: int | None = None,
|
||||
):
|
||||
@@ -149,13 +135,10 @@ class UserCacheService:
|
||||
try:
|
||||
if expire_seconds is None:
|
||||
expire_seconds = settings.user_cache_expire_seconds
|
||||
if user_resp.id is None:
|
||||
logger.warning("Cannot cache user with None id")
|
||||
return
|
||||
cache_key = self._get_user_cache_key(user_resp.id, ruleset)
|
||||
cached_data = user_resp.model_dump_json()
|
||||
cache_key = self._get_user_cache_key(user_resp["id"], ruleset)
|
||||
cached_data = safe_json_dumps(user_resp)
|
||||
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
||||
logger.debug(f"Cached user {user_resp.id} for {expire_seconds}s")
|
||||
logger.debug(f"Cached user {user_resp['id']} for {expire_seconds}s")
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching user: {e}")
|
||||
|
||||
@@ -168,10 +151,9 @@ class UserCacheService:
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
is_legacy: bool = False,
|
||||
) -> list[ScoreResp] | list[LegacyScoreResp] | None:
|
||||
) -> list[UserDict] | list[LegacyScoreResp] | None:
|
||||
"""从缓存获取用户成绩"""
|
||||
try:
|
||||
model = LegacyScoreResp if is_legacy else ScoreResp
|
||||
cache_key = self._get_user_scores_cache_key(
|
||||
user_id, score_type, include_fail, mode, limit, offset, is_legacy
|
||||
)
|
||||
@@ -179,7 +161,7 @@ class UserCacheService:
|
||||
if cached_data:
|
||||
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
|
||||
data = json.loads(cached_data)
|
||||
return [model(**score_data) for score_data in data] # pyright: ignore[reportReturnType]
|
||||
return [LegacyScoreResp(**score_data) for score_data in data] if is_legacy else data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user scores from cache: {e}")
|
||||
@@ -189,7 +171,7 @@ class UserCacheService:
|
||||
self,
|
||||
user_id: int,
|
||||
score_type: str,
|
||||
scores: list[ScoreResp] | list[LegacyScoreResp],
|
||||
scores: list[UserDict] | list[LegacyScoreResp],
|
||||
include_fail: bool,
|
||||
mode: GameMode | None = None,
|
||||
limit: int = 100,
|
||||
@@ -204,8 +186,12 @@ class UserCacheService:
|
||||
cache_key = self._get_user_scores_cache_key(
|
||||
user_id, score_type, include_fail, mode, limit, offset, is_legacy
|
||||
)
|
||||
# 使用 model_dump_json() 而不是 model_dump() + json.dumps()
|
||||
scores_json_list = [score.model_dump_json() for score in scores]
|
||||
if len(scores) == 0:
|
||||
return
|
||||
if isinstance(scores[0], dict):
|
||||
scores_json_list = [safe_json_dumps(score) for score in scores]
|
||||
else:
|
||||
scores_json_list = [score.model_dump_json() for score in scores] # pyright: ignore[reportAttributeAccessIssue]
|
||||
cached_data = f"[{','.join(scores_json_list)}]"
|
||||
await self.redis.setex(cache_key, expire_seconds, cached_data)
|
||||
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
|
||||
@@ -308,7 +294,7 @@ class UserCacheService:
|
||||
for user in users:
|
||||
if user.id != BANCHOBOT_ID:
|
||||
try:
|
||||
await self._cache_single_user(user, session)
|
||||
await self._cache_single_user(user)
|
||||
cached_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cache user {user.id}: {e}")
|
||||
@@ -320,10 +306,10 @@ class UserCacheService:
|
||||
finally:
|
||||
self._refreshing = False
|
||||
|
||||
async def _cache_single_user(self, user: User, session: AsyncSession):
|
||||
async def _cache_single_user(self, user: User):
|
||||
"""缓存单个用户"""
|
||||
try:
|
||||
user_resp = await UserResp.from_db(user, session, include=SEARCH_INCLUDED)
|
||||
user_resp = await UserModel.transform(user, includes=User.USER_INCLUDES)
|
||||
|
||||
# 应用资源代理处理
|
||||
if settings.enable_asset_proxy:
|
||||
@@ -347,7 +333,7 @@ class UserCacheService:
|
||||
# 立即重新加载用户信息
|
||||
user = await session.get(User, user_id)
|
||||
if user and user.id != BANCHOBOT_ID:
|
||||
await self._cache_single_user(user, session)
|
||||
await self._cache_single_user(user)
|
||||
logger.info(f"Refreshed cache for user {user_id} after score submit")
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing user cache on score submit: {e}")
|
||||
|
||||
Reference in New Issue
Block a user