Files
g0v0-server/app/database/beatmap.py

279 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from datetime import datetime
import hashlib
from typing import TYPE_CHECKING
from app.calculator import get_calculator
from app.config import settings
from app.database.beatmap_tags import BeatmapTagVote
from app.database.failtime import FailTime, FailTimeResp
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
from app.models.performance import DifficultyAttributesUnion
from app.models.score import GameMode
from .beatmap_playcounts import BeatmapPlaycounts
from .beatmapset import Beatmapset, BeatmapsetResp
from pydantic import BaseModel, TypeAdapter
from redis.asyncio import Redis
from sqlalchemy import Column, DateTime
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
from .user import User
class BeatmapOwner(SQLModel):
id: int
username: str
class BeatmapBase(SQLModel):
# Beatmap
url: str
mode: GameMode
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
difficulty_rating: float = Field(default=0.0, index=True)
total_length: int
user_id: int = Field(index=True)
version: str = Field(index=True)
# optional
checksum: str = Field(sa_column=Column(VARCHAR(32), index=True))
current_user_playcount: int = Field(default=0)
max_combo: int | None = Field(default=0)
# TODO: failtimes, owners
# BeatmapExtended
ar: float = Field(default=0.0)
cs: float = Field(default=0.0)
drain: float = Field(default=0.0) # hp
accuracy: float = Field(default=0.0) # od
bpm: float = Field(default=0.0)
count_circles: int = Field(default=0)
count_sliders: int = Field(default=0)
count_spinners: int = Field(default=0)
deleted_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
hit_length: int = Field(default=0)
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
class Beatmap(BeatmapBase, table=True):
__tablename__: str = "beatmaps"
id: int = Field(primary_key=True, index=True)
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus = Field(index=True)
# optional
beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
@classmethod
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
d = resp.model_dump()
del d["beatmapset"]
beatmap = cls.model_validate(
{
**d,
"beatmapset_id": resp.beatmapset_id,
"id": resp.id,
"beatmap_status": BeatmapRankStatus(resp.ranked),
}
)
return beatmap
@classmethod
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
beatmap = await cls.from_resp_no_save(session, resp)
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
session.add(beatmap)
await session.commit()
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
@classmethod
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
beatmaps = []
for resp in inp:
if resp.id == from_:
continue
d = resp.model_dump()
del d["beatmapset"]
beatmap = Beatmap.model_validate(
{
**d,
"beatmapset_id": resp.beatmapset_id,
"id": resp.id,
"beatmap_status": BeatmapRankStatus(resp.ranked),
}
)
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
session.add(beatmap)
beatmaps.append(beatmap)
await session.commit()
return beatmaps
@classmethod
async def get_or_fetch(
cls,
session: AsyncSession,
fetcher: "Fetcher",
bid: int | None = None,
md5: str | None = None,
) -> "Beatmap":
stmt = select(Beatmap)
if bid is not None:
stmt = stmt.where(Beatmap.id == bid)
elif md5 is not None:
stmt = stmt.where(Beatmap.checksum == md5)
else:
raise ValueError("Either bid or md5 must be provided")
beatmap = (await session.exec(stmt)).first()
if not beatmap:
resp = await fetcher.get_beatmap(bid, md5)
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
if not r.first():
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
return await Beatmap.from_resp(session, resp)
return beatmap
class APIBeatmapTag(BaseModel):
tag_id: int
count: int
class BeatmapResp(BeatmapBase):
id: int
beatmapset_id: int
beatmapset: BeatmapsetResp | None = None
convert: bool = False
is_scoreable: bool
status: str
mode_int: int
ranked: int
url: str = ""
playcount: int = 0
passcount: int = 0
failtimes: FailTimeResp | None = None
top_tag_ids: list[APIBeatmapTag] | None = None
current_user_tag_ids: list[int] | None = None
is_deleted: bool = False
@classmethod
async def from_db(
cls,
beatmap: Beatmap,
query_mode: GameMode | None = None,
from_set: bool = False,
session: AsyncSession | None = None,
user: "User | None" = None,
) -> "BeatmapResp":
from .score import Score
beatmap_ = beatmap.model_dump()
beatmap_status = beatmap.beatmap_status
if query_mode is not None and beatmap.mode != query_mode:
beatmap_["convert"] = True
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
else:
beatmap_["status"] = beatmap_status.name.lower()
beatmap_["ranked"] = beatmap_status.value
beatmap_["mode_int"] = int(beatmap.mode)
beatmap_["is_deleted"] = beatmap.deleted_at is not None
if not from_set:
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
if beatmap.failtimes is not None:
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
else:
beatmap_["failtimes"] = FailTimeResp()
if session:
beatmap_["playcount"] = (
await session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
)
).first() or 0
beatmap_["passcount"] = (
await session.exec(
select(func.count())
.select_from(Score)
.where(
Score.beatmap_id == beatmap.id,
col(Score.passed).is_(True),
)
)
).one()
all_votes = (
await session.exec(
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
.where(BeatmapTagVote.beatmap_id == beatmap.id)
.group_by(col(BeatmapTagVote.tag_id))
.having(func.count() > settings.beatmap_tag_top_count)
)
).all()
top_tag_ids: list[dict[str, int]] = []
for id, votes in all_votes:
top_tag_ids.append({"tag_id": id, "count": votes})
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
beatmap_["top_tag_ids"] = top_tag_ids
if user is not None:
beatmap_["current_user_tag_ids"] = (
await session.exec(
select(BeatmapTagVote.tag_id)
.where(BeatmapTagVote.beatmap_id == beatmap.id)
.where(BeatmapTagVote.user_id == user.id)
)
).all()
else:
beatmap_["current_user_tag_ids"] = []
return cls.model_validate(beatmap_)
class BannedBeatmaps(SQLModel, table=True):
__tablename__: str = "banned_beatmaps"
id: int | None = Field(primary_key=True, index=True, default=None)
beatmap_id: int = Field(index=True)
async def calculate_beatmap_attributes(
beatmap_id: int,
ruleset: GameMode,
mods_: list[APIMod],
redis: Redis,
fetcher: "Fetcher",
) -> DifficultyAttributesUnion:
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes"
if await redis.exists(key):
return TypeAdapter(DifficultyAttributesUnion).validate_json(await redis.get(key))
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
attr = await get_calculator().calculate_difficulty(resp, mods_, ruleset)
await redis.set(key, attr.model_dump_json())
return attr
async def clear_cached_beatmap_raws(redis: Redis, beatmaps: list[int] = []):
"""清理缓存的 beatmap 原始数据,使用非阻塞方式"""
if beatmaps:
# 分批删除,避免一次删除太多 key 导致阻塞
batch_size = 50
for i in range(0, len(beatmaps), batch_size):
batch = beatmaps[i : i + batch_size]
keys = [f"beatmap:{bid}:raw" for bid in batch]
# 使用 unlink 而不是 delete非阻塞更快
try:
await redis.unlink(*keys)
except Exception:
# 如果 unlink 不支持,回退到 delete
await redis.delete(*keys)
return
await redis.delete("beatmap:*:raw")