Optimization of score calculation
This commit is contained in:
@@ -86,26 +86,34 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
||||
f"Error checking if beatmap {score.beatmap_id} is suspicious"
|
||||
)
|
||||
|
||||
map = rosu.Beatmap(content=beatmap)
|
||||
mods = deepcopy(score.mods.copy())
|
||||
parse_enum_to_str(int(score.gamemode), mods)
|
||||
map.convert(score.gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
|
||||
perf = rosu.Performance(
|
||||
mods=mods,
|
||||
lazer=True,
|
||||
accuracy=clamp(score.accuracy * 100, 0, 100),
|
||||
combo=score.max_combo,
|
||||
large_tick_hits=score.nlarge_tick_hit or 0,
|
||||
slider_end_hits=score.nslider_tail_hit or 0,
|
||||
small_tick_hits=score.nsmall_tick_hit or 0,
|
||||
n_geki=score.ngeki,
|
||||
n_katu=score.nkatu,
|
||||
n300=score.n300,
|
||||
n100=score.n100,
|
||||
n50=score.n50,
|
||||
misses=score.nmiss,
|
||||
)
|
||||
attrs = perf.calculate(map)
|
||||
# 使用线程池执行计算密集型操作以避免阻塞事件循环
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _calculate_pp_sync():
|
||||
map = rosu.Beatmap(content=beatmap)
|
||||
mods = deepcopy(score.mods.copy())
|
||||
parse_enum_to_str(int(score.gamemode), mods)
|
||||
map.convert(score.gamemode.to_rosu(), mods) # pyright: ignore[reportArgumentType]
|
||||
perf = rosu.Performance(
|
||||
mods=mods,
|
||||
lazer=True,
|
||||
accuracy=clamp(score.accuracy * 100, 0, 100),
|
||||
combo=score.max_combo,
|
||||
large_tick_hits=score.nlarge_tick_hit or 0,
|
||||
slider_end_hits=score.nslider_tail_hit or 0,
|
||||
small_tick_hits=score.nsmall_tick_hit or 0,
|
||||
n_geki=score.ngeki,
|
||||
n_katu=score.nkatu,
|
||||
n300=score.n300,
|
||||
n100=score.n100,
|
||||
n50=score.n50,
|
||||
misses=score.nmiss,
|
||||
)
|
||||
return perf.calculate(map)
|
||||
|
||||
# 在线程池中执行计算
|
||||
attrs = await loop.run_in_executor(None, _calculate_pp_sync)
|
||||
pp = attrs.pp
|
||||
|
||||
# mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp
|
||||
@@ -122,6 +130,132 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
||||
return pp
|
||||
|
||||
|
||||
async def pre_fetch_and_calculate_pp(
|
||||
score: "Score",
|
||||
beatmap_id: int,
|
||||
session: AsyncSession,
|
||||
redis,
|
||||
fetcher
|
||||
) -> float:
|
||||
"""
|
||||
优化版PP计算:预先获取beatmap文件并使用缓存
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.database.beatmap import BannedBeatmaps
|
||||
|
||||
# 快速检查是否被封禁
|
||||
if settings.suspicious_score_check:
|
||||
beatmap_banned = (
|
||||
await session.exec(
|
||||
select(exists()).where(
|
||||
col(BannedBeatmaps.beatmap_id) == beatmap_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if beatmap_banned:
|
||||
return 0
|
||||
|
||||
# 异步获取beatmap原始文件,利用已有的Redis缓存机制
|
||||
try:
|
||||
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch beatmap {beatmap_id}: {e}")
|
||||
return 0
|
||||
|
||||
# 在获取文件的同时,可以检查可疑beatmap
|
||||
if settings.suspicious_score_check:
|
||||
try:
|
||||
# 将可疑检查也移到线程池中执行
|
||||
def _check_suspicious():
|
||||
return is_suspicious_beatmap(beatmap_raw)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
is_sus = await loop.run_in_executor(None, _check_suspicious)
|
||||
if is_sus:
|
||||
session.add(BannedBeatmaps(beatmap_id=beatmap_id))
|
||||
logger.warning(f"Beatmap {beatmap_id} is suspicious, banned")
|
||||
return 0
|
||||
except Exception:
|
||||
logger.exception(f"Error checking if beatmap {beatmap_id} is suspicious")
|
||||
|
||||
# 调用已优化的PP计算函数
|
||||
return await calculate_pp(score, beatmap_raw, session)
|
||||
|
||||
|
||||
async def batch_calculate_pp(
|
||||
scores_data: list[tuple["Score", int]],
|
||||
session: AsyncSession,
|
||||
redis,
|
||||
fetcher
|
||||
) -> list[float]:
|
||||
"""
|
||||
批量计算PP:适用于重新计算或批量处理场景
|
||||
Args:
|
||||
scores_data: [(score, beatmap_id), ...] 的列表
|
||||
Returns:
|
||||
对应的PP值列表
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.database.beatmap import BannedBeatmaps
|
||||
|
||||
if not scores_data:
|
||||
return []
|
||||
|
||||
# 提取所有唯一的beatmap_id
|
||||
unique_beatmap_ids = list({beatmap_id for _, beatmap_id in scores_data})
|
||||
|
||||
# 批量检查被封禁的beatmap
|
||||
banned_beatmaps = set()
|
||||
if settings.suspicious_score_check:
|
||||
banned_results = await session.exec(
|
||||
select(BannedBeatmaps.beatmap_id).where(
|
||||
col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids)
|
||||
)
|
||||
)
|
||||
banned_beatmaps = set(banned_results.all())
|
||||
|
||||
# 并发获取所有需要的beatmap原始文件
|
||||
async def fetch_beatmap_safe(beatmap_id: int) -> tuple[int, str | None]:
|
||||
if beatmap_id in banned_beatmaps:
|
||||
return beatmap_id, None
|
||||
try:
|
||||
content = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
return beatmap_id, content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch beatmap {beatmap_id}: {e}")
|
||||
return beatmap_id, None
|
||||
|
||||
# 并发获取所有beatmap文件
|
||||
fetch_tasks = [fetch_beatmap_safe(bid) for bid in unique_beatmap_ids]
|
||||
fetch_results = await asyncio.gather(*fetch_tasks, return_exceptions=True)
|
||||
|
||||
# 构建beatmap_id -> content的映射
|
||||
beatmap_contents = {}
|
||||
for result in fetch_results:
|
||||
if isinstance(result, tuple):
|
||||
beatmap_id, content = result
|
||||
beatmap_contents[beatmap_id] = content
|
||||
|
||||
# 为每个score计算PP
|
||||
pp_results = []
|
||||
for score, beatmap_id in scores_data:
|
||||
beatmap_content = beatmap_contents.get(beatmap_id)
|
||||
if beatmap_content is None:
|
||||
pp_results.append(0.0)
|
||||
continue
|
||||
|
||||
try:
|
||||
pp = await calculate_pp(score, beatmap_content, session)
|
||||
pp_results.append(pp)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to calculate PP for score {score.id}: {e}")
|
||||
pp_results.append(0.0)
|
||||
|
||||
return pp_results
|
||||
|
||||
|
||||
# https://osu.ppy.sh/wiki/Gameplay/Score/Total_score
|
||||
def calculate_level_to_score(n: int) -> float:
|
||||
if n <= 100:
|
||||
|
||||
@@ -137,6 +137,13 @@ class Settings(BaseSettings):
|
||||
enable_supporter_for_all_users: bool = False
|
||||
enable_all_beatmap_leaderboard: bool = False
|
||||
enable_all_beatmap_pp: bool = False
|
||||
# 性能优化设置
|
||||
enable_beatmap_preload: bool = True
|
||||
beatmap_cache_expire_hours: int = 24
|
||||
max_concurrent_pp_calculations: int = 10
|
||||
enable_pp_calculation_threading: bool = True
|
||||
|
||||
# 反作弊设置
|
||||
suspicious_score_check: bool = True
|
||||
seasonal_backgrounds: Annotated[list[str], BeforeValidator(_parse_list)] = []
|
||||
banned_name: list[str] = [
|
||||
|
||||
@@ -5,7 +5,6 @@ import math
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.calculator import (
|
||||
calculate_pp,
|
||||
calculate_pp_weight,
|
||||
calculate_score_to_level,
|
||||
calculate_weighted_acc,
|
||||
@@ -772,8 +771,10 @@ async def process_score(
|
||||
maximum_statistics=info.maximum_statistics,
|
||||
)
|
||||
if can_get_pp:
|
||||
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
pp = await calculate_pp(score, beatmap_raw, session)
|
||||
from app.calculator import pre_fetch_and_calculate_pp
|
||||
pp = await pre_fetch_and_calculate_pp(
|
||||
score, beatmap_id, session, redis, fetcher
|
||||
)
|
||||
score.pp = pp
|
||||
session.add(score)
|
||||
user_id = user.id
|
||||
@@ -799,5 +800,5 @@ async def process_score(
|
||||
await session.refresh(score)
|
||||
await session.refresh(score_token)
|
||||
await session.refresh(user)
|
||||
await redis.publish("score:processed", score.id)
|
||||
await redis.publish("score:processed", str(score.id or 0))
|
||||
return score
|
||||
|
||||
@@ -37,8 +37,20 @@ class BeatmapRawFetcher(BaseFetcher):
|
||||
async def get_or_fetch_beatmap_raw(
|
||||
self, redis: redis.Redis, beatmap_id: int
|
||||
) -> str:
|
||||
if await redis.exists(f"beatmap:{beatmap_id}:raw"):
|
||||
return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
|
||||
from app.config import settings
|
||||
|
||||
cache_key = f"beatmap:{beatmap_id}:raw"
|
||||
cache_expire = settings.beatmap_cache_expire_hours * 60 * 60
|
||||
|
||||
# 检查缓存
|
||||
if await redis.exists(cache_key):
|
||||
content = await redis.get(cache_key)
|
||||
if content:
|
||||
# 延长缓存时间
|
||||
await redis.expire(cache_key, cache_expire)
|
||||
return content # pyright: ignore[reportReturnType]
|
||||
|
||||
# 获取并缓存
|
||||
raw = await self.get_beatmap_raw(beatmap_id)
|
||||
await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
|
||||
await redis.set(cache_key, raw, ex=cache_expire)
|
||||
return raw
|
||||
|
||||
@@ -32,8 +32,6 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
q="",
|
||||
s="leaderboard",
|
||||
sort=sort, # type: ignore
|
||||
# 不设置 nsfw 和 m,让它们使用默认值
|
||||
# 这样 exclude_defaults=True 时它们会被排除
|
||||
)
|
||||
homepage_queries.append((query, {}))
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.storage import get_storage_service
|
||||
from app.dependencies.user import get_client_user, get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.log import logger
|
||||
from app.models.room import RoomCategory
|
||||
from app.models.score import (
|
||||
GameMode,
|
||||
@@ -95,6 +96,14 @@ async def submit_score(
|
||||
if not score:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
else:
|
||||
# 智能预取beatmap缓存(异步进行,不阻塞主流程)
|
||||
try:
|
||||
from app.service.beatmap_cache_service import get_beatmap_cache_service
|
||||
cache_service = get_beatmap_cache_service(redis, fetcher)
|
||||
await cache_service.smart_preload_for_score(beatmap)
|
||||
except Exception as e:
|
||||
logger.debug(f"Beatmap preload failed for {beatmap}: {e}")
|
||||
|
||||
try:
|
||||
db_beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap)
|
||||
except HTTPError:
|
||||
|
||||
174
app/service/beatmap_cache_service.py
Normal file
174
app/service/beatmap_cache_service.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Beatmap缓存预取服务
|
||||
用于提前缓存热门beatmap,减少成绩计算时的获取延迟
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.config import settings
|
||||
from app.log import logger
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
|
||||
class BeatmapCacheService:
|
||||
def __init__(self, redis: Redis, fetcher: "Fetcher"):
|
||||
self.redis = redis
|
||||
self.fetcher = fetcher
|
||||
self._preloading = False
|
||||
self._background_tasks: set = set()
|
||||
|
||||
async def preload_popular_beatmaps(self, session: AsyncSession, limit: int = 100):
|
||||
"""
|
||||
预加载热门beatmap到Redis缓存
|
||||
"""
|
||||
if self._preloading:
|
||||
logger.info("Beatmap preloading already in progress")
|
||||
return
|
||||
|
||||
self._preloading = True
|
||||
try:
|
||||
logger.info(f"Starting preload of top {limit} popular beatmaps")
|
||||
|
||||
# 获取过去24小时内最热门的beatmap
|
||||
recent_time = datetime.now(UTC) - timedelta(hours=24)
|
||||
|
||||
from app.database.score import Score
|
||||
|
||||
popular_beatmaps = (
|
||||
await session.exec(
|
||||
select(Score.beatmap_id, func.count().label("play_count"))
|
||||
.where(col(Score.ended_at) >= recent_time)
|
||||
.group_by(col(Score.beatmap_id))
|
||||
.order_by(col("play_count").desc())
|
||||
.limit(limit)
|
||||
)
|
||||
).all()
|
||||
|
||||
# 并发预取这些beatmap
|
||||
preload_tasks = []
|
||||
for beatmap_id, _ in popular_beatmaps:
|
||||
task = self._preload_single_beatmap(beatmap_id)
|
||||
preload_tasks.append(task)
|
||||
|
||||
if preload_tasks:
|
||||
results = await asyncio.gather(*preload_tasks, return_exceptions=True)
|
||||
success_count = sum(1 for r in results if r is True)
|
||||
logger.info(
|
||||
f"Preloaded {success_count}/{len(preload_tasks)} "
|
||||
f"beatmaps successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during beatmap preloading: {e}")
|
||||
finally:
|
||||
self._preloading = False
|
||||
|
||||
async def _preload_single_beatmap(self, beatmap_id: int) -> bool:
|
||||
"""
|
||||
预加载单个beatmap
|
||||
"""
|
||||
try:
|
||||
cache_key = f"beatmap:{beatmap_id}:raw"
|
||||
if await self.redis.exists(cache_key):
|
||||
# 已经在缓存中,延长过期时间
|
||||
await self.redis.expire(cache_key, 60 * 60 * 24)
|
||||
return True
|
||||
|
||||
# 获取并缓存beatmap
|
||||
content = await self.fetcher.get_beatmap_raw(beatmap_id)
|
||||
await self.redis.set(cache_key, content, ex=60 * 60 * 24)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to preload beatmap {beatmap_id}: {e}")
|
||||
return False
|
||||
|
||||
async def smart_preload_for_score(self, beatmap_id: int):
|
||||
"""
|
||||
智能预加载:为即将提交的成绩预加载beatmap
|
||||
"""
|
||||
task = asyncio.create_task(self._preload_single_beatmap(beatmap_id))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def get_cache_stats(self) -> dict:
|
||||
"""
|
||||
获取缓存统计信息
|
||||
"""
|
||||
try:
|
||||
keys = await self.redis.keys("beatmap:*:raw")
|
||||
total_size = 0
|
||||
|
||||
for key in keys[:100]: # 限制检查数量以避免性能问题
|
||||
try:
|
||||
size = await self.redis.memory_usage(key)
|
||||
if size:
|
||||
total_size += size
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return {
|
||||
"cached_beatmaps": len(keys),
|
||||
"estimated_total_size_mb": (
|
||||
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
|
||||
),
|
||||
"preloading": self._preloading,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting cache stats: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def cleanup_old_cache(self, max_age_hours: int = 48):
|
||||
"""
|
||||
清理过期的缓存
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Cleaning up beatmap cache older than {max_age_hours} hours")
|
||||
# Redis会自动清理过期的key,这里主要是记录日志
|
||||
keys = await self.redis.keys("beatmap:*:raw")
|
||||
logger.info(f"Current cache contains {len(keys)} beatmaps")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cache cleanup: {e}")
|
||||
|
||||
|
||||
# 全局缓存服务实例
|
||||
_cache_service: BeatmapCacheService | None = None
|
||||
|
||||
|
||||
def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheService:
|
||||
"""
|
||||
获取beatmap缓存服务实例
|
||||
"""
|
||||
global _cache_service
|
||||
if _cache_service is None:
|
||||
_cache_service = BeatmapCacheService(redis, fetcher)
|
||||
return _cache_service
|
||||
|
||||
|
||||
async def schedule_preload_task(
|
||||
session: AsyncSession,
|
||||
redis: Redis,
|
||||
fetcher: "Fetcher"
|
||||
):
|
||||
"""
|
||||
定时预加载任务
|
||||
"""
|
||||
# 默认启用预加载,除非明确禁用
|
||||
enable_preload = getattr(settings, "enable_beatmap_preload", True)
|
||||
if not enable_preload:
|
||||
return
|
||||
|
||||
cache_service = get_beatmap_cache_service(redis, fetcher)
|
||||
try:
|
||||
await cache_service.preload_popular_beatmaps(session, limit=200)
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduled preload task failed: {e}")
|
||||
@@ -222,6 +222,10 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
)
|
||||
)
|
||||
logger.info(f"[SpectatorHub] {client.user_id} began playing {state.beatmap_id}")
|
||||
|
||||
# 预缓存beatmap文件以加速后续PP计算
|
||||
await self._preload_beatmap_for_pp_calculation(state.beatmap_id)
|
||||
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserBeganPlaying",
|
||||
@@ -446,3 +450,50 @@ class SpectatorHub(Hub[StoreClientState]):
|
||||
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||
await self.call_noblock(target_client, "UserEndedWatching", user_id)
|
||||
logger.info(f"[SpectatorHub] {user_id} ended watching {target_id}")
|
||||
|
||||
async def _preload_beatmap_for_pp_calculation(self, beatmap_id: int) -> None:
|
||||
"""
|
||||
预缓存beatmap文件以加速PP计算
|
||||
当玩家开始游玩时异步预加载beatmap原始文件到Redis缓存
|
||||
"""
|
||||
# 检查是否启用了beatmap预加载功能
|
||||
if not settings.enable_beatmap_preload:
|
||||
return
|
||||
|
||||
try:
|
||||
# 异步获取fetcher和redis连接
|
||||
from app.dependencies.database import get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
|
||||
fetcher = get_fetcher()
|
||||
redis = get_redis()
|
||||
|
||||
# 检查是否已经缓存,避免重复下载
|
||||
cache_key = f"beatmap:raw:{beatmap_id}"
|
||||
if await redis.exists(cache_key):
|
||||
logger.debug(f"Beatmap {beatmap_id} already cached, skipping preload")
|
||||
return
|
||||
|
||||
# 在后台异步预缓存beatmap文件,存储任务引用防止被回收
|
||||
task = asyncio.create_task(
|
||||
self._fetch_beatmap_background(fetcher, redis, beatmap_id)
|
||||
)
|
||||
# 任务完成后自动清理,避免内存泄漏
|
||||
task.add_done_callback(lambda t: None)
|
||||
|
||||
except Exception as e:
|
||||
# 预缓存失败不应该影响正常游戏流程
|
||||
logger.warning(f"Failed to preload beatmap {beatmap_id}: {e}")
|
||||
|
||||
async def _fetch_beatmap_background(self, fetcher, redis, beatmap_id: int) -> None:
|
||||
"""
|
||||
后台获取beatmap文件
|
||||
"""
|
||||
try:
|
||||
# 使用fetcher的get_or_fetch_beatmap_raw方法预缓存
|
||||
await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
logger.debug(
|
||||
f"Successfully preloaded beatmap {beatmap_id} for PP calculation"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to preload beatmap {beatmap_id}: {e}")
|
||||
|
||||
Reference in New Issue
Block a user