refactor(database): use a new 'On-Demand' design (#86)

Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
MingxuanGame
2025-11-23 21:41:02 +08:00
committed by GitHub
parent 42f1d53d3e
commit 40da994ae8
46 changed files with 4396 additions and 2354 deletions

View File

@@ -2,23 +2,31 @@ from .achievement import UserAchievement, UserAchievementResp
from .auth import OAuthClient, OAuthToken, TotpKeys, V1APIKeys
from .beatmap import (
Beatmap,
BeatmapResp,
BeatmapDict,
BeatmapModel,
)
from .beatmap_playcounts import (
BeatmapPlaycounts,
BeatmapPlaycountsDict,
BeatmapPlaycountsModel,
)
from .beatmap_playcounts import BeatmapPlaycounts, BeatmapPlaycountsResp
from .beatmap_sync import BeatmapSync
from .beatmap_tags import BeatmapTagVote
from .beatmapset import (
Beatmapset,
BeatmapsetResp,
BeatmapsetDict,
BeatmapsetModel,
)
from .beatmapset_ratings import BeatmapRating
from .best_scores import BestScore
from .chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatChannelDict,
ChatChannelModel,
ChatMessage,
ChatMessageResp,
ChatMessageDict,
ChatMessageModel,
)
from .counts import (
CountResp,
@@ -30,8 +38,8 @@ from .events import Event
from .favourite_beatmapset import FavouriteBeatmapset
from .item_attempts_count import (
ItemAttemptsCount,
ItemAttemptsResp,
PlaylistAggregateScore,
ItemAttemptsCountDict,
ItemAttemptsCountModel,
)
from .matchmaking import (
MatchmakingPool,
@@ -42,30 +50,32 @@ from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
from .notification import Notification, UserNotification
from .password_reset import PasswordReset
from .playlist_best_score import PlaylistBestScore
from .playlists import Playlist, PlaylistResp
from .playlists import Playlist, PlaylistDict, PlaylistModel
from .rank_history import RankHistory, RankHistoryResp, RankTop
from .relationship import Relationship, RelationshipResp, RelationshipType
from .room import APIUploadedRoom, Room, RoomResp
from .relationship import Relationship, RelationshipDict, RelationshipModel, RelationshipType
from .room import APIUploadedRoom, Room, RoomDict, RoomModel
from .room_participated_user import RoomParticipatedUser
from .score import (
MultiplayerScores,
Score,
ScoreAround,
ScoreBase,
ScoreResp,
ScoreDict,
ScoreModel,
ScoreStatistics,
)
from .score_token import ScoreToken, ScoreTokenResp
from .search_beatmapset import SearchBeatmapsetsResp
from .statistics import (
UserStatistics,
UserStatisticsResp,
UserStatisticsDict,
UserStatisticsModel,
)
from .team import Team, TeamMember, TeamRequest, TeamResp
from .total_score_best_scores import TotalScoreBestScore
from .user import (
MeResp,
User,
UserResp,
UserDict,
UserModel,
)
from .user_account_history import (
UserAccountHistory,
@@ -79,20 +89,25 @@ from .verification import EmailVerification, LoginSession, LoginSessionResp, Tru
__all__ = [
"APIUploadedRoom",
"Beatmap",
"BeatmapDict",
"BeatmapModel",
"BeatmapPlaycounts",
"BeatmapPlaycountsResp",
"BeatmapPlaycountsDict",
"BeatmapPlaycountsModel",
"BeatmapRating",
"BeatmapResp",
"BeatmapSync",
"BeatmapTagVote",
"Beatmapset",
"BeatmapsetResp",
"BeatmapsetDict",
"BeatmapsetModel",
"BestScore",
"ChannelType",
"ChatChannel",
"ChatChannelResp",
"ChatChannelDict",
"ChatChannelModel",
"ChatMessage",
"ChatMessageResp",
"ChatMessageDict",
"ChatMessageModel",
"CountResp",
"DailyChallengeStats",
"DailyChallengeStatsResp",
@@ -100,13 +115,13 @@ __all__ = [
"Event",
"FavouriteBeatmapset",
"ItemAttemptsCount",
"ItemAttemptsResp",
"ItemAttemptsCountDict",
"ItemAttemptsCountModel",
"LoginSession",
"LoginSessionResp",
"MatchmakingPool",
"MatchmakingPoolBeatmap",
"MatchmakingUserStats",
"MeResp",
"MonthlyPlaycounts",
"MultiplayerEvent",
"MultiplayerEventResp",
@@ -116,26 +131,29 @@ __all__ = [
"OAuthToken",
"PasswordReset",
"Playlist",
"PlaylistAggregateScore",
"PlaylistBestScore",
"PlaylistResp",
"PlaylistDict",
"PlaylistModel",
"RankHistory",
"RankHistoryResp",
"RankTop",
"Relationship",
"RelationshipResp",
"RelationshipDict",
"RelationshipModel",
"RelationshipType",
"ReplayWatchedCount",
"Room",
"RoomDict",
"RoomModel",
"RoomParticipatedUser",
"RoomResp",
"Score",
"ScoreAround",
"ScoreBase",
"ScoreResp",
"ScoreDict",
"ScoreModel",
"ScoreStatistics",
"ScoreToken",
"ScoreTokenResp",
"SearchBeatmapsetsResp",
"Team",
"TeamMember",
"TeamRequest",
@@ -149,17 +167,18 @@ __all__ = [
"UserAccountHistoryResp",
"UserAccountHistoryType",
"UserAchievement",
"UserAchievement",
"UserAchievementResp",
"UserDict",
"UserLoginLog",
"UserModel",
"UserNotification",
"UserPreference",
"UserResp",
"UserStatistics",
"UserStatisticsResp",
"UserStatisticsDict",
"UserStatisticsModel",
"V1APIKeys",
]
for i in __all__:
if i.endswith("Resp"):
globals()[i].model_rebuild() # type: ignore[call-arg]
if i.endswith("Model") or i.endswith("Resp"):
globals()[i].model_rebuild()

499
app/database/_base.py Normal file
View File

@@ -0,0 +1,499 @@
from collections.abc import Awaitable, Callable, Sequence
from functools import lru_cache, wraps
import inspect
import sys
from types import NoneType, get_original_bases
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Concatenate,
ForwardRef,
ParamSpec,
TypedDict,
cast,
get_args,
get_origin,
overload,
)
from app.models.model import UTCBaseModel
from app.utils import type_is_optional
from sqlalchemy.ext.asyncio import async_object_session
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import SQLModelMetaclass
_dict_to_model: dict[type, type["DatabaseModel"]] = {}
def _safe_evaluate_forwardref(type_: str | ForwardRef, module_name: str) -> Any:
"""Safely evaluate a ForwardRef, with fallback to app.database module"""
if isinstance(type_, str):
type_ = ForwardRef(type_)
try:
return evaluate_forwardref(
type_,
globalns=vars(sys.modules[module_name]),
localns={},
)
except (NameError, AttributeError, KeyError):
# Fallback to app.database module
try:
import app.database
return evaluate_forwardref(
type_,
globalns=vars(app.database),
localns={},
)
except (NameError, AttributeError, KeyError):
return None
class OnDemand[T]:
if TYPE_CHECKING:
def __get__(self, instance: object | None, owner: Any) -> T: ...
def __set__(self, instance: Any, value: T) -> None: ...
def __delete__(self, instance: Any) -> None: ...
class Exclude[T]:
if TYPE_CHECKING:
def __get__(self, instance: object | None, owner: Any) -> T: ...
def __set__(self, instance: Any, value: T) -> None: ...
def __delete__(self, instance: Any) -> None: ...
# https://github.com/fastapi/sqlmodel/blob/main/sqlmodel/_compat.py#L126-L140
def _get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
raw_annotations: dict[str, Any] = class_dict.get("__annotations__", {})
if sys.version_info >= (3, 14) and "__annotations__" not in class_dict:
# See https://github.com/pydantic/pydantic/pull/11991
from annotationlib import (
Format,
call_annotate_function,
get_annotate_from_class_namespace,
)
if annotate := get_annotate_from_class_namespace(class_dict):
raw_annotations = call_annotate_function(annotate, format=Format.FORWARDREF)
return raw_annotations
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L58-L77
if sys.version_info < (3, 12, 4):
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
# Even though it is the right signature for python 3.9, mypy complains with
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
# Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid:
# TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard'
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
else:
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
# Pydantic 1.x will not support PEP 695 syntax, but provide `type_params` to avoid
# warnings:
return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set())
class DatabaseModelMetaclass(SQLModelMetaclass):
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
**kwargs: Any,
) -> "DatabaseModelMetaclass":
original_annotations = _get_annotations(namespace)
new_annotations = {}
ondemands = []
excludes = []
for k, v in original_annotations.items():
if get_origin(v) is OnDemand:
inner_type = v.__args__[0]
new_annotations[k] = inner_type
ondemands.append(k)
elif get_origin(v) is Exclude:
inner_type = v.__args__[0]
new_annotations[k] = inner_type
excludes.append(k)
else:
new_annotations[k] = v
new_class = super().__new__(
cls,
name,
bases,
{
**namespace,
"__annotations__": new_annotations,
},
**kwargs,
)
new_class._CALCULATED_FIELDS = dict(getattr(new_class, "_CALCULATED_FIELDS", {}))
new_class._ONDEMAND_DATABASE_FIELDS = list(getattr(new_class, "_ONDEMAND_DATABASE_FIELDS", [])) + list(
ondemands
)
new_class._ONDEMAND_CALCULATED_FIELDS = dict(getattr(new_class, "_ONDEMAND_CALCULATED_FIELDS", {}))
new_class._EXCLUDED_DATABASE_FIELDS = list(getattr(new_class, "_EXCLUDED_DATABASE_FIELDS", [])) + list(excludes)
for attr_name, attr_value in namespace.items():
target = _get_callable_target(attr_value)
if target is None:
continue
if getattr(target, "__included__", False):
new_class._CALCULATED_FIELDS[attr_name] = _get_return_type(target)
_pre_calculate_context_params(target, attr_value)
if getattr(target, "__calculated_ondemand__", False):
new_class._ONDEMAND_CALCULATED_FIELDS[attr_name] = _get_return_type(target)
_pre_calculate_context_params(target, attr_value)
# Register TDict to DatabaseModel mapping
for base in get_original_bases(new_class):
cls_name = base.__name__
if "DatabaseModel" in cls_name and "[" in cls_name and "]" in cls_name:
generic_type_name = cls_name[cls_name.index("[") : cls_name.rindex("]") + 1]
generic_type = evaluate_forwardref(
ForwardRef(generic_type_name),
globalns=vars(sys.modules[new_class.__module__]),
localns={},
)
_dict_to_model[generic_type[0]] = new_class
return new_class
def _pre_calculate_context_params(target: Callable, attr_value: Any) -> None:
if hasattr(target, "__context_params__"):
return
sig = inspect.signature(target)
params = list(sig.parameters.keys())
start_index = 2
if isinstance(attr_value, classmethod):
start_index = 3
context_params = [] if len(params) < start_index else params[start_index:]
setattr(target, "__context_params__", context_params)
def _get_callable_target(value: Any) -> Callable | None:
if isinstance(value, (staticmethod, classmethod)):
return value.__func__
if inspect.isfunction(value):
return value
if inspect.ismethod(value):
return value.__func__
return None
def _mark_callable(value: Any, flag: str) -> Callable | None:
target = _get_callable_target(value)
if target is None:
return None
setattr(target, flag, True)
return target
def _get_return_type(func: Callable) -> type:
sig = inspect.get_annotations(func)
return sig.get("return", Any)
P = ParamSpec("P")
CalculatedField = Callable[Concatenate[AsyncSession, Any, P], Awaitable[Any]]
DecoratorTarget = CalculatedField | staticmethod | classmethod
def included(func: DecoratorTarget) -> DecoratorTarget:
marker = _mark_callable(func, "__included__")
if marker is None:
raise RuntimeError("@included is only usable on callables.")
@wraps(marker)
async def wrapper(*args, **kwargs):
return await marker(*args, **kwargs)
if isinstance(func, staticmethod):
return staticmethod(wrapper)
if isinstance(func, classmethod):
return classmethod(wrapper)
return wrapper
def ondemand(func: DecoratorTarget) -> DecoratorTarget:
marker = _mark_callable(func, "__calculated_ondemand__")
if marker is None:
raise RuntimeError("@ondemand is only usable on callables.")
@wraps(marker)
async def wrapper(*args, **kwargs):
return await marker(*args, **kwargs)
if isinstance(func, staticmethod):
return staticmethod(wrapper)
if isinstance(func, classmethod):
return classmethod(wrapper)
return wrapper
async def call_awaitable_with_context(
func: CalculatedField,
session: AsyncSession,
instance: Any,
context: dict[str, Any],
) -> Any:
context_params: list[str] | None = getattr(func, "__context_params__", None)
if context_params is None:
# Fallback if not pre-calculated
sig = inspect.signature(func)
if len(sig.parameters) == 2:
return await func(session, instance)
else:
call_params = {}
for param in sig.parameters.values():
if param.name in context:
call_params[param.name] = context[param.name]
return await func(session, instance, **call_params)
if not context_params:
return await func(session, instance)
call_params = {}
for name in context_params:
if name in context:
call_params[name] = context[name]
return await func(session, instance, **call_params)
class DatabaseModel[TDict](SQLModel, UTCBaseModel, metaclass=DatabaseModelMetaclass):
_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
_ONDEMAND_DATABASE_FIELDS: ClassVar[list[str]] = []
_ONDEMAND_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
_EXCLUDED_DATABASE_FIELDS: ClassVar[list[str]] = []
@overload
@classmethod
async def transform(
cls,
db_instance: "DatabaseModel",
*,
session: AsyncSession,
includes: list[str] | None = None,
**context: Any,
) -> TDict: ...
@overload
@classmethod
async def transform(
cls,
db_instance: "DatabaseModel",
*,
includes: list[str] | None = None,
**context: Any,
) -> TDict: ...
@classmethod
async def transform(
cls,
db_instance: "DatabaseModel",
*,
session: AsyncSession | None = None,
includes: list[str] | None = None,
**context: Any,
) -> TDict:
includes = includes.copy() if includes is not None else []
session = cast(AsyncSession | None, async_object_session(db_instance)) if session is None else session
if session is None:
raise RuntimeError("DatabaseModel.transform requires a session-bound instance.")
resp_obj = cls.model_validate(db_instance.model_dump())
data = resp_obj.model_dump()
for field in cls._CALCULATED_FIELDS:
func = getattr(cls, field)
value = await call_awaitable_with_context(func, session, db_instance, context)
data[field] = value
sub_include_map: dict[str, list[str]] = {}
for include in [i for i in includes if "." in i]:
parent, sub_include = include.split(".", 1)
if parent not in sub_include_map:
sub_include_map[parent] = []
sub_include_map[parent].append(sub_include)
includes.remove(include) # pyright: ignore[reportOptionalMemberAccess]
for field, sub_includes in sub_include_map.items():
if field in cls._ONDEMAND_CALCULATED_FIELDS:
func = getattr(cls, field)
value = await call_awaitable_with_context(
func, session, db_instance, {**context, "includes": sub_includes}
)
data[field] = value
for include in includes:
if include in data:
continue
if include in cls._ONDEMAND_CALCULATED_FIELDS:
func = getattr(cls, include)
value = await call_awaitable_with_context(func, session, db_instance, context)
data[include] = value
for field in cls._ONDEMAND_DATABASE_FIELDS:
if field not in includes:
del data[field]
for field in cls._EXCLUDED_DATABASE_FIELDS:
if field in data:
del data[field]
return cast(TDict, data)
@classmethod
async def transform_many(
cls,
db_instances: Sequence["DatabaseModel"],
*,
session: AsyncSession | None = None,
includes: list[str] | None = None,
**context: Any,
) -> list[TDict]:
if not db_instances:
return []
# SQLAlchemy AsyncSession is not concurrency-safe, so we cannot use asyncio.gather here
# if the transform method performs any database operations using the shared session.
# Since we don't know if the transform method (or its calculated fields) will use the DB,
# we must execute them serially to be safe.
results = []
for instance in db_instances:
results.append(await cls.transform(instance, session=session, includes=includes, **context))
return results
@classmethod
@lru_cache
def generate_typeddict(cls, includes: tuple[str, ...] | None = None) -> type[TypedDict]: # pyright: ignore[reportInvalidTypeForm]
def _evaluate_type(field_type: Any, *, resolve_database_model: bool = False, field_name: str = "") -> Any:
# Evaluate ForwardRef if present
if isinstance(field_type, (str, ForwardRef)):
resolved = _safe_evaluate_forwardref(field_type, cls.__module__)
if resolved is not None:
field_type = resolved
origin_type = get_origin(field_type)
inner_type = field_type
args = get_args(field_type)
is_optional = type_is_optional(field_type) # pyright: ignore[reportArgumentType]
if is_optional:
inner_type = next((arg for arg in args if arg is not NoneType), field_type)
is_list = False
if origin_type is list:
is_list = True
inner_type = args[0]
# Evaluate ForwardRef in inner_type if present
if isinstance(inner_type, (str, ForwardRef)):
resolved = _safe_evaluate_forwardref(inner_type, cls.__module__)
if resolved is not None:
inner_type = resolved
if not resolve_database_model:
if is_optional:
return inner_type | None # pyright: ignore[reportOperatorIssue]
elif is_list:
return list[inner_type]
return inner_type
model_class = None
# First check if inner_type is directly a DatabaseModel subclass
try:
if inspect.isclass(inner_type) and issubclass(inner_type, DatabaseModel): # type: ignore
model_class = inner_type
except TypeError:
pass
# If not found, look up in _dict_to_model
if model_class is None:
model_class = _dict_to_model.get(inner_type) # type: ignore
if model_class is not None:
nested_dict = model_class.generate_typeddict(tuple(sub_include_map.get(field_name, ())))
resolved_type = list[nested_dict] if is_list else nested_dict # type: ignore
if is_optional:
resolved_type = resolved_type | None # type: ignore
return resolved_type
# Fallback: use the resolved inner_type
resolved_type = list[inner_type] if is_list else inner_type # type: ignore
if is_optional:
resolved_type = resolved_type | None # type: ignore
return resolved_type
if includes is None:
includes = ()
# Parse nested includes
direct_includes = []
sub_include_map: dict[str, list[str]] = {}
for include in includes:
if "." in include:
parent, sub_include = include.split(".", 1)
if parent not in sub_include_map:
sub_include_map[parent] = []
sub_include_map[parent].append(sub_include)
if parent not in direct_includes:
direct_includes.append(parent)
else:
direct_includes.append(include)
fields = {}
# Process model fields
for field_name, field_info in cls.model_fields.items():
field_type = field_info.annotation or Any
field_type = _evaluate_type(field_type, field_name=field_name)
if field_name in cls._ONDEMAND_DATABASE_FIELDS and field_name not in direct_includes:
continue
else:
fields[field_name] = field_type
# Process calculated fields
for field_name, field_type in cls._CALCULATED_FIELDS.items():
field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
fields[field_name] = field_type
# Process ondemand calculated fields
for field_name, field_type in cls._ONDEMAND_CALCULATED_FIELDS.items():
if field_name not in direct_includes:
continue
field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
fields[field_name] = field_type
return TypedDict(f"{cls.__name__}Dict[{', '.join(includes)}]" if includes else f"{cls.__name__}Dict", fields) # pyright: ignore[reportArgumentType]

View File

@@ -1,117 +1,339 @@
from datetime import datetime
import hashlib
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
from app.calculator import get_calculator
from app.config import settings
from app.database.beatmap_tags import BeatmapTagVote
from app.database.failtime import FailTime, FailTimeResp
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
from app.models.performance import DifficultyAttributesUnion
from app.models.score import GameMode
from ._base import DatabaseModel, OnDemand, included, ondemand
from .beatmap_playcounts import BeatmapPlaycounts
from .beatmapset import Beatmapset, BeatmapsetResp
from .beatmap_tags import BeatmapTagVote
from .beatmapset import Beatmapset, BeatmapsetDict, BeatmapsetModel
from .failtime import FailTime, FailTimeResp
from .user import User, UserDict, UserModel
from pydantic import BaseModel, TypeAdapter
from redis.asyncio import Redis
from sqlalchemy import Column, DateTime
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
from .user import User
class BeatmapOwner(SQLModel):
id: int
username: str
class BeatmapBase(SQLModel):
# Beatmap
url: str
class BeatmapDict(TypedDict):
beatmapset_id: int
difficulty_rating: float
id: int
mode: GameMode
total_length: int
user_id: int
version: str
url: str
checksum: NotRequired[str]
max_combo: NotRequired[int | None]
ar: NotRequired[float]
cs: NotRequired[float]
drain: NotRequired[float]
accuracy: NotRequired[float]
bpm: NotRequired[float]
count_circles: NotRequired[int]
count_sliders: NotRequired[int]
count_spinners: NotRequired[int]
deleted_at: NotRequired[datetime | None]
hit_length: NotRequired[int]
last_updated: NotRequired[datetime]
status: NotRequired[str]
beatmapset: NotRequired[BeatmapsetDict]
current_user_playcount: NotRequired[int]
current_user_tag_ids: NotRequired[list[int]]
failtimes: NotRequired[FailTimeResp]
top_tag_ids: NotRequired[list[dict[str, int]]]
user: NotRequired[UserDict]
convert: NotRequired[bool]
is_scoreable: NotRequired[bool]
mode_int: NotRequired[int]
ranked: NotRequired[int]
playcount: NotRequired[int]
passcount: NotRequired[int]
class BeatmapModel(DatabaseModel[BeatmapDict]):
BEATMAP_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
"checksum",
"accuracy",
"ar",
"bpm",
"convert",
"count_circles",
"count_sliders",
"count_spinners",
"cs",
"deleted_at",
"drain",
"hit_length",
"is_scoreable",
"last_updated",
"mode_int",
"passcount",
"playcount",
"ranked",
"url",
]
DEFAULT_API_INCLUDES: ClassVar[list[str]] = [
"beatmapset.ratings",
"current_user_playcount",
"failtimes",
"max_combo",
"owners",
]
TRANSFORMER_INCLUDES: ClassVar[list[str]] = [*DEFAULT_API_INCLUDES, *BEATMAP_TRANSFORMER_INCLUDES]
# Beatmap
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
difficulty_rating: float = Field(default=0.0, index=True)
id: int = Field(primary_key=True, index=True)
mode: GameMode
total_length: int
user_id: int = Field(index=True)
version: str = Field(index=True)
url: OnDemand[str]
# optional
checksum: str = Field(sa_column=Column(VARCHAR(32), index=True))
current_user_playcount: int = Field(default=0)
max_combo: int | None = Field(default=0)
# TODO: failtimes, owners
checksum: OnDemand[str] = Field(sa_column=Column(VARCHAR(32), index=True))
max_combo: OnDemand[int | None] = Field(default=0)
# TODO: owners
# BeatmapExtended
ar: float = Field(default=0.0)
cs: float = Field(default=0.0)
drain: float = Field(default=0.0) # hp
accuracy: float = Field(default=0.0) # od
bpm: float = Field(default=0.0)
count_circles: int = Field(default=0)
count_sliders: int = Field(default=0)
count_spinners: int = Field(default=0)
deleted_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
hit_length: int = Field(default=0)
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
ar: OnDemand[float] = Field(default=0.0)
cs: OnDemand[float] = Field(default=0.0)
drain: OnDemand[float] = Field(default=0.0) # hp
accuracy: OnDemand[float] = Field(default=0.0) # od
bpm: OnDemand[float] = Field(default=0.0)
count_circles: OnDemand[int] = Field(default=0)
count_sliders: OnDemand[int] = Field(default=0)
count_spinners: OnDemand[int] = Field(default=0)
deleted_at: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime))
hit_length: OnDemand[int] = Field(default=0)
last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
@included
@staticmethod
async def status(_session: AsyncSession, beatmap: "Beatmap") -> str:
if settings.enable_all_beatmap_leaderboard and not beatmap.beatmap_status.has_leaderboard():
return BeatmapRankStatus.APPROVED.name.lower()
return beatmap.beatmap_status.name.lower()
@ondemand
@staticmethod
async def beatmapset(
_session: AsyncSession,
beatmap: "Beatmap",
includes: list[str] | None = None,
) -> BeatmapsetDict | None:
if beatmap.beatmapset is not None:
return await BeatmapsetModel.transform(
beatmap.beatmapset, includes=(includes or []) + Beatmapset.BEATMAPSET_TRANSFORMER_INCLUDES
)
@ondemand
@staticmethod
async def current_user_playcount(_session: AsyncSession, beatmap: "Beatmap", user: "User") -> int:
playcount = (
await _session.exec(
select(BeatmapPlaycounts.playcount).where(
BeatmapPlaycounts.beatmap_id == beatmap.id, BeatmapPlaycounts.user_id == user.id
)
)
).first()
return int(playcount or 0)
@ondemand
@staticmethod
async def current_user_tag_ids(_session: AsyncSession, beatmap: "Beatmap", user: "User | None" = None) -> list[int]:
if user is None:
return []
tag_ids = (
await _session.exec(
select(BeatmapTagVote.tag_id).where(
BeatmapTagVote.beatmap_id == beatmap.id,
BeatmapTagVote.user_id == user.id,
)
)
).all()
return list(tag_ids)
@ondemand
@staticmethod
async def failtimes(_session: AsyncSession, beatmap: "Beatmap") -> FailTimeResp:
if beatmap.failtimes is not None:
return FailTimeResp.from_db(beatmap.failtimes)
return FailTimeResp()
@ondemand
@staticmethod
async def top_tag_ids(_session: AsyncSession, beatmap: "Beatmap") -> list[dict[str, int]]:
all_votes = (
await _session.exec(
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
.where(BeatmapTagVote.beatmap_id == beatmap.id)
.group_by(col(BeatmapTagVote.tag_id))
.having(func.count() > settings.beatmap_tag_top_count)
)
).all()
top_tag_ids: list[dict[str, int]] = []
for id, votes in all_votes:
top_tag_ids.append({"tag_id": id, "count": votes})
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
return top_tag_ids
@ondemand
@staticmethod
async def user(
_session: AsyncSession,
beatmap: "Beatmap",
includes: list[str] | None = None,
) -> UserDict | None:
from .user import User
user = await _session.get(User, beatmap.user_id)
if user is None:
return None
return await UserModel.transform(user, includes=includes)
@ondemand
@staticmethod
async def convert(_session: AsyncSession, _beatmap: "Beatmap") -> bool:
return False
@ondemand
@staticmethod
async def is_scoreable(_session: AsyncSession, beatmap: "Beatmap") -> bool:
beatmap_status = beatmap.beatmap_status
if settings.enable_all_beatmap_leaderboard:
return True
return beatmap_status.has_leaderboard()
@ondemand
@staticmethod
async def mode_int(_session: AsyncSession, beatmap: "Beatmap") -> int:
return int(beatmap.mode)
@ondemand
@staticmethod
async def ranked(_session: AsyncSession, beatmap: "Beatmap") -> int:
beatmap_status = beatmap.beatmap_status
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
return BeatmapRankStatus.APPROVED.value
return beatmap_status.value
@ondemand
@staticmethod
async def playcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
result = (
await _session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
)
).first()
return int(result or 0)
@ondemand
@staticmethod
async def passcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
from .score import Score
return (
await _session.exec(
select(func.count())
.select_from(Score)
.where(
Score.beatmap_id == beatmap.id,
col(Score.passed).is_(True),
)
)
).one()
class Beatmap(BeatmapBase, table=True):
class Beatmap(AsyncAttrs, BeatmapModel, table=True):
__tablename__: str = "beatmaps"
id: int = Field(primary_key=True, index=True)
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus = Field(index=True)
# optional
beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
beatmapset: "Beatmapset" = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
@classmethod
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
d = resp.model_dump()
del d["beatmapset"]
async def from_resp_no_save(cls, _session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
d = {k: v for k, v in resp.items() if k != "beatmapset"}
beatmapset_id = resp.get("beatmapset_id")
bid = resp.get("id")
ranked = resp.get("ranked")
if beatmapset_id is None or bid is None or ranked is None:
raise ValueError("beatmapset_id, id and ranked are required")
beatmap = cls.model_validate(
{
**d,
"beatmapset_id": resp.beatmapset_id,
"id": resp.id,
"beatmap_status": BeatmapRankStatus(resp.ranked),
"beatmapset_id": beatmapset_id,
"id": bid,
"beatmap_status": BeatmapRankStatus(ranked),
}
)
return beatmap
@classmethod
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
async def from_resp(cls, session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
beatmap = await cls.from_resp_no_save(session, resp)
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
resp_id = resp.get("id")
if resp_id is None:
raise ValueError("id is required")
if not (await session.exec(select(exists()).where(Beatmap.id == resp_id))).first():
session.add(beatmap)
await session.commit()
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
return (await session.exec(select(Beatmap).where(Beatmap.id == resp_id))).one()
@classmethod
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
async def from_resp_batch(cls, session: AsyncSession, inp: list[BeatmapDict], from_: int = 0) -> list["Beatmap"]:
beatmaps = []
for resp in inp:
if resp.id == from_:
for resp_dict in inp:
bid = resp_dict.get("id")
if bid == from_ or bid is None:
continue
d = resp.model_dump()
del d["beatmapset"]
beatmapset_id = resp_dict.get("beatmapset_id")
ranked = resp_dict.get("ranked")
if beatmapset_id is None or ranked is None:
continue
# 创建 beatmap 字典,移除 beatmapset
d = {k: v for k, v in resp_dict.items() if k != "beatmapset"}
beatmap = Beatmap.model_validate(
{
**d,
"beatmapset_id": resp.beatmapset_id,
"id": resp.id,
"beatmap_status": BeatmapRankStatus(resp.ranked),
"beatmapset_id": beatmapset_id,
"id": bid,
"beatmap_status": BeatmapRankStatus(ranked),
}
)
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
if not (await session.exec(select(exists()).where(Beatmap.id == bid))).first():
session.add(beatmap)
beatmaps.append(beatmap)
await session.commit()
for beatmap in beatmaps:
await session.refresh(beatmap)
return beatmaps
@classmethod
@@ -132,10 +354,14 @@ class Beatmap(BeatmapBase, table=True):
beatmap = (await session.exec(stmt)).first()
if not beatmap:
resp = await fetcher.get_beatmap(bid, md5)
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
beatmapset_id = resp.get("beatmapset_id")
if beatmapset_id is None:
raise ValueError("beatmapset_id is required")
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == beatmapset_id))
if not r.first():
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
set_resp = await fetcher.get_beatmapset(beatmapset_id)
resp_id = resp.get("id")
await Beatmapset.from_resp(session, set_resp, from_=resp_id or 0)
return await Beatmap.from_resp(session, resp)
return beatmap
@@ -145,97 +371,6 @@ class APIBeatmapTag(BaseModel):
count: int
class BeatmapResp(BeatmapBase):
id: int
beatmapset_id: int
beatmapset: BeatmapsetResp | None = None
convert: bool = False
is_scoreable: bool
status: str
mode_int: int
ranked: int
url: str = ""
playcount: int = 0
passcount: int = 0
failtimes: FailTimeResp | None = None
top_tag_ids: list[APIBeatmapTag] | None = None
current_user_tag_ids: list[int] | None = None
is_deleted: bool = False
@classmethod
async def from_db(
cls,
beatmap: Beatmap,
query_mode: GameMode | None = None,
from_set: bool = False,
session: AsyncSession | None = None,
user: "User | None" = None,
) -> "BeatmapResp":
from .score import Score
beatmap_ = beatmap.model_dump()
beatmap_status = beatmap.beatmap_status
if query_mode is not None and beatmap.mode != query_mode:
beatmap_["convert"] = True
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
else:
beatmap_["status"] = beatmap_status.name.lower()
beatmap_["ranked"] = beatmap_status.value
beatmap_["mode_int"] = int(beatmap.mode)
beatmap_["is_deleted"] = beatmap.deleted_at is not None
if not from_set:
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
if beatmap.failtimes is not None:
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
else:
beatmap_["failtimes"] = FailTimeResp()
if session:
beatmap_["playcount"] = (
await session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
)
).first() or 0
beatmap_["passcount"] = (
await session.exec(
select(func.count())
.select_from(Score)
.where(
Score.beatmap_id == beatmap.id,
col(Score.passed).is_(True),
)
)
).one()
all_votes = (
await session.exec(
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
.where(BeatmapTagVote.beatmap_id == beatmap.id)
.group_by(col(BeatmapTagVote.tag_id))
.having(func.count() > settings.beatmap_tag_top_count)
)
).all()
top_tag_ids: list[dict[str, int]] = []
for id, votes in all_votes:
top_tag_ids.append({"tag_id": id, "count": votes})
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
beatmap_["top_tag_ids"] = top_tag_ids
if user is not None:
beatmap_["current_user_tag_ids"] = (
await session.exec(
select(BeatmapTagVote.tag_id)
.where(BeatmapTagVote.beatmap_id == beatmap.id)
.where(BeatmapTagVote.user_id == user.id)
)
).all()
else:
beatmap_["current_user_tag_ids"] = []
return cls.model_validate(beatmap_)
class BannedBeatmaps(SQLModel, table=True):
__tablename__: str = "banned_beatmaps"
id: int | None = Field(primary_key=True, index=True, default=None)

View File

@@ -1,10 +1,11 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, NotRequired, TypedDict
from app.config import settings
from app.database.events import Event, EventType
from app.utils import utcnow
from pydantic import BaseModel
from ._base import DatabaseModel, included
from .events import Event, EventType
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
@@ -12,52 +13,65 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
SQLModel,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .beatmap import Beatmap, BeatmapResp
from .beatmapset import BeatmapsetResp
from .beatmap import Beatmap, BeatmapDict
from .beatmapset import BeatmapsetDict
from .user import User
class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True):
class BeatmapPlaycountsDict(TypedDict):
user_id: int
beatmap_id: int
count: NotRequired[int]
beatmap: NotRequired["BeatmapDict"]
beatmapset: NotRequired["BeatmapsetDict"]
class BeatmapPlaycountsModel(AsyncAttrs, DatabaseModel[BeatmapPlaycountsDict]):
__tablename__: str = "beatmap_playcounts"
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True), exclude=True
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
playcount: int = Field(default=0)
playcount: int = Field(default=0, exclude=True)
@included
@staticmethod
async def count(_session: AsyncSession, obj: "BeatmapPlaycounts") -> int:
return obj.playcount
@included
@staticmethod
async def beatmap(
_session: AsyncSession, obj: "BeatmapPlaycounts", includes: list[str] | None = None
) -> "BeatmapDict":
from .beatmap import BeatmapModel
await obj.awaitable_attrs.beatmap
return await BeatmapModel.transform(obj.beatmap, includes=includes)
@included
@staticmethod
async def beatmapset(
_session: AsyncSession, obj: "BeatmapPlaycounts", includes: list[str] | None = None
) -> "BeatmapsetDict":
from .beatmap import BeatmapsetModel
await obj.awaitable_attrs.beatmap
return await BeatmapsetModel.transform(obj.beatmap.beatmapset, includes=includes)
class BeatmapPlaycounts(BeatmapPlaycountsModel, table=True):
user: "User" = Relationship()
beatmap: "Beatmap" = Relationship()
class BeatmapPlaycountsResp(BaseModel):
beatmap_id: int
beatmap: "BeatmapResp | None" = None
beatmapset: "BeatmapsetResp | None" = None
count: int
@classmethod
async def from_db(cls, db_model: BeatmapPlaycounts) -> "BeatmapPlaycountsResp":
from .beatmap import BeatmapResp
from .beatmapset import BeatmapsetResp
await db_model.awaitable_attrs.beatmap
return cls(
beatmap_id=db_model.beatmap_id,
count=db_model.playcount,
beatmap=await BeatmapResp.from_db(db_model.beatmap),
beatmapset=await BeatmapsetResp.from_db(db_model.beatmap.beatmapset),
)
async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None:
existing_playcount = (
await session.exec(

View File

@@ -1,14 +1,15 @@
from datetime import datetime
from typing import TYPE_CHECKING, NotRequired, Self, TypedDict
from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict
from app.config import settings
from app.database.beatmap_playcounts import BeatmapPlaycounts
from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.score import GameMode
from .user import BASE_INCLUDES, User, UserResp
from ._base import DatabaseModel, OnDemand, included, ondemand
from .beatmap_playcounts import BeatmapPlaycounts
from .user import User, UserDict
from pydantic import BaseModel, field_validator, model_validator
from pydantic import BaseModel
from sqlalchemy import JSON, Boolean, Column, DateTime, Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
@@ -17,7 +18,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
from .beatmap import Beatmap, BeatmapResp
from .beatmap import Beatmap, BeatmapDict
from .favourite_beatmapset import FavouriteBeatmapset
@@ -68,8 +69,99 @@ class BeatmapTranslationText(BaseModel):
id: int | None = None
class BeatmapsetBase(SQLModel):
class BeatmapsetDict(TypedDict):
id: int
artist: str
artist_unicode: str
covers: BeatmapCovers | None
creator: str
nsfw: bool
preview_url: str
source: str
spotlight: bool
title: str
title_unicode: str
track_id: int | None
user_id: int
video: bool
current_nominations: list[BeatmapNomination] | None
description: BeatmapDescription | None
pack_tags: list[str]
bpm: NotRequired[float]
can_be_hyped: NotRequired[bool]
discussion_locked: NotRequired[bool]
last_updated: NotRequired[datetime]
ranked_date: NotRequired[datetime | None]
storyboard: NotRequired[bool]
submitted_date: NotRequired[datetime]
tags: NotRequired[str]
discussion_enabled: NotRequired[bool]
legacy_thread_url: NotRequired[str | None]
status: NotRequired[str]
ranked: NotRequired[int]
is_scoreable: NotRequired[bool]
favourite_count: NotRequired[int]
genre_id: NotRequired[int]
hype: NotRequired[BeatmapHype]
language_id: NotRequired[int]
play_count: NotRequired[int]
availability: NotRequired[BeatmapAvailability]
beatmaps: NotRequired[list["BeatmapDict"]]
has_favourited: NotRequired[bool]
recent_favourites: NotRequired[list[UserDict]]
genre: NotRequired[BeatmapTranslationText]
language: NotRequired[BeatmapTranslationText]
nominations: NotRequired["BeatmapNominations"]
ratings: NotRequired[list[int]]
class BeatmapsetModel(DatabaseModel[BeatmapsetDict]):
BEATMAPSET_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
"availability",
"has_favourited",
"bpm",
"deleted_atcan_be_hyped",
"discussion_locked",
"is_scoreable",
"last_updated",
"legacy_thread_url",
"ranked",
"ranked_date",
"submitted_date",
"tags",
"rating",
"storyboard",
]
API_INCLUDES: ClassVar[list[str]] = [
*BEATMAPSET_TRANSFORMER_INCLUDES,
"beatmaps.current_user_playcount",
"beatmaps.current_user_tag_ids",
"beatmaps.max_combo",
"current_nominations",
"current_user_attributes",
"description",
"genre",
"language",
"pack_tags",
"ratings",
"recent_favourites",
"related_tags",
"related_users",
"user",
"version_count",
*[
f"beatmaps.{inc}"
for inc in {
"failtimes",
"owners",
"top_tag_ids",
}
],
]
# Beatmapset
id: int = Field(default=None, primary_key=True, index=True)
artist: str = Field(index=True)
artist_unicode: str = Field(index=True)
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
@@ -77,41 +169,285 @@ class BeatmapsetBase(SQLModel):
nsfw: bool = Field(default=False, sa_column=Column(Boolean))
preview_url: str
source: str = Field(default="")
spotlight: bool = Field(default=False, sa_column=Column(Boolean))
title: str = Field(index=True)
title_unicode: str = Field(index=True)
track_id: int | None = Field(default=None, index=True) # feature artist?
user_id: int = Field(index=True)
video: bool = Field(sa_column=Column(Boolean, index=True))
# optional
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON))
description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON))
current_nominations: OnDemand[list[BeatmapNomination] | None] = Field(None, sa_column=Column(JSON))
description: OnDemand[BeatmapDescription | None] = Field(default=None, sa_column=Column(JSON))
# TODO: discussions: list[BeatmapsetDiscussion] = None
# TODO: current_user_attributes: Optional[CurrentUserAttributes] = None
# TODO: events: Optional[list[BeatmapsetEvent]] = None
pack_tags: list[str] = Field(default=[], sa_column=Column(JSON))
pack_tags: OnDemand[list[str]] = Field(default=[], sa_column=Column(JSON))
# TODO: related_users: Optional[list[User]] = None
# TODO: user: Optional[User] = Field(default=None)
track_id: int | None = Field(default=None, index=True) # feature artist?
# BeatmapsetExtended
bpm: float = Field(default=0.0)
can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True))
storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
tags: str = Field(default="", sa_column=Column(Text))
bpm: OnDemand[float] = Field(default=0.0)
can_be_hyped: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean))
discussion_locked: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean))
last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
ranked_date: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime, index=True))
storyboard: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean, index=True))
submitted_date: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
tags: OnDemand[str] = Field(default="", sa_column=Column(Text))
@ondemand
@staticmethod
async def legacy_thread_url(
_session: AsyncSession,
_beatmapset: "Beatmapset",
) -> str | None:
return None
@included
@staticmethod
async def discussion_enabled(
_session: AsyncSession,
_beatmapset: "Beatmapset",
) -> bool:
return True
@included
@staticmethod
async def status(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> str:
beatmap_status = beatmapset.beatmap_status
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
return BeatmapRankStatus.APPROVED.name.lower()
return beatmap_status.name.lower()
@included
@staticmethod
async def ranked(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
beatmap_status = beatmapset.beatmap_status
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
return BeatmapRankStatus.APPROVED.value
return beatmap_status.value
@ondemand
@staticmethod
async def is_scoreable(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> bool:
beatmap_status = beatmapset.beatmap_status
if settings.enable_all_beatmap_leaderboard:
return True
return beatmap_status.has_leaderboard()
@included
@staticmethod
async def favourite_count(
session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
from .favourite_beatmapset import FavouriteBeatmapset
count = await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
return count.one()
@included
@staticmethod
async def genre_id(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
return beatmapset.beatmap_genre.value
@ondemand
@staticmethod
async def hype(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapHype:
return BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required)
@included
@staticmethod
async def language_id(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
return beatmapset.beatmap_language.value
@included
@staticmethod
async def play_count(
session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
from .beatmap import Beatmap
playcount = await session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(
col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id)
)
)
return int(playcount.first() or 0)
@ondemand
@staticmethod
async def availability(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapAvailability:
return BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
)
@ondemand
@staticmethod
async def beatmaps(
_session: AsyncSession,
beatmapset: "Beatmapset",
includes: list[str] | None = None,
user: "User | None" = None,
) -> list["BeatmapDict"]:
from .beatmap import BeatmapModel
return [
await BeatmapModel.transform(
beatmap, includes=(includes or []) + BeatmapModel.BEATMAP_TRANSFORMER_INCLUDES, user=user
)
for beatmap in await beatmapset.awaitable_attrs.beatmaps
]
# @ondemand
# @staticmethod
# async def current_nominations(
# _session: AsyncSession,
# beatmapset: "Beatmapset",
# ) -> list[BeatmapNomination] | None:
# return beatmapset.current_nominations or []
@ondemand
@staticmethod
async def has_favourited(
session: AsyncSession,
beatmapset: "Beatmapset",
user: User | None = None,
) -> bool:
from .favourite_beatmapset import FavouriteBeatmapset
if session is None:
return False
query = select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
if user is not None:
query = query.where(FavouriteBeatmapset.user_id == user.id)
existing = (await session.exec(query)).first()
return existing is not None
@ondemand
@staticmethod
async def recent_favourites(
session: AsyncSession,
beatmapset: "Beatmapset",
includes: list[str] | None = None,
) -> list[UserDict]:
from .favourite_beatmapset import FavouriteBeatmapset
recent_favourites = (
await session.exec(
select(FavouriteBeatmapset)
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
.order_by(col(FavouriteBeatmapset.date).desc())
.limit(50)
)
).all()
return [
await User.transform(
(await favourite.awaitable_attrs.user),
includes=includes,
)
for favourite in recent_favourites
]
@ondemand
@staticmethod
async def genre(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapTranslationText:
return BeatmapTranslationText(
name=beatmapset.beatmap_genre.name,
id=beatmapset.beatmap_genre.value,
)
@ondemand
@staticmethod
async def language(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapTranslationText:
return BeatmapTranslationText(
name=beatmapset.beatmap_language.name,
id=beatmapset.beatmap_language.value,
)
@ondemand
@staticmethod
async def nominations(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapNominations:
return BeatmapNominations(
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
)
# @ondemand
# @staticmethod
# async def user(
# session: AsyncSession,
# beatmapset: Beatmapset,
# includes: list[str] | None = None,
# ) -> dict[str, Any] | None:
# db_user = await session.get(User, beatmapset.user_id)
# if not db_user:
# return None
# return await UserResp.transform(db_user, includes=includes)
@ondemand
@staticmethod
async def ratings(
session: AsyncSession,
beatmapset: "Beatmapset",
) -> list[int]:
# Provide a stable default shape if no session is available
if session is None:
return []
from .beatmapset_ratings import BeatmapRating
beatmapset_all_ratings = (
await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id))
).all()
ratings_list = [0] * 11
for rating in beatmapset_all_ratings:
ratings_list[rating.rating] += 1
return ratings_list
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
class Beatmapset(AsyncAttrs, BeatmapsetModel, table=True):
__tablename__: str = "beatmapsets"
id: int = Field(default=None, primary_key=True, index=True)
# Beatmapset
beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True)
@@ -130,29 +466,45 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
d = resp.model_dump()
if resp.nominations:
d["nominations_required"] = resp.nominations.required
d["nominations_current"] = resp.nominations.current
if resp.hype:
d["hype_current"] = resp.hype.current
d["hype_required"] = resp.hype.required
if resp.genre_id:
d["beatmap_genre"] = Genre(resp.genre_id)
elif resp.genre:
d["beatmap_genre"] = Genre(resp.genre.id)
if resp.language_id:
d["beatmap_language"] = Language(resp.language_id)
elif resp.language:
d["beatmap_language"] = Language(resp.language.id)
async def from_resp_no_save(cls, resp: BeatmapsetDict) -> "Beatmapset":
# make a shallow copy so we can mutate safely
d: dict[str, Any] = dict(resp)
# nominations = resp.get("nominations")
# if nominations is not None:
# d["nominations_required"] = nominations.required
# d["nominations_current"] = nominations.current
hype = resp.get("hype")
if hype is not None:
d["hype_current"] = hype.current
d["hype_required"] = hype.required
genre_id = resp.get("genre_id")
genre = resp.get("genre")
if genre_id is not None:
d["beatmap_genre"] = Genre(genre_id)
elif genre is not None:
d["beatmap_genre"] = Genre(genre.id)
language_id = resp.get("language_id")
language = resp.get("language")
if language_id is not None:
d["beatmap_language"] = Language(language_id)
elif language is not None:
d["beatmap_language"] = Language(language.id)
availability = resp.get("availability")
ranked = resp.get("ranked")
if ranked is None:
raise ValueError("ranked field is required")
beatmapset = Beatmapset.model_validate(
{
**d,
"id": resp.id,
"beatmap_status": BeatmapRankStatus(resp.ranked),
"availability_info": resp.availability.more_information,
"download_disabled": resp.availability.download_disabled or False,
"beatmap_status": BeatmapRankStatus(ranked),
"availability_info": availability.more_information if availability is not None else None,
"download_disabled": bool(availability.download_disabled) if availability is not None else False,
}
)
return beatmapset
@@ -161,17 +513,19 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
async def from_resp(
cls,
session: AsyncSession,
resp: "BeatmapsetResp",
resp: BeatmapsetDict,
from_: int = 0,
) -> "Beatmapset":
from .beatmap import Beatmap
beatmapset_id = resp["id"]
beatmapset = await cls.from_resp_no_save(resp)
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
if not (await session.exec(select(exists()).where(Beatmapset.id == beatmapset_id))).first():
session.add(beatmapset)
await session.commit()
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == resp.id))).one()
beatmaps = resp.get("beatmaps", [])
await Beatmap.from_resp_batch(session, beatmaps, from_=from_)
beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id))).one()
return beatmapset
@classmethod
@@ -183,170 +537,5 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
resp = await fetcher.get_beatmapset(sid)
beatmapset = await cls.from_resp(session, resp)
await get_beatmapset_update_service().add(resp)
await session.refresh(beatmapset)
return beatmapset
class BeatmapsetResp(BeatmapsetBase):
id: int
beatmaps: list["BeatmapResp"] = Field(default_factory=list)
discussion_enabled: bool = True
status: str
ranked: int
legacy_thread_url: str | None = ""
is_scoreable: bool
hype: BeatmapHype | None = None
availability: BeatmapAvailability
genre: BeatmapTranslationText | None = None
genre_id: int
language: BeatmapTranslationText | None = None
language_id: int
nominations: BeatmapNominations | None = None
has_favourited: bool = False
favourite_count: int = 0
recent_favourites: list[UserResp] = Field(default_factory=list)
play_count: int = 0
@field_validator(
"nsfw",
"spotlight",
"video",
"can_be_hyped",
"discussion_locked",
"storyboard",
"discussion_enabled",
"is_scoreable",
"has_favourited",
mode="before",
)
@classmethod
def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
if isinstance(v, int):
return bool(v)
return v
@model_validator(mode="after")
def fix_genre_language(self) -> Self:
if self.genre is None:
self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id)
if self.language is None:
self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id)
return self
@classmethod
async def from_db(
cls,
beatmapset: Beatmapset,
include: list[str] = [],
session: AsyncSession | None = None,
user: User | None = None,
) -> "BeatmapsetResp":
from .beatmap import Beatmap, BeatmapResp
from .favourite_beatmapset import FavouriteBeatmapset
update = {
"beatmaps": [
await BeatmapResp.from_db(beatmap, from_set=True, session=session)
for beatmap in await beatmapset.awaitable_attrs.beatmaps
],
"hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required),
"availability": BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
),
"genre": BeatmapTranslationText(
name=beatmapset.beatmap_genre.name,
id=beatmapset.beatmap_genre.value,
),
"language": BeatmapTranslationText(
name=beatmapset.beatmap_language.name,
id=beatmapset.beatmap_language.value,
),
"genre_id": beatmapset.beatmap_genre.value,
"language_id": beatmapset.beatmap_language.value,
"nominations": BeatmapNominations(
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
),
"is_scoreable": beatmapset.beatmap_status.has_leaderboard(),
**beatmapset.model_dump(),
}
if session is not None:
# 从数据库读取对应谱面集的评分
from .beatmapset_ratings import BeatmapRating
beatmapset_all_ratings = (
await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id))
).all()
ratings_list = [0] * 11
for rating in beatmapset_all_ratings:
ratings_list[rating.rating] += 1
update["ratings"] = ratings_list
else:
# 返回非空值避免客户端崩溃
if update.get("ratings") is None:
update["ratings"] = []
beatmap_status = beatmapset.beatmap_status
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
update["status"] = BeatmapRankStatus.APPROVED.name.lower()
update["ranked"] = BeatmapRankStatus.APPROVED.value
else:
update["status"] = beatmap_status.name.lower()
update["ranked"] = beatmap_status.value
if session and user:
existing_favourite = (
await session.exec(
select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
).first()
update["has_favourited"] = existing_favourite is not None
if session and "recent_favourites" in include:
recent_favourites = (
await session.exec(
select(FavouriteBeatmapset)
.where(
FavouriteBeatmapset.beatmapset_id == beatmapset.id,
)
.order_by(col(FavouriteBeatmapset.date).desc())
.limit(50)
)
).all()
update["recent_favourites"] = [
await UserResp.from_db(
await favourite.awaitable_attrs.user,
session=session,
include=BASE_INCLUDES,
)
for favourite in recent_favourites
]
if session:
update["favourite_count"] = (
await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
).one()
update["play_count"] = (
await session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(
col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id)
)
)
).first() or 0
return cls.model_validate(
update,
)
class SearchBeatmapsetsResp(SQLModel):
beatmapsets: list[BeatmapsetResp]
total: int
cursor: dict[str, int | float | str] | None = None
cursor_string: str | None = None

View File

@@ -1,5 +1,5 @@
from app.database.beatmapset import Beatmapset
from app.database.user import User
from .beatmapset import Beatmapset
from .user import User
from sqlmodel import BigInteger, Column, Field, ForeignKey, Relationship, SQLModel

View File

@@ -1,8 +1,8 @@
from typing import TYPE_CHECKING
from app.database.statistics import UserStatistics
from app.models.score import GameMode
from .statistics import UserStatistics
from .user import User
from sqlmodel import (

View File

@@ -1,13 +1,14 @@
from datetime import datetime
from enum import Enum
from typing import Self
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
from app.database.user import RANKING_INCLUDES, User, UserResp
from app.models.model import UTCBaseModel
from app.utils import utcnow
from ._base import DatabaseModel, included, ondemand
from .user import User, UserDict, UserModel
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlmodel import (
VARCHAR,
BigInteger,
@@ -22,6 +23,8 @@ from sqlmodel import (
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.router.notification.server import ChatServer
# ChatChannel
@@ -44,16 +47,168 @@ class ChannelType(str, Enum):
TEAM = "TEAM"
class ChatChannelBase(SQLModel):
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
class MessageType(str, Enum):
ACTION = "action"
MARKDOWN = "markdown"
PLAIN = "plain"
class ChatChannelDict(TypedDict):
channel_id: int
description: str
name: str
icon: str | None
type: ChannelType
uuid: NotRequired[str | None]
message_length_limit: NotRequired[int]
moderated: NotRequired[bool]
current_user_attributes: NotRequired[ChatUserAttributes]
last_read_id: NotRequired[int | None]
last_message_id: NotRequired[int | None]
recent_messages: NotRequired[list["ChatMessageDict"]]
users: NotRequired[list[int]]
class ChatChannelModel(DatabaseModel[ChatChannelDict]):
CONVERSATION_INCLUDES: ClassVar[list[str]] = [
"last_message_id",
"users",
]
LISTING_INCLUDES: ClassVar[list[str]] = [
*CONVERSATION_INCLUDES,
"current_user_attributes",
"last_read_id",
]
channel_id: int = Field(primary_key=True, index=True, default=None)
description: str = Field(sa_column=Column(VARCHAR(255), index=True))
icon: str | None = Field(default=None)
type: ChannelType = Field(index=True)
@included
@staticmethod
async def name(session: AsyncSession, channel: "ChatChannel", user: User, server: "ChatServer") -> str:
users = server.channels.get(channel.channel_id, [])
if channel.type == ChannelType.PM and users and len(users) == 2:
target_user_id = next(u for u in users if u != user.id)
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
return target_name.one()
return channel.name
class ChatChannel(ChatChannelBase, table=True):
@included
@staticmethod
async def moderated(session: AsyncSession, channel: "ChatChannel", user: User) -> bool:
silence = (
await session.exec(
select(SilenceUser).where(
SilenceUser.channel_id == channel.channel_id,
SilenceUser.user_id == user.id,
)
)
).first()
return silence is not None
@ondemand
@staticmethod
async def current_user_attributes(
session: AsyncSession,
channel: "ChatChannel",
user: User,
) -> ChatUserAttributes:
from app.dependencies.database import get_redis
silence = (
await session.exec(
select(SilenceUser).where(
SilenceUser.channel_id == channel.channel_id,
SilenceUser.user_id == user.id,
)
)
).first()
can_message = silence is None
can_message_error = "You are silenced in this channel" if not can_message else None
redis = get_redis()
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else (last_msg or 0)
return ChatUserAttributes(
can_message=can_message,
can_message_error=can_message_error,
last_read_id=last_read_id,
)
@ondemand
@staticmethod
async def last_read_id(_session: AsyncSession, channel: "ChatChannel", user: User) -> int | None:
from app.dependencies.database import get_redis
redis = get_redis()
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
return int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
@ondemand
@staticmethod
async def last_message_id(_session: AsyncSession, channel: "ChatChannel") -> int | None:
from app.dependencies.database import get_redis
redis = get_redis()
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
return int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
@ondemand
@staticmethod
async def recent_messages(
session: AsyncSession,
channel: "ChatChannel",
) -> list["ChatMessageDict"]:
messages = (
await session.exec(
select(ChatMessage)
.where(ChatMessage.channel_id == channel.channel_id)
.order_by(col(ChatMessage.message_id).desc())
.limit(50)
)
).all()
result = [
await ChatMessageModel.transform(
msg,
)
for msg in reversed(messages)
]
return result
@ondemand
@staticmethod
async def users(
_session: AsyncSession,
channel: "ChatChannel",
server: "ChatServer",
user: User,
) -> list[int]:
if channel.type == ChannelType.PUBLIC:
return []
users = server.channels.get(channel.channel_id, []).copy()
if channel.type == ChannelType.PM and users and len(users) == 2:
target_user_id = next(u for u in users if u != user.id)
users = [target_user_id, user.id]
return users
@included
@staticmethod
async def message_length_limit(_session: AsyncSession, _channel: "ChatChannel") -> int:
return 1000
class ChatChannel(ChatChannelModel, table=True):
__tablename__: str = "chat_channels"
channel_id: int = Field(primary_key=True, index=True, default=None)
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
@classmethod
async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
@@ -74,93 +229,20 @@ class ChatChannel(ChatChannelBase, table=True):
return channel
class ChatChannelResp(ChatChannelBase):
channel_id: int
moderated: bool = False
uuid: str | None = None
current_user_attributes: ChatUserAttributes | None = None
last_read_id: int | None = None
last_message_id: int | None = None
recent_messages: list["ChatMessageResp"] = Field(default_factory=list)
users: list[int] = Field(default_factory=list)
message_length_limit: int = 1000
@classmethod
async def from_db(
cls,
channel: ChatChannel,
session: AsyncSession,
user: User,
redis: Redis,
users: list[int] | None = None,
include_recent_messages: bool = False,
) -> Self:
c = cls.model_validate(channel)
silence = (
await session.exec(
select(SilenceUser).where(
SilenceUser.channel_id == channel.channel_id,
SilenceUser.user_id == user.id,
)
)
).first()
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
if silence is not None:
attribute = ChatUserAttributes(
can_message=False,
can_message_error=silence.reason or "You are muted in this channel.",
last_read_id=last_read_id or 0,
)
c.moderated = True
else:
attribute = ChatUserAttributes(
can_message=True,
last_read_id=last_read_id or 0,
)
c.moderated = False
c.current_user_attributes = attribute
if c.type != ChannelType.PUBLIC and users is not None:
c.users = users
c.last_message_id = last_msg
c.last_read_id = last_read_id
if include_recent_messages:
messages = (
await session.exec(
select(ChatMessage)
.where(ChatMessage.channel_id == channel.channel_id)
.order_by(col(ChatMessage.timestamp).desc())
.limit(10)
)
).all()
c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages]
c.recent_messages.reverse()
if c.type == ChannelType.PM and users and len(users) == 2:
target_user_id = next(u for u in users if u != user.id)
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
c.name = target_name.one()
c.users = [target_user_id, user.id]
return c
# ChatMessage
class ChatMessageDict(TypedDict):
channel_id: int
content: str
message_id: int
sender_id: int
timestamp: datetime
type: MessageType
uuid: str | None
is_action: NotRequired[bool]
sender: NotRequired[UserDict]
class MessageType(str, Enum):
ACTION = "action"
MARKDOWN = "markdown"
PLAIN = "plain"
class ChatMessageBase(UTCBaseModel, SQLModel):
class ChatMessageModel(DatabaseModel[ChatMessageDict]):
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
content: str = Field(sa_column=Column(VARCHAR(1000)))
message_id: int = Field(index=True, primary_key=True, default=None)
@@ -169,31 +251,21 @@ class ChatMessageBase(UTCBaseModel, SQLModel):
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
uuid: str | None = Field(default=None)
@included
@staticmethod
async def is_action(_session: AsyncSession, db_message: "ChatMessage") -> bool:
return db_message.type == MessageType.ACTION
class ChatMessage(ChatMessageBase, table=True):
@ondemand
@staticmethod
async def sender(_session: AsyncSession, db_message: "ChatMessage") -> UserDict:
return await UserModel.transform(db_message.user)
class ChatMessage(ChatMessageModel, table=True):
__tablename__: str = "chat_messages"
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
channel: ChatChannel = Relationship()
class ChatMessageResp(ChatMessageBase):
sender: UserResp | None = None
is_action: bool = False
@classmethod
async def from_db(
cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None
) -> "ChatMessageResp":
m = cls.model_validate(db_message.model_dump())
m.is_action = db_message.type == MessageType.ACTION
if user:
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
else:
m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES)
return m
# SilenceUser
channel: "ChatChannel" = Relationship()
class SilenceUser(UTCBaseModel, SQLModel, table=True):

View File

@@ -1,7 +1,7 @@
import datetime
from app.database.beatmapset import Beatmapset
from app.database.user import User
from .beatmapset import Beatmapset
from .user import User
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (

View File

@@ -1,7 +1,9 @@
from .playlist_best_score import PlaylistBestScore
from .user import User, UserResp
from typing import Any, NotRequired, TypedDict
from ._base import DatabaseModel, ondemand
from .playlist_best_score import PlaylistBestScore
from .user import User, UserDict, UserModel
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
@@ -9,7 +11,6 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
SQLModel,
col,
func,
select,
@@ -17,17 +18,66 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
class ItemAttemptsCountBase(SQLModel):
room_id: int = Field(foreign_key="rooms.id", index=True)
class ItemAttemptsCountDict(TypedDict):
accuracy: float
attempts: int
completed: int
pp: float
room_id: int
total_score: int
user_id: int
user: NotRequired[UserDict]
position: NotRequired[int]
playlist_item_attempts: NotRequired[list[dict[str, Any]]]
class ItemAttemptsCountModel(DatabaseModel[ItemAttemptsCountDict]):
accuracy: float = 0.0
attempts: int = Field(default=0)
completed: int = Field(default=0)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
accuracy: float = 0.0
pp: float = 0
room_id: int = Field(foreign_key="rooms.id", index=True)
total_score: int = 0
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
@ondemand
@staticmethod
async def user(_session: AsyncSession, item_attempts: "ItemAttemptsCount") -> UserDict:
user_instance = await item_attempts.awaitable_attrs.user
return await UserModel.transform(user_instance)
@ondemand
@staticmethod
async def position(session: AsyncSession, item_attempts: "ItemAttemptsCount") -> int:
return await item_attempts.get_position(session)
@ondemand
@staticmethod
async def playlist_item_attempts(
session: AsyncSession,
item_attempts: "ItemAttemptsCount",
) -> list[dict[str, Any]]:
playlist_scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.room_id == item_attempts.room_id,
PlaylistBestScore.user_id == item_attempts.user_id,
)
)
).all()
result: list[dict[str, Any]] = []
for score in playlist_scores:
result.append(
{
"id": score.playlist_id,
"attempts": score.attempts,
"passed": score.score.passed,
}
)
return result
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountModel, table=True):
__tablename__: str = "item_attempts_count"
id: int | None = Field(default=None, primary_key=True)
@@ -37,15 +87,15 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
rownum = (
func.row_number()
.over(
partition_by=col(ItemAttemptsCountBase.room_id),
order_by=col(ItemAttemptsCountBase.total_score).desc(),
partition_by=col(ItemAttemptsCount.room_id),
order_by=col(ItemAttemptsCount.total_score).desc(),
)
.label("rn")
)
subq = select(ItemAttemptsCountBase, rownum).subquery()
subq = select(ItemAttemptsCount, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id)
result = await session.exec(stmt)
return result.one()
return result.first() or 0
async def update(self, session: AsyncSession):
playlist_scores = (
@@ -88,62 +138,3 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
await session.refresh(item_attempts)
await item_attempts.update(session)
return item_attempts
class ItemAttemptsResp(ItemAttemptsCountBase):
user: UserResp | None = None
position: int | None = None
@classmethod
async def from_db(
cls,
item_attempts: ItemAttemptsCount,
session: AsyncSession,
include: list[str] = [],
) -> "ItemAttemptsResp":
resp = cls.model_validate(item_attempts.model_dump())
resp.user = await UserResp.from_db(
await item_attempts.awaitable_attrs.user,
session=session,
include=["statistics", "team", "daily_challenge_user_stats"],
)
if "position" in include:
resp.position = await item_attempts.get_position(session)
# resp.accuracy *= 100
return resp
class ItemAttemptsCountForItem(BaseModel):
id: int
attempts: int
passed: bool
class PlaylistAggregateScore(BaseModel):
playlist_item_attempts: list[ItemAttemptsCountForItem] = Field(default_factory=list)
@classmethod
async def from_db(
cls,
room_id: int,
user_id: int,
session: AsyncSession,
) -> "PlaylistAggregateScore":
playlist_scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.room_id == room_id,
PlaylistBestScore.user_id == user_id,
)
)
).all()
playlist_item_attempts = []
for score in playlist_scores:
playlist_item_attempts.append(
ItemAttemptsCountForItem(
id=score.playlist_id,
attempts=score.attempts,
passed=score.score.passed,
)
)
return cls(playlist_item_attempts=playlist_item_attempts)

View File

@@ -1,11 +1,11 @@
from datetime import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
from app.models.model import UTCBaseModel
from app.models.mods import APIMod
from app.models.playlist import PlaylistItem
from .beatmap import Beatmap, BeatmapResp
from ._base import DatabaseModel, ondemand
from .beatmap import Beatmap, BeatmapDict, BeatmapModel
from sqlmodel import (
JSON,
@@ -15,7 +15,6 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
SQLModel,
func,
select,
)
@@ -23,18 +22,34 @@ from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .room import Room
from .score import ScoreDict
class PlaylistBase(SQLModel, UTCBaseModel):
id: int = Field(index=True)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
class PlaylistDict(TypedDict):
id: int
room_id: int
beatmap_id: int
created_at: datetime | None
ruleset_id: int
expired: bool = Field(default=False)
playlist_order: int = Field(default=0)
played_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True)),
default=None,
allowed_mods: list[APIMod]
required_mods: list[APIMod]
freestyle: bool
expired: bool
owner_id: int
playlist_order: int
played_at: datetime | None
beatmap: NotRequired["BeatmapDict"]
scores: NotRequired[list[dict[str, Any]]]
class PlaylistModel(DatabaseModel[PlaylistDict]):
id: int = Field(index=True)
room_id: int = Field(foreign_key="rooms.id")
beatmap_id: int = Field(
foreign_key="beatmaps.id",
)
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
ruleset_id: int
allowed_mods: list[APIMod] = Field(
default_factory=list,
sa_column=Column(JSON),
@@ -43,16 +58,46 @@ class PlaylistBase(SQLModel, UTCBaseModel):
default_factory=list,
sa_column=Column(JSON),
)
beatmap_id: int = Field(
foreign_key="beatmaps.id",
)
freestyle: bool = Field(default=False)
expired: bool = Field(default=False)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
playlist_order: int = Field(default=0)
played_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True)),
default=None,
)
@ondemand
@staticmethod
async def beatmap(_session: AsyncSession, playlist: "Playlist", includes: list[str] | None = None) -> BeatmapDict:
return await BeatmapModel.transform(playlist.beatmap, includes=includes)
@ondemand
@staticmethod
async def scores(session: AsyncSession, playlist: "Playlist") -> list["ScoreDict"]:
from .score import Score, ScoreModel
scores = (
await session.exec(
select(Score).where(
Score.playlist_item_id == playlist.id,
Score.room_id == playlist.room_id,
)
)
).all()
result: list[ScoreDict] = []
for score in scores:
result.append(
await ScoreModel.transform(
score,
)
)
return result
class Playlist(PlaylistBase, table=True):
class Playlist(PlaylistModel, table=True):
__tablename__: str = "room_playlists"
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
room_id: int = Field(foreign_key="rooms.id", exclude=True)
beatmap: Beatmap = Relationship(
sa_relationship_kwargs={
@@ -60,7 +105,6 @@ class Playlist(PlaylistBase, table=True):
}
)
room: "Room" = Relationship()
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
updated_at: datetime | None = Field(
default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()}
)
@@ -121,15 +165,3 @@ class Playlist(PlaylistBase, table=True):
raise ValueError("Playlist item not found")
await session.delete(db_playlist)
await session.commit()
class PlaylistResp(PlaylistBase):
beatmap: BeatmapResp | None = None
@classmethod
async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp":
data = playlist.model_dump()
if "beatmap" in include:
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
resp = cls.model_validate(data)
return resp

View File

@@ -1,26 +1,39 @@
from enum import Enum
from typing import TYPE_CHECKING, NotRequired, TypedDict
from .user import User, UserResp
from app.models.score import GameMode
from ._base import DatabaseModel, included, ondemand
from pydantic import BaseModel
from sqlmodel import (
BigInteger,
Column,
Field,
ForeignKey,
Relationship as SQLRelationship,
SQLModel,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .user import User, UserDict
class RelationshipType(str, Enum):
FOLLOW = "Friend"
BLOCK = "Block"
FOLLOW = "friend"
BLOCK = "block"
class Relationship(SQLModel, table=True):
class RelationshipDict(TypedDict):
target_id: int | None
type: RelationshipType
id: NotRequired[int | None]
user_id: NotRequired[int | None]
mutual: NotRequired[bool]
target: NotRequired["UserDict"]
class RelationshipModel(DatabaseModel[RelationshipDict]):
__tablename__: str = "relationship"
id: int | None = Field(
default=None,
@@ -34,6 +47,7 @@ class Relationship(SQLModel, table=True):
ForeignKey("lazer_users.id"),
index=True,
),
exclude=True,
)
target_id: int = Field(
default=None,
@@ -44,22 +58,10 @@ class Relationship(SQLModel, table=True):
),
)
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
target: User = SQLRelationship(
sa_relationship_kwargs={
"foreign_keys": "[Relationship.target_id]",
"lazy": "selectin",
}
)
class RelationshipResp(BaseModel):
target_id: int
target: UserResp
mutual: bool = False
type: RelationshipType
@classmethod
async def from_db(cls, session: AsyncSession, relationship: Relationship) -> "RelationshipResp":
@included
@staticmethod
async def mutual(session: AsyncSession, relationship: "Relationship") -> bool:
target_relationship = (
await session.exec(
select(Relationship).where(
@@ -68,23 +70,29 @@ class RelationshipResp(BaseModel):
)
)
).first()
mutual = bool(
return bool(
target_relationship is not None
and relationship.type == RelationshipType.FOLLOW
and target_relationship.type == RelationshipType.FOLLOW
)
return cls(
target_id=relationship.target_id,
target=await UserResp.from_db(
relationship.target,
session,
include=[
"team",
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
],
),
mutual=mutual,
type=relationship.type,
)
@ondemand
@staticmethod
async def target(
_session: AsyncSession,
relationship: "Relationship",
ruleset: GameMode | None = None,
includes: list[str] | None = None,
) -> "UserDict":
from .user import UserModel
return await UserModel.transform(relationship.target, ruleset=ruleset, includes=includes)
class Relationship(RelationshipModel, table=True):
target: "User" = SQLRelationship(
sa_relationship_kwargs={
"foreign_keys": "[Relationship.target_id]",
"lazy": "selectin",
}
)

View File

@@ -1,8 +1,6 @@
from datetime import datetime
from typing import ClassVar, NotRequired, TypedDict
from app.database.item_attempts_count import PlaylistAggregateScore
from app.database.room_participated_user import RoomParticipatedUser
from app.models.model import UTCBaseModel
from app.models.room import (
MatchType,
QueueMode,
@@ -13,28 +11,58 @@ from app.models.room import (
)
from app.utils import utcnow
from .playlists import Playlist, PlaylistResp
from .user import User, UserResp
from ._base import DatabaseModel, included, ondemand
from .item_attempts_count import ItemAttemptsCount, ItemAttemptsCountDict, ItemAttemptsCountModel
from .playlists import Playlist, PlaylistDict, PlaylistModel
from .room_participated_user import RoomParticipatedUser
from .user import User, UserDict, UserModel
from pydantic import field_validator
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
col,
func,
select,
)
from sqlmodel import BigInteger, Column, DateTime, Field, ForeignKey, Relationship, SQLModel, col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class RoomBase(SQLModel, UTCBaseModel):
class RoomDict(TypedDict):
id: int
name: str
category: RoomCategory
status: RoomStatus
type: MatchType
duration: int | None
starts_at: datetime | None
ends_at: datetime | None
max_attempts: int | None
participant_count: int
channel_id: int
queue_mode: QueueMode
auto_skip: bool
auto_start_duration: int
has_password: NotRequired[bool]
current_playlist_item: NotRequired["PlaylistDict | None"]
playlist: NotRequired[list["PlaylistDict"]]
playlist_item_stats: NotRequired[RoomPlaylistItemStats]
difficulty_range: NotRequired[RoomDifficultyRange]
host: NotRequired[UserDict]
recent_participants: NotRequired[list[UserDict]]
current_user_score: NotRequired["ItemAttemptsCountDict | None"]
class RoomModel(DatabaseModel[RoomDict]):
SHOW_RESPONSE_INCLUDES: ClassVar[list[str]] = [
"current_user_score.playlist_item_attempts",
"host.country",
"playlist.beatmap.beatmapset",
"playlist.beatmap.checksum",
"playlist.beatmap.max_combo",
"recent_participants",
]
id: int = Field(default=None, primary_key=True, index=True)
name: str = Field(index=True)
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
status: RoomStatus
type: MatchType
duration: int | None = Field(default=None) # minutes
starts_at: datetime | None = Field(
sa_column=Column(
@@ -48,76 +76,88 @@ class RoomBase(SQLModel, UTCBaseModel):
),
default=None,
)
participant_count: int = Field(default=0)
max_attempts: int | None = Field(default=None) # playlists
type: MatchType
participant_count: int = Field(default=0)
channel_id: int = 0
queue_mode: QueueMode
auto_skip: bool
auto_start_duration: int
status: RoomStatus
channel_id: int | None = None
class Room(AsyncAttrs, RoomBase, table=True):
__tablename__: str = "rooms"
id: int = Field(default=None, primary_key=True, index=True)
host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
password: str | None = Field(default=None)
host: User = Relationship()
playlist: list[Playlist] = Relationship(
sa_relationship_kwargs={
"lazy": "selectin",
"cascade": "all, delete-orphan",
"overlaps": "room",
}
)
class RoomResp(RoomBase):
id: int
has_password: bool = False
host: UserResp | None = None
playlist: list[PlaylistResp] = []
playlist_item_stats: RoomPlaylistItemStats | None = None
difficulty_range: RoomDifficultyRange | None = None
current_playlist_item: PlaylistResp | None = None
current_user_score: PlaylistAggregateScore | None = None
recent_participants: list[UserResp] = Field(default_factory=list)
channel_id: int = 0
@field_validator("channel_id", mode="before")
@classmethod
async def from_db(
cls,
room: Room,
session: AsyncSession,
include: list[str] = [],
user: User | None = None,
) -> "RoomResp":
d = room.model_dump()
d["channel_id"] = d.get("channel_id", 0) or 0
d["has_password"] = bool(room.password)
resp = cls.model_validate(d)
def validate_channel_id(cls, v):
"""将 None 转换为 0"""
if v is None:
return 0
return v
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
difficulty_range = RoomDifficultyRange(
min=0,
max=0,
)
rulesets = set()
for playlist in room.playlist:
@included
@staticmethod
async def has_password(_session: AsyncSession, room: "Room") -> bool:
return bool(room.password)
@ondemand
@staticmethod
async def current_playlist_item(
_session: AsyncSession, room: "Room", includes: list[str] | None = None
) -> "PlaylistDict | None":
playlists = await room.awaitable_attrs.playlist
if not playlists:
return None
return await PlaylistModel.transform(playlists[-1], includes=includes)
@ondemand
@staticmethod
async def playlist(_session: AsyncSession, room: "Room", includes: list[str] | None = None) -> list["PlaylistDict"]:
playlists = await room.awaitable_attrs.playlist
result: list[PlaylistDict] = []
for playlist_item in playlists:
result.append(await PlaylistModel.transform(playlist_item, includes=includes))
return result
@ondemand
@staticmethod
async def playlist_item_stats(_session: AsyncSession, room: "Room") -> RoomPlaylistItemStats:
playlists = await room.awaitable_attrs.playlist
stats = RoomPlaylistItemStats(count_active=0, count_total=0, ruleset_ids=[])
rulesets: set[int] = set()
for playlist in playlists:
stats.count_total += 1
if not playlist.expired:
stats.count_active += 1
rulesets.add(playlist.ruleset_id)
difficulty_range.min = min(difficulty_range.min, playlist.beatmap.difficulty_rating)
difficulty_range.max = max(difficulty_range.max, playlist.beatmap.difficulty_rating)
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
stats.ruleset_ids = list(rulesets)
resp.playlist_item_stats = stats
resp.difficulty_range = difficulty_range
resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None
resp.recent_participants = []
return stats
@ondemand
@staticmethod
async def difficulty_range(_session: AsyncSession, room: "Room") -> RoomDifficultyRange:
playlists = await room.awaitable_attrs.playlist
if not playlists:
return RoomDifficultyRange(min=0.0, max=0.0)
min_diff = float("inf")
max_diff = float("-inf")
for playlist in playlists:
rating = playlist.beatmap.difficulty_rating
min_diff = min(min_diff, rating)
max_diff = max(max_diff, rating)
if min_diff == float("inf"):
min_diff = 0.0
if max_diff == float("-inf"):
max_diff = 0.0
return RoomDifficultyRange(min=min_diff, max=max_diff)
@ondemand
@staticmethod
async def host(_session: AsyncSession, room: "Room", includes: list[str] | None = None) -> UserDict:
host_user = await room.awaitable_attrs.host
return await UserModel.transform(host_user, includes=includes)
@ondemand
@staticmethod
async def recent_participants(session: AsyncSession, room: "Room") -> list[UserDict]:
participants: list[UserDict] = []
if room.category == RoomCategory.REALTIME:
query = (
select(RoomParticipatedUser)
@@ -137,39 +177,67 @@ class RoomResp(RoomBase):
.limit(8)
.order_by(col(RoomParticipatedUser.joined_at).desc())
)
resp.participant_count = (
await session.exec(
select(func.count())
.select_from(RoomParticipatedUser)
.where(
RoomParticipatedUser.room_id == room.id,
)
)
).first() or 0
for recent_participant in await session.exec(query):
resp.recent_participants.append(
await UserResp.from_db(
await recent_participant.awaitable_attrs.user,
session,
include=["statistics"],
user_instance = await recent_participant.awaitable_attrs.user
participants.append(await UserModel.transform(user_instance))
return participants
@ondemand
@staticmethod
async def current_user_score(
session: AsyncSession, room: "Room", includes: list[str] | None = None
) -> "ItemAttemptsCountDict | None":
item_attempt = (
await session.exec(
select(ItemAttemptsCount).where(
ItemAttemptsCount.room_id == room.id,
)
)
resp.host = await UserResp.from_db(await room.awaitable_attrs.host, session, include=["statistics"])
if "current_user_score" in include and user:
resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
return resp
).first()
if item_attempt is None:
return None
return await ItemAttemptsCountModel.transform(item_attempt, includes=includes)
class APIUploadedRoom(RoomBase):
def to_room(self) -> Room:
"""
将 APIUploadedRoom 转换为 Room 对象playlist 字段需单独处理。
"""
room_dict = self.model_dump()
room_dict.pop("playlist", None)
# host_id 已在字段中
return Room(**room_dict)
class Room(AsyncAttrs, RoomModel, table=True):
__tablename__: str = "rooms"
id: int | None
host_id: int | None = None
host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
password: str | None = Field(default=None)
host: User = Relationship()
playlist: list[Playlist] = Relationship(
sa_relationship_kwargs={
"lazy": "selectin",
"cascade": "all, delete-orphan",
"overlaps": "room",
}
)
class APIUploadedRoom(SQLModel):
name: str = Field(index=True)
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
status: RoomStatus
type: MatchType
duration: int | None = Field(default=None) # minutes
starts_at: datetime | None = Field(
sa_column=Column(
DateTime(timezone=True),
),
default_factory=utcnow,
)
ends_at: datetime | None = Field(
sa_column=Column(
DateTime(timezone=True),
),
default=None,
)
max_attempts: int | None = Field(default=None) # playlists
participant_count: int = Field(default=0)
channel_id: int = 0
queue_mode: QueueMode
auto_skip: bool
auto_start_duration: int
playlist: list[Playlist] = Field(default_factory=list)

View File

@@ -3,7 +3,7 @@ from datetime import date, datetime
import json
import math
import sys
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict
from app.calculator import (
calculate_pp_weight,
@@ -15,8 +15,6 @@ from app.calculator import (
pre_fetch_and_calculate_pp,
)
from app.config import settings
from app.database.beatmap_playcounts import BeatmapPlaycounts
from app.database.team import TeamMember
from app.dependencies.database import get_redis
from app.log import log
from app.models.beatmap import BeatmapRankStatus
@@ -39,8 +37,10 @@ from app.models.scoring_mode import ScoringMode
from app.storage import StorageService
from app.utils import utcnow
from .beatmap import Beatmap, BeatmapResp
from .beatmapset import BeatmapsetResp
from ._base import DatabaseModel, OnDemand, included, ondemand
from .beatmap import Beatmap, BeatmapDict, BeatmapModel
from .beatmap_playcounts import BeatmapPlaycounts
from .beatmapset import BeatmapsetDict, BeatmapsetModel
from .best_scores import BestScore
from .counts import MonthlyPlaycounts
from .events import Event, EventType
@@ -50,8 +50,9 @@ from .relationship import (
RelationshipType,
)
from .score_token import ScoreToken
from .team import TeamMember
from .total_score_best_scores import TotalScoreBestScore
from .user import User, UserResp
from .user import User, UserDict, UserModel
from pydantic import BaseModel, field_serializer, field_validator
from redis.asyncio import Redis
@@ -80,30 +81,290 @@ if TYPE_CHECKING:
logger = log("Score")
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# 基本字段
class ScoreDict(TypedDict):
beatmap_id: int
id: int
rank: Rank
type: str
user_id: int
accuracy: float
build_id: int | None
ended_at: datetime
has_replay: bool
max_combo: int
passed: bool
pp: float
started_at: datetime
total_score: int
maximum_statistics: ScoreStatistics
mods: list[APIMod]
classic_total_score: int | None
preserve: bool
processed: bool
ranked: bool
playlist_item_id: NotRequired[int | None]
room_id: NotRequired[int | None]
best_id: NotRequired[int | None]
legacy_perfect: NotRequired[bool]
is_perfect_combo: NotRequired[bool]
ruleset_id: NotRequired[int]
statistics: NotRequired[ScoreStatistics]
beatmapset: NotRequired[BeatmapsetDict]
beatmap: NotRequired[BeatmapDict]
current_user_attributes: NotRequired[CurrentUserAttributes]
position: NotRequired[int | None]
scores_around: NotRequired["ScoreAround | None"]
rank_country: NotRequired[int | None]
rank_global: NotRequired[int | None]
user: NotRequired[UserDict]
weight: NotRequired[float | None]
# ScoreResp 字段
legacy_total_score: NotRequired[int]
class ScoreModel(AsyncAttrs, DatabaseModel[ScoreDict]):
# https://github.com/ppy/osu-web/blob/master/app/Transformers/ScoreTransformer.php#L72-L84
MULTIPLAYER_SCORE_INCLUDE: ClassVar[list[str]] = ["playlist_item_id", "room_id", "solo_score_id"]
MULTIPLAYER_BASE_INCLUDES: ClassVar[list[str]] = [
"user.country",
"user.cover",
"user.team",
*MULTIPLAYER_SCORE_INCLUDE,
]
USER_PROFILE_INCLUDES: ClassVar[list[str]] = ["beatmap", "beatmapset", "user"]
# 基本字段
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
rank: Rank
type: str
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("lazer_users.id"),
index=True,
),
)
accuracy: float
map_md5: str = Field(max_length=32, index=True)
build_id: int | None = Field(default=None)
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score
ended_at: datetime = Field(sa_column=Column(DateTime))
has_replay: bool = Field(sa_column=Column(Boolean))
max_combo: int
mods: list[APIMod] = Field(sa_column=Column(JSON))
passed: bool = Field(sa_column=Column(Boolean))
playlist_item_id: int | None = Field(default=None) # multiplayer
pp: float = Field(default=0.0)
preserve: bool = Field(default=True, sa_column=Column(Boolean))
rank: Rank
room_id: int | None = Field(default=None) # multiplayer
started_at: datetime = Field(sa_column=Column(DateTime))
total_score: int = Field(default=0, sa_column=Column(BigInteger))
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
type: str
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict)
processed: bool = False # solo_score
ranked: bool = False
mods: list[APIMod] = Field(sa_column=Column(JSON))
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
# solo
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger))
preserve: bool = Field(default=True, sa_column=Column(Boolean))
processed: bool = Field(default=False)
ranked: bool = Field(default=False)
# multiplayer
playlist_item_id: OnDemand[int | None] = Field(default=None)
room_id: OnDemand[int | None] = Field(default=None)
@included
@staticmethod
async def best_id(
session: AsyncSession,
score: "Score",
) -> int | None:
return await get_best_id(session, score.id)
@included
@staticmethod
async def legacy_perfect(
_session: AsyncSession,
score: "Score",
) -> bool:
await score.awaitable_attrs.beatmap
return score.max_combo == score.beatmap.max_combo
@included
@staticmethod
async def is_perfect_combo(
_session: AsyncSession,
score: "Score",
) -> bool:
await score.awaitable_attrs.beatmap
return score.max_combo == score.beatmap.max_combo
@included
@staticmethod
async def ruleset_id(
_session: AsyncSession,
score: "Score",
) -> int:
return int(score.gamemode)
@included
@staticmethod
async def statistics(
_session: AsyncSession,
score: "Score",
) -> ScoreStatistics:
stats = {
HitResult.MISS: score.nmiss,
HitResult.MEH: score.n50,
HitResult.OK: score.n100,
HitResult.GREAT: score.n300,
HitResult.PERFECT: score.ngeki,
HitResult.GOOD: score.nkatu,
}
if score.nlarge_tick_miss is not None:
stats[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
if score.nslider_tail_hit is not None:
stats[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
if score.nsmall_tick_hit is not None:
stats[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
if score.nlarge_tick_hit is not None:
stats[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
return stats
@ondemand
@staticmethod
async def beatmapset(
_session: AsyncSession,
score: "Score",
includes: list[str] | None = None,
) -> BeatmapsetDict:
await score.awaitable_attrs.beatmap
return await BeatmapsetModel.transform(score.beatmap.beatmapset, includes=includes)
# reorder beatmapset and beatmap
# https://github.com/ppy/osu/blob/d8900defd34690de92be3406003fb3839fc0df1d/osu.Game/Online/API/Requests/Responses/SoloScoreInfo.cs#L111-L112
@ondemand
@staticmethod
async def beatmap(
_session: AsyncSession,
score: "Score",
includes: list[str] | None = None,
) -> BeatmapDict:
await score.awaitable_attrs.beatmap
return await BeatmapModel.transform(score.beatmap, includes=includes)
@ondemand
@staticmethod
async def current_user_attributes(
_session: AsyncSession,
score: "Score",
) -> CurrentUserAttributes:
return CurrentUserAttributes(pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id))
@ondemand
@staticmethod
async def position(
session: AsyncSession,
score: "Score",
) -> int | None:
return await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=score.user,
)
@ondemand
@staticmethod
async def scores_around(
session: AsyncSession, _score: "Score", playlist_id: int, room_id: int, is_playlist: bool
) -> "ScoreAround | None":
scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
col(PlaylistBestScore.score).has(col(Score.passed).is_(True)) if not is_playlist else True,
)
)
).all()
higher_scores = []
lower_scores = []
for score in scores:
total_score = score.score.total_score
resp = await ScoreModel.transform(score.score, includes=ScoreModel.MULTIPLAYER_BASE_INCLUDES)
if score.total_score > total_score:
higher_scores.append(resp)
elif score.total_score < total_score:
lower_scores.append(resp)
return ScoreAround(
higher=MultiplayerScores(scores=higher_scores),
lower=MultiplayerScores(scores=lower_scores),
)
@ondemand
@staticmethod
async def rank_country(
session: AsyncSession,
score: "Score",
) -> int | None:
return (
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
score.gamemode,
score.user,
type=LeaderboardType.COUNTRY,
)
or None
)
@ondemand
@staticmethod
async def rank_global(
session: AsyncSession,
score: "Score",
) -> int | None:
return (
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=score.user,
)
or None
)
@ondemand
@staticmethod
async def user(
_session: AsyncSession,
score: "Score",
includes: list[str] | None = None,
) -> UserDict:
return await UserModel.transform(score.user, ruleset=score.gamemode, includes=includes or [])
@ondemand
@staticmethod
async def weight(
session: AsyncSession,
score: "Score",
) -> float | None:
best_id = await get_best_id(session, score.id)
if best_id:
return calculate_pp_weight(best_id - 1)
return None
@ondemand
@staticmethod
async def legacy_total_score(
_session: AsyncSession,
_score: "Score",
) -> int:
return 0
@field_validator("maximum_statistics", mode="before")
@classmethod
@@ -151,17 +412,9 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# TODO: current_user_attributes
class Score(ScoreBase, table=True):
class Score(ScoreModel, table=True):
__tablename__: str = "scores"
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("lazer_users.id"),
index=True,
),
)
# ScoreStatistics
n300: int = Field(exclude=True)
n100: int = Field(exclude=True)
@@ -175,6 +428,7 @@ class Score(ScoreBase, table=True):
nsmall_tick_hit: int | None = Field(default=None, exclude=True)
gamemode: GameMode = Field(index=True)
pinned_order: int = Field(default=0, exclude=True)
map_md5: str = Field(max_length=32, index=True, exclude=True)
@field_validator("gamemode", mode="before")
@classmethod
@@ -245,9 +499,11 @@ class Score(ScoreBase, table=True):
maximum_statistics=self.maximum_statistics,
)
async def to_resp(self, session: AsyncSession, api_version: int) -> "ScoreResp | LegacyScoreResp":
async def to_resp(
self, session: AsyncSession, api_version: int, includes: list[str] = []
) -> "ScoreDict | LegacyScoreResp":
if api_version >= 20220705:
return await ScoreResp.from_db(session, self)
return await ScoreModel.transform(self, includes=includes)
return await LegacyScoreResp.from_db(session, self)
async def delete(
@@ -270,141 +526,7 @@ class Score(ScoreBase, table=True):
await session.delete(self)
class ScoreResp(ScoreBase):
id: int
user_id: int
is_perfect_combo: bool = False
legacy_perfect: bool = False
legacy_total_score: int = 0 # FIXME
weight: float = 0.0
best_id: int | None = None
ruleset_id: int | None = None
beatmap: BeatmapResp | None = None
beatmapset: BeatmapsetResp | None = None
user: UserResp | None = None
statistics: ScoreStatistics | None = None
maximum_statistics: ScoreStatistics | None = None
rank_global: int | None = None
rank_country: int | None = None
position: int | None = None
scores_around: "ScoreAround | None" = None
current_user_attributes: CurrentUserAttributes | None = None
@field_validator(
"has_replay",
"passed",
"preserve",
"is_perfect_combo",
"legacy_perfect",
"processed",
"ranked",
mode="before",
)
@classmethod
def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
if isinstance(v, int):
return bool(v)
return v
@field_validator("statistics", "maximum_statistics", mode="before")
@classmethod
def validate_statistics_fields(cls, v):
"""处理统计字段中的字符串键,转换为 HitResult 枚举"""
if isinstance(v, dict):
converted = {}
for key, value in v.items():
if isinstance(key, str):
try:
# 尝试将字符串转换为 HitResult 枚举
enum_key = HitResult(key)
converted[enum_key] = value
except ValueError:
# 如果转换失败,跳过这个键值对
continue
else:
converted[key] = value
return converted
return v
@field_serializer("statistics", when_used="json")
def serialize_statistics_fields(self, v):
"""序列化统计字段,确保枚举值正确转换为字符串"""
if isinstance(v, dict):
serialized = {}
for key, value in v.items():
if hasattr(key, "value"):
# 如果是枚举,使用其值
serialized[key.value] = value
else:
# 否则直接使用键
serialized[str(key)] = value
return serialized
return v
@classmethod
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
# 确保 score 对象完全加载,避免懒加载问题
await session.refresh(score)
s = cls.model_validate(score.model_dump())
await score.awaitable_attrs.beatmap
s.beatmap = await BeatmapResp.from_db(score.beatmap)
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = int(score.gamemode)
best_id = await get_best_id(session, score.id)
if best_id:
s.best_id = best_id
s.weight = calculate_pp_weight(best_id - 1)
s.statistics = {
HitResult.MISS: score.nmiss,
HitResult.MEH: score.n50,
HitResult.OK: score.n100,
HitResult.GREAT: score.n300,
HitResult.PERFECT: score.ngeki,
HitResult.GOOD: score.nkatu,
}
if score.nlarge_tick_miss is not None:
s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
if score.nslider_tail_hit is not None:
s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
if score.nsmall_tick_hit is not None:
s.statistics[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
if score.nlarge_tick_hit is not None:
s.statistics[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
s.user = await UserResp.from_db(
score.user,
session,
include=["statistics", "team", "daily_challenge_user_stats"],
ruleset=score.gamemode,
)
s.rank_global = (
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=score.user,
)
or None
)
s.rank_country = (
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
score.gamemode,
score.user,
type=LeaderboardType.COUNTRY,
)
or None
)
s.current_user_attributes = CurrentUserAttributes(
pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
)
return s
MultiplayScoreDict = ScoreModel.generate_typeddict(tuple(Score.MULTIPLAYER_BASE_INCLUDES)) # pyright: ignore[reportGeneralTypeIssues]
class LegacyStatistics(BaseModel):
@@ -417,31 +539,25 @@ class LegacyStatistics(BaseModel):
class LegacyScoreResp(UTCBaseModel):
accuracy: float
best_id: int
created_at: datetime
id: int
max_combo: int
mode: GameMode
mode_int: int
best_id: int
user_id: int
accuracy: float
mods: list[str] # acronym
passed: bool
score: int
max_combo: int
perfect: bool = False
statistics: LegacyStatistics
passed: bool
pp: float
rank: Rank
created_at: datetime
mode: GameMode
mode_int: int
replay: bool
score: int
statistics: LegacyStatistics
type: str
user_id: int
current_user_attributes: CurrentUserAttributes
user: UserResp
beatmap: BeatmapResp
rank_global: int | None = Field(default=None, exclude=True)
@classmethod
async def from_db(cls, session: AsyncSession, score: Score) -> "LegacyScoreResp":
await session.refresh(score)
async def from_db(cls, session: AsyncSession, score: "Score") -> "LegacyScoreResp":
await score.awaitable_attrs.beatmap
return cls(
accuracy=score.accuracy,
@@ -465,34 +581,13 @@ class LegacyScoreResp(UTCBaseModel):
count_geki=score.ngeki or 0,
count_katu=score.nkatu or 0,
),
type=score.type,
user_id=score.user_id,
current_user_attributes=CurrentUserAttributes(
pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
),
user=await UserResp.from_db(
score.user,
session,
include=["statistics", "team", "daily_challenge_user_stats"],
ruleset=score.gamemode,
),
beatmap=await BeatmapResp.from_db(score.beatmap),
perfect=score.is_perfect_combo,
rank_global=(
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=score.user,
)
or None
),
)
class MultiplayerScores(RespWithCursor):
scores: list[ScoreResp] = Field(default_factory=list)
scores: list[MultiplayScoreDict] = Field(default_factory=list) # pyright: ignore[reportInvalidTypeForm]
params: dict[str, Any] = Field(default_factory=dict)
@@ -842,13 +937,13 @@ async def get_user_best_pp(
# https://github.com/ppy/osu-queue-score-statistics/blob/master/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/PlayValidityHelper.cs
def get_play_length(score: Score, beatmap_length: int):
def get_play_length(score: "Score", beatmap_length: int):
speed_rate = get_speed_rate(score.mods)
length = beatmap_length / speed_rate
return int(min(length, (score.ended_at - score.started_at).total_seconds()))
def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]:
def calculate_playtime(score: "Score", beatmap_length: int) -> tuple[int, bool]:
total_length = get_play_length(score, beatmap_length)
total_obj_hited = (
score.n300
@@ -937,7 +1032,7 @@ async def process_score(
return score
async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
async def _process_score_pp(score: "Score", session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
if score.pp != 0:
logger.debug(
"Skipping PP calculation for score {score_id} | already set {pp:.2f}",
@@ -984,7 +1079,7 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f
)
async def _process_score_events(score: Score, session: AsyncSession):
async def _process_score_events(score: "Score", session: AsyncSession):
total_users = (await session.exec(select(func.count()).select_from(User))).one()
rank_global = await get_score_position_by_id(
session,
@@ -1088,7 +1183,7 @@ async def _process_statistics(
session: AsyncSession,
redis: Redis,
user: User,
score: Score,
score: "Score",
score_token: int,
beatmap_length: int,
beatmap_status: BeatmapRankStatus,
@@ -1318,7 +1413,7 @@ async def process_user(
redis: Redis,
fetcher: "Fetcher",
user: User,
score: Score,
score: "Score",
score_token: int,
beatmap_length: int,
beatmap_status: BeatmapRankStatus,

View File

@@ -0,0 +1,13 @@
from . import beatmap # noqa: F401
from .beatmapset import BeatmapsetModel
from sqlmodel import SQLModel
SearchBeatmapset = BeatmapsetModel.generate_typeddict(("beatmaps.max_combo", "pack_tags"))
class SearchBeatmapsetsResp(SQLModel):
beatmapsets: list[SearchBeatmapset] # pyright: ignore[reportInvalidTypeForm]
total: int
cursor: dict[str, int | float | str] | None = None
cursor_string: str | None = None

View File

@@ -1,10 +1,11 @@
from datetime import timedelta
import math
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
from app.models.score import GameMode
from app.utils import utcnow
from ._base import DatabaseModel, included, ondemand
from .rank_history import RankHistory
from pydantic import field_validator
@@ -15,7 +16,6 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
SQLModel,
col,
func,
select,
@@ -23,10 +23,40 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .user import User, UserResp
from .user import User, UserDict
class UserStatisticsBase(SQLModel):
class UserStatisticsDict(TypedDict):
mode: GameMode
count_100: int
count_300: int
count_50: int
count_miss: int
pp: float
ranked_score: int
hit_accuracy: float
total_score: int
total_hits: int
maximum_combo: int
play_count: int
play_time: int
replays_watched_by_others: int
is_ranked: bool
level: NotRequired[dict[str, int]]
global_rank: NotRequired[int | None]
grade_counts: NotRequired[dict[str, int]]
rank_change_since_30_days: NotRequired[int]
country_rank: NotRequired[int | None]
user: NotRequired["UserDict"]
class UserStatisticsModel(DatabaseModel[UserStatisticsDict]):
RANKING_INCLUDES: ClassVar[list[str]] = [
"user.country",
"user.cover",
"user.team",
]
mode: GameMode = Field(index=True)
count_100: int = Field(default=0, sa_column=Column(BigInteger))
count_300: int = Field(default=0, sa_column=Column(BigInteger))
@@ -57,8 +87,63 @@ class UserStatisticsBase(SQLModel):
return GameMode.OSU
return v
@included
@staticmethod
async def level(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]:
return {
"current": int(statistics.level_current),
"progress": int(math.fmod(statistics.level_current, 1) * 100),
}
class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
@included
@staticmethod
async def global_rank(session: AsyncSession, statistics: "UserStatistics") -> int | None:
return await get_rank(session, statistics)
@included
@staticmethod
async def grade_counts(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]:
return {
"ss": statistics.grade_ss,
"ssh": statistics.grade_ssh,
"s": statistics.grade_s,
"sh": statistics.grade_sh,
"a": statistics.grade_a,
}
@ondemand
@staticmethod
async def rank_change_since_30_days(session: AsyncSession, statistics: "UserStatistics") -> int:
global_rank = await get_rank(session, statistics)
rank_best = (
await session.exec(
select(func.max(RankHistory.rank)).where(
RankHistory.date > utcnow() - timedelta(days=30),
RankHistory.user_id == statistics.user_id,
)
)
).first()
if rank_best is None or global_rank is None:
return 0
return rank_best - global_rank
@ondemand
@staticmethod
async def country_rank(
session: AsyncSession, statistics: "UserStatistics", user_country: str | None = None
) -> int | None:
return await get_rank(session, statistics, user_country)
@ondemand
@staticmethod
async def user(_session: AsyncSession, statistics: "UserStatistics") -> "UserDict":
from .user import UserModel
user_instance = await statistics.awaitable_attrs.user
return await UserModel.transform(user_instance)
class UserStatistics(AsyncAttrs, UserStatisticsModel, table=True):
__tablename__: str = "lazer_user_statistics"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(
@@ -80,74 +165,6 @@ class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
user: "User" = Relationship(back_populates="statistics")
class UserStatisticsResp(UserStatisticsBase):
user: "UserResp | None" = None
rank_change_since_30_days: int | None = 0
global_rank: int | None = Field(default=None)
country_rank: int | None = Field(default=None)
grade_counts: dict[str, int] = Field(
default_factory=lambda: {
"ss": 0,
"ssh": 0,
"s": 0,
"sh": 0,
"a": 0,
}
)
level: dict[str, int] = Field(
default_factory=lambda: {
"current": 1,
"progress": 0,
}
)
@classmethod
async def from_db(
cls,
obj: UserStatistics,
session: AsyncSession,
user_country: str | None = None,
include: list[str] = [],
) -> "UserStatisticsResp":
s = cls.model_validate(obj.model_dump())
s.grade_counts = {
"ss": obj.grade_ss,
"ssh": obj.grade_ssh,
"s": obj.grade_s,
"sh": obj.grade_sh,
"a": obj.grade_a,
}
s.level = {
"current": int(obj.level_current),
"progress": int(math.fmod(obj.level_current, 1) * 100),
}
if "user" in include:
from .user import RANKING_INCLUDES, UserResp
user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
s.user = user
user_country = user.country_code
s.global_rank = await get_rank(session, obj)
s.country_rank = await get_rank(session, obj, user_country)
if "rank_change_since_30_days" in include:
rank_best = (
await session.exec(
select(func.max(RankHistory.rank)).where(
RankHistory.date > utcnow() - timedelta(days=30),
RankHistory.user_id == obj.user_id,
)
)
).first()
if rank_best is None or s.global_rank is None:
s.rank_change_since_30_days = 0
else:
s.rank_change_since_30_days = rank_best - s.global_rank
return s
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
from .user import User
@@ -164,7 +181,6 @@ async def get_rank(session: AsyncSession, statistics: UserStatistics, country: s
query = query.join(User).where(User.country_code == country)
subq = query.subquery()
result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id))
rank = result.first()

View File

@@ -1,9 +1,9 @@
from typing import TYPE_CHECKING
from app.calculator import calculate_score_to_level
from app.database.statistics import UserStatistics
from app.models.score import GameMode, Rank
from .statistics import UserStatistics
from .user import User
from sqlmodel import (

File diff suppressed because it is too large Load Diff