Add beatmapsets Download load balancing
This commit is contained in:
8
app/dependencies/beatmap_download.py
Normal file
8
app/dependencies/beatmap_download.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.service.beatmap_download_service import download_service
|
||||||
|
|
||||||
|
|
||||||
|
def get_beatmap_download_service():
|
||||||
|
"""获取谱面下载服务实例"""
|
||||||
|
return download_service
|
||||||
@@ -6,11 +6,14 @@ from urllib.parse import parse_qs
|
|||||||
|
|
||||||
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
||||||
from app.database.beatmapset import SearchBeatmapsetsResp
|
from app.database.beatmapset import SearchBeatmapsetsResp
|
||||||
|
from app.dependencies.beatmap_download import get_beatmap_download_service
|
||||||
from app.dependencies.database import engine, get_db
|
from app.dependencies.database import engine, get_db
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
|
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||||
from app.dependencies.user import get_client_user, get_current_user
|
from app.dependencies.user import get_client_user, get_current_user
|
||||||
from app.fetcher import Fetcher
|
from app.fetcher import Fetcher
|
||||||
from app.models.beatmap import SearchQueryModel
|
from app.models.beatmap import SearchQueryModel
|
||||||
|
from app.service.beatmap_download_service import BeatmapDownloadService
|
||||||
|
|
||||||
from .router import router
|
from .router import router
|
||||||
|
|
||||||
@@ -123,22 +126,43 @@ async def get_beatmapset(
|
|||||||
"/beatmapsets/{beatmapset_id}/download",
|
"/beatmapsets/{beatmapset_id}/download",
|
||||||
tags=["谱面集"],
|
tags=["谱面集"],
|
||||||
name="下载谱面集",
|
name="下载谱面集",
|
||||||
description="**客户端专属**\n下载谱面集文件。若用户国家为 CN 则跳转国内镜像。",
|
description="**客户端专属**\n下载谱面集文件。基于请求IP地理位置智能分流,支持负载均衡和自动故障转移。中国IP使用Sayobot镜像,其他地区使用Nerinyan和OsuDirect镜像。",
|
||||||
)
|
)
|
||||||
async def download_beatmapset(
|
async def download_beatmapset(
|
||||||
|
request: Request,
|
||||||
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
||||||
no_video: bool = Query(True, alias="noVideo", description="是否下载无视频版本"),
|
no_video: bool = Query(True, alias="noVideo", description="是否下载无视频版本"),
|
||||||
current_user: User = Security(get_client_user),
|
current_user: User = Security(get_client_user),
|
||||||
|
download_service: BeatmapDownloadService = Depends(get_beatmap_download_service),
|
||||||
):
|
):
|
||||||
if current_user.country_code == "CN":
|
client_ip = get_client_ip(request)
|
||||||
return RedirectResponse(
|
|
||||||
f"https://txy1.sayobot.cn/beatmaps/download/"
|
geoip_helper = get_geoip_helper()
|
||||||
f"{'novideo' if no_video else 'full'}/{beatmapset_id}?server=auto"
|
geo_info = geoip_helper.lookup(client_ip)
|
||||||
)
|
country_code = geo_info.get("country_iso", "")
|
||||||
else:
|
|
||||||
return RedirectResponse(
|
# 优先使用IP地理位置判断,如果获取失败则回退到用户账户的国家代码
|
||||||
f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}"
|
is_china = country_code == "CN" or (
|
||||||
|
not country_code and current_user.country_code == "CN"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用负载均衡服务获取下载URL
|
||||||
|
download_url = download_service.get_download_url(
|
||||||
|
beatmapset_id=beatmapset_id, no_video=no_video, is_china=is_china
|
||||||
)
|
)
|
||||||
|
return RedirectResponse(download_url)
|
||||||
|
except HTTPException:
|
||||||
|
# 如果负载均衡服务失败,回退到原有逻辑
|
||||||
|
if is_china:
|
||||||
|
return RedirectResponse(
|
||||||
|
f"https://dl.sayobot.cn/beatmaps/download/"
|
||||||
|
f"{'novideo' if no_video else 'full'}/{beatmapset_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return RedirectResponse(
|
||||||
|
f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -178,3 +202,17 @@ async def favourite_beatmapset(
|
|||||||
else:
|
else:
|
||||||
await db.delete(existing_favourite)
|
await db.delete(existing_favourite)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/beatmapsets/download-status",
|
||||||
|
tags=["谱面集"],
|
||||||
|
name="下载服务状态",
|
||||||
|
description="获取谱面下载服务的健康状态和负载均衡信息。",
|
||||||
|
)
|
||||||
|
async def get_download_service_status(
|
||||||
|
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||||
|
download_service: BeatmapDownloadService = Depends(get_beatmap_download_service),
|
||||||
|
):
|
||||||
|
"""获取下载服务状态"""
|
||||||
|
return download_service.get_service_status()
|
||||||
|
|||||||
255
app/service/beatmap_download_service.py
Normal file
255
app/service/beatmap_download_service.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DownloadEndpoint:
|
||||||
|
"""下载端点配置"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
base_url: str
|
||||||
|
health_check_url: str
|
||||||
|
url_template: str # 下载URL模板,使用{sid}和{type}占位符
|
||||||
|
is_china: bool = False
|
||||||
|
priority: int = 0 # 优先级,数字越小优先级越高
|
||||||
|
timeout: int = 10 # 健康检查超时时间(秒)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EndpointStatus:
|
||||||
|
"""端点状态"""
|
||||||
|
|
||||||
|
endpoint: DownloadEndpoint
|
||||||
|
is_healthy: bool = True
|
||||||
|
last_check: datetime | None = None
|
||||||
|
consecutive_failures: int = 0
|
||||||
|
last_error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BeatmapDownloadService:
|
||||||
|
"""谱面下载服务 - 负载均衡和健康检查"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 中国区域端点
|
||||||
|
self.china_endpoints = [
|
||||||
|
DownloadEndpoint(
|
||||||
|
name="Sayobot",
|
||||||
|
base_url="https://dl.sayobot.cn",
|
||||||
|
health_check_url="https://dl.sayobot.cn/",
|
||||||
|
url_template="https://dl.sayobot.cn/beatmaps/download/{type}/{sid}",
|
||||||
|
is_china=True,
|
||||||
|
priority=0,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 国外区域端点
|
||||||
|
self.international_endpoints = [
|
||||||
|
DownloadEndpoint(
|
||||||
|
name="Nerinyan",
|
||||||
|
base_url="https://api.nerinyan.moe",
|
||||||
|
health_check_url="https://api.nerinyan.moe/health",
|
||||||
|
url_template="https://api.nerinyan.moe/d/{sid}?noVideo={no_video}",
|
||||||
|
is_china=False,
|
||||||
|
priority=0,
|
||||||
|
timeout=10,
|
||||||
|
),
|
||||||
|
DownloadEndpoint(
|
||||||
|
name="OsuDirect",
|
||||||
|
base_url="https://osu.direct",
|
||||||
|
health_check_url="https://osu.direct/api/status",
|
||||||
|
url_template="https://osu.direct/api/d/{sid}",
|
||||||
|
is_china=False,
|
||||||
|
priority=1,
|
||||||
|
timeout=10,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 端点状态跟踪
|
||||||
|
self.endpoint_status: dict[str, EndpointStatus] = {}
|
||||||
|
self._initialize_status()
|
||||||
|
|
||||||
|
# 健康检查配置
|
||||||
|
self.health_check_interval = 600 # 健康检查间隔(秒)
|
||||||
|
self.max_consecutive_failures = 3 # 最大连续失败次数
|
||||||
|
self.health_check_running = False
|
||||||
|
self.health_check_task = None # 存储健康检查任务引用
|
||||||
|
|
||||||
|
# HTTP客户端
|
||||||
|
self.http_client = httpx.AsyncClient(timeout=30)
|
||||||
|
|
||||||
|
def _initialize_status(self):
|
||||||
|
"""初始化端点状态"""
|
||||||
|
all_endpoints = self.china_endpoints + self.international_endpoints
|
||||||
|
for endpoint in all_endpoints:
|
||||||
|
self.endpoint_status[endpoint.name] = EndpointStatus(endpoint=endpoint)
|
||||||
|
|
||||||
|
async def start_health_check(self):
|
||||||
|
"""启动健康检查任务"""
|
||||||
|
if self.health_check_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.health_check_running = True
|
||||||
|
self.health_check_task = asyncio.create_task(self._health_check_loop())
|
||||||
|
logger.info("Beatmap download service health check started")
|
||||||
|
|
||||||
|
async def stop_health_check(self):
|
||||||
|
"""停止健康检查任务"""
|
||||||
|
self.health_check_running = False
|
||||||
|
await self.http_client.aclose()
|
||||||
|
logger.info("Beatmap download service health check stopped")
|
||||||
|
|
||||||
|
async def _health_check_loop(self):
|
||||||
|
"""健康检查循环"""
|
||||||
|
while self.health_check_running:
|
||||||
|
try:
|
||||||
|
await self._check_all_endpoints()
|
||||||
|
await asyncio.sleep(self.health_check_interval)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Health check loop error: {e}")
|
||||||
|
await asyncio.sleep(5) # 错误时短暂等待
|
||||||
|
|
||||||
|
async def _check_all_endpoints(self):
|
||||||
|
"""检查所有端点的健康状态"""
|
||||||
|
all_endpoints = self.china_endpoints + self.international_endpoints
|
||||||
|
|
||||||
|
# 并发检查所有端点
|
||||||
|
tasks = []
|
||||||
|
for endpoint in all_endpoints:
|
||||||
|
task = asyncio.create_task(self._check_endpoint_health(endpoint))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
async def _check_endpoint_health(self, endpoint: DownloadEndpoint):
|
||||||
|
"""检查单个端点的健康状态"""
|
||||||
|
status = self.endpoint_status[endpoint.name]
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=endpoint.timeout) as client:
|
||||||
|
response = await client.get(endpoint.health_check_url)
|
||||||
|
|
||||||
|
# 根据不同端点类型判断健康状态
|
||||||
|
is_healthy = False
|
||||||
|
if endpoint.name == "Sayobot":
|
||||||
|
# Sayobot 端点返回 304 (Not Modified) 表示正常
|
||||||
|
is_healthy = response.status_code in [200, 304]
|
||||||
|
else:
|
||||||
|
# 其他端点返回 200 表示正常
|
||||||
|
is_healthy = response.status_code == 200
|
||||||
|
|
||||||
|
if is_healthy:
|
||||||
|
# 健康检查成功
|
||||||
|
if not status.is_healthy:
|
||||||
|
logger.info(f"Endpoint {endpoint.name} is now healthy")
|
||||||
|
|
||||||
|
status.is_healthy = True
|
||||||
|
status.consecutive_failures = 0
|
||||||
|
status.last_error = None
|
||||||
|
else:
|
||||||
|
raise httpx.HTTPStatusError(
|
||||||
|
f"Health check failed with status {response.status_code}",
|
||||||
|
request=response.request,
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 健康检查失败
|
||||||
|
status.consecutive_failures += 1
|
||||||
|
status.last_error = str(e)
|
||||||
|
|
||||||
|
if status.consecutive_failures >= self.max_consecutive_failures:
|
||||||
|
if status.is_healthy:
|
||||||
|
logger.warning(
|
||||||
|
f"Endpoint {endpoint.name} marked as unhealthy after "
|
||||||
|
f"{status.consecutive_failures} consecutive failures: {e}"
|
||||||
|
)
|
||||||
|
status.is_healthy = False
|
||||||
|
|
||||||
|
finally:
|
||||||
|
status.last_check = datetime.now()
|
||||||
|
|
||||||
|
def get_healthy_endpoints(self, is_china: bool) -> list[DownloadEndpoint]:
|
||||||
|
"""获取健康的端点列表"""
|
||||||
|
endpoints = self.china_endpoints if is_china else self.international_endpoints
|
||||||
|
|
||||||
|
healthy_endpoints = []
|
||||||
|
for endpoint in endpoints:
|
||||||
|
status = self.endpoint_status[endpoint.name]
|
||||||
|
if status.is_healthy:
|
||||||
|
healthy_endpoints.append(endpoint)
|
||||||
|
|
||||||
|
# 按优先级排序
|
||||||
|
healthy_endpoints.sort(key=lambda x: x.priority)
|
||||||
|
return healthy_endpoints
|
||||||
|
|
||||||
|
def get_download_url(
|
||||||
|
self, beatmapset_id: int, no_video: bool, is_china: bool
|
||||||
|
) -> str:
|
||||||
|
"""获取下载URL,带负载均衡和故障转移"""
|
||||||
|
healthy_endpoints = self.get_healthy_endpoints(is_china)
|
||||||
|
|
||||||
|
if not healthy_endpoints:
|
||||||
|
# 如果没有健康的端点,记录错误并回退到所有端点中优先级最高的
|
||||||
|
logger.error(f"No healthy endpoints available for is_china={is_china}")
|
||||||
|
endpoints = (
|
||||||
|
self.china_endpoints if is_china else self.international_endpoints
|
||||||
|
)
|
||||||
|
if not endpoints:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503, detail="No download endpoints available"
|
||||||
|
)
|
||||||
|
endpoint = min(endpoints, key=lambda x: x.priority)
|
||||||
|
else:
|
||||||
|
# 使用第一个健康的端点(已按优先级排序)
|
||||||
|
endpoint = healthy_endpoints[0]
|
||||||
|
|
||||||
|
# 根据端点类型生成URL
|
||||||
|
if endpoint.name == "Sayobot":
|
||||||
|
video_type = "novideo" if no_video else "full"
|
||||||
|
return endpoint.url_template.format(type=video_type, sid=beatmapset_id)
|
||||||
|
elif endpoint.name == "Nerinyan":
|
||||||
|
return endpoint.url_template.format(
|
||||||
|
sid=beatmapset_id, no_video="true" if no_video else "false"
|
||||||
|
)
|
||||||
|
elif endpoint.name == "OsuDirect":
|
||||||
|
# osu.direct 似乎没有no_video参数,直接使用基础URL
|
||||||
|
return endpoint.url_template.format(sid=beatmapset_id)
|
||||||
|
else:
|
||||||
|
# 默认处理
|
||||||
|
return endpoint.url_template.format(sid=beatmapset_id)
|
||||||
|
|
||||||
|
def get_service_status(self) -> dict:
|
||||||
|
"""获取服务状态信息"""
|
||||||
|
status_info = {
|
||||||
|
"service_running": self.health_check_running,
|
||||||
|
"last_update": datetime.now().isoformat(),
|
||||||
|
"endpoints": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, status in self.endpoint_status.items():
|
||||||
|
status_info["endpoints"][name] = {
|
||||||
|
"healthy": status.is_healthy,
|
||||||
|
"last_check": status.last_check.isoformat()
|
||||||
|
if status.last_check
|
||||||
|
else None,
|
||||||
|
"consecutive_failures": status.consecutive_failures,
|
||||||
|
"last_error": status.last_error,
|
||||||
|
"priority": status.endpoint.priority,
|
||||||
|
"is_china": status.endpoint.is_china,
|
||||||
|
}
|
||||||
|
|
||||||
|
return status_info
|
||||||
|
|
||||||
|
|
||||||
|
# 全局服务实例
|
||||||
|
download_service = BeatmapDownloadService()
|
||||||
3
main.py
3
main.py
@@ -21,6 +21,7 @@ from app.router import (
|
|||||||
signalr_router,
|
signalr_router,
|
||||||
)
|
)
|
||||||
from app.router.redirect import redirect_router
|
from app.router.redirect import redirect_router
|
||||||
|
from app.service.beatmap_download_service import download_service
|
||||||
from app.service.calculate_all_user_rank import calculate_user_rank
|
from app.service.calculate_all_user_rank import calculate_user_rank
|
||||||
from app.service.create_banchobot import create_banchobot
|
from app.service.create_banchobot import create_banchobot
|
||||||
from app.service.daily_challenge import daily_challenge_job
|
from app.service.daily_challenge import daily_challenge_job
|
||||||
@@ -72,9 +73,11 @@ async def lifespan(app: FastAPI):
|
|||||||
schedule_geoip_updates() # 调度 GeoIP 定时更新任务
|
schedule_geoip_updates() # 调度 GeoIP 定时更新任务
|
||||||
await daily_challenge_job()
|
await daily_challenge_job()
|
||||||
await create_banchobot()
|
await create_banchobot()
|
||||||
|
await download_service.start_health_check() # 启动下载服务健康检查
|
||||||
# on shutdown
|
# on shutdown
|
||||||
yield
|
yield
|
||||||
stop_scheduler()
|
stop_scheduler()
|
||||||
|
await download_service.stop_health_check() # 停止下载服务健康检查
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
await redis_client.aclose()
|
await redis_client.aclose()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user