from datetime import datetime import hashlib from typing import TYPE_CHECKING from app.calculator import get_calculator from app.config import settings from app.database.beatmap_tags import BeatmapTagVote from app.database.failtime import FailTime, FailTimeResp from app.models.beatmap import BeatmapRankStatus from app.models.mods import APIMod from app.models.performance import DifficultyAttributesUnion from app.models.score import GameMode from .beatmap_playcounts import BeatmapPlaycounts from .beatmapset import Beatmapset, BeatmapsetResp from pydantic import BaseModel, TypeAdapter from redis.asyncio import Redis from sqlalchemy import Column, DateTime from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from app.fetcher import Fetcher from .user import User class BeatmapOwner(SQLModel): id: int username: str class BeatmapBase(SQLModel): # Beatmap url: str mode: GameMode beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) difficulty_rating: float = Field(default=0.0, index=True) total_length: int user_id: int = Field(index=True) version: str = Field(index=True) # optional checksum: str = Field(sa_column=Column(VARCHAR(32), index=True)) current_user_playcount: int = Field(default=0) max_combo: int | None = Field(default=0) # TODO: failtimes, owners # BeatmapExtended ar: float = Field(default=0.0) cs: float = Field(default=0.0) drain: float = Field(default=0.0) # hp accuracy: float = Field(default=0.0) # od bpm: float = Field(default=0.0) count_circles: int = Field(default=0) count_sliders: int = Field(default=0) count_spinners: int = Field(default=0) deleted_at: datetime | None = Field(default=None, sa_column=Column(DateTime)) hit_length: int = Field(default=0) last_updated: datetime = Field(sa_column=Column(DateTime, index=True)) class Beatmap(BeatmapBase, table=True): __tablename__: str = "beatmaps" id: int = Field(primary_key=True, index=True) beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) beatmap_status: BeatmapRankStatus = Field(index=True) # optional beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}) failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"}) @classmethod async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap": d = resp.model_dump() del d["beatmapset"] beatmap = cls.model_validate( { **d, "beatmapset_id": resp.beatmapset_id, "id": resp.id, "beatmap_status": BeatmapRankStatus(resp.ranked), } ) return beatmap @classmethod async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap": beatmap = await cls.from_resp_no_save(session, resp) if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first(): session.add(beatmap) await session.commit() return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one() @classmethod async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]: beatmaps = [] for resp in inp: if resp.id == from_: continue d = resp.model_dump() del d["beatmapset"] beatmap = Beatmap.model_validate( { **d, "beatmapset_id": resp.beatmapset_id, "id": resp.id, "beatmap_status": BeatmapRankStatus(resp.ranked), } ) if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first(): session.add(beatmap) beatmaps.append(beatmap) await session.commit() return beatmaps @classmethod async def get_or_fetch( cls, session: AsyncSession, fetcher: "Fetcher", bid: int | None = None, md5: str | None = None, ) -> "Beatmap": stmt = select(Beatmap) if bid is not None: stmt = stmt.where(Beatmap.id == bid) elif md5 is not None: stmt = stmt.where(Beatmap.checksum == md5) else: raise ValueError("Either bid or md5 must be provided") beatmap = (await session.exec(stmt)).first() if not beatmap: resp = await fetcher.get_beatmap(bid, md5) r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)) if not r.first(): set_resp = await fetcher.get_beatmapset(resp.beatmapset_id) await Beatmapset.from_resp(session, set_resp, from_=resp.id) return await Beatmap.from_resp(session, resp) return beatmap class APIBeatmapTag(BaseModel): tag_id: int count: int class BeatmapResp(BeatmapBase): id: int beatmapset_id: int beatmapset: BeatmapsetResp | None = None convert: bool = False is_scoreable: bool status: str mode_int: int ranked: int url: str = "" playcount: int = 0 passcount: int = 0 failtimes: FailTimeResp | None = None top_tag_ids: list[APIBeatmapTag] | None = None current_user_tag_ids: list[int] | None = None is_deleted: bool = False @classmethod async def from_db( cls, beatmap: Beatmap, query_mode: GameMode | None = None, from_set: bool = False, session: AsyncSession | None = None, user: "User | None" = None, ) -> "BeatmapResp": from .score import Score beatmap_ = beatmap.model_dump() beatmap_status = beatmap.beatmap_status if query_mode is not None and beatmap.mode != query_mode: beatmap_["convert"] = True beatmap_["is_scoreable"] = beatmap_status.has_leaderboard() if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard(): beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower() else: beatmap_["status"] = beatmap_status.name.lower() beatmap_["ranked"] = beatmap_status.value beatmap_["mode_int"] = int(beatmap.mode) beatmap_["is_deleted"] = beatmap.deleted_at is not None if not from_set: beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user) if beatmap.failtimes is not None: beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes) else: beatmap_["failtimes"] = FailTimeResp() if session: beatmap_["playcount"] = ( await session.exec( select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id) ) ).first() or 0 beatmap_["passcount"] = ( await session.exec( select(func.count()) .select_from(Score) .where( Score.beatmap_id == beatmap.id, col(Score.passed).is_(True), ) ) ).one() all_votes = ( await session.exec( select(BeatmapTagVote.tag_id, func.count().label("vote_count")) .where(BeatmapTagVote.beatmap_id == beatmap.id) .group_by(col(BeatmapTagVote.tag_id)) .having(func.count() > settings.beatmap_tag_top_count) ) ).all() top_tag_ids: list[dict[str, int]] = [] for id, votes in all_votes: top_tag_ids.append({"tag_id": id, "count": votes}) top_tag_ids.sort(key=lambda x: x["count"], reverse=True) beatmap_["top_tag_ids"] = top_tag_ids if user is not None: beatmap_["current_user_tag_ids"] = ( await session.exec( select(BeatmapTagVote.tag_id) .where(BeatmapTagVote.beatmap_id == beatmap.id) .where(BeatmapTagVote.user_id == user.id) ) ).all() else: beatmap_["current_user_tag_ids"] = [] return cls.model_validate(beatmap_) class BannedBeatmaps(SQLModel, table=True): __tablename__: str = "banned_beatmaps" id: int | None = Field(primary_key=True, index=True, default=None) beatmap_id: int = Field(index=True) async def calculate_beatmap_attributes( beatmap_id: int, ruleset: GameMode, mods_: list[APIMod], redis: Redis, fetcher: "Fetcher", ) -> DifficultyAttributesUnion: key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes" if await redis.exists(key): return TypeAdapter(DifficultyAttributesUnion).validate_json(await redis.get(key)) resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) attr = await get_calculator().calculate_difficulty(resp, mods_, ruleset) await redis.set(key, attr.model_dump_json()) return attr async def clear_cached_beatmap_raws(redis: Redis, beatmaps: list[int] = []): """清理缓存的 beatmap 原始数据,使用非阻塞方式""" if beatmaps: # 分批删除,避免一次删除太多 key 导致阻塞 batch_size = 50 for i in range(0, len(beatmaps), batch_size): batch = beatmaps[i : i + batch_size] keys = [f"beatmap:{bid}:raw" for bid in batch] # 使用 unlink 而不是 delete(非阻塞,更快) try: await redis.unlink(*keys) except Exception: # 如果 unlink 不支持,回退到 delete await redis.delete(*keys) return await redis.delete("beatmap:*:raw")