diff --git a/app/dependencies/beatmap_download.py b/app/dependencies/beatmap_download.py new file mode 100644 index 0000000..ffed3a0 --- /dev/null +++ b/app/dependencies/beatmap_download.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from app.service.beatmap_download_service import download_service + + +def get_beatmap_download_service(): + """获取谱面下载服务实例""" + return download_service diff --git a/app/router/v2/beatmapset.py b/app/router/v2/beatmapset.py index ebd8ecd..d843208 100644 --- a/app/router/v2/beatmapset.py +++ b/app/router/v2/beatmapset.py @@ -6,11 +6,14 @@ from urllib.parse import parse_qs from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.database.beatmapset import SearchBeatmapsetsResp +from app.dependencies.beatmap_download import get_beatmap_download_service from app.dependencies.database import engine, get_db from app.dependencies.fetcher import get_fetcher +from app.dependencies.geoip import get_client_ip, get_geoip_helper from app.dependencies.user import get_client_user, get_current_user from app.fetcher import Fetcher from app.models.beatmap import SearchQueryModel +from app.service.beatmap_download_service import BeatmapDownloadService from .router import router @@ -123,22 +126,43 @@ async def get_beatmapset( "/beatmapsets/{beatmapset_id}/download", tags=["谱面集"], name="下载谱面集", - description="**客户端专属**\n下载谱面集文件。若用户国家为 CN 则跳转国内镜像。", + description="**客户端专属**\n下载谱面集文件。基于请求IP地理位置智能分流,支持负载均衡和自动故障转移。中国IP使用Sayobot镜像,其他地区使用Nerinyan和OsuDirect镜像。", ) async def download_beatmapset( + request: Request, beatmapset_id: int = Path(..., description="谱面集 ID"), no_video: bool = Query(True, alias="noVideo", description="是否下载无视频版本"), current_user: User = Security(get_client_user), + download_service: BeatmapDownloadService = Depends(get_beatmap_download_service), ): - if current_user.country_code == "CN": - return RedirectResponse( - f"https://txy1.sayobot.cn/beatmaps/download/" - f"{'novideo' if no_video else 'full'}/{beatmapset_id}?server=auto" - ) - else: - return RedirectResponse( - f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}" + client_ip = get_client_ip(request) + + geoip_helper = get_geoip_helper() + geo_info = geoip_helper.lookup(client_ip) + country_code = geo_info.get("country_iso", "") + + # 优先使用IP地理位置判断,如果获取失败则回退到用户账户的国家代码 + is_china = country_code == "CN" or ( + not country_code and current_user.country_code == "CN" + ) + + try: + # 使用负载均衡服务获取下载URL + download_url = download_service.get_download_url( + beatmapset_id=beatmapset_id, no_video=no_video, is_china=is_china ) + return RedirectResponse(download_url) + except HTTPException: + # 如果负载均衡服务失败,回退到原有逻辑 + if is_china: + return RedirectResponse( + f"https://dl.sayobot.cn/beatmaps/download/" + f"{'novideo' if no_video else 'full'}/{beatmapset_id}" + ) + else: + return RedirectResponse( + f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}" + ) @router.post( @@ -178,3 +202,17 @@ async def favourite_beatmapset( else: await db.delete(existing_favourite) await db.commit() + + +@router.get( + "/beatmapsets/download-status", + tags=["谱面集"], + name="下载服务状态", + description="获取谱面下载服务的健康状态和负载均衡信息。", +) +async def get_download_service_status( + current_user: User = Security(get_current_user, scopes=["public"]), + download_service: BeatmapDownloadService = Depends(get_beatmap_download_service), +): + """获取下载服务状态""" + return download_service.get_service_status() diff --git a/app/service/beatmap_download_service.py b/app/service/beatmap_download_service.py new file mode 100644 index 0000000..d90856a --- /dev/null +++ b/app/service/beatmap_download_service.py @@ -0,0 +1,255 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from datetime import datetime +import logging + +from fastapi import HTTPException +import httpx + +logger = logging.getLogger(__name__) + + +@dataclass +class DownloadEndpoint: + """下载端点配置""" + + name: str + base_url: str + health_check_url: str + url_template: str # 下载URL模板,使用{sid}和{type}占位符 + is_china: bool = False + priority: int = 0 # 优先级,数字越小优先级越高 + timeout: int = 10 # 健康检查超时时间(秒) + + +@dataclass +class EndpointStatus: + """端点状态""" + + endpoint: DownloadEndpoint + is_healthy: bool = True + last_check: datetime | None = None + consecutive_failures: int = 0 + last_error: str | None = None + + +class BeatmapDownloadService: + """谱面下载服务 - 负载均衡和健康检查""" + + def __init__(self): + # 中国区域端点 + self.china_endpoints = [ + DownloadEndpoint( + name="Sayobot", + base_url="https://dl.sayobot.cn", + health_check_url="https://dl.sayobot.cn/", + url_template="https://dl.sayobot.cn/beatmaps/download/{type}/{sid}", + is_china=True, + priority=0, + timeout=5, + ) + ] + + # 国外区域端点 + self.international_endpoints = [ + DownloadEndpoint( + name="Nerinyan", + base_url="https://api.nerinyan.moe", + health_check_url="https://api.nerinyan.moe/health", + url_template="https://api.nerinyan.moe/d/{sid}?noVideo={no_video}", + is_china=False, + priority=0, + timeout=10, + ), + DownloadEndpoint( + name="OsuDirect", + base_url="https://osu.direct", + health_check_url="https://osu.direct/api/status", + url_template="https://osu.direct/api/d/{sid}", + is_china=False, + priority=1, + timeout=10, + ), + ] + + # 端点状态跟踪 + self.endpoint_status: dict[str, EndpointStatus] = {} + self._initialize_status() + + # 健康检查配置 + self.health_check_interval = 600 # 健康检查间隔(秒) + self.max_consecutive_failures = 3 # 最大连续失败次数 + self.health_check_running = False + self.health_check_task = None # 存储健康检查任务引用 + + # HTTP客户端 + self.http_client = httpx.AsyncClient(timeout=30) + + def _initialize_status(self): + """初始化端点状态""" + all_endpoints = self.china_endpoints + self.international_endpoints + for endpoint in all_endpoints: + self.endpoint_status[endpoint.name] = EndpointStatus(endpoint=endpoint) + + async def start_health_check(self): + """启动健康检查任务""" + if self.health_check_running: + return + + self.health_check_running = True + self.health_check_task = asyncio.create_task(self._health_check_loop()) + logger.info("Beatmap download service health check started") + + async def stop_health_check(self): + """停止健康检查任务""" + self.health_check_running = False + await self.http_client.aclose() + logger.info("Beatmap download service health check stopped") + + async def _health_check_loop(self): + """健康检查循环""" + while self.health_check_running: + try: + await self._check_all_endpoints() + await asyncio.sleep(self.health_check_interval) + except Exception as e: + logger.error(f"Health check loop error: {e}") + await asyncio.sleep(5) # 错误时短暂等待 + + async def _check_all_endpoints(self): + """检查所有端点的健康状态""" + all_endpoints = self.china_endpoints + self.international_endpoints + + # 并发检查所有端点 + tasks = [] + for endpoint in all_endpoints: + task = asyncio.create_task(self._check_endpoint_health(endpoint)) + tasks.append(task) + + await asyncio.gather(*tasks, return_exceptions=True) + + async def _check_endpoint_health(self, endpoint: DownloadEndpoint): + """检查单个端点的健康状态""" + status = self.endpoint_status[endpoint.name] + + try: + async with httpx.AsyncClient(timeout=endpoint.timeout) as client: + response = await client.get(endpoint.health_check_url) + + # 根据不同端点类型判断健康状态 + is_healthy = False + if endpoint.name == "Sayobot": + # Sayobot 端点返回 304 (Not Modified) 表示正常 + is_healthy = response.status_code in [200, 304] + else: + # 其他端点返回 200 表示正常 + is_healthy = response.status_code == 200 + + if is_healthy: + # 健康检查成功 + if not status.is_healthy: + logger.info(f"Endpoint {endpoint.name} is now healthy") + + status.is_healthy = True + status.consecutive_failures = 0 + status.last_error = None + else: + raise httpx.HTTPStatusError( + f"Health check failed with status {response.status_code}", + request=response.request, + response=response, + ) + + except Exception as e: + # 健康检查失败 + status.consecutive_failures += 1 + status.last_error = str(e) + + if status.consecutive_failures >= self.max_consecutive_failures: + if status.is_healthy: + logger.warning( + f"Endpoint {endpoint.name} marked as unhealthy after " + f"{status.consecutive_failures} consecutive failures: {e}" + ) + status.is_healthy = False + + finally: + status.last_check = datetime.now() + + def get_healthy_endpoints(self, is_china: bool) -> list[DownloadEndpoint]: + """获取健康的端点列表""" + endpoints = self.china_endpoints if is_china else self.international_endpoints + + healthy_endpoints = [] + for endpoint in endpoints: + status = self.endpoint_status[endpoint.name] + if status.is_healthy: + healthy_endpoints.append(endpoint) + + # 按优先级排序 + healthy_endpoints.sort(key=lambda x: x.priority) + return healthy_endpoints + + def get_download_url( + self, beatmapset_id: int, no_video: bool, is_china: bool + ) -> str: + """获取下载URL,带负载均衡和故障转移""" + healthy_endpoints = self.get_healthy_endpoints(is_china) + + if not healthy_endpoints: + # 如果没有健康的端点,记录错误并回退到所有端点中优先级最高的 + logger.error(f"No healthy endpoints available for is_china={is_china}") + endpoints = ( + self.china_endpoints if is_china else self.international_endpoints + ) + if not endpoints: + raise HTTPException( + status_code=503, detail="No download endpoints available" + ) + endpoint = min(endpoints, key=lambda x: x.priority) + else: + # 使用第一个健康的端点(已按优先级排序) + endpoint = healthy_endpoints[0] + + # 根据端点类型生成URL + if endpoint.name == "Sayobot": + video_type = "novideo" if no_video else "full" + return endpoint.url_template.format(type=video_type, sid=beatmapset_id) + elif endpoint.name == "Nerinyan": + return endpoint.url_template.format( + sid=beatmapset_id, no_video="true" if no_video else "false" + ) + elif endpoint.name == "OsuDirect": + # osu.direct 似乎没有no_video参数,直接使用基础URL + return endpoint.url_template.format(sid=beatmapset_id) + else: + # 默认处理 + return endpoint.url_template.format(sid=beatmapset_id) + + def get_service_status(self) -> dict: + """获取服务状态信息""" + status_info = { + "service_running": self.health_check_running, + "last_update": datetime.now().isoformat(), + "endpoints": {}, + } + + for name, status in self.endpoint_status.items(): + status_info["endpoints"][name] = { + "healthy": status.is_healthy, + "last_check": status.last_check.isoformat() + if status.last_check + else None, + "consecutive_failures": status.consecutive_failures, + "last_error": status.last_error, + "priority": status.endpoint.priority, + "is_china": status.endpoint.is_china, + } + + return status_info + + +# 全局服务实例 +download_service = BeatmapDownloadService() diff --git a/main.py b/main.py index 28b0830..f8cc17f 100644 --- a/main.py +++ b/main.py @@ -21,6 +21,7 @@ from app.router import ( signalr_router, ) from app.router.redirect import redirect_router +from app.service.beatmap_download_service import download_service from app.service.calculate_all_user_rank import calculate_user_rank from app.service.create_banchobot import create_banchobot from app.service.daily_challenge import daily_challenge_job @@ -72,9 +73,11 @@ async def lifespan(app: FastAPI): schedule_geoip_updates() # 调度 GeoIP 定时更新任务 await daily_challenge_job() await create_banchobot() + await download_service.start_health_check() # 启动下载服务健康检查 # on shutdown yield stop_scheduler() + await download_service.stop_health_check() # 停止下载服务健康检查 await engine.dispose() await redis_client.aclose()