from __future__ import annotations import re from typing import Annotated, Literal from urllib.parse import parse_qs from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.database.beatmapset import SearchBeatmapsetsResp from app.dependencies.beatmap_download import get_beatmap_download_service from app.dependencies.database import Database, get_redis, with_db from app.dependencies.fetcher import get_fetcher from app.dependencies.geoip import get_client_ip, get_geoip_helper from app.dependencies.user import get_client_user, get_current_user from app.fetcher import Fetcher from app.models.beatmap import SearchQueryModel from app.service.beatmap_download_service import BeatmapDownloadService from .router import router from fastapi import ( BackgroundTasks, Depends, Form, HTTPException, Path, Query, Request, Security, ) from fastapi.responses import RedirectResponse from httpx import HTTPError from sqlmodel import exists, select async def _save_to_db(sets: SearchBeatmapsetsResp): async with with_db() as session: for s in sets.beatmapsets: if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first(): await Beatmapset.from_resp(session, s) @router.get( "/beatmapsets/search", name="搜索谱面集", tags=["谱面集"], response_model=SearchBeatmapsetsResp, ) async def search_beatmapset( db: Database, query: Annotated[SearchQueryModel, Query(...)], request: Request, background_tasks: BackgroundTasks, current_user: User = Security(get_current_user, scopes=["public"]), fetcher: Fetcher = Depends(get_fetcher), redis=Depends(get_redis), ): params = parse_qs(qs=request.url.query, keep_blank_values=True) cursor = {} # 解析 cursor[field] 格式的参数 for k, v in params.items(): match = re.match(r"cursor\[(\w+)\]", k) if match: field_name = match.group(1) field_value = v[0] if v else None if field_value is not None: # 转换为适当的类型 try: if field_name in ["approved_date", "id"]: cursor[field_name] = int(field_value) else: # 尝试转换为数字类型 try: # 首先尝试转换为整数 cursor[field_name] = int(field_value) except ValueError: try: # 然后尝试转换为浮点数 cursor[field_name] = float(field_value) except ValueError: # 最后保持字符串 cursor[field_name] = field_value except ValueError: cursor[field_name] = field_value if ( "recommended" in query.c or len(query.r) > 0 or query.played or "follows" in query.c or "mine" in query.s or "favourites" in query.s ): # TODO: search locally return SearchBeatmapsetsResp(total=0, beatmapsets=[]) try: sets = await fetcher.search_beatmapset(query, cursor, redis) background_tasks.add_task(_save_to_db, sets) except HTTPError as e: raise HTTPException(status_code=500, detail=str(e)) return sets @router.get( "/beatmapsets/lookup", tags=["谱面集"], name="查询谱面集 (通过谱面 ID)", response_model=BeatmapsetResp, description=("通过谱面 ID 查询所属谱面集。"), ) async def lookup_beatmapset( db: Database, beatmap_id: int = Query(description="谱面 ID"), current_user: User = Security(get_current_user, scopes=["public"]), fetcher: Fetcher = Depends(get_fetcher), ): beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id) resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user) return resp @router.get( "/beatmapsets/{beatmapset_id}", tags=["谱面集"], name="获取谱面集详情", response_model=BeatmapsetResp, description="获取单个谱面集详情。", ) async def get_beatmapset( db: Database, beatmapset_id: int = Path(..., description="谱面集 ID"), current_user: User = Security(get_current_user, scopes=["public"]), fetcher: Fetcher = Depends(get_fetcher), ): try: beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id) return await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user) except HTTPError: raise HTTPException(status_code=404, detail="Beatmapset not found") @router.get( "/beatmapsets/{beatmapset_id}/download", tags=["谱面集"], name="下载谱面集", description="**客户端专属**\n下载谱面集文件。基于请求IP地理位置智能分流,支持负载均衡和自动故障转移。中国IP使用Sayobot镜像,其他地区使用Nerinyan和OsuDirect镜像。", ) async def download_beatmapset( request: Request, beatmapset_id: int = Path(..., description="谱面集 ID"), no_video: bool = Query(True, alias="noVideo", description="是否下载无视频版本"), current_user: User = Security(get_client_user), download_service: BeatmapDownloadService = Depends(get_beatmap_download_service), ): client_ip = get_client_ip(request) geoip_helper = get_geoip_helper() geo_info = geoip_helper.lookup(client_ip) country_code = geo_info.get("country_iso", "") # 优先使用IP地理位置判断,如果获取失败则回退到用户账户的国家代码 is_china = country_code == "CN" or (not country_code and current_user.country_code == "CN") try: # 使用负载均衡服务获取下载URL download_url = download_service.get_download_url( beatmapset_id=beatmapset_id, no_video=no_video, is_china=is_china ) return RedirectResponse(download_url) except HTTPException: # 如果负载均衡服务失败,回退到原有逻辑 if is_china: return RedirectResponse( f"https://dl.sayobot.cn/beatmaps/download/{'novideo' if no_video else 'full'}/{beatmapset_id}" ) else: return RedirectResponse(f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}") @router.post( "/beatmapsets/{beatmapset_id}/favourites", tags=["谱面集"], name="收藏或取消收藏谱面集", description="**客户端专属**\n收藏或取消收藏指定谱面集。", ) async def favourite_beatmapset( db: Database, beatmapset_id: int = Path(..., description="谱面集 ID"), action: Literal["favourite", "unfavourite"] = Form(description="操作类型:favourite 收藏 / unfavourite 取消收藏"), current_user: User = Security(get_client_user), ): existing_favourite = ( await db.exec( select(FavouriteBeatmapset).where( FavouriteBeatmapset.user_id == current_user.id, FavouriteBeatmapset.beatmapset_id == beatmapset_id, ) ) ).first() if (action == "favourite" and existing_favourite) or (action == "unfavourite" and not existing_favourite): return if action == "favourite": favourite = FavouriteBeatmapset(user_id=current_user.id, beatmapset_id=beatmapset_id) db.add(favourite) else: await db.delete(existing_favourite) await db.commit() @router.get( "/beatmapsets/download-status", tags=["谱面集"], name="下载服务状态", description="获取谱面下载服务的健康状态和负载均衡信息。", ) async def get_download_service_status( current_user: User = Security(get_current_user, scopes=["public"]), download_service: BeatmapDownloadService = Depends(get_beatmap_download_service), ): """获取下载服务状态""" return download_service.get_service_status()