Optimization of score calculation

This commit is contained in:
咕谷酱
2025-08-18 17:16:44 +08:00
parent e5f0cd1fd6
commit 7f512cec6e
8 changed files with 415 additions and 29 deletions

View File

@@ -86,26 +86,34 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
f"Error checking if beatmap {score.beatmap_id} is suspicious"
)
map = rosu.Beatmap(content=beatmap)
mods = deepcopy(score.mods.copy())
parse_enum_to_str(int(score.gamemode), mods)
map.convert(score.gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
perf = rosu.Performance(
mods=mods,
lazer=True,
accuracy=clamp(score.accuracy * 100, 0, 100),
combo=score.max_combo,
large_tick_hits=score.nlarge_tick_hit or 0,
slider_end_hits=score.nslider_tail_hit or 0,
small_tick_hits=score.nsmall_tick_hit or 0,
n_geki=score.ngeki,
n_katu=score.nkatu,
n300=score.n300,
n100=score.n100,
n50=score.n50,
misses=score.nmiss,
)
attrs = perf.calculate(map)
# 使用线程池执行计算密集型操作以避免阻塞事件循环
import asyncio
loop = asyncio.get_event_loop()
def _calculate_pp_sync():
map = rosu.Beatmap(content=beatmap)
mods = deepcopy(score.mods.copy())
parse_enum_to_str(int(score.gamemode), mods)
map.convert(score.gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
perf = rosu.Performance(
mods=mods,
lazer=True,
accuracy=clamp(score.accuracy * 100, 0, 100),
combo=score.max_combo,
large_tick_hits=score.nlarge_tick_hit or 0,
slider_end_hits=score.nslider_tail_hit or 0,
small_tick_hits=score.nsmall_tick_hit or 0,
n_geki=score.ngeki,
n_katu=score.nkatu,
n300=score.n300,
n100=score.n100,
n50=score.n50,
misses=score.nmiss,
)
return perf.calculate(map)
# 在线程池中执行计算
attrs = await loop.run_in_executor(None, _calculate_pp_sync)
pp = attrs.pp
# mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp
@@ -122,6 +130,132 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
return pp
async def pre_fetch_and_calculate_pp(
score: "Score",
beatmap_id: int,
session: AsyncSession,
redis,
fetcher
) -> float:
"""
优化版PP计算预先获取beatmap文件并使用缓存
"""
import asyncio
from app.database.beatmap import BannedBeatmaps
# 快速检查是否被封禁
if settings.suspicious_score_check:
beatmap_banned = (
await session.exec(
select(exists()).where(
col(BannedBeatmaps.beatmap_id) == beatmap_id
)
)
).first()
if beatmap_banned:
return 0
# 异步获取beatmap原始文件利用已有的Redis缓存机制
try:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
except Exception as e:
logger.error(f"Failed to fetch beatmap {beatmap_id}: {e}")
return 0
# 在获取文件的同时可以检查可疑beatmap
if settings.suspicious_score_check:
try:
# 将可疑检查也移到线程池中执行
def _check_suspicious():
return is_suspicious_beatmap(beatmap_raw)
loop = asyncio.get_event_loop()
is_sus = await loop.run_in_executor(None, _check_suspicious)
if is_sus:
session.add(BannedBeatmaps(beatmap_id=beatmap_id))
logger.warning(f"Beatmap {beatmap_id} is suspicious, banned")
return 0
except Exception:
logger.exception(f"Error checking if beatmap {beatmap_id} is suspicious")
# 调用已优化的PP计算函数
return await calculate_pp(score, beatmap_raw, session)
async def batch_calculate_pp(
scores_data: list[tuple["Score", int]],
session: AsyncSession,
redis,
fetcher
) -> list[float]:
"""
批量计算PP适用于重新计算或批量处理场景
Args:
scores_data: [(score, beatmap_id), ...] 的列表
Returns:
对应的PP值列表
"""
import asyncio
from app.database.beatmap import BannedBeatmaps
if not scores_data:
return []
# 提取所有唯一的beatmap_id
unique_beatmap_ids = list({beatmap_id for _, beatmap_id in scores_data})
# 批量检查被封禁的beatmap
banned_beatmaps = set()
if settings.suspicious_score_check:
banned_results = await session.exec(
select(BannedBeatmaps.beatmap_id).where(
col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids)
)
)
banned_beatmaps = set(banned_results.all())
# 并发获取所有需要的beatmap原始文件
async def fetch_beatmap_safe(beatmap_id: int) -> tuple[int, str | None]:
if beatmap_id in banned_beatmaps:
return beatmap_id, None
try:
content = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
return beatmap_id, content
except Exception as e:
logger.error(f"Failed to fetch beatmap {beatmap_id}: {e}")
return beatmap_id, None
# 并发获取所有beatmap文件
fetch_tasks = [fetch_beatmap_safe(bid) for bid in unique_beatmap_ids]
fetch_results = await asyncio.gather(*fetch_tasks, return_exceptions=True)
# 构建beatmap_id -> content的映射
beatmap_contents = {}
for result in fetch_results:
if isinstance(result, tuple):
beatmap_id, content = result
beatmap_contents[beatmap_id] = content
# 为每个score计算PP
pp_results = []
for score, beatmap_id in scores_data:
beatmap_content = beatmap_contents.get(beatmap_id)
if beatmap_content is None:
pp_results.append(0.0)
continue
try:
pp = await calculate_pp(score, beatmap_content, session)
pp_results.append(pp)
except Exception as e:
logger.error(f"Failed to calculate PP for score {score.id}: {e}")
pp_results.append(0.0)
return pp_results
# https://osu.ppy.sh/wiki/Gameplay/Score/Total_score
def calculate_level_to_score(n: int) -> float:
if n <= 100:

View File

@@ -137,6 +137,13 @@ class Settings(BaseSettings):
enable_supporter_for_all_users: bool = False
enable_all_beatmap_leaderboard: bool = False
enable_all_beatmap_pp: bool = False
# 性能优化设置
enable_beatmap_preload: bool = True
beatmap_cache_expire_hours: int = 24
max_concurrent_pp_calculations: int = 10
enable_pp_calculation_threading: bool = True
# 反作弊设置
suspicious_score_check: bool = True
seasonal_backgrounds: Annotated[list[str], BeforeValidator(_parse_list)] = []
banned_name: list[str] = [

View File

@@ -5,7 +5,6 @@ import math
from typing import TYPE_CHECKING, Any
from app.calculator import (
calculate_pp,
calculate_pp_weight,
calculate_score_to_level,
calculate_weighted_acc,
@@ -772,8 +771,10 @@ async def process_score(
maximum_statistics=info.maximum_statistics,
)
if can_get_pp:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
pp = await calculate_pp(score, beatmap_raw, session)
from app.calculator import pre_fetch_and_calculate_pp
pp = await pre_fetch_and_calculate_pp(
score, beatmap_id, session, redis, fetcher
)
score.pp = pp
session.add(score)
user_id = user.id
@@ -799,5 +800,5 @@ async def process_score(
await session.refresh(score)
await session.refresh(score_token)
await session.refresh(user)
await redis.publish("score:processed", score.id)
await redis.publish("score:processed", str(score.id or 0))
return score

View File

@@ -37,8 +37,20 @@ class BeatmapRawFetcher(BaseFetcher):
async def get_or_fetch_beatmap_raw(
self, redis: redis.Redis, beatmap_id: int
) -> str:
if await redis.exists(f"beatmap:{beatmap_id}:raw"):
return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
from app.config import settings
cache_key = f"beatmap:{beatmap_id}:raw"
cache_expire = settings.beatmap_cache_expire_hours * 60 * 60
# 检查缓存
if await redis.exists(cache_key):
content = await redis.get(cache_key)
if content:
# 延长缓存时间
await redis.expire(cache_key, cache_expire)
return content # pyright: ignore[reportReturnType]
# 获取并缓存
raw = await self.get_beatmap_raw(beatmap_id)
await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
await redis.set(cache_key, raw, ex=cache_expire)
return raw

View File

@@ -32,8 +32,6 @@ class BeatmapsetFetcher(BaseFetcher):
q="",
s="leaderboard",
sort=sort, # type: ignore
# 不设置 nsfw 和 m让它们使用默认值
# 这样 exclude_defaults=True 时它们会被排除
)
homepage_queries.append((query, {}))

View File

@@ -37,6 +37,7 @@ from app.dependencies.fetcher import get_fetcher
from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_client_user, get_current_user
from app.fetcher import Fetcher
from app.log import logger
from app.models.room import RoomCategory
from app.models.score import (
GameMode,
@@ -95,6 +96,14 @@ async def submit_score(
if not score:
raise HTTPException(status_code=404, detail="Score not found")
else:
# 智能预取beatmap缓存异步进行不阻塞主流程
try:
from app.service.beatmap_cache_service import get_beatmap_cache_service
cache_service = get_beatmap_cache_service(redis, fetcher)
await cache_service.smart_preload_for_score(beatmap)
except Exception as e:
logger.debug(f"Beatmap preload failed for {beatmap}: {e}")
try:
db_beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap)
except HTTPError:

View File

@@ -0,0 +1,174 @@
"""
Beatmap缓存预取服务
用于提前缓存热门beatmap减少成绩计算时的获取延迟
"""
from __future__ import annotations
import asyncio
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING
from app.config import settings
from app.log import logger
from redis.asyncio import Redis
from sqlmodel import col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
class BeatmapCacheService:
def __init__(self, redis: Redis, fetcher: "Fetcher"):
self.redis = redis
self.fetcher = fetcher
self._preloading = False
self._background_tasks: set = set()
async def preload_popular_beatmaps(self, session: AsyncSession, limit: int = 100):
"""
预加载热门beatmap到Redis缓存
"""
if self._preloading:
logger.info("Beatmap preloading already in progress")
return
self._preloading = True
try:
logger.info(f"Starting preload of top {limit} popular beatmaps")
# 获取过去24小时内最热门的beatmap
recent_time = datetime.now(UTC) - timedelta(hours=24)
from app.database.score import Score
popular_beatmaps = (
await session.exec(
select(Score.beatmap_id, func.count().label("play_count"))
.where(col(Score.ended_at) >= recent_time)
.group_by(col(Score.beatmap_id))
.order_by(col("play_count").desc())
.limit(limit)
)
).all()
# 并发预取这些beatmap
preload_tasks = []
for beatmap_id, _ in popular_beatmaps:
task = self._preload_single_beatmap(beatmap_id)
preload_tasks.append(task)
if preload_tasks:
results = await asyncio.gather(*preload_tasks, return_exceptions=True)
success_count = sum(1 for r in results if r is True)
logger.info(
f"Preloaded {success_count}/{len(preload_tasks)} "
f"beatmaps successfully"
)
except Exception as e:
logger.error(f"Error during beatmap preloading: {e}")
finally:
self._preloading = False
async def _preload_single_beatmap(self, beatmap_id: int) -> bool:
"""
预加载单个beatmap
"""
try:
cache_key = f"beatmap:{beatmap_id}:raw"
if await self.redis.exists(cache_key):
# 已经在缓存中,延长过期时间
await self.redis.expire(cache_key, 60 * 60 * 24)
return True
# 获取并缓存beatmap
content = await self.fetcher.get_beatmap_raw(beatmap_id)
await self.redis.set(cache_key, content, ex=60 * 60 * 24)
return True
except Exception as e:
logger.debug(f"Failed to preload beatmap {beatmap_id}: {e}")
return False
async def smart_preload_for_score(self, beatmap_id: int):
"""
智能预加载为即将提交的成绩预加载beatmap
"""
task = asyncio.create_task(self._preload_single_beatmap(beatmap_id))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
async def get_cache_stats(self) -> dict:
"""
获取缓存统计信息
"""
try:
keys = await self.redis.keys("beatmap:*:raw")
total_size = 0
for key in keys[:100]: # 限制检查数量以避免性能问题
try:
size = await self.redis.memory_usage(key)
if size:
total_size += size
except Exception:
continue
return {
"cached_beatmaps": len(keys),
"estimated_total_size_mb": (
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
),
"preloading": self._preloading,
}
except Exception as e:
logger.error(f"Error getting cache stats: {e}")
return {"error": str(e)}
async def cleanup_old_cache(self, max_age_hours: int = 48):
"""
清理过期的缓存
"""
try:
logger.info(f"Cleaning up beatmap cache older than {max_age_hours} hours")
# Redis会自动清理过期的key这里主要是记录日志
keys = await self.redis.keys("beatmap:*:raw")
logger.info(f"Current cache contains {len(keys)} beatmaps")
except Exception as e:
logger.error(f"Error during cache cleanup: {e}")
# 全局缓存服务实例
_cache_service: BeatmapCacheService | None = None
def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheService:
"""
获取beatmap缓存服务实例
"""
global _cache_service
if _cache_service is None:
_cache_service = BeatmapCacheService(redis, fetcher)
return _cache_service
async def schedule_preload_task(
session: AsyncSession,
redis: Redis,
fetcher: "Fetcher"
):
"""
定时预加载任务
"""
# 默认启用预加载,除非明确禁用
enable_preload = getattr(settings, "enable_beatmap_preload", True)
if not enable_preload:
return
cache_service = get_beatmap_cache_service(redis, fetcher)
try:
await cache_service.preload_popular_beatmaps(session, limit=200)
except Exception as e:
logger.error(f"Scheduled preload task failed: {e}")

View File

@@ -222,6 +222,10 @@ class SpectatorHub(Hub[StoreClientState]):
)
)
logger.info(f"[SpectatorHub] {client.user_id} began playing {state.beatmap_id}")
# 预缓存beatmap文件以加速后续PP计算
await self._preload_beatmap_for_pp_calculation(state.beatmap_id)
await self.broadcast_group_call(
self.group_id(user_id),
"UserBeganPlaying",
@@ -446,3 +450,50 @@ class SpectatorHub(Hub[StoreClientState]):
if (target_client := self.get_client_by_id(str(target_id))) is not None:
await self.call_noblock(target_client, "UserEndedWatching", user_id)
logger.info(f"[SpectatorHub] {user_id} ended watching {target_id}")
async def _preload_beatmap_for_pp_calculation(self, beatmap_id: int) -> None:
"""
预缓存beatmap文件以加速PP计算
当玩家开始游玩时异步预加载beatmap原始文件到Redis缓存
"""
# 检查是否启用了beatmap预加载功能
if not settings.enable_beatmap_preload:
return
try:
# 异步获取fetcher和redis连接
from app.dependencies.database import get_redis
from app.dependencies.fetcher import get_fetcher
fetcher = get_fetcher()
redis = get_redis()
# 检查是否已经缓存,避免重复下载
cache_key = f"beatmap:raw:{beatmap_id}"
if await redis.exists(cache_key):
logger.debug(f"Beatmap {beatmap_id} already cached, skipping preload")
return
# 在后台异步预缓存beatmap文件存储任务引用防止被回收
task = asyncio.create_task(
self._fetch_beatmap_background(fetcher, redis, beatmap_id)
)
# 任务完成后自动清理,避免内存泄漏
task.add_done_callback(lambda t: None)
except Exception as e:
# 预缓存失败不应该影响正常游戏流程
logger.warning(f"Failed to preload beatmap {beatmap_id}: {e}")
async def _fetch_beatmap_background(self, fetcher, redis, beatmap_id: int) -> None:
"""
后台获取beatmap文件
"""
try:
# 使用fetcher的get_or_fetch_beatmap_raw方法预缓存
await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
logger.debug(
f"Successfully preloaded beatmap {beatmap_id} for PP calculation"
)
except Exception as e:
logger.debug(f"Failed to preload beatmap {beatmap_id}: {e}")