Optimization of score calculation
This commit is contained in:
@@ -86,26 +86,34 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
|||||||
f"Error checking if beatmap {score.beatmap_id} is suspicious"
|
f"Error checking if beatmap {score.beatmap_id} is suspicious"
|
||||||
)
|
)
|
||||||
|
|
||||||
map = rosu.Beatmap(content=beatmap)
|
# 使用线程池执行计算密集型操作以避免阻塞事件循环
|
||||||
mods = deepcopy(score.mods.copy())
|
import asyncio
|
||||||
parse_enum_to_str(int(score.gamemode), mods)
|
loop = asyncio.get_event_loop()
|
||||||
map.convert(score.gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
|
|
||||||
perf = rosu.Performance(
|
def _calculate_pp_sync():
|
||||||
mods=mods,
|
map = rosu.Beatmap(content=beatmap)
|
||||||
lazer=True,
|
mods = deepcopy(score.mods.copy())
|
||||||
accuracy=clamp(score.accuracy * 100, 0, 100),
|
parse_enum_to_str(int(score.gamemode), mods)
|
||||||
combo=score.max_combo,
|
map.convert(score.gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
|
||||||
large_tick_hits=score.nlarge_tick_hit or 0,
|
perf = rosu.Performance(
|
||||||
slider_end_hits=score.nslider_tail_hit or 0,
|
mods=mods,
|
||||||
small_tick_hits=score.nsmall_tick_hit or 0,
|
lazer=True,
|
||||||
n_geki=score.ngeki,
|
accuracy=clamp(score.accuracy * 100, 0, 100),
|
||||||
n_katu=score.nkatu,
|
combo=score.max_combo,
|
||||||
n300=score.n300,
|
large_tick_hits=score.nlarge_tick_hit or 0,
|
||||||
n100=score.n100,
|
slider_end_hits=score.nslider_tail_hit or 0,
|
||||||
n50=score.n50,
|
small_tick_hits=score.nsmall_tick_hit or 0,
|
||||||
misses=score.nmiss,
|
n_geki=score.ngeki,
|
||||||
)
|
n_katu=score.nkatu,
|
||||||
attrs = perf.calculate(map)
|
n300=score.n300,
|
||||||
|
n100=score.n100,
|
||||||
|
n50=score.n50,
|
||||||
|
misses=score.nmiss,
|
||||||
|
)
|
||||||
|
return perf.calculate(map)
|
||||||
|
|
||||||
|
# 在线程池中执行计算
|
||||||
|
attrs = await loop.run_in_executor(None, _calculate_pp_sync)
|
||||||
pp = attrs.pp
|
pp = attrs.pp
|
||||||
|
|
||||||
# mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp
|
# mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp
|
||||||
@@ -122,6 +130,132 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
|||||||
return pp
|
return pp
|
||||||
|
|
||||||
|
|
||||||
|
async def pre_fetch_and_calculate_pp(
|
||||||
|
score: "Score",
|
||||||
|
beatmap_id: int,
|
||||||
|
session: AsyncSession,
|
||||||
|
redis,
|
||||||
|
fetcher
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
优化版PP计算:预先获取beatmap文件并使用缓存
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from app.database.beatmap import BannedBeatmaps
|
||||||
|
|
||||||
|
# 快速检查是否被封禁
|
||||||
|
if settings.suspicious_score_check:
|
||||||
|
beatmap_banned = (
|
||||||
|
await session.exec(
|
||||||
|
select(exists()).where(
|
||||||
|
col(BannedBeatmaps.beatmap_id) == beatmap_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if beatmap_banned:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 异步获取beatmap原始文件,利用已有的Redis缓存机制
|
||||||
|
try:
|
||||||
|
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to fetch beatmap {beatmap_id}: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 在获取文件的同时,可以检查可疑beatmap
|
||||||
|
if settings.suspicious_score_check:
|
||||||
|
try:
|
||||||
|
# 将可疑检查也移到线程池中执行
|
||||||
|
def _check_suspicious():
|
||||||
|
return is_suspicious_beatmap(beatmap_raw)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
is_sus = await loop.run_in_executor(None, _check_suspicious)
|
||||||
|
if is_sus:
|
||||||
|
session.add(BannedBeatmaps(beatmap_id=beatmap_id))
|
||||||
|
logger.warning(f"Beatmap {beatmap_id} is suspicious, banned")
|
||||||
|
return 0
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Error checking if beatmap {beatmap_id} is suspicious")
|
||||||
|
|
||||||
|
# 调用已优化的PP计算函数
|
||||||
|
return await calculate_pp(score, beatmap_raw, session)
|
||||||
|
|
||||||
|
|
||||||
|
async def batch_calculate_pp(
|
||||||
|
scores_data: list[tuple["Score", int]],
|
||||||
|
session: AsyncSession,
|
||||||
|
redis,
|
||||||
|
fetcher
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
批量计算PP:适用于重新计算或批量处理场景
|
||||||
|
Args:
|
||||||
|
scores_data: [(score, beatmap_id), ...] 的列表
|
||||||
|
Returns:
|
||||||
|
对应的PP值列表
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from app.database.beatmap import BannedBeatmaps
|
||||||
|
|
||||||
|
if not scores_data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 提取所有唯一的beatmap_id
|
||||||
|
unique_beatmap_ids = list({beatmap_id for _, beatmap_id in scores_data})
|
||||||
|
|
||||||
|
# 批量检查被封禁的beatmap
|
||||||
|
banned_beatmaps = set()
|
||||||
|
if settings.suspicious_score_check:
|
||||||
|
banned_results = await session.exec(
|
||||||
|
select(BannedBeatmaps.beatmap_id).where(
|
||||||
|
col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
banned_beatmaps = set(banned_results.all())
|
||||||
|
|
||||||
|
# 并发获取所有需要的beatmap原始文件
|
||||||
|
async def fetch_beatmap_safe(beatmap_id: int) -> tuple[int, str | None]:
|
||||||
|
if beatmap_id in banned_beatmaps:
|
||||||
|
return beatmap_id, None
|
||||||
|
try:
|
||||||
|
content = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||||
|
return beatmap_id, content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to fetch beatmap {beatmap_id}: {e}")
|
||||||
|
return beatmap_id, None
|
||||||
|
|
||||||
|
# 并发获取所有beatmap文件
|
||||||
|
fetch_tasks = [fetch_beatmap_safe(bid) for bid in unique_beatmap_ids]
|
||||||
|
fetch_results = await asyncio.gather(*fetch_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# 构建beatmap_id -> content的映射
|
||||||
|
beatmap_contents = {}
|
||||||
|
for result in fetch_results:
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
beatmap_id, content = result
|
||||||
|
beatmap_contents[beatmap_id] = content
|
||||||
|
|
||||||
|
# 为每个score计算PP
|
||||||
|
pp_results = []
|
||||||
|
for score, beatmap_id in scores_data:
|
||||||
|
beatmap_content = beatmap_contents.get(beatmap_id)
|
||||||
|
if beatmap_content is None:
|
||||||
|
pp_results.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
pp = await calculate_pp(score, beatmap_content, session)
|
||||||
|
pp_results.append(pp)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to calculate PP for score {score.id}: {e}")
|
||||||
|
pp_results.append(0.0)
|
||||||
|
|
||||||
|
return pp_results
|
||||||
|
|
||||||
|
|
||||||
# https://osu.ppy.sh/wiki/Gameplay/Score/Total_score
|
# https://osu.ppy.sh/wiki/Gameplay/Score/Total_score
|
||||||
def calculate_level_to_score(n: int) -> float:
|
def calculate_level_to_score(n: int) -> float:
|
||||||
if n <= 100:
|
if n <= 100:
|
||||||
|
|||||||
@@ -137,6 +137,13 @@ class Settings(BaseSettings):
|
|||||||
enable_supporter_for_all_users: bool = False
|
enable_supporter_for_all_users: bool = False
|
||||||
enable_all_beatmap_leaderboard: bool = False
|
enable_all_beatmap_leaderboard: bool = False
|
||||||
enable_all_beatmap_pp: bool = False
|
enable_all_beatmap_pp: bool = False
|
||||||
|
# 性能优化设置
|
||||||
|
enable_beatmap_preload: bool = True
|
||||||
|
beatmap_cache_expire_hours: int = 24
|
||||||
|
max_concurrent_pp_calculations: int = 10
|
||||||
|
enable_pp_calculation_threading: bool = True
|
||||||
|
|
||||||
|
# 反作弊设置
|
||||||
suspicious_score_check: bool = True
|
suspicious_score_check: bool = True
|
||||||
seasonal_backgrounds: Annotated[list[str], BeforeValidator(_parse_list)] = []
|
seasonal_backgrounds: Annotated[list[str], BeforeValidator(_parse_list)] = []
|
||||||
banned_name: list[str] = [
|
banned_name: list[str] = [
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import math
|
|||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from app.calculator import (
|
from app.calculator import (
|
||||||
calculate_pp,
|
|
||||||
calculate_pp_weight,
|
calculate_pp_weight,
|
||||||
calculate_score_to_level,
|
calculate_score_to_level,
|
||||||
calculate_weighted_acc,
|
calculate_weighted_acc,
|
||||||
@@ -772,8 +771,10 @@ async def process_score(
|
|||||||
maximum_statistics=info.maximum_statistics,
|
maximum_statistics=info.maximum_statistics,
|
||||||
)
|
)
|
||||||
if can_get_pp:
|
if can_get_pp:
|
||||||
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
from app.calculator import pre_fetch_and_calculate_pp
|
||||||
pp = await calculate_pp(score, beatmap_raw, session)
|
pp = await pre_fetch_and_calculate_pp(
|
||||||
|
score, beatmap_id, session, redis, fetcher
|
||||||
|
)
|
||||||
score.pp = pp
|
score.pp = pp
|
||||||
session.add(score)
|
session.add(score)
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
@@ -799,5 +800,5 @@ async def process_score(
|
|||||||
await session.refresh(score)
|
await session.refresh(score)
|
||||||
await session.refresh(score_token)
|
await session.refresh(score_token)
|
||||||
await session.refresh(user)
|
await session.refresh(user)
|
||||||
await redis.publish("score:processed", score.id)
|
await redis.publish("score:processed", str(score.id or 0))
|
||||||
return score
|
return score
|
||||||
|
|||||||
@@ -37,8 +37,20 @@ class BeatmapRawFetcher(BaseFetcher):
|
|||||||
async def get_or_fetch_beatmap_raw(
|
async def get_or_fetch_beatmap_raw(
|
||||||
self, redis: redis.Redis, beatmap_id: int
|
self, redis: redis.Redis, beatmap_id: int
|
||||||
) -> str:
|
) -> str:
|
||||||
if await redis.exists(f"beatmap:{beatmap_id}:raw"):
|
from app.config import settings
|
||||||
return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
|
|
||||||
|
cache_key = f"beatmap:{beatmap_id}:raw"
|
||||||
|
cache_expire = settings.beatmap_cache_expire_hours * 60 * 60
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
if await redis.exists(cache_key):
|
||||||
|
content = await redis.get(cache_key)
|
||||||
|
if content:
|
||||||
|
# 延长缓存时间
|
||||||
|
await redis.expire(cache_key, cache_expire)
|
||||||
|
return content # pyright: ignore[reportReturnType]
|
||||||
|
|
||||||
|
# 获取并缓存
|
||||||
raw = await self.get_beatmap_raw(beatmap_id)
|
raw = await self.get_beatmap_raw(beatmap_id)
|
||||||
await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
|
await redis.set(cache_key, raw, ex=cache_expire)
|
||||||
return raw
|
return raw
|
||||||
|
|||||||
@@ -32,8 +32,6 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
q="",
|
q="",
|
||||||
s="leaderboard",
|
s="leaderboard",
|
||||||
sort=sort, # type: ignore
|
sort=sort, # type: ignore
|
||||||
# 不设置 nsfw 和 m,让它们使用默认值
|
|
||||||
# 这样 exclude_defaults=True 时它们会被排除
|
|
||||||
)
|
)
|
||||||
homepage_queries.append((query, {}))
|
homepage_queries.append((query, {}))
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from app.dependencies.fetcher import get_fetcher
|
|||||||
from app.dependencies.storage import get_storage_service
|
from app.dependencies.storage import get_storage_service
|
||||||
from app.dependencies.user import get_client_user, get_current_user
|
from app.dependencies.user import get_client_user, get_current_user
|
||||||
from app.fetcher import Fetcher
|
from app.fetcher import Fetcher
|
||||||
|
from app.log import logger
|
||||||
from app.models.room import RoomCategory
|
from app.models.room import RoomCategory
|
||||||
from app.models.score import (
|
from app.models.score import (
|
||||||
GameMode,
|
GameMode,
|
||||||
@@ -95,6 +96,14 @@ async def submit_score(
|
|||||||
if not score:
|
if not score:
|
||||||
raise HTTPException(status_code=404, detail="Score not found")
|
raise HTTPException(status_code=404, detail="Score not found")
|
||||||
else:
|
else:
|
||||||
|
# 智能预取beatmap缓存(异步进行,不阻塞主流程)
|
||||||
|
try:
|
||||||
|
from app.service.beatmap_cache_service import get_beatmap_cache_service
|
||||||
|
cache_service = get_beatmap_cache_service(redis, fetcher)
|
||||||
|
await cache_service.smart_preload_for_score(beatmap)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Beatmap preload failed for {beatmap}: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db_beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap)
|
db_beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap)
|
||||||
except HTTPError:
|
except HTTPError:
|
||||||
|
|||||||
174
app/service/beatmap_cache_service.py
Normal file
174
app/service/beatmap_cache_service.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""
|
||||||
|
Beatmap缓存预取服务
|
||||||
|
用于提前缓存热门beatmap,减少成绩计算时的获取延迟
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.log import logger
|
||||||
|
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
from sqlmodel import col, func, select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.fetcher import Fetcher
|
||||||
|
|
||||||
|
|
||||||
|
class BeatmapCacheService:
|
||||||
|
def __init__(self, redis: Redis, fetcher: "Fetcher"):
|
||||||
|
self.redis = redis
|
||||||
|
self.fetcher = fetcher
|
||||||
|
self._preloading = False
|
||||||
|
self._background_tasks: set = set()
|
||||||
|
|
||||||
|
async def preload_popular_beatmaps(self, session: AsyncSession, limit: int = 100):
|
||||||
|
"""
|
||||||
|
预加载热门beatmap到Redis缓存
|
||||||
|
"""
|
||||||
|
if self._preloading:
|
||||||
|
logger.info("Beatmap preloading already in progress")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._preloading = True
|
||||||
|
try:
|
||||||
|
logger.info(f"Starting preload of top {limit} popular beatmaps")
|
||||||
|
|
||||||
|
# 获取过去24小时内最热门的beatmap
|
||||||
|
recent_time = datetime.now(UTC) - timedelta(hours=24)
|
||||||
|
|
||||||
|
from app.database.score import Score
|
||||||
|
|
||||||
|
popular_beatmaps = (
|
||||||
|
await session.exec(
|
||||||
|
select(Score.beatmap_id, func.count().label("play_count"))
|
||||||
|
.where(col(Score.ended_at) >= recent_time)
|
||||||
|
.group_by(col(Score.beatmap_id))
|
||||||
|
.order_by(col("play_count").desc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
# 并发预取这些beatmap
|
||||||
|
preload_tasks = []
|
||||||
|
for beatmap_id, _ in popular_beatmaps:
|
||||||
|
task = self._preload_single_beatmap(beatmap_id)
|
||||||
|
preload_tasks.append(task)
|
||||||
|
|
||||||
|
if preload_tasks:
|
||||||
|
results = await asyncio.gather(*preload_tasks, return_exceptions=True)
|
||||||
|
success_count = sum(1 for r in results if r is True)
|
||||||
|
logger.info(
|
||||||
|
f"Preloaded {success_count}/{len(preload_tasks)} "
|
||||||
|
f"beatmaps successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during beatmap preloading: {e}")
|
||||||
|
finally:
|
||||||
|
self._preloading = False
|
||||||
|
|
||||||
|
async def _preload_single_beatmap(self, beatmap_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
预加载单个beatmap
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache_key = f"beatmap:{beatmap_id}:raw"
|
||||||
|
if await self.redis.exists(cache_key):
|
||||||
|
# 已经在缓存中,延长过期时间
|
||||||
|
await self.redis.expire(cache_key, 60 * 60 * 24)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 获取并缓存beatmap
|
||||||
|
content = await self.fetcher.get_beatmap_raw(beatmap_id)
|
||||||
|
await self.redis.set(cache_key, content, ex=60 * 60 * 24)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to preload beatmap {beatmap_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def smart_preload_for_score(self, beatmap_id: int):
|
||||||
|
"""
|
||||||
|
智能预加载:为即将提交的成绩预加载beatmap
|
||||||
|
"""
|
||||||
|
task = asyncio.create_task(self._preload_single_beatmap(beatmap_id))
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
|
||||||
|
async def get_cache_stats(self) -> dict:
|
||||||
|
"""
|
||||||
|
获取缓存统计信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
keys = await self.redis.keys("beatmap:*:raw")
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
for key in keys[:100]: # 限制检查数量以避免性能问题
|
||||||
|
try:
|
||||||
|
size = await self.redis.memory_usage(key)
|
||||||
|
if size:
|
||||||
|
total_size += size
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return {
|
||||||
|
"cached_beatmaps": len(keys),
|
||||||
|
"estimated_total_size_mb": (
|
||||||
|
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
|
||||||
|
),
|
||||||
|
"preloading": self._preloading,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting cache stats: {e}")
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
async def cleanup_old_cache(self, max_age_hours: int = 48):
|
||||||
|
"""
|
||||||
|
清理过期的缓存
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Cleaning up beatmap cache older than {max_age_hours} hours")
|
||||||
|
# Redis会自动清理过期的key,这里主要是记录日志
|
||||||
|
keys = await self.redis.keys("beatmap:*:raw")
|
||||||
|
logger.info(f"Current cache contains {len(keys)} beatmaps")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during cache cleanup: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# 全局缓存服务实例
|
||||||
|
_cache_service: BeatmapCacheService | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheService:
|
||||||
|
"""
|
||||||
|
获取beatmap缓存服务实例
|
||||||
|
"""
|
||||||
|
global _cache_service
|
||||||
|
if _cache_service is None:
|
||||||
|
_cache_service = BeatmapCacheService(redis, fetcher)
|
||||||
|
return _cache_service
|
||||||
|
|
||||||
|
|
||||||
|
async def schedule_preload_task(
|
||||||
|
session: AsyncSession,
|
||||||
|
redis: Redis,
|
||||||
|
fetcher: "Fetcher"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
定时预加载任务
|
||||||
|
"""
|
||||||
|
# 默认启用预加载,除非明确禁用
|
||||||
|
enable_preload = getattr(settings, "enable_beatmap_preload", True)
|
||||||
|
if not enable_preload:
|
||||||
|
return
|
||||||
|
|
||||||
|
cache_service = get_beatmap_cache_service(redis, fetcher)
|
||||||
|
try:
|
||||||
|
await cache_service.preload_popular_beatmaps(session, limit=200)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Scheduled preload task failed: {e}")
|
||||||
@@ -222,6 +222,10 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info(f"[SpectatorHub] {client.user_id} began playing {state.beatmap_id}")
|
logger.info(f"[SpectatorHub] {client.user_id} began playing {state.beatmap_id}")
|
||||||
|
|
||||||
|
# 预缓存beatmap文件以加速后续PP计算
|
||||||
|
await self._preload_beatmap_for_pp_calculation(state.beatmap_id)
|
||||||
|
|
||||||
await self.broadcast_group_call(
|
await self.broadcast_group_call(
|
||||||
self.group_id(user_id),
|
self.group_id(user_id),
|
||||||
"UserBeganPlaying",
|
"UserBeganPlaying",
|
||||||
@@ -446,3 +450,50 @@ class SpectatorHub(Hub[StoreClientState]):
|
|||||||
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||||
await self.call_noblock(target_client, "UserEndedWatching", user_id)
|
await self.call_noblock(target_client, "UserEndedWatching", user_id)
|
||||||
logger.info(f"[SpectatorHub] {user_id} ended watching {target_id}")
|
logger.info(f"[SpectatorHub] {user_id} ended watching {target_id}")
|
||||||
|
|
||||||
|
async def _preload_beatmap_for_pp_calculation(self, beatmap_id: int) -> None:
|
||||||
|
"""
|
||||||
|
预缓存beatmap文件以加速PP计算
|
||||||
|
当玩家开始游玩时异步预加载beatmap原始文件到Redis缓存
|
||||||
|
"""
|
||||||
|
# 检查是否启用了beatmap预加载功能
|
||||||
|
if not settings.enable_beatmap_preload:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 异步获取fetcher和redis连接
|
||||||
|
from app.dependencies.database import get_redis
|
||||||
|
from app.dependencies.fetcher import get_fetcher
|
||||||
|
|
||||||
|
fetcher = get_fetcher()
|
||||||
|
redis = get_redis()
|
||||||
|
|
||||||
|
# 检查是否已经缓存,避免重复下载
|
||||||
|
cache_key = f"beatmap:raw:{beatmap_id}"
|
||||||
|
if await redis.exists(cache_key):
|
||||||
|
logger.debug(f"Beatmap {beatmap_id} already cached, skipping preload")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 在后台异步预缓存beatmap文件,存储任务引用防止被回收
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._fetch_beatmap_background(fetcher, redis, beatmap_id)
|
||||||
|
)
|
||||||
|
# 任务完成后自动清理,避免内存泄漏
|
||||||
|
task.add_done_callback(lambda t: None)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 预缓存失败不应该影响正常游戏流程
|
||||||
|
logger.warning(f"Failed to preload beatmap {beatmap_id}: {e}")
|
||||||
|
|
||||||
|
async def _fetch_beatmap_background(self, fetcher, redis, beatmap_id: int) -> None:
|
||||||
|
"""
|
||||||
|
后台获取beatmap文件
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用fetcher的get_or_fetch_beatmap_raw方法预缓存
|
||||||
|
await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||||
|
logger.debug(
|
||||||
|
f"Successfully preloaded beatmap {beatmap_id} for PP calculation"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to preload beatmap {beatmap_id}: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user