Files
g0v0-server/app/router/v2/beatmapset.py
2025-08-18 16:37:30 +00:00

242 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import re
from typing import Annotated, Literal
from urllib.parse import parse_qs
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database.beatmapset import SearchBeatmapsetsResp
from app.dependencies.beatmap_download import get_beatmap_download_service
from app.dependencies.database import Database, get_redis, with_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.dependencies.user import get_client_user, get_current_user
from app.fetcher import Fetcher
from app.models.beatmap import SearchQueryModel
from app.service.beatmap_download_service import BeatmapDownloadService
from .router import router
from fastapi import (
BackgroundTasks,
Depends,
Form,
HTTPException,
Path,
Query,
Request,
Security,
)
from fastapi.responses import RedirectResponse
from httpx import HTTPError
from sqlmodel import exists, select
async def _save_to_db(sets: SearchBeatmapsetsResp):
async with with_db() as session:
for s in sets.beatmapsets:
if not (
await session.exec(select(exists()).where(Beatmapset.id == s.id))
).first():
await Beatmapset.from_resp(session, s)
@router.get(
"/beatmapsets/search",
name="搜索谱面集",
tags=["谱面集"],
response_model=SearchBeatmapsetsResp,
)
async def search_beatmapset(
db: Database,
query: Annotated[SearchQueryModel, Query(...)],
request: Request,
background_tasks: BackgroundTasks,
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
redis=Depends(get_redis),
):
params = parse_qs(qs=request.url.query, keep_blank_values=True)
cursor = {}
# 解析 cursor[field] 格式的参数
for k, v in params.items():
match = re.match(r"cursor\[(\w+)\]", k)
if match:
field_name = match.group(1)
field_value = v[0] if v else None
if field_value is not None:
# 转换为适当的类型
try:
if field_name in ["approved_date", "id"]:
cursor[field_name] = int(field_value)
else:
# 尝试转换为数字类型
try:
# 首先尝试转换为整数
cursor[field_name] = int(field_value)
except ValueError:
try:
# 然后尝试转换为浮点数
cursor[field_name] = float(field_value)
except ValueError:
# 最后保持字符串
cursor[field_name] = field_value
except ValueError:
cursor[field_name] = field_value
if (
"recommended" in query.c
or len(query.r) > 0
or query.played
or "follows" in query.c
or "mine" in query.s
or "favourites" in query.s
):
# TODO: search locally
return SearchBeatmapsetsResp(total=0, beatmapsets=[])
try:
sets = await fetcher.search_beatmapset(query, cursor, redis)
background_tasks.add_task(_save_to_db, sets)
except HTTPError as e:
raise HTTPException(status_code=500, detail=str(e))
return sets
@router.get(
"/beatmapsets/lookup",
tags=["谱面集"],
name="查询谱面集 (通过谱面 ID)",
response_model=BeatmapsetResp,
description=("通过谱面 ID 查询所属谱面集。"),
)
async def lookup_beatmapset(
db: Database,
beatmap_id: int = Query(description="谱面 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
):
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
resp = await BeatmapsetResp.from_db(
beatmap.beatmapset, session=db, user=current_user
)
return resp
@router.get(
"/beatmapsets/{beatmapset_id}",
tags=["谱面集"],
name="获取谱面集详情",
response_model=BeatmapsetResp,
description="获取单个谱面集详情。",
)
async def get_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
):
try:
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
return await BeatmapsetResp.from_db(
beatmapset, session=db, include=["recent_favourites"], user=current_user
)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
@router.get(
"/beatmapsets/{beatmapset_id}/download",
tags=["谱面集"],
name="下载谱面集",
description="**客户端专属**\n下载谱面集文件。基于请求IP地理位置智能分流支持负载均衡和自动故障转移。中国IP使用Sayobot镜像其他地区使用Nerinyan和OsuDirect镜像。",
)
async def download_beatmapset(
request: Request,
beatmapset_id: int = Path(..., description="谱面集 ID"),
no_video: bool = Query(True, alias="noVideo", description="是否下载无视频版本"),
current_user: User = Security(get_client_user),
download_service: BeatmapDownloadService = Depends(get_beatmap_download_service),
):
client_ip = get_client_ip(request)
geoip_helper = get_geoip_helper()
geo_info = geoip_helper.lookup(client_ip)
country_code = geo_info.get("country_iso", "")
# 优先使用IP地理位置判断如果获取失败则回退到用户账户的国家代码
is_china = country_code == "CN" or (
not country_code and current_user.country_code == "CN"
)
try:
# 使用负载均衡服务获取下载URL
download_url = download_service.get_download_url(
beatmapset_id=beatmapset_id, no_video=no_video, is_china=is_china
)
return RedirectResponse(download_url)
except HTTPException:
# 如果负载均衡服务失败,回退到原有逻辑
if is_china:
return RedirectResponse(
f"https://dl.sayobot.cn/beatmaps/download/"
f"{'novideo' if no_video else 'full'}/{beatmapset_id}"
)
else:
return RedirectResponse(
f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}"
)
@router.post(
"/beatmapsets/{beatmapset_id}/favourites",
tags=["谱面集"],
name="收藏或取消收藏谱面集",
description="**客户端专属**\n收藏或取消收藏指定谱面集。",
)
async def favourite_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"),
action: Literal["favourite", "unfavourite"] = Form(
description="操作类型favourite 收藏 / unfavourite 取消收藏"
),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
existing_favourite = (
await db.exec(
select(FavouriteBeatmapset).where(
FavouriteBeatmapset.user_id == current_user.id,
FavouriteBeatmapset.beatmapset_id == beatmapset_id,
)
)
).first()
if (action == "favourite" and existing_favourite) or (
action == "unfavourite" and not existing_favourite
):
return
if action == "favourite":
favourite = FavouriteBeatmapset(
user_id=current_user.id, beatmapset_id=beatmapset_id
)
db.add(favourite)
else:
await db.delete(existing_favourite)
await db.commit()
@router.get(
"/beatmapsets/download-status",
tags=["谱面集"],
name="下载服务状态",
description="获取谱面下载服务的健康状态和负载均衡信息。",
)
async def get_download_service_status(
current_user: User = Security(get_current_user, scopes=["public"]),
download_service: BeatmapDownloadService = Depends(get_beatmap_download_service),
):
"""获取下载服务状态"""
return download_service.get_service_status()