refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
@@ -2,23 +2,31 @@ from .achievement import UserAchievement, UserAchievementResp
|
||||
from .auth import OAuthClient, OAuthToken, TotpKeys, V1APIKeys
|
||||
from .beatmap import (
|
||||
Beatmap,
|
||||
BeatmapResp,
|
||||
BeatmapDict,
|
||||
BeatmapModel,
|
||||
)
|
||||
from .beatmap_playcounts import (
|
||||
BeatmapPlaycounts,
|
||||
BeatmapPlaycountsDict,
|
||||
BeatmapPlaycountsModel,
|
||||
)
|
||||
from .beatmap_playcounts import BeatmapPlaycounts, BeatmapPlaycountsResp
|
||||
from .beatmap_sync import BeatmapSync
|
||||
from .beatmap_tags import BeatmapTagVote
|
||||
from .beatmapset import (
|
||||
Beatmapset,
|
||||
BeatmapsetResp,
|
||||
BeatmapsetDict,
|
||||
BeatmapsetModel,
|
||||
)
|
||||
from .beatmapset_ratings import BeatmapRating
|
||||
from .best_scores import BestScore
|
||||
from .chat import (
|
||||
ChannelType,
|
||||
ChatChannel,
|
||||
ChatChannelResp,
|
||||
ChatChannelDict,
|
||||
ChatChannelModel,
|
||||
ChatMessage,
|
||||
ChatMessageResp,
|
||||
ChatMessageDict,
|
||||
ChatMessageModel,
|
||||
)
|
||||
from .counts import (
|
||||
CountResp,
|
||||
@@ -30,8 +38,8 @@ from .events import Event
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
from .item_attempts_count import (
|
||||
ItemAttemptsCount,
|
||||
ItemAttemptsResp,
|
||||
PlaylistAggregateScore,
|
||||
ItemAttemptsCountDict,
|
||||
ItemAttemptsCountModel,
|
||||
)
|
||||
from .matchmaking import (
|
||||
MatchmakingPool,
|
||||
@@ -42,30 +50,32 @@ from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
|
||||
from .notification import Notification, UserNotification
|
||||
from .password_reset import PasswordReset
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
from .playlists import Playlist, PlaylistResp
|
||||
from .playlists import Playlist, PlaylistDict, PlaylistModel
|
||||
from .rank_history import RankHistory, RankHistoryResp, RankTop
|
||||
from .relationship import Relationship, RelationshipResp, RelationshipType
|
||||
from .room import APIUploadedRoom, Room, RoomResp
|
||||
from .relationship import Relationship, RelationshipDict, RelationshipModel, RelationshipType
|
||||
from .room import APIUploadedRoom, Room, RoomDict, RoomModel
|
||||
from .room_participated_user import RoomParticipatedUser
|
||||
from .score import (
|
||||
MultiplayerScores,
|
||||
Score,
|
||||
ScoreAround,
|
||||
ScoreBase,
|
||||
ScoreResp,
|
||||
ScoreDict,
|
||||
ScoreModel,
|
||||
ScoreStatistics,
|
||||
)
|
||||
from .score_token import ScoreToken, ScoreTokenResp
|
||||
from .search_beatmapset import SearchBeatmapsetsResp
|
||||
from .statistics import (
|
||||
UserStatistics,
|
||||
UserStatisticsResp,
|
||||
UserStatisticsDict,
|
||||
UserStatisticsModel,
|
||||
)
|
||||
from .team import Team, TeamMember, TeamRequest, TeamResp
|
||||
from .total_score_best_scores import TotalScoreBestScore
|
||||
from .user import (
|
||||
MeResp,
|
||||
User,
|
||||
UserResp,
|
||||
UserDict,
|
||||
UserModel,
|
||||
)
|
||||
from .user_account_history import (
|
||||
UserAccountHistory,
|
||||
@@ -79,20 +89,25 @@ from .verification import EmailVerification, LoginSession, LoginSessionResp, Tru
|
||||
__all__ = [
|
||||
"APIUploadedRoom",
|
||||
"Beatmap",
|
||||
"BeatmapDict",
|
||||
"BeatmapModel",
|
||||
"BeatmapPlaycounts",
|
||||
"BeatmapPlaycountsResp",
|
||||
"BeatmapPlaycountsDict",
|
||||
"BeatmapPlaycountsModel",
|
||||
"BeatmapRating",
|
||||
"BeatmapResp",
|
||||
"BeatmapSync",
|
||||
"BeatmapTagVote",
|
||||
"Beatmapset",
|
||||
"BeatmapsetResp",
|
||||
"BeatmapsetDict",
|
||||
"BeatmapsetModel",
|
||||
"BestScore",
|
||||
"ChannelType",
|
||||
"ChatChannel",
|
||||
"ChatChannelResp",
|
||||
"ChatChannelDict",
|
||||
"ChatChannelModel",
|
||||
"ChatMessage",
|
||||
"ChatMessageResp",
|
||||
"ChatMessageDict",
|
||||
"ChatMessageModel",
|
||||
"CountResp",
|
||||
"DailyChallengeStats",
|
||||
"DailyChallengeStatsResp",
|
||||
@@ -100,13 +115,13 @@ __all__ = [
|
||||
"Event",
|
||||
"FavouriteBeatmapset",
|
||||
"ItemAttemptsCount",
|
||||
"ItemAttemptsResp",
|
||||
"ItemAttemptsCountDict",
|
||||
"ItemAttemptsCountModel",
|
||||
"LoginSession",
|
||||
"LoginSessionResp",
|
||||
"MatchmakingPool",
|
||||
"MatchmakingPoolBeatmap",
|
||||
"MatchmakingUserStats",
|
||||
"MeResp",
|
||||
"MonthlyPlaycounts",
|
||||
"MultiplayerEvent",
|
||||
"MultiplayerEventResp",
|
||||
@@ -116,26 +131,29 @@ __all__ = [
|
||||
"OAuthToken",
|
||||
"PasswordReset",
|
||||
"Playlist",
|
||||
"PlaylistAggregateScore",
|
||||
"PlaylistBestScore",
|
||||
"PlaylistResp",
|
||||
"PlaylistDict",
|
||||
"PlaylistModel",
|
||||
"RankHistory",
|
||||
"RankHistoryResp",
|
||||
"RankTop",
|
||||
"Relationship",
|
||||
"RelationshipResp",
|
||||
"RelationshipDict",
|
||||
"RelationshipModel",
|
||||
"RelationshipType",
|
||||
"ReplayWatchedCount",
|
||||
"Room",
|
||||
"RoomDict",
|
||||
"RoomModel",
|
||||
"RoomParticipatedUser",
|
||||
"RoomResp",
|
||||
"Score",
|
||||
"ScoreAround",
|
||||
"ScoreBase",
|
||||
"ScoreResp",
|
||||
"ScoreDict",
|
||||
"ScoreModel",
|
||||
"ScoreStatistics",
|
||||
"ScoreToken",
|
||||
"ScoreTokenResp",
|
||||
"SearchBeatmapsetsResp",
|
||||
"Team",
|
||||
"TeamMember",
|
||||
"TeamRequest",
|
||||
@@ -149,17 +167,18 @@ __all__ = [
|
||||
"UserAccountHistoryResp",
|
||||
"UserAccountHistoryType",
|
||||
"UserAchievement",
|
||||
"UserAchievement",
|
||||
"UserAchievementResp",
|
||||
"UserDict",
|
||||
"UserLoginLog",
|
||||
"UserModel",
|
||||
"UserNotification",
|
||||
"UserPreference",
|
||||
"UserResp",
|
||||
"UserStatistics",
|
||||
"UserStatisticsResp",
|
||||
"UserStatisticsDict",
|
||||
"UserStatisticsModel",
|
||||
"V1APIKeys",
|
||||
]
|
||||
|
||||
for i in __all__:
|
||||
if i.endswith("Resp"):
|
||||
globals()[i].model_rebuild() # type: ignore[call-arg]
|
||||
if i.endswith("Model") or i.endswith("Resp"):
|
||||
globals()[i].model_rebuild()
|
||||
|
||||
499
app/database/_base.py
Normal file
499
app/database/_base.py
Normal file
@@ -0,0 +1,499 @@
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from functools import lru_cache, wraps
|
||||
import inspect
|
||||
import sys
|
||||
from types import NoneType, get_original_bases
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Concatenate,
|
||||
ForwardRef,
|
||||
ParamSpec,
|
||||
TypedDict,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
overload,
|
||||
)
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.utils import type_is_optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_object_session
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import SQLModelMetaclass
|
||||
|
||||
_dict_to_model: dict[type, type["DatabaseModel"]] = {}
|
||||
|
||||
|
||||
def _safe_evaluate_forwardref(type_: str | ForwardRef, module_name: str) -> Any:
|
||||
"""Safely evaluate a ForwardRef, with fallback to app.database module"""
|
||||
if isinstance(type_, str):
|
||||
type_ = ForwardRef(type_)
|
||||
|
||||
try:
|
||||
return evaluate_forwardref(
|
||||
type_,
|
||||
globalns=vars(sys.modules[module_name]),
|
||||
localns={},
|
||||
)
|
||||
except (NameError, AttributeError, KeyError):
|
||||
# Fallback to app.database module
|
||||
try:
|
||||
import app.database
|
||||
|
||||
return evaluate_forwardref(
|
||||
type_,
|
||||
globalns=vars(app.database),
|
||||
localns={},
|
||||
)
|
||||
except (NameError, AttributeError, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
class OnDemand[T]:
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def __get__(self, instance: object | None, owner: Any) -> T: ...
|
||||
|
||||
def __set__(self, instance: Any, value: T) -> None: ...
|
||||
|
||||
def __delete__(self, instance: Any) -> None: ...
|
||||
|
||||
|
||||
class Exclude[T]:
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def __get__(self, instance: object | None, owner: Any) -> T: ...
|
||||
|
||||
def __set__(self, instance: Any, value: T) -> None: ...
|
||||
|
||||
def __delete__(self, instance: Any) -> None: ...
|
||||
|
||||
|
||||
# https://github.com/fastapi/sqlmodel/blob/main/sqlmodel/_compat.py#L126-L140
|
||||
def _get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
raw_annotations: dict[str, Any] = class_dict.get("__annotations__", {})
|
||||
if sys.version_info >= (3, 14) and "__annotations__" not in class_dict:
|
||||
# See https://github.com/pydantic/pydantic/pull/11991
|
||||
from annotationlib import (
|
||||
Format,
|
||||
call_annotate_function,
|
||||
get_annotate_from_class_namespace,
|
||||
)
|
||||
|
||||
if annotate := get_annotate_from_class_namespace(class_dict):
|
||||
raw_annotations = call_annotate_function(annotate, format=Format.FORWARDREF)
|
||||
return raw_annotations
|
||||
|
||||
|
||||
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L58-L77
|
||||
if sys.version_info < (3, 12, 4):
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
# Even though it is the right signature for python 3.9, mypy complains with
|
||||
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
|
||||
# Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid:
|
||||
# TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard'
|
||||
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
|
||||
|
||||
else:
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
# Pydantic 1.x will not support PEP 695 syntax, but provide `type_params` to avoid
|
||||
# warnings:
|
||||
return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set())
|
||||
|
||||
|
||||
class DatabaseModelMetaclass(SQLModelMetaclass):
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> "DatabaseModelMetaclass":
|
||||
original_annotations = _get_annotations(namespace)
|
||||
new_annotations = {}
|
||||
ondemands = []
|
||||
excludes = []
|
||||
|
||||
for k, v in original_annotations.items():
|
||||
if get_origin(v) is OnDemand:
|
||||
inner_type = v.__args__[0]
|
||||
new_annotations[k] = inner_type
|
||||
ondemands.append(k)
|
||||
elif get_origin(v) is Exclude:
|
||||
inner_type = v.__args__[0]
|
||||
new_annotations[k] = inner_type
|
||||
excludes.append(k)
|
||||
else:
|
||||
new_annotations[k] = v
|
||||
|
||||
new_class = super().__new__(
|
||||
cls,
|
||||
name,
|
||||
bases,
|
||||
{
|
||||
**namespace,
|
||||
"__annotations__": new_annotations,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
new_class._CALCULATED_FIELDS = dict(getattr(new_class, "_CALCULATED_FIELDS", {}))
|
||||
new_class._ONDEMAND_DATABASE_FIELDS = list(getattr(new_class, "_ONDEMAND_DATABASE_FIELDS", [])) + list(
|
||||
ondemands
|
||||
)
|
||||
new_class._ONDEMAND_CALCULATED_FIELDS = dict(getattr(new_class, "_ONDEMAND_CALCULATED_FIELDS", {}))
|
||||
new_class._EXCLUDED_DATABASE_FIELDS = list(getattr(new_class, "_EXCLUDED_DATABASE_FIELDS", [])) + list(excludes)
|
||||
|
||||
for attr_name, attr_value in namespace.items():
|
||||
target = _get_callable_target(attr_value)
|
||||
if target is None:
|
||||
continue
|
||||
|
||||
if getattr(target, "__included__", False):
|
||||
new_class._CALCULATED_FIELDS[attr_name] = _get_return_type(target)
|
||||
_pre_calculate_context_params(target, attr_value)
|
||||
|
||||
if getattr(target, "__calculated_ondemand__", False):
|
||||
new_class._ONDEMAND_CALCULATED_FIELDS[attr_name] = _get_return_type(target)
|
||||
_pre_calculate_context_params(target, attr_value)
|
||||
|
||||
# Register TDict to DatabaseModel mapping
|
||||
for base in get_original_bases(new_class):
|
||||
cls_name = base.__name__
|
||||
if "DatabaseModel" in cls_name and "[" in cls_name and "]" in cls_name:
|
||||
generic_type_name = cls_name[cls_name.index("[") : cls_name.rindex("]") + 1]
|
||||
generic_type = evaluate_forwardref(
|
||||
ForwardRef(generic_type_name),
|
||||
globalns=vars(sys.modules[new_class.__module__]),
|
||||
localns={},
|
||||
)
|
||||
_dict_to_model[generic_type[0]] = new_class
|
||||
|
||||
return new_class
|
||||
|
||||
|
||||
def _pre_calculate_context_params(target: Callable, attr_value: Any) -> None:
|
||||
if hasattr(target, "__context_params__"):
|
||||
return
|
||||
|
||||
sig = inspect.signature(target)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
start_index = 2
|
||||
if isinstance(attr_value, classmethod):
|
||||
start_index = 3
|
||||
|
||||
context_params = [] if len(params) < start_index else params[start_index:]
|
||||
|
||||
setattr(target, "__context_params__", context_params)
|
||||
|
||||
|
||||
def _get_callable_target(value: Any) -> Callable | None:
|
||||
if isinstance(value, (staticmethod, classmethod)):
|
||||
return value.__func__
|
||||
if inspect.isfunction(value):
|
||||
return value
|
||||
if inspect.ismethod(value):
|
||||
return value.__func__
|
||||
return None
|
||||
|
||||
|
||||
def _mark_callable(value: Any, flag: str) -> Callable | None:
|
||||
target = _get_callable_target(value)
|
||||
if target is None:
|
||||
return None
|
||||
setattr(target, flag, True)
|
||||
return target
|
||||
|
||||
|
||||
def _get_return_type(func: Callable) -> type:
|
||||
sig = inspect.get_annotations(func)
|
||||
return sig.get("return", Any)
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
CalculatedField = Callable[Concatenate[AsyncSession, Any, P], Awaitable[Any]]
|
||||
DecoratorTarget = CalculatedField | staticmethod | classmethod
|
||||
|
||||
|
||||
def included(func: DecoratorTarget) -> DecoratorTarget:
|
||||
marker = _mark_callable(func, "__included__")
|
||||
if marker is None:
|
||||
raise RuntimeError("@included is only usable on callables.")
|
||||
|
||||
@wraps(marker)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await marker(*args, **kwargs)
|
||||
|
||||
if isinstance(func, staticmethod):
|
||||
return staticmethod(wrapper)
|
||||
if isinstance(func, classmethod):
|
||||
return classmethod(wrapper)
|
||||
return wrapper
|
||||
|
||||
|
||||
def ondemand(func: DecoratorTarget) -> DecoratorTarget:
|
||||
marker = _mark_callable(func, "__calculated_ondemand__")
|
||||
if marker is None:
|
||||
raise RuntimeError("@ondemand is only usable on callables.")
|
||||
|
||||
@wraps(marker)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await marker(*args, **kwargs)
|
||||
|
||||
if isinstance(func, staticmethod):
|
||||
return staticmethod(wrapper)
|
||||
if isinstance(func, classmethod):
|
||||
return classmethod(wrapper)
|
||||
return wrapper
|
||||
|
||||
|
||||
async def call_awaitable_with_context(
|
||||
func: CalculatedField,
|
||||
session: AsyncSession,
|
||||
instance: Any,
|
||||
context: dict[str, Any],
|
||||
) -> Any:
|
||||
context_params: list[str] | None = getattr(func, "__context_params__", None)
|
||||
|
||||
if context_params is None:
|
||||
# Fallback if not pre-calculated
|
||||
sig = inspect.signature(func)
|
||||
if len(sig.parameters) == 2:
|
||||
return await func(session, instance)
|
||||
else:
|
||||
call_params = {}
|
||||
for param in sig.parameters.values():
|
||||
if param.name in context:
|
||||
call_params[param.name] = context[param.name]
|
||||
return await func(session, instance, **call_params)
|
||||
|
||||
if not context_params:
|
||||
return await func(session, instance)
|
||||
|
||||
call_params = {}
|
||||
for name in context_params:
|
||||
if name in context:
|
||||
call_params[name] = context[name]
|
||||
return await func(session, instance, **call_params)
|
||||
|
||||
|
||||
class DatabaseModel[TDict](SQLModel, UTCBaseModel, metaclass=DatabaseModelMetaclass):
|
||||
_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
|
||||
|
||||
_ONDEMAND_DATABASE_FIELDS: ClassVar[list[str]] = []
|
||||
_ONDEMAND_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
|
||||
|
||||
_EXCLUDED_DATABASE_FIELDS: ClassVar[list[str]] = []
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def transform(
|
||||
cls,
|
||||
db_instance: "DatabaseModel",
|
||||
*,
|
||||
session: AsyncSession,
|
||||
includes: list[str] | None = None,
|
||||
**context: Any,
|
||||
) -> TDict: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def transform(
|
||||
cls,
|
||||
db_instance: "DatabaseModel",
|
||||
*,
|
||||
includes: list[str] | None = None,
|
||||
**context: Any,
|
||||
) -> TDict: ...
|
||||
|
||||
@classmethod
|
||||
async def transform(
|
||||
cls,
|
||||
db_instance: "DatabaseModel",
|
||||
*,
|
||||
session: AsyncSession | None = None,
|
||||
includes: list[str] | None = None,
|
||||
**context: Any,
|
||||
) -> TDict:
|
||||
includes = includes.copy() if includes is not None else []
|
||||
session = cast(AsyncSession | None, async_object_session(db_instance)) if session is None else session
|
||||
if session is None:
|
||||
raise RuntimeError("DatabaseModel.transform requires a session-bound instance.")
|
||||
resp_obj = cls.model_validate(db_instance.model_dump())
|
||||
data = resp_obj.model_dump()
|
||||
|
||||
for field in cls._CALCULATED_FIELDS:
|
||||
func = getattr(cls, field)
|
||||
value = await call_awaitable_with_context(func, session, db_instance, context)
|
||||
data[field] = value
|
||||
|
||||
sub_include_map: dict[str, list[str]] = {}
|
||||
for include in [i for i in includes if "." in i]:
|
||||
parent, sub_include = include.split(".", 1)
|
||||
if parent not in sub_include_map:
|
||||
sub_include_map[parent] = []
|
||||
sub_include_map[parent].append(sub_include)
|
||||
includes.remove(include) # pyright: ignore[reportOptionalMemberAccess]
|
||||
|
||||
for field, sub_includes in sub_include_map.items():
|
||||
if field in cls._ONDEMAND_CALCULATED_FIELDS:
|
||||
func = getattr(cls, field)
|
||||
value = await call_awaitable_with_context(
|
||||
func, session, db_instance, {**context, "includes": sub_includes}
|
||||
)
|
||||
data[field] = value
|
||||
|
||||
for include in includes:
|
||||
if include in data:
|
||||
continue
|
||||
|
||||
if include in cls._ONDEMAND_CALCULATED_FIELDS:
|
||||
func = getattr(cls, include)
|
||||
value = await call_awaitable_with_context(func, session, db_instance, context)
|
||||
data[include] = value
|
||||
|
||||
for field in cls._ONDEMAND_DATABASE_FIELDS:
|
||||
if field not in includes:
|
||||
del data[field]
|
||||
|
||||
for field in cls._EXCLUDED_DATABASE_FIELDS:
|
||||
if field in data:
|
||||
del data[field]
|
||||
|
||||
return cast(TDict, data)
|
||||
|
||||
@classmethod
|
||||
async def transform_many(
|
||||
cls,
|
||||
db_instances: Sequence["DatabaseModel"],
|
||||
*,
|
||||
session: AsyncSession | None = None,
|
||||
includes: list[str] | None = None,
|
||||
**context: Any,
|
||||
) -> list[TDict]:
|
||||
if not db_instances:
|
||||
return []
|
||||
|
||||
# SQLAlchemy AsyncSession is not concurrency-safe, so we cannot use asyncio.gather here
|
||||
# if the transform method performs any database operations using the shared session.
|
||||
# Since we don't know if the transform method (or its calculated fields) will use the DB,
|
||||
# we must execute them serially to be safe.
|
||||
results = []
|
||||
for instance in db_instances:
|
||||
results.append(await cls.transform(instance, session=session, includes=includes, **context))
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
@lru_cache
|
||||
def generate_typeddict(cls, includes: tuple[str, ...] | None = None) -> type[TypedDict]: # pyright: ignore[reportInvalidTypeForm]
|
||||
def _evaluate_type(field_type: Any, *, resolve_database_model: bool = False, field_name: str = "") -> Any:
|
||||
# Evaluate ForwardRef if present
|
||||
if isinstance(field_type, (str, ForwardRef)):
|
||||
resolved = _safe_evaluate_forwardref(field_type, cls.__module__)
|
||||
if resolved is not None:
|
||||
field_type = resolved
|
||||
|
||||
origin_type = get_origin(field_type)
|
||||
inner_type = field_type
|
||||
args = get_args(field_type)
|
||||
|
||||
is_optional = type_is_optional(field_type) # pyright: ignore[reportArgumentType]
|
||||
if is_optional:
|
||||
inner_type = next((arg for arg in args if arg is not NoneType), field_type)
|
||||
|
||||
is_list = False
|
||||
if origin_type is list:
|
||||
is_list = True
|
||||
inner_type = args[0]
|
||||
|
||||
# Evaluate ForwardRef in inner_type if present
|
||||
if isinstance(inner_type, (str, ForwardRef)):
|
||||
resolved = _safe_evaluate_forwardref(inner_type, cls.__module__)
|
||||
if resolved is not None:
|
||||
inner_type = resolved
|
||||
|
||||
if not resolve_database_model:
|
||||
if is_optional:
|
||||
return inner_type | None # pyright: ignore[reportOperatorIssue]
|
||||
elif is_list:
|
||||
return list[inner_type]
|
||||
return inner_type
|
||||
|
||||
model_class = None
|
||||
|
||||
# First check if inner_type is directly a DatabaseModel subclass
|
||||
try:
|
||||
if inspect.isclass(inner_type) and issubclass(inner_type, DatabaseModel): # type: ignore
|
||||
model_class = inner_type
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
# If not found, look up in _dict_to_model
|
||||
if model_class is None:
|
||||
model_class = _dict_to_model.get(inner_type) # type: ignore
|
||||
|
||||
if model_class is not None:
|
||||
nested_dict = model_class.generate_typeddict(tuple(sub_include_map.get(field_name, ())))
|
||||
resolved_type = list[nested_dict] if is_list else nested_dict # type: ignore
|
||||
|
||||
if is_optional:
|
||||
resolved_type = resolved_type | None # type: ignore
|
||||
|
||||
return resolved_type
|
||||
|
||||
# Fallback: use the resolved inner_type
|
||||
resolved_type = list[inner_type] if is_list else inner_type # type: ignore
|
||||
if is_optional:
|
||||
resolved_type = resolved_type | None # type: ignore
|
||||
return resolved_type
|
||||
|
||||
if includes is None:
|
||||
includes = ()
|
||||
|
||||
# Parse nested includes
|
||||
direct_includes = []
|
||||
sub_include_map: dict[str, list[str]] = {}
|
||||
for include in includes:
|
||||
if "." in include:
|
||||
parent, sub_include = include.split(".", 1)
|
||||
if parent not in sub_include_map:
|
||||
sub_include_map[parent] = []
|
||||
sub_include_map[parent].append(sub_include)
|
||||
if parent not in direct_includes:
|
||||
direct_includes.append(parent)
|
||||
else:
|
||||
direct_includes.append(include)
|
||||
|
||||
fields = {}
|
||||
|
||||
# Process model fields
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
field_type = field_info.annotation or Any
|
||||
field_type = _evaluate_type(field_type, field_name=field_name)
|
||||
|
||||
if field_name in cls._ONDEMAND_DATABASE_FIELDS and field_name not in direct_includes:
|
||||
continue
|
||||
else:
|
||||
fields[field_name] = field_type
|
||||
|
||||
# Process calculated fields
|
||||
for field_name, field_type in cls._CALCULATED_FIELDS.items():
|
||||
field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
|
||||
fields[field_name] = field_type
|
||||
|
||||
# Process ondemand calculated fields
|
||||
for field_name, field_type in cls._ONDEMAND_CALCULATED_FIELDS.items():
|
||||
if field_name not in direct_includes:
|
||||
continue
|
||||
|
||||
field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
|
||||
fields[field_name] = field_type
|
||||
|
||||
return TypedDict(f"{cls.__name__}Dict[{', '.join(includes)}]" if includes else f"{cls.__name__}Dict", fields) # pyright: ignore[reportArgumentType]
|
||||
@@ -1,117 +1,339 @@
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
|
||||
|
||||
from app.calculator import get_calculator
|
||||
from app.config import settings
|
||||
from app.database.beatmap_tags import BeatmapTagVote
|
||||
from app.database.failtime import FailTime, FailTimeResp
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import APIMod
|
||||
from app.models.performance import DifficultyAttributesUnion
|
||||
from app.models.score import GameMode
|
||||
|
||||
from ._base import DatabaseModel, OnDemand, included, ondemand
|
||||
from .beatmap_playcounts import BeatmapPlaycounts
|
||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
||||
from .beatmap_tags import BeatmapTagVote
|
||||
from .beatmapset import Beatmapset, BeatmapsetDict, BeatmapsetModel
|
||||
from .failtime import FailTime, FailTimeResp
|
||||
from .user import User, UserDict, UserModel
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
from .user import User
|
||||
|
||||
|
||||
class BeatmapOwner(SQLModel):
|
||||
id: int
|
||||
username: str
|
||||
|
||||
|
||||
class BeatmapBase(SQLModel):
|
||||
# Beatmap
|
||||
url: str
|
||||
class BeatmapDict(TypedDict):
|
||||
beatmapset_id: int
|
||||
difficulty_rating: float
|
||||
id: int
|
||||
mode: GameMode
|
||||
total_length: int
|
||||
user_id: int
|
||||
version: str
|
||||
url: str
|
||||
|
||||
checksum: NotRequired[str]
|
||||
max_combo: NotRequired[int | None]
|
||||
ar: NotRequired[float]
|
||||
cs: NotRequired[float]
|
||||
drain: NotRequired[float]
|
||||
accuracy: NotRequired[float]
|
||||
bpm: NotRequired[float]
|
||||
count_circles: NotRequired[int]
|
||||
count_sliders: NotRequired[int]
|
||||
count_spinners: NotRequired[int]
|
||||
deleted_at: NotRequired[datetime | None]
|
||||
hit_length: NotRequired[int]
|
||||
last_updated: NotRequired[datetime]
|
||||
|
||||
status: NotRequired[str]
|
||||
beatmapset: NotRequired[BeatmapsetDict]
|
||||
current_user_playcount: NotRequired[int]
|
||||
current_user_tag_ids: NotRequired[list[int]]
|
||||
failtimes: NotRequired[FailTimeResp]
|
||||
top_tag_ids: NotRequired[list[dict[str, int]]]
|
||||
user: NotRequired[UserDict]
|
||||
convert: NotRequired[bool]
|
||||
is_scoreable: NotRequired[bool]
|
||||
mode_int: NotRequired[int]
|
||||
ranked: NotRequired[int]
|
||||
playcount: NotRequired[int]
|
||||
passcount: NotRequired[int]
|
||||
|
||||
|
||||
class BeatmapModel(DatabaseModel[BeatmapDict]):
|
||||
BEATMAP_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
|
||||
"checksum",
|
||||
"accuracy",
|
||||
"ar",
|
||||
"bpm",
|
||||
"convert",
|
||||
"count_circles",
|
||||
"count_sliders",
|
||||
"count_spinners",
|
||||
"cs",
|
||||
"deleted_at",
|
||||
"drain",
|
||||
"hit_length",
|
||||
"is_scoreable",
|
||||
"last_updated",
|
||||
"mode_int",
|
||||
"passcount",
|
||||
"playcount",
|
||||
"ranked",
|
||||
"url",
|
||||
]
|
||||
DEFAULT_API_INCLUDES: ClassVar[list[str]] = [
|
||||
"beatmapset.ratings",
|
||||
"current_user_playcount",
|
||||
"failtimes",
|
||||
"max_combo",
|
||||
"owners",
|
||||
]
|
||||
TRANSFORMER_INCLUDES: ClassVar[list[str]] = [*DEFAULT_API_INCLUDES, *BEATMAP_TRANSFORMER_INCLUDES]
|
||||
|
||||
# Beatmap
|
||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||
difficulty_rating: float = Field(default=0.0, index=True)
|
||||
id: int = Field(primary_key=True, index=True)
|
||||
mode: GameMode
|
||||
total_length: int
|
||||
user_id: int = Field(index=True)
|
||||
version: str = Field(index=True)
|
||||
|
||||
url: OnDemand[str]
|
||||
# optional
|
||||
checksum: str = Field(sa_column=Column(VARCHAR(32), index=True))
|
||||
current_user_playcount: int = Field(default=0)
|
||||
max_combo: int | None = Field(default=0)
|
||||
# TODO: failtimes, owners
|
||||
checksum: OnDemand[str] = Field(sa_column=Column(VARCHAR(32), index=True))
|
||||
max_combo: OnDemand[int | None] = Field(default=0)
|
||||
# TODO: owners
|
||||
|
||||
# BeatmapExtended
|
||||
ar: float = Field(default=0.0)
|
||||
cs: float = Field(default=0.0)
|
||||
drain: float = Field(default=0.0) # hp
|
||||
accuracy: float = Field(default=0.0) # od
|
||||
bpm: float = Field(default=0.0)
|
||||
count_circles: int = Field(default=0)
|
||||
count_sliders: int = Field(default=0)
|
||||
count_spinners: int = Field(default=0)
|
||||
deleted_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||||
hit_length: int = Field(default=0)
|
||||
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
|
||||
ar: OnDemand[float] = Field(default=0.0)
|
||||
cs: OnDemand[float] = Field(default=0.0)
|
||||
drain: OnDemand[float] = Field(default=0.0) # hp
|
||||
accuracy: OnDemand[float] = Field(default=0.0) # od
|
||||
bpm: OnDemand[float] = Field(default=0.0)
|
||||
count_circles: OnDemand[int] = Field(default=0)
|
||||
count_sliders: OnDemand[int] = Field(default=0)
|
||||
count_spinners: OnDemand[int] = Field(default=0)
|
||||
deleted_at: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime))
|
||||
hit_length: OnDemand[int] = Field(default=0)
|
||||
last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def status(_session: AsyncSession, beatmap: "Beatmap") -> str:
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap.beatmap_status.has_leaderboard():
|
||||
return BeatmapRankStatus.APPROVED.name.lower()
|
||||
return beatmap.beatmap_status.name.lower()
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def beatmapset(
|
||||
_session: AsyncSession,
|
||||
beatmap: "Beatmap",
|
||||
includes: list[str] | None = None,
|
||||
) -> BeatmapsetDict | None:
|
||||
if beatmap.beatmapset is not None:
|
||||
return await BeatmapsetModel.transform(
|
||||
beatmap.beatmapset, includes=(includes or []) + Beatmapset.BEATMAPSET_TRANSFORMER_INCLUDES
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def current_user_playcount(_session: AsyncSession, beatmap: "Beatmap", user: "User") -> int:
|
||||
playcount = (
|
||||
await _session.exec(
|
||||
select(BeatmapPlaycounts.playcount).where(
|
||||
BeatmapPlaycounts.beatmap_id == beatmap.id, BeatmapPlaycounts.user_id == user.id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
return int(playcount or 0)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def current_user_tag_ids(_session: AsyncSession, beatmap: "Beatmap", user: "User | None" = None) -> list[int]:
|
||||
if user is None:
|
||||
return []
|
||||
tag_ids = (
|
||||
await _session.exec(
|
||||
select(BeatmapTagVote.tag_id).where(
|
||||
BeatmapTagVote.beatmap_id == beatmap.id,
|
||||
BeatmapTagVote.user_id == user.id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
return list(tag_ids)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def failtimes(_session: AsyncSession, beatmap: "Beatmap") -> FailTimeResp:
|
||||
if beatmap.failtimes is not None:
|
||||
return FailTimeResp.from_db(beatmap.failtimes)
|
||||
return FailTimeResp()
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def top_tag_ids(_session: AsyncSession, beatmap: "Beatmap") -> list[dict[str, int]]:
|
||||
all_votes = (
|
||||
await _session.exec(
|
||||
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
|
||||
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
||||
.group_by(col(BeatmapTagVote.tag_id))
|
||||
.having(func.count() > settings.beatmap_tag_top_count)
|
||||
)
|
||||
).all()
|
||||
top_tag_ids: list[dict[str, int]] = []
|
||||
for id, votes in all_votes:
|
||||
top_tag_ids.append({"tag_id": id, "count": votes})
|
||||
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
|
||||
return top_tag_ids
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def user(
|
||||
_session: AsyncSession,
|
||||
beatmap: "Beatmap",
|
||||
includes: list[str] | None = None,
|
||||
) -> UserDict | None:
|
||||
from .user import User
|
||||
|
||||
user = await _session.get(User, beatmap.user_id)
|
||||
if user is None:
|
||||
return None
|
||||
return await UserModel.transform(user, includes=includes)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def convert(_session: AsyncSession, _beatmap: "Beatmap") -> bool:
|
||||
return False
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def is_scoreable(_session: AsyncSession, beatmap: "Beatmap") -> bool:
|
||||
beatmap_status = beatmap.beatmap_status
|
||||
if settings.enable_all_beatmap_leaderboard:
|
||||
return True
|
||||
return beatmap_status.has_leaderboard()
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def mode_int(_session: AsyncSession, beatmap: "Beatmap") -> int:
|
||||
return int(beatmap.mode)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def ranked(_session: AsyncSession, beatmap: "Beatmap") -> int:
|
||||
beatmap_status = beatmap.beatmap_status
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||
return BeatmapRankStatus.APPROVED.value
|
||||
return beatmap_status.value
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def playcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
|
||||
result = (
|
||||
await _session.exec(
|
||||
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
|
||||
)
|
||||
).first()
|
||||
return int(result or 0)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def passcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
|
||||
from .score import Score
|
||||
|
||||
return (
|
||||
await _session.exec(
|
||||
select(func.count())
|
||||
.select_from(Score)
|
||||
.where(
|
||||
Score.beatmap_id == beatmap.id,
|
||||
col(Score.passed).is_(True),
|
||||
)
|
||||
)
|
||||
).one()
|
||||
|
||||
|
||||
class Beatmap(BeatmapBase, table=True):
|
||||
class Beatmap(AsyncAttrs, BeatmapModel, table=True):
|
||||
__tablename__: str = "beatmaps"
|
||||
id: int = Field(primary_key=True, index=True)
|
||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||
|
||||
beatmap_status: BeatmapRankStatus = Field(index=True)
|
||||
# optional
|
||||
beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
|
||||
beatmapset: "Beatmapset" = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
|
||||
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@classmethod
|
||||
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||
d = resp.model_dump()
|
||||
del d["beatmapset"]
|
||||
async def from_resp_no_save(cls, _session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
|
||||
d = {k: v for k, v in resp.items() if k != "beatmapset"}
|
||||
beatmapset_id = resp.get("beatmapset_id")
|
||||
bid = resp.get("id")
|
||||
ranked = resp.get("ranked")
|
||||
if beatmapset_id is None or bid is None or ranked is None:
|
||||
raise ValueError("beatmapset_id, id and ranked are required")
|
||||
beatmap = cls.model_validate(
|
||||
{
|
||||
**d,
|
||||
"beatmapset_id": resp.beatmapset_id,
|
||||
"id": resp.id,
|
||||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
||||
"beatmapset_id": beatmapset_id,
|
||||
"id": bid,
|
||||
"beatmap_status": BeatmapRankStatus(ranked),
|
||||
}
|
||||
)
|
||||
return beatmap
|
||||
|
||||
@classmethod
|
||||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||
async def from_resp(cls, session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
|
||||
beatmap = await cls.from_resp_no_save(session, resp)
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||||
resp_id = resp.get("id")
|
||||
if resp_id is None:
|
||||
raise ValueError("id is required")
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp_id))).first():
|
||||
session.add(beatmap)
|
||||
await session.commit()
|
||||
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
|
||||
return (await session.exec(select(Beatmap).where(Beatmap.id == resp_id))).one()
|
||||
|
||||
@classmethod
|
||||
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
|
||||
async def from_resp_batch(cls, session: AsyncSession, inp: list[BeatmapDict], from_: int = 0) -> list["Beatmap"]:
|
||||
beatmaps = []
|
||||
for resp in inp:
|
||||
if resp.id == from_:
|
||||
for resp_dict in inp:
|
||||
bid = resp_dict.get("id")
|
||||
if bid == from_ or bid is None:
|
||||
continue
|
||||
d = resp.model_dump()
|
||||
del d["beatmapset"]
|
||||
|
||||
beatmapset_id = resp_dict.get("beatmapset_id")
|
||||
ranked = resp_dict.get("ranked")
|
||||
if beatmapset_id is None or ranked is None:
|
||||
continue
|
||||
|
||||
# 创建 beatmap 字典,移除 beatmapset
|
||||
d = {k: v for k, v in resp_dict.items() if k != "beatmapset"}
|
||||
|
||||
beatmap = Beatmap.model_validate(
|
||||
{
|
||||
**d,
|
||||
"beatmapset_id": resp.beatmapset_id,
|
||||
"id": resp.id,
|
||||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
||||
"beatmapset_id": beatmapset_id,
|
||||
"id": bid,
|
||||
"beatmap_status": BeatmapRankStatus(ranked),
|
||||
}
|
||||
)
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == bid))).first():
|
||||
session.add(beatmap)
|
||||
beatmaps.append(beatmap)
|
||||
await session.commit()
|
||||
for beatmap in beatmaps:
|
||||
await session.refresh(beatmap)
|
||||
return beatmaps
|
||||
|
||||
@classmethod
|
||||
@@ -132,10 +354,14 @@ class Beatmap(BeatmapBase, table=True):
|
||||
beatmap = (await session.exec(stmt)).first()
|
||||
if not beatmap:
|
||||
resp = await fetcher.get_beatmap(bid, md5)
|
||||
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
|
||||
beatmapset_id = resp.get("beatmapset_id")
|
||||
if beatmapset_id is None:
|
||||
raise ValueError("beatmapset_id is required")
|
||||
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == beatmapset_id))
|
||||
if not r.first():
|
||||
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
|
||||
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
|
||||
set_resp = await fetcher.get_beatmapset(beatmapset_id)
|
||||
resp_id = resp.get("id")
|
||||
await Beatmapset.from_resp(session, set_resp, from_=resp_id or 0)
|
||||
return await Beatmap.from_resp(session, resp)
|
||||
return beatmap
|
||||
|
||||
@@ -145,97 +371,6 @@ class APIBeatmapTag(BaseModel):
|
||||
count: int
|
||||
|
||||
|
||||
class BeatmapResp(BeatmapBase):
|
||||
id: int
|
||||
beatmapset_id: int
|
||||
beatmapset: BeatmapsetResp | None = None
|
||||
convert: bool = False
|
||||
is_scoreable: bool
|
||||
status: str
|
||||
mode_int: int
|
||||
ranked: int
|
||||
url: str = ""
|
||||
playcount: int = 0
|
||||
passcount: int = 0
|
||||
failtimes: FailTimeResp | None = None
|
||||
top_tag_ids: list[APIBeatmapTag] | None = None
|
||||
current_user_tag_ids: list[int] | None = None
|
||||
is_deleted: bool = False
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
beatmap: Beatmap,
|
||||
query_mode: GameMode | None = None,
|
||||
from_set: bool = False,
|
||||
session: AsyncSession | None = None,
|
||||
user: "User | None" = None,
|
||||
) -> "BeatmapResp":
|
||||
from .score import Score
|
||||
|
||||
beatmap_ = beatmap.model_dump()
|
||||
beatmap_status = beatmap.beatmap_status
|
||||
if query_mode is not None and beatmap.mode != query_mode:
|
||||
beatmap_["convert"] = True
|
||||
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
|
||||
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
|
||||
else:
|
||||
beatmap_["status"] = beatmap_status.name.lower()
|
||||
beatmap_["ranked"] = beatmap_status.value
|
||||
beatmap_["mode_int"] = int(beatmap.mode)
|
||||
beatmap_["is_deleted"] = beatmap.deleted_at is not None
|
||||
if not from_set:
|
||||
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
|
||||
if beatmap.failtimes is not None:
|
||||
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
|
||||
else:
|
||||
beatmap_["failtimes"] = FailTimeResp()
|
||||
if session:
|
||||
beatmap_["playcount"] = (
|
||||
await session.exec(
|
||||
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
|
||||
)
|
||||
).first() or 0
|
||||
beatmap_["passcount"] = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(Score)
|
||||
.where(
|
||||
Score.beatmap_id == beatmap.id,
|
||||
col(Score.passed).is_(True),
|
||||
)
|
||||
)
|
||||
).one()
|
||||
|
||||
all_votes = (
|
||||
await session.exec(
|
||||
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
|
||||
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
||||
.group_by(col(BeatmapTagVote.tag_id))
|
||||
.having(func.count() > settings.beatmap_tag_top_count)
|
||||
)
|
||||
).all()
|
||||
top_tag_ids: list[dict[str, int]] = []
|
||||
for id, votes in all_votes:
|
||||
top_tag_ids.append({"tag_id": id, "count": votes})
|
||||
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
|
||||
beatmap_["top_tag_ids"] = top_tag_ids
|
||||
|
||||
if user is not None:
|
||||
beatmap_["current_user_tag_ids"] = (
|
||||
await session.exec(
|
||||
select(BeatmapTagVote.tag_id)
|
||||
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
||||
.where(BeatmapTagVote.user_id == user.id)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
beatmap_["current_user_tag_ids"] = []
|
||||
return cls.model_validate(beatmap_)
|
||||
|
||||
|
||||
class BannedBeatmaps(SQLModel, table=True):
|
||||
__tablename__: str = "banned_beatmaps"
|
||||
id: int | None = Field(primary_key=True, index=True, default=None)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, NotRequired, TypedDict
|
||||
|
||||
from app.config import settings
|
||||
from app.database.events import Event, EventType
|
||||
from app.utils import utcnow
|
||||
|
||||
from pydantic import BaseModel
|
||||
from ._base import DatabaseModel, included
|
||||
from .events import Event, EventType
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
@@ -12,52 +13,65 @@ from sqlmodel import (
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .beatmapset import BeatmapsetResp
|
||||
from .beatmap import Beatmap, BeatmapDict
|
||||
from .beatmapset import BeatmapsetDict
|
||||
from .user import User
|
||||
|
||||
|
||||
class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True):
|
||||
class BeatmapPlaycountsDict(TypedDict):
|
||||
user_id: int
|
||||
beatmap_id: int
|
||||
count: NotRequired[int]
|
||||
beatmap: NotRequired["BeatmapDict"]
|
||||
beatmapset: NotRequired["BeatmapsetDict"]
|
||||
|
||||
|
||||
class BeatmapPlaycountsModel(AsyncAttrs, DatabaseModel[BeatmapPlaycountsDict]):
|
||||
__tablename__: str = "beatmap_playcounts"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
|
||||
default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True), exclude=True
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
||||
playcount: int = Field(default=0)
|
||||
playcount: int = Field(default=0, exclude=True)
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def count(_session: AsyncSession, obj: "BeatmapPlaycounts") -> int:
|
||||
return obj.playcount
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def beatmap(
|
||||
_session: AsyncSession, obj: "BeatmapPlaycounts", includes: list[str] | None = None
|
||||
) -> "BeatmapDict":
|
||||
from .beatmap import BeatmapModel
|
||||
|
||||
await obj.awaitable_attrs.beatmap
|
||||
return await BeatmapModel.transform(obj.beatmap, includes=includes)
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def beatmapset(
|
||||
_session: AsyncSession, obj: "BeatmapPlaycounts", includes: list[str] | None = None
|
||||
) -> "BeatmapsetDict":
|
||||
from .beatmap import BeatmapsetModel
|
||||
|
||||
await obj.awaitable_attrs.beatmap
|
||||
return await BeatmapsetModel.transform(obj.beatmap.beatmapset, includes=includes)
|
||||
|
||||
|
||||
class BeatmapPlaycounts(BeatmapPlaycountsModel, table=True):
|
||||
user: "User" = Relationship()
|
||||
beatmap: "Beatmap" = Relationship()
|
||||
|
||||
|
||||
class BeatmapPlaycountsResp(BaseModel):
|
||||
beatmap_id: int
|
||||
beatmap: "BeatmapResp | None" = None
|
||||
beatmapset: "BeatmapsetResp | None" = None
|
||||
count: int
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, db_model: BeatmapPlaycounts) -> "BeatmapPlaycountsResp":
|
||||
from .beatmap import BeatmapResp
|
||||
from .beatmapset import BeatmapsetResp
|
||||
|
||||
await db_model.awaitable_attrs.beatmap
|
||||
return cls(
|
||||
beatmap_id=db_model.beatmap_id,
|
||||
count=db_model.playcount,
|
||||
beatmap=await BeatmapResp.from_db(db_model.beatmap),
|
||||
beatmapset=await BeatmapsetResp.from_db(db_model.beatmap.beatmapset),
|
||||
)
|
||||
|
||||
|
||||
async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None:
|
||||
existing_playcount = (
|
||||
await session.exec(
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, NotRequired, Self, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict
|
||||
|
||||
from app.config import settings
|
||||
from app.database.beatmap_playcounts import BeatmapPlaycounts
|
||||
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .user import BASE_INCLUDES, User, UserResp
|
||||
from ._base import DatabaseModel, OnDemand, included, ondemand
|
||||
from .beatmap_playcounts import BeatmapPlaycounts
|
||||
from .user import User, UserDict
|
||||
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON, Boolean, Column, DateTime, Text
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
|
||||
@@ -17,7 +18,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
if TYPE_CHECKING:
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .beatmap import Beatmap, BeatmapDict
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
|
||||
|
||||
@@ -68,8 +69,99 @@ class BeatmapTranslationText(BaseModel):
|
||||
id: int | None = None
|
||||
|
||||
|
||||
class BeatmapsetBase(SQLModel):
|
||||
class BeatmapsetDict(TypedDict):
|
||||
id: int
|
||||
artist: str
|
||||
artist_unicode: str
|
||||
covers: BeatmapCovers | None
|
||||
creator: str
|
||||
nsfw: bool
|
||||
preview_url: str
|
||||
source: str
|
||||
spotlight: bool
|
||||
title: str
|
||||
title_unicode: str
|
||||
track_id: int | None
|
||||
user_id: int
|
||||
video: bool
|
||||
current_nominations: list[BeatmapNomination] | None
|
||||
description: BeatmapDescription | None
|
||||
pack_tags: list[str]
|
||||
|
||||
bpm: NotRequired[float]
|
||||
can_be_hyped: NotRequired[bool]
|
||||
discussion_locked: NotRequired[bool]
|
||||
last_updated: NotRequired[datetime]
|
||||
ranked_date: NotRequired[datetime | None]
|
||||
storyboard: NotRequired[bool]
|
||||
submitted_date: NotRequired[datetime]
|
||||
tags: NotRequired[str]
|
||||
discussion_enabled: NotRequired[bool]
|
||||
legacy_thread_url: NotRequired[str | None]
|
||||
status: NotRequired[str]
|
||||
ranked: NotRequired[int]
|
||||
is_scoreable: NotRequired[bool]
|
||||
favourite_count: NotRequired[int]
|
||||
genre_id: NotRequired[int]
|
||||
hype: NotRequired[BeatmapHype]
|
||||
language_id: NotRequired[int]
|
||||
play_count: NotRequired[int]
|
||||
availability: NotRequired[BeatmapAvailability]
|
||||
beatmaps: NotRequired[list["BeatmapDict"]]
|
||||
has_favourited: NotRequired[bool]
|
||||
recent_favourites: NotRequired[list[UserDict]]
|
||||
genre: NotRequired[BeatmapTranslationText]
|
||||
language: NotRequired[BeatmapTranslationText]
|
||||
nominations: NotRequired["BeatmapNominations"]
|
||||
ratings: NotRequired[list[int]]
|
||||
|
||||
|
||||
class BeatmapsetModel(DatabaseModel[BeatmapsetDict]):
|
||||
BEATMAPSET_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
|
||||
"availability",
|
||||
"has_favourited",
|
||||
"bpm",
|
||||
"deleted_atcan_be_hyped",
|
||||
"discussion_locked",
|
||||
"is_scoreable",
|
||||
"last_updated",
|
||||
"legacy_thread_url",
|
||||
"ranked",
|
||||
"ranked_date",
|
||||
"submitted_date",
|
||||
"tags",
|
||||
"rating",
|
||||
"storyboard",
|
||||
]
|
||||
API_INCLUDES: ClassVar[list[str]] = [
|
||||
*BEATMAPSET_TRANSFORMER_INCLUDES,
|
||||
"beatmaps.current_user_playcount",
|
||||
"beatmaps.current_user_tag_ids",
|
||||
"beatmaps.max_combo",
|
||||
"current_nominations",
|
||||
"current_user_attributes",
|
||||
"description",
|
||||
"genre",
|
||||
"language",
|
||||
"pack_tags",
|
||||
"ratings",
|
||||
"recent_favourites",
|
||||
"related_tags",
|
||||
"related_users",
|
||||
"user",
|
||||
"version_count",
|
||||
*[
|
||||
f"beatmaps.{inc}"
|
||||
for inc in {
|
||||
"failtimes",
|
||||
"owners",
|
||||
"top_tag_ids",
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
# Beatmapset
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
artist: str = Field(index=True)
|
||||
artist_unicode: str = Field(index=True)
|
||||
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
|
||||
@@ -77,41 +169,285 @@ class BeatmapsetBase(SQLModel):
|
||||
nsfw: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
preview_url: str
|
||||
source: str = Field(default="")
|
||||
|
||||
spotlight: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
title: str = Field(index=True)
|
||||
title_unicode: str = Field(index=True)
|
||||
track_id: int | None = Field(default=None, index=True) # feature artist?
|
||||
user_id: int = Field(index=True)
|
||||
video: bool = Field(sa_column=Column(Boolean, index=True))
|
||||
|
||||
# optional
|
||||
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
|
||||
current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON))
|
||||
description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON))
|
||||
current_nominations: OnDemand[list[BeatmapNomination] | None] = Field(None, sa_column=Column(JSON))
|
||||
description: OnDemand[BeatmapDescription | None] = Field(default=None, sa_column=Column(JSON))
|
||||
# TODO: discussions: list[BeatmapsetDiscussion] = None
|
||||
# TODO: current_user_attributes: Optional[CurrentUserAttributes] = None
|
||||
# TODO: events: Optional[list[BeatmapsetEvent]] = None
|
||||
|
||||
pack_tags: list[str] = Field(default=[], sa_column=Column(JSON))
|
||||
pack_tags: OnDemand[list[str]] = Field(default=[], sa_column=Column(JSON))
|
||||
# TODO: related_users: Optional[list[User]] = None
|
||||
# TODO: user: Optional[User] = Field(default=None)
|
||||
track_id: int | None = Field(default=None, index=True) # feature artist?
|
||||
|
||||
# BeatmapsetExtended
|
||||
bpm: float = Field(default=0.0)
|
||||
can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
|
||||
ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True))
|
||||
storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
|
||||
submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
|
||||
tags: str = Field(default="", sa_column=Column(Text))
|
||||
bpm: OnDemand[float] = Field(default=0.0)
|
||||
can_be_hyped: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean))
|
||||
discussion_locked: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean))
|
||||
last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
|
||||
ranked_date: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime, index=True))
|
||||
storyboard: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean, index=True))
|
||||
submitted_date: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
|
||||
tags: OnDemand[str] = Field(default="", sa_column=Column(Text))
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def legacy_thread_url(
|
||||
_session: AsyncSession,
|
||||
_beatmapset: "Beatmapset",
|
||||
) -> str | None:
|
||||
return None
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def discussion_enabled(
|
||||
_session: AsyncSession,
|
||||
_beatmapset: "Beatmapset",
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def status(
|
||||
_session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> str:
|
||||
beatmap_status = beatmapset.beatmap_status
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||
return BeatmapRankStatus.APPROVED.name.lower()
|
||||
return beatmap_status.name.lower()
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def ranked(
|
||||
_session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> int:
|
||||
beatmap_status = beatmapset.beatmap_status
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||
return BeatmapRankStatus.APPROVED.value
|
||||
return beatmap_status.value
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def is_scoreable(
|
||||
_session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> bool:
|
||||
beatmap_status = beatmapset.beatmap_status
|
||||
if settings.enable_all_beatmap_leaderboard:
|
||||
return True
|
||||
return beatmap_status.has_leaderboard()
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def favourite_count(
|
||||
session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> int:
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
|
||||
count = await session.exec(
|
||||
select(func.count())
|
||||
.select_from(FavouriteBeatmapset)
|
||||
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
||||
)
|
||||
return count.one()
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def genre_id(
|
||||
_session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> int:
|
||||
return beatmapset.beatmap_genre.value
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def hype(
|
||||
_session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> BeatmapHype:
|
||||
return BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required)
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def language_id(
|
||||
_session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> int:
|
||||
return beatmapset.beatmap_language.value
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def play_count(
|
||||
session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> int:
|
||||
from .beatmap import Beatmap
|
||||
|
||||
playcount = await session.exec(
|
||||
select(func.sum(BeatmapPlaycounts.playcount)).where(
|
||||
col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id)
|
||||
)
|
||||
)
|
||||
return int(playcount.first() or 0)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def availability(
|
||||
_session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> BeatmapAvailability:
|
||||
return BeatmapAvailability(
|
||||
more_information=beatmapset.availability_info,
|
||||
download_disabled=beatmapset.download_disabled,
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def beatmaps(
|
||||
_session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
includes: list[str] | None = None,
|
||||
user: "User | None" = None,
|
||||
) -> list["BeatmapDict"]:
|
||||
from .beatmap import BeatmapModel
|
||||
|
||||
return [
|
||||
await BeatmapModel.transform(
|
||||
beatmap, includes=(includes or []) + BeatmapModel.BEATMAP_TRANSFORMER_INCLUDES, user=user
|
||||
)
|
||||
for beatmap in await beatmapset.awaitable_attrs.beatmaps
|
||||
]
|
||||
|
||||
# @ondemand
|
||||
# @staticmethod
|
||||
# async def current_nominations(
|
||||
# _session: AsyncSession,
|
||||
# beatmapset: "Beatmapset",
|
||||
# ) -> list[BeatmapNomination] | None:
|
||||
# return beatmapset.current_nominations or []
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def has_favourited(
|
||||
session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
user: User | None = None,
|
||||
) -> bool:
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
|
||||
if session is None:
|
||||
return False
|
||||
query = select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
||||
if user is not None:
|
||||
query = query.where(FavouriteBeatmapset.user_id == user.id)
|
||||
existing = (await session.exec(query)).first()
|
||||
return existing is not None
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def recent_favourites(
|
||||
session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
includes: list[str] | None = None,
|
||||
) -> list[UserDict]:
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
|
||||
recent_favourites = (
|
||||
await session.exec(
|
||||
select(FavouriteBeatmapset)
|
||||
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
||||
.order_by(col(FavouriteBeatmapset.date).desc())
|
||||
.limit(50)
|
||||
)
|
||||
).all()
|
||||
return [
|
||||
await User.transform(
|
||||
(await favourite.awaitable_attrs.user),
|
||||
includes=includes,
|
||||
)
|
||||
for favourite in recent_favourites
|
||||
]
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def genre(
|
||||
_session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> BeatmapTranslationText:
|
||||
return BeatmapTranslationText(
|
||||
name=beatmapset.beatmap_genre.name,
|
||||
id=beatmapset.beatmap_genre.value,
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def language(
|
||||
_session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> BeatmapTranslationText:
|
||||
return BeatmapTranslationText(
|
||||
name=beatmapset.beatmap_language.name,
|
||||
id=beatmapset.beatmap_language.value,
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def nominations(
|
||||
_session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> BeatmapNominations:
|
||||
return BeatmapNominations(
|
||||
required=beatmapset.nominations_required,
|
||||
current=beatmapset.nominations_current,
|
||||
)
|
||||
|
||||
# @ondemand
|
||||
# @staticmethod
|
||||
# async def user(
|
||||
# session: AsyncSession,
|
||||
# beatmapset: Beatmapset,
|
||||
# includes: list[str] | None = None,
|
||||
# ) -> dict[str, Any] | None:
|
||||
# db_user = await session.get(User, beatmapset.user_id)
|
||||
# if not db_user:
|
||||
# return None
|
||||
# return await UserResp.transform(db_user, includes=includes)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def ratings(
|
||||
session: AsyncSession,
|
||||
beatmapset: "Beatmapset",
|
||||
) -> list[int]:
|
||||
# Provide a stable default shape if no session is available
|
||||
if session is None:
|
||||
return []
|
||||
|
||||
from .beatmapset_ratings import BeatmapRating
|
||||
|
||||
beatmapset_all_ratings = (
|
||||
await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id))
|
||||
).all()
|
||||
ratings_list = [0] * 11
|
||||
for rating in beatmapset_all_ratings:
|
||||
ratings_list[rating.rating] += 1
|
||||
return ratings_list
|
||||
|
||||
|
||||
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
class Beatmapset(AsyncAttrs, BeatmapsetModel, table=True):
|
||||
__tablename__: str = "beatmapsets"
|
||||
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
# Beatmapset
|
||||
beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True)
|
||||
|
||||
@@ -130,29 +466,45 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
|
||||
|
||||
@classmethod
|
||||
async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
|
||||
d = resp.model_dump()
|
||||
if resp.nominations:
|
||||
d["nominations_required"] = resp.nominations.required
|
||||
d["nominations_current"] = resp.nominations.current
|
||||
if resp.hype:
|
||||
d["hype_current"] = resp.hype.current
|
||||
d["hype_required"] = resp.hype.required
|
||||
if resp.genre_id:
|
||||
d["beatmap_genre"] = Genre(resp.genre_id)
|
||||
elif resp.genre:
|
||||
d["beatmap_genre"] = Genre(resp.genre.id)
|
||||
if resp.language_id:
|
||||
d["beatmap_language"] = Language(resp.language_id)
|
||||
elif resp.language:
|
||||
d["beatmap_language"] = Language(resp.language.id)
|
||||
async def from_resp_no_save(cls, resp: BeatmapsetDict) -> "Beatmapset":
|
||||
# make a shallow copy so we can mutate safely
|
||||
d: dict[str, Any] = dict(resp)
|
||||
|
||||
# nominations = resp.get("nominations")
|
||||
# if nominations is not None:
|
||||
# d["nominations_required"] = nominations.required
|
||||
# d["nominations_current"] = nominations.current
|
||||
|
||||
hype = resp.get("hype")
|
||||
if hype is not None:
|
||||
d["hype_current"] = hype.current
|
||||
d["hype_required"] = hype.required
|
||||
|
||||
genre_id = resp.get("genre_id")
|
||||
genre = resp.get("genre")
|
||||
if genre_id is not None:
|
||||
d["beatmap_genre"] = Genre(genre_id)
|
||||
elif genre is not None:
|
||||
d["beatmap_genre"] = Genre(genre.id)
|
||||
|
||||
language_id = resp.get("language_id")
|
||||
language = resp.get("language")
|
||||
if language_id is not None:
|
||||
d["beatmap_language"] = Language(language_id)
|
||||
elif language is not None:
|
||||
d["beatmap_language"] = Language(language.id)
|
||||
|
||||
availability = resp.get("availability")
|
||||
ranked = resp.get("ranked")
|
||||
if ranked is None:
|
||||
raise ValueError("ranked field is required")
|
||||
|
||||
beatmapset = Beatmapset.model_validate(
|
||||
{
|
||||
**d,
|
||||
"id": resp.id,
|
||||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
||||
"availability_info": resp.availability.more_information,
|
||||
"download_disabled": resp.availability.download_disabled or False,
|
||||
"beatmap_status": BeatmapRankStatus(ranked),
|
||||
"availability_info": availability.more_information if availability is not None else None,
|
||||
"download_disabled": bool(availability.download_disabled) if availability is not None else False,
|
||||
}
|
||||
)
|
||||
return beatmapset
|
||||
@@ -161,17 +513,19 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
async def from_resp(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
resp: "BeatmapsetResp",
|
||||
resp: BeatmapsetDict,
|
||||
from_: int = 0,
|
||||
) -> "Beatmapset":
|
||||
from .beatmap import Beatmap
|
||||
|
||||
beatmapset_id = resp["id"]
|
||||
beatmapset = await cls.from_resp_no_save(resp)
|
||||
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
|
||||
if not (await session.exec(select(exists()).where(Beatmapset.id == beatmapset_id))).first():
|
||||
session.add(beatmapset)
|
||||
await session.commit()
|
||||
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
|
||||
beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == resp.id))).one()
|
||||
beatmaps = resp.get("beatmaps", [])
|
||||
await Beatmap.from_resp_batch(session, beatmaps, from_=from_)
|
||||
beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id))).one()
|
||||
return beatmapset
|
||||
|
||||
@classmethod
|
||||
@@ -183,170 +537,5 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
resp = await fetcher.get_beatmapset(sid)
|
||||
beatmapset = await cls.from_resp(session, resp)
|
||||
await get_beatmapset_update_service().add(resp)
|
||||
await session.refresh(beatmapset)
|
||||
return beatmapset
|
||||
|
||||
|
||||
class BeatmapsetResp(BeatmapsetBase):
|
||||
id: int
|
||||
beatmaps: list["BeatmapResp"] = Field(default_factory=list)
|
||||
discussion_enabled: bool = True
|
||||
status: str
|
||||
ranked: int
|
||||
legacy_thread_url: str | None = ""
|
||||
is_scoreable: bool
|
||||
hype: BeatmapHype | None = None
|
||||
availability: BeatmapAvailability
|
||||
genre: BeatmapTranslationText | None = None
|
||||
genre_id: int
|
||||
language: BeatmapTranslationText | None = None
|
||||
language_id: int
|
||||
nominations: BeatmapNominations | None = None
|
||||
has_favourited: bool = False
|
||||
favourite_count: int = 0
|
||||
recent_favourites: list[UserResp] = Field(default_factory=list)
|
||||
play_count: int = 0
|
||||
|
||||
@field_validator(
|
||||
"nsfw",
|
||||
"spotlight",
|
||||
"video",
|
||||
"can_be_hyped",
|
||||
"discussion_locked",
|
||||
"storyboard",
|
||||
"discussion_enabled",
|
||||
"is_scoreable",
|
||||
"has_favourited",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def validate_bool_fields(cls, v):
|
||||
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
|
||||
if isinstance(v, int):
|
||||
return bool(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def fix_genre_language(self) -> Self:
|
||||
if self.genre is None:
|
||||
self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id)
|
||||
if self.language is None:
|
||||
self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
beatmapset: Beatmapset,
|
||||
include: list[str] = [],
|
||||
session: AsyncSession | None = None,
|
||||
user: User | None = None,
|
||||
) -> "BeatmapsetResp":
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .favourite_beatmapset import FavouriteBeatmapset
|
||||
|
||||
update = {
|
||||
"beatmaps": [
|
||||
await BeatmapResp.from_db(beatmap, from_set=True, session=session)
|
||||
for beatmap in await beatmapset.awaitable_attrs.beatmaps
|
||||
],
|
||||
"hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required),
|
||||
"availability": BeatmapAvailability(
|
||||
more_information=beatmapset.availability_info,
|
||||
download_disabled=beatmapset.download_disabled,
|
||||
),
|
||||
"genre": BeatmapTranslationText(
|
||||
name=beatmapset.beatmap_genre.name,
|
||||
id=beatmapset.beatmap_genre.value,
|
||||
),
|
||||
"language": BeatmapTranslationText(
|
||||
name=beatmapset.beatmap_language.name,
|
||||
id=beatmapset.beatmap_language.value,
|
||||
),
|
||||
"genre_id": beatmapset.beatmap_genre.value,
|
||||
"language_id": beatmapset.beatmap_language.value,
|
||||
"nominations": BeatmapNominations(
|
||||
required=beatmapset.nominations_required,
|
||||
current=beatmapset.nominations_current,
|
||||
),
|
||||
"is_scoreable": beatmapset.beatmap_status.has_leaderboard(),
|
||||
**beatmapset.model_dump(),
|
||||
}
|
||||
|
||||
if session is not None:
|
||||
# 从数据库读取对应谱面集的评分
|
||||
from .beatmapset_ratings import BeatmapRating
|
||||
|
||||
beatmapset_all_ratings = (
|
||||
await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id))
|
||||
).all()
|
||||
ratings_list = [0] * 11
|
||||
for rating in beatmapset_all_ratings:
|
||||
ratings_list[rating.rating] += 1
|
||||
update["ratings"] = ratings_list
|
||||
else:
|
||||
# 返回非空值避免客户端崩溃
|
||||
if update.get("ratings") is None:
|
||||
update["ratings"] = []
|
||||
|
||||
beatmap_status = beatmapset.beatmap_status
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||
update["status"] = BeatmapRankStatus.APPROVED.name.lower()
|
||||
update["ranked"] = BeatmapRankStatus.APPROVED.value
|
||||
else:
|
||||
update["status"] = beatmap_status.name.lower()
|
||||
update["ranked"] = beatmap_status.value
|
||||
|
||||
if session and user:
|
||||
existing_favourite = (
|
||||
await session.exec(
|
||||
select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
||||
)
|
||||
).first()
|
||||
update["has_favourited"] = existing_favourite is not None
|
||||
|
||||
if session and "recent_favourites" in include:
|
||||
recent_favourites = (
|
||||
await session.exec(
|
||||
select(FavouriteBeatmapset)
|
||||
.where(
|
||||
FavouriteBeatmapset.beatmapset_id == beatmapset.id,
|
||||
)
|
||||
.order_by(col(FavouriteBeatmapset.date).desc())
|
||||
.limit(50)
|
||||
)
|
||||
).all()
|
||||
update["recent_favourites"] = [
|
||||
await UserResp.from_db(
|
||||
await favourite.awaitable_attrs.user,
|
||||
session=session,
|
||||
include=BASE_INCLUDES,
|
||||
)
|
||||
for favourite in recent_favourites
|
||||
]
|
||||
|
||||
if session:
|
||||
update["favourite_count"] = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(FavouriteBeatmapset)
|
||||
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
||||
)
|
||||
).one()
|
||||
|
||||
update["play_count"] = (
|
||||
await session.exec(
|
||||
select(func.sum(BeatmapPlaycounts.playcount)).where(
|
||||
col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id)
|
||||
)
|
||||
)
|
||||
).first() or 0
|
||||
return cls.model_validate(
|
||||
update,
|
||||
)
|
||||
|
||||
|
||||
class SearchBeatmapsetsResp(SQLModel):
|
||||
beatmapsets: list[BeatmapsetResp]
|
||||
total: int
|
||||
cursor: dict[str, int | float | str] | None = None
|
||||
cursor_string: str | None = None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.user import User
|
||||
from .beatmapset import Beatmapset
|
||||
from .user import User
|
||||
|
||||
from sqlmodel import BigInteger, Column, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .statistics import UserStatistics
|
||||
from .user import User
|
||||
|
||||
from sqlmodel import (
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Self
|
||||
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
|
||||
|
||||
from app.database.user import RANKING_INCLUDES, User, UserResp
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.utils import utcnow
|
||||
|
||||
from ._base import DatabaseModel, included, ondemand
|
||||
from .user import User, UserDict, UserModel
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import (
|
||||
VARCHAR,
|
||||
BigInteger,
|
||||
@@ -22,6 +23,8 @@ from sqlmodel import (
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.router.notification.server import ChatServer
|
||||
# ChatChannel
|
||||
|
||||
|
||||
@@ -44,16 +47,168 @@ class ChannelType(str, Enum):
|
||||
TEAM = "TEAM"
|
||||
|
||||
|
||||
class ChatChannelBase(SQLModel):
|
||||
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
|
||||
class MessageType(str, Enum):
|
||||
ACTION = "action"
|
||||
MARKDOWN = "markdown"
|
||||
PLAIN = "plain"
|
||||
|
||||
|
||||
class ChatChannelDict(TypedDict):
|
||||
channel_id: int
|
||||
description: str
|
||||
name: str
|
||||
icon: str | None
|
||||
type: ChannelType
|
||||
uuid: NotRequired[str | None]
|
||||
message_length_limit: NotRequired[int]
|
||||
moderated: NotRequired[bool]
|
||||
current_user_attributes: NotRequired[ChatUserAttributes]
|
||||
last_read_id: NotRequired[int | None]
|
||||
last_message_id: NotRequired[int | None]
|
||||
recent_messages: NotRequired[list["ChatMessageDict"]]
|
||||
users: NotRequired[list[int]]
|
||||
|
||||
|
||||
class ChatChannelModel(DatabaseModel[ChatChannelDict]):
|
||||
CONVERSATION_INCLUDES: ClassVar[list[str]] = [
|
||||
"last_message_id",
|
||||
"users",
|
||||
]
|
||||
LISTING_INCLUDES: ClassVar[list[str]] = [
|
||||
*CONVERSATION_INCLUDES,
|
||||
"current_user_attributes",
|
||||
"last_read_id",
|
||||
]
|
||||
|
||||
channel_id: int = Field(primary_key=True, index=True, default=None)
|
||||
description: str = Field(sa_column=Column(VARCHAR(255), index=True))
|
||||
icon: str | None = Field(default=None)
|
||||
type: ChannelType = Field(index=True)
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def name(session: AsyncSession, channel: "ChatChannel", user: User, server: "ChatServer") -> str:
|
||||
users = server.channels.get(channel.channel_id, [])
|
||||
if channel.type == ChannelType.PM and users and len(users) == 2:
|
||||
target_user_id = next(u for u in users if u != user.id)
|
||||
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
|
||||
return target_name.one()
|
||||
return channel.name
|
||||
|
||||
class ChatChannel(ChatChannelBase, table=True):
|
||||
@included
|
||||
@staticmethod
|
||||
async def moderated(session: AsyncSession, channel: "ChatChannel", user: User) -> bool:
|
||||
silence = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
SilenceUser.channel_id == channel.channel_id,
|
||||
SilenceUser.user_id == user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
return silence is not None
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def current_user_attributes(
|
||||
session: AsyncSession,
|
||||
channel: "ChatChannel",
|
||||
user: User,
|
||||
) -> ChatUserAttributes:
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
silence = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
SilenceUser.channel_id == channel.channel_id,
|
||||
SilenceUser.user_id == user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
can_message = silence is None
|
||||
can_message_error = "You are silenced in this channel" if not can_message else None
|
||||
|
||||
redis = get_redis()
|
||||
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else (last_msg or 0)
|
||||
|
||||
return ChatUserAttributes(
|
||||
can_message=can_message,
|
||||
can_message_error=can_message_error,
|
||||
last_read_id=last_read_id,
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def last_read_id(_session: AsyncSession, channel: "ChatChannel", user: User) -> int | None:
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
redis = get_redis()
|
||||
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||
return int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def last_message_id(_session: AsyncSession, channel: "ChatChannel") -> int | None:
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
redis = get_redis()
|
||||
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
return int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def recent_messages(
|
||||
session: AsyncSession,
|
||||
channel: "ChatChannel",
|
||||
) -> list["ChatMessageDict"]:
|
||||
messages = (
|
||||
await session.exec(
|
||||
select(ChatMessage)
|
||||
.where(ChatMessage.channel_id == channel.channel_id)
|
||||
.order_by(col(ChatMessage.message_id).desc())
|
||||
.limit(50)
|
||||
)
|
||||
).all()
|
||||
result = [
|
||||
await ChatMessageModel.transform(
|
||||
msg,
|
||||
)
|
||||
for msg in reversed(messages)
|
||||
]
|
||||
return result
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def users(
|
||||
_session: AsyncSession,
|
||||
channel: "ChatChannel",
|
||||
server: "ChatServer",
|
||||
user: User,
|
||||
) -> list[int]:
|
||||
if channel.type == ChannelType.PUBLIC:
|
||||
return []
|
||||
users = server.channels.get(channel.channel_id, []).copy()
|
||||
if channel.type == ChannelType.PM and users and len(users) == 2:
|
||||
target_user_id = next(u for u in users if u != user.id)
|
||||
users = [target_user_id, user.id]
|
||||
return users
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def message_length_limit(_session: AsyncSession, _channel: "ChatChannel") -> int:
|
||||
return 1000
|
||||
|
||||
|
||||
class ChatChannel(ChatChannelModel, table=True):
|
||||
__tablename__: str = "chat_channels"
|
||||
channel_id: int = Field(primary_key=True, index=True, default=None)
|
||||
|
||||
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
|
||||
|
||||
@classmethod
|
||||
async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
|
||||
@@ -74,93 +229,20 @@ class ChatChannel(ChatChannelBase, table=True):
|
||||
return channel
|
||||
|
||||
|
||||
class ChatChannelResp(ChatChannelBase):
|
||||
channel_id: int
|
||||
moderated: bool = False
|
||||
uuid: str | None = None
|
||||
current_user_attributes: ChatUserAttributes | None = None
|
||||
last_read_id: int | None = None
|
||||
last_message_id: int | None = None
|
||||
recent_messages: list["ChatMessageResp"] = Field(default_factory=list)
|
||||
users: list[int] = Field(default_factory=list)
|
||||
message_length_limit: int = 1000
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
channel: ChatChannel,
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
redis: Redis,
|
||||
users: list[int] | None = None,
|
||||
include_recent_messages: bool = False,
|
||||
) -> Self:
|
||||
c = cls.model_validate(channel)
|
||||
silence = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
SilenceUser.channel_id == channel.channel_id,
|
||||
SilenceUser.user_id == user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||
|
||||
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
|
||||
|
||||
if silence is not None:
|
||||
attribute = ChatUserAttributes(
|
||||
can_message=False,
|
||||
can_message_error=silence.reason or "You are muted in this channel.",
|
||||
last_read_id=last_read_id or 0,
|
||||
)
|
||||
c.moderated = True
|
||||
else:
|
||||
attribute = ChatUserAttributes(
|
||||
can_message=True,
|
||||
last_read_id=last_read_id or 0,
|
||||
)
|
||||
c.moderated = False
|
||||
|
||||
c.current_user_attributes = attribute
|
||||
if c.type != ChannelType.PUBLIC and users is not None:
|
||||
c.users = users
|
||||
c.last_message_id = last_msg
|
||||
c.last_read_id = last_read_id
|
||||
|
||||
if include_recent_messages:
|
||||
messages = (
|
||||
await session.exec(
|
||||
select(ChatMessage)
|
||||
.where(ChatMessage.channel_id == channel.channel_id)
|
||||
.order_by(col(ChatMessage.timestamp).desc())
|
||||
.limit(10)
|
||||
)
|
||||
).all()
|
||||
c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages]
|
||||
c.recent_messages.reverse()
|
||||
|
||||
if c.type == ChannelType.PM and users and len(users) == 2:
|
||||
target_user_id = next(u for u in users if u != user.id)
|
||||
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
|
||||
c.name = target_name.one()
|
||||
c.users = [target_user_id, user.id]
|
||||
return c
|
||||
|
||||
|
||||
# ChatMessage
|
||||
class ChatMessageDict(TypedDict):
|
||||
channel_id: int
|
||||
content: str
|
||||
message_id: int
|
||||
sender_id: int
|
||||
timestamp: datetime
|
||||
type: MessageType
|
||||
uuid: str | None
|
||||
is_action: NotRequired[bool]
|
||||
sender: NotRequired[UserDict]
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
ACTION = "action"
|
||||
MARKDOWN = "markdown"
|
||||
PLAIN = "plain"
|
||||
|
||||
|
||||
class ChatMessageBase(UTCBaseModel, SQLModel):
|
||||
class ChatMessageModel(DatabaseModel[ChatMessageDict]):
|
||||
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
|
||||
content: str = Field(sa_column=Column(VARCHAR(1000)))
|
||||
message_id: int = Field(index=True, primary_key=True, default=None)
|
||||
@@ -169,31 +251,21 @@ class ChatMessageBase(UTCBaseModel, SQLModel):
|
||||
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
|
||||
uuid: str | None = Field(default=None)
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def is_action(_session: AsyncSession, db_message: "ChatMessage") -> bool:
|
||||
return db_message.type == MessageType.ACTION
|
||||
|
||||
class ChatMessage(ChatMessageBase, table=True):
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def sender(_session: AsyncSession, db_message: "ChatMessage") -> UserDict:
|
||||
return await UserModel.transform(db_message.user)
|
||||
|
||||
|
||||
class ChatMessage(ChatMessageModel, table=True):
|
||||
__tablename__: str = "chat_messages"
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
channel: ChatChannel = Relationship()
|
||||
|
||||
|
||||
class ChatMessageResp(ChatMessageBase):
|
||||
sender: UserResp | None = None
|
||||
is_action: bool = False
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None
|
||||
) -> "ChatMessageResp":
|
||||
m = cls.model_validate(db_message.model_dump())
|
||||
m.is_action = db_message.type == MessageType.ACTION
|
||||
if user:
|
||||
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
else:
|
||||
m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES)
|
||||
return m
|
||||
|
||||
|
||||
# SilenceUser
|
||||
channel: "ChatChannel" = Relationship()
|
||||
|
||||
|
||||
class SilenceUser(UTCBaseModel, SQLModel, table=True):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.user import User
|
||||
from .beatmapset import Beatmapset
|
||||
from .user import User
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
from .user import User, UserResp
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from ._base import DatabaseModel, ondemand
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
from .user import User, UserDict, UserModel
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
@@ -9,7 +11,6 @@ from sqlmodel import (
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
func,
|
||||
select,
|
||||
@@ -17,17 +18,66 @@ from sqlmodel import (
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class ItemAttemptsCountBase(SQLModel):
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
class ItemAttemptsCountDict(TypedDict):
|
||||
accuracy: float
|
||||
attempts: int
|
||||
completed: int
|
||||
pp: float
|
||||
room_id: int
|
||||
total_score: int
|
||||
user_id: int
|
||||
user: NotRequired[UserDict]
|
||||
position: NotRequired[int]
|
||||
playlist_item_attempts: NotRequired[list[dict[str, Any]]]
|
||||
|
||||
|
||||
class ItemAttemptsCountModel(DatabaseModel[ItemAttemptsCountDict]):
|
||||
accuracy: float = 0.0
|
||||
attempts: int = Field(default=0)
|
||||
completed: int = Field(default=0)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
accuracy: float = 0.0
|
||||
pp: float = 0
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
total_score: int = 0
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def user(_session: AsyncSession, item_attempts: "ItemAttemptsCount") -> UserDict:
|
||||
user_instance = await item_attempts.awaitable_attrs.user
|
||||
return await UserModel.transform(user_instance)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def position(session: AsyncSession, item_attempts: "ItemAttemptsCount") -> int:
|
||||
return await item_attempts.get_position(session)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def playlist_item_attempts(
|
||||
session: AsyncSession,
|
||||
item_attempts: "ItemAttemptsCount",
|
||||
) -> list[dict[str, Any]]:
|
||||
playlist_scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == item_attempts.room_id,
|
||||
PlaylistBestScore.user_id == item_attempts.user_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
result: list[dict[str, Any]] = []
|
||||
for score in playlist_scores:
|
||||
result.append(
|
||||
{
|
||||
"id": score.playlist_id,
|
||||
"attempts": score.attempts,
|
||||
"passed": score.score.passed,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
|
||||
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountModel, table=True):
|
||||
__tablename__: str = "item_attempts_count"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
|
||||
@@ -37,15 +87,15 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=col(ItemAttemptsCountBase.room_id),
|
||||
order_by=col(ItemAttemptsCountBase.total_score).desc(),
|
||||
partition_by=col(ItemAttemptsCount.room_id),
|
||||
order_by=col(ItemAttemptsCount.total_score).desc(),
|
||||
)
|
||||
.label("rn")
|
||||
)
|
||||
subq = select(ItemAttemptsCountBase, rownum).subquery()
|
||||
subq = select(ItemAttemptsCount, rownum).subquery()
|
||||
stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id)
|
||||
result = await session.exec(stmt)
|
||||
return result.one()
|
||||
return result.first() or 0
|
||||
|
||||
async def update(self, session: AsyncSession):
|
||||
playlist_scores = (
|
||||
@@ -88,62 +138,3 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
|
||||
await session.refresh(item_attempts)
|
||||
await item_attempts.update(session)
|
||||
return item_attempts
|
||||
|
||||
|
||||
class ItemAttemptsResp(ItemAttemptsCountBase):
|
||||
user: UserResp | None = None
|
||||
position: int | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
item_attempts: ItemAttemptsCount,
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
) -> "ItemAttemptsResp":
|
||||
resp = cls.model_validate(item_attempts.model_dump())
|
||||
resp.user = await UserResp.from_db(
|
||||
await item_attempts.awaitable_attrs.user,
|
||||
session=session,
|
||||
include=["statistics", "team", "daily_challenge_user_stats"],
|
||||
)
|
||||
if "position" in include:
|
||||
resp.position = await item_attempts.get_position(session)
|
||||
# resp.accuracy *= 100
|
||||
return resp
|
||||
|
||||
|
||||
class ItemAttemptsCountForItem(BaseModel):
|
||||
id: int
|
||||
attempts: int
|
||||
passed: bool
|
||||
|
||||
|
||||
class PlaylistAggregateScore(BaseModel):
|
||||
playlist_item_attempts: list[ItemAttemptsCountForItem] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
room_id: int,
|
||||
user_id: int,
|
||||
session: AsyncSession,
|
||||
) -> "PlaylistAggregateScore":
|
||||
playlist_scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
PlaylistBestScore.user_id == user_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
playlist_item_attempts = []
|
||||
for score in playlist_scores:
|
||||
playlist_item_attempts.append(
|
||||
ItemAttemptsCountForItem(
|
||||
id=score.playlist_id,
|
||||
attempts=score.attempts,
|
||||
passed=score.score.passed,
|
||||
)
|
||||
)
|
||||
return cls(playlist_item_attempts=playlist_item_attempts)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.mods import APIMod
|
||||
from app.models.playlist import PlaylistItem
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from ._base import DatabaseModel, ondemand
|
||||
from .beatmap import Beatmap, BeatmapDict, BeatmapModel
|
||||
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
@@ -15,7 +15,6 @@ from sqlmodel import (
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
@@ -23,18 +22,34 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .room import Room
|
||||
from .score import ScoreDict
|
||||
|
||||
|
||||
class PlaylistBase(SQLModel, UTCBaseModel):
|
||||
id: int = Field(index=True)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||
class PlaylistDict(TypedDict):
|
||||
id: int
|
||||
room_id: int
|
||||
beatmap_id: int
|
||||
created_at: datetime | None
|
||||
ruleset_id: int
|
||||
expired: bool = Field(default=False)
|
||||
playlist_order: int = Field(default=0)
|
||||
played_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True)),
|
||||
default=None,
|
||||
allowed_mods: list[APIMod]
|
||||
required_mods: list[APIMod]
|
||||
freestyle: bool
|
||||
expired: bool
|
||||
owner_id: int
|
||||
playlist_order: int
|
||||
played_at: datetime | None
|
||||
beatmap: NotRequired["BeatmapDict"]
|
||||
scores: NotRequired[list[dict[str, Any]]]
|
||||
|
||||
|
||||
class PlaylistModel(DatabaseModel[PlaylistDict]):
|
||||
id: int = Field(index=True)
|
||||
room_id: int = Field(foreign_key="rooms.id")
|
||||
beatmap_id: int = Field(
|
||||
foreign_key="beatmaps.id",
|
||||
)
|
||||
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
|
||||
ruleset_id: int
|
||||
allowed_mods: list[APIMod] = Field(
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
@@ -43,16 +58,46 @@ class PlaylistBase(SQLModel, UTCBaseModel):
|
||||
default_factory=list,
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
beatmap_id: int = Field(
|
||||
foreign_key="beatmaps.id",
|
||||
)
|
||||
freestyle: bool = Field(default=False)
|
||||
expired: bool = Field(default=False)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||
playlist_order: int = Field(default=0)
|
||||
played_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True)),
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def beatmap(_session: AsyncSession, playlist: "Playlist", includes: list[str] | None = None) -> BeatmapDict:
|
||||
return await BeatmapModel.transform(playlist.beatmap, includes=includes)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def scores(session: AsyncSession, playlist: "Playlist") -> list["ScoreDict"]:
|
||||
from .score import Score, ScoreModel
|
||||
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(Score).where(
|
||||
Score.playlist_item_id == playlist.id,
|
||||
Score.room_id == playlist.room_id,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
result: list[ScoreDict] = []
|
||||
for score in scores:
|
||||
result.append(
|
||||
await ScoreModel.transform(
|
||||
score,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class Playlist(PlaylistBase, table=True):
|
||||
class Playlist(PlaylistModel, table=True):
|
||||
__tablename__: str = "room_playlists"
|
||||
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
|
||||
room_id: int = Field(foreign_key="rooms.id", exclude=True)
|
||||
|
||||
beatmap: Beatmap = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
@@ -60,7 +105,6 @@ class Playlist(PlaylistBase, table=True):
|
||||
}
|
||||
)
|
||||
room: "Room" = Relationship()
|
||||
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
|
||||
updated_at: datetime | None = Field(
|
||||
default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()}
|
||||
)
|
||||
@@ -121,15 +165,3 @@ class Playlist(PlaylistBase, table=True):
|
||||
raise ValueError("Playlist item not found")
|
||||
await session.delete(db_playlist)
|
||||
await session.commit()
|
||||
|
||||
|
||||
class PlaylistResp(PlaylistBase):
|
||||
beatmap: BeatmapResp | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp":
|
||||
data = playlist.model_dump()
|
||||
if "beatmap" in include:
|
||||
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
|
||||
resp = cls.model_validate(data)
|
||||
return resp
|
||||
|
||||
@@ -1,26 +1,39 @@
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, NotRequired, TypedDict
|
||||
|
||||
from .user import User, UserResp
|
||||
from app.models.score import GameMode
|
||||
|
||||
from ._base import DatabaseModel, included, ondemand
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship as SQLRelationship,
|
||||
SQLModel,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User, UserDict
|
||||
|
||||
|
||||
class RelationshipType(str, Enum):
|
||||
FOLLOW = "Friend"
|
||||
BLOCK = "Block"
|
||||
FOLLOW = "friend"
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
class Relationship(SQLModel, table=True):
|
||||
class RelationshipDict(TypedDict):
|
||||
target_id: int | None
|
||||
type: RelationshipType
|
||||
id: NotRequired[int | None]
|
||||
user_id: NotRequired[int | None]
|
||||
mutual: NotRequired[bool]
|
||||
target: NotRequired["UserDict"]
|
||||
|
||||
|
||||
class RelationshipModel(DatabaseModel[RelationshipDict]):
|
||||
__tablename__: str = "relationship"
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
@@ -34,6 +47,7 @@ class Relationship(SQLModel, table=True):
|
||||
ForeignKey("lazer_users.id"),
|
||||
index=True,
|
||||
),
|
||||
exclude=True,
|
||||
)
|
||||
target_id: int = Field(
|
||||
default=None,
|
||||
@@ -44,22 +58,10 @@ class Relationship(SQLModel, table=True):
|
||||
),
|
||||
)
|
||||
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
|
||||
target: User = SQLRelationship(
|
||||
sa_relationship_kwargs={
|
||||
"foreign_keys": "[Relationship.target_id]",
|
||||
"lazy": "selectin",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RelationshipResp(BaseModel):
|
||||
target_id: int
|
||||
target: UserResp
|
||||
mutual: bool = False
|
||||
type: RelationshipType
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, session: AsyncSession, relationship: Relationship) -> "RelationshipResp":
|
||||
@included
|
||||
@staticmethod
|
||||
async def mutual(session: AsyncSession, relationship: "Relationship") -> bool:
|
||||
target_relationship = (
|
||||
await session.exec(
|
||||
select(Relationship).where(
|
||||
@@ -68,23 +70,29 @@ class RelationshipResp(BaseModel):
|
||||
)
|
||||
)
|
||||
).first()
|
||||
mutual = bool(
|
||||
return bool(
|
||||
target_relationship is not None
|
||||
and relationship.type == RelationshipType.FOLLOW
|
||||
and target_relationship.type == RelationshipType.FOLLOW
|
||||
)
|
||||
return cls(
|
||||
target_id=relationship.target_id,
|
||||
target=await UserResp.from_db(
|
||||
relationship.target,
|
||||
session,
|
||||
include=[
|
||||
"team",
|
||||
"daily_challenge_user_stats",
|
||||
"statistics",
|
||||
"statistics_rulesets",
|
||||
],
|
||||
),
|
||||
mutual=mutual,
|
||||
type=relationship.type,
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def target(
|
||||
_session: AsyncSession,
|
||||
relationship: "Relationship",
|
||||
ruleset: GameMode | None = None,
|
||||
includes: list[str] | None = None,
|
||||
) -> "UserDict":
|
||||
from .user import UserModel
|
||||
|
||||
return await UserModel.transform(relationship.target, ruleset=ruleset, includes=includes)
|
||||
|
||||
|
||||
class Relationship(RelationshipModel, table=True):
|
||||
target: "User" = SQLRelationship(
|
||||
sa_relationship_kwargs={
|
||||
"foreign_keys": "[Relationship.target_id]",
|
||||
"lazy": "selectin",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from datetime import datetime
|
||||
from typing import ClassVar, NotRequired, TypedDict
|
||||
|
||||
from app.database.item_attempts_count import PlaylistAggregateScore
|
||||
from app.database.room_participated_user import RoomParticipatedUser
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.models.room import (
|
||||
MatchType,
|
||||
QueueMode,
|
||||
@@ -13,28 +11,58 @@ from app.models.room import (
|
||||
)
|
||||
from app.utils import utcnow
|
||||
|
||||
from .playlists import Playlist, PlaylistResp
|
||||
from .user import User, UserResp
|
||||
from ._base import DatabaseModel, included, ondemand
|
||||
from .item_attempts_count import ItemAttemptsCount, ItemAttemptsCountDict, ItemAttemptsCountModel
|
||||
from .playlists import Playlist, PlaylistDict, PlaylistModel
|
||||
from .room_participated_user import RoomParticipatedUser
|
||||
from .user import User, UserDict, UserModel
|
||||
|
||||
from pydantic import field_validator
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlmodel import BigInteger, Column, DateTime, Field, ForeignKey, Relationship, SQLModel, col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class RoomBase(SQLModel, UTCBaseModel):
|
||||
class RoomDict(TypedDict):
|
||||
id: int
|
||||
name: str
|
||||
category: RoomCategory
|
||||
status: RoomStatus
|
||||
type: MatchType
|
||||
duration: int | None
|
||||
starts_at: datetime | None
|
||||
ends_at: datetime | None
|
||||
max_attempts: int | None
|
||||
participant_count: int
|
||||
channel_id: int
|
||||
queue_mode: QueueMode
|
||||
auto_skip: bool
|
||||
auto_start_duration: int
|
||||
has_password: NotRequired[bool]
|
||||
current_playlist_item: NotRequired["PlaylistDict | None"]
|
||||
playlist: NotRequired[list["PlaylistDict"]]
|
||||
playlist_item_stats: NotRequired[RoomPlaylistItemStats]
|
||||
difficulty_range: NotRequired[RoomDifficultyRange]
|
||||
host: NotRequired[UserDict]
|
||||
recent_participants: NotRequired[list[UserDict]]
|
||||
current_user_score: NotRequired["ItemAttemptsCountDict | None"]
|
||||
|
||||
|
||||
class RoomModel(DatabaseModel[RoomDict]):
|
||||
SHOW_RESPONSE_INCLUDES: ClassVar[list[str]] = [
|
||||
"current_user_score.playlist_item_attempts",
|
||||
"host.country",
|
||||
"playlist.beatmap.beatmapset",
|
||||
"playlist.beatmap.checksum",
|
||||
"playlist.beatmap.max_combo",
|
||||
"recent_participants",
|
||||
]
|
||||
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
name: str = Field(index=True)
|
||||
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
|
||||
status: RoomStatus
|
||||
type: MatchType
|
||||
duration: int | None = Field(default=None) # minutes
|
||||
starts_at: datetime | None = Field(
|
||||
sa_column=Column(
|
||||
@@ -48,76 +76,88 @@ class RoomBase(SQLModel, UTCBaseModel):
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
participant_count: int = Field(default=0)
|
||||
max_attempts: int | None = Field(default=None) # playlists
|
||||
type: MatchType
|
||||
participant_count: int = Field(default=0)
|
||||
channel_id: int = 0
|
||||
queue_mode: QueueMode
|
||||
auto_skip: bool
|
||||
|
||||
auto_start_duration: int
|
||||
status: RoomStatus
|
||||
channel_id: int | None = None
|
||||
|
||||
|
||||
class Room(AsyncAttrs, RoomBase, table=True):
|
||||
__tablename__: str = "rooms"
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
password: str | None = Field(default=None)
|
||||
|
||||
host: User = Relationship()
|
||||
playlist: list[Playlist] = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"lazy": "selectin",
|
||||
"cascade": "all, delete-orphan",
|
||||
"overlaps": "room",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RoomResp(RoomBase):
|
||||
id: int
|
||||
has_password: bool = False
|
||||
host: UserResp | None = None
|
||||
playlist: list[PlaylistResp] = []
|
||||
playlist_item_stats: RoomPlaylistItemStats | None = None
|
||||
difficulty_range: RoomDifficultyRange | None = None
|
||||
current_playlist_item: PlaylistResp | None = None
|
||||
current_user_score: PlaylistAggregateScore | None = None
|
||||
recent_participants: list[UserResp] = Field(default_factory=list)
|
||||
channel_id: int = 0
|
||||
|
||||
@field_validator("channel_id", mode="before")
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
room: Room,
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
user: User | None = None,
|
||||
) -> "RoomResp":
|
||||
d = room.model_dump()
|
||||
d["channel_id"] = d.get("channel_id", 0) or 0
|
||||
d["has_password"] = bool(room.password)
|
||||
resp = cls.model_validate(d)
|
||||
def validate_channel_id(cls, v):
|
||||
"""将 None 转换为 0"""
|
||||
if v is None:
|
||||
return 0
|
||||
return v
|
||||
|
||||
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
|
||||
difficulty_range = RoomDifficultyRange(
|
||||
min=0,
|
||||
max=0,
|
||||
)
|
||||
rulesets = set()
|
||||
for playlist in room.playlist:
|
||||
@included
|
||||
@staticmethod
|
||||
async def has_password(_session: AsyncSession, room: "Room") -> bool:
|
||||
return bool(room.password)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def current_playlist_item(
|
||||
_session: AsyncSession, room: "Room", includes: list[str] | None = None
|
||||
) -> "PlaylistDict | None":
|
||||
playlists = await room.awaitable_attrs.playlist
|
||||
if not playlists:
|
||||
return None
|
||||
return await PlaylistModel.transform(playlists[-1], includes=includes)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def playlist(_session: AsyncSession, room: "Room", includes: list[str] | None = None) -> list["PlaylistDict"]:
|
||||
playlists = await room.awaitable_attrs.playlist
|
||||
result: list[PlaylistDict] = []
|
||||
for playlist_item in playlists:
|
||||
result.append(await PlaylistModel.transform(playlist_item, includes=includes))
|
||||
return result
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def playlist_item_stats(_session: AsyncSession, room: "Room") -> RoomPlaylistItemStats:
|
||||
playlists = await room.awaitable_attrs.playlist
|
||||
stats = RoomPlaylistItemStats(count_active=0, count_total=0, ruleset_ids=[])
|
||||
rulesets: set[int] = set()
|
||||
for playlist in playlists:
|
||||
stats.count_total += 1
|
||||
if not playlist.expired:
|
||||
stats.count_active += 1
|
||||
rulesets.add(playlist.ruleset_id)
|
||||
difficulty_range.min = min(difficulty_range.min, playlist.beatmap.difficulty_rating)
|
||||
difficulty_range.max = max(difficulty_range.max, playlist.beatmap.difficulty_rating)
|
||||
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
|
||||
stats.ruleset_ids = list(rulesets)
|
||||
resp.playlist_item_stats = stats
|
||||
resp.difficulty_range = difficulty_range
|
||||
resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None
|
||||
resp.recent_participants = []
|
||||
return stats
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def difficulty_range(_session: AsyncSession, room: "Room") -> RoomDifficultyRange:
|
||||
playlists = await room.awaitable_attrs.playlist
|
||||
if not playlists:
|
||||
return RoomDifficultyRange(min=0.0, max=0.0)
|
||||
min_diff = float("inf")
|
||||
max_diff = float("-inf")
|
||||
for playlist in playlists:
|
||||
rating = playlist.beatmap.difficulty_rating
|
||||
min_diff = min(min_diff, rating)
|
||||
max_diff = max(max_diff, rating)
|
||||
if min_diff == float("inf"):
|
||||
min_diff = 0.0
|
||||
if max_diff == float("-inf"):
|
||||
max_diff = 0.0
|
||||
return RoomDifficultyRange(min=min_diff, max=max_diff)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def host(_session: AsyncSession, room: "Room", includes: list[str] | None = None) -> UserDict:
|
||||
host_user = await room.awaitable_attrs.host
|
||||
return await UserModel.transform(host_user, includes=includes)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def recent_participants(session: AsyncSession, room: "Room") -> list[UserDict]:
|
||||
participants: list[UserDict] = []
|
||||
if room.category == RoomCategory.REALTIME:
|
||||
query = (
|
||||
select(RoomParticipatedUser)
|
||||
@@ -137,39 +177,67 @@ class RoomResp(RoomBase):
|
||||
.limit(8)
|
||||
.order_by(col(RoomParticipatedUser.joined_at).desc())
|
||||
)
|
||||
resp.participant_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(RoomParticipatedUser)
|
||||
.where(
|
||||
RoomParticipatedUser.room_id == room.id,
|
||||
)
|
||||
)
|
||||
).first() or 0
|
||||
for recent_participant in await session.exec(query):
|
||||
resp.recent_participants.append(
|
||||
await UserResp.from_db(
|
||||
await recent_participant.awaitable_attrs.user,
|
||||
session,
|
||||
include=["statistics"],
|
||||
user_instance = await recent_participant.awaitable_attrs.user
|
||||
participants.append(await UserModel.transform(user_instance))
|
||||
return participants
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def current_user_score(
|
||||
session: AsyncSession, room: "Room", includes: list[str] | None = None
|
||||
) -> "ItemAttemptsCountDict | None":
|
||||
item_attempt = (
|
||||
await session.exec(
|
||||
select(ItemAttemptsCount).where(
|
||||
ItemAttemptsCount.room_id == room.id,
|
||||
)
|
||||
)
|
||||
resp.host = await UserResp.from_db(await room.awaitable_attrs.host, session, include=["statistics"])
|
||||
if "current_user_score" in include and user:
|
||||
resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
|
||||
return resp
|
||||
).first()
|
||||
if item_attempt is None:
|
||||
return None
|
||||
|
||||
return await ItemAttemptsCountModel.transform(item_attempt, includes=includes)
|
||||
|
||||
|
||||
class APIUploadedRoom(RoomBase):
|
||||
def to_room(self) -> Room:
|
||||
"""
|
||||
将 APIUploadedRoom 转换为 Room 对象,playlist 字段需单独处理。
|
||||
"""
|
||||
room_dict = self.model_dump()
|
||||
room_dict.pop("playlist", None)
|
||||
# host_id 已在字段中
|
||||
return Room(**room_dict)
|
||||
class Room(AsyncAttrs, RoomModel, table=True):
|
||||
__tablename__: str = "rooms"
|
||||
|
||||
id: int | None
|
||||
host_id: int | None = None
|
||||
host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
password: str | None = Field(default=None)
|
||||
|
||||
host: User = Relationship()
|
||||
playlist: list[Playlist] = Relationship(
|
||||
sa_relationship_kwargs={
|
||||
"lazy": "selectin",
|
||||
"cascade": "all, delete-orphan",
|
||||
"overlaps": "room",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class APIUploadedRoom(SQLModel):
|
||||
name: str = Field(index=True)
|
||||
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
|
||||
status: RoomStatus
|
||||
type: MatchType
|
||||
duration: int | None = Field(default=None) # minutes
|
||||
starts_at: datetime | None = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default_factory=utcnow,
|
||||
)
|
||||
ends_at: datetime | None = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
max_attempts: int | None = Field(default=None) # playlists
|
||||
participant_count: int = Field(default=0)
|
||||
channel_id: int = 0
|
||||
queue_mode: QueueMode
|
||||
auto_skip: bool
|
||||
auto_start_duration: int
|
||||
playlist: list[Playlist] = Field(default_factory=list)
|
||||
|
||||
@@ -3,7 +3,7 @@ from datetime import date, datetime
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict
|
||||
|
||||
from app.calculator import (
|
||||
calculate_pp_weight,
|
||||
@@ -15,8 +15,6 @@ from app.calculator import (
|
||||
pre_fetch_and_calculate_pp,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.database.beatmap_playcounts import BeatmapPlaycounts
|
||||
from app.database.team import TeamMember
|
||||
from app.dependencies.database import get_redis
|
||||
from app.log import log
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
@@ -39,8 +37,10 @@ from app.models.scoring_mode import ScoringMode
|
||||
from app.storage import StorageService
|
||||
from app.utils import utcnow
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .beatmapset import BeatmapsetResp
|
||||
from ._base import DatabaseModel, OnDemand, included, ondemand
|
||||
from .beatmap import Beatmap, BeatmapDict, BeatmapModel
|
||||
from .beatmap_playcounts import BeatmapPlaycounts
|
||||
from .beatmapset import BeatmapsetDict, BeatmapsetModel
|
||||
from .best_scores import BestScore
|
||||
from .counts import MonthlyPlaycounts
|
||||
from .events import Event, EventType
|
||||
@@ -50,8 +50,9 @@ from .relationship import (
|
||||
RelationshipType,
|
||||
)
|
||||
from .score_token import ScoreToken
|
||||
from .team import TeamMember
|
||||
from .total_score_best_scores import TotalScoreBestScore
|
||||
from .user import User, UserResp
|
||||
from .user import User, UserDict, UserModel
|
||||
|
||||
from pydantic import BaseModel, field_serializer, field_validator
|
||||
from redis.asyncio import Redis
|
||||
@@ -80,30 +81,290 @@ if TYPE_CHECKING:
|
||||
logger = log("Score")
|
||||
|
||||
|
||||
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
# 基本字段
|
||||
class ScoreDict(TypedDict):
|
||||
beatmap_id: int
|
||||
id: int
|
||||
rank: Rank
|
||||
type: str
|
||||
user_id: int
|
||||
accuracy: float
|
||||
build_id: int | None
|
||||
ended_at: datetime
|
||||
has_replay: bool
|
||||
max_combo: int
|
||||
passed: bool
|
||||
pp: float
|
||||
started_at: datetime
|
||||
total_score: int
|
||||
maximum_statistics: ScoreStatistics
|
||||
mods: list[APIMod]
|
||||
classic_total_score: int | None
|
||||
preserve: bool
|
||||
processed: bool
|
||||
ranked: bool
|
||||
playlist_item_id: NotRequired[int | None]
|
||||
room_id: NotRequired[int | None]
|
||||
best_id: NotRequired[int | None]
|
||||
legacy_perfect: NotRequired[bool]
|
||||
is_perfect_combo: NotRequired[bool]
|
||||
ruleset_id: NotRequired[int]
|
||||
statistics: NotRequired[ScoreStatistics]
|
||||
beatmapset: NotRequired[BeatmapsetDict]
|
||||
beatmap: NotRequired[BeatmapDict]
|
||||
current_user_attributes: NotRequired[CurrentUserAttributes]
|
||||
position: NotRequired[int | None]
|
||||
scores_around: NotRequired["ScoreAround | None"]
|
||||
rank_country: NotRequired[int | None]
|
||||
rank_global: NotRequired[int | None]
|
||||
user: NotRequired[UserDict]
|
||||
weight: NotRequired[float | None]
|
||||
|
||||
# ScoreResp 字段
|
||||
legacy_total_score: NotRequired[int]
|
||||
|
||||
|
||||
class ScoreModel(AsyncAttrs, DatabaseModel[ScoreDict]):
|
||||
# https://github.com/ppy/osu-web/blob/master/app/Transformers/ScoreTransformer.php#L72-L84
|
||||
MULTIPLAYER_SCORE_INCLUDE: ClassVar[list[str]] = ["playlist_item_id", "room_id", "solo_score_id"]
|
||||
MULTIPLAYER_BASE_INCLUDES: ClassVar[list[str]] = [
|
||||
"user.country",
|
||||
"user.cover",
|
||||
"user.team",
|
||||
*MULTIPLAYER_SCORE_INCLUDE,
|
||||
]
|
||||
USER_PROFILE_INCLUDES: ClassVar[list[str]] = ["beatmap", "beatmapset", "user"]
|
||||
|
||||
# 基本字段
|
||||
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
||||
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
|
||||
rank: Rank
|
||||
type: str
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("lazer_users.id"),
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
accuracy: float
|
||||
map_md5: str = Field(max_length=32, index=True)
|
||||
build_id: int | None = Field(default=None)
|
||||
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score
|
||||
ended_at: datetime = Field(sa_column=Column(DateTime))
|
||||
has_replay: bool = Field(sa_column=Column(Boolean))
|
||||
max_combo: int
|
||||
mods: list[APIMod] = Field(sa_column=Column(JSON))
|
||||
passed: bool = Field(sa_column=Column(Boolean))
|
||||
playlist_item_id: int | None = Field(default=None) # multiplayer
|
||||
pp: float = Field(default=0.0)
|
||||
preserve: bool = Field(default=True, sa_column=Column(Boolean))
|
||||
rank: Rank
|
||||
room_id: int | None = Field(default=None) # multiplayer
|
||||
started_at: datetime = Field(sa_column=Column(DateTime))
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
|
||||
type: str
|
||||
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
||||
maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict)
|
||||
processed: bool = False # solo_score
|
||||
ranked: bool = False
|
||||
mods: list[APIMod] = Field(sa_column=Column(JSON))
|
||||
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
|
||||
|
||||
# solo
|
||||
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger))
|
||||
preserve: bool = Field(default=True, sa_column=Column(Boolean))
|
||||
processed: bool = Field(default=False)
|
||||
ranked: bool = Field(default=False)
|
||||
|
||||
# multiplayer
|
||||
playlist_item_id: OnDemand[int | None] = Field(default=None)
|
||||
room_id: OnDemand[int | None] = Field(default=None)
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def best_id(
|
||||
session: AsyncSession,
|
||||
score: "Score",
|
||||
) -> int | None:
|
||||
return await get_best_id(session, score.id)
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def legacy_perfect(
|
||||
_session: AsyncSession,
|
||||
score: "Score",
|
||||
) -> bool:
|
||||
await score.awaitable_attrs.beatmap
|
||||
return score.max_combo == score.beatmap.max_combo
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def is_perfect_combo(
|
||||
_session: AsyncSession,
|
||||
score: "Score",
|
||||
) -> bool:
|
||||
await score.awaitable_attrs.beatmap
|
||||
return score.max_combo == score.beatmap.max_combo
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def ruleset_id(
|
||||
_session: AsyncSession,
|
||||
score: "Score",
|
||||
) -> int:
|
||||
return int(score.gamemode)
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def statistics(
|
||||
_session: AsyncSession,
|
||||
score: "Score",
|
||||
) -> ScoreStatistics:
|
||||
stats = {
|
||||
HitResult.MISS: score.nmiss,
|
||||
HitResult.MEH: score.n50,
|
||||
HitResult.OK: score.n100,
|
||||
HitResult.GREAT: score.n300,
|
||||
HitResult.PERFECT: score.ngeki,
|
||||
HitResult.GOOD: score.nkatu,
|
||||
}
|
||||
if score.nlarge_tick_miss is not None:
|
||||
stats[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
|
||||
if score.nslider_tail_hit is not None:
|
||||
stats[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
|
||||
if score.nsmall_tick_hit is not None:
|
||||
stats[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
|
||||
if score.nlarge_tick_hit is not None:
|
||||
stats[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
|
||||
return stats
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def beatmapset(
|
||||
_session: AsyncSession,
|
||||
score: "Score",
|
||||
includes: list[str] | None = None,
|
||||
) -> BeatmapsetDict:
|
||||
await score.awaitable_attrs.beatmap
|
||||
return await BeatmapsetModel.transform(score.beatmap.beatmapset, includes=includes)
|
||||
|
||||
# reorder beatmapset and beatmap
|
||||
# https://github.com/ppy/osu/blob/d8900defd34690de92be3406003fb3839fc0df1d/osu.Game/Online/API/Requests/Responses/SoloScoreInfo.cs#L111-L112
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def beatmap(
|
||||
_session: AsyncSession,
|
||||
score: "Score",
|
||||
includes: list[str] | None = None,
|
||||
) -> BeatmapDict:
|
||||
await score.awaitable_attrs.beatmap
|
||||
return await BeatmapModel.transform(score.beatmap, includes=includes)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def current_user_attributes(
|
||||
_session: AsyncSession,
|
||||
score: "Score",
|
||||
) -> CurrentUserAttributes:
|
||||
return CurrentUserAttributes(pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id))
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def position(
|
||||
session: AsyncSession,
|
||||
score: "Score",
|
||||
) -> int | None:
|
||||
return await get_score_position_by_id(
|
||||
session,
|
||||
score.beatmap_id,
|
||||
score.id,
|
||||
mode=score.gamemode,
|
||||
user=score.user,
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def scores_around(
|
||||
session: AsyncSession, _score: "Score", playlist_id: int, room_id: int, is_playlist: bool
|
||||
) -> "ScoreAround | None":
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(PlaylistBestScore).where(
|
||||
PlaylistBestScore.playlist_id == playlist_id,
|
||||
PlaylistBestScore.room_id == room_id,
|
||||
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
|
||||
col(PlaylistBestScore.score).has(col(Score.passed).is_(True)) if not is_playlist else True,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
|
||||
higher_scores = []
|
||||
lower_scores = []
|
||||
for score in scores:
|
||||
total_score = score.score.total_score
|
||||
resp = await ScoreModel.transform(score.score, includes=ScoreModel.MULTIPLAYER_BASE_INCLUDES)
|
||||
if score.total_score > total_score:
|
||||
higher_scores.append(resp)
|
||||
elif score.total_score < total_score:
|
||||
lower_scores.append(resp)
|
||||
|
||||
return ScoreAround(
|
||||
higher=MultiplayerScores(scores=higher_scores),
|
||||
lower=MultiplayerScores(scores=lower_scores),
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def rank_country(
|
||||
session: AsyncSession,
|
||||
score: "Score",
|
||||
) -> int | None:
|
||||
return (
|
||||
await get_score_position_by_id(
|
||||
session,
|
||||
score.beatmap_id,
|
||||
score.id,
|
||||
score.gamemode,
|
||||
score.user,
|
||||
type=LeaderboardType.COUNTRY,
|
||||
)
|
||||
or None
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def rank_global(
|
||||
session: AsyncSession,
|
||||
score: "Score",
|
||||
) -> int | None:
|
||||
return (
|
||||
await get_score_position_by_id(
|
||||
session,
|
||||
score.beatmap_id,
|
||||
score.id,
|
||||
mode=score.gamemode,
|
||||
user=score.user,
|
||||
)
|
||||
or None
|
||||
)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def user(
|
||||
_session: AsyncSession,
|
||||
score: "Score",
|
||||
includes: list[str] | None = None,
|
||||
) -> UserDict:
|
||||
return await UserModel.transform(score.user, ruleset=score.gamemode, includes=includes or [])
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def weight(
|
||||
session: AsyncSession,
|
||||
score: "Score",
|
||||
) -> float | None:
|
||||
best_id = await get_best_id(session, score.id)
|
||||
if best_id:
|
||||
return calculate_pp_weight(best_id - 1)
|
||||
return None
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def legacy_total_score(
|
||||
_session: AsyncSession,
|
||||
_score: "Score",
|
||||
) -> int:
|
||||
return 0
|
||||
|
||||
@field_validator("maximum_statistics", mode="before")
|
||||
@classmethod
|
||||
@@ -151,17 +412,9 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
# TODO: current_user_attributes
|
||||
|
||||
|
||||
class Score(ScoreBase, table=True):
|
||||
class Score(ScoreModel, table=True):
|
||||
__tablename__: str = "scores"
|
||||
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("lazer_users.id"),
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
|
||||
# ScoreStatistics
|
||||
n300: int = Field(exclude=True)
|
||||
n100: int = Field(exclude=True)
|
||||
@@ -175,6 +428,7 @@ class Score(ScoreBase, table=True):
|
||||
nsmall_tick_hit: int | None = Field(default=None, exclude=True)
|
||||
gamemode: GameMode = Field(index=True)
|
||||
pinned_order: int = Field(default=0, exclude=True)
|
||||
map_md5: str = Field(max_length=32, index=True, exclude=True)
|
||||
|
||||
@field_validator("gamemode", mode="before")
|
||||
@classmethod
|
||||
@@ -245,9 +499,11 @@ class Score(ScoreBase, table=True):
|
||||
maximum_statistics=self.maximum_statistics,
|
||||
)
|
||||
|
||||
async def to_resp(self, session: AsyncSession, api_version: int) -> "ScoreResp | LegacyScoreResp":
|
||||
async def to_resp(
|
||||
self, session: AsyncSession, api_version: int, includes: list[str] = []
|
||||
) -> "ScoreDict | LegacyScoreResp":
|
||||
if api_version >= 20220705:
|
||||
return await ScoreResp.from_db(session, self)
|
||||
return await ScoreModel.transform(self, includes=includes)
|
||||
return await LegacyScoreResp.from_db(session, self)
|
||||
|
||||
async def delete(
|
||||
@@ -270,141 +526,7 @@ class Score(ScoreBase, table=True):
|
||||
await session.delete(self)
|
||||
|
||||
|
||||
class ScoreResp(ScoreBase):
|
||||
id: int
|
||||
user_id: int
|
||||
is_perfect_combo: bool = False
|
||||
legacy_perfect: bool = False
|
||||
legacy_total_score: int = 0 # FIXME
|
||||
weight: float = 0.0
|
||||
best_id: int | None = None
|
||||
ruleset_id: int | None = None
|
||||
beatmap: BeatmapResp | None = None
|
||||
beatmapset: BeatmapsetResp | None = None
|
||||
user: UserResp | None = None
|
||||
statistics: ScoreStatistics | None = None
|
||||
maximum_statistics: ScoreStatistics | None = None
|
||||
rank_global: int | None = None
|
||||
rank_country: int | None = None
|
||||
position: int | None = None
|
||||
scores_around: "ScoreAround | None" = None
|
||||
current_user_attributes: CurrentUserAttributes | None = None
|
||||
|
||||
@field_validator(
|
||||
"has_replay",
|
||||
"passed",
|
||||
"preserve",
|
||||
"is_perfect_combo",
|
||||
"legacy_perfect",
|
||||
"processed",
|
||||
"ranked",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def validate_bool_fields(cls, v):
|
||||
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
|
||||
if isinstance(v, int):
|
||||
return bool(v)
|
||||
return v
|
||||
|
||||
@field_validator("statistics", "maximum_statistics", mode="before")
|
||||
@classmethod
|
||||
def validate_statistics_fields(cls, v):
|
||||
"""处理统计字段中的字符串键,转换为 HitResult 枚举"""
|
||||
if isinstance(v, dict):
|
||||
converted = {}
|
||||
for key, value in v.items():
|
||||
if isinstance(key, str):
|
||||
try:
|
||||
# 尝试将字符串转换为 HitResult 枚举
|
||||
enum_key = HitResult(key)
|
||||
converted[enum_key] = value
|
||||
except ValueError:
|
||||
# 如果转换失败,跳过这个键值对
|
||||
continue
|
||||
else:
|
||||
converted[key] = value
|
||||
return converted
|
||||
return v
|
||||
|
||||
@field_serializer("statistics", when_used="json")
|
||||
def serialize_statistics_fields(self, v):
|
||||
"""序列化统计字段,确保枚举值正确转换为字符串"""
|
||||
if isinstance(v, dict):
|
||||
serialized = {}
|
||||
for key, value in v.items():
|
||||
if hasattr(key, "value"):
|
||||
# 如果是枚举,使用其值
|
||||
serialized[key.value] = value
|
||||
else:
|
||||
# 否则直接使用键
|
||||
serialized[str(key)] = value
|
||||
return serialized
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
|
||||
# 确保 score 对象完全加载,避免懒加载问题
|
||||
await session.refresh(score)
|
||||
|
||||
s = cls.model_validate(score.model_dump())
|
||||
await score.awaitable_attrs.beatmap
|
||||
s.beatmap = await BeatmapResp.from_db(score.beatmap)
|
||||
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user)
|
||||
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
|
||||
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
|
||||
s.ruleset_id = int(score.gamemode)
|
||||
best_id = await get_best_id(session, score.id)
|
||||
if best_id:
|
||||
s.best_id = best_id
|
||||
s.weight = calculate_pp_weight(best_id - 1)
|
||||
s.statistics = {
|
||||
HitResult.MISS: score.nmiss,
|
||||
HitResult.MEH: score.n50,
|
||||
HitResult.OK: score.n100,
|
||||
HitResult.GREAT: score.n300,
|
||||
HitResult.PERFECT: score.ngeki,
|
||||
HitResult.GOOD: score.nkatu,
|
||||
}
|
||||
if score.nlarge_tick_miss is not None:
|
||||
s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
|
||||
if score.nslider_tail_hit is not None:
|
||||
s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
|
||||
if score.nsmall_tick_hit is not None:
|
||||
s.statistics[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
|
||||
if score.nlarge_tick_hit is not None:
|
||||
s.statistics[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
|
||||
s.user = await UserResp.from_db(
|
||||
score.user,
|
||||
session,
|
||||
include=["statistics", "team", "daily_challenge_user_stats"],
|
||||
ruleset=score.gamemode,
|
||||
)
|
||||
s.rank_global = (
|
||||
await get_score_position_by_id(
|
||||
session,
|
||||
score.beatmap_id,
|
||||
score.id,
|
||||
mode=score.gamemode,
|
||||
user=score.user,
|
||||
)
|
||||
or None
|
||||
)
|
||||
s.rank_country = (
|
||||
await get_score_position_by_id(
|
||||
session,
|
||||
score.beatmap_id,
|
||||
score.id,
|
||||
score.gamemode,
|
||||
score.user,
|
||||
type=LeaderboardType.COUNTRY,
|
||||
)
|
||||
or None
|
||||
)
|
||||
s.current_user_attributes = CurrentUserAttributes(
|
||||
pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
|
||||
)
|
||||
return s
|
||||
MultiplayScoreDict = ScoreModel.generate_typeddict(tuple(Score.MULTIPLAYER_BASE_INCLUDES)) # pyright: ignore[reportGeneralTypeIssues]
|
||||
|
||||
|
||||
class LegacyStatistics(BaseModel):
|
||||
@@ -417,31 +539,25 @@ class LegacyStatistics(BaseModel):
|
||||
|
||||
|
||||
class LegacyScoreResp(UTCBaseModel):
|
||||
accuracy: float
|
||||
best_id: int
|
||||
created_at: datetime
|
||||
id: int
|
||||
max_combo: int
|
||||
mode: GameMode
|
||||
mode_int: int
|
||||
best_id: int
|
||||
user_id: int
|
||||
accuracy: float
|
||||
mods: list[str] # acronym
|
||||
passed: bool
|
||||
score: int
|
||||
max_combo: int
|
||||
perfect: bool = False
|
||||
statistics: LegacyStatistics
|
||||
passed: bool
|
||||
pp: float
|
||||
rank: Rank
|
||||
created_at: datetime
|
||||
mode: GameMode
|
||||
mode_int: int
|
||||
replay: bool
|
||||
score: int
|
||||
statistics: LegacyStatistics
|
||||
type: str
|
||||
user_id: int
|
||||
current_user_attributes: CurrentUserAttributes
|
||||
user: UserResp
|
||||
beatmap: BeatmapResp
|
||||
rank_global: int | None = Field(default=None, exclude=True)
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, session: AsyncSession, score: Score) -> "LegacyScoreResp":
|
||||
await session.refresh(score)
|
||||
async def from_db(cls, session: AsyncSession, score: "Score") -> "LegacyScoreResp":
|
||||
await score.awaitable_attrs.beatmap
|
||||
return cls(
|
||||
accuracy=score.accuracy,
|
||||
@@ -465,34 +581,13 @@ class LegacyScoreResp(UTCBaseModel):
|
||||
count_geki=score.ngeki or 0,
|
||||
count_katu=score.nkatu or 0,
|
||||
),
|
||||
type=score.type,
|
||||
user_id=score.user_id,
|
||||
current_user_attributes=CurrentUserAttributes(
|
||||
pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
|
||||
),
|
||||
user=await UserResp.from_db(
|
||||
score.user,
|
||||
session,
|
||||
include=["statistics", "team", "daily_challenge_user_stats"],
|
||||
ruleset=score.gamemode,
|
||||
),
|
||||
beatmap=await BeatmapResp.from_db(score.beatmap),
|
||||
perfect=score.is_perfect_combo,
|
||||
rank_global=(
|
||||
await get_score_position_by_id(
|
||||
session,
|
||||
score.beatmap_id,
|
||||
score.id,
|
||||
mode=score.gamemode,
|
||||
user=score.user,
|
||||
)
|
||||
or None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class MultiplayerScores(RespWithCursor):
|
||||
scores: list[ScoreResp] = Field(default_factory=list)
|
||||
scores: list[MultiplayScoreDict] = Field(default_factory=list) # pyright: ignore[reportInvalidTypeForm]
|
||||
params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -842,13 +937,13 @@ async def get_user_best_pp(
|
||||
|
||||
|
||||
# https://github.com/ppy/osu-queue-score-statistics/blob/master/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/PlayValidityHelper.cs
|
||||
def get_play_length(score: Score, beatmap_length: int):
|
||||
def get_play_length(score: "Score", beatmap_length: int):
|
||||
speed_rate = get_speed_rate(score.mods)
|
||||
length = beatmap_length / speed_rate
|
||||
return int(min(length, (score.ended_at - score.started_at).total_seconds()))
|
||||
|
||||
|
||||
def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]:
|
||||
def calculate_playtime(score: "Score", beatmap_length: int) -> tuple[int, bool]:
|
||||
total_length = get_play_length(score, beatmap_length)
|
||||
total_obj_hited = (
|
||||
score.n300
|
||||
@@ -937,7 +1032,7 @@ async def process_score(
|
||||
return score
|
||||
|
||||
|
||||
async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
|
||||
async def _process_score_pp(score: "Score", session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
|
||||
if score.pp != 0:
|
||||
logger.debug(
|
||||
"Skipping PP calculation for score {score_id} | already set {pp:.2f}",
|
||||
@@ -984,7 +1079,7 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f
|
||||
)
|
||||
|
||||
|
||||
async def _process_score_events(score: Score, session: AsyncSession):
|
||||
async def _process_score_events(score: "Score", session: AsyncSession):
|
||||
total_users = (await session.exec(select(func.count()).select_from(User))).one()
|
||||
rank_global = await get_score_position_by_id(
|
||||
session,
|
||||
@@ -1088,7 +1183,7 @@ async def _process_statistics(
|
||||
session: AsyncSession,
|
||||
redis: Redis,
|
||||
user: User,
|
||||
score: Score,
|
||||
score: "Score",
|
||||
score_token: int,
|
||||
beatmap_length: int,
|
||||
beatmap_status: BeatmapRankStatus,
|
||||
@@ -1318,7 +1413,7 @@ async def process_user(
|
||||
redis: Redis,
|
||||
fetcher: "Fetcher",
|
||||
user: User,
|
||||
score: Score,
|
||||
score: "Score",
|
||||
score_token: int,
|
||||
beatmap_length: int,
|
||||
beatmap_status: BeatmapRankStatus,
|
||||
|
||||
13
app/database/search_beatmapset.py
Normal file
13
app/database/search_beatmapset.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from . import beatmap # noqa: F401
|
||||
from .beatmapset import BeatmapsetModel
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
SearchBeatmapset = BeatmapsetModel.generate_typeddict(("beatmaps.max_combo", "pack_tags"))
|
||||
|
||||
|
||||
class SearchBeatmapsetsResp(SQLModel):
|
||||
beatmapsets: list[SearchBeatmapset] # pyright: ignore[reportInvalidTypeForm]
|
||||
total: int
|
||||
cursor: dict[str, int | float | str] | None = None
|
||||
cursor_string: str | None = None
|
||||
@@ -1,10 +1,11 @@
|
||||
from datetime import timedelta
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
|
||||
|
||||
from app.models.score import GameMode
|
||||
from app.utils import utcnow
|
||||
|
||||
from ._base import DatabaseModel, included, ondemand
|
||||
from .rank_history import RankHistory
|
||||
|
||||
from pydantic import field_validator
|
||||
@@ -15,7 +16,6 @@ from sqlmodel import (
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
func,
|
||||
select,
|
||||
@@ -23,10 +23,40 @@ from sqlmodel import (
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User, UserResp
|
||||
from .user import User, UserDict
|
||||
|
||||
|
||||
class UserStatisticsBase(SQLModel):
|
||||
class UserStatisticsDict(TypedDict):
|
||||
mode: GameMode
|
||||
count_100: int
|
||||
count_300: int
|
||||
count_50: int
|
||||
count_miss: int
|
||||
pp: float
|
||||
ranked_score: int
|
||||
hit_accuracy: float
|
||||
total_score: int
|
||||
total_hits: int
|
||||
maximum_combo: int
|
||||
play_count: int
|
||||
play_time: int
|
||||
replays_watched_by_others: int
|
||||
is_ranked: bool
|
||||
level: NotRequired[dict[str, int]]
|
||||
global_rank: NotRequired[int | None]
|
||||
grade_counts: NotRequired[dict[str, int]]
|
||||
rank_change_since_30_days: NotRequired[int]
|
||||
country_rank: NotRequired[int | None]
|
||||
user: NotRequired["UserDict"]
|
||||
|
||||
|
||||
class UserStatisticsModel(DatabaseModel[UserStatisticsDict]):
|
||||
RANKING_INCLUDES: ClassVar[list[str]] = [
|
||||
"user.country",
|
||||
"user.cover",
|
||||
"user.team",
|
||||
]
|
||||
|
||||
mode: GameMode = Field(index=True)
|
||||
count_100: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
count_300: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
@@ -57,8 +87,63 @@ class UserStatisticsBase(SQLModel):
|
||||
return GameMode.OSU
|
||||
return v
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def level(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]:
|
||||
return {
|
||||
"current": int(statistics.level_current),
|
||||
"progress": int(math.fmod(statistics.level_current, 1) * 100),
|
||||
}
|
||||
|
||||
class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
|
||||
@included
|
||||
@staticmethod
|
||||
async def global_rank(session: AsyncSession, statistics: "UserStatistics") -> int | None:
|
||||
return await get_rank(session, statistics)
|
||||
|
||||
@included
|
||||
@staticmethod
|
||||
async def grade_counts(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]:
|
||||
return {
|
||||
"ss": statistics.grade_ss,
|
||||
"ssh": statistics.grade_ssh,
|
||||
"s": statistics.grade_s,
|
||||
"sh": statistics.grade_sh,
|
||||
"a": statistics.grade_a,
|
||||
}
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def rank_change_since_30_days(session: AsyncSession, statistics: "UserStatistics") -> int:
|
||||
global_rank = await get_rank(session, statistics)
|
||||
rank_best = (
|
||||
await session.exec(
|
||||
select(func.max(RankHistory.rank)).where(
|
||||
RankHistory.date > utcnow() - timedelta(days=30),
|
||||
RankHistory.user_id == statistics.user_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if rank_best is None or global_rank is None:
|
||||
return 0
|
||||
return rank_best - global_rank
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def country_rank(
|
||||
session: AsyncSession, statistics: "UserStatistics", user_country: str | None = None
|
||||
) -> int | None:
|
||||
return await get_rank(session, statistics, user_country)
|
||||
|
||||
@ondemand
|
||||
@staticmethod
|
||||
async def user(_session: AsyncSession, statistics: "UserStatistics") -> "UserDict":
|
||||
from .user import UserModel
|
||||
|
||||
user_instance = await statistics.awaitable_attrs.user
|
||||
return await UserModel.transform(user_instance)
|
||||
|
||||
|
||||
class UserStatistics(AsyncAttrs, UserStatisticsModel, table=True):
|
||||
__tablename__: str = "lazer_user_statistics"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(
|
||||
@@ -80,74 +165,6 @@ class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
|
||||
user: "User" = Relationship(back_populates="statistics")
|
||||
|
||||
|
||||
class UserStatisticsResp(UserStatisticsBase):
|
||||
user: "UserResp | None" = None
|
||||
rank_change_since_30_days: int | None = 0
|
||||
global_rank: int | None = Field(default=None)
|
||||
country_rank: int | None = Field(default=None)
|
||||
grade_counts: dict[str, int] = Field(
|
||||
default_factory=lambda: {
|
||||
"ss": 0,
|
||||
"ssh": 0,
|
||||
"s": 0,
|
||||
"sh": 0,
|
||||
"a": 0,
|
||||
}
|
||||
)
|
||||
level: dict[str, int] = Field(
|
||||
default_factory=lambda: {
|
||||
"current": 1,
|
||||
"progress": 0,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls,
|
||||
obj: UserStatistics,
|
||||
session: AsyncSession,
|
||||
user_country: str | None = None,
|
||||
include: list[str] = [],
|
||||
) -> "UserStatisticsResp":
|
||||
s = cls.model_validate(obj.model_dump())
|
||||
s.grade_counts = {
|
||||
"ss": obj.grade_ss,
|
||||
"ssh": obj.grade_ssh,
|
||||
"s": obj.grade_s,
|
||||
"sh": obj.grade_sh,
|
||||
"a": obj.grade_a,
|
||||
}
|
||||
s.level = {
|
||||
"current": int(obj.level_current),
|
||||
"progress": int(math.fmod(obj.level_current, 1) * 100),
|
||||
}
|
||||
if "user" in include:
|
||||
from .user import RANKING_INCLUDES, UserResp
|
||||
|
||||
user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
|
||||
s.user = user
|
||||
user_country = user.country_code
|
||||
|
||||
s.global_rank = await get_rank(session, obj)
|
||||
s.country_rank = await get_rank(session, obj, user_country)
|
||||
|
||||
if "rank_change_since_30_days" in include:
|
||||
rank_best = (
|
||||
await session.exec(
|
||||
select(func.max(RankHistory.rank)).where(
|
||||
RankHistory.date > utcnow() - timedelta(days=30),
|
||||
RankHistory.user_id == obj.user_id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if rank_best is None or s.global_rank is None:
|
||||
s.rank_change_since_30_days = 0
|
||||
else:
|
||||
s.rank_change_since_30_days = rank_best - s.global_rank
|
||||
|
||||
return s
|
||||
|
||||
|
||||
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
|
||||
from .user import User
|
||||
|
||||
@@ -164,7 +181,6 @@ async def get_rank(session: AsyncSession, statistics: UserStatistics, country: s
|
||||
query = query.join(User).where(User.country_code == country)
|
||||
|
||||
subq = query.subquery()
|
||||
|
||||
result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id))
|
||||
|
||||
rank = result.first()
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.calculator import calculate_score_to_level
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.models.score import GameMode, Rank
|
||||
|
||||
from .statistics import UserStatistics
|
||||
from .user import User
|
||||
|
||||
from sqlmodel import (
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user