feat(fetcher): optimize the process of getting beatmap raw to boost recalculate

This commit is contained in:
MingxuanGame
2025-11-08 19:42:47 +00:00
parent 6753843158
commit 05540d44d0
3 changed files with 167 additions and 39 deletions

View File

@@ -260,8 +260,19 @@ async def calculate_beatmap_attributes(
async def clear_cached_beatmap_raws(redis: Redis, beatmaps: list[int] = []): async def clear_cached_beatmap_raws(redis: Redis, beatmaps: list[int] = []):
"""清理缓存的 beatmap 原始数据,使用非阻塞方式"""
if beatmaps: if beatmaps:
keys = [f"beatmap:{bid}:raw" for bid in beatmaps] # 分批删除,避免一次删除太多 key 导致阻塞
await redis.delete(*keys) batch_size = 50
for i in range(0, len(beatmaps), batch_size):
batch = beatmaps[i : i + batch_size]
keys = [f"beatmap:{bid}:raw" for bid in batch]
# 使用 unlink 而不是 delete非阻塞更快
try:
await redis.unlink(*keys)
except Exception:
# 如果 unlink 不支持,回退到 delete
await redis.delete(*keys)
return return
await redis.delete("beatmap:*:raw") await redis.delete("beatmap:*:raw")

View File

@@ -1,9 +1,10 @@
import asyncio
from app.log import fetcher_logger from app.log import fetcher_logger
from ._base import BaseFetcher from ._base import BaseFetcher
from httpx import AsyncClient, HTTPError from httpx import AsyncClient, HTTPError, Limits
from httpx._models import Response
import redis.asyncio as redis import redis.asyncio as redis
urls = [ urls = [
@@ -16,24 +17,117 @@ logger = fetcher_logger("BeatmapRawFetcher")
class BeatmapRawFetcher(BaseFetcher): class BeatmapRawFetcher(BaseFetcher):
async def get_beatmap_raw(self, beatmap_id: int) -> str: def __init__(self, client_id: str = "", client_secret: str = "", **kwargs):
for url in urls: # BeatmapRawFetcher 不需要 OAuth传递空值给父类
req_url = url.format(beatmap_id=beatmap_id) super().__init__(client_id, client_secret, **kwargs)
logger.opt(colors=True).debug(f"get_beatmap_raw: <y>{req_url}</y>") # 使用共享的 HTTP 客户端和连接池
resp = await self._request(req_url) self._client: AsyncClient | None = None
if resp.status_code >= 400: # 用于并发请求去重
continue self._pending_requests: dict[int, asyncio.Future[str]] = {}
if not resp.text: self._request_lock = asyncio.Lock()
continue
return resp.text
raise HTTPError("Failed to fetch beatmap")
async def _request(self, url: str) -> Response: async def _get_client(self) -> AsyncClient:
async with AsyncClient(timeout=15) as client: """获取或创建共享的 HTTP 客户端"""
response = await client.get( if self._client is None:
url, # 配置连接池限制
limits = Limits(
max_keepalive_connections=20,
max_connections=50,
keepalive_expiry=30.0,
) )
return response self._client = AsyncClient(
timeout=10.0, # 单个请求超时 10 秒
limits=limits,
follow_redirects=True,
)
return self._client
async def close(self):
"""关闭 HTTP 客户端"""
if self._client is not None:
await self._client.aclose()
self._client = None
async def get_beatmap_raw(self, beatmap_id: int) -> str:
future: asyncio.Future[str] | None = None
# 检查是否已有正在进行的请求
async with self._request_lock:
if beatmap_id in self._pending_requests:
logger.debug(f"Beatmap {beatmap_id} request already in progress, waiting...")
future = self._pending_requests[beatmap_id]
# 如果有正在进行的请求,等待它
if future is not None:
try:
return await future
except Exception as e:
logger.warning(f"Waiting for beatmap {beatmap_id} failed: {e}")
# 如果等待失败,继续自己发起请求
future = None
# 创建新的请求 Future
async with self._request_lock:
if beatmap_id in self._pending_requests:
# 双重检查,可能在等待锁时已经有其他协程创建了
future = self._pending_requests[beatmap_id]
if future is not None:
try:
return await future
except Exception as e:
logger.debug(f"Concurrent request for beatmap {beatmap_id} failed: {e}")
# 继续创建新请求
# 创建新的 Future
future = asyncio.get_event_loop().create_future()
self._pending_requests[beatmap_id] = future
try:
# 实际执行请求
result = await self._fetch_beatmap_raw(beatmap_id)
future.set_result(result)
return result
except Exception as e:
future.set_exception(e)
raise
finally:
# 清理
async with self._request_lock:
self._pending_requests.pop(beatmap_id, None)
async def _fetch_beatmap_raw(self, beatmap_id: int) -> str:
client = await self._get_client()
last_error = None
for url_template in urls:
req_url = url_template.format(beatmap_id=beatmap_id)
try:
logger.opt(colors=True).debug(f"get_beatmap_raw: <y>{req_url}</y>")
resp = await client.get(req_url)
if resp.status_code >= 400:
logger.warning(f"Beatmap {beatmap_id} from {req_url}: HTTP {resp.status_code}")
last_error = HTTPError(f"HTTP {resp.status_code}")
continue
if not resp.text:
logger.warning(f"Beatmap {beatmap_id} from {req_url}: empty response")
last_error = HTTPError("Empty response")
continue
logger.debug(f"Successfully fetched beatmap {beatmap_id} from {req_url}")
return resp.text
except Exception as e:
logger.warning(f"Error fetching beatmap {beatmap_id} from {req_url}: {e}")
last_error = e
continue
# 所有 URL 都失败了
error_msg = f"Failed to fetch beatmap {beatmap_id} from all sources"
if last_error:
raise HTTPError(error_msg) from last_error
raise HTTPError(error_msg)
async def get_or_fetch_beatmap_raw(self, redis: redis.Redis, beatmap_id: int) -> str: async def get_or_fetch_beatmap_raw(self, redis: redis.Redis, beatmap_id: int) -> str:
from app.config import settings from app.config import settings

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import argparse import argparse
import asyncio import asyncio
from collections.abc import Awaitable, Sequence from collections.abc import Awaitable, Sequence
import contextlib
import csv import csv
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime from datetime import UTC, datetime
@@ -41,7 +42,12 @@ warnings.filterwarnings("ignore")
class BeatmapCacheManager: class BeatmapCacheManager:
"""管理beatmap缓存确保不超过指定数量""" """管理beatmap缓存确保不超过指定数量
优化:
1. 将清理操作移到锁外,减少持锁时间
2. 使用 LRU 策略,已存在的 beatmap 移到最后
"""
def __init__(self, max_count: int, additional_count: int, redis: Redis): def __init__(self, max_count: int, additional_count: int, redis: Redis):
self.max_count = max_count self.max_count = max_count
@@ -52,13 +58,18 @@ class BeatmapCacheManager:
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
async def add_beatmap(self, beatmap_id: int) -> None: async def add_beatmap(self, beatmap_id: int) -> None:
"""添加beatmap到缓存跟踪列表""" """添加beatmap到缓存跟踪列表LRU策略"""
if self.max_count <= 0: # 不限制 if self.max_count <= 0: # 不限制
return return
to_remove: list[int] = []
async with self.lock: async with self.lock:
# 如果已经存在,不重复添加 # 如果已经存在,更新其位置(移到最后,表示最近使用)
if beatmap_id in self.beatmap_id_set: if beatmap_id in self.beatmap_id_set:
with contextlib.suppress(ValueError):
self.beatmap_ids.remove(beatmap_id)
self.beatmap_ids.append(beatmap_id)
return return
self.beatmap_ids.append(beatmap_id) self.beatmap_ids.append(beatmap_id)
@@ -69,24 +80,28 @@ class BeatmapCacheManager:
if len(self.beatmap_ids) > threshold: if len(self.beatmap_ids) > threshold:
# 计算需要删除的数量 # 计算需要删除的数量
to_remove_count = max(1, self.additional_count) to_remove_count = max(1, self.additional_count)
await self._cleanup(to_remove_count) # 获取要删除的 beatmap ids最旧的
to_remove = self.beatmap_ids[:to_remove_count]
self.beatmap_ids = self.beatmap_ids[to_remove_count:]
# 从 set 中移除
for bid in to_remove:
self.beatmap_id_set.discard(bid)
async def _cleanup(self, count: int) -> None: # 在锁外执行清理(避免阻塞其他协程)
"""清理最早的count个beatmap缓存""" if to_remove:
if count <= 0 or not self.beatmap_ids: await self._cleanup_async(to_remove)
async def _cleanup_async(self, to_remove: list[int]) -> None:
"""异步清理 beatmap 缓存(在锁外执行)"""
if not to_remove:
return return
# 获取要删除的beatmap ids try:
to_remove = self.beatmap_ids[:count] # 从 Redis 中删除缓存
self.beatmap_ids = self.beatmap_ids[count:] await clear_cached_beatmap_raws(self.redis, to_remove)
logger.info(f"Cleaned up {len(to_remove)} beatmap caches (remaining: {len(self.beatmap_ids)})")
# 从set中移除 except Exception as e:
for bid in to_remove: logger.warning(f"Failed to cleanup {len(to_remove)} beatmap caches: {e}")
self.beatmap_id_set.discard(bid)
# 从Redis中删除缓存
await clear_cached_beatmap_raws(self.redis, to_remove)
logger.info(f"Cleaned up {len(to_remove)} beatmap caches (total: {len(self.beatmap_ids)})")
def get_stats(self) -> dict: def get_stats(self) -> dict:
"""获取统计信息""" """获取统计信息"""
@@ -1071,7 +1086,15 @@ async def recalculate_beatmap_rating(
while attempts > 0: while attempts > 0:
try: try:
ruleset = GameMode(beatmap.mode) if isinstance(beatmap.mode, int) else beatmap.mode ruleset = GameMode(beatmap.mode) if isinstance(beatmap.mode, int) else beatmap.mode
attributes = await calculate_beatmap_attributes(beatmap_id, ruleset, [], redis, fetcher) # 添加整体超时保护30秒防止单个请求卡死
try:
attributes = await asyncio.wait_for(
calculate_beatmap_attributes(beatmap_id, ruleset, [], redis, fetcher), timeout=30.0
)
except TimeoutError:
logger.error(f"Timeout calculating attributes for beatmap {beatmap_id} after 30s")
return
# 记录使用的beatmap # 记录使用的beatmap
if cache_manager: if cache_manager:
await cache_manager.add_beatmap(beatmap_id) await cache_manager.add_beatmap(beatmap_id)