feat(fetcher): optimize the process of getting beatmap raw to boost recalculate
This commit is contained in:
@@ -260,8 +260,19 @@ async def calculate_beatmap_attributes(
|
|||||||
|
|
||||||
|
|
||||||
async def clear_cached_beatmap_raws(redis: Redis, beatmaps: list[int] = []):
|
async def clear_cached_beatmap_raws(redis: Redis, beatmaps: list[int] = []):
|
||||||
|
"""清理缓存的 beatmap 原始数据,使用非阻塞方式"""
|
||||||
if beatmaps:
|
if beatmaps:
|
||||||
keys = [f"beatmap:{bid}:raw" for bid in beatmaps]
|
# 分批删除,避免一次删除太多 key 导致阻塞
|
||||||
await redis.delete(*keys)
|
batch_size = 50
|
||||||
|
for i in range(0, len(beatmaps), batch_size):
|
||||||
|
batch = beatmaps[i : i + batch_size]
|
||||||
|
keys = [f"beatmap:{bid}:raw" for bid in batch]
|
||||||
|
# 使用 unlink 而不是 delete(非阻塞,更快)
|
||||||
|
try:
|
||||||
|
await redis.unlink(*keys)
|
||||||
|
except Exception:
|
||||||
|
# 如果 unlink 不支持,回退到 delete
|
||||||
|
await redis.delete(*keys)
|
||||||
return
|
return
|
||||||
|
|
||||||
await redis.delete("beatmap:*:raw")
|
await redis.delete("beatmap:*:raw")
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
from app.log import fetcher_logger
|
from app.log import fetcher_logger
|
||||||
|
|
||||||
from ._base import BaseFetcher
|
from ._base import BaseFetcher
|
||||||
|
|
||||||
from httpx import AsyncClient, HTTPError
|
from httpx import AsyncClient, HTTPError, Limits
|
||||||
from httpx._models import Response
|
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
@@ -16,24 +17,117 @@ logger = fetcher_logger("BeatmapRawFetcher")
|
|||||||
|
|
||||||
|
|
||||||
class BeatmapRawFetcher(BaseFetcher):
|
class BeatmapRawFetcher(BaseFetcher):
|
||||||
async def get_beatmap_raw(self, beatmap_id: int) -> str:
|
def __init__(self, client_id: str = "", client_secret: str = "", **kwargs):
|
||||||
for url in urls:
|
# BeatmapRawFetcher 不需要 OAuth,传递空值给父类
|
||||||
req_url = url.format(beatmap_id=beatmap_id)
|
super().__init__(client_id, client_secret, **kwargs)
|
||||||
logger.opt(colors=True).debug(f"get_beatmap_raw: <y>{req_url}</y>")
|
# 使用共享的 HTTP 客户端和连接池
|
||||||
resp = await self._request(req_url)
|
self._client: AsyncClient | None = None
|
||||||
if resp.status_code >= 400:
|
# 用于并发请求去重
|
||||||
continue
|
self._pending_requests: dict[int, asyncio.Future[str]] = {}
|
||||||
if not resp.text:
|
self._request_lock = asyncio.Lock()
|
||||||
continue
|
|
||||||
return resp.text
|
|
||||||
raise HTTPError("Failed to fetch beatmap")
|
|
||||||
|
|
||||||
async def _request(self, url: str) -> Response:
|
async def _get_client(self) -> AsyncClient:
|
||||||
async with AsyncClient(timeout=15) as client:
|
"""获取或创建共享的 HTTP 客户端"""
|
||||||
response = await client.get(
|
if self._client is None:
|
||||||
url,
|
# 配置连接池限制
|
||||||
|
limits = Limits(
|
||||||
|
max_keepalive_connections=20,
|
||||||
|
max_connections=50,
|
||||||
|
keepalive_expiry=30.0,
|
||||||
)
|
)
|
||||||
return response
|
self._client = AsyncClient(
|
||||||
|
timeout=10.0, # 单个请求超时 10 秒
|
||||||
|
limits=limits,
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""关闭 HTTP 客户端"""
|
||||||
|
if self._client is not None:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
async def get_beatmap_raw(self, beatmap_id: int) -> str:
|
||||||
|
future: asyncio.Future[str] | None = None
|
||||||
|
|
||||||
|
# 检查是否已有正在进行的请求
|
||||||
|
async with self._request_lock:
|
||||||
|
if beatmap_id in self._pending_requests:
|
||||||
|
logger.debug(f"Beatmap {beatmap_id} request already in progress, waiting...")
|
||||||
|
future = self._pending_requests[beatmap_id]
|
||||||
|
|
||||||
|
# 如果有正在进行的请求,等待它
|
||||||
|
if future is not None:
|
||||||
|
try:
|
||||||
|
return await future
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Waiting for beatmap {beatmap_id} failed: {e}")
|
||||||
|
# 如果等待失败,继续自己发起请求
|
||||||
|
future = None
|
||||||
|
|
||||||
|
# 创建新的请求 Future
|
||||||
|
async with self._request_lock:
|
||||||
|
if beatmap_id in self._pending_requests:
|
||||||
|
# 双重检查,可能在等待锁时已经有其他协程创建了
|
||||||
|
future = self._pending_requests[beatmap_id]
|
||||||
|
if future is not None:
|
||||||
|
try:
|
||||||
|
return await future
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Concurrent request for beatmap {beatmap_id} failed: {e}")
|
||||||
|
# 继续创建新请求
|
||||||
|
|
||||||
|
# 创建新的 Future
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
self._pending_requests[beatmap_id] = future
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 实际执行请求
|
||||||
|
result = await self._fetch_beatmap_raw(beatmap_id)
|
||||||
|
future.set_result(result)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
future.set_exception(e)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 清理
|
||||||
|
async with self._request_lock:
|
||||||
|
self._pending_requests.pop(beatmap_id, None)
|
||||||
|
|
||||||
|
async def _fetch_beatmap_raw(self, beatmap_id: int) -> str:
|
||||||
|
client = await self._get_client()
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
for url_template in urls:
|
||||||
|
req_url = url_template.format(beatmap_id=beatmap_id)
|
||||||
|
try:
|
||||||
|
logger.opt(colors=True).debug(f"get_beatmap_raw: <y>{req_url}</y>")
|
||||||
|
resp = await client.get(req_url)
|
||||||
|
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
logger.warning(f"Beatmap {beatmap_id} from {req_url}: HTTP {resp.status_code}")
|
||||||
|
last_error = HTTPError(f"HTTP {resp.status_code}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not resp.text:
|
||||||
|
logger.warning(f"Beatmap {beatmap_id} from {req_url}: empty response")
|
||||||
|
last_error = HTTPError("Empty response")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"Successfully fetched beatmap {beatmap_id} from {req_url}")
|
||||||
|
return resp.text
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error fetching beatmap {beatmap_id} from {req_url}: {e}")
|
||||||
|
last_error = e
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 所有 URL 都失败了
|
||||||
|
error_msg = f"Failed to fetch beatmap {beatmap_id} from all sources"
|
||||||
|
if last_error:
|
||||||
|
raise HTTPError(error_msg) from last_error
|
||||||
|
raise HTTPError(error_msg)
|
||||||
|
|
||||||
async def get_or_fetch_beatmap_raw(self, redis: redis.Redis, beatmap_id: int) -> str:
|
async def get_or_fetch_beatmap_raw(self, redis: redis.Redis, beatmap_id: int) -> str:
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Awaitable, Sequence
|
from collections.abc import Awaitable, Sequence
|
||||||
|
import contextlib
|
||||||
import csv
|
import csv
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
@@ -41,7 +42,12 @@ warnings.filterwarnings("ignore")
|
|||||||
|
|
||||||
|
|
||||||
class BeatmapCacheManager:
|
class BeatmapCacheManager:
|
||||||
"""管理beatmap缓存,确保不超过指定数量"""
|
"""管理beatmap缓存,确保不超过指定数量
|
||||||
|
|
||||||
|
优化:
|
||||||
|
1. 将清理操作移到锁外,减少持锁时间
|
||||||
|
2. 使用 LRU 策略,已存在的 beatmap 移到最后
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, max_count: int, additional_count: int, redis: Redis):
|
def __init__(self, max_count: int, additional_count: int, redis: Redis):
|
||||||
self.max_count = max_count
|
self.max_count = max_count
|
||||||
@@ -52,13 +58,18 @@ class BeatmapCacheManager:
|
|||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
async def add_beatmap(self, beatmap_id: int) -> None:
|
async def add_beatmap(self, beatmap_id: int) -> None:
|
||||||
"""添加beatmap到缓存跟踪列表"""
|
"""添加beatmap到缓存跟踪列表(LRU策略)"""
|
||||||
if self.max_count <= 0: # 不限制
|
if self.max_count <= 0: # 不限制
|
||||||
return
|
return
|
||||||
|
|
||||||
|
to_remove: list[int] = []
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
# 如果已经存在,不重复添加
|
# 如果已经存在,更新其位置(移到最后,表示最近使用)
|
||||||
if beatmap_id in self.beatmap_id_set:
|
if beatmap_id in self.beatmap_id_set:
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
self.beatmap_ids.remove(beatmap_id)
|
||||||
|
self.beatmap_ids.append(beatmap_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
self.beatmap_ids.append(beatmap_id)
|
self.beatmap_ids.append(beatmap_id)
|
||||||
@@ -69,24 +80,28 @@ class BeatmapCacheManager:
|
|||||||
if len(self.beatmap_ids) > threshold:
|
if len(self.beatmap_ids) > threshold:
|
||||||
# 计算需要删除的数量
|
# 计算需要删除的数量
|
||||||
to_remove_count = max(1, self.additional_count)
|
to_remove_count = max(1, self.additional_count)
|
||||||
await self._cleanup(to_remove_count)
|
# 获取要删除的 beatmap ids(最旧的)
|
||||||
|
to_remove = self.beatmap_ids[:to_remove_count]
|
||||||
|
self.beatmap_ids = self.beatmap_ids[to_remove_count:]
|
||||||
|
# 从 set 中移除
|
||||||
|
for bid in to_remove:
|
||||||
|
self.beatmap_id_set.discard(bid)
|
||||||
|
|
||||||
async def _cleanup(self, count: int) -> None:
|
# 在锁外执行清理(避免阻塞其他协程)
|
||||||
"""清理最早的count个beatmap缓存"""
|
if to_remove:
|
||||||
if count <= 0 or not self.beatmap_ids:
|
await self._cleanup_async(to_remove)
|
||||||
|
|
||||||
|
async def _cleanup_async(self, to_remove: list[int]) -> None:
|
||||||
|
"""异步清理 beatmap 缓存(在锁外执行)"""
|
||||||
|
if not to_remove:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 获取要删除的beatmap ids
|
try:
|
||||||
to_remove = self.beatmap_ids[:count]
|
# 从 Redis 中删除缓存
|
||||||
self.beatmap_ids = self.beatmap_ids[count:]
|
await clear_cached_beatmap_raws(self.redis, to_remove)
|
||||||
|
logger.info(f"Cleaned up {len(to_remove)} beatmap caches (remaining: {len(self.beatmap_ids)})")
|
||||||
# 从set中移除
|
except Exception as e:
|
||||||
for bid in to_remove:
|
logger.warning(f"Failed to cleanup {len(to_remove)} beatmap caches: {e}")
|
||||||
self.beatmap_id_set.discard(bid)
|
|
||||||
|
|
||||||
# 从Redis中删除缓存
|
|
||||||
await clear_cached_beatmap_raws(self.redis, to_remove)
|
|
||||||
logger.info(f"Cleaned up {len(to_remove)} beatmap caches (total: {len(self.beatmap_ids)})")
|
|
||||||
|
|
||||||
def get_stats(self) -> dict:
|
def get_stats(self) -> dict:
|
||||||
"""获取统计信息"""
|
"""获取统计信息"""
|
||||||
@@ -1071,7 +1086,15 @@ async def recalculate_beatmap_rating(
|
|||||||
while attempts > 0:
|
while attempts > 0:
|
||||||
try:
|
try:
|
||||||
ruleset = GameMode(beatmap.mode) if isinstance(beatmap.mode, int) else beatmap.mode
|
ruleset = GameMode(beatmap.mode) if isinstance(beatmap.mode, int) else beatmap.mode
|
||||||
attributes = await calculate_beatmap_attributes(beatmap_id, ruleset, [], redis, fetcher)
|
# 添加整体超时保护(30秒),防止单个请求卡死
|
||||||
|
try:
|
||||||
|
attributes = await asyncio.wait_for(
|
||||||
|
calculate_beatmap_attributes(beatmap_id, ruleset, [], redis, fetcher), timeout=30.0
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.error(f"Timeout calculating attributes for beatmap {beatmap_id} after 30s")
|
||||||
|
return
|
||||||
|
|
||||||
# 记录使用的beatmap
|
# 记录使用的beatmap
|
||||||
if cache_manager:
|
if cache_manager:
|
||||||
await cache_manager.add_beatmap(beatmap_id)
|
await cache_manager.add_beatmap(beatmap_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user