From 05540d44d04c8490c1bb5e2377410fb314edf806 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 8 Nov 2025 19:42:47 +0000 Subject: [PATCH] feat(fetcher): optimize the process of getting beatmap raw to boost recalculate --- app/database/beatmap.py | 15 ++++- app/fetcher/beatmap_raw.py | 130 ++++++++++++++++++++++++++++++++----- tools/recalculate.py | 61 +++++++++++------ 3 files changed, 167 insertions(+), 39 deletions(-) diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 71a4df3..695e8a7 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -260,8 +260,19 @@ async def calculate_beatmap_attributes( async def clear_cached_beatmap_raws(redis: Redis, beatmaps: list[int] = []): + """清理缓存的 beatmap 原始数据,使用非阻塞方式""" if beatmaps: - keys = [f"beatmap:{bid}:raw" for bid in beatmaps] - await redis.delete(*keys) + # 分批删除,避免一次删除太多 key 导致阻塞 + batch_size = 50 + for i in range(0, len(beatmaps), batch_size): + batch = beatmaps[i : i + batch_size] + keys = [f"beatmap:{bid}:raw" for bid in batch] + # 使用 unlink 而不是 delete(非阻塞,更快) + try: + await redis.unlink(*keys) + except Exception: + # 如果 unlink 不支持,回退到 delete + await redis.delete(*keys) return + await redis.delete("beatmap:*:raw") diff --git a/app/fetcher/beatmap_raw.py b/app/fetcher/beatmap_raw.py index 896e415..b358baa 100644 --- a/app/fetcher/beatmap_raw.py +++ b/app/fetcher/beatmap_raw.py @@ -1,9 +1,10 @@ +import asyncio + from app.log import fetcher_logger from ._base import BaseFetcher -from httpx import AsyncClient, HTTPError -from httpx._models import Response +from httpx import AsyncClient, HTTPError, Limits import redis.asyncio as redis urls = [ @@ -16,24 +17,117 @@ logger = fetcher_logger("BeatmapRawFetcher") class BeatmapRawFetcher(BaseFetcher): - async def get_beatmap_raw(self, beatmap_id: int) -> str: - for url in urls: - req_url = url.format(beatmap_id=beatmap_id) - logger.opt(colors=True).debug(f"get_beatmap_raw: {req_url}") - resp = await self._request(req_url) - if resp.status_code >= 400: - continue - if not resp.text: - continue - return resp.text - raise HTTPError("Failed to fetch beatmap") + def __init__(self, client_id: str = "", client_secret: str = "", **kwargs): + # BeatmapRawFetcher 不需要 OAuth,传递空值给父类 + super().__init__(client_id, client_secret, **kwargs) + # 使用共享的 HTTP 客户端和连接池 + self._client: AsyncClient | None = None + # 用于并发请求去重 + self._pending_requests: dict[int, asyncio.Future[str]] = {} + self._request_lock = asyncio.Lock() - async def _request(self, url: str) -> Response: - async with AsyncClient(timeout=15) as client: - response = await client.get( - url, + async def _get_client(self) -> AsyncClient: + """获取或创建共享的 HTTP 客户端""" + if self._client is None: + # 配置连接池限制 + limits = Limits( + max_keepalive_connections=20, + max_connections=50, + keepalive_expiry=30.0, ) - return response + self._client = AsyncClient( + timeout=10.0, # 单个请求超时 10 秒 + limits=limits, + follow_redirects=True, + ) + return self._client + + async def close(self): + """关闭 HTTP 客户端""" + if self._client is not None: + await self._client.aclose() + self._client = None + + async def get_beatmap_raw(self, beatmap_id: int) -> str: + future: asyncio.Future[str] | None = None + + # 检查是否已有正在进行的请求 + async with self._request_lock: + if beatmap_id in self._pending_requests: + logger.debug(f"Beatmap {beatmap_id} request already in progress, waiting...") + future = self._pending_requests[beatmap_id] + + # 如果有正在进行的请求,等待它 + if future is not None: + try: + return await future + except Exception as e: + logger.warning(f"Waiting for beatmap {beatmap_id} failed: {e}") + # 如果等待失败,继续自己发起请求 + future = None + + # 创建新的请求 Future + async with self._request_lock: + if beatmap_id in self._pending_requests: + # 双重检查,可能在等待锁时已经有其他协程创建了 + future = self._pending_requests[beatmap_id] + if future is not None: + try: + return await future + except Exception as e: + logger.debug(f"Concurrent request for beatmap {beatmap_id} failed: {e}") + # 继续创建新请求 + + # 创建新的 Future + future = asyncio.get_event_loop().create_future() + self._pending_requests[beatmap_id] = future + + try: + # 实际执行请求 + result = await self._fetch_beatmap_raw(beatmap_id) + future.set_result(result) + return result + except Exception as e: + future.set_exception(e) + raise + finally: + # 清理 + async with self._request_lock: + self._pending_requests.pop(beatmap_id, None) + + async def _fetch_beatmap_raw(self, beatmap_id: int) -> str: + client = await self._get_client() + last_error = None + + for url_template in urls: + req_url = url_template.format(beatmap_id=beatmap_id) + try: + logger.opt(colors=True).debug(f"get_beatmap_raw: {req_url}") + resp = await client.get(req_url) + + if resp.status_code >= 400: + logger.warning(f"Beatmap {beatmap_id} from {req_url}: HTTP {resp.status_code}") + last_error = HTTPError(f"HTTP {resp.status_code}") + continue + + if not resp.text: + logger.warning(f"Beatmap {beatmap_id} from {req_url}: empty response") + last_error = HTTPError("Empty response") + continue + + logger.debug(f"Successfully fetched beatmap {beatmap_id} from {req_url}") + return resp.text + + except Exception as e: + logger.warning(f"Error fetching beatmap {beatmap_id} from {req_url}: {e}") + last_error = e + continue + + # 所有 URL 都失败了 + error_msg = f"Failed to fetch beatmap {beatmap_id} from all sources" + if last_error: + raise HTTPError(error_msg) from last_error + raise HTTPError(error_msg) async def get_or_fetch_beatmap_raw(self, redis: redis.Redis, beatmap_id: int) -> str: from app.config import settings diff --git a/tools/recalculate.py b/tools/recalculate.py index 2fe42c5..9773c0d 100644 --- a/tools/recalculate.py +++ b/tools/recalculate.py @@ -3,6 +3,7 @@ from __future__ import annotations import argparse import asyncio from collections.abc import Awaitable, Sequence +import contextlib import csv from dataclasses import dataclass from datetime import UTC, datetime @@ -41,7 +42,12 @@ warnings.filterwarnings("ignore") class BeatmapCacheManager: - """管理beatmap缓存,确保不超过指定数量""" + """管理beatmap缓存,确保不超过指定数量 + + 优化: + 1. 将清理操作移到锁外,减少持锁时间 + 2. 使用 LRU 策略,已存在的 beatmap 移到最后 + """ def __init__(self, max_count: int, additional_count: int, redis: Redis): self.max_count = max_count @@ -52,13 +58,18 @@ class BeatmapCacheManager: self.lock = asyncio.Lock() async def add_beatmap(self, beatmap_id: int) -> None: - """添加beatmap到缓存跟踪列表""" + """添加beatmap到缓存跟踪列表(LRU策略)""" if self.max_count <= 0: # 不限制 return + to_remove: list[int] = [] + async with self.lock: - # 如果已经存在,不重复添加 + # 如果已经存在,更新其位置(移到最后,表示最近使用) if beatmap_id in self.beatmap_id_set: + with contextlib.suppress(ValueError): + self.beatmap_ids.remove(beatmap_id) + self.beatmap_ids.append(beatmap_id) return self.beatmap_ids.append(beatmap_id) @@ -69,24 +80,28 @@ class BeatmapCacheManager: if len(self.beatmap_ids) > threshold: # 计算需要删除的数量 to_remove_count = max(1, self.additional_count) - await self._cleanup(to_remove_count) + # 获取要删除的 beatmap ids(最旧的) + to_remove = self.beatmap_ids[:to_remove_count] + self.beatmap_ids = self.beatmap_ids[to_remove_count:] + # 从 set 中移除 + for bid in to_remove: + self.beatmap_id_set.discard(bid) - async def _cleanup(self, count: int) -> None: - """清理最早的count个beatmap缓存""" - if count <= 0 or not self.beatmap_ids: + # 在锁外执行清理(避免阻塞其他协程) + if to_remove: + await self._cleanup_async(to_remove) + + async def _cleanup_async(self, to_remove: list[int]) -> None: + """异步清理 beatmap 缓存(在锁外执行)""" + if not to_remove: return - # 获取要删除的beatmap ids - to_remove = self.beatmap_ids[:count] - self.beatmap_ids = self.beatmap_ids[count:] - - # 从set中移除 - for bid in to_remove: - self.beatmap_id_set.discard(bid) - - # 从Redis中删除缓存 - await clear_cached_beatmap_raws(self.redis, to_remove) - logger.info(f"Cleaned up {len(to_remove)} beatmap caches (total: {len(self.beatmap_ids)})") + try: + # 从 Redis 中删除缓存 + await clear_cached_beatmap_raws(self.redis, to_remove) + logger.info(f"Cleaned up {len(to_remove)} beatmap caches (remaining: {len(self.beatmap_ids)})") + except Exception as e: + logger.warning(f"Failed to cleanup {len(to_remove)} beatmap caches: {e}") def get_stats(self) -> dict: """获取统计信息""" @@ -1071,7 +1086,15 @@ async def recalculate_beatmap_rating( while attempts > 0: try: ruleset = GameMode(beatmap.mode) if isinstance(beatmap.mode, int) else beatmap.mode - attributes = await calculate_beatmap_attributes(beatmap_id, ruleset, [], redis, fetcher) + # 添加整体超时保护(30秒),防止单个请求卡死 + try: + attributes = await asyncio.wait_for( + calculate_beatmap_attributes(beatmap_id, ruleset, [], redis, fetcher), timeout=30.0 + ) + except TimeoutError: + logger.error(f"Timeout calculating attributes for beatmap {beatmap_id} after 30s") + return + # 记录使用的beatmap if cache_manager: await cache_manager.add_beatmap(beatmap_id)