add search redis

This commit is contained in:
咕谷酱
2025-08-18 16:20:29 +08:00
parent a246393ff7
commit 71c961cafd
9 changed files with 542 additions and 12 deletions

View File

@@ -315,3 +315,5 @@ class BeatmapsetResp(BeatmapsetBase):
class SearchBeatmapsetsResp(SQLModel):
beatmapsets: list[BeatmapsetResp]
total: int
cursor: dict[str, int | float] | None = None
cursor_string: str | None = None

View File

@@ -1,14 +1,83 @@
from __future__ import annotations
import asyncio
import base64
import hashlib
import json
from app.database.beatmapset import BeatmapsetResp, SearchBeatmapsetsResp
from app.helpers.rate_limiter import osu_api_rate_limiter
from app.log import logger
from app.models.beatmap import SearchQueryModel
from app.models.model import Cursor
from ._base import BaseFetcher
import redis.asyncio as redis
class BeatmapsetFetcher(BaseFetcher):
@staticmethod
def _get_homepage_queries() -> list[tuple[SearchQueryModel, Cursor]]:
"""获取主页预缓存查询列表"""
# 主页常用查询组合
homepage_queries = []
# 主要排序方式
sorts = ["ranked_desc", "updated_desc", "favourites_desc", "plays_desc"]
for sort in sorts:
# 第一页
query = SearchQueryModel(
q="",
s="leaderboard",
sort=sort, # type: ignore
nsfw=False,
m=0
)
homepage_queries.append((query, {}))
return homepage_queries
async def request_api(self, url: str, method: str = "GET", **kwargs) -> dict:
"""覆盖基类方法,添加速率限制"""
# 在请求前获取速率限制许可
await osu_api_rate_limiter.acquire()
# 调用基类的请求方法
return await super().request_api(url, method, **kwargs)
@staticmethod
def _generate_cache_key(query: SearchQueryModel, cursor: Cursor) -> str:
"""生成搜索缓存键"""
# 创建包含查询参数和 cursor 的字典
cache_data = {
**query.model_dump(
exclude_none=True, exclude_unset=True, exclude_defaults=True
),
"cursor": cursor
}
# 序列化为 JSON 并生成 MD5 哈希
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
cache_hash = hashlib.md5(cache_json.encode()).hexdigest()
return f"beatmapset:search:{cache_hash}"
@staticmethod
def _encode_cursor(cursor_dict: dict[str, int | float]) -> str:
"""将cursor字典编码为base64字符串"""
cursor_json = json.dumps(cursor_dict, separators=(",", ":"))
return base64.b64encode(cursor_json.encode()).decode()
@staticmethod
def _decode_cursor(cursor_string: str) -> dict[str, int | float]:
"""将base64字符串解码为cursor字典"""
try:
cursor_json = base64.b64decode(cursor_string).decode()
return json.loads(cursor_json)
except Exception:
return {}
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
logger.opt(colors=True).debug(
f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>"
@@ -21,21 +90,236 @@ class BeatmapsetFetcher(BaseFetcher):
)
async def search_beatmapset(
self, query: SearchQueryModel, cursor: Cursor
self, query: SearchQueryModel, cursor: Cursor, redis_client: redis.Redis
) -> SearchBeatmapsetsResp:
logger.opt(colors=True).debug(
f"<blue>[BeatmapsetFetcher]</blue> search_beatmapset: <y>{query}</y>"
)
# 生成缓存键
cache_key = self._generate_cache_key(query, cursor)
# 尝试从缓存获取结果
cached_result = await redis_client.get(cache_key)
if cached_result:
logger.opt(colors=True).debug(
f"<green>[BeatmapsetFetcher]</green> Cache hit for key: <y>{cache_key}</y>"
)
try:
cached_data = json.loads(cached_result)
return SearchBeatmapsetsResp.model_validate(cached_data)
except Exception as e:
logger.opt(colors=True).warning(
f"<yellow>[BeatmapsetFetcher]</yellow> Cache data invalid, fetching from API: {e}"
)
# 缓存未命中,从 API 获取数据
logger.opt(colors=True).debug(
"<blue>[BeatmapsetFetcher]</blue> Cache miss, fetching from API"
)
params = query.model_dump(
exclude_none=True, exclude_unset=True, exclude_defaults=True
)
for k, v in cursor.items():
params[f"cursor[{k}]"] = v
resp = SearchBeatmapsetsResp.model_validate(
await self.request_api(
"https://osu.ppy.sh/api/v2/beatmapsets/search",
params=params,
)
if query.cursor_string:
params["cursor_string"] = query.cursor_string
else:
for k, v in cursor.items():
params[f"cursor[{k}]"] = v
api_response = await self.request_api(
"https://osu.ppy.sh/api/v2/beatmapsets/search",
params=params,
)
# 处理响应中的cursor信息
if api_response.get("cursor"):
cursor_dict = api_response["cursor"]
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
# 将结果缓存 15 分钟
cache_ttl = 15 * 60 # 15 分钟
await redis_client.set(
cache_key,
json.dumps(api_response, separators=(",", ":")),
ex=cache_ttl
)
logger.opt(colors=True).debug(
f"<green>[BeatmapsetFetcher]</green> Cached result for key: "
f"<y>{cache_key}</y> (TTL: {cache_ttl}s)"
)
resp = SearchBeatmapsetsResp.model_validate(api_response)
# 智能预取只在用户明确搜索时才预取避免过多API请求
# 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取
if (api_response.get("cursor") and
(query.q or query.s != "leaderboard" or cursor)):
# 在后台预取下1页减少预取量
import asyncio
# 不立即创建任务,而是延迟一段时间再预取
async def delayed_prefetch():
await asyncio.sleep(3.0) # 延迟3秒
await self.prefetch_next_pages(
query, api_response["cursor"], redis_client, pages=1
)
# 创建延迟预取任务
task = asyncio.create_task(delayed_prefetch())
# 添加到后台任务集合避免被垃圾回收
if not hasattr(self, "_background_tasks"):
self._background_tasks = set()
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return resp
async def prefetch_next_pages(
self, query: SearchQueryModel, current_cursor: Cursor,
redis_client: redis.Redis, pages: int = 3
) -> None:
"""预取下几页内容"""
if not current_cursor:
return
try:
cursor = current_cursor.copy()
for page in range(1, pages + 1):
# 使用当前 cursor 请求下一页
next_query = query.model_copy()
logger.opt(colors=True).debug(
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetching page {page + 1}"
)
# 生成下一页的缓存键
next_cache_key = self._generate_cache_key(next_query, cursor)
# 检查是否已经缓存
if await redis_client.exists(next_cache_key):
logger.opt(colors=True).debug(
f"<cyan>[BeatmapsetFetcher]</cyan> Page {page + 1} already cached"
)
# 尝试从缓存获取cursor继续预取
cached_data = await redis_client.get(next_cache_key)
if cached_data:
try:
data = json.loads(cached_data)
if data.get("cursor"):
cursor = data["cursor"]
continue
except Exception:
pass
break
# 在预取页面之间添加延迟,避免突发请求
if page > 1:
await asyncio.sleep(1.5) # 1.5秒延迟
# 请求下一页数据
params = next_query.model_dump(
exclude_none=True, exclude_unset=True, exclude_defaults=True
)
for k, v in cursor.items():
params[f"cursor[{k}]"] = v
api_response = await self.request_api(
"https://osu.ppy.sh/api/v2/beatmapsets/search",
params=params,
)
# 处理响应中的cursor信息
if api_response.get("cursor"):
cursor_dict = api_response["cursor"]
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
cursor = cursor_dict # 更新cursor用于下一页
else:
# 没有更多页面了
break
# 缓存结果较短的TTL用于预取
prefetch_ttl = 10 * 60 # 10 分钟
await redis_client.set(
next_cache_key,
json.dumps(api_response, separators=(",", ":")),
ex=prefetch_ttl
)
logger.opt(colors=True).debug(
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetched page {page + 1} "
f"(TTL: {prefetch_ttl}s)"
)
except Exception as e:
logger.opt(colors=True).warning(
f"<yellow>[BeatmapsetFetcher]</yellow> Prefetch failed: {e}"
)
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
"""预热主页缓存"""
homepage_queries = self._get_homepage_queries()
logger.opt(colors=True).info(
f"<magenta>[BeatmapsetFetcher]</magenta> Starting homepage cache warmup "
f"({len(homepage_queries)} queries)"
)
for i, (query, cursor) in enumerate(homepage_queries):
try:
# 在请求之间添加延迟,避免突发请求
if i > 0:
await asyncio.sleep(2.0) # 2秒延迟
cache_key = self._generate_cache_key(query, cursor)
# 检查是否已经缓存
if await redis_client.exists(cache_key):
logger.opt(colors=True).debug(
f"<magenta>[BeatmapsetFetcher]</magenta> "
f"Query {query.sort} already cached"
)
continue
# 请求并缓存
params = query.model_dump(
exclude_none=True, exclude_unset=True, exclude_defaults=True
)
api_response = await self.request_api(
"https://osu.ppy.sh/api/v2/beatmapsets/search",
params=params,
)
# 处理响应中的cursor信息
if api_response.get("cursor"):
cursor_dict = api_response["cursor"]
api_response["cursor_string"] = self._encode_cursor(cursor_dict)
# 缓存结果
cache_ttl = 20 * 60 # 20 分钟
await redis_client.set(
cache_key,
json.dumps(api_response, separators=(",", ":")),
ex=cache_ttl
)
logger.opt(colors=True).info(
f"<magenta>[BeatmapsetFetcher]</magenta> "
f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
)
# 预取前2页也会遵循速率限制
if api_response.get("cursor"):
await self.prefetch_next_pages(
query, api_response["cursor"], redis_client, pages=2
)
except Exception as e:
logger.opt(colors=True).error(
f"<red>[BeatmapsetFetcher]</red> "
f"Failed to warmup cache for {query.sort}: {e}"
)

122
app/helpers/rate_limiter.py Normal file
View File

@@ -0,0 +1,122 @@
"""
Rate limiter for osu! API requests to avoid abuse detection.
根据 osu! API v2 的速率限制设计:
- 默认:每分钟最多 1200 次请求
- 突发:短时间内最多 200 次额外请求
- 建议:每分钟不超过 60 次请求以避免滥用检测
"""
from __future__ import annotations
import asyncio
from collections import deque
import time
from app.log import logger
class RateLimiter:
"""osu! API 速率限制器"""
def __init__(
self,
max_requests_per_minute: int = 60, # 保守的限制
burst_limit: int = 10, # 短时间内的突发限制
burst_window: float = 10.0, # 突发窗口(秒)
):
self.max_requests_per_minute = max_requests_per_minute
self.burst_limit = burst_limit
self.burst_window = burst_window
# 跟踪请求时间戳
self.request_times: deque[float] = deque()
self.burst_times: deque[float] = deque()
# 锁确保线程安全
self._lock = asyncio.Lock()
async def acquire(self) -> None:
"""获取请求许可,如果超过限制则等待"""
async with self._lock:
current_time = time.time()
# 清理过期的请求记录
self._cleanup_old_requests(current_time)
# 检查是否需要等待
wait_time = self._calculate_wait_time(current_time)
if wait_time > 0:
logger.opt(colors=True).info(
f"<yellow>[RateLimiter]</yellow> Rate limit reached, "
f"waiting {wait_time:.2f}s"
)
await asyncio.sleep(wait_time)
current_time = time.time()
self._cleanup_old_requests(current_time)
# 记录当前请求
self.request_times.append(current_time)
self.burst_times.append(current_time)
logger.opt(colors=True).debug(
f"<green>[RateLimiter]</green> Request granted. "
f"Recent requests: {len(self.request_times)}/min, "
f"{len(self.burst_times)}/{self.burst_window}s"
)
def _cleanup_old_requests(self, current_time: float) -> None:
"""清理过期的请求记录"""
# 清理1分钟前的请求
minute_ago = current_time - 60.0
while self.request_times and self.request_times[0] < minute_ago:
self.request_times.popleft()
# 清理突发窗口外的请求
burst_window_ago = current_time - self.burst_window
while self.burst_times and self.burst_times[0] < burst_window_ago:
self.burst_times.popleft()
def _calculate_wait_time(self, current_time: float) -> float:
"""计算需要等待的时间"""
# 检查每分钟限制
if len(self.request_times) >= self.max_requests_per_minute:
# 需要等到最老的请求超过1分钟
oldest_request = self.request_times[0]
wait_for_minute_limit = oldest_request + 60.0 - current_time
else:
wait_for_minute_limit = 0.0
# 检查突发限制
if len(self.burst_times) >= self.burst_limit:
# 需要等到最老的突发请求超过突发窗口
oldest_burst = self.burst_times[0]
wait_for_burst_limit = oldest_burst + self.burst_window - current_time
else:
wait_for_burst_limit = 0.0
return max(wait_for_minute_limit, wait_for_burst_limit, 0.0)
def get_status(self) -> dict[str, int | float]:
"""获取当前速率限制状态"""
current_time = time.time()
self._cleanup_old_requests(current_time)
return {
"requests_this_minute": len(self.request_times),
"max_requests_per_minute": self.max_requests_per_minute,
"burst_requests": len(self.burst_times),
"burst_limit": self.burst_limit,
"next_reset_in_seconds": (
60.0 - (current_time - self.request_times[0])
if self.request_times
else 0.0
),
}
# 全局速率限制器实例
osu_api_rate_limiter = RateLimiter(
max_requests_per_minute=50, # 保守设置低于建议的60
burst_limit=8, # 短时间内最多8个请求
burst_window=10.0, # 10秒窗口
)

View File

@@ -202,3 +202,7 @@ class SearchQueryModel(BaseModel):
default=False,
description="不良内容",
)
cursor_string: str | None = Field(
default=None,
description="游标字符串,用于分页",
)

View File

@@ -17,7 +17,7 @@ class UTCBaseModel(BaseModel):
return v
Cursor = dict[str, int]
Cursor = dict[str, int | float]
class RespWithCursor(BaseModel):

View File

@@ -7,7 +7,7 @@ from urllib.parse import parse_qs
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database.beatmapset import SearchBeatmapsetsResp
from app.dependencies.beatmap_download import get_beatmap_download_service
from app.dependencies.database import engine, get_db
from app.dependencies.database import engine, get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.user import get_client_user, get_current_user
@@ -55,13 +55,37 @@ async def search_beatmapset(
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
redis = Depends(get_redis),
):
params = parse_qs(qs=request.url.query, keep_blank_values=True)
cursor = {}
# 解析 cursor[field] 格式的参数
for k, v in params.items():
match = re.match(r"cursor\[(\w+)\]", k)
if match:
cursor[match.group(1)] = v[0] if v else None
field_name = match.group(1)
field_value = v[0] if v else None
if field_value is not None:
# 转换为适当的类型
try:
if field_name in ["approved_date", "id"]:
cursor[field_name] = int(field_value)
else:
# 尝试转换为数字类型
try:
# 首先尝试转换为整数
cursor[field_name] = int(field_value)
except ValueError:
try:
# 然后尝试转换为浮点数
cursor[field_name] = float(field_value)
except ValueError:
# 最后保持字符串
cursor[field_name] = field_value
except ValueError:
cursor[field_name] = field_value
if (
"recommended" in query.c
or len(query.r) > 0
@@ -73,7 +97,7 @@ async def search_beatmapset(
# TODO: search locally
return SearchBeatmapsetsResp(total=0, beatmapsets=[])
try:
sets = await fetcher.search_beatmapset(query, cursor)
sets = await fetcher.search_beatmapset(query, cursor, redis)
background_tasks.add_task(_save_to_db, sets)
except HTTPError as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,6 @@
"""缓存调度器模块"""
from __future__ import annotations
from .cache_scheduler import start_cache_scheduler, stop_cache_scheduler
__all__ = ["start_cache_scheduler", "stop_cache_scheduler"]

View File

@@ -0,0 +1,84 @@
from __future__ import annotations
import asyncio
from app.dependencies.database import get_redis
from app.dependencies.fetcher import get_fetcher
from app.log import logger
class BeatmapsetCacheScheduler:
"""谱面集缓存调度器"""
def __init__(self):
self.running = False
self.task = None
async def start(self):
"""启动调度器"""
if self.running:
return
self.running = True
self.task = asyncio.create_task(self._run_scheduler())
logger.info("BeatmapsetCacheScheduler started")
async def stop(self):
"""停止调度器"""
self.running = False
if self.task:
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
logger.info("BeatmapsetCacheScheduler stopped")
async def _run_scheduler(self):
"""运行调度器主循环"""
# 启动时立即执行一次预热
await self._warmup_cache()
while self.running:
try:
# 每30分钟执行一次缓存预热
await asyncio.sleep(30 * 60) # 30分钟
if self.running:
await self._warmup_cache()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Cache scheduler error: {e}")
await asyncio.sleep(60) # 出错后等待1分钟再继续
async def _warmup_cache(self):
"""执行缓存预热"""
try:
logger.info("Starting cache warmup...")
fetcher = await get_fetcher()
redis = get_redis()
# 预热主页缓存
await fetcher.warmup_homepage_cache(redis)
logger.info("Cache warmup completed successfully")
except Exception as e:
logger.error(f"Cache warmup failed: {e}")
# 全局调度器实例
cache_scheduler = BeatmapsetCacheScheduler()
async def start_cache_scheduler():
"""启动缓存调度器"""
await cache_scheduler.start()
async def stop_cache_scheduler():
"""停止缓存调度器"""
await cache_scheduler.stop()

View File

@@ -21,6 +21,7 @@ from app.router import (
signalr_router,
)
from app.router.redirect import redirect_router
from app.scheduler.cache_scheduler import start_cache_scheduler, stop_cache_scheduler
from app.service.beatmap_download_service import download_service
from app.service.calculate_all_user_rank import calculate_user_rank
from app.service.create_banchobot import create_banchobot
@@ -74,9 +75,11 @@ async def lifespan(app: FastAPI):
await daily_challenge_job()
await create_banchobot()
await download_service.start_health_check() # 启动下载服务健康检查
await start_cache_scheduler() # 启动缓存调度器
# on shutdown
yield
stop_scheduler()
await stop_cache_scheduler() # 停止缓存调度器
await download_service.stop_health_check() # 停止下载服务健康检查
await engine.dispose()
await redis_client.aclose()
@@ -111,6 +114,7 @@ app.include_router(fetcher_router)
app.include_router(file_router)
app.include_router(auth_router)
app.include_router(private_router)
# CORS 配置
origins = []
for url in [*settings.cors_urls, settings.server_url]: