279 lines
10 KiB
Python
279 lines
10 KiB
Python
from datetime import datetime
|
||
import hashlib
|
||
from typing import TYPE_CHECKING
|
||
|
||
from app.calculator import get_calculator
|
||
from app.config import settings
|
||
from app.database.beatmap_tags import BeatmapTagVote
|
||
from app.database.failtime import FailTime, FailTimeResp
|
||
from app.models.beatmap import BeatmapRankStatus
|
||
from app.models.mods import APIMod
|
||
from app.models.performance import DifficultyAttributesUnion
|
||
from app.models.score import GameMode
|
||
|
||
from .beatmap_playcounts import BeatmapPlaycounts
|
||
from .beatmapset import Beatmapset, BeatmapsetResp
|
||
|
||
from pydantic import BaseModel, TypeAdapter
|
||
from redis.asyncio import Redis
|
||
from sqlalchemy import Column, DateTime
|
||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select
|
||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||
|
||
if TYPE_CHECKING:
|
||
from app.fetcher import Fetcher
|
||
|
||
from .user import User
|
||
|
||
|
||
class BeatmapOwner(SQLModel):
|
||
id: int
|
||
username: str
|
||
|
||
|
||
class BeatmapBase(SQLModel):
|
||
# Beatmap
|
||
url: str
|
||
mode: GameMode
|
||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||
difficulty_rating: float = Field(default=0.0, index=True)
|
||
total_length: int
|
||
user_id: int = Field(index=True)
|
||
version: str = Field(index=True)
|
||
|
||
# optional
|
||
checksum: str = Field(sa_column=Column(VARCHAR(32), index=True))
|
||
current_user_playcount: int = Field(default=0)
|
||
max_combo: int | None = Field(default=0)
|
||
# TODO: failtimes, owners
|
||
|
||
# BeatmapExtended
|
||
ar: float = Field(default=0.0)
|
||
cs: float = Field(default=0.0)
|
||
drain: float = Field(default=0.0) # hp
|
||
accuracy: float = Field(default=0.0) # od
|
||
bpm: float = Field(default=0.0)
|
||
count_circles: int = Field(default=0)
|
||
count_sliders: int = Field(default=0)
|
||
count_spinners: int = Field(default=0)
|
||
deleted_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||
hit_length: int = Field(default=0)
|
||
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
|
||
|
||
|
||
class Beatmap(BeatmapBase, table=True):
|
||
__tablename__: str = "beatmaps"
|
||
id: int = Field(primary_key=True, index=True)
|
||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||
beatmap_status: BeatmapRankStatus = Field(index=True)
|
||
# optional
|
||
beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
|
||
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
|
||
|
||
@classmethod
|
||
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||
d = resp.model_dump()
|
||
del d["beatmapset"]
|
||
beatmap = cls.model_validate(
|
||
{
|
||
**d,
|
||
"beatmapset_id": resp.beatmapset_id,
|
||
"id": resp.id,
|
||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
||
}
|
||
)
|
||
return beatmap
|
||
|
||
@classmethod
|
||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||
beatmap = await cls.from_resp_no_save(session, resp)
|
||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||
session.add(beatmap)
|
||
await session.commit()
|
||
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
|
||
|
||
@classmethod
|
||
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
|
||
beatmaps = []
|
||
for resp in inp:
|
||
if resp.id == from_:
|
||
continue
|
||
d = resp.model_dump()
|
||
del d["beatmapset"]
|
||
beatmap = Beatmap.model_validate(
|
||
{
|
||
**d,
|
||
"beatmapset_id": resp.beatmapset_id,
|
||
"id": resp.id,
|
||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
||
}
|
||
)
|
||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||
session.add(beatmap)
|
||
beatmaps.append(beatmap)
|
||
await session.commit()
|
||
return beatmaps
|
||
|
||
@classmethod
|
||
async def get_or_fetch(
|
||
cls,
|
||
session: AsyncSession,
|
||
fetcher: "Fetcher",
|
||
bid: int | None = None,
|
||
md5: str | None = None,
|
||
) -> "Beatmap":
|
||
stmt = select(Beatmap)
|
||
if bid is not None:
|
||
stmt = stmt.where(Beatmap.id == bid)
|
||
elif md5 is not None:
|
||
stmt = stmt.where(Beatmap.checksum == md5)
|
||
else:
|
||
raise ValueError("Either bid or md5 must be provided")
|
||
beatmap = (await session.exec(stmt)).first()
|
||
if not beatmap:
|
||
resp = await fetcher.get_beatmap(bid, md5)
|
||
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
|
||
if not r.first():
|
||
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
|
||
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
|
||
return await Beatmap.from_resp(session, resp)
|
||
return beatmap
|
||
|
||
|
||
class APIBeatmapTag(BaseModel):
|
||
tag_id: int
|
||
count: int
|
||
|
||
|
||
class BeatmapResp(BeatmapBase):
|
||
id: int
|
||
beatmapset_id: int
|
||
beatmapset: BeatmapsetResp | None = None
|
||
convert: bool = False
|
||
is_scoreable: bool
|
||
status: str
|
||
mode_int: int
|
||
ranked: int
|
||
url: str = ""
|
||
playcount: int = 0
|
||
passcount: int = 0
|
||
failtimes: FailTimeResp | None = None
|
||
top_tag_ids: list[APIBeatmapTag] | None = None
|
||
current_user_tag_ids: list[int] | None = None
|
||
is_deleted: bool = False
|
||
|
||
@classmethod
|
||
async def from_db(
|
||
cls,
|
||
beatmap: Beatmap,
|
||
query_mode: GameMode | None = None,
|
||
from_set: bool = False,
|
||
session: AsyncSession | None = None,
|
||
user: "User | None" = None,
|
||
) -> "BeatmapResp":
|
||
from .score import Score
|
||
|
||
beatmap_ = beatmap.model_dump()
|
||
beatmap_status = beatmap.beatmap_status
|
||
if query_mode is not None and beatmap.mode != query_mode:
|
||
beatmap_["convert"] = True
|
||
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
|
||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
|
||
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
|
||
else:
|
||
beatmap_["status"] = beatmap_status.name.lower()
|
||
beatmap_["ranked"] = beatmap_status.value
|
||
beatmap_["mode_int"] = int(beatmap.mode)
|
||
beatmap_["is_deleted"] = beatmap.deleted_at is not None
|
||
if not from_set:
|
||
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
|
||
if beatmap.failtimes is not None:
|
||
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
|
||
else:
|
||
beatmap_["failtimes"] = FailTimeResp()
|
||
if session:
|
||
beatmap_["playcount"] = (
|
||
await session.exec(
|
||
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
|
||
)
|
||
).first() or 0
|
||
beatmap_["passcount"] = (
|
||
await session.exec(
|
||
select(func.count())
|
||
.select_from(Score)
|
||
.where(
|
||
Score.beatmap_id == beatmap.id,
|
||
col(Score.passed).is_(True),
|
||
)
|
||
)
|
||
).one()
|
||
|
||
all_votes = (
|
||
await session.exec(
|
||
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
|
||
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
||
.group_by(col(BeatmapTagVote.tag_id))
|
||
.having(func.count() > settings.beatmap_tag_top_count)
|
||
)
|
||
).all()
|
||
top_tag_ids: list[dict[str, int]] = []
|
||
for id, votes in all_votes:
|
||
top_tag_ids.append({"tag_id": id, "count": votes})
|
||
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
|
||
beatmap_["top_tag_ids"] = top_tag_ids
|
||
|
||
if user is not None:
|
||
beatmap_["current_user_tag_ids"] = (
|
||
await session.exec(
|
||
select(BeatmapTagVote.tag_id)
|
||
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
||
.where(BeatmapTagVote.user_id == user.id)
|
||
)
|
||
).all()
|
||
else:
|
||
beatmap_["current_user_tag_ids"] = []
|
||
return cls.model_validate(beatmap_)
|
||
|
||
|
||
class BannedBeatmaps(SQLModel, table=True):
|
||
__tablename__: str = "banned_beatmaps"
|
||
id: int | None = Field(primary_key=True, index=True, default=None)
|
||
beatmap_id: int = Field(index=True)
|
||
|
||
|
||
async def calculate_beatmap_attributes(
|
||
beatmap_id: int,
|
||
ruleset: GameMode,
|
||
mods_: list[APIMod],
|
||
redis: Redis,
|
||
fetcher: "Fetcher",
|
||
) -> DifficultyAttributesUnion:
|
||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes"
|
||
if await redis.exists(key):
|
||
return TypeAdapter(DifficultyAttributesUnion).validate_json(await redis.get(key))
|
||
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||
|
||
attr = await get_calculator().calculate_difficulty(resp, mods_, ruleset)
|
||
await redis.set(key, attr.model_dump_json())
|
||
return attr
|
||
|
||
|
||
async def clear_cached_beatmap_raws(redis: Redis, beatmaps: list[int] = []):
|
||
"""清理缓存的 beatmap 原始数据,使用非阻塞方式"""
|
||
if beatmaps:
|
||
# 分批删除,避免一次删除太多 key 导致阻塞
|
||
batch_size = 50
|
||
for i in range(0, len(beatmaps), batch_size):
|
||
batch = beatmaps[i : i + batch_size]
|
||
keys = [f"beatmap:{bid}:raw" for bid in batch]
|
||
# 使用 unlink 而不是 delete(非阻塞,更快)
|
||
try:
|
||
await redis.unlink(*keys)
|
||
except Exception:
|
||
# 如果 unlink 不支持,回退到 delete
|
||
await redis.delete(*keys)
|
||
return
|
||
|
||
await redis.delete("beatmap:*:raw")
|