Add beatmapsets Download load balancing

This commit is contained in:
咕谷酱
2025-08-18 02:58:40 +08:00
parent 944c3e4931
commit 041e2a0781
4 changed files with 313 additions and 9 deletions

View File

@@ -0,0 +1,8 @@
from __future__ import annotations
from app.service.beatmap_download_service import download_service
def get_beatmap_download_service():
"""获取谱面下载服务实例"""
return download_service

View File

@@ -6,11 +6,14 @@ from urllib.parse import parse_qs
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database.beatmapset import SearchBeatmapsetsResp
from app.dependencies.beatmap_download import get_beatmap_download_service
from app.dependencies.database import engine, get_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.user import get_client_user, get_current_user
from app.fetcher import Fetcher
from app.models.beatmap import SearchQueryModel
from app.service.beatmap_download_service import BeatmapDownloadService
from .router import router
@@ -123,22 +126,43 @@ async def get_beatmapset(
"/beatmapsets/{beatmapset_id}/download",
tags=["谱面集"],
name="下载谱面集",
description="**客户端专属**\n下载谱面集文件。若用户国家为 CN 则跳转国内镜像。",
description="**客户端专属**\n下载谱面集文件。基于请求IP地理位置智能分流支持负载均衡和自动故障转移。中国IP使用Sayobot镜像其他地区使用Nerinyan和OsuDirect镜像。",
)
async def download_beatmapset(
request: Request,
beatmapset_id: int = Path(..., description="谱面集 ID"),
no_video: bool = Query(True, alias="noVideo", description="是否下载无视频版本"),
current_user: User = Security(get_client_user),
download_service: BeatmapDownloadService = Depends(get_beatmap_download_service),
):
if current_user.country_code == "CN":
return RedirectResponse(
f"https://txy1.sayobot.cn/beatmaps/download/"
f"{'novideo' if no_video else 'full'}/{beatmapset_id}?server=auto"
)
else:
return RedirectResponse(
f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}"
client_ip = get_client_ip(request)
geoip_helper = get_geoip_helper()
geo_info = geoip_helper.lookup(client_ip)
country_code = geo_info.get("country_iso", "")
# 优先使用IP地理位置判断如果获取失败则回退到用户账户的国家代码
is_china = country_code == "CN" or (
not country_code and current_user.country_code == "CN"
)
try:
# 使用负载均衡服务获取下载URL
download_url = download_service.get_download_url(
beatmapset_id=beatmapset_id, no_video=no_video, is_china=is_china
)
return RedirectResponse(download_url)
except HTTPException:
# 如果负载均衡服务失败,回退到原有逻辑
if is_china:
return RedirectResponse(
f"https://dl.sayobot.cn/beatmaps/download/"
f"{'novideo' if no_video else 'full'}/{beatmapset_id}"
)
else:
return RedirectResponse(
f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}"
)
@router.post(
@@ -178,3 +202,17 @@ async def favourite_beatmapset(
else:
await db.delete(existing_favourite)
await db.commit()
@router.get(
"/beatmapsets/download-status",
tags=["谱面集"],
name="下载服务状态",
description="获取谱面下载服务的健康状态和负载均衡信息。",
)
async def get_download_service_status(
current_user: User = Security(get_current_user, scopes=["public"]),
download_service: BeatmapDownloadService = Depends(get_beatmap_download_service),
):
"""获取下载服务状态"""
return download_service.get_service_status()

View File

@@ -0,0 +1,255 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from datetime import datetime
import logging
from fastapi import HTTPException
import httpx
logger = logging.getLogger(__name__)
@dataclass
class DownloadEndpoint:
"""下载端点配置"""
name: str
base_url: str
health_check_url: str
url_template: str # 下载URL模板使用{sid}和{type}占位符
is_china: bool = False
priority: int = 0 # 优先级,数字越小优先级越高
timeout: int = 10 # 健康检查超时时间(秒)
@dataclass
class EndpointStatus:
"""端点状态"""
endpoint: DownloadEndpoint
is_healthy: bool = True
last_check: datetime | None = None
consecutive_failures: int = 0
last_error: str | None = None
class BeatmapDownloadService:
"""谱面下载服务 - 负载均衡和健康检查"""
def __init__(self):
# 中国区域端点
self.china_endpoints = [
DownloadEndpoint(
name="Sayobot",
base_url="https://dl.sayobot.cn",
health_check_url="https://dl.sayobot.cn/",
url_template="https://dl.sayobot.cn/beatmaps/download/{type}/{sid}",
is_china=True,
priority=0,
timeout=5,
)
]
# 国外区域端点
self.international_endpoints = [
DownloadEndpoint(
name="Nerinyan",
base_url="https://api.nerinyan.moe",
health_check_url="https://api.nerinyan.moe/health",
url_template="https://api.nerinyan.moe/d/{sid}?noVideo={no_video}",
is_china=False,
priority=0,
timeout=10,
),
DownloadEndpoint(
name="OsuDirect",
base_url="https://osu.direct",
health_check_url="https://osu.direct/api/status",
url_template="https://osu.direct/api/d/{sid}",
is_china=False,
priority=1,
timeout=10,
),
]
# 端点状态跟踪
self.endpoint_status: dict[str, EndpointStatus] = {}
self._initialize_status()
# 健康检查配置
self.health_check_interval = 600 # 健康检查间隔(秒)
self.max_consecutive_failures = 3 # 最大连续失败次数
self.health_check_running = False
self.health_check_task = None # 存储健康检查任务引用
# HTTP客户端
self.http_client = httpx.AsyncClient(timeout=30)
def _initialize_status(self):
"""初始化端点状态"""
all_endpoints = self.china_endpoints + self.international_endpoints
for endpoint in all_endpoints:
self.endpoint_status[endpoint.name] = EndpointStatus(endpoint=endpoint)
async def start_health_check(self):
"""启动健康检查任务"""
if self.health_check_running:
return
self.health_check_running = True
self.health_check_task = asyncio.create_task(self._health_check_loop())
logger.info("Beatmap download service health check started")
async def stop_health_check(self):
"""停止健康检查任务"""
self.health_check_running = False
await self.http_client.aclose()
logger.info("Beatmap download service health check stopped")
async def _health_check_loop(self):
"""健康检查循环"""
while self.health_check_running:
try:
await self._check_all_endpoints()
await asyncio.sleep(self.health_check_interval)
except Exception as e:
logger.error(f"Health check loop error: {e}")
await asyncio.sleep(5) # 错误时短暂等待
async def _check_all_endpoints(self):
"""检查所有端点的健康状态"""
all_endpoints = self.china_endpoints + self.international_endpoints
# 并发检查所有端点
tasks = []
for endpoint in all_endpoints:
task = asyncio.create_task(self._check_endpoint_health(endpoint))
tasks.append(task)
await asyncio.gather(*tasks, return_exceptions=True)
async def _check_endpoint_health(self, endpoint: DownloadEndpoint):
"""检查单个端点的健康状态"""
status = self.endpoint_status[endpoint.name]
try:
async with httpx.AsyncClient(timeout=endpoint.timeout) as client:
response = await client.get(endpoint.health_check_url)
# 根据不同端点类型判断健康状态
is_healthy = False
if endpoint.name == "Sayobot":
# Sayobot 端点返回 304 (Not Modified) 表示正常
is_healthy = response.status_code in [200, 304]
else:
# 其他端点返回 200 表示正常
is_healthy = response.status_code == 200
if is_healthy:
# 健康检查成功
if not status.is_healthy:
logger.info(f"Endpoint {endpoint.name} is now healthy")
status.is_healthy = True
status.consecutive_failures = 0
status.last_error = None
else:
raise httpx.HTTPStatusError(
f"Health check failed with status {response.status_code}",
request=response.request,
response=response,
)
except Exception as e:
# 健康检查失败
status.consecutive_failures += 1
status.last_error = str(e)
if status.consecutive_failures >= self.max_consecutive_failures:
if status.is_healthy:
logger.warning(
f"Endpoint {endpoint.name} marked as unhealthy after "
f"{status.consecutive_failures} consecutive failures: {e}"
)
status.is_healthy = False
finally:
status.last_check = datetime.now()
def get_healthy_endpoints(self, is_china: bool) -> list[DownloadEndpoint]:
"""获取健康的端点列表"""
endpoints = self.china_endpoints if is_china else self.international_endpoints
healthy_endpoints = []
for endpoint in endpoints:
status = self.endpoint_status[endpoint.name]
if status.is_healthy:
healthy_endpoints.append(endpoint)
# 按优先级排序
healthy_endpoints.sort(key=lambda x: x.priority)
return healthy_endpoints
def get_download_url(
self, beatmapset_id: int, no_video: bool, is_china: bool
) -> str:
"""获取下载URL带负载均衡和故障转移"""
healthy_endpoints = self.get_healthy_endpoints(is_china)
if not healthy_endpoints:
# 如果没有健康的端点,记录错误并回退到所有端点中优先级最高的
logger.error(f"No healthy endpoints available for is_china={is_china}")
endpoints = (
self.china_endpoints if is_china else self.international_endpoints
)
if not endpoints:
raise HTTPException(
status_code=503, detail="No download endpoints available"
)
endpoint = min(endpoints, key=lambda x: x.priority)
else:
# 使用第一个健康的端点(已按优先级排序)
endpoint = healthy_endpoints[0]
# 根据端点类型生成URL
if endpoint.name == "Sayobot":
video_type = "novideo" if no_video else "full"
return endpoint.url_template.format(type=video_type, sid=beatmapset_id)
elif endpoint.name == "Nerinyan":
return endpoint.url_template.format(
sid=beatmapset_id, no_video="true" if no_video else "false"
)
elif endpoint.name == "OsuDirect":
# osu.direct 似乎没有no_video参数直接使用基础URL
return endpoint.url_template.format(sid=beatmapset_id)
else:
# 默认处理
return endpoint.url_template.format(sid=beatmapset_id)
def get_service_status(self) -> dict:
"""获取服务状态信息"""
status_info = {
"service_running": self.health_check_running,
"last_update": datetime.now().isoformat(),
"endpoints": {},
}
for name, status in self.endpoint_status.items():
status_info["endpoints"][name] = {
"healthy": status.is_healthy,
"last_check": status.last_check.isoformat()
if status.last_check
else None,
"consecutive_failures": status.consecutive_failures,
"last_error": status.last_error,
"priority": status.endpoint.priority,
"is_china": status.endpoint.is_china,
}
return status_info
# 全局服务实例
download_service = BeatmapDownloadService()

View File

@@ -21,6 +21,7 @@ from app.router import (
signalr_router,
)
from app.router.redirect import redirect_router
from app.service.beatmap_download_service import download_service
from app.service.calculate_all_user_rank import calculate_user_rank
from app.service.create_banchobot import create_banchobot
from app.service.daily_challenge import daily_challenge_job
@@ -72,9 +73,11 @@ async def lifespan(app: FastAPI):
schedule_geoip_updates() # 调度 GeoIP 定时更新任务
await daily_challenge_job()
await create_banchobot()
await download_service.start_health_check() # 启动下载服务健康检查
# on shutdown
yield
stop_scheduler()
await download_service.stop_health_check() # 停止下载服务健康检查
await engine.dispose()
await redis_client.aclose()