refactor(database): use a new 'On-Demand' design (#86)

Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
MingxuanGame
2025-11-23 21:41:02 +08:00
committed by GitHub
parent 42f1d53d3e
commit 40da994ae8
46 changed files with 4396 additions and 2354 deletions

View File

@@ -9,14 +9,17 @@ git clone https://github.com/GooGuTeam/g0v0-server.git
此外,您还需要:
- clone 旁观服务器到 g0v0-server 的文件夹。
```bash
git clone https://github.com/GooGuTeam/osu-server-spectator.git spectator-server
```
- clone 表现分计算器到 g0v0-server 的文件夹。
```bash
git clone https://github.com/GooGuTeam/osu-performance-server.git performance-server
```
- 下载并放置自定义规则集 DLL 到 `rulesets/` 目录(如果需要)。
## 开发环境
@@ -94,27 +97,82 @@ uv sync
所有的数据库模型定义在 `app.database` 里,并且在 `__init__.py` 中导出。
如果这个模型的数据表结构和响应不完全相同,遵循 `Base` - `Table` - `Resp` 结构:
本项目使用一种“按需返回”的设计模式,遵循 `Dict` - `Model` - `Table` 结构。详细设计思路请参考[这篇文章](https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/)。
#### 基本结构
1. **Dict**: 定义模型转换后的字典结构,用于类型检查。必须在 Model 之前定义。
2. **Model**: 继承自 `DatabaseModel[Dict]`,定义字段和计算属性。
3. **Table**: 继承自 `Model`,定义数据库表结构。
```python
class ModelBase(SQLModel):
# 定义共有内容
from typing import TypedDict, NotRequired
from app.database._base import DatabaseModel, OnDemand, included, ondemand
from sqlmodel import Field
# 1. 定义 Dict
class UserDict(TypedDict):
id: int
username: str
email: NotRequired[str] # 可选字段
followers_count: int # 计算属性
# 2. 定义 Model
class UserModel(DatabaseModel[UserDict]):
id: int = Field(primary_key=True)
username: str
email: OnDemand[str] # 使用 OnDemand 标记可选字段
# 普通计算属性 (总是返回)
@included
@staticmethod
async def followers_count(session: AsyncSession, instance: "User") -> int:
return await session.scalar(select(func.count()).where(Follower.followed_id == instance.id))
# 可选计算属性 (仅在 includes 中指定时返回)
@ondemand
@staticmethod
async def some_optional_property(session: AsyncSession, instance: "User") -> str:
...
class Model(ModelBase, table=True):
# 定义数据库表内容
# 3. 定义 Table
class User(UserModel, table=True):
password: str # 仅在数据库中存在的字段
...
```
#### 字段类型
class ModelResp(ModelBase):
# 定义响应内容
...
@classmethod
def from_db(cls, db: Model) -> "ModelResp":
# 从数据库模型转换
- **普通属性**: 直接定义在 Model 中,总是返回。
- **可选属性**: 使用 `OnDemand[T]` 标记,仅在 `includes` 中指定时返回。
- **普通计算属性**: 使用 `@included` 装饰的静态方法,总是返回。
- **可选计算属性**: 使用 `@ondemand` 装饰的静态方法,仅在 `includes` 中指定时返回。
#### 使用方法
**转换模型**:
使用 `Model.transform` 方法将数据库实例转换为字典:
```python
user = await session.get(User, 1)
user_dict = await UserModel.transform(
user,
includes=["email"], # 指定需要返回的可选字段
some_context="foo-bar", # 如果计算属性需要上下文,可以传入额外参数
session=session # 可选传入自己的 session
)
```
**API 文档**:
在 FastAPI 路由中,使用 `Model.generate_typeddict` 生成准确的响应文档:
```python
@router.get("/users/{id}", response_model=UserModel.generate_typeddict(includes=("email",)))
async def get_user(id: int) -> dict:
...
return await UserModel.transform(user, includes=["email"])
```
数据库模块名应与表名相同,定义了多个模型的除外。
@@ -227,16 +285,16 @@ pre-commit 不提供 pyright 的 hook您需要手动运行 `pyright` 检查
**类型** 必须是以下之一:
* **feat**:新功能
* **fix**:错误修复
* **docs**:仅文档更改
* **style**:不影响代码含义的更改(空格、格式、缺少分号等)
* **refactor**:代码重构
* **perf**:改善性能的代码更改
* **test**:添加缺失的测试或修正现有测试
* **chore**:对构建过程或辅助工具和库(如文档生成)的更改
* **ci**:持续集成相关的更改
* **deploy**: 部署相关的更改
- **feat**:新功能
- **fix**:错误修复
- **docs**:仅文档更改
- **style**:不影响代码含义的更改(空格、格式、缺少分号等)
- **refactor**:代码重构
- **perf**:改善性能的代码更改
- **test**:添加缺失的测试或修正现有测试
- **chore**:对构建过程或辅助工具和库(如文档生成)的更改
- **ci**:持续集成相关的更改
- **deploy**: 部署相关的更改
**范围** 可以是任何指定提交更改位置的内容。例如 `api``db``auth` 等等。对整个项目的更改使用 `project`

View File

@@ -2,23 +2,31 @@ from .achievement import UserAchievement, UserAchievementResp
from .auth import OAuthClient, OAuthToken, TotpKeys, V1APIKeys
from .beatmap import (
Beatmap,
BeatmapResp,
BeatmapDict,
BeatmapModel,
)
from .beatmap_playcounts import (
BeatmapPlaycounts,
BeatmapPlaycountsDict,
BeatmapPlaycountsModel,
)
from .beatmap_playcounts import BeatmapPlaycounts, BeatmapPlaycountsResp
from .beatmap_sync import BeatmapSync
from .beatmap_tags import BeatmapTagVote
from .beatmapset import (
Beatmapset,
BeatmapsetResp,
BeatmapsetDict,
BeatmapsetModel,
)
from .beatmapset_ratings import BeatmapRating
from .best_scores import BestScore
from .chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatChannelDict,
ChatChannelModel,
ChatMessage,
ChatMessageResp,
ChatMessageDict,
ChatMessageModel,
)
from .counts import (
CountResp,
@@ -30,8 +38,8 @@ from .events import Event
from .favourite_beatmapset import FavouriteBeatmapset
from .item_attempts_count import (
ItemAttemptsCount,
ItemAttemptsResp,
PlaylistAggregateScore,
ItemAttemptsCountDict,
ItemAttemptsCountModel,
)
from .matchmaking import (
MatchmakingPool,
@@ -42,30 +50,32 @@ from .multiplayer_event import MultiplayerEvent, MultiplayerEventResp
from .notification import Notification, UserNotification
from .password_reset import PasswordReset
from .playlist_best_score import PlaylistBestScore
from .playlists import Playlist, PlaylistResp
from .playlists import Playlist, PlaylistDict, PlaylistModel
from .rank_history import RankHistory, RankHistoryResp, RankTop
from .relationship import Relationship, RelationshipResp, RelationshipType
from .room import APIUploadedRoom, Room, RoomResp
from .relationship import Relationship, RelationshipDict, RelationshipModel, RelationshipType
from .room import APIUploadedRoom, Room, RoomDict, RoomModel
from .room_participated_user import RoomParticipatedUser
from .score import (
MultiplayerScores,
Score,
ScoreAround,
ScoreBase,
ScoreResp,
ScoreDict,
ScoreModel,
ScoreStatistics,
)
from .score_token import ScoreToken, ScoreTokenResp
from .search_beatmapset import SearchBeatmapsetsResp
from .statistics import (
UserStatistics,
UserStatisticsResp,
UserStatisticsDict,
UserStatisticsModel,
)
from .team import Team, TeamMember, TeamRequest, TeamResp
from .total_score_best_scores import TotalScoreBestScore
from .user import (
MeResp,
User,
UserResp,
UserDict,
UserModel,
)
from .user_account_history import (
UserAccountHistory,
@@ -79,20 +89,25 @@ from .verification import EmailVerification, LoginSession, LoginSessionResp, Tru
__all__ = [
"APIUploadedRoom",
"Beatmap",
"BeatmapDict",
"BeatmapModel",
"BeatmapPlaycounts",
"BeatmapPlaycountsResp",
"BeatmapPlaycountsDict",
"BeatmapPlaycountsModel",
"BeatmapRating",
"BeatmapResp",
"BeatmapSync",
"BeatmapTagVote",
"Beatmapset",
"BeatmapsetResp",
"BeatmapsetDict",
"BeatmapsetModel",
"BestScore",
"ChannelType",
"ChatChannel",
"ChatChannelResp",
"ChatChannelDict",
"ChatChannelModel",
"ChatMessage",
"ChatMessageResp",
"ChatMessageDict",
"ChatMessageModel",
"CountResp",
"DailyChallengeStats",
"DailyChallengeStatsResp",
@@ -100,13 +115,13 @@ __all__ = [
"Event",
"FavouriteBeatmapset",
"ItemAttemptsCount",
"ItemAttemptsResp",
"ItemAttemptsCountDict",
"ItemAttemptsCountModel",
"LoginSession",
"LoginSessionResp",
"MatchmakingPool",
"MatchmakingPoolBeatmap",
"MatchmakingUserStats",
"MeResp",
"MonthlyPlaycounts",
"MultiplayerEvent",
"MultiplayerEventResp",
@@ -116,26 +131,29 @@ __all__ = [
"OAuthToken",
"PasswordReset",
"Playlist",
"PlaylistAggregateScore",
"PlaylistBestScore",
"PlaylistResp",
"PlaylistDict",
"PlaylistModel",
"RankHistory",
"RankHistoryResp",
"RankTop",
"Relationship",
"RelationshipResp",
"RelationshipDict",
"RelationshipModel",
"RelationshipType",
"ReplayWatchedCount",
"Room",
"RoomDict",
"RoomModel",
"RoomParticipatedUser",
"RoomResp",
"Score",
"ScoreAround",
"ScoreBase",
"ScoreResp",
"ScoreDict",
"ScoreModel",
"ScoreStatistics",
"ScoreToken",
"ScoreTokenResp",
"SearchBeatmapsetsResp",
"Team",
"TeamMember",
"TeamRequest",
@@ -149,17 +167,18 @@ __all__ = [
"UserAccountHistoryResp",
"UserAccountHistoryType",
"UserAchievement",
"UserAchievement",
"UserAchievementResp",
"UserDict",
"UserLoginLog",
"UserModel",
"UserNotification",
"UserPreference",
"UserResp",
"UserStatistics",
"UserStatisticsResp",
"UserStatisticsDict",
"UserStatisticsModel",
"V1APIKeys",
]
for i in __all__:
if i.endswith("Resp"):
globals()[i].model_rebuild() # type: ignore[call-arg]
if i.endswith("Model") or i.endswith("Resp"):
globals()[i].model_rebuild()

499
app/database/_base.py Normal file
View File

@@ -0,0 +1,499 @@
from collections.abc import Awaitable, Callable, Sequence
from functools import lru_cache, wraps
import inspect
import sys
from types import NoneType, get_original_bases
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Concatenate,
ForwardRef,
ParamSpec,
TypedDict,
cast,
get_args,
get_origin,
overload,
)
from app.models.model import UTCBaseModel
from app.utils import type_is_optional
from sqlalchemy.ext.asyncio import async_object_session
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import SQLModelMetaclass
_dict_to_model: dict[type, type["DatabaseModel"]] = {}
def _safe_evaluate_forwardref(type_: str | ForwardRef, module_name: str) -> Any:
"""Safely evaluate a ForwardRef, with fallback to app.database module"""
if isinstance(type_, str):
type_ = ForwardRef(type_)
try:
return evaluate_forwardref(
type_,
globalns=vars(sys.modules[module_name]),
localns={},
)
except (NameError, AttributeError, KeyError):
# Fallback to app.database module
try:
import app.database
return evaluate_forwardref(
type_,
globalns=vars(app.database),
localns={},
)
except (NameError, AttributeError, KeyError):
return None
class OnDemand[T]:
if TYPE_CHECKING:
def __get__(self, instance: object | None, owner: Any) -> T: ...
def __set__(self, instance: Any, value: T) -> None: ...
def __delete__(self, instance: Any) -> None: ...
class Exclude[T]:
if TYPE_CHECKING:
def __get__(self, instance: object | None, owner: Any) -> T: ...
def __set__(self, instance: Any, value: T) -> None: ...
def __delete__(self, instance: Any) -> None: ...
# https://github.com/fastapi/sqlmodel/blob/main/sqlmodel/_compat.py#L126-L140
def _get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
raw_annotations: dict[str, Any] = class_dict.get("__annotations__", {})
if sys.version_info >= (3, 14) and "__annotations__" not in class_dict:
# See https://github.com/pydantic/pydantic/pull/11991
from annotationlib import (
Format,
call_annotate_function,
get_annotate_from_class_namespace,
)
if annotate := get_annotate_from_class_namespace(class_dict):
raw_annotations = call_annotate_function(annotate, format=Format.FORWARDREF)
return raw_annotations
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L58-L77
if sys.version_info < (3, 12, 4):
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
# Even though it is the right signature for python 3.9, mypy complains with
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
# Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid:
# TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard'
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
else:
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
# Pydantic 1.x will not support PEP 695 syntax, but provide `type_params` to avoid
# warnings:
return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set())
class DatabaseModelMetaclass(SQLModelMetaclass):
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
**kwargs: Any,
) -> "DatabaseModelMetaclass":
original_annotations = _get_annotations(namespace)
new_annotations = {}
ondemands = []
excludes = []
for k, v in original_annotations.items():
if get_origin(v) is OnDemand:
inner_type = v.__args__[0]
new_annotations[k] = inner_type
ondemands.append(k)
elif get_origin(v) is Exclude:
inner_type = v.__args__[0]
new_annotations[k] = inner_type
excludes.append(k)
else:
new_annotations[k] = v
new_class = super().__new__(
cls,
name,
bases,
{
**namespace,
"__annotations__": new_annotations,
},
**kwargs,
)
new_class._CALCULATED_FIELDS = dict(getattr(new_class, "_CALCULATED_FIELDS", {}))
new_class._ONDEMAND_DATABASE_FIELDS = list(getattr(new_class, "_ONDEMAND_DATABASE_FIELDS", [])) + list(
ondemands
)
new_class._ONDEMAND_CALCULATED_FIELDS = dict(getattr(new_class, "_ONDEMAND_CALCULATED_FIELDS", {}))
new_class._EXCLUDED_DATABASE_FIELDS = list(getattr(new_class, "_EXCLUDED_DATABASE_FIELDS", [])) + list(excludes)
for attr_name, attr_value in namespace.items():
target = _get_callable_target(attr_value)
if target is None:
continue
if getattr(target, "__included__", False):
new_class._CALCULATED_FIELDS[attr_name] = _get_return_type(target)
_pre_calculate_context_params(target, attr_value)
if getattr(target, "__calculated_ondemand__", False):
new_class._ONDEMAND_CALCULATED_FIELDS[attr_name] = _get_return_type(target)
_pre_calculate_context_params(target, attr_value)
# Register TDict to DatabaseModel mapping
for base in get_original_bases(new_class):
cls_name = base.__name__
if "DatabaseModel" in cls_name and "[" in cls_name and "]" in cls_name:
generic_type_name = cls_name[cls_name.index("[") : cls_name.rindex("]") + 1]
generic_type = evaluate_forwardref(
ForwardRef(generic_type_name),
globalns=vars(sys.modules[new_class.__module__]),
localns={},
)
_dict_to_model[generic_type[0]] = new_class
return new_class
def _pre_calculate_context_params(target: Callable, attr_value: Any) -> None:
if hasattr(target, "__context_params__"):
return
sig = inspect.signature(target)
params = list(sig.parameters.keys())
start_index = 2
if isinstance(attr_value, classmethod):
start_index = 3
context_params = [] if len(params) < start_index else params[start_index:]
setattr(target, "__context_params__", context_params)
def _get_callable_target(value: Any) -> Callable | None:
if isinstance(value, (staticmethod, classmethod)):
return value.__func__
if inspect.isfunction(value):
return value
if inspect.ismethod(value):
return value.__func__
return None
def _mark_callable(value: Any, flag: str) -> Callable | None:
target = _get_callable_target(value)
if target is None:
return None
setattr(target, flag, True)
return target
def _get_return_type(func: Callable) -> type:
sig = inspect.get_annotations(func)
return sig.get("return", Any)
P = ParamSpec("P")
CalculatedField = Callable[Concatenate[AsyncSession, Any, P], Awaitable[Any]]
DecoratorTarget = CalculatedField | staticmethod | classmethod
def included(func: DecoratorTarget) -> DecoratorTarget:
marker = _mark_callable(func, "__included__")
if marker is None:
raise RuntimeError("@included is only usable on callables.")
@wraps(marker)
async def wrapper(*args, **kwargs):
return await marker(*args, **kwargs)
if isinstance(func, staticmethod):
return staticmethod(wrapper)
if isinstance(func, classmethod):
return classmethod(wrapper)
return wrapper
def ondemand(func: DecoratorTarget) -> DecoratorTarget:
marker = _mark_callable(func, "__calculated_ondemand__")
if marker is None:
raise RuntimeError("@ondemand is only usable on callables.")
@wraps(marker)
async def wrapper(*args, **kwargs):
return await marker(*args, **kwargs)
if isinstance(func, staticmethod):
return staticmethod(wrapper)
if isinstance(func, classmethod):
return classmethod(wrapper)
return wrapper
async def call_awaitable_with_context(
func: CalculatedField,
session: AsyncSession,
instance: Any,
context: dict[str, Any],
) -> Any:
context_params: list[str] | None = getattr(func, "__context_params__", None)
if context_params is None:
# Fallback if not pre-calculated
sig = inspect.signature(func)
if len(sig.parameters) == 2:
return await func(session, instance)
else:
call_params = {}
for param in sig.parameters.values():
if param.name in context:
call_params[param.name] = context[param.name]
return await func(session, instance, **call_params)
if not context_params:
return await func(session, instance)
call_params = {}
for name in context_params:
if name in context:
call_params[name] = context[name]
return await func(session, instance, **call_params)
class DatabaseModel[TDict](SQLModel, UTCBaseModel, metaclass=DatabaseModelMetaclass):
_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
_ONDEMAND_DATABASE_FIELDS: ClassVar[list[str]] = []
_ONDEMAND_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
_EXCLUDED_DATABASE_FIELDS: ClassVar[list[str]] = []
@overload
@classmethod
async def transform(
cls,
db_instance: "DatabaseModel",
*,
session: AsyncSession,
includes: list[str] | None = None,
**context: Any,
) -> TDict: ...
@overload
@classmethod
async def transform(
cls,
db_instance: "DatabaseModel",
*,
includes: list[str] | None = None,
**context: Any,
) -> TDict: ...
@classmethod
async def transform(
cls,
db_instance: "DatabaseModel",
*,
session: AsyncSession | None = None,
includes: list[str] | None = None,
**context: Any,
) -> TDict:
includes = includes.copy() if includes is not None else []
session = cast(AsyncSession | None, async_object_session(db_instance)) if session is None else session
if session is None:
raise RuntimeError("DatabaseModel.transform requires a session-bound instance.")
resp_obj = cls.model_validate(db_instance.model_dump())
data = resp_obj.model_dump()
for field in cls._CALCULATED_FIELDS:
func = getattr(cls, field)
value = await call_awaitable_with_context(func, session, db_instance, context)
data[field] = value
sub_include_map: dict[str, list[str]] = {}
for include in [i for i in includes if "." in i]:
parent, sub_include = include.split(".", 1)
if parent not in sub_include_map:
sub_include_map[parent] = []
sub_include_map[parent].append(sub_include)
includes.remove(include) # pyright: ignore[reportOptionalMemberAccess]
for field, sub_includes in sub_include_map.items():
if field in cls._ONDEMAND_CALCULATED_FIELDS:
func = getattr(cls, field)
value = await call_awaitable_with_context(
func, session, db_instance, {**context, "includes": sub_includes}
)
data[field] = value
for include in includes:
if include in data:
continue
if include in cls._ONDEMAND_CALCULATED_FIELDS:
func = getattr(cls, include)
value = await call_awaitable_with_context(func, session, db_instance, context)
data[include] = value
for field in cls._ONDEMAND_DATABASE_FIELDS:
if field not in includes:
del data[field]
for field in cls._EXCLUDED_DATABASE_FIELDS:
if field in data:
del data[field]
return cast(TDict, data)
@classmethod
async def transform_many(
cls,
db_instances: Sequence["DatabaseModel"],
*,
session: AsyncSession | None = None,
includes: list[str] | None = None,
**context: Any,
) -> list[TDict]:
if not db_instances:
return []
# SQLAlchemy AsyncSession is not concurrency-safe, so we cannot use asyncio.gather here
# if the transform method performs any database operations using the shared session.
# Since we don't know if the transform method (or its calculated fields) will use the DB,
# we must execute them serially to be safe.
results = []
for instance in db_instances:
results.append(await cls.transform(instance, session=session, includes=includes, **context))
return results
@classmethod
@lru_cache
def generate_typeddict(cls, includes: tuple[str, ...] | None = None) -> type[TypedDict]: # pyright: ignore[reportInvalidTypeForm]
def _evaluate_type(field_type: Any, *, resolve_database_model: bool = False, field_name: str = "") -> Any:
# Evaluate ForwardRef if present
if isinstance(field_type, (str, ForwardRef)):
resolved = _safe_evaluate_forwardref(field_type, cls.__module__)
if resolved is not None:
field_type = resolved
origin_type = get_origin(field_type)
inner_type = field_type
args = get_args(field_type)
is_optional = type_is_optional(field_type) # pyright: ignore[reportArgumentType]
if is_optional:
inner_type = next((arg for arg in args if arg is not NoneType), field_type)
is_list = False
if origin_type is list:
is_list = True
inner_type = args[0]
# Evaluate ForwardRef in inner_type if present
if isinstance(inner_type, (str, ForwardRef)):
resolved = _safe_evaluate_forwardref(inner_type, cls.__module__)
if resolved is not None:
inner_type = resolved
if not resolve_database_model:
if is_optional:
return inner_type | None # pyright: ignore[reportOperatorIssue]
elif is_list:
return list[inner_type]
return inner_type
model_class = None
# First check if inner_type is directly a DatabaseModel subclass
try:
if inspect.isclass(inner_type) and issubclass(inner_type, DatabaseModel): # type: ignore
model_class = inner_type
except TypeError:
pass
# If not found, look up in _dict_to_model
if model_class is None:
model_class = _dict_to_model.get(inner_type) # type: ignore
if model_class is not None:
nested_dict = model_class.generate_typeddict(tuple(sub_include_map.get(field_name, ())))
resolved_type = list[nested_dict] if is_list else nested_dict # type: ignore
if is_optional:
resolved_type = resolved_type | None # type: ignore
return resolved_type
# Fallback: use the resolved inner_type
resolved_type = list[inner_type] if is_list else inner_type # type: ignore
if is_optional:
resolved_type = resolved_type | None # type: ignore
return resolved_type
if includes is None:
includes = ()
# Parse nested includes
direct_includes = []
sub_include_map: dict[str, list[str]] = {}
for include in includes:
if "." in include:
parent, sub_include = include.split(".", 1)
if parent not in sub_include_map:
sub_include_map[parent] = []
sub_include_map[parent].append(sub_include)
if parent not in direct_includes:
direct_includes.append(parent)
else:
direct_includes.append(include)
fields = {}
# Process model fields
for field_name, field_info in cls.model_fields.items():
field_type = field_info.annotation or Any
field_type = _evaluate_type(field_type, field_name=field_name)
if field_name in cls._ONDEMAND_DATABASE_FIELDS and field_name not in direct_includes:
continue
else:
fields[field_name] = field_type
# Process calculated fields
for field_name, field_type in cls._CALCULATED_FIELDS.items():
field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
fields[field_name] = field_type
# Process ondemand calculated fields
for field_name, field_type in cls._ONDEMAND_CALCULATED_FIELDS.items():
if field_name not in direct_includes:
continue
field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
fields[field_name] = field_type
return TypedDict(f"{cls.__name__}Dict[{', '.join(includes)}]" if includes else f"{cls.__name__}Dict", fields) # pyright: ignore[reportArgumentType]

View File

@@ -1,117 +1,339 @@
from datetime import datetime
import hashlib
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
from app.calculator import get_calculator
from app.config import settings
from app.database.beatmap_tags import BeatmapTagVote
from app.database.failtime import FailTime, FailTimeResp
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
from app.models.performance import DifficultyAttributesUnion
from app.models.score import GameMode
from ._base import DatabaseModel, OnDemand, included, ondemand
from .beatmap_playcounts import BeatmapPlaycounts
from .beatmapset import Beatmapset, BeatmapsetResp
from .beatmap_tags import BeatmapTagVote
from .beatmapset import Beatmapset, BeatmapsetDict, BeatmapsetModel
from .failtime import FailTime, FailTimeResp
from .user import User, UserDict, UserModel
from pydantic import BaseModel, TypeAdapter
from redis.asyncio import Redis
from sqlalchemy import Column, DateTime
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
from .user import User
class BeatmapOwner(SQLModel):
id: int
username: str
class BeatmapBase(SQLModel):
# Beatmap
url: str
class BeatmapDict(TypedDict):
beatmapset_id: int
difficulty_rating: float
id: int
mode: GameMode
total_length: int
user_id: int
version: str
url: str
checksum: NotRequired[str]
max_combo: NotRequired[int | None]
ar: NotRequired[float]
cs: NotRequired[float]
drain: NotRequired[float]
accuracy: NotRequired[float]
bpm: NotRequired[float]
count_circles: NotRequired[int]
count_sliders: NotRequired[int]
count_spinners: NotRequired[int]
deleted_at: NotRequired[datetime | None]
hit_length: NotRequired[int]
last_updated: NotRequired[datetime]
status: NotRequired[str]
beatmapset: NotRequired[BeatmapsetDict]
current_user_playcount: NotRequired[int]
current_user_tag_ids: NotRequired[list[int]]
failtimes: NotRequired[FailTimeResp]
top_tag_ids: NotRequired[list[dict[str, int]]]
user: NotRequired[UserDict]
convert: NotRequired[bool]
is_scoreable: NotRequired[bool]
mode_int: NotRequired[int]
ranked: NotRequired[int]
playcount: NotRequired[int]
passcount: NotRequired[int]
class BeatmapModel(DatabaseModel[BeatmapDict]):
BEATMAP_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
"checksum",
"accuracy",
"ar",
"bpm",
"convert",
"count_circles",
"count_sliders",
"count_spinners",
"cs",
"deleted_at",
"drain",
"hit_length",
"is_scoreable",
"last_updated",
"mode_int",
"passcount",
"playcount",
"ranked",
"url",
]
DEFAULT_API_INCLUDES: ClassVar[list[str]] = [
"beatmapset.ratings",
"current_user_playcount",
"failtimes",
"max_combo",
"owners",
]
TRANSFORMER_INCLUDES: ClassVar[list[str]] = [*DEFAULT_API_INCLUDES, *BEATMAP_TRANSFORMER_INCLUDES]
# Beatmap
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
difficulty_rating: float = Field(default=0.0, index=True)
id: int = Field(primary_key=True, index=True)
mode: GameMode
total_length: int
user_id: int = Field(index=True)
version: str = Field(index=True)
url: OnDemand[str]
# optional
checksum: str = Field(sa_column=Column(VARCHAR(32), index=True))
current_user_playcount: int = Field(default=0)
max_combo: int | None = Field(default=0)
# TODO: failtimes, owners
checksum: OnDemand[str] = Field(sa_column=Column(VARCHAR(32), index=True))
max_combo: OnDemand[int | None] = Field(default=0)
# TODO: owners
# BeatmapExtended
ar: float = Field(default=0.0)
cs: float = Field(default=0.0)
drain: float = Field(default=0.0) # hp
accuracy: float = Field(default=0.0) # od
bpm: float = Field(default=0.0)
count_circles: int = Field(default=0)
count_sliders: int = Field(default=0)
count_spinners: int = Field(default=0)
deleted_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
hit_length: int = Field(default=0)
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
ar: OnDemand[float] = Field(default=0.0)
cs: OnDemand[float] = Field(default=0.0)
drain: OnDemand[float] = Field(default=0.0) # hp
accuracy: OnDemand[float] = Field(default=0.0) # od
bpm: OnDemand[float] = Field(default=0.0)
count_circles: OnDemand[int] = Field(default=0)
count_sliders: OnDemand[int] = Field(default=0)
count_spinners: OnDemand[int] = Field(default=0)
deleted_at: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime))
hit_length: OnDemand[int] = Field(default=0)
last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
@included
@staticmethod
async def status(_session: AsyncSession, beatmap: "Beatmap") -> str:
if settings.enable_all_beatmap_leaderboard and not beatmap.beatmap_status.has_leaderboard():
return BeatmapRankStatus.APPROVED.name.lower()
return beatmap.beatmap_status.name.lower()
@ondemand
@staticmethod
async def beatmapset(
_session: AsyncSession,
beatmap: "Beatmap",
includes: list[str] | None = None,
) -> BeatmapsetDict | None:
if beatmap.beatmapset is not None:
return await BeatmapsetModel.transform(
beatmap.beatmapset, includes=(includes or []) + Beatmapset.BEATMAPSET_TRANSFORMER_INCLUDES
)
@ondemand
@staticmethod
async def current_user_playcount(_session: AsyncSession, beatmap: "Beatmap", user: "User") -> int:
playcount = (
await _session.exec(
select(BeatmapPlaycounts.playcount).where(
BeatmapPlaycounts.beatmap_id == beatmap.id, BeatmapPlaycounts.user_id == user.id
)
)
).first()
return int(playcount or 0)
@ondemand
@staticmethod
async def current_user_tag_ids(_session: AsyncSession, beatmap: "Beatmap", user: "User | None" = None) -> list[int]:
if user is None:
return []
tag_ids = (
await _session.exec(
select(BeatmapTagVote.tag_id).where(
BeatmapTagVote.beatmap_id == beatmap.id,
BeatmapTagVote.user_id == user.id,
)
)
).all()
return list(tag_ids)
@ondemand
@staticmethod
async def failtimes(_session: AsyncSession, beatmap: "Beatmap") -> FailTimeResp:
if beatmap.failtimes is not None:
return FailTimeResp.from_db(beatmap.failtimes)
return FailTimeResp()
@ondemand
@staticmethod
async def top_tag_ids(_session: AsyncSession, beatmap: "Beatmap") -> list[dict[str, int]]:
all_votes = (
await _session.exec(
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
.where(BeatmapTagVote.beatmap_id == beatmap.id)
.group_by(col(BeatmapTagVote.tag_id))
.having(func.count() > settings.beatmap_tag_top_count)
)
).all()
top_tag_ids: list[dict[str, int]] = []
for id, votes in all_votes:
top_tag_ids.append({"tag_id": id, "count": votes})
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
return top_tag_ids
@ondemand
@staticmethod
async def user(
_session: AsyncSession,
beatmap: "Beatmap",
includes: list[str] | None = None,
) -> UserDict | None:
from .user import User
user = await _session.get(User, beatmap.user_id)
if user is None:
return None
return await UserModel.transform(user, includes=includes)
@ondemand
@staticmethod
async def convert(_session: AsyncSession, _beatmap: "Beatmap") -> bool:
return False
@ondemand
@staticmethod
async def is_scoreable(_session: AsyncSession, beatmap: "Beatmap") -> bool:
beatmap_status = beatmap.beatmap_status
if settings.enable_all_beatmap_leaderboard:
return True
return beatmap_status.has_leaderboard()
@ondemand
@staticmethod
async def mode_int(_session: AsyncSession, beatmap: "Beatmap") -> int:
return int(beatmap.mode)
@ondemand
@staticmethod
async def ranked(_session: AsyncSession, beatmap: "Beatmap") -> int:
beatmap_status = beatmap.beatmap_status
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
return BeatmapRankStatus.APPROVED.value
return beatmap_status.value
@ondemand
@staticmethod
async def playcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
result = (
await _session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
)
).first()
return int(result or 0)
@ondemand
@staticmethod
async def passcount(_session: AsyncSession, beatmap: "Beatmap") -> int:
from .score import Score
return (
await _session.exec(
select(func.count())
.select_from(Score)
.where(
Score.beatmap_id == beatmap.id,
col(Score.passed).is_(True),
)
)
).one()
class Beatmap(BeatmapBase, table=True):
class Beatmap(AsyncAttrs, BeatmapModel, table=True):
__tablename__: str = "beatmaps"
id: int = Field(primary_key=True, index=True)
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus = Field(index=True)
# optional
beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
beatmapset: "Beatmapset" = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
@classmethod
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
d = resp.model_dump()
del d["beatmapset"]
async def from_resp_no_save(cls, _session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
d = {k: v for k, v in resp.items() if k != "beatmapset"}
beatmapset_id = resp.get("beatmapset_id")
bid = resp.get("id")
ranked = resp.get("ranked")
if beatmapset_id is None or bid is None or ranked is None:
raise ValueError("beatmapset_id, id and ranked are required")
beatmap = cls.model_validate(
{
**d,
"beatmapset_id": resp.beatmapset_id,
"id": resp.id,
"beatmap_status": BeatmapRankStatus(resp.ranked),
"beatmapset_id": beatmapset_id,
"id": bid,
"beatmap_status": BeatmapRankStatus(ranked),
}
)
return beatmap
@classmethod
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
async def from_resp(cls, session: AsyncSession, resp: BeatmapDict) -> "Beatmap":
beatmap = await cls.from_resp_no_save(session, resp)
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
resp_id = resp.get("id")
if resp_id is None:
raise ValueError("id is required")
if not (await session.exec(select(exists()).where(Beatmap.id == resp_id))).first():
session.add(beatmap)
await session.commit()
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
return (await session.exec(select(Beatmap).where(Beatmap.id == resp_id))).one()
@classmethod
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
async def from_resp_batch(cls, session: AsyncSession, inp: list[BeatmapDict], from_: int = 0) -> list["Beatmap"]:
beatmaps = []
for resp in inp:
if resp.id == from_:
for resp_dict in inp:
bid = resp_dict.get("id")
if bid == from_ or bid is None:
continue
d = resp.model_dump()
del d["beatmapset"]
beatmapset_id = resp_dict.get("beatmapset_id")
ranked = resp_dict.get("ranked")
if beatmapset_id is None or ranked is None:
continue
# 创建 beatmap 字典,移除 beatmapset
d = {k: v for k, v in resp_dict.items() if k != "beatmapset"}
beatmap = Beatmap.model_validate(
{
**d,
"beatmapset_id": resp.beatmapset_id,
"id": resp.id,
"beatmap_status": BeatmapRankStatus(resp.ranked),
"beatmapset_id": beatmapset_id,
"id": bid,
"beatmap_status": BeatmapRankStatus(ranked),
}
)
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
if not (await session.exec(select(exists()).where(Beatmap.id == bid))).first():
session.add(beatmap)
beatmaps.append(beatmap)
await session.commit()
for beatmap in beatmaps:
await session.refresh(beatmap)
return beatmaps
@classmethod
@@ -132,10 +354,14 @@ class Beatmap(BeatmapBase, table=True):
beatmap = (await session.exec(stmt)).first()
if not beatmap:
resp = await fetcher.get_beatmap(bid, md5)
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
beatmapset_id = resp.get("beatmapset_id")
if beatmapset_id is None:
raise ValueError("beatmapset_id is required")
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == beatmapset_id))
if not r.first():
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
set_resp = await fetcher.get_beatmapset(beatmapset_id)
resp_id = resp.get("id")
await Beatmapset.from_resp(session, set_resp, from_=resp_id or 0)
return await Beatmap.from_resp(session, resp)
return beatmap
@@ -145,97 +371,6 @@ class APIBeatmapTag(BaseModel):
count: int
class BeatmapResp(BeatmapBase):
id: int
beatmapset_id: int
beatmapset: BeatmapsetResp | None = None
convert: bool = False
is_scoreable: bool
status: str
mode_int: int
ranked: int
url: str = ""
playcount: int = 0
passcount: int = 0
failtimes: FailTimeResp | None = None
top_tag_ids: list[APIBeatmapTag] | None = None
current_user_tag_ids: list[int] | None = None
is_deleted: bool = False
@classmethod
async def from_db(
cls,
beatmap: Beatmap,
query_mode: GameMode | None = None,
from_set: bool = False,
session: AsyncSession | None = None,
user: "User | None" = None,
) -> "BeatmapResp":
from .score import Score
beatmap_ = beatmap.model_dump()
beatmap_status = beatmap.beatmap_status
if query_mode is not None and beatmap.mode != query_mode:
beatmap_["convert"] = True
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
else:
beatmap_["status"] = beatmap_status.name.lower()
beatmap_["ranked"] = beatmap_status.value
beatmap_["mode_int"] = int(beatmap.mode)
beatmap_["is_deleted"] = beatmap.deleted_at is not None
if not from_set:
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
if beatmap.failtimes is not None:
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
else:
beatmap_["failtimes"] = FailTimeResp()
if session:
beatmap_["playcount"] = (
await session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(BeatmapPlaycounts.beatmap_id == beatmap.id)
)
).first() or 0
beatmap_["passcount"] = (
await session.exec(
select(func.count())
.select_from(Score)
.where(
Score.beatmap_id == beatmap.id,
col(Score.passed).is_(True),
)
)
).one()
all_votes = (
await session.exec(
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
.where(BeatmapTagVote.beatmap_id == beatmap.id)
.group_by(col(BeatmapTagVote.tag_id))
.having(func.count() > settings.beatmap_tag_top_count)
)
).all()
top_tag_ids: list[dict[str, int]] = []
for id, votes in all_votes:
top_tag_ids.append({"tag_id": id, "count": votes})
top_tag_ids.sort(key=lambda x: x["count"], reverse=True)
beatmap_["top_tag_ids"] = top_tag_ids
if user is not None:
beatmap_["current_user_tag_ids"] = (
await session.exec(
select(BeatmapTagVote.tag_id)
.where(BeatmapTagVote.beatmap_id == beatmap.id)
.where(BeatmapTagVote.user_id == user.id)
)
).all()
else:
beatmap_["current_user_tag_ids"] = []
return cls.model_validate(beatmap_)
class BannedBeatmaps(SQLModel, table=True):
__tablename__: str = "banned_beatmaps"
id: int | None = Field(primary_key=True, index=True, default=None)

View File

@@ -1,10 +1,11 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, NotRequired, TypedDict
from app.config import settings
from app.database.events import Event, EventType
from app.utils import utcnow
from pydantic import BaseModel
from ._base import DatabaseModel, included
from .events import Event, EventType
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
@@ -12,52 +13,65 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
SQLModel,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .beatmap import Beatmap, BeatmapResp
from .beatmapset import BeatmapsetResp
from .beatmap import Beatmap, BeatmapDict
from .beatmapset import BeatmapsetDict
from .user import User
class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True):
class BeatmapPlaycountsDict(TypedDict):
user_id: int
beatmap_id: int
count: NotRequired[int]
beatmap: NotRequired["BeatmapDict"]
beatmapset: NotRequired["BeatmapsetDict"]
class BeatmapPlaycountsModel(AsyncAttrs, DatabaseModel[BeatmapPlaycountsDict]):
__tablename__: str = "beatmap_playcounts"
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True), exclude=True
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
playcount: int = Field(default=0)
playcount: int = Field(default=0, exclude=True)
@included
@staticmethod
async def count(_session: AsyncSession, obj: "BeatmapPlaycounts") -> int:
return obj.playcount
@included
@staticmethod
async def beatmap(
_session: AsyncSession, obj: "BeatmapPlaycounts", includes: list[str] | None = None
) -> "BeatmapDict":
from .beatmap import BeatmapModel
await obj.awaitable_attrs.beatmap
return await BeatmapModel.transform(obj.beatmap, includes=includes)
@included
@staticmethod
async def beatmapset(
_session: AsyncSession, obj: "BeatmapPlaycounts", includes: list[str] | None = None
) -> "BeatmapsetDict":
from .beatmap import BeatmapsetModel
await obj.awaitable_attrs.beatmap
return await BeatmapsetModel.transform(obj.beatmap.beatmapset, includes=includes)
class BeatmapPlaycounts(BeatmapPlaycountsModel, table=True):
user: "User" = Relationship()
beatmap: "Beatmap" = Relationship()
class BeatmapPlaycountsResp(BaseModel):
beatmap_id: int
beatmap: "BeatmapResp | None" = None
beatmapset: "BeatmapsetResp | None" = None
count: int
@classmethod
async def from_db(cls, db_model: BeatmapPlaycounts) -> "BeatmapPlaycountsResp":
from .beatmap import BeatmapResp
from .beatmapset import BeatmapsetResp
await db_model.awaitable_attrs.beatmap
return cls(
beatmap_id=db_model.beatmap_id,
count=db_model.playcount,
beatmap=await BeatmapResp.from_db(db_model.beatmap),
beatmapset=await BeatmapsetResp.from_db(db_model.beatmap.beatmapset),
)
async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None:
existing_playcount = (
await session.exec(

View File

@@ -1,14 +1,15 @@
from datetime import datetime
from typing import TYPE_CHECKING, NotRequired, Self, TypedDict
from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict
from app.config import settings
from app.database.beatmap_playcounts import BeatmapPlaycounts
from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.score import GameMode
from .user import BASE_INCLUDES, User, UserResp
from ._base import DatabaseModel, OnDemand, included, ondemand
from .beatmap_playcounts import BeatmapPlaycounts
from .user import User, UserDict
from pydantic import BaseModel, field_validator, model_validator
from pydantic import BaseModel
from sqlalchemy import JSON, Boolean, Column, DateTime, Text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel, col, exists, func, select
@@ -17,7 +18,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
from .beatmap import Beatmap, BeatmapResp
from .beatmap import Beatmap, BeatmapDict
from .favourite_beatmapset import FavouriteBeatmapset
@@ -68,8 +69,99 @@ class BeatmapTranslationText(BaseModel):
id: int | None = None
class BeatmapsetBase(SQLModel):
class BeatmapsetDict(TypedDict):
id: int
artist: str
artist_unicode: str
covers: BeatmapCovers | None
creator: str
nsfw: bool
preview_url: str
source: str
spotlight: bool
title: str
title_unicode: str
track_id: int | None
user_id: int
video: bool
current_nominations: list[BeatmapNomination] | None
description: BeatmapDescription | None
pack_tags: list[str]
bpm: NotRequired[float]
can_be_hyped: NotRequired[bool]
discussion_locked: NotRequired[bool]
last_updated: NotRequired[datetime]
ranked_date: NotRequired[datetime | None]
storyboard: NotRequired[bool]
submitted_date: NotRequired[datetime]
tags: NotRequired[str]
discussion_enabled: NotRequired[bool]
legacy_thread_url: NotRequired[str | None]
status: NotRequired[str]
ranked: NotRequired[int]
is_scoreable: NotRequired[bool]
favourite_count: NotRequired[int]
genre_id: NotRequired[int]
hype: NotRequired[BeatmapHype]
language_id: NotRequired[int]
play_count: NotRequired[int]
availability: NotRequired[BeatmapAvailability]
beatmaps: NotRequired[list["BeatmapDict"]]
has_favourited: NotRequired[bool]
recent_favourites: NotRequired[list[UserDict]]
genre: NotRequired[BeatmapTranslationText]
language: NotRequired[BeatmapTranslationText]
nominations: NotRequired["BeatmapNominations"]
ratings: NotRequired[list[int]]
class BeatmapsetModel(DatabaseModel[BeatmapsetDict]):
BEATMAPSET_TRANSFORMER_INCLUDES: ClassVar[list[str]] = [
"availability",
"has_favourited",
"bpm",
"deleted_atcan_be_hyped",
"discussion_locked",
"is_scoreable",
"last_updated",
"legacy_thread_url",
"ranked",
"ranked_date",
"submitted_date",
"tags",
"rating",
"storyboard",
]
API_INCLUDES: ClassVar[list[str]] = [
*BEATMAPSET_TRANSFORMER_INCLUDES,
"beatmaps.current_user_playcount",
"beatmaps.current_user_tag_ids",
"beatmaps.max_combo",
"current_nominations",
"current_user_attributes",
"description",
"genre",
"language",
"pack_tags",
"ratings",
"recent_favourites",
"related_tags",
"related_users",
"user",
"version_count",
*[
f"beatmaps.{inc}"
for inc in {
"failtimes",
"owners",
"top_tag_ids",
}
],
]
# Beatmapset
id: int = Field(default=None, primary_key=True, index=True)
artist: str = Field(index=True)
artist_unicode: str = Field(index=True)
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
@@ -77,41 +169,285 @@ class BeatmapsetBase(SQLModel):
nsfw: bool = Field(default=False, sa_column=Column(Boolean))
preview_url: str
source: str = Field(default="")
spotlight: bool = Field(default=False, sa_column=Column(Boolean))
title: str = Field(index=True)
title_unicode: str = Field(index=True)
track_id: int | None = Field(default=None, index=True) # feature artist?
user_id: int = Field(index=True)
video: bool = Field(sa_column=Column(Boolean, index=True))
# optional
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON))
description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON))
current_nominations: OnDemand[list[BeatmapNomination] | None] = Field(None, sa_column=Column(JSON))
description: OnDemand[BeatmapDescription | None] = Field(default=None, sa_column=Column(JSON))
# TODO: discussions: list[BeatmapsetDiscussion] = None
# TODO: current_user_attributes: Optional[CurrentUserAttributes] = None
# TODO: events: Optional[list[BeatmapsetEvent]] = None
pack_tags: list[str] = Field(default=[], sa_column=Column(JSON))
pack_tags: OnDemand[list[str]] = Field(default=[], sa_column=Column(JSON))
# TODO: related_users: Optional[list[User]] = None
# TODO: user: Optional[User] = Field(default=None)
track_id: int | None = Field(default=None, index=True) # feature artist?
# BeatmapsetExtended
bpm: float = Field(default=0.0)
can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True))
storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
tags: str = Field(default="", sa_column=Column(Text))
bpm: OnDemand[float] = Field(default=0.0)
can_be_hyped: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean))
discussion_locked: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean))
last_updated: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
ranked_date: OnDemand[datetime | None] = Field(default=None, sa_column=Column(DateTime, index=True))
storyboard: OnDemand[bool] = Field(default=False, sa_column=Column(Boolean, index=True))
submitted_date: OnDemand[datetime] = Field(sa_column=Column(DateTime, index=True))
tags: OnDemand[str] = Field(default="", sa_column=Column(Text))
@ondemand
@staticmethod
async def legacy_thread_url(
_session: AsyncSession,
_beatmapset: "Beatmapset",
) -> str | None:
return None
@included
@staticmethod
async def discussion_enabled(
_session: AsyncSession,
_beatmapset: "Beatmapset",
) -> bool:
return True
@included
@staticmethod
async def status(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> str:
beatmap_status = beatmapset.beatmap_status
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
return BeatmapRankStatus.APPROVED.name.lower()
return beatmap_status.name.lower()
@included
@staticmethod
async def ranked(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
beatmap_status = beatmapset.beatmap_status
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
return BeatmapRankStatus.APPROVED.value
return beatmap_status.value
@ondemand
@staticmethod
async def is_scoreable(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> bool:
beatmap_status = beatmapset.beatmap_status
if settings.enable_all_beatmap_leaderboard:
return True
return beatmap_status.has_leaderboard()
@included
@staticmethod
async def favourite_count(
session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
from .favourite_beatmapset import FavouriteBeatmapset
count = await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
return count.one()
@included
@staticmethod
async def genre_id(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
return beatmapset.beatmap_genre.value
@ondemand
@staticmethod
async def hype(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapHype:
return BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required)
@included
@staticmethod
async def language_id(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
return beatmapset.beatmap_language.value
@included
@staticmethod
async def play_count(
session: AsyncSession,
beatmapset: "Beatmapset",
) -> int:
from .beatmap import Beatmap
playcount = await session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(
col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id)
)
)
return int(playcount.first() or 0)
@ondemand
@staticmethod
async def availability(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapAvailability:
return BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
)
@ondemand
@staticmethod
async def beatmaps(
_session: AsyncSession,
beatmapset: "Beatmapset",
includes: list[str] | None = None,
user: "User | None" = None,
) -> list["BeatmapDict"]:
from .beatmap import BeatmapModel
return [
await BeatmapModel.transform(
beatmap, includes=(includes or []) + BeatmapModel.BEATMAP_TRANSFORMER_INCLUDES, user=user
)
for beatmap in await beatmapset.awaitable_attrs.beatmaps
]
# @ondemand
# @staticmethod
# async def current_nominations(
# _session: AsyncSession,
# beatmapset: "Beatmapset",
# ) -> list[BeatmapNomination] | None:
# return beatmapset.current_nominations or []
@ondemand
@staticmethod
async def has_favourited(
session: AsyncSession,
beatmapset: "Beatmapset",
user: User | None = None,
) -> bool:
from .favourite_beatmapset import FavouriteBeatmapset
if session is None:
return False
query = select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
if user is not None:
query = query.where(FavouriteBeatmapset.user_id == user.id)
existing = (await session.exec(query)).first()
return existing is not None
@ondemand
@staticmethod
async def recent_favourites(
session: AsyncSession,
beatmapset: "Beatmapset",
includes: list[str] | None = None,
) -> list[UserDict]:
from .favourite_beatmapset import FavouriteBeatmapset
recent_favourites = (
await session.exec(
select(FavouriteBeatmapset)
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
.order_by(col(FavouriteBeatmapset.date).desc())
.limit(50)
)
).all()
return [
await User.transform(
(await favourite.awaitable_attrs.user),
includes=includes,
)
for favourite in recent_favourites
]
@ondemand
@staticmethod
async def genre(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapTranslationText:
return BeatmapTranslationText(
name=beatmapset.beatmap_genre.name,
id=beatmapset.beatmap_genre.value,
)
@ondemand
@staticmethod
async def language(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapTranslationText:
return BeatmapTranslationText(
name=beatmapset.beatmap_language.name,
id=beatmapset.beatmap_language.value,
)
@ondemand
@staticmethod
async def nominations(
_session: AsyncSession,
beatmapset: "Beatmapset",
) -> BeatmapNominations:
return BeatmapNominations(
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
)
# @ondemand
# @staticmethod
# async def user(
# session: AsyncSession,
# beatmapset: Beatmapset,
# includes: list[str] | None = None,
# ) -> dict[str, Any] | None:
# db_user = await session.get(User, beatmapset.user_id)
# if not db_user:
# return None
# return await UserResp.transform(db_user, includes=includes)
@ondemand
@staticmethod
async def ratings(
session: AsyncSession,
beatmapset: "Beatmapset",
) -> list[int]:
# Provide a stable default shape if no session is available
if session is None:
return []
from .beatmapset_ratings import BeatmapRating
beatmapset_all_ratings = (
await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id))
).all()
ratings_list = [0] * 11
for rating in beatmapset_all_ratings:
ratings_list[rating.rating] += 1
return ratings_list
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
class Beatmapset(AsyncAttrs, BeatmapsetModel, table=True):
__tablename__: str = "beatmapsets"
id: int = Field(default=None, primary_key=True, index=True)
# Beatmapset
beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True)
@@ -130,29 +466,45 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
d = resp.model_dump()
if resp.nominations:
d["nominations_required"] = resp.nominations.required
d["nominations_current"] = resp.nominations.current
if resp.hype:
d["hype_current"] = resp.hype.current
d["hype_required"] = resp.hype.required
if resp.genre_id:
d["beatmap_genre"] = Genre(resp.genre_id)
elif resp.genre:
d["beatmap_genre"] = Genre(resp.genre.id)
if resp.language_id:
d["beatmap_language"] = Language(resp.language_id)
elif resp.language:
d["beatmap_language"] = Language(resp.language.id)
async def from_resp_no_save(cls, resp: BeatmapsetDict) -> "Beatmapset":
# make a shallow copy so we can mutate safely
d: dict[str, Any] = dict(resp)
# nominations = resp.get("nominations")
# if nominations is not None:
# d["nominations_required"] = nominations.required
# d["nominations_current"] = nominations.current
hype = resp.get("hype")
if hype is not None:
d["hype_current"] = hype.current
d["hype_required"] = hype.required
genre_id = resp.get("genre_id")
genre = resp.get("genre")
if genre_id is not None:
d["beatmap_genre"] = Genre(genre_id)
elif genre is not None:
d["beatmap_genre"] = Genre(genre.id)
language_id = resp.get("language_id")
language = resp.get("language")
if language_id is not None:
d["beatmap_language"] = Language(language_id)
elif language is not None:
d["beatmap_language"] = Language(language.id)
availability = resp.get("availability")
ranked = resp.get("ranked")
if ranked is None:
raise ValueError("ranked field is required")
beatmapset = Beatmapset.model_validate(
{
**d,
"id": resp.id,
"beatmap_status": BeatmapRankStatus(resp.ranked),
"availability_info": resp.availability.more_information,
"download_disabled": resp.availability.download_disabled or False,
"beatmap_status": BeatmapRankStatus(ranked),
"availability_info": availability.more_information if availability is not None else None,
"download_disabled": bool(availability.download_disabled) if availability is not None else False,
}
)
return beatmapset
@@ -161,17 +513,19 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
async def from_resp(
cls,
session: AsyncSession,
resp: "BeatmapsetResp",
resp: BeatmapsetDict,
from_: int = 0,
) -> "Beatmapset":
from .beatmap import Beatmap
beatmapset_id = resp["id"]
beatmapset = await cls.from_resp_no_save(resp)
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
if not (await session.exec(select(exists()).where(Beatmapset.id == beatmapset_id))).first():
session.add(beatmapset)
await session.commit()
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == resp.id))).one()
beatmaps = resp.get("beatmaps", [])
await Beatmap.from_resp_batch(session, beatmaps, from_=from_)
beatmapset = (await session.exec(select(Beatmapset).where(Beatmapset.id == beatmapset_id))).one()
return beatmapset
@classmethod
@@ -183,170 +537,5 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
resp = await fetcher.get_beatmapset(sid)
beatmapset = await cls.from_resp(session, resp)
await get_beatmapset_update_service().add(resp)
await session.refresh(beatmapset)
return beatmapset
class BeatmapsetResp(BeatmapsetBase):
id: int
beatmaps: list["BeatmapResp"] = Field(default_factory=list)
discussion_enabled: bool = True
status: str
ranked: int
legacy_thread_url: str | None = ""
is_scoreable: bool
hype: BeatmapHype | None = None
availability: BeatmapAvailability
genre: BeatmapTranslationText | None = None
genre_id: int
language: BeatmapTranslationText | None = None
language_id: int
nominations: BeatmapNominations | None = None
has_favourited: bool = False
favourite_count: int = 0
recent_favourites: list[UserResp] = Field(default_factory=list)
play_count: int = 0
@field_validator(
"nsfw",
"spotlight",
"video",
"can_be_hyped",
"discussion_locked",
"storyboard",
"discussion_enabled",
"is_scoreable",
"has_favourited",
mode="before",
)
@classmethod
def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
if isinstance(v, int):
return bool(v)
return v
@model_validator(mode="after")
def fix_genre_language(self) -> Self:
if self.genre is None:
self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id)
if self.language is None:
self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id)
return self
@classmethod
async def from_db(
cls,
beatmapset: Beatmapset,
include: list[str] = [],
session: AsyncSession | None = None,
user: User | None = None,
) -> "BeatmapsetResp":
from .beatmap import Beatmap, BeatmapResp
from .favourite_beatmapset import FavouriteBeatmapset
update = {
"beatmaps": [
await BeatmapResp.from_db(beatmap, from_set=True, session=session)
for beatmap in await beatmapset.awaitable_attrs.beatmaps
],
"hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required),
"availability": BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
),
"genre": BeatmapTranslationText(
name=beatmapset.beatmap_genre.name,
id=beatmapset.beatmap_genre.value,
),
"language": BeatmapTranslationText(
name=beatmapset.beatmap_language.name,
id=beatmapset.beatmap_language.value,
),
"genre_id": beatmapset.beatmap_genre.value,
"language_id": beatmapset.beatmap_language.value,
"nominations": BeatmapNominations(
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
),
"is_scoreable": beatmapset.beatmap_status.has_leaderboard(),
**beatmapset.model_dump(),
}
if session is not None:
# 从数据库读取对应谱面集的评分
from .beatmapset_ratings import BeatmapRating
beatmapset_all_ratings = (
await session.exec(select(BeatmapRating).where(BeatmapRating.beatmapset_id == beatmapset.id))
).all()
ratings_list = [0] * 11
for rating in beatmapset_all_ratings:
ratings_list[rating.rating] += 1
update["ratings"] = ratings_list
else:
# 返回非空值避免客户端崩溃
if update.get("ratings") is None:
update["ratings"] = []
beatmap_status = beatmapset.beatmap_status
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
update["status"] = BeatmapRankStatus.APPROVED.name.lower()
update["ranked"] = BeatmapRankStatus.APPROVED.value
else:
update["status"] = beatmap_status.name.lower()
update["ranked"] = beatmap_status.value
if session and user:
existing_favourite = (
await session.exec(
select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
).first()
update["has_favourited"] = existing_favourite is not None
if session and "recent_favourites" in include:
recent_favourites = (
await session.exec(
select(FavouriteBeatmapset)
.where(
FavouriteBeatmapset.beatmapset_id == beatmapset.id,
)
.order_by(col(FavouriteBeatmapset.date).desc())
.limit(50)
)
).all()
update["recent_favourites"] = [
await UserResp.from_db(
await favourite.awaitable_attrs.user,
session=session,
include=BASE_INCLUDES,
)
for favourite in recent_favourites
]
if session:
update["favourite_count"] = (
await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
).one()
update["play_count"] = (
await session.exec(
select(func.sum(BeatmapPlaycounts.playcount)).where(
col(BeatmapPlaycounts.beatmap).has(col(Beatmap.beatmapset_id) == beatmapset.id)
)
)
).first() or 0
return cls.model_validate(
update,
)
class SearchBeatmapsetsResp(SQLModel):
beatmapsets: list[BeatmapsetResp]
total: int
cursor: dict[str, int | float | str] | None = None
cursor_string: str | None = None

View File

@@ -1,5 +1,5 @@
from app.database.beatmapset import Beatmapset
from app.database.user import User
from .beatmapset import Beatmapset
from .user import User
from sqlmodel import BigInteger, Column, Field, ForeignKey, Relationship, SQLModel

View File

@@ -1,8 +1,8 @@
from typing import TYPE_CHECKING
from app.database.statistics import UserStatistics
from app.models.score import GameMode
from .statistics import UserStatistics
from .user import User
from sqlmodel import (

View File

@@ -1,13 +1,14 @@
from datetime import datetime
from enum import Enum
from typing import Self
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
from app.database.user import RANKING_INCLUDES, User, UserResp
from app.models.model import UTCBaseModel
from app.utils import utcnow
from ._base import DatabaseModel, included, ondemand
from .user import User, UserDict, UserModel
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlmodel import (
VARCHAR,
BigInteger,
@@ -22,6 +23,8 @@ from sqlmodel import (
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.router.notification.server import ChatServer
# ChatChannel
@@ -44,16 +47,168 @@ class ChannelType(str, Enum):
TEAM = "TEAM"
class ChatChannelBase(SQLModel):
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
class MessageType(str, Enum):
ACTION = "action"
MARKDOWN = "markdown"
PLAIN = "plain"
class ChatChannelDict(TypedDict):
channel_id: int
description: str
name: str
icon: str | None
type: ChannelType
uuid: NotRequired[str | None]
message_length_limit: NotRequired[int]
moderated: NotRequired[bool]
current_user_attributes: NotRequired[ChatUserAttributes]
last_read_id: NotRequired[int | None]
last_message_id: NotRequired[int | None]
recent_messages: NotRequired[list["ChatMessageDict"]]
users: NotRequired[list[int]]
class ChatChannelModel(DatabaseModel[ChatChannelDict]):
CONVERSATION_INCLUDES: ClassVar[list[str]] = [
"last_message_id",
"users",
]
LISTING_INCLUDES: ClassVar[list[str]] = [
*CONVERSATION_INCLUDES,
"current_user_attributes",
"last_read_id",
]
channel_id: int = Field(primary_key=True, index=True, default=None)
description: str = Field(sa_column=Column(VARCHAR(255), index=True))
icon: str | None = Field(default=None)
type: ChannelType = Field(index=True)
@included
@staticmethod
async def name(session: AsyncSession, channel: "ChatChannel", user: User, server: "ChatServer") -> str:
users = server.channels.get(channel.channel_id, [])
if channel.type == ChannelType.PM and users and len(users) == 2:
target_user_id = next(u for u in users if u != user.id)
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
return target_name.one()
return channel.name
class ChatChannel(ChatChannelBase, table=True):
@included
@staticmethod
async def moderated(session: AsyncSession, channel: "ChatChannel", user: User) -> bool:
silence = (
await session.exec(
select(SilenceUser).where(
SilenceUser.channel_id == channel.channel_id,
SilenceUser.user_id == user.id,
)
)
).first()
return silence is not None
@ondemand
@staticmethod
async def current_user_attributes(
session: AsyncSession,
channel: "ChatChannel",
user: User,
) -> ChatUserAttributes:
from app.dependencies.database import get_redis
silence = (
await session.exec(
select(SilenceUser).where(
SilenceUser.channel_id == channel.channel_id,
SilenceUser.user_id == user.id,
)
)
).first()
can_message = silence is None
can_message_error = "You are silenced in this channel" if not can_message else None
redis = get_redis()
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else (last_msg or 0)
return ChatUserAttributes(
can_message=can_message,
can_message_error=can_message_error,
last_read_id=last_read_id,
)
@ondemand
@staticmethod
async def last_read_id(_session: AsyncSession, channel: "ChatChannel", user: User) -> int | None:
from app.dependencies.database import get_redis
redis = get_redis()
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
return int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
@ondemand
@staticmethod
async def last_message_id(_session: AsyncSession, channel: "ChatChannel") -> int | None:
from app.dependencies.database import get_redis
redis = get_redis()
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
return int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
@ondemand
@staticmethod
async def recent_messages(
session: AsyncSession,
channel: "ChatChannel",
) -> list["ChatMessageDict"]:
messages = (
await session.exec(
select(ChatMessage)
.where(ChatMessage.channel_id == channel.channel_id)
.order_by(col(ChatMessage.message_id).desc())
.limit(50)
)
).all()
result = [
await ChatMessageModel.transform(
msg,
)
for msg in reversed(messages)
]
return result
@ondemand
@staticmethod
async def users(
_session: AsyncSession,
channel: "ChatChannel",
server: "ChatServer",
user: User,
) -> list[int]:
if channel.type == ChannelType.PUBLIC:
return []
users = server.channels.get(channel.channel_id, []).copy()
if channel.type == ChannelType.PM and users and len(users) == 2:
target_user_id = next(u for u in users if u != user.id)
users = [target_user_id, user.id]
return users
@included
@staticmethod
async def message_length_limit(_session: AsyncSession, _channel: "ChatChannel") -> int:
return 1000
class ChatChannel(ChatChannelModel, table=True):
__tablename__: str = "chat_channels"
channel_id: int = Field(primary_key=True, index=True, default=None)
name: str = Field(sa_column=Column(VARCHAR(50), index=True))
@classmethod
async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
@@ -74,93 +229,20 @@ class ChatChannel(ChatChannelBase, table=True):
return channel
class ChatChannelResp(ChatChannelBase):
channel_id: int
moderated: bool = False
uuid: str | None = None
current_user_attributes: ChatUserAttributes | None = None
last_read_id: int | None = None
last_message_id: int | None = None
recent_messages: list["ChatMessageResp"] = Field(default_factory=list)
users: list[int] = Field(default_factory=list)
message_length_limit: int = 1000
@classmethod
async def from_db(
cls,
channel: ChatChannel,
session: AsyncSession,
user: User,
redis: Redis,
users: list[int] | None = None,
include_recent_messages: bool = False,
) -> Self:
c = cls.model_validate(channel)
silence = (
await session.exec(
select(SilenceUser).where(
SilenceUser.channel_id == channel.channel_id,
SilenceUser.user_id == user.id,
)
)
).first()
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
if silence is not None:
attribute = ChatUserAttributes(
can_message=False,
can_message_error=silence.reason or "You are muted in this channel.",
last_read_id=last_read_id or 0,
)
c.moderated = True
else:
attribute = ChatUserAttributes(
can_message=True,
last_read_id=last_read_id or 0,
)
c.moderated = False
c.current_user_attributes = attribute
if c.type != ChannelType.PUBLIC and users is not None:
c.users = users
c.last_message_id = last_msg
c.last_read_id = last_read_id
if include_recent_messages:
messages = (
await session.exec(
select(ChatMessage)
.where(ChatMessage.channel_id == channel.channel_id)
.order_by(col(ChatMessage.timestamp).desc())
.limit(10)
)
).all()
c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages]
c.recent_messages.reverse()
if c.type == ChannelType.PM and users and len(users) == 2:
target_user_id = next(u for u in users if u != user.id)
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
c.name = target_name.one()
c.users = [target_user_id, user.id]
return c
# ChatMessage
class ChatMessageDict(TypedDict):
channel_id: int
content: str
message_id: int
sender_id: int
timestamp: datetime
type: MessageType
uuid: str | None
is_action: NotRequired[bool]
sender: NotRequired[UserDict]
class MessageType(str, Enum):
ACTION = "action"
MARKDOWN = "markdown"
PLAIN = "plain"
class ChatMessageBase(UTCBaseModel, SQLModel):
class ChatMessageModel(DatabaseModel[ChatMessageDict]):
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
content: str = Field(sa_column=Column(VARCHAR(1000)))
message_id: int = Field(index=True, primary_key=True, default=None)
@@ -169,31 +251,21 @@ class ChatMessageBase(UTCBaseModel, SQLModel):
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
uuid: str | None = Field(default=None)
@included
@staticmethod
async def is_action(_session: AsyncSession, db_message: "ChatMessage") -> bool:
return db_message.type == MessageType.ACTION
class ChatMessage(ChatMessageBase, table=True):
@ondemand
@staticmethod
async def sender(_session: AsyncSession, db_message: "ChatMessage") -> UserDict:
return await UserModel.transform(db_message.user)
class ChatMessage(ChatMessageModel, table=True):
__tablename__: str = "chat_messages"
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
channel: ChatChannel = Relationship()
class ChatMessageResp(ChatMessageBase):
sender: UserResp | None = None
is_action: bool = False
@classmethod
async def from_db(
cls, db_message: ChatMessage, session: AsyncSession, user: User | None = None
) -> "ChatMessageResp":
m = cls.model_validate(db_message.model_dump())
m.is_action = db_message.type == MessageType.ACTION
if user:
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
else:
m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES)
return m
# SilenceUser
channel: "ChatChannel" = Relationship()
class SilenceUser(UTCBaseModel, SQLModel, table=True):

View File

@@ -1,7 +1,7 @@
import datetime
from app.database.beatmapset import Beatmapset
from app.database.user import User
from .beatmapset import Beatmapset
from .user import User
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (

View File

@@ -1,7 +1,9 @@
from .playlist_best_score import PlaylistBestScore
from .user import User, UserResp
from typing import Any, NotRequired, TypedDict
from ._base import DatabaseModel, ondemand
from .playlist_best_score import PlaylistBestScore
from .user import User, UserDict, UserModel
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
@@ -9,7 +11,6 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
SQLModel,
col,
func,
select,
@@ -17,17 +18,66 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
class ItemAttemptsCountBase(SQLModel):
room_id: int = Field(foreign_key="rooms.id", index=True)
class ItemAttemptsCountDict(TypedDict):
accuracy: float
attempts: int
completed: int
pp: float
room_id: int
total_score: int
user_id: int
user: NotRequired[UserDict]
position: NotRequired[int]
playlist_item_attempts: NotRequired[list[dict[str, Any]]]
class ItemAttemptsCountModel(DatabaseModel[ItemAttemptsCountDict]):
accuracy: float = 0.0
attempts: int = Field(default=0)
completed: int = Field(default=0)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
accuracy: float = 0.0
pp: float = 0
room_id: int = Field(foreign_key="rooms.id", index=True)
total_score: int = 0
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
@ondemand
@staticmethod
async def user(_session: AsyncSession, item_attempts: "ItemAttemptsCount") -> UserDict:
user_instance = await item_attempts.awaitable_attrs.user
return await UserModel.transform(user_instance)
@ondemand
@staticmethod
async def position(session: AsyncSession, item_attempts: "ItemAttemptsCount") -> int:
return await item_attempts.get_position(session)
@ondemand
@staticmethod
async def playlist_item_attempts(
session: AsyncSession,
item_attempts: "ItemAttemptsCount",
) -> list[dict[str, Any]]:
playlist_scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.room_id == item_attempts.room_id,
PlaylistBestScore.user_id == item_attempts.user_id,
)
)
).all()
result: list[dict[str, Any]] = []
for score in playlist_scores:
result.append(
{
"id": score.playlist_id,
"attempts": score.attempts,
"passed": score.score.passed,
}
)
return result
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountModel, table=True):
__tablename__: str = "item_attempts_count"
id: int | None = Field(default=None, primary_key=True)
@@ -37,15 +87,15 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
rownum = (
func.row_number()
.over(
partition_by=col(ItemAttemptsCountBase.room_id),
order_by=col(ItemAttemptsCountBase.total_score).desc(),
partition_by=col(ItemAttemptsCount.room_id),
order_by=col(ItemAttemptsCount.total_score).desc(),
)
.label("rn")
)
subq = select(ItemAttemptsCountBase, rownum).subquery()
subq = select(ItemAttemptsCount, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.user_id == self.user_id)
result = await session.exec(stmt)
return result.one()
return result.first() or 0
async def update(self, session: AsyncSession):
playlist_scores = (
@@ -88,62 +138,3 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
await session.refresh(item_attempts)
await item_attempts.update(session)
return item_attempts
class ItemAttemptsResp(ItemAttemptsCountBase):
user: UserResp | None = None
position: int | None = None
@classmethod
async def from_db(
cls,
item_attempts: ItemAttemptsCount,
session: AsyncSession,
include: list[str] = [],
) -> "ItemAttemptsResp":
resp = cls.model_validate(item_attempts.model_dump())
resp.user = await UserResp.from_db(
await item_attempts.awaitable_attrs.user,
session=session,
include=["statistics", "team", "daily_challenge_user_stats"],
)
if "position" in include:
resp.position = await item_attempts.get_position(session)
# resp.accuracy *= 100
return resp
class ItemAttemptsCountForItem(BaseModel):
id: int
attempts: int
passed: bool
class PlaylistAggregateScore(BaseModel):
playlist_item_attempts: list[ItemAttemptsCountForItem] = Field(default_factory=list)
@classmethod
async def from_db(
cls,
room_id: int,
user_id: int,
session: AsyncSession,
) -> "PlaylistAggregateScore":
playlist_scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.room_id == room_id,
PlaylistBestScore.user_id == user_id,
)
)
).all()
playlist_item_attempts = []
for score in playlist_scores:
playlist_item_attempts.append(
ItemAttemptsCountForItem(
id=score.playlist_id,
attempts=score.attempts,
passed=score.score.passed,
)
)
return cls(playlist_item_attempts=playlist_item_attempts)

View File

@@ -1,11 +1,11 @@
from datetime import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, NotRequired, TypedDict
from app.models.model import UTCBaseModel
from app.models.mods import APIMod
from app.models.playlist import PlaylistItem
from .beatmap import Beatmap, BeatmapResp
from ._base import DatabaseModel, ondemand
from .beatmap import Beatmap, BeatmapDict, BeatmapModel
from sqlmodel import (
JSON,
@@ -15,7 +15,6 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
SQLModel,
func,
select,
)
@@ -23,18 +22,34 @@ from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .room import Room
from .score import ScoreDict
class PlaylistBase(SQLModel, UTCBaseModel):
id: int = Field(index=True)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
class PlaylistDict(TypedDict):
id: int
room_id: int
beatmap_id: int
created_at: datetime | None
ruleset_id: int
expired: bool = Field(default=False)
playlist_order: int = Field(default=0)
played_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True)),
default=None,
allowed_mods: list[APIMod]
required_mods: list[APIMod]
freestyle: bool
expired: bool
owner_id: int
playlist_order: int
played_at: datetime | None
beatmap: NotRequired["BeatmapDict"]
scores: NotRequired[list[dict[str, Any]]]
class PlaylistModel(DatabaseModel[PlaylistDict]):
id: int = Field(index=True)
room_id: int = Field(foreign_key="rooms.id")
beatmap_id: int = Field(
foreign_key="beatmaps.id",
)
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
ruleset_id: int
allowed_mods: list[APIMod] = Field(
default_factory=list,
sa_column=Column(JSON),
@@ -43,16 +58,46 @@ class PlaylistBase(SQLModel, UTCBaseModel):
default_factory=list,
sa_column=Column(JSON),
)
beatmap_id: int = Field(
foreign_key="beatmaps.id",
)
freestyle: bool = Field(default=False)
expired: bool = Field(default=False)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
playlist_order: int = Field(default=0)
played_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True)),
default=None,
)
@ondemand
@staticmethod
async def beatmap(_session: AsyncSession, playlist: "Playlist", includes: list[str] | None = None) -> BeatmapDict:
return await BeatmapModel.transform(playlist.beatmap, includes=includes)
@ondemand
@staticmethod
async def scores(session: AsyncSession, playlist: "Playlist") -> list["ScoreDict"]:
from .score import Score, ScoreModel
scores = (
await session.exec(
select(Score).where(
Score.playlist_item_id == playlist.id,
Score.room_id == playlist.room_id,
)
)
).all()
result: list[ScoreDict] = []
for score in scores:
result.append(
await ScoreModel.transform(
score,
)
)
return result
class Playlist(PlaylistBase, table=True):
class Playlist(PlaylistModel, table=True):
__tablename__: str = "room_playlists"
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
room_id: int = Field(foreign_key="rooms.id", exclude=True)
beatmap: Beatmap = Relationship(
sa_relationship_kwargs={
@@ -60,7 +105,6 @@ class Playlist(PlaylistBase, table=True):
}
)
room: "Room" = Relationship()
created_at: datetime | None = Field(default=None, sa_column_kwargs={"server_default": func.now()})
updated_at: datetime | None = Field(
default=None, sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()}
)
@@ -121,15 +165,3 @@ class Playlist(PlaylistBase, table=True):
raise ValueError("Playlist item not found")
await session.delete(db_playlist)
await session.commit()
class PlaylistResp(PlaylistBase):
beatmap: BeatmapResp | None = None
@classmethod
async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp":
data = playlist.model_dump()
if "beatmap" in include:
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
resp = cls.model_validate(data)
return resp

View File

@@ -1,26 +1,39 @@
from enum import Enum
from typing import TYPE_CHECKING, NotRequired, TypedDict
from .user import User, UserResp
from app.models.score import GameMode
from ._base import DatabaseModel, included, ondemand
from pydantic import BaseModel
from sqlmodel import (
BigInteger,
Column,
Field,
ForeignKey,
Relationship as SQLRelationship,
SQLModel,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .user import User, UserDict
class RelationshipType(str, Enum):
FOLLOW = "Friend"
BLOCK = "Block"
FOLLOW = "friend"
BLOCK = "block"
class Relationship(SQLModel, table=True):
class RelationshipDict(TypedDict):
target_id: int | None
type: RelationshipType
id: NotRequired[int | None]
user_id: NotRequired[int | None]
mutual: NotRequired[bool]
target: NotRequired["UserDict"]
class RelationshipModel(DatabaseModel[RelationshipDict]):
__tablename__: str = "relationship"
id: int | None = Field(
default=None,
@@ -34,6 +47,7 @@ class Relationship(SQLModel, table=True):
ForeignKey("lazer_users.id"),
index=True,
),
exclude=True,
)
target_id: int = Field(
default=None,
@@ -44,22 +58,10 @@ class Relationship(SQLModel, table=True):
),
)
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
target: User = SQLRelationship(
sa_relationship_kwargs={
"foreign_keys": "[Relationship.target_id]",
"lazy": "selectin",
}
)
class RelationshipResp(BaseModel):
target_id: int
target: UserResp
mutual: bool = False
type: RelationshipType
@classmethod
async def from_db(cls, session: AsyncSession, relationship: Relationship) -> "RelationshipResp":
@included
@staticmethod
async def mutual(session: AsyncSession, relationship: "Relationship") -> bool:
target_relationship = (
await session.exec(
select(Relationship).where(
@@ -68,23 +70,29 @@ class RelationshipResp(BaseModel):
)
)
).first()
mutual = bool(
return bool(
target_relationship is not None
and relationship.type == RelationshipType.FOLLOW
and target_relationship.type == RelationshipType.FOLLOW
)
return cls(
target_id=relationship.target_id,
target=await UserResp.from_db(
relationship.target,
session,
include=[
"team",
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
],
),
mutual=mutual,
type=relationship.type,
@ondemand
@staticmethod
async def target(
_session: AsyncSession,
relationship: "Relationship",
ruleset: GameMode | None = None,
includes: list[str] | None = None,
) -> "UserDict":
from .user import UserModel
return await UserModel.transform(relationship.target, ruleset=ruleset, includes=includes)
class Relationship(RelationshipModel, table=True):
target: "User" = SQLRelationship(
sa_relationship_kwargs={
"foreign_keys": "[Relationship.target_id]",
"lazy": "selectin",
}
)

View File

@@ -1,8 +1,6 @@
from datetime import datetime
from typing import ClassVar, NotRequired, TypedDict
from app.database.item_attempts_count import PlaylistAggregateScore
from app.database.room_participated_user import RoomParticipatedUser
from app.models.model import UTCBaseModel
from app.models.room import (
MatchType,
QueueMode,
@@ -13,28 +11,58 @@ from app.models.room import (
)
from app.utils import utcnow
from .playlists import Playlist, PlaylistResp
from .user import User, UserResp
from ._base import DatabaseModel, included, ondemand
from .item_attempts_count import ItemAttemptsCount, ItemAttemptsCountDict, ItemAttemptsCountModel
from .playlists import Playlist, PlaylistDict, PlaylistModel
from .room_participated_user import RoomParticipatedUser
from .user import User, UserDict, UserModel
from pydantic import field_validator
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
col,
func,
select,
)
from sqlmodel import BigInteger, Column, DateTime, Field, ForeignKey, Relationship, SQLModel, col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class RoomBase(SQLModel, UTCBaseModel):
class RoomDict(TypedDict):
id: int
name: str
category: RoomCategory
status: RoomStatus
type: MatchType
duration: int | None
starts_at: datetime | None
ends_at: datetime | None
max_attempts: int | None
participant_count: int
channel_id: int
queue_mode: QueueMode
auto_skip: bool
auto_start_duration: int
has_password: NotRequired[bool]
current_playlist_item: NotRequired["PlaylistDict | None"]
playlist: NotRequired[list["PlaylistDict"]]
playlist_item_stats: NotRequired[RoomPlaylistItemStats]
difficulty_range: NotRequired[RoomDifficultyRange]
host: NotRequired[UserDict]
recent_participants: NotRequired[list[UserDict]]
current_user_score: NotRequired["ItemAttemptsCountDict | None"]
class RoomModel(DatabaseModel[RoomDict]):
SHOW_RESPONSE_INCLUDES: ClassVar[list[str]] = [
"current_user_score.playlist_item_attempts",
"host.country",
"playlist.beatmap.beatmapset",
"playlist.beatmap.checksum",
"playlist.beatmap.max_combo",
"recent_participants",
]
id: int = Field(default=None, primary_key=True, index=True)
name: str = Field(index=True)
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
status: RoomStatus
type: MatchType
duration: int | None = Field(default=None) # minutes
starts_at: datetime | None = Field(
sa_column=Column(
@@ -48,76 +76,88 @@ class RoomBase(SQLModel, UTCBaseModel):
),
default=None,
)
participant_count: int = Field(default=0)
max_attempts: int | None = Field(default=None) # playlists
type: MatchType
participant_count: int = Field(default=0)
channel_id: int = 0
queue_mode: QueueMode
auto_skip: bool
auto_start_duration: int
status: RoomStatus
channel_id: int | None = None
class Room(AsyncAttrs, RoomBase, table=True):
__tablename__: str = "rooms"
id: int = Field(default=None, primary_key=True, index=True)
host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
password: str | None = Field(default=None)
host: User = Relationship()
playlist: list[Playlist] = Relationship(
sa_relationship_kwargs={
"lazy": "selectin",
"cascade": "all, delete-orphan",
"overlaps": "room",
}
)
class RoomResp(RoomBase):
id: int
has_password: bool = False
host: UserResp | None = None
playlist: list[PlaylistResp] = []
playlist_item_stats: RoomPlaylistItemStats | None = None
difficulty_range: RoomDifficultyRange | None = None
current_playlist_item: PlaylistResp | None = None
current_user_score: PlaylistAggregateScore | None = None
recent_participants: list[UserResp] = Field(default_factory=list)
channel_id: int = 0
@field_validator("channel_id", mode="before")
@classmethod
async def from_db(
cls,
room: Room,
session: AsyncSession,
include: list[str] = [],
user: User | None = None,
) -> "RoomResp":
d = room.model_dump()
d["channel_id"] = d.get("channel_id", 0) or 0
d["has_password"] = bool(room.password)
resp = cls.model_validate(d)
def validate_channel_id(cls, v):
"""将 None 转换为 0"""
if v is None:
return 0
return v
stats = RoomPlaylistItemStats(count_active=0, count_total=0)
difficulty_range = RoomDifficultyRange(
min=0,
max=0,
)
rulesets = set()
for playlist in room.playlist:
@included
@staticmethod
async def has_password(_session: AsyncSession, room: "Room") -> bool:
return bool(room.password)
@ondemand
@staticmethod
async def current_playlist_item(
_session: AsyncSession, room: "Room", includes: list[str] | None = None
) -> "PlaylistDict | None":
playlists = await room.awaitable_attrs.playlist
if not playlists:
return None
return await PlaylistModel.transform(playlists[-1], includes=includes)
@ondemand
@staticmethod
async def playlist(_session: AsyncSession, room: "Room", includes: list[str] | None = None) -> list["PlaylistDict"]:
playlists = await room.awaitable_attrs.playlist
result: list[PlaylistDict] = []
for playlist_item in playlists:
result.append(await PlaylistModel.transform(playlist_item, includes=includes))
return result
@ondemand
@staticmethod
async def playlist_item_stats(_session: AsyncSession, room: "Room") -> RoomPlaylistItemStats:
playlists = await room.awaitable_attrs.playlist
stats = RoomPlaylistItemStats(count_active=0, count_total=0, ruleset_ids=[])
rulesets: set[int] = set()
for playlist in playlists:
stats.count_total += 1
if not playlist.expired:
stats.count_active += 1
rulesets.add(playlist.ruleset_id)
difficulty_range.min = min(difficulty_range.min, playlist.beatmap.difficulty_rating)
difficulty_range.max = max(difficulty_range.max, playlist.beatmap.difficulty_rating)
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
stats.ruleset_ids = list(rulesets)
resp.playlist_item_stats = stats
resp.difficulty_range = difficulty_range
resp.current_playlist_item = resp.playlist[-1] if resp.playlist else None
resp.recent_participants = []
return stats
@ondemand
@staticmethod
async def difficulty_range(_session: AsyncSession, room: "Room") -> RoomDifficultyRange:
playlists = await room.awaitable_attrs.playlist
if not playlists:
return RoomDifficultyRange(min=0.0, max=0.0)
min_diff = float("inf")
max_diff = float("-inf")
for playlist in playlists:
rating = playlist.beatmap.difficulty_rating
min_diff = min(min_diff, rating)
max_diff = max(max_diff, rating)
if min_diff == float("inf"):
min_diff = 0.0
if max_diff == float("-inf"):
max_diff = 0.0
return RoomDifficultyRange(min=min_diff, max=max_diff)
@ondemand
@staticmethod
async def host(_session: AsyncSession, room: "Room", includes: list[str] | None = None) -> UserDict:
host_user = await room.awaitable_attrs.host
return await UserModel.transform(host_user, includes=includes)
@ondemand
@staticmethod
async def recent_participants(session: AsyncSession, room: "Room") -> list[UserDict]:
participants: list[UserDict] = []
if room.category == RoomCategory.REALTIME:
query = (
select(RoomParticipatedUser)
@@ -137,39 +177,67 @@ class RoomResp(RoomBase):
.limit(8)
.order_by(col(RoomParticipatedUser.joined_at).desc())
)
resp.participant_count = (
await session.exec(
select(func.count())
.select_from(RoomParticipatedUser)
.where(
RoomParticipatedUser.room_id == room.id,
)
)
).first() or 0
for recent_participant in await session.exec(query):
resp.recent_participants.append(
await UserResp.from_db(
await recent_participant.awaitable_attrs.user,
session,
include=["statistics"],
user_instance = await recent_participant.awaitable_attrs.user
participants.append(await UserModel.transform(user_instance))
return participants
@ondemand
@staticmethod
async def current_user_score(
session: AsyncSession, room: "Room", includes: list[str] | None = None
) -> "ItemAttemptsCountDict | None":
item_attempt = (
await session.exec(
select(ItemAttemptsCount).where(
ItemAttemptsCount.room_id == room.id,
)
)
resp.host = await UserResp.from_db(await room.awaitable_attrs.host, session, include=["statistics"])
if "current_user_score" in include and user:
resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
return resp
).first()
if item_attempt is None:
return None
return await ItemAttemptsCountModel.transform(item_attempt, includes=includes)
class APIUploadedRoom(RoomBase):
def to_room(self) -> Room:
"""
将 APIUploadedRoom 转换为 Room 对象playlist 字段需单独处理。
"""
room_dict = self.model_dump()
room_dict.pop("playlist", None)
# host_id 已在字段中
return Room(**room_dict)
class Room(AsyncAttrs, RoomModel, table=True):
__tablename__: str = "rooms"
id: int | None
host_id: int | None = None
host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
password: str | None = Field(default=None)
host: User = Relationship()
playlist: list[Playlist] = Relationship(
sa_relationship_kwargs={
"lazy": "selectin",
"cascade": "all, delete-orphan",
"overlaps": "room",
}
)
class APIUploadedRoom(SQLModel):
name: str = Field(index=True)
category: RoomCategory = Field(default=RoomCategory.NORMAL, index=True)
status: RoomStatus
type: MatchType
duration: int | None = Field(default=None) # minutes
starts_at: datetime | None = Field(
sa_column=Column(
DateTime(timezone=True),
),
default_factory=utcnow,
)
ends_at: datetime | None = Field(
sa_column=Column(
DateTime(timezone=True),
),
default=None,
)
max_attempts: int | None = Field(default=None) # playlists
participant_count: int = Field(default=0)
channel_id: int = 0
queue_mode: QueueMode
auto_skip: bool
auto_start_duration: int
playlist: list[Playlist] = Field(default_factory=list)

View File

@@ -3,7 +3,7 @@ from datetime import date, datetime
import json
import math
import sys
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar, NotRequired, TypedDict
from app.calculator import (
calculate_pp_weight,
@@ -15,8 +15,6 @@ from app.calculator import (
pre_fetch_and_calculate_pp,
)
from app.config import settings
from app.database.beatmap_playcounts import BeatmapPlaycounts
from app.database.team import TeamMember
from app.dependencies.database import get_redis
from app.log import log
from app.models.beatmap import BeatmapRankStatus
@@ -39,8 +37,10 @@ from app.models.scoring_mode import ScoringMode
from app.storage import StorageService
from app.utils import utcnow
from .beatmap import Beatmap, BeatmapResp
from .beatmapset import BeatmapsetResp
from ._base import DatabaseModel, OnDemand, included, ondemand
from .beatmap import Beatmap, BeatmapDict, BeatmapModel
from .beatmap_playcounts import BeatmapPlaycounts
from .beatmapset import BeatmapsetDict, BeatmapsetModel
from .best_scores import BestScore
from .counts import MonthlyPlaycounts
from .events import Event, EventType
@@ -50,8 +50,9 @@ from .relationship import (
RelationshipType,
)
from .score_token import ScoreToken
from .team import TeamMember
from .total_score_best_scores import TotalScoreBestScore
from .user import User, UserResp
from .user import User, UserDict, UserModel
from pydantic import BaseModel, field_serializer, field_validator
from redis.asyncio import Redis
@@ -80,30 +81,290 @@ if TYPE_CHECKING:
logger = log("Score")
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# 基本字段
class ScoreDict(TypedDict):
beatmap_id: int
id: int
rank: Rank
type: str
user_id: int
accuracy: float
build_id: int | None
ended_at: datetime
has_replay: bool
max_combo: int
passed: bool
pp: float
started_at: datetime
total_score: int
maximum_statistics: ScoreStatistics
mods: list[APIMod]
classic_total_score: int | None
preserve: bool
processed: bool
ranked: bool
playlist_item_id: NotRequired[int | None]
room_id: NotRequired[int | None]
best_id: NotRequired[int | None]
legacy_perfect: NotRequired[bool]
is_perfect_combo: NotRequired[bool]
ruleset_id: NotRequired[int]
statistics: NotRequired[ScoreStatistics]
beatmapset: NotRequired[BeatmapsetDict]
beatmap: NotRequired[BeatmapDict]
current_user_attributes: NotRequired[CurrentUserAttributes]
position: NotRequired[int | None]
scores_around: NotRequired["ScoreAround | None"]
rank_country: NotRequired[int | None]
rank_global: NotRequired[int | None]
user: NotRequired[UserDict]
weight: NotRequired[float | None]
# ScoreResp 字段
legacy_total_score: NotRequired[int]
class ScoreModel(AsyncAttrs, DatabaseModel[ScoreDict]):
# https://github.com/ppy/osu-web/blob/master/app/Transformers/ScoreTransformer.php#L72-L84
MULTIPLAYER_SCORE_INCLUDE: ClassVar[list[str]] = ["playlist_item_id", "room_id", "solo_score_id"]
MULTIPLAYER_BASE_INCLUDES: ClassVar[list[str]] = [
"user.country",
"user.cover",
"user.team",
*MULTIPLAYER_SCORE_INCLUDE,
]
USER_PROFILE_INCLUDES: ClassVar[list[str]] = ["beatmap", "beatmapset", "user"]
# 基本字段
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
rank: Rank
type: str
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("lazer_users.id"),
index=True,
),
)
accuracy: float
map_md5: str = Field(max_length=32, index=True)
build_id: int | None = Field(default=None)
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score
ended_at: datetime = Field(sa_column=Column(DateTime))
has_replay: bool = Field(sa_column=Column(Boolean))
max_combo: int
mods: list[APIMod] = Field(sa_column=Column(JSON))
passed: bool = Field(sa_column=Column(Boolean))
playlist_item_id: int | None = Field(default=None) # multiplayer
pp: float = Field(default=0.0)
preserve: bool = Field(default=True, sa_column=Column(Boolean))
rank: Rank
room_id: int | None = Field(default=None) # multiplayer
started_at: datetime = Field(sa_column=Column(DateTime))
total_score: int = Field(default=0, sa_column=Column(BigInteger))
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
type: str
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict)
processed: bool = False # solo_score
ranked: bool = False
mods: list[APIMod] = Field(sa_column=Column(JSON))
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
# solo
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger))
preserve: bool = Field(default=True, sa_column=Column(Boolean))
processed: bool = Field(default=False)
ranked: bool = Field(default=False)
# multiplayer
playlist_item_id: OnDemand[int | None] = Field(default=None)
room_id: OnDemand[int | None] = Field(default=None)
@included
@staticmethod
async def best_id(
session: AsyncSession,
score: "Score",
) -> int | None:
return await get_best_id(session, score.id)
@included
@staticmethod
async def legacy_perfect(
_session: AsyncSession,
score: "Score",
) -> bool:
await score.awaitable_attrs.beatmap
return score.max_combo == score.beatmap.max_combo
@included
@staticmethod
async def is_perfect_combo(
_session: AsyncSession,
score: "Score",
) -> bool:
await score.awaitable_attrs.beatmap
return score.max_combo == score.beatmap.max_combo
@included
@staticmethod
async def ruleset_id(
_session: AsyncSession,
score: "Score",
) -> int:
return int(score.gamemode)
@included
@staticmethod
async def statistics(
_session: AsyncSession,
score: "Score",
) -> ScoreStatistics:
stats = {
HitResult.MISS: score.nmiss,
HitResult.MEH: score.n50,
HitResult.OK: score.n100,
HitResult.GREAT: score.n300,
HitResult.PERFECT: score.ngeki,
HitResult.GOOD: score.nkatu,
}
if score.nlarge_tick_miss is not None:
stats[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
if score.nslider_tail_hit is not None:
stats[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
if score.nsmall_tick_hit is not None:
stats[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
if score.nlarge_tick_hit is not None:
stats[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
return stats
@ondemand
@staticmethod
async def beatmapset(
_session: AsyncSession,
score: "Score",
includes: list[str] | None = None,
) -> BeatmapsetDict:
await score.awaitable_attrs.beatmap
return await BeatmapsetModel.transform(score.beatmap.beatmapset, includes=includes)
# reorder beatmapset and beatmap
# https://github.com/ppy/osu/blob/d8900defd34690de92be3406003fb3839fc0df1d/osu.Game/Online/API/Requests/Responses/SoloScoreInfo.cs#L111-L112
@ondemand
@staticmethod
async def beatmap(
_session: AsyncSession,
score: "Score",
includes: list[str] | None = None,
) -> BeatmapDict:
await score.awaitable_attrs.beatmap
return await BeatmapModel.transform(score.beatmap, includes=includes)
@ondemand
@staticmethod
async def current_user_attributes(
_session: AsyncSession,
score: "Score",
) -> CurrentUserAttributes:
return CurrentUserAttributes(pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id))
@ondemand
@staticmethod
async def position(
session: AsyncSession,
score: "Score",
) -> int | None:
return await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=score.user,
)
@ondemand
@staticmethod
async def scores_around(
session: AsyncSession, _score: "Score", playlist_id: int, room_id: int, is_playlist: bool
) -> "ScoreAround | None":
scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
col(PlaylistBestScore.score).has(col(Score.passed).is_(True)) if not is_playlist else True,
)
)
).all()
higher_scores = []
lower_scores = []
for score in scores:
total_score = score.score.total_score
resp = await ScoreModel.transform(score.score, includes=ScoreModel.MULTIPLAYER_BASE_INCLUDES)
if score.total_score > total_score:
higher_scores.append(resp)
elif score.total_score < total_score:
lower_scores.append(resp)
return ScoreAround(
higher=MultiplayerScores(scores=higher_scores),
lower=MultiplayerScores(scores=lower_scores),
)
@ondemand
@staticmethod
async def rank_country(
session: AsyncSession,
score: "Score",
) -> int | None:
return (
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
score.gamemode,
score.user,
type=LeaderboardType.COUNTRY,
)
or None
)
@ondemand
@staticmethod
async def rank_global(
session: AsyncSession,
score: "Score",
) -> int | None:
return (
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=score.user,
)
or None
)
@ondemand
@staticmethod
async def user(
_session: AsyncSession,
score: "Score",
includes: list[str] | None = None,
) -> UserDict:
return await UserModel.transform(score.user, ruleset=score.gamemode, includes=includes or [])
@ondemand
@staticmethod
async def weight(
session: AsyncSession,
score: "Score",
) -> float | None:
best_id = await get_best_id(session, score.id)
if best_id:
return calculate_pp_weight(best_id - 1)
return None
@ondemand
@staticmethod
async def legacy_total_score(
_session: AsyncSession,
_score: "Score",
) -> int:
return 0
@field_validator("maximum_statistics", mode="before")
@classmethod
@@ -151,17 +412,9 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# TODO: current_user_attributes
class Score(ScoreBase, table=True):
class Score(ScoreModel, table=True):
__tablename__: str = "scores"
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("lazer_users.id"),
index=True,
),
)
# ScoreStatistics
n300: int = Field(exclude=True)
n100: int = Field(exclude=True)
@@ -175,6 +428,7 @@ class Score(ScoreBase, table=True):
nsmall_tick_hit: int | None = Field(default=None, exclude=True)
gamemode: GameMode = Field(index=True)
pinned_order: int = Field(default=0, exclude=True)
map_md5: str = Field(max_length=32, index=True, exclude=True)
@field_validator("gamemode", mode="before")
@classmethod
@@ -245,9 +499,11 @@ class Score(ScoreBase, table=True):
maximum_statistics=self.maximum_statistics,
)
async def to_resp(self, session: AsyncSession, api_version: int) -> "ScoreResp | LegacyScoreResp":
async def to_resp(
self, session: AsyncSession, api_version: int, includes: list[str] = []
) -> "ScoreDict | LegacyScoreResp":
if api_version >= 20220705:
return await ScoreResp.from_db(session, self)
return await ScoreModel.transform(self, includes=includes)
return await LegacyScoreResp.from_db(session, self)
async def delete(
@@ -270,141 +526,7 @@ class Score(ScoreBase, table=True):
await session.delete(self)
class ScoreResp(ScoreBase):
id: int
user_id: int
is_perfect_combo: bool = False
legacy_perfect: bool = False
legacy_total_score: int = 0 # FIXME
weight: float = 0.0
best_id: int | None = None
ruleset_id: int | None = None
beatmap: BeatmapResp | None = None
beatmapset: BeatmapsetResp | None = None
user: UserResp | None = None
statistics: ScoreStatistics | None = None
maximum_statistics: ScoreStatistics | None = None
rank_global: int | None = None
rank_country: int | None = None
position: int | None = None
scores_around: "ScoreAround | None" = None
current_user_attributes: CurrentUserAttributes | None = None
@field_validator(
"has_replay",
"passed",
"preserve",
"is_perfect_combo",
"legacy_perfect",
"processed",
"ranked",
mode="before",
)
@classmethod
def validate_bool_fields(cls, v):
"""将整数 0/1 转换为布尔值,处理数据库中的布尔字段"""
if isinstance(v, int):
return bool(v)
return v
@field_validator("statistics", "maximum_statistics", mode="before")
@classmethod
def validate_statistics_fields(cls, v):
"""处理统计字段中的字符串键,转换为 HitResult 枚举"""
if isinstance(v, dict):
converted = {}
for key, value in v.items():
if isinstance(key, str):
try:
# 尝试将字符串转换为 HitResult 枚举
enum_key = HitResult(key)
converted[enum_key] = value
except ValueError:
# 如果转换失败,跳过这个键值对
continue
else:
converted[key] = value
return converted
return v
@field_serializer("statistics", when_used="json")
def serialize_statistics_fields(self, v):
"""序列化统计字段,确保枚举值正确转换为字符串"""
if isinstance(v, dict):
serialized = {}
for key, value in v.items():
if hasattr(key, "value"):
# 如果是枚举,使用其值
serialized[key.value] = value
else:
# 否则直接使用键
serialized[str(key)] = value
return serialized
return v
@classmethod
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
# 确保 score 对象完全加载,避免懒加载问题
await session.refresh(score)
s = cls.model_validate(score.model_dump())
await score.awaitable_attrs.beatmap
s.beatmap = await BeatmapResp.from_db(score.beatmap)
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = int(score.gamemode)
best_id = await get_best_id(session, score.id)
if best_id:
s.best_id = best_id
s.weight = calculate_pp_weight(best_id - 1)
s.statistics = {
HitResult.MISS: score.nmiss,
HitResult.MEH: score.n50,
HitResult.OK: score.n100,
HitResult.GREAT: score.n300,
HitResult.PERFECT: score.ngeki,
HitResult.GOOD: score.nkatu,
}
if score.nlarge_tick_miss is not None:
s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
if score.nslider_tail_hit is not None:
s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
if score.nsmall_tick_hit is not None:
s.statistics[HitResult.SMALL_TICK_HIT] = score.nsmall_tick_hit
if score.nlarge_tick_hit is not None:
s.statistics[HitResult.LARGE_TICK_HIT] = score.nlarge_tick_hit
s.user = await UserResp.from_db(
score.user,
session,
include=["statistics", "team", "daily_challenge_user_stats"],
ruleset=score.gamemode,
)
s.rank_global = (
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=score.user,
)
or None
)
s.rank_country = (
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
score.gamemode,
score.user,
type=LeaderboardType.COUNTRY,
)
or None
)
s.current_user_attributes = CurrentUserAttributes(
pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
)
return s
MultiplayScoreDict = ScoreModel.generate_typeddict(tuple(Score.MULTIPLAYER_BASE_INCLUDES)) # pyright: ignore[reportGeneralTypeIssues]
class LegacyStatistics(BaseModel):
@@ -417,31 +539,25 @@ class LegacyStatistics(BaseModel):
class LegacyScoreResp(UTCBaseModel):
accuracy: float
best_id: int
created_at: datetime
id: int
max_combo: int
mode: GameMode
mode_int: int
best_id: int
user_id: int
accuracy: float
mods: list[str] # acronym
passed: bool
score: int
max_combo: int
perfect: bool = False
statistics: LegacyStatistics
passed: bool
pp: float
rank: Rank
created_at: datetime
mode: GameMode
mode_int: int
replay: bool
score: int
statistics: LegacyStatistics
type: str
user_id: int
current_user_attributes: CurrentUserAttributes
user: UserResp
beatmap: BeatmapResp
rank_global: int | None = Field(default=None, exclude=True)
@classmethod
async def from_db(cls, session: AsyncSession, score: Score) -> "LegacyScoreResp":
await session.refresh(score)
async def from_db(cls, session: AsyncSession, score: "Score") -> "LegacyScoreResp":
await score.awaitable_attrs.beatmap
return cls(
accuracy=score.accuracy,
@@ -465,34 +581,13 @@ class LegacyScoreResp(UTCBaseModel):
count_geki=score.ngeki or 0,
count_katu=score.nkatu or 0,
),
type=score.type,
user_id=score.user_id,
current_user_attributes=CurrentUserAttributes(
pin=PinAttributes(is_pinned=bool(score.pinned_order), score_id=score.id)
),
user=await UserResp.from_db(
score.user,
session,
include=["statistics", "team", "daily_challenge_user_stats"],
ruleset=score.gamemode,
),
beatmap=await BeatmapResp.from_db(score.beatmap),
perfect=score.is_perfect_combo,
rank_global=(
await get_score_position_by_id(
session,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=score.user,
)
or None
),
)
class MultiplayerScores(RespWithCursor):
scores: list[ScoreResp] = Field(default_factory=list)
scores: list[MultiplayScoreDict] = Field(default_factory=list) # pyright: ignore[reportInvalidTypeForm]
params: dict[str, Any] = Field(default_factory=dict)
@@ -842,13 +937,13 @@ async def get_user_best_pp(
# https://github.com/ppy/osu-queue-score-statistics/blob/master/osu.Server.Queues.ScoreStatisticsProcessor/Helpers/PlayValidityHelper.cs
def get_play_length(score: Score, beatmap_length: int):
def get_play_length(score: "Score", beatmap_length: int):
speed_rate = get_speed_rate(score.mods)
length = beatmap_length / speed_rate
return int(min(length, (score.ended_at - score.started_at).total_seconds()))
def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]:
def calculate_playtime(score: "Score", beatmap_length: int) -> tuple[int, bool]:
total_length = get_play_length(score, beatmap_length)
total_obj_hited = (
score.n300
@@ -937,7 +1032,7 @@ async def process_score(
return score
async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
async def _process_score_pp(score: "Score", session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
if score.pp != 0:
logger.debug(
"Skipping PP calculation for score {score_id} | already set {pp:.2f}",
@@ -984,7 +1079,7 @@ async def _process_score_pp(score: Score, session: AsyncSession, redis: Redis, f
)
async def _process_score_events(score: Score, session: AsyncSession):
async def _process_score_events(score: "Score", session: AsyncSession):
total_users = (await session.exec(select(func.count()).select_from(User))).one()
rank_global = await get_score_position_by_id(
session,
@@ -1088,7 +1183,7 @@ async def _process_statistics(
session: AsyncSession,
redis: Redis,
user: User,
score: Score,
score: "Score",
score_token: int,
beatmap_length: int,
beatmap_status: BeatmapRankStatus,
@@ -1318,7 +1413,7 @@ async def process_user(
redis: Redis,
fetcher: "Fetcher",
user: User,
score: Score,
score: "Score",
score_token: int,
beatmap_length: int,
beatmap_status: BeatmapRankStatus,

View File

@@ -0,0 +1,13 @@
from . import beatmap # noqa: F401
from .beatmapset import BeatmapsetModel
from sqlmodel import SQLModel
SearchBeatmapset = BeatmapsetModel.generate_typeddict(("beatmaps.max_combo", "pack_tags"))
class SearchBeatmapsetsResp(SQLModel):
beatmapsets: list[SearchBeatmapset] # pyright: ignore[reportInvalidTypeForm]
total: int
cursor: dict[str, int | float | str] | None = None
cursor_string: str | None = None

View File

@@ -1,10 +1,11 @@
from datetime import timedelta
import math
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar, NotRequired, TypedDict
from app.models.score import GameMode
from app.utils import utcnow
from ._base import DatabaseModel, included, ondemand
from .rank_history import RankHistory
from pydantic import field_validator
@@ -15,7 +16,6 @@ from sqlmodel import (
Field,
ForeignKey,
Relationship,
SQLModel,
col,
func,
select,
@@ -23,10 +23,40 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .user import User, UserResp
from .user import User, UserDict
class UserStatisticsBase(SQLModel):
class UserStatisticsDict(TypedDict):
mode: GameMode
count_100: int
count_300: int
count_50: int
count_miss: int
pp: float
ranked_score: int
hit_accuracy: float
total_score: int
total_hits: int
maximum_combo: int
play_count: int
play_time: int
replays_watched_by_others: int
is_ranked: bool
level: NotRequired[dict[str, int]]
global_rank: NotRequired[int | None]
grade_counts: NotRequired[dict[str, int]]
rank_change_since_30_days: NotRequired[int]
country_rank: NotRequired[int | None]
user: NotRequired["UserDict"]
class UserStatisticsModel(DatabaseModel[UserStatisticsDict]):
RANKING_INCLUDES: ClassVar[list[str]] = [
"user.country",
"user.cover",
"user.team",
]
mode: GameMode = Field(index=True)
count_100: int = Field(default=0, sa_column=Column(BigInteger))
count_300: int = Field(default=0, sa_column=Column(BigInteger))
@@ -57,8 +87,63 @@ class UserStatisticsBase(SQLModel):
return GameMode.OSU
return v
@included
@staticmethod
async def level(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]:
return {
"current": int(statistics.level_current),
"progress": int(math.fmod(statistics.level_current, 1) * 100),
}
class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
@included
@staticmethod
async def global_rank(session: AsyncSession, statistics: "UserStatistics") -> int | None:
return await get_rank(session, statistics)
@included
@staticmethod
async def grade_counts(_session: AsyncSession, statistics: "UserStatistics") -> dict[str, int]:
return {
"ss": statistics.grade_ss,
"ssh": statistics.grade_ssh,
"s": statistics.grade_s,
"sh": statistics.grade_sh,
"a": statistics.grade_a,
}
@ondemand
@staticmethod
async def rank_change_since_30_days(session: AsyncSession, statistics: "UserStatistics") -> int:
global_rank = await get_rank(session, statistics)
rank_best = (
await session.exec(
select(func.max(RankHistory.rank)).where(
RankHistory.date > utcnow() - timedelta(days=30),
RankHistory.user_id == statistics.user_id,
)
)
).first()
if rank_best is None or global_rank is None:
return 0
return rank_best - global_rank
@ondemand
@staticmethod
async def country_rank(
session: AsyncSession, statistics: "UserStatistics", user_country: str | None = None
) -> int | None:
return await get_rank(session, statistics, user_country)
@ondemand
@staticmethod
async def user(_session: AsyncSession, statistics: "UserStatistics") -> "UserDict":
from .user import UserModel
user_instance = await statistics.awaitable_attrs.user
return await UserModel.transform(user_instance)
class UserStatistics(AsyncAttrs, UserStatisticsModel, table=True):
__tablename__: str = "lazer_user_statistics"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(
@@ -80,74 +165,6 @@ class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
user: "User" = Relationship(back_populates="statistics")
class UserStatisticsResp(UserStatisticsBase):
user: "UserResp | None" = None
rank_change_since_30_days: int | None = 0
global_rank: int | None = Field(default=None)
country_rank: int | None = Field(default=None)
grade_counts: dict[str, int] = Field(
default_factory=lambda: {
"ss": 0,
"ssh": 0,
"s": 0,
"sh": 0,
"a": 0,
}
)
level: dict[str, int] = Field(
default_factory=lambda: {
"current": 1,
"progress": 0,
}
)
@classmethod
async def from_db(
cls,
obj: UserStatistics,
session: AsyncSession,
user_country: str | None = None,
include: list[str] = [],
) -> "UserStatisticsResp":
s = cls.model_validate(obj.model_dump())
s.grade_counts = {
"ss": obj.grade_ss,
"ssh": obj.grade_ssh,
"s": obj.grade_s,
"sh": obj.grade_sh,
"a": obj.grade_a,
}
s.level = {
"current": int(obj.level_current),
"progress": int(math.fmod(obj.level_current, 1) * 100),
}
if "user" in include:
from .user import RANKING_INCLUDES, UserResp
user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
s.user = user
user_country = user.country_code
s.global_rank = await get_rank(session, obj)
s.country_rank = await get_rank(session, obj, user_country)
if "rank_change_since_30_days" in include:
rank_best = (
await session.exec(
select(func.max(RankHistory.rank)).where(
RankHistory.date > utcnow() - timedelta(days=30),
RankHistory.user_id == obj.user_id,
)
)
).first()
if rank_best is None or s.global_rank is None:
s.rank_change_since_30_days = 0
else:
s.rank_change_since_30_days = rank_best - s.global_rank
return s
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
from .user import User
@@ -164,7 +181,6 @@ async def get_rank(session: AsyncSession, statistics: UserStatistics, country: s
query = query.join(User).where(User.country_code == country)
subq = query.subquery()
result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id))
rank = result.first()

View File

@@ -1,9 +1,9 @@
from typing import TYPE_CHECKING
from app.calculator import calculate_score_to_level
from app.database.statistics import UserStatistics
from app.models.score import GameMode, Rank
from .statistics import UserStatistics
from .user import User
from sqlmodel import (

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,40 @@
from app.database.beatmap import BeatmapResp
from app.database.beatmap import BeatmapDict, BeatmapModel
from app.log import fetcher_logger
from ._base import BaseFetcher
from pydantic import TypeAdapter
logger = fetcher_logger("BeatmapFetcher")
adapter = TypeAdapter(
BeatmapModel.generate_typeddict(
(
"checksum",
"accuracy",
"ar",
"bpm",
"convert",
"count_circles",
"count_sliders",
"count_spinners",
"cs",
"deleted_at",
"drain",
"hit_length",
"is_scoreable",
"last_updated",
"mode_int",
"ranked",
"url",
"max_combo",
"beatmapset",
)
)
)
class BeatmapFetcher(BaseFetcher):
async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapResp:
async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapDict:
if beatmap_id:
params = {"id": beatmap_id}
elif beatmap_checksum:
@@ -16,7 +43,7 @@ class BeatmapFetcher(BaseFetcher):
raise ValueError("Either beatmap_id or beatmap_checksum must be provided.")
logger.opt(colors=True).debug(f"get_beatmap: <y>{params}</y>")
return BeatmapResp.model_validate(
return adapter.validate_python( # pyright: ignore[reportReturnType]
await self.request_api(
"https://osu.ppy.sh/api/v2/beatmaps/lookup",
params=params,

View File

@@ -3,7 +3,7 @@ import base64
import hashlib
import json
from app.database.beatmapset import BeatmapsetResp, SearchBeatmapsetsResp
from app.database import BeatmapsetDict, BeatmapsetModel, SearchBeatmapsetsResp
from app.helpers.rate_limiter import osu_api_rate_limiter
from app.log import fetcher_logger
from app.models.beatmap import SearchQueryModel
@@ -13,6 +13,7 @@ from app.utils import bg_tasks
from ._base import BaseFetcher
from httpx import AsyncClient
from pydantic import TypeAdapter
import redis.asyncio as redis
@@ -26,6 +27,46 @@ logger = fetcher_logger("BeatmapsetFetcher")
MAX_RETRY_ATTEMPTS = 3
adapter = TypeAdapter(
BeatmapsetModel.generate_typeddict(
(
"availability",
"bpm",
"last_updated",
"ranked",
"ranked_date",
"submitted_date",
"tags",
"storyboard",
"description",
"genre",
"language",
*[
f"beatmaps.{inc}"
for inc in (
"checksum",
"accuracy",
"ar",
"bpm",
"convert",
"count_circles",
"count_sliders",
"count_spinners",
"cs",
"deleted_at",
"drain",
"hit_length",
"is_scoreable",
"last_updated",
"mode_int",
"ranked",
"url",
"max_combo",
)
],
)
)
)
class BeatmapsetFetcher(BaseFetcher):
@@ -139,10 +180,9 @@ class BeatmapsetFetcher(BaseFetcher):
except Exception:
return {}
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetDict:
logger.opt(colors=True).debug(f"get_beatmapset: <y>{beatmap_set_id}</y>")
return BeatmapsetResp.model_validate(
return adapter.validate_python( # pyright: ignore[reportReturnType]
await self.request_api(f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}")
)

View File

@@ -1,8 +1,6 @@
from typing import Any
from pydantic import BaseModel
from typing import Any, TypedDict
class ChatEvent(BaseModel):
class ChatEvent(TypedDict):
event: str
data: dict[str, Any] | None = None
data: dict[str, Any] | None

View File

@@ -3,16 +3,16 @@ from datetime import UTC, datetime
from app.models.score import GameMode
from pydantic import BaseModel, field_serializer
from pydantic import BaseModel, FieldSerializationInfo, field_serializer
class UTCBaseModel(BaseModel):
@field_serializer("*", when_used="json")
def serialize_datetime(self, v, _info):
@field_serializer("*", when_used="always")
def serialize_datetime(self, v, _info: FieldSerializationInfo):
if isinstance(v, datetime):
if v.tzinfo is None:
v = v.replace(tzinfo=UTC)
return v.astimezone(UTC).isoformat().replace("+00:00", "Z")
return v.astimezone(UTC)
return v

View File

@@ -447,7 +447,7 @@ async def create_multiplayer_room(
# 让房主加入频道
host_user = await db.get(User, host_user_id)
if host_user:
await server.batch_join_channel([host_user], channel, db)
await server.batch_join_channel([host_user], channel)
# Add playlist items
await _add_playlist_items(db, room_id, room_data, host_user_id)

View File

@@ -3,11 +3,14 @@ from collections.abc import Awaitable, Callable
from math import ceil
import random
import shlex
from typing import TYPE_CHECKING
from app.calculator import calculate_weighted_pp
from app.const import BANCHOBOT_ID
from app.database import ChatMessageResp
from app.database.chat import ChannelType, ChatChannel, ChatMessage, MessageType
if TYPE_CHECKING:
pass
from app.database.chat import ChannelType, ChatChannel, ChatMessage, ChatMessageModel, MessageType
from app.database.score import Score, get_best_id
from app.database.statistics import UserStatistics, get_rank
from app.database.user import User
@@ -95,7 +98,7 @@ class Bot:
await session.commit()
await session.refresh(msg)
await session.refresh(bot)
resp = await ChatMessageResp.from_db(msg, session, bot)
resp = await ChatMessageModel.transform(msg, includes=["sender"])
await server.send_message_to_channel(resp)
async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
@@ -119,7 +122,7 @@ class Bot:
await session.refresh(channel)
await session.refresh(user)
await session.refresh(bot)
await server.batch_join_channel([user, bot], channel, session)
await server.batch_join_channel([user, bot], channel)
return channel
async def _send_reply(

View File

@@ -1,37 +1,40 @@
from typing import Annotated, Any, Literal, Self
from typing import Annotated, Literal, Self
from app.database.chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatChannelModel,
ChatMessage,
SilenceUser,
UserSilenceResp,
)
from app.database.user import User, UserResp
from app.database.user import User, UserModel
from app.dependencies.database import Database, Redis
from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user
from app.router.v2 import api_v2_router as router
from app.utils import api_doc
from .server import server
from fastapi import Depends, HTTPException, Path, Query, Security
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, model_validator
from sqlmodel import col, select
class UpdateResponse(BaseModel):
presence: list[ChatChannelResp] = Field(default_factory=list)
silences: list[Any] = Field(default_factory=list)
@router.get(
"/chat/updates",
response_model=UpdateResponse,
name="获取更新",
description="获取当前用户所在频道的最新的禁言情况。",
tags=["聊天"],
responses={
200: api_doc(
"获取更新响应。",
{"presence": list[ChatChannelModel], "silences": list[UserSilenceResp]},
ChatChannel.LISTING_INCLUDES,
name="UpdateResponse",
)
},
)
async def get_update(
session: Database,
@@ -44,45 +47,44 @@ async def get_update(
Query(alias="includes[]", description="要包含的更新类型"),
] = ["presence", "silences"],
):
resp = UpdateResponse()
resp = {
"presence": [],
"silences": [],
}
if "presence" in includes:
channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids:
# 使用明确的查询避免延迟加载
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel:
# 提取必要的属性避免惰性加载
channel_type = db_channel.type
resp.presence.append(
await ChatChannelResp.from_db(
resp["presence"].append(
await ChatChannelModel.transform(
db_channel,
session,
current_user,
redis,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
user=current_user,
server=server,
includes=ChatChannel.LISTING_INCLUDES,
)
)
if "silences" in includes:
if history_since:
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
resp["silences"].extend([UserSilenceResp.from_db(silence) for silence in silences])
return resp
@router.put(
"/chat/channels/{channel}/users/{user}",
response_model=ChatChannelResp,
name="加入频道",
description="加入指定的公开/房间频道。",
tags=["聊天"],
responses={200: api_doc("加入的频道", ChatChannelModel, ChatChannel.LISTING_INCLUDES)},
)
async def join_channel(
session: Database,
@@ -101,7 +103,7 @@ async def join_channel(
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
return await server.join_channel(current_user, db_channel, session)
return await server.join_channel(current_user, db_channel)
@router.delete(
@@ -128,13 +130,13 @@ async def leave_channel(
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
await server.leave_channel(current_user, db_channel, session)
await server.leave_channel(current_user, db_channel)
return
@router.get(
"/chat/channels",
response_model=list[ChatChannelResp],
responses={200: api_doc("加入的频道", list[ChatChannelModel])},
name="获取频道列表",
description="获取所有公开频道。",
tags=["聊天"],
@@ -142,35 +144,30 @@ async def leave_channel(
async def get_channel_list(
session: Database,
current_user: Annotated[User, Security(get_current_user, scopes=["chat.read"])],
redis: Redis,
):
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
results = []
for channel in channels:
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
channel_type = channel.type
results = await ChatChannelModel.transform_many(
channels,
user=current_user,
server=server,
)
results.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
return results
class GetChannelResp(BaseModel):
channel: ChatChannelResp
users: list[UserResp] = Field(default_factory=list)
@router.get(
"/chat/channels/{channel}",
response_model=GetChannelResp,
responses={
200: api_doc(
"频道详细信息",
{
"channel": ChatChannelModel,
"users": list[UserModel],
},
ChatChannel.LISTING_INCLUDES + User.CARD_INCLUDES,
name="GetChannelResponse",
)
},
name="获取频道信息",
description="获取指定频道的信息。",
tags=["聊天"],
@@ -191,7 +188,6 @@ async def get_channel(
raise HTTPException(status_code=404, detail="Channel not found")
# 立即提取需要的属性
channel_id = db_channel.channel_id
channel_type = db_channel.type
channel_name = db_channel.name
@@ -209,15 +205,15 @@ async def get_channel(
users.extend([target_user, current_user])
break
return GetChannelResp(
channel=await ChatChannelResp.from_db(
return {
"channel": await ChatChannelModel.transform(
db_channel,
session,
current_user,
redis,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
user=current_user,
server=server,
includes=ChatChannel.LISTING_INCLUDES,
),
"users": await UserModel.transform_many(users, includes=User.CARD_INCLUDES),
}
class CreateChannelReq(BaseModel):
@@ -244,7 +240,7 @@ class CreateChannelReq(BaseModel):
@router.post(
"/chat/channels",
response_model=ChatChannelResp,
responses={200: api_doc("创建的频道", ChatChannelModel, ["recent_messages.sender"])},
name="创建频道",
description="创建一个新的私聊/通知频道。如果存在私聊频道则重新加入。",
tags=["聊天"],
@@ -289,21 +285,13 @@ async def create_channel(
await session.refresh(current_user)
if req.type == "PM":
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel) # pyright: ignore[reportPossiblyUnboundVariable]
else:
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
await server.batch_join_channel([*target_users, current_user], channel, session)
await server.batch_join_channel([*target_users, current_user], channel)
await server.join_channel(current_user, channel, session)
await server.join_channel(current_user, channel)
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
return await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel_id, []),
include_recent_messages=True,
return await ChatChannelModel.transform(
channel, user=current_user, server=server, includes=["recent_messages.sender"]
)

View File

@@ -1,11 +1,11 @@
from typing import Annotated
from app.database import ChatMessageResp
from app.database import ChatChannelModel
from app.database.chat import (
ChannelType,
ChatChannel,
ChatChannelResp,
ChatMessage,
ChatMessageModel,
MessageType,
SilenceUser,
UserSilenceResp,
@@ -18,6 +18,7 @@ from app.log import log
from app.models.notification import ChannelMessage, ChannelMessageTeam
from app.router.v2 import api_v2_router as router
from app.service.redis_message_system import redis_message_system
from app.utils import api_doc
from .banchobot import bot
from .server import server
@@ -68,7 +69,7 @@ class MessageReq(BaseModel):
@router.post(
"/chat/channels/{channel}/messages",
response_model=ChatMessageResp,
responses={200: api_doc("发送的消息", ChatMessageModel, ["sender", "is_action"])},
name="发送消息",
description="发送消息到指定频道。",
tags=["聊天"],
@@ -130,7 +131,7 @@ async def send_message(
# 为通知系统创建临时 ChatMessage 对象(仅适用于私聊和团队频道)
if channel_type in [ChannelType.PM, ChannelType.TEAM]:
temp_msg = ChatMessage(
message_id=resp.message_id, # 使用 Redis 系统生成的ID
message_id=resp["message_id"], # 使用 Redis 系统生成的ID
channel_id=channel_id,
content=req.message,
sender_id=user_id,
@@ -151,7 +152,7 @@ async def send_message(
@router.get(
"/chat/channels/{channel}/messages",
response_model=list[ChatMessageResp],
responses={200: api_doc("获取的消息", list[ChatMessageModel], ["sender"])},
name="获取消息",
description="获取指定频道的消息列表(统一按时间正序返回)。",
tags=["聊天"],
@@ -177,7 +178,7 @@ async def get_message(
try:
messages = await redis_message_system.get_messages(channel_id, limit, since)
if len(messages) >= 2 and messages[0].message_id > messages[-1].message_id:
if len(messages) >= 2 and messages[0]["message_id"] > messages[-1]["message_id"]:
messages.reverse()
return messages
except Exception as e:
@@ -189,7 +190,7 @@ async def get_message(
# 向前加载新消息 → 直接 ASC
query = base.where(col(ChatMessage.message_id) > since).order_by(col(ChatMessage.message_id).asc()).limit(limit)
rows = (await session.exec(query)).all()
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
# 已经 ASC无需反转
return resp
@@ -202,15 +203,14 @@ async def get_message(
rows = (await session.exec(query)).all()
rows = list(rows)
rows.reverse() # 反转为 ASC
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
return resp
query = base.order_by(col(ChatMessage.message_id).desc()).limit(limit)
rows = (await session.exec(query)).all()
rows = list(rows)
rows.reverse() # 反转为 ASC
resp = [await ChatMessageResp.from_db(m, session) for m in rows]
return resp
resp = await ChatMessageModel.transform_many(rows, includes=["sender"])
return resp
@@ -248,17 +248,23 @@ class PMReq(BaseModel):
uuid: str | None = None
class NewPMResp(BaseModel):
channel: ChatChannelResp
message: ChatMessageResp
new_channel_id: int
@router.post(
"/chat/new",
name="创建私聊频道",
description="创建一个新的私聊频道。",
tags=["聊天"],
responses={
200: api_doc(
"创建私聊频道响应",
{
"channel": ChatChannelModel,
"message": ChatMessageModel,
"new_channel_id": int,
},
["recent_messages.sender", "sender"],
name="NewPMResponse",
)
},
)
async def create_new_pm(
session: Database,
@@ -290,9 +296,9 @@ async def create_new_pm(
await session.refresh(target)
await session.refresh(current_user)
await server.batch_join_channel([target, current_user], channel, session)
channel_resp = await ChatChannelResp.from_db(
channel, session, current_user, redis, server.channels[channel.channel_id]
await server.batch_join_channel([target, current_user], channel)
channel_resp = await ChatChannelModel.transform(
channel, user=current_user, server=server, includes=["recent_messages.sender"]
)
msg = ChatMessage(
channel_id=channel.channel_id,
@@ -306,10 +312,10 @@ async def create_new_pm(
await session.refresh(msg)
await session.refresh(current_user)
await session.refresh(channel)
message_resp = await ChatMessageResp.from_db(msg, session, current_user)
message_resp = await ChatMessageModel.transform(msg, user=current_user, includes=["sender"])
await server.send_message_to_channel(message_resp)
return NewPMResp(
channel=channel_resp,
message=message_resp,
new_channel_id=channel_resp.channel_id,
)
return {
"channel": channel_resp,
"message": message_resp,
"new_channel_id": channel_resp["channel_id"],
}

View File

@@ -1,7 +1,8 @@
import asyncio
from typing import Annotated, overload
from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMessageResp
from app.database import ChatMessageDict
from app.database.chat import ChannelType, ChatChannel, ChatChannelDict, ChatChannelModel
from app.database.notification import UserNotification, insert_notification
from app.database.user import User
from app.dependencies.database import (
@@ -16,7 +17,7 @@ from app.log import log
from app.models.chat import ChatEvent
from app.models.notification import NotificationDetail
from app.service.subscribers.chat import ChatSubscriber
from app.utils import bg_tasks
from app.utils import bg_tasks, safe_json_dumps
from fastapi import APIRouter, Depends, Header, Query, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
@@ -65,7 +66,7 @@ class ChatServer:
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
).first()
if db_channel:
await self.leave_channel(user, db_channel, session)
await self.leave_channel(user, db_channel)
@overload
async def send_event(self, client: int, event: ChatEvent): ...
@@ -80,7 +81,7 @@ class ChatServer:
return
client = client_
if client.client_state == WebSocketState.CONNECTED:
await client.send_text(event.model_dump_json())
await client.send_text(safe_json_dumps(event))
async def broadcast(self, channel_id: int, event: ChatEvent):
users_in_channel = self.channels.get(channel_id, [])
@@ -107,38 +108,38 @@ class ChatServer:
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
async def send_message_to_channel(self, message: ChatMessageDict, is_bot_command: bool = False):
logger.info(
f"Sending message to channel {message.channel_id}, message_id: "
f"{message.message_id}, is_bot_command: {is_bot_command}"
f"Sending message to channel {message['channel_id']}, message_id: "
f"{message['message_id']}, is_bot_command: {is_bot_command}"
)
event = ChatEvent(
event="chat.message.new",
data={"messages": [message], "users": [message.sender]},
data={"messages": [message], "users": [message["sender"]]}, # pyright: ignore[reportTypedDictNotRequiredAccess]
)
if is_bot_command:
logger.info(f"Sending bot command to user {message.sender_id}")
bg_tasks.add_task(self.send_event, message.sender_id, event)
logger.info(f"Sending bot command to user {message['sender_id']}")
bg_tasks.add_task(self.send_event, message["sender_id"], event)
else:
# 总是广播消息无论是临时ID还是真实ID
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
logger.info(f"Broadcasting message to all users in channel {message['channel_id']}")
bg_tasks.add_task(
self.broadcast,
message.channel_id,
message["channel_id"],
event,
)
# 只有真实消息 ID正数且非零才进行标记已读和设置最后消息
# Redis 消息系统生成的ID都是正数所以这里应该都能正常处理
if message.message_id and message.message_id > 0:
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
if message["message_id"] and message["message_id"] > 0:
await self.mark_as_read(message["channel_id"], message["sender_id"], message["message_id"])
await self.redis.set(f"chat:{message['channel_id']}:last_msg", message["message_id"])
logger.info(f"Updated last message ID for channel {message['channel_id']} to {message['message_id']}")
else:
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
logger.debug(f"Skipping last message update for message ID: {message['message_id']}")
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
async def batch_join_channel(self, users: list[User], channel: ChatChannel):
channel_id = channel.channel_id
not_joined = []
@@ -151,22 +152,18 @@ class ChatServer:
not_joined.append(user)
for user in not_joined:
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
channel_resp = await ChatChannelModel.transform(
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
)
await self.send_event(
user.id,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
data=channel_resp, # pyright: ignore[reportArgumentType]
),
)
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
async def join_channel(self, user: User, channel: ChatChannel) -> ChatChannelDict:
user_id = user.id
channel_id = channel.channel_id
@@ -175,25 +172,21 @@ class ChatServer:
if user_id not in self.channels[channel_id]:
self.channels[channel_id].append(user_id)
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
channel_resp: ChatChannelDict = await ChatChannelModel.transform(
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
)
await self.send_event(
user_id,
ChatEvent(
event="chat.channel.join",
data=channel_resp.model_dump(),
data=channel_resp, # pyright: ignore[reportArgumentType]
),
)
return channel_resp
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
async def leave_channel(self, user: User, channel: ChatChannel) -> None:
user_id = user.id
channel_id = channel.channel_id
@@ -203,18 +196,14 @@ class ChatServer:
if (c := self.channels.get(channel_id)) is not None and not c:
del self.channels[channel_id]
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
channel_resp = await ChatChannelModel.transform(
channel, user=user, server=server, includes=ChatChannel.LISTING_INCLUDES
)
await self.send_event(
user_id,
ChatEvent(
event="chat.channel.part",
data=channel_resp.model_dump(),
data=channel_resp, # pyright: ignore[reportArgumentType]
),
)
@@ -232,7 +221,7 @@ class ChatServer:
return
logger.info(f"User {user_id} joining channel {channel_id} (type: {db_channel.type.value})")
await self.join_channel(user, db_channel, session)
await self.join_channel(user, db_channel)
async def leave_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
@@ -248,7 +237,7 @@ class ChatServer:
return
logger.info(f"User {user_id} leaving channel {channel_id} (type: {db_channel.type.value})")
await self.leave_channel(user, db_channel, session)
await self.leave_channel(user, db_channel)
async def new_private_notification(self, detail: NotificationDetail):
async with with_db() as session:
@@ -336,6 +325,6 @@ async def chat_websocket(
# 使用明确的查询避免延迟加载
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
if db_channel is not None:
await server.join_channel(user, db_channel, session)
await server.join_channel(user, db_channel)
await _listen_stop(websocket, user_id, factory)

View File

@@ -2,7 +2,7 @@ import hashlib
from typing import Annotated
from app.database.team import Team, TeamMember, TeamRequest, TeamResp
from app.database.user import BASE_INCLUDES, User, UserResp
from app.database.user import User, UserModel
from app.dependencies.database import Database, Redis
from app.dependencies.storage import StorageService
from app.dependencies.user import ClientUser
@@ -14,12 +14,11 @@ from app.models.notification import (
from app.models.score import GameMode
from app.router.notification import server
from app.service.ranking_cache_service import get_ranking_cache_service
from app.utils import check_image, utcnow
from app.utils import api_doc, check_image, utcnow
from .router import router
from fastapi import File, Form, HTTPException, Path, Query, Request
from pydantic import BaseModel
from sqlmodel import col, exists, select
@@ -214,12 +213,22 @@ async def delete_team(
await cache_service.invalidate_team_cache()
class TeamQueryResp(BaseModel):
team: TeamResp
members: list[UserResp]
@router.get("/team/{team_id}", name="查询战队", response_model=TeamQueryResp, tags=["战队", "g0v0 API"])
@router.get(
"/team/{team_id}",
name="查询战队",
tags=["战队", "g0v0 API"],
responses={
200: api_doc(
"战队信息",
{
"team": TeamResp,
"members": list[UserModel],
},
["statistics", "country"],
name="TeamQueryResp",
)
},
)
async def get_team(
session: Database,
team_id: Annotated[int, Path(..., description="战队 ID")],
@@ -233,10 +242,10 @@ async def get_team(
)
)
).all()
return TeamQueryResp(
team=await TeamResp.from_db(members[0].team, session, gamemode),
members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members],
)
return {
"team": await TeamResp.from_db(members[0].team, session, gamemode),
"members": await UserModel.transform_many([m.user for m in members], includes=["statistics", "country"]),
}
@router.post("/team/{team_id}/request", name="请求加入战队", status_code=204, tags=["战队", "g0v0 API"])

View File

@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Annotated, Literal
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.database.statistics import UserStatistics, UserStatisticsModel
from app.database.user import User
from app.dependencies.database import Database, get_redis
from app.log import logger
@@ -46,7 +46,7 @@ class V1User(AllStrModel):
return f"v1_user:{user_id}"
@classmethod
async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User":
async def from_db(cls, db_user: User, ruleset: GameMode | None = None) -> "V1User":
ruleset = ruleset or db_user.playmode
current_statistics: UserStatistics | None = None
for i in await db_user.awaitable_attrs.statistics:
@@ -54,31 +54,33 @@ class V1User(AllStrModel):
current_statistics = i
break
if current_statistics:
statistics = await UserStatisticsResp.from_db(current_statistics, session, db_user.country_code)
statistics = await UserStatisticsModel.transform(
current_statistics, country_code=db_user.country_code, includes=["country_rank"]
)
else:
statistics = None
return cls(
user_id=db_user.id,
username=db_user.username,
join_date=db_user.join_date,
count300=statistics.count_300 if statistics else 0,
count100=statistics.count_100 if statistics else 0,
count50=statistics.count_50 if statistics else 0,
playcount=statistics.play_count if statistics else 0,
ranked_score=statistics.ranked_score if statistics else 0,
total_score=statistics.total_score if statistics else 0,
pp_rank=statistics.global_rank if statistics and statistics.global_rank else 0,
count300=current_statistics.count_300 if current_statistics else 0,
count100=current_statistics.count_100 if current_statistics else 0,
count50=current_statistics.count_50 if current_statistics else 0,
playcount=current_statistics.play_count if current_statistics else 0,
ranked_score=current_statistics.ranked_score if current_statistics else 0,
total_score=current_statistics.total_score if current_statistics else 0,
pp_rank=statistics.get("global_rank") or 0 if statistics else 0,
level=current_statistics.level_current if current_statistics else 0,
pp_raw=statistics.pp if statistics else 0.0,
accuracy=statistics.hit_accuracy if statistics else 0,
pp_raw=current_statistics.pp if current_statistics else 0.0,
accuracy=current_statistics.hit_accuracy if current_statistics else 0,
count_rank_ss=current_statistics.grade_ss if current_statistics else 0,
count_rank_ssh=current_statistics.grade_ssh if current_statistics else 0,
count_rank_s=current_statistics.grade_s if current_statistics else 0,
count_rank_sh=current_statistics.grade_sh if current_statistics else 0,
count_rank_a=current_statistics.grade_a if current_statistics else 0,
country=db_user.country_code,
total_seconds_played=statistics.play_time if statistics else 0,
pp_country_rank=statistics.country_rank if statistics and statistics.country_rank else 0,
total_seconds_played=current_statistics.play_time if current_statistics else 0,
pp_country_rank=statistics.get("country_rank") or 0 if statistics else 0,
events=[], # TODO
)
@@ -134,7 +136,7 @@ async def get_user(
try:
# 生成用户数据
v1_user = await V1User.from_db(session, db_user, ruleset)
v1_user = await V1User.from_db(db_user, ruleset)
# 异步缓存结果如果有用户ID
if db_user.id is not None:

View File

@@ -5,7 +5,11 @@ from typing import Annotated
from app.calculator import get_calculator
from app.calculators.performance import ConvertError
from app.database import Beatmap, BeatmapResp, User
from app.database import (
Beatmap,
BeatmapModel,
User,
)
from app.database.beatmap import calculate_beatmap_attributes
from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
@@ -19,29 +23,20 @@ from app.models.performance import (
from app.models.score import (
GameMode,
)
from app.utils import api_doc
from .router import router
from fastapi import HTTPException, Path, Query, Security
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
from sqlmodel import col, select
class BatchGetResp(BaseModel):
"""批量获取谱面返回模型。
返回字段说明:
- beatmaps: 谱面详细信息列表。"""
beatmaps: list[BeatmapResp]
@router.get(
"/beatmaps/lookup",
tags=["谱面"],
name="查询单个谱面",
response_model=BeatmapResp,
responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)},
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
)
@asset_proxy_response
@@ -67,14 +62,14 @@ async def lookup_beatmap(
raise HTTPException(status_code=404, detail="Beatmap not found")
await db.refresh(current_user)
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
return await BeatmapModel.transform(beatmap, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES)
@router.get(
"/beatmaps/{beatmap_id}",
tags=["谱面"],
name="获取谱面详情",
response_model=BeatmapResp,
responses={200: api_doc("单个谱面详细信息。", BeatmapModel, BeatmapModel.TRANSFORMER_INCLUDES)},
description="获取单个谱面详情。",
)
@asset_proxy_response
@@ -86,7 +81,12 @@ async def get_beatmap(
):
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
await db.refresh(current_user)
return await BeatmapModel.transform(
beatmap,
user=current_user,
includes=BeatmapModel.TRANSFORMER_INCLUDES,
)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
@@ -95,7 +95,11 @@ async def get_beatmap(
"/beatmaps/",
tags=["谱面"],
name="批量获取谱面",
response_model=BatchGetResp,
responses={
200: api_doc(
"谱面列表", {"beatmaps": list[BeatmapModel]}, BeatmapModel.TRANSFORMER_INCLUDES, name="BatchBeatmapResponse"
)
},
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
)
@asset_proxy_response
@@ -124,7 +128,12 @@ async def batch_get_beatmaps(
for beatmap in beatmaps:
await db.refresh(beatmap)
await db.refresh(current_user)
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
return {
"beatmaps": [
await BeatmapModel.transform(bm, user=current_user, includes=BeatmapModel.TRANSFORMER_INCLUDES)
for bm in beatmaps
]
}
@router.post(

View File

@@ -2,17 +2,24 @@ import re
from typing import Annotated, Literal
from urllib.parse import parse_qs
from app.database import Beatmap, Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.database.beatmapset import SearchBeatmapsetsResp
from app.database import (
Beatmap,
Beatmapset,
BeatmapsetModel,
FavouriteBeatmapset,
SearchBeatmapsetsResp,
User,
)
from app.dependencies.beatmap_download import DownloadService
from app.dependencies.cache import BeatmapsetCacheService, UserCacheService
from app.dependencies.database import Database, Redis, with_db
from app.dependencies.database import Database, Redis
from app.dependencies.fetcher import Fetcher
from app.dependencies.geoip import IPAddress, get_geoip_helper
from app.dependencies.user import ClientUser, get_current_user
from app.helpers.asset_proxy_helper import asset_proxy_response
from app.models.beatmap import SearchQueryModel
from app.service.beatmapset_cache_service import generate_hash
from app.utils import api_doc
from .router import router
@@ -27,14 +34,7 @@ from fastapi import (
)
from fastapi.responses import RedirectResponse
from httpx import HTTPError
from sqlmodel import exists, select
async def _save_to_db(sets: SearchBeatmapsetsResp):
async with with_db() as session:
for s in sets.beatmapsets:
if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first():
await Beatmapset.from_resp(session, s)
from sqlmodel import select
@router.get(
@@ -105,7 +105,6 @@ async def search_beatmapset(
try:
sets = await fetcher.search_beatmapset(query, cursor, redis)
background_tasks.add_task(_save_to_db, sets)
# 缓存搜索结果
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
@@ -117,8 +116,8 @@ async def search_beatmapset(
@router.get(
"/beatmapsets/lookup",
tags=["谱面集"],
responses={200: api_doc("谱面集详细信息", BeatmapsetModel, BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES)},
name="查询谱面集 (通过谱面 ID)",
response_model=BeatmapsetResp,
description=("通过谱面 ID 查询所属谱面集。"),
)
@asset_proxy_response
@@ -137,7 +136,10 @@ async def lookup_beatmapset(
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
resp = await BeatmapsetModel.transform(
beatmap.beatmapset, user=current_user, includes=BeatmapsetModel.API_INCLUDES
)
# 缓存结果
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
@@ -149,8 +151,8 @@ async def lookup_beatmapset(
@router.get(
"/beatmapsets/{beatmapset_id}",
tags=["谱面集"],
responses={200: api_doc("谱面集详细信息", BeatmapsetModel, BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES)},
name="获取谱面集详情",
response_model=BeatmapsetResp,
description="获取单个谱面集详情。",
)
@asset_proxy_response
@@ -169,7 +171,8 @@ async def get_beatmapset(
try:
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
resp = await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user)
await db.refresh(current_user)
resp = await BeatmapsetModel.transform(beatmapset, includes=BeatmapsetModel.API_INCLUDES, user=current_user)
# 缓存结果
await cache_service.cache_beatmapset(resp)

View File

@@ -1,20 +1,29 @@
from typing import Annotated
from app.database import FavouriteBeatmapset, MeResp, User
from app.database import FavouriteBeatmapset, User
from app.database.user import UserModel
from app.dependencies.database import Database
from app.dependencies.user import UserAndToken, get_current_user, get_current_user_and_token
from app.models.score import GameMode
from app.utils import api_doc
from .router import router
from fastapi import Path, Security
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from sqlmodel import select
ME_INCLUDES = [*User.USER_INCLUDES, "session_verified", "session_verification_method"]
class BeatmapsetIds(BaseModel):
beatmapset_ids: list[int]
@router.get(
"/me/beatmapset-favourites",
response_model=list[int],
response_model=BeatmapsetIds,
name="获取当前用户收藏的谱面集 ID 列表",
description="获取当前登录用户收藏的谱面集 ID 列表。",
tags=["用户", "谱面集"],
@@ -26,37 +35,39 @@ async def get_user_beatmapset_favourites(
beatmapset_ids = await session.exec(
select(FavouriteBeatmapset.beatmapset_id).where(FavouriteBeatmapset.user_id == current_user.id)
)
return beatmapset_ids.all()
return BeatmapsetIds(beatmapset_ids=list(beatmapset_ids.all()))
@router.get(
"/me/{ruleset}",
response_model=MeResp,
responses={200: api_doc("当前用户信息(含指定 ruleset 统计)", UserModel, ME_INCLUDES)},
name="获取当前用户信息 (指定 ruleset)",
description="获取当前登录用户信息 (含指定 ruleset 统计)。",
tags=["用户"],
)
async def get_user_info_with_ruleset(
session: Database,
ruleset: Annotated[GameMode, Path(description="指定 ruleset")],
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
):
user_resp = await MeResp.from_db(user_and_token[0], session, ruleset, token_id=user_and_token[1].id)
user_resp = await UserModel.transform(
user_and_token[0], ruleset=ruleset, token_id=user_and_token[1].id, includes=ME_INCLUDES
)
return user_resp
@router.get(
"/me/",
response_model=MeResp,
responses={200: api_doc("当前用户信息", UserModel, ME_INCLUDES)},
name="获取当前用户信息",
description="获取当前登录用户信息。",
tags=["用户"],
)
async def get_user_info_default(
session: Database,
user_and_token: Annotated[UserAndToken, Security(get_current_user_and_token, scopes=["identify"])],
):
user_resp = await MeResp.from_db(user_and_token[0], session, None, token_id=user_and_token[1].id)
user_resp = await UserModel.transform(
user_and_token[0], ruleset=None, token_id=user_and_token[1].id, includes=ME_INCLUDES
)
return user_resp

View File

@@ -1,11 +1,13 @@
from typing import Annotated, Literal
from app.config import settings
from app.database import Team, TeamMember, User, UserStatistics, UserStatisticsResp
from app.database import Team, TeamMember, User, UserStatistics
from app.database.statistics import UserStatisticsModel
from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_current_user
from app.models.score import GameMode
from app.service.ranking_cache_service import get_ranking_cache_service
from app.utils import api_doc
from .router import router
@@ -308,14 +310,16 @@ async def get_country_ranking(
return response
class TopUsersResponse(BaseModel):
ranking: list[UserStatisticsResp]
total: int
@router.get(
"/rankings/{ruleset}/{sort}",
response_model=TopUsersResponse,
responses={
200: api_doc(
"用户排行榜",
{"ranking": list[UserStatisticsModel], "total": int},
["user.country", "user.cover"],
name="TopUsersResponse",
)
},
name="获取用户排行榜",
description="获取在指定模式下的用户排行榜",
tags=["排行榜"],
@@ -339,10 +343,10 @@ async def get_user_ranking(
if cached_data and cached_stats:
# 从缓存返回数据
return TopUsersResponse(
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data],
total=cached_stats.get("total", 0),
)
return {
"ranking": cached_data,
"total": cached_stats.get("total", 0),
}
# 缓存未命中,从数据库查询
wheres = [
@@ -350,7 +354,7 @@ async def get_user_ranking(
col(UserStatistics.pp) > 0,
col(UserStatistics.is_ranked),
]
include = ["user"]
include = UserStatistics.RANKING_INCLUDES.copy()
if sort == "performance":
order_by = col(UserStatistics.pp).desc()
include.append("rank_change_since_30_days")
@@ -358,6 +362,7 @@ async def get_user_ranking(
order_by = col(UserStatistics.ranked_score).desc()
if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
include.append("country_rank")
# 查询总数
count_query = select(func.count()).select_from(UserStatistics).where(*wheres)
@@ -378,12 +383,14 @@ async def get_user_ranking(
# 转换为响应格式
ranking_data = []
for statistics in statistics_list:
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
user_stats_resp = await UserStatisticsModel.transform(
statistics, includes=include, user_country=current_user.country_code
)
ranking_data.append(user_stats_resp)
# 异步缓存数据(不等待完成)
# 使用配置文件中的TTL设置
cache_data = [item.model_dump() for item in ranking_data]
cache_data = ranking_data
stats_data = {"total": total_count}
# 创建后台任务来缓存数据
@@ -407,5 +414,7 @@ async def get_user_ranking(
ttl=settings.ranking_cache_expire_minutes * 60,
)
resp = TopUsersResponse(ranking=ranking_data, total=total_count)
return resp
return {
"ranking": ranking_data,
"total": total_count,
}

View File

@@ -1,15 +1,16 @@
from typing import Annotated
from typing import Annotated, Any
from app.database import Relationship, RelationshipResp, RelationshipType, User
from app.database.user import UserResp
from app.database import Relationship, RelationshipType, User
from app.database.relationship import RelationshipModel
from app.database.user import UserModel
from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database
from app.dependencies.user import ClientUser, get_current_user
from app.utils import api_doc
from .router import router
from fastapi import HTTPException, Path, Query, Request, Security
from pydantic import BaseModel
from sqlmodel import col, exists, select
@@ -17,38 +18,19 @@ from sqlmodel import col, exists, select
"/friends",
tags=["用户关系"],
responses={
200: {
"description": "好友列表",
"content": {
"application/json": {
"schema": {
"oneOf": [
{
"type": "array",
"items": {"$ref": "#/components/schemas/RelationshipResp"},
"description": "好友列表",
},
{
"type": "array",
"items": {"$ref": "#/components/schemas/UserResp"},
"description": "好友列表 (`x-api-version < 20241022`)",
},
]
}
}
},
}
200: api_doc(
"好友列表\n\n如果 `x-api-version < 20241022`,返回值为 `User` 列表,否则为 `Relationship` 列表",
list[RelationshipModel] | list[UserModel],
[f"target.{inc}" for inc in User.LIST_INCLUDES],
)
},
name="获取好友列表",
description=(
"获取当前用户的好友列表。\n\n"
"如果 `x-api-version < 20241022`,返回值为 `UserResp` 列表,否则为 `RelationshipResp` 列表。"
),
description="获取当前用户的好友列表。",
)
@router.get(
"/blocks",
tags=["用户关系"],
response_model=list[RelationshipResp],
response_model=list[dict[str, Any]],
name="获取屏蔽列表",
description="获取当前用户的屏蔽用户列表。",
)
@@ -67,35 +49,29 @@ async def get_relationship(
)
)
if api_version >= 20241022 or relationship_type == RelationshipType.BLOCK:
return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
return [
await RelationshipModel.transform(
rel,
includes=[f"target.{inc}" for inc in User.LIST_INCLUDES],
ruleset=current_user.playmode,
)
for rel in relationships.unique()
]
else:
return [
await UserResp.from_db(
await UserModel.transform(
rel.target,
db,
include=[
"team",
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
],
ruleset=current_user.playmode,
includes=User.LIST_INCLUDES,
)
for rel in relationships.unique()
]
class AddFriendResp(BaseModel):
"""添加好友/屏蔽 返回模型。
- user_relation: 新的或更新后的关系对象。"""
user_relation: RelationshipResp
@router.post(
"/friends",
tags=["用户关系"],
response_model=AddFriendResp,
responses={200: api_doc("好友关系", {"user_relation": RelationshipModel}, name="UserRelationshipResponse")},
name="添加或更新好友关系",
description="\n添加或更新与目标用户的好友关系。",
)
@@ -163,7 +139,13 @@ async def add_relationship(
)
)
).one()
return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship))
return {
"user_relation": await RelationshipModel.transform(
relationship,
includes=[],
ruleset=current_user.playmode,
)
}
@router.delete(

View File

@@ -1,25 +1,27 @@
from datetime import UTC
from typing import Annotated, Literal
from app.database.beatmap import Beatmap, BeatmapResp
from app.database.beatmapset import BeatmapsetResp
from app.database.item_attempts_count import ItemAttemptsCount, ItemAttemptsResp
from app.database.beatmap import (
Beatmap,
BeatmapModel,
)
from app.database.beatmapset import BeatmapsetModel
from app.database.item_attempts_count import ItemAttemptsCount, ItemAttemptsCountModel
from app.database.multiplayer_event import MultiplayerEvent, MultiplayerEventResp
from app.database.playlists import Playlist, PlaylistResp
from app.database.room import APIUploadedRoom, Room, RoomResp
from app.database.playlists import Playlist, PlaylistModel
from app.database.room import APIUploadedRoom, Room, RoomModel
from app.database.room_participated_user import RoomParticipatedUser
from app.database.score import Score
from app.database.user import User, UserResp
from app.database.user import User, UserModel
from app.dependencies.database import Database, Redis
from app.dependencies.user import ClientUser, get_current_user
from app.models.room import MatchType, RoomCategory, RoomStatus
from app.service.room import create_playlist_room_from_api
from app.utils import utcnow
from app.utils import api_doc, utcnow
from .router import router
from fastapi import HTTPException, Path, Query, Security
from pydantic import BaseModel, Field
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -28,7 +30,19 @@ from sqlmodel.ext.asyncio.session import AsyncSession
@router.get(
"/rooms",
tags=["房间"],
response_model=list[RoomResp],
responses={
200: api_doc(
"房间列表",
list[RoomModel],
[
"current_playlist_item.beatmap.beatmapset",
"difficulty_range",
"host.country",
"playlist_item_stats",
"recent_participants",
],
)
},
name="获取房间列表",
description="获取房间列表。支持按状态/模式筛选",
)
@@ -49,7 +63,7 @@ async def get_all_rooms(
] = RoomCategory.NORMAL,
status: Annotated[RoomStatus | None, Query(description="房间状态(可选)")] = None,
):
resp_list: list[RoomResp] = []
resp_list = []
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category, col(Room.type) != MatchType.MATCHMAKING]
now = utcnow()
@@ -90,22 +104,24 @@ async def get_all_rooms(
.all()
)
for room in db_rooms:
resp = await RoomResp.from_db(room, db)
resp = await RoomModel.transform(
room,
includes=[
"current_playlist_item.beatmap.beatmapset",
"difficulty_range",
"host.country",
"playlist_item_stats",
"recent_participants",
],
)
if category == RoomCategory.REALTIME:
resp.category = RoomCategory.NORMAL
resp["category"] = RoomCategory.NORMAL
resp_list.append(resp)
return resp_list
class APICreatedRoom(RoomResp):
"""创建房间返回模型,继承 RoomResp。额外字段:
- error: 错误信息(为空表示成功)。"""
error: str = ""
async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
participated_user = (
await session.exec(
@@ -133,9 +149,15 @@ async def _participate_room(room_id: int, user_id: int, db_room: Room, session:
@router.post(
"/rooms",
tags=["房间"],
response_model=APICreatedRoom,
name="创建房间",
description="\n创建一个新的房间。",
responses={
200: api_doc(
"创建的房间信息",
RoomModel,
Room.SHOW_RESPONSE_INCLUDES,
)
},
)
async def create_room(
db: Database,
@@ -145,23 +167,27 @@ async def create_room(
):
if await current_user.is_restricted(db):
raise HTTPException(status_code=403, detail="Your account is restricted from multiplayer.")
user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id)
await _participate_room(db_room.id, user_id, db_room, db, redis)
await db.commit()
await db.refresh(db_room)
created_room = APICreatedRoom.model_validate(await RoomResp.from_db(db_room, db))
created_room.error = ""
created_room = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES)
return created_room
@router.get(
"/rooms/{room_id}",
tags=["房间"],
response_model=RoomResp,
responses={
200: api_doc(
"房间详细信息",
RoomModel,
Room.SHOW_RESPONSE_INCLUDES,
)
},
name="获取房间详情",
description="获取单个房间详情。",
description="获取指定房间详情。",
)
async def get_room(
db: Database,
@@ -177,7 +203,7 @@ async def get_room(
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None:
raise HTTPException(404, "Room not found")
resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user)
resp = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES, user=current_user)
return resp
@@ -225,10 +251,10 @@ async def add_user_to_room(
await _participate_room(room_id, user_id, db_room, db, redis)
await db.commit()
await db.refresh(db_room)
resp = await RoomResp.from_db(db_room, db)
resp = await RoomModel.transform(db_room, includes=Room.SHOW_RESPONSE_INCLUDES)
return resp
else:
raise HTTPException(404, "room not found0")
raise HTTPException(404, "room not found")
@router.delete(
@@ -268,21 +294,22 @@ async def remove_user_from_room(
raise HTTPException(404, "Room not found")
class APILeaderboard(BaseModel):
"""房间全局排行榜返回模型。
- leaderboard: 用户游玩统计(尝试次数/分数等)。
- user_score: 当前用户对应统计。"""
leaderboard: list[ItemAttemptsResp] = Field(default_factory=list)
user_score: ItemAttemptsResp | None = None
@router.get(
"/rooms/{room_id}/leaderboard",
tags=["房间"],
response_model=APILeaderboard,
name="获取房间排行榜",
description="获取房间内累计得分排行榜。",
responses={
200: api_doc(
"房间排行榜",
{
"leaderboard": list[ItemAttemptsCountModel],
"user_score": ItemAttemptsCountModel | None,
},
["user.country", "position"],
name="RoomLeaderboardResponse",
)
},
)
async def get_room_leaderboard(
db: Database,
@@ -300,45 +327,43 @@ async def get_room_leaderboard(
aggs_resp = []
user_agg = None
for i, agg in enumerate(aggs):
resp = await ItemAttemptsResp.from_db(agg, db)
resp.position = i + 1
includes = ["user.country"]
if agg.user_id == current_user.id:
includes.append("position")
resp = await ItemAttemptsCountModel.transform(agg, includes=includes)
aggs_resp.append(resp)
if agg.user_id == current_user.id:
user_agg = resp
return APILeaderboard(
leaderboard=aggs_resp,
user_score=user_agg,
)
class RoomEvents(BaseModel):
"""房间事件流返回模型。
- beatmaps: 本次结果涉及的谱面列表。
- beatmapsets: 谱面集映射。
- current_playlist_item_id: 当前游玩列表(项目)项 ID。
- events: 事件列表。
- first_event_id / last_event_id: 事件范围。
- playlist_items: 房间游玩列表(项目)详情。
- room: 房间详情。
- user: 关联用户列表。"""
beatmaps: list[BeatmapResp] = Field(default_factory=list)
beatmapsets: dict[int, BeatmapsetResp] = Field(default_factory=dict)
current_playlist_item_id: int = 0
events: list[MultiplayerEventResp] = Field(default_factory=list)
first_event_id: int = 0
last_event_id: int = 0
playlist_items: list[PlaylistResp] = Field(default_factory=list)
room: RoomResp
user: list[UserResp] = Field(default_factory=list)
return {
"leaderboard": aggs_resp,
"user_score": user_agg,
}
@router.get(
"/rooms/{room_id}/events",
response_model=RoomEvents,
tags=["房间"],
name="获取房间事件",
description="获取房间事件列表 (倒序,可按 after / before 进行范围截取)。",
responses={
200: api_doc(
"房间事件",
{
"beatmaps": list[BeatmapModel],
"beatmapsets": list[BeatmapsetModel],
"current_playlist_item_id": int,
"events": list[MultiplayerEventResp],
"first_event_id": int,
"last_event_id": int,
"playlist_items": list[PlaylistModel],
"room": RoomModel,
"user": list[UserModel],
},
["country", "details", "scores"],
name="RoomEventsResponse",
)
},
)
async def get_room_events(
db: Database,
@@ -402,28 +427,44 @@ async def get_room_events(
room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if room is None:
raise HTTPException(404, "Room not found")
room_resp = await RoomResp.from_db(room, db)
if room.category == RoomCategory.REALTIME and room_resp.current_playlist_item:
current_playlist_item_id = room_resp.current_playlist_item.id
room_resp = await RoomModel.transform(room, includes=["current_playlist_item"])
if room.category == RoomCategory.REALTIME:
current_playlist_item_id = (await Room.current_playlist_item(db, room))["id"]
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
user_resps = [await UserResp.from_db(user, db) for user in users]
user_resps = [await UserModel.transform(user, includes=["country"]) for user in users]
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps]
beatmapset_resps = {}
for beatmap_resp in beatmap_resps:
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()]
return RoomEvents(
beatmaps=beatmap_resps,
beatmapsets=beatmapset_resps,
current_playlist_item_id=current_playlist_item_id,
events=event_resps,
first_event_id=first_event_id,
last_event_id=last_event_id,
playlist_items=playlist_items_resps,
room=room_resp,
user=user_resps,
beatmap_resps = [
await BeatmapModel.transform(
beatmap,
)
for beatmap in beatmaps
]
beatmapsets = []
for beatmap in beatmaps:
if beatmap.beatmapset_id not in beatmapsets:
beatmapsets.append(beatmap.beatmapset)
beatmapset_resps = [
await BeatmapsetModel.transform(
beatmapset,
)
for beatmapset in beatmapsets
]
playlist_items_resps = [
await PlaylistModel.transform(item, includes=["details", "scores"]) for item in playlist_items.values()
]
return {
"beatmaps": beatmap_resps,
"beatmapsets": beatmapset_resps,
"current_playlist_item_id": current_playlist_item_id,
"events": event_resps,
"first_event_id": first_event_id,
"last_event_id": last_event_id,
"playlist_items": playlist_items_resps,
"room": room_resp,
"user": user_resps,
}

View File

@@ -9,7 +9,6 @@ from app.database import (
Playlist,
Room,
Score,
ScoreResp,
ScoreToken,
ScoreTokenResp,
User,
@@ -27,8 +26,10 @@ from app.database.relationship import Relationship, RelationshipType
from app.database.score import (
LegacyScoreResp,
MultiplayerScores,
ScoreAround,
MultiplayScoreDict,
ScoreModel,
get_leaderboard,
get_score_position_by_id,
process_score,
process_user,
)
@@ -49,7 +50,7 @@ from app.models.score import (
)
from app.service.beatmap_cache_service import get_beatmap_cache_service
from app.service.user_cache_service import refresh_user_cache_background
from app.utils import utcnow
from app.utils import api_doc, utcnow
from .router import router
@@ -72,6 +73,7 @@ from sqlmodel import col, exists, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
READ_SCORE_TIMEOUT = 10
DEFAULT_SCORE_INCLUDES = ["user", "user.country", "user.cover", "user.team"]
logger = log("Score")
@@ -180,13 +182,15 @@ async def submit_score(
await db.refresh(score)
background_task.add_task(_process_user, score_id, user_id, redis, fetcher)
resp: ScoreResp = await ScoreResp.from_db(db, score)
resp = await ScoreModel.transform(
score,
)
score_gamemode = score.gamemode
await db.commit()
if user_id is not None:
background_task.add_task(refresh_user_cache_background, redis, user_id, score_gamemode)
background_task.add_task(_process_user_achievement, resp.id)
background_task.add_task(_process_user_achievement, resp["id"])
return resp
@@ -218,27 +222,36 @@ async def _preload_beatmap_for_pp_calculation(beatmap_id: int) -> None:
logger.warning(f"Failed to preload beatmap {beatmap_id}: {e}")
class BeatmapUserScore[T: ScoreResp | LegacyScoreResp](BaseModel):
LeaderboardScoreType = ScoreModel.generate_typeddict(tuple(DEFAULT_SCORE_INCLUDES)) | LegacyScoreResp
class BeatmapUserScore(BaseModel):
position: int
score: T
score: LeaderboardScoreType # pyright: ignore[reportInvalidTypeForm]
class BeatmapScores[T: ScoreResp | LegacyScoreResp](BaseModel):
scores: list[T]
user_score: BeatmapUserScore[T] | None = None
class BeatmapScores(BaseModel):
scores: list[LeaderboardScoreType] # pyright: ignore[reportInvalidTypeForm]
user_score: BeatmapUserScore | None = None
score_count: int = 0
@router.get(
"/beatmaps/{beatmap_id}/scores",
tags=["成绩"],
response_model=BeatmapScores[ScoreResp] | BeatmapScores[LegacyScoreResp],
name="获取谱面排行榜",
description=(
"获取指定谱面在特定条件下的排行榜及当前用户成绩。\n\n"
"如果 `x-api-version >= 20220705`,返回值为 `BeatmapScores[ScoreResp]`"
responses={
200: {
"model": BeatmapScores,
"description": (
"排行榜及当前用户成绩。\n\n"
f"如果 `x-api-version >= 20220705`,返回值为 `BeatmapScores[Score]`"
f" (包含:{', '.join([f'`{inc}`' for inc in DEFAULT_SCORE_INCLUDES])}"
"否则为 `BeatmapScores[LegacyScoreResp]`。"
),
}
},
name="获取谱面排行榜",
description="获取指定谱面在特定条件下的排行榜及当前用户成绩。",
)
async def get_beatmap_scores(
db: Database,
@@ -266,27 +279,46 @@ async def get_beatmap_scores(
mods=sorted(mods),
)
user_score_resp = await user_score.to_resp(db, api_version) if user_score else None
resp = BeatmapScores(
scores=[await score.to_resp(db, api_version) for score in all_scores],
user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0)
if user_score_resp
else None,
score_count=count,
user_score_resp = await user_score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) if user_score else None
return {
"scores": [await score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) for score in all_scores],
"user_score": (
{
"score": user_score_resp,
"position": (
await get_score_position_by_id(
db,
user_score.beatmap_id,
user_score.id,
mode=user_score.gamemode,
user=user_score.user,
)
return resp
or 0
),
}
if user_score and user_score_resp
else None
),
"score_count": count,
}
@router.get(
"/beatmaps/{beatmap_id}/scores/users/{user_id}",
tags=["成绩"],
response_model=BeatmapUserScore[ScoreResp] | BeatmapUserScore[LegacyScoreResp],
name="获取用户谱面最高成绩",
description=(
"获取指定用户在指定谱面上的最高成绩。\n\n"
"如果 `x-api-version >= 20220705`,返回值为 `BeatmapUserScore[ScoreResp]`"
responses={
200: {
"model": BeatmapUserScore,
"description": (
"指定用户在指定谱面上的最高成绩\n\n"
"如果 `x-api-version >= 20220705`,返回值为 `BeatmapUserScore[Score]`"
f" (包含:{', '.join([f'`{inc}`' for inc in DEFAULT_SCORE_INCLUDES])}"
"否则为 `BeatmapUserScore[LegacyScoreResp]`。"
),
}
},
name="获取用户谱面最高成绩",
description="获取指定用户在指定谱面上的最高成绩。",
)
async def get_user_beatmap_score(
db: Database,
@@ -318,23 +350,38 @@ async def get_user_beatmap_score(
detail=f"Cannot find user {user_id}'s score on this beatmap",
)
else:
resp = await user_score.to_resp(db, api_version=api_version)
return BeatmapUserScore(
position=resp.rank_global or 0,
score=resp,
resp = await user_score.to_resp(db, api_version=api_version, includes=DEFAULT_SCORE_INCLUDES)
return {
"position": (
await get_score_position_by_id(
db,
user_score.beatmap_id,
user_score.id,
mode=user_score.gamemode,
user=user_score.user,
)
or 0
),
"score": resp,
}
@router.get(
"/beatmaps/{beatmap_id}/scores/users/{user_id}/all",
tags=["成绩"],
response_model=list[ScoreResp] | list[LegacyScoreResp],
name="获取用户谱面全部成绩",
description=(
"获取指定用户在指定谱面上的全部成绩列表。\n\n"
"如果 `x-api-version >= 20220705`,返回值为 `ScoreResp`列表,"
responses={
200: api_doc(
(
"用户谱面全部成绩\n\n"
"如果 `x-api-version >= 20220705`,返回值为 `Score`列表,"
"否则为 `LegacyScoreResp`列表。"
),
list[ScoreModel] | list[LegacyScoreResp],
DEFAULT_SCORE_INCLUDES,
)
},
name="获取用户谱面全部成绩",
description="获取指定用户在指定谱面上的全部成绩列表。",
)
async def get_user_all_beatmap_scores(
db: Database,
@@ -359,7 +406,7 @@ async def get_user_all_beatmap_scores(
)
).all()
return [await score.to_resp(db, api_version) for score in all_user_scores]
return [await score.to_resp(db, api_version, includes=DEFAULT_SCORE_INCLUDES) for score in all_user_scores]
@router.post(
@@ -413,9 +460,9 @@ async def create_solo_score(
@router.put(
"/beatmaps/{beatmap_id}/solo/scores/{token}",
tags=["游玩"],
response_model=ScoreResp,
name="提交单曲成绩",
description="\n使用令牌提交单曲成绩。",
responses={200: api_doc("单曲成绩提交结果。", ScoreModel)},
)
async def submit_solo_score(
background_task: BackgroundTasks,
@@ -520,6 +567,7 @@ async def create_playlist_score(
tags=["游玩"],
name="提交房间项目成绩",
description="\n提交房间游玩项目成绩。",
responses={200: api_doc("单曲成绩提交结果。", ScoreModel)},
)
async def submit_playlist_score(
background_task: BackgroundTasks,
@@ -560,13 +608,13 @@ async def submit_playlist_score(
room_id,
playlist_id,
user_id,
score_resp.id,
score_resp.total_score,
score_resp["id"],
score_resp["total_score"],
session,
redis,
)
await session.commit()
if room_category == RoomCategory.DAILY_CHALLENGE and score_resp.passed:
if room_category == RoomCategory.DAILY_CHALLENGE and score_resp["passed"]:
await process_daily_challenge_score(session, user_id, room_id)
await ItemAttemptsCount.get_or_create(room_id, user_id, session)
await session.commit()
@@ -575,15 +623,23 @@ async def submit_playlist_score(
class IndexedScoreResp(MultiplayerScores):
total: int
user_score: ScoreResp | None = None
user_score: MultiplayScoreDict | None = None # pyright: ignore[reportInvalidTypeForm]
@router.get(
"/rooms/{room_id}/playlist/{playlist_id}/scores",
response_model=IndexedScoreResp,
# response_model=IndexedScoreResp,
name="获取房间项目排行榜",
description="获取房间游玩项目排行榜。",
tags=["成绩"],
responses={
200: {
"description": (
f"房间项目排行榜。\n\n包含:{', '.join([f'`{inc}`' for inc in Score.MULTIPLAYER_BASE_INCLUDES])}"
),
"model": IndexedScoreResp,
}
},
)
async def index_playlist_scores(
session: Database,
@@ -620,16 +676,14 @@ async def index_playlist_scores(
scores = scores[:-1]
user_score = None
score_resp = [await ScoreResp.from_db(session, score.score) for score in scores]
score_resp = [await ScoreModel.transform(score.score, includes=Score.MULTIPLAYER_BASE_INCLUDES) for score in scores]
for score in score_resp:
score.position = await get_position(room_id, playlist_id, score.id, session)
if score.user_id == user_id:
if (room.category == RoomCategory.DAILY_CHALLENGE and score["user_id"] == user_id and score["passed"]) or score[
"user_id"
] == user_id:
user_score = score
if room.category == RoomCategory.DAILY_CHALLENGE:
score_resp = [s for s in score_resp if s.passed]
if user_score and not user_score.passed:
user_score = None
user_score["position"] = await get_position(room_id, playlist_id, score["id"], session)
break
resp = IndexedScoreResp(
scores=score_resp,
@@ -648,10 +702,16 @@ async def index_playlist_scores(
@router.get(
"/rooms/{room_id}/playlist/{playlist_id}/scores/{score_id}",
response_model=ScoreResp,
name="获取房间项目单个成绩",
description="获取指定房间游玩项目中单个成绩详情。",
tags=["成绩"],
responses={
200: api_doc(
"房间项目单个成绩详情。",
ScoreModel,
[*Score.MULTIPLAYER_BASE_INCLUDES, "position", "scores_around"],
)
},
)
async def show_playlist_score(
session: Database,
@@ -687,39 +747,25 @@ async def show_playlist_score(
break
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
resp = await ScoreResp.from_db(session, score_record.score)
resp.position = await get_position(room_id, playlist_id, score_id, session)
includes = [
*Score.MULTIPLAYER_BASE_INCLUDES,
"position",
]
if completed:
scores = (
await session.exec(
select(PlaylistBestScore).where(
PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id,
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
)
)
).all()
higher_scores = []
lower_scores = []
for score in scores:
resp = await ScoreResp.from_db(session, score.score)
if is_playlist and not resp.passed:
continue
if score.total_score > resp.total_score:
higher_scores.append(resp)
elif score.total_score < resp.total_score:
lower_scores.append(resp)
resp.scores_around = ScoreAround(
higher=MultiplayerScores(scores=higher_scores),
lower=MultiplayerScores(scores=lower_scores),
)
includes.append("scores_around")
resp = await ScoreModel.transform(score_record.score, includes=includes)
return resp
@router.get(
"rooms/{room_id}/playlist/{playlist_id}/scores/users/{user_id}",
response_model=ScoreResp,
responses={
200: api_doc(
"房间项目单个成绩详情。",
ScoreModel,
[*Score.MULTIPLAYER_BASE_INCLUDES, "position", "scores_around"],
)
},
name="获取房间项目用户成绩",
description="获取指定用户在房间游玩项目中的成绩。",
tags=["成绩"],
@@ -749,8 +795,14 @@ async def get_user_playlist_score(
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
resp = await ScoreResp.from_db(session, score_record.score)
resp.position = await get_position(room_id, playlist_id, score_record.score_id, session)
resp = await ScoreModel.transform(
score_record.score,
includes=[
*Score.MULTIPLAYER_BASE_INCLUDES,
"position",
"scores_around",
],
)
return resp

View File

@@ -5,17 +5,16 @@ from app.config import settings
from app.const import BANCHOBOT_ID
from app.database import (
Beatmap,
BeatmapModel,
BeatmapPlaycounts,
BeatmapPlaycountsResp,
BeatmapResp,
BeatmapsetResp,
BeatmapsetModel,
User,
UserResp,
)
from app.database.beatmap_playcounts import BeatmapPlaycountsModel
from app.database.best_scores import BestScore
from app.database.events import Event
from app.database.score import LegacyScoreResp, Score, ScoreResp, get_user_first_scores
from app.database.user import ALL_INCLUDED, SEARCH_INCLUDED
from app.database.score import Score, get_user_first_scores
from app.database.user import UserModel
from app.dependencies.api_version import APIVersion
from app.dependencies.cache import UserCacheService
from app.dependencies.database import Database, get_redis
@@ -26,24 +25,15 @@ from app.models.mods import API_MODS
from app.models.score import GameMode
from app.models.user import BeatmapsetType
from app.service.user_cache_service import get_user_cache_service
from app.utils import utcnow
from app.utils import api_doc, utcnow
from .router import router
from fastapi import BackgroundTasks, HTTPException, Path, Query, Request, Security
from pydantic import BaseModel
from sqlmodel import exists, false, select
from sqlmodel.sql.expression import col
class BatchUserResponse(BaseModel):
users: list[UserResp]
class BeatmapsPassedResponse(BaseModel):
beatmaps_passed: list[BeatmapResp]
def _get_difficulty_reduction_mods() -> set[str]:
mods: set[str] = set()
for ruleset_mods in API_MODS.values():
@@ -63,13 +53,15 @@ async def visible_to_current_user(user: User, current_user: User | None, session
@router.get(
"/users/",
response_model=BatchUserResponse,
responses={
200: api_doc("批量获取用户信息", {"users": list[UserModel]}, User.CARD_INCLUDES, name="UsersLookupResponse")
},
name="批量获取用户信息",
description="通过用户 ID 列表批量获取用户信息。",
tags=["用户"],
)
@router.get("/users/lookup", response_model=BatchUserResponse, include_in_schema=False)
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
@router.get("/users/lookup", include_in_schema=False)
@router.get("/users/lookup/", include_in_schema=False)
@asset_proxy_response
async def get_users(
session: Database,
@@ -108,16 +100,15 @@ async def get_users(
# 将查询到的用户添加到缓存并返回
for searched_user in searched_users:
if searched_user.id != BANCHOBOT_ID:
user_resp = await UserResp.from_db(
user_resp = await UserModel.transform(
searched_user,
session,
include=SEARCH_INCLUDED,
includes=User.CARD_INCLUDES,
)
cached_users.append(user_resp)
# 异步缓存,不阻塞响应
background_task.add_task(cache_service.cache_user, user_resp)
response = BatchUserResponse(users=cached_users)
response = {"users": cached_users}
return response
else:
searched_users = (
@@ -127,16 +118,15 @@ async def get_users(
for searched_user in searched_users:
if searched_user.id == BANCHOBOT_ID:
continue
user_resp = await UserResp.from_db(
user_resp = await UserModel.transform(
searched_user,
session,
include=SEARCH_INCLUDED,
includes=User.CARD_INCLUDES,
)
users.append(user_resp)
# 异步缓存
background_task.add_task(cache_service.cache_user, user_resp)
response = BatchUserResponse(users=users)
response = {"users": users}
return response
@@ -200,10 +190,12 @@ async def get_user_kudosu(
@router.get(
"/users/{user_id}/beatmaps-passed",
response_model=BeatmapsPassedResponse,
name="获取用户已通过谱面",
description="获取指定用户在给定谱面集中的已通过谱面列表。",
tags=["用户"],
responses={
200: api_doc("用户已通过谱面列表", {"beatmaps_passed": list[BeatmapModel]}, name="BeatmapsPassedResponse")
},
)
@asset_proxy_response
async def get_user_beatmaps_passed(
@@ -226,7 +218,7 @@ async def get_user_beatmaps_passed(
no_diff_reduction: Annotated[bool, Query(description="是否排除减难 MOD 成绩")] = True,
):
if not beatmapset_ids:
return BeatmapsPassedResponse(beatmaps_passed=[])
return {"beatmaps_passed": []}
if len(beatmapset_ids) > 50:
raise HTTPException(status_code=413, detail="beatmapset_ids cannot exceed 50 items")
@@ -255,7 +247,7 @@ async def get_user_beatmaps_passed(
scores = (await session.exec(score_query)).all()
if not scores:
return BeatmapsPassedResponse(beatmaps_passed=[])
return {"beatmaps_passed": []}
difficulty_reduction_mods = _get_difficulty_reduction_mods() if no_diff_reduction else set()
passed_beatmap_ids: set[int] = set()
@@ -269,7 +261,7 @@ async def get_user_beatmaps_passed(
continue
passed_beatmap_ids.add(beatmap_id)
if not passed_beatmap_ids:
return BeatmapsPassedResponse(beatmaps_passed=[])
return {"beatmaps_passed": []}
beatmaps = (
await session.exec(
@@ -279,19 +271,24 @@ async def get_user_beatmaps_passed(
)
).all()
return BeatmapsPassedResponse(
beatmaps_passed=[
await BeatmapResp.from_db(beatmap, allowed_mode, session=session, user=user) for beatmap in beatmaps
]
return {
"beatmaps_passed": [
await BeatmapModel.transform(
beatmap,
)
for beatmap in beatmaps
]
}
@router.get(
"/users/{user_id}/{ruleset}",
response_model=UserResp,
name="获取用户信息(指定ruleset)",
description="通过用户 ID 或用户名获取单个用户的详细信息,并指定特定 ruleset。",
tags=["用户"],
responses={
200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
},
)
@asset_proxy_response
async def get_user_info_ruleset(
@@ -325,29 +322,26 @@ async def get_user_info_ruleset(
if should_not_show:
raise HTTPException(404, detail="User not found")
include = SEARCH_INCLUDED
if searched_is_self:
include = ALL_INCLUDED
user_resp = await UserResp.from_db(
user_resp = await UserModel.transform(
searched_user,
session,
include=include,
includes=User.USER_INCLUDES,
ruleset=ruleset,
)
# 异步缓存结果
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
return user_resp
@router.get("/users/{user_id}/", response_model=UserResp, include_in_schema=False)
@router.get("/users/{user_id}/", include_in_schema=False)
@router.get(
"/users/{user_id}",
response_model=UserResp,
name="获取用户信息",
description="通过用户 ID 或用户名获取单个用户的详细信息。",
tags=["用户"],
responses={
200: api_doc("用户信息", UserModel, User.USER_INCLUDES),
},
)
@asset_proxy_response
async def get_user_info(
@@ -381,27 +375,31 @@ async def get_user_info(
if should_not_show:
raise HTTPException(404, detail="User not found")
include = SEARCH_INCLUDED
if searched_is_self:
include = ALL_INCLUDED
user_resp = await UserResp.from_db(
user_resp = await UserModel.transform(
searched_user,
session,
include=include,
includes=User.USER_INCLUDES,
)
# 异步缓存结果
background_task.add_task(cache_service.cache_user, user_resp)
return user_resp
beatmapset_includes = [*BeatmapsetModel.BEATMAPSET_TRANSFORMER_INCLUDES, "beatmaps"]
@router.get(
"/users/{user_id}/beatmapsets/{type}",
response_model=list[BeatmapsetResp | BeatmapPlaycountsResp],
name="获取用户谱面集列表",
description="获取指定用户特定类型的谱面集列表,如最常游玩、收藏等。",
tags=["用户"],
responses={
200: api_doc(
"当类型为 `most_played` 时返回 `list[BeatmapPlaycountsModel]`,其他为 `list[BeatmapsetModel]`",
list[BeatmapsetModel] | list[BeatmapPlaycountsModel],
beatmapset_includes,
)
},
)
@asset_proxy_response
async def get_user_beatmapsets(
@@ -417,11 +415,7 @@ async def get_user_beatmapsets(
# 先尝试从缓存获取
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
if cached_result is not None:
# 根据类型恢复对象
if type == BeatmapsetType.MOST_PLAYED:
return [BeatmapPlaycountsResp(**item) for item in cached_result]
else:
return [BeatmapsetResp(**item) for item in cached_result]
return cached_result
user = await session.get(User, user_id)
if not user or user.id == BANCHOBOT_ID:
@@ -444,7 +438,10 @@ async def get_user_beatmapsets(
raise HTTPException(404, detail="User not found")
favourites = await user.awaitable_attrs.favourite_beatmapsets
resp = [
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
await BeatmapsetModel.transform(
favourite.beatmapset, session=session, user=user, includes=beatmapset_includes
)
for favourite in favourites
]
elif type == BeatmapsetType.MOST_PLAYED:
@@ -459,7 +456,10 @@ async def get_user_beatmapsets(
.limit(limit)
.offset(offset)
)
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
resp = [
await BeatmapPlaycountsModel.transform(most_played_beatmap, user=user, includes=beatmapset_includes)
for most_played_beatmap in most_played
]
else:
raise HTTPException(400, detail="Invalid beatmapset type")
@@ -477,7 +477,6 @@ async def get_user_beatmapsets(
@router.get(
"/users/{user_id}/scores/{type}",
response_model=list[ScoreResp] | list[LegacyScoreResp],
name="获取用户成绩列表",
description=(
"获取用户特定类型的成绩列表,如最好成绩、最近成绩等。\n\n"
@@ -523,6 +522,7 @@ async def get_user_scores(
gamemode = mode or db_user.playmode
order_by = None
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
includes = Score.USER_PROFILE_INCLUDES.copy()
if not include_fails:
where_clause &= col(Score.passed).is_(True)
if type == "pinned":
@@ -531,6 +531,7 @@ async def get_user_scores(
elif type == "best":
where_clause &= exists().where(col(BestScore.score_id) == Score.id)
order_by = col(Score.pp).desc()
includes.append("weight")
elif type == "recent":
where_clause &= Score.ended_at > utcnow() - timedelta(hours=24)
order_by = col(Score.ended_at).desc()
@@ -551,6 +552,7 @@ async def get_user_scores(
await score.to_resp(
session,
api_version,
includes=includes,
)
for score in scores
]

View File

@@ -3,14 +3,14 @@ Beatmapset缓存服务
用于缓存beatmapset数据减少数据库查询频率
"""
from datetime import datetime
import hashlib
import json
from typing import TYPE_CHECKING
from app.config import settings
from app.database.beatmapset import BeatmapsetResp
from app.database import BeatmapsetDict
from app.log import logger
from app.utils import safe_json_dumps
from redis.asyncio import Redis
@@ -18,20 +18,6 @@ if TYPE_CHECKING:
pass
class DateTimeEncoder(json.JSONEncoder):
"""处理datetime序列化的JSON编码器"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
def safe_json_dumps(data) -> str:
"""安全的JSON序列化处理datetime对象"""
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False)
def generate_hash(data) -> str:
"""生成数据的MD5哈希值"""
content = data if isinstance(data, str) else safe_json_dumps(data)
@@ -57,15 +43,14 @@ class BeatmapsetCacheService:
"""生成搜索结果缓存键"""
return f"beatmapset_search:{query_hash}:{cursor_hash}"
async def get_beatmapset_from_cache(self, beatmapset_id: int) -> BeatmapsetResp | None:
async def get_beatmapset_from_cache(self, beatmapset_id: int) -> BeatmapsetDict | None:
"""从缓存获取beatmapset信息"""
try:
cache_key = self._get_beatmapset_cache_key(beatmapset_id)
cached_data = await self.redis.get(cache_key)
if cached_data:
logger.debug(f"Beatmapset cache hit for {beatmapset_id}")
data = json.loads(cached_data)
return BeatmapsetResp(**data)
return json.loads(cached_data)
return None
except (ValueError, TypeError, AttributeError) as e:
logger.error(f"Error getting beatmapset from cache: {e}")
@@ -73,24 +58,21 @@ class BeatmapsetCacheService:
async def cache_beatmapset(
self,
beatmapset_resp: BeatmapsetResp,
beatmapset_resp: BeatmapsetDict,
expire_seconds: int | None = None,
):
"""缓存beatmapset信息"""
try:
if expire_seconds is None:
expire_seconds = self._default_ttl
if beatmapset_resp.id is None:
logger.warning("Cannot cache beatmapset with None id")
return
cache_key = self._get_beatmapset_cache_key(beatmapset_resp.id)
cached_data = beatmapset_resp.model_dump_json()
cache_key = self._get_beatmapset_cache_key(beatmapset_resp["id"])
cached_data = safe_json_dumps(beatmapset_resp)
await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore
logger.debug(f"Cached beatmapset {beatmapset_resp.id} for {expire_seconds}s")
logger.debug(f"Cached beatmapset {beatmapset_resp['id']} for {expire_seconds}s")
except (ValueError, TypeError, AttributeError) as e:
logger.error(f"Error caching beatmapset: {e}")
async def get_beatmap_lookup_from_cache(self, beatmap_id: int) -> BeatmapsetResp | None:
async def get_beatmap_lookup_from_cache(self, beatmap_id: int) -> BeatmapsetDict | None:
"""从缓存获取通过beatmap ID查找的beatmapset信息"""
try:
cache_key = self._get_beatmap_lookup_cache_key(beatmap_id)
@@ -98,7 +80,7 @@ class BeatmapsetCacheService:
if cached_data:
logger.debug(f"Beatmap lookup cache hit for {beatmap_id}")
data = json.loads(cached_data)
return BeatmapsetResp(**data)
return data
return None
except (ValueError, TypeError, AttributeError) as e:
logger.error(f"Error getting beatmap lookup from cache: {e}")
@@ -107,7 +89,7 @@ class BeatmapsetCacheService:
async def cache_beatmap_lookup(
self,
beatmap_id: int,
beatmapset_resp: BeatmapsetResp,
beatmapset_resp: BeatmapsetDict,
expire_seconds: int | None = None,
):
"""缓存通过beatmap ID查找的beatmapset信息"""
@@ -115,7 +97,7 @@ class BeatmapsetCacheService:
if expire_seconds is None:
expire_seconds = self._default_ttl
cache_key = self._get_beatmap_lookup_cache_key(beatmap_id)
cached_data = beatmapset_resp.model_dump_json()
cached_data = safe_json_dumps(beatmapset_resp)
await self.redis.setex(cache_key, expire_seconds, cached_data) # type: ignore
logger.debug(f"Cached beatmap lookup {beatmap_id} for {expire_seconds}s")
except (ValueError, TypeError, AttributeError) as e:

View File

@@ -3,12 +3,12 @@ from datetime import timedelta
from enum import Enum
import math
import random
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, NamedTuple, cast
from app.config import OldScoreProcessingMode, settings
from app.database.beatmap import Beatmap, BeatmapResp
from app.database.beatmap import Beatmap, BeatmapDict
from app.database.beatmap_sync import BeatmapSync, SavedBeatmapMeta
from app.database.beatmapset import Beatmapset, BeatmapsetResp
from app.database.beatmapset import Beatmapset, BeatmapsetDict
from app.database.score import Score
from app.dependencies.database import get_redis, with_db
from app.dependencies.storage import get_storage_service
@@ -62,10 +62,23 @@ STATUS_FACTOR: dict[BeatmapRankStatus, float] = {
SCHEDULER_INTERVAL_MINUTES = 2
class EnsuredBeatmap(BeatmapDict):
checksum: str
ranked: int
class EnsuredBeatmapset(BeatmapsetDict):
ranked: int
ranked_date: datetime.datetime
last_updated: datetime.datetime
play_count: int
beatmaps: list[EnsuredBeatmap]
class ProcessingBeatmapset:
def __init__(self, beatmapset: BeatmapsetResp, record: BeatmapSync) -> None:
def __init__(self, beatmapset: EnsuredBeatmapset, record: BeatmapSync) -> None:
self.beatmapset = beatmapset
self.status = BeatmapRankStatus(self.beatmapset.ranked)
self.status = BeatmapRankStatus(self.beatmapset["ranked"])
self.record = record
def calculate_next_sync_time(
@@ -76,19 +89,19 @@ class ProcessingBeatmapset:
now = utcnow()
if self.status == BeatmapRankStatus.QUALIFIED:
assert self.beatmapset.ranked_date is not None, "ranked_date should not be None for qualified maps"
time_to_ranked = (self.beatmapset.ranked_date + timedelta(days=7) - now).total_seconds()
assert self.beatmapset["ranked_date"] is not None, "ranked_date should not be None for qualified maps"
time_to_ranked = (self.beatmapset["ranked_date"] + timedelta(days=7) - now).total_seconds()
baseline = max(MIN_DELTA, time_to_ranked / 2)
next_delta = max(MIN_DELTA, baseline)
elif self.status in {BeatmapRankStatus.WIP, BeatmapRankStatus.PENDING}:
seconds_since_update = (now - self.beatmapset.last_updated).total_seconds()
seconds_since_update = (now - self.beatmapset["last_updated"]).total_seconds()
factor_update = max(1.0, seconds_since_update / TAU)
factor_play = 1.0 + math.log(1.0 + self.beatmapset.play_count)
factor_play = 1.0 + math.log(1.0 + self.beatmapset["play_count"])
status_factor = STATUS_FACTOR[self.status]
baseline = BASE * factor_play / factor_update * status_factor
next_delta = max(MIN_DELTA, baseline * (GROWTH ** (self.record.consecutive_no_change + 1)))
elif self.status == BeatmapRankStatus.GRAVEYARD:
days_since_update = (now - self.beatmapset.last_updated).days
days_since_update = (now - self.beatmapset["last_updated"]).days
doubling_periods = days_since_update / GRAVEYARD_DOUBLING_PERIOD_DAYS
delta = MIN_DELTA * (2**doubling_periods)
max_seconds = GRAVEYARD_MAX_DAYS * 86400
@@ -105,21 +118,24 @@ class ProcessingBeatmapset:
@property
def beatmapset_changed(self) -> bool:
return self.record.beatmap_status != BeatmapRankStatus(self.beatmapset.ranked)
return self.record.beatmap_status != BeatmapRankStatus(self.beatmapset["ranked"])
@property
def changed_beatmaps(self) -> list[ChangedBeatmap]:
changed_beatmaps = []
for bm in self.beatmapset.beatmaps:
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None)
for bm in self.beatmapset["beatmaps"]:
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm["id"]), None)
if not saved or saved["is_deleted"]:
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
elif saved["md5"] != bm.checksum:
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED))
elif saved["beatmap_status"] != BeatmapRankStatus(bm.ranked):
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.STATUS_CHANGED))
changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.MAP_ADDED))
elif saved["md5"] != bm["checksum"]:
changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.MAP_UPDATED))
elif saved["beatmap_status"] != BeatmapRankStatus(bm["ranked"]):
changed_beatmaps.append(ChangedBeatmap(bm["id"], BeatmapChangeType.STATUS_CHANGED))
for saved in self.record.beatmaps:
if not any(bm.id == saved["beatmap_id"] for bm in self.beatmapset.beatmaps) and not saved["is_deleted"]:
if (
not any(bm["id"] == saved["beatmap_id"] for bm in self.beatmapset["beatmaps"])
and not saved["is_deleted"]
):
changed_beatmaps.append(ChangedBeatmap(saved["beatmap_id"], BeatmapChangeType.MAP_DELETED))
return changed_beatmaps
@@ -132,7 +148,7 @@ class BeatmapsetUpdateService:
async def add_missing_beatmapset(self, beatmapset_id: int, immediate: bool = False) -> bool:
beatmapset = await self.fetcher.get_beatmapset(beatmapset_id)
if immediate:
await self._sync_immediately(beatmapset)
await self._sync_immediately(cast(EnsuredBeatmapset, beatmapset))
logger.debug(f"triggered immediate sync for beatmapset {beatmapset_id} ")
return True
await self.add(beatmapset)
@@ -172,7 +188,7 @@ class BeatmapsetUpdateService:
BeatmapSync(
beatmapset_id=missing,
beatmap_status=BeatmapRankStatus.GRAVEYARD,
next_sync_time=datetime.datetime.max,
next_sync_time=datetime.datetime(year=6000, month=1, day=1),
beatmaps=[],
)
)
@@ -185,11 +201,13 @@ class BeatmapsetUpdateService:
await session.commit()
self._adding_missing = False
async def add(self, beatmapset: BeatmapsetResp, calculate_next_sync: bool = True):
async def add(self, set: BeatmapsetDict, calculate_next_sync: bool = True):
beatmapset = cast(EnsuredBeatmapset, set)
async with with_db() as session:
sync_record = await session.get(BeatmapSync, beatmapset.id)
beatmapset_id = beatmapset["id"]
sync_record = await session.get(BeatmapSync, beatmapset_id)
if not sync_record:
database_beatmapset = await session.get(Beatmapset, beatmapset.id)
database_beatmapset = await session.get(Beatmapset, beatmapset_id)
if database_beatmapset:
status = BeatmapRankStatus(database_beatmapset.beatmap_status)
await database_beatmapset.awaitable_attrs.beatmaps
@@ -203,19 +221,29 @@ class BeatmapsetUpdateService:
for bm in database_beatmapset.beatmaps
]
else:
status = BeatmapRankStatus(beatmapset.ranked)
beatmaps = [
ranked = beatmapset.get("ranked")
if ranked is None:
raise ValueError("ranked field is required")
status = BeatmapRankStatus(ranked)
beatmap_list = beatmapset.get("beatmaps", [])
beatmaps = []
for bm in beatmap_list:
bm_id = bm.get("id")
checksum = bm.get("checksum")
ranked = bm.get("ranked")
if bm_id is None or checksum is None or ranked is None:
continue
beatmaps.append(
SavedBeatmapMeta(
beatmap_id=bm.id,
md5=bm.checksum,
beatmap_id=bm_id,
md5=checksum,
is_deleted=False,
beatmap_status=BeatmapRankStatus(bm.ranked),
beatmap_status=BeatmapRankStatus(ranked),
)
)
for bm in beatmapset.beatmaps
]
sync_record = BeatmapSync(
beatmapset_id=beatmapset.id,
beatmapset_id=beatmapset_id,
beatmaps=beatmaps,
beatmap_status=status,
)
@@ -223,13 +251,27 @@ class BeatmapsetUpdateService:
await session.commit()
await session.refresh(sync_record)
else:
sync_record.beatmaps = [
ranked = beatmapset.get("ranked")
if ranked is None:
raise ValueError("ranked field is required")
beatmap_list = beatmapset.get("beatmaps", [])
beatmaps = []
for bm in beatmap_list:
bm_id = bm.get("id")
checksum = bm.get("checksum")
bm_ranked = bm.get("ranked")
if bm_id is None or checksum is None or bm_ranked is None:
continue
beatmaps.append(
SavedBeatmapMeta(
beatmap_id=bm.id, md5=bm.checksum, is_deleted=False, beatmap_status=BeatmapRankStatus(bm.ranked)
beatmap_id=bm_id,
md5=checksum,
is_deleted=False,
beatmap_status=BeatmapRankStatus(bm_ranked),
)
for bm in beatmapset.beatmaps
]
sync_record.beatmap_status = BeatmapRankStatus(beatmapset.ranked)
)
sync_record.beatmaps = beatmaps
sync_record.beatmap_status = BeatmapRankStatus(ranked)
if calculate_next_sync:
processing = ProcessingBeatmapset(beatmapset, sync_record)
next_time_delta = processing.calculate_next_sync_time()
@@ -238,17 +280,19 @@ class BeatmapsetUpdateService:
await BeatmapsetUpdateService._sync_immediately(self, beatmapset)
return
sync_record.next_sync_time = utcnow() + next_time_delta
logger.opt(colors=True).info(f"<g>[{beatmapset.id}]</g> next sync at {sync_record.next_sync_time}")
beatmapset_id = beatmapset.get("id")
if beatmapset_id:
logger.opt(colors=True).debug(f"<g>[{beatmapset_id}]</g> next sync at {sync_record.next_sync_time}")
await session.commit()
async def _sync_immediately(self, beatmapset: BeatmapsetResp) -> None:
async def _sync_immediately(self, beatmapset: EnsuredBeatmapset) -> None:
async with with_db() as session:
record = await session.get(BeatmapSync, beatmapset.id)
record = await session.get(BeatmapSync, beatmapset["id"])
if not record:
record = BeatmapSync(
beatmapset_id=beatmapset.id,
beatmapset_id=beatmapset["id"],
beatmaps=[],
beatmap_status=BeatmapRankStatus(beatmapset.ranked),
beatmap_status=BeatmapRankStatus(beatmapset["ranked"]),
)
session.add(record)
await session.commit()
@@ -261,19 +305,18 @@ class BeatmapsetUpdateService:
record: BeatmapSync,
session: AsyncSession,
*,
beatmapset: BeatmapsetResp | None = None,
beatmapset: EnsuredBeatmapset | None = None,
):
logger.opt(colors=True).info(f"<g>[{record.beatmapset_id}]</g> syncing...")
logger.opt(colors=True).debug(f"<g>[{record.beatmapset_id}]</g> syncing...")
if beatmapset is None:
try:
beatmapset = await self.fetcher.get_beatmapset(record.beatmapset_id)
beatmapset = cast(EnsuredBeatmapset, await self.fetcher.get_beatmapset(record.beatmapset_id))
except Exception as e:
if isinstance(e, HTTPStatusError) and e.response.status_code == 404:
logger.opt(colors=True).warning(
f"<g>[{record.beatmapset_id}]</g> beatmapset not found (404), removing from sync list"
)
await session.delete(record)
await session.commit()
return
if isinstance(e, HTTPError):
logger.opt(colors=True).warning(
@@ -292,20 +335,20 @@ class BeatmapsetUpdateService:
if changed:
record.beatmaps = [
SavedBeatmapMeta(
beatmap_id=bm.id,
md5=bm.checksum,
beatmap_id=bm["id"],
md5=bm["checksum"],
is_deleted=False,
beatmap_status=BeatmapRankStatus(bm.ranked),
beatmap_status=BeatmapRankStatus(bm["ranked"]),
)
for bm in beatmapset.beatmaps
for bm in beatmapset["beatmaps"]
]
record.beatmap_status = BeatmapRankStatus(beatmapset.ranked)
record.beatmap_status = BeatmapRankStatus(beatmapset["ranked"])
record.consecutive_no_change = 0
bg_tasks.add_task(
self._process_changed_beatmaps,
changed_beatmaps,
beatmapset.beatmaps,
beatmapset["beatmaps"],
)
bg_tasks.add_task(
self._process_changed_beatmapset,
@@ -317,13 +360,13 @@ class BeatmapsetUpdateService:
next_time_delta = processing.calculate_next_sync_time()
if not next_time_delta:
logger.opt(colors=True).info(
f"<yellow>[{beatmapset.id}]</yellow> beatmapset has transformed to ranked or loved,"
f"<yellow>[{beatmapset['id']}]</yellow> beatmapset has transformed to ranked or loved,"
f" removing from sync list"
)
await session.delete(record)
else:
record.next_sync_time = utcnow() + next_time_delta
logger.opt(colors=True).info(f"<g>[{record.beatmapset_id}]</g> next sync at {record.next_sync_time}")
logger.opt(colors=True).debug(f"<g>[{record.beatmapset_id}]</g> next sync at {record.next_sync_time}")
async def _update_beatmaps(self):
async with with_db() as session:
@@ -338,18 +381,18 @@ class BeatmapsetUpdateService:
await self.sync(record, session)
await session.commit()
async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp):
async def _process_changed_beatmapset(self, beatmapset: EnsuredBeatmapset):
async with with_db() as session:
db_beatmapset = await session.get(Beatmapset, beatmapset.id)
new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset)
db_beatmapset = await session.get(Beatmapset, beatmapset["id"])
new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset) # pyright: ignore[reportArgumentType]
if db_beatmapset:
await session.merge(new_beatmapset)
await get_beatmapset_cache_service(get_redis()).invalidate_beatmapset_cache(beatmapset.id)
await get_beatmapset_cache_service(get_redis()).invalidate_beatmapset_cache(beatmapset["id"])
await session.commit()
async def _process_changed_beatmaps(self, changed: list[ChangedBeatmap], beatmaps_list: list[BeatmapResp]):
async def _process_changed_beatmaps(self, changed: list[ChangedBeatmap], beatmaps_list: list[EnsuredBeatmap]):
storage_service = get_storage_service()
beatmaps = {bm.id: bm for bm in beatmaps_list}
beatmaps = {bm["id"]: bm for bm in beatmaps_list}
async with with_db() as session:
@@ -380,9 +423,9 @@ class BeatmapsetUpdateService:
)
continue
logger.opt(colors=True).info(
f"<g>[{beatmap.beatmapset_id}]</g> adding beatmap <blue>{beatmap.id}</blue>"
f"<g>[{beatmap['beatmapset_id']}]</g> adding beatmap <blue>{beatmap['id']}</blue>"
)
await Beatmap.from_resp_no_save(session, beatmap)
await Beatmap.from_resp_no_save(session, beatmap) # pyright: ignore[reportArgumentType]
else:
beatmap = beatmaps.get(change.beatmap_id)
if not beatmap:
@@ -391,10 +434,10 @@ class BeatmapsetUpdateService:
)
continue
logger.opt(colors=True).info(
f"<g>[{beatmap.beatmapset_id}]</g> processing beatmap <blue>{beatmap.id}</blue> "
f"<g>[{beatmap['beatmapset_id']}]</g> processing beatmap <blue>{beatmap['id']}</blue> "
f"change <cyan>{change.type}</cyan>"
)
new_db_beatmap = await Beatmap.from_resp_no_save(session, beatmap)
new_db_beatmap = await Beatmap.from_resp_no_save(session, beatmap) # pyright: ignore[reportArgumentType]
existing_beatmap = await session.get(Beatmap, change.beatmap_id)
if existing_beatmap:
await session.merge(new_db_beatmap)

View File

@@ -4,16 +4,15 @@
"""
import asyncio
from datetime import datetime
import json
from typing import TYPE_CHECKING, Literal
from app.config import settings
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.database.statistics import UserStatistics, UserStatisticsModel
from app.helpers.asset_proxy_helper import replace_asset_urls
from app.log import logger
from app.models.score import GameMode
from app.utils import utcnow
from app.utils import safe_json_dumps, utcnow
from redis.asyncio import Redis
from sqlmodel import col, select
@@ -23,20 +22,6 @@ if TYPE_CHECKING:
pass
class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
def safe_json_dumps(data) -> str:
"""安全的 JSON 序列化,支持 datetime 对象"""
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":"))
class RankingCacheService:
"""用户排行榜缓存服务"""
@@ -311,7 +296,7 @@ class RankingCacheService:
col(UserStatistics.pp) > 0,
col(UserStatistics.is_ranked).is_(True),
]
include = ["user"]
include = UserStatistics.RANKING_INCLUDES.copy()
if type == "performance":
order_by = col(UserStatistics.pp).desc()
@@ -321,6 +306,7 @@ class RankingCacheService:
if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
include.append("country_rank")
# 获取总用户数用于统计
total_users_query = select(UserStatistics).where(*wheres)
@@ -353,9 +339,9 @@ class RankingCacheService:
# 转换为响应格式并确保正确序列化
ranking_data = []
for statistics in statistics_data:
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
user_stats_resp = await UserStatisticsModel.transform(statistics, includes=include)
user_dict = user_stats_resp.model_dump()
user_dict = user_stats_resp
# 应用资源代理处理
if settings.enable_asset_proxy:

View File

@@ -8,14 +8,14 @@
import asyncio
from datetime import datetime
import json
import time
from typing import Any
from app.database.chat import ChatMessage, ChatMessageResp, MessageType
from app.database.user import RANKING_INCLUDES, User, UserResp
from app.database import ChatMessageDict
from app.database.chat import ChatMessage, ChatMessageModel, MessageType
from app.database.user import User, UserModel
from app.dependencies.database import get_redis_message, with_db
from app.log import logger
from app.utils import bg_tasks
from app.utils import bg_tasks, safe_json_dumps
class RedisMessageSystem:
@@ -35,7 +35,7 @@ class RedisMessageSystem:
content: str,
is_action: bool = False,
user_uuid: str | None = None,
) -> ChatMessageResp:
) -> "ChatMessageDict":
"""
发送消息 - 立即存储到 Redis 并返回
@@ -47,7 +47,7 @@ class RedisMessageSystem:
user_uuid: 用户UUID
Returns:
ChatMessageResp: 消息响应对象
ChatMessage: 消息响应对象
"""
# 生成消息ID和时间戳
message_id = await self._generate_message_id(channel_id)
@@ -57,28 +57,16 @@ class RedisMessageSystem:
if not user.id:
raise ValueError("User ID is required")
# 获取频道类型以判断是否需要存储到数据库
async with with_db() as session:
from app.database.chat import ChannelType, ChatChannel
from sqlmodel import select
channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id))
channel_type = channel_result.first()
is_multiplayer = channel_type == ChannelType.MULTIPLAYER
# 准备消息数据
message_data = {
message_data: "ChatMessageDict" = {
"message_id": message_id,
"channel_id": channel_id,
"sender_id": user.id,
"content": content,
"timestamp": timestamp.isoformat(),
"type": MessageType.ACTION.value if is_action else MessageType.PLAIN.value,
"timestamp": timestamp,
"type": MessageType.ACTION if is_action else MessageType.PLAIN,
"uuid": user_uuid or "",
"status": "cached", # Redis 缓存状态
"created_at": time.time(),
"is_multiplayer": is_multiplayer, # 标记是否为多人房间消息
"is_action": is_action,
}
# 立即存储到 Redis
@@ -86,51 +74,13 @@ class RedisMessageSystem:
# 创建响应对象
async with with_db() as session:
user_resp = await UserResp.from_db(user, session, RANKING_INCLUDES)
user_resp = await UserModel.transform(user, session=session, includes=User.LIST_INCLUDES)
message_data["sender"] = user_resp
# 确保 statistics 不为空
if user_resp.statistics is None:
from app.database.statistics import UserStatisticsResp
user_resp.statistics = UserStatisticsResp(
mode=user.playmode,
global_rank=0,
country_rank=0,
pp=0.0,
ranked_score=0,
hit_accuracy=0.0,
play_count=0,
play_time=0,
total_score=0,
total_hits=0,
maximum_combo=0,
replays_watched_by_others=0,
is_ranked=False,
grade_counts={"ssh": 0, "ss": 0, "sh": 0, "s": 0, "a": 0},
level={"current": 1, "progress": 0},
)
response = ChatMessageResp(
message_id=message_id,
channel_id=channel_id,
content=content,
timestamp=timestamp,
sender_id=user.id,
sender=user_resp,
is_action=is_action,
uuid=user_uuid,
)
if is_multiplayer:
logger.info(
f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id},"
" will not be persisted to database"
)
else:
logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}")
return response
return message_data
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]:
async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
"""
获取频道消息 - 优先从 Redis 获取最新消息
@@ -140,9 +90,9 @@ class RedisMessageSystem:
since: 起始消息ID
Returns:
List[ChatMessageResp]: 消息列表
List[ChatMessageDict]: 消息列表
"""
messages = []
messages: list["ChatMessageDict"] = []
try:
# 从 Redis 获取最新消息
@@ -154,45 +104,21 @@ class RedisMessageSystem:
# 获取发送者信息
sender = await session.get(User, msg_data["sender_id"])
if sender:
user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES)
user_resp = await UserModel.transform(sender, includes=User.LIST_INCLUDES)
if user_resp.statistics is None:
from app.database.statistics import UserStatisticsResp
from app.database.chat import ChatMessageDict
user_resp.statistics = UserStatisticsResp(
mode=sender.playmode,
global_rank=0,
country_rank=0,
pp=0.0,
ranked_score=0,
hit_accuracy=0.0,
play_count=0,
play_time=0,
total_score=0,
total_hits=0,
maximum_combo=0,
replays_watched_by_others=0,
is_ranked=False,
grade_counts={
"ssh": 0,
"ss": 0,
"sh": 0,
"s": 0,
"a": 0,
},
level={"current": 1, "progress": 0},
)
message_resp = ChatMessageResp(
message_id=msg_data["message_id"],
channel_id=msg_data["channel_id"],
content=msg_data["content"],
timestamp=datetime.fromisoformat(msg_data["timestamp"]),
sender_id=msg_data["sender_id"],
sender=user_resp,
is_action=msg_data["type"] == MessageType.ACTION.value,
uuid=msg_data.get("uuid") or None,
)
message_resp: ChatMessageDict = {
"message_id": msg_data["message_id"],
"channel_id": msg_data["channel_id"],
"content": msg_data["content"],
"timestamp": datetime.fromisoformat(msg_data["timestamp"]), # pyright: ignore[reportArgumentType]
"sender_id": msg_data["sender_id"],
"sender": user_resp,
"is_action": msg_data["type"] == MessageType.ACTION.value,
"uuid": msg_data.get("uuid") or None,
"type": MessageType(msg_data["type"]),
}
messages.append(message_resp)
# 如果 Redis 消息不够,从数据库补充
@@ -216,66 +142,30 @@ class RedisMessageSystem:
return message_id
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]):
async def _store_to_redis(self, message_id: int, channel_id: int, message_data: ChatMessageDict):
"""存储消息到 Redis"""
try:
# 检查是否是多人房间消息
is_multiplayer = message_data.get("is_multiplayer", False)
# 存储消息数据
await self.redis.hset(
# 存储消息数据为 JSON 字符串
await self.redis.set(
f"msg:{channel_id}:{message_id}",
mapping={k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) for k, v in message_data.items()},
safe_json_dumps(message_data),
ex=604800, # 7天过期
)
# 设置消息过期时间7天
await self.redis.expire(f"msg:{channel_id}:{message_id}", 604800)
# 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序)
# 添加到频道消息列表(按时间排序
channel_messages_key = f"channel:{channel_id}:messages"
# 更健壮的键类型检查清理
# 检查清理错误类型的键
try:
key_type = await self.redis.type(channel_messages_key)
if key_type == "none":
# 键不存在,这是正常的
pass
elif key_type != "zset":
# 键类型错误,需要清理
if key_type not in ("none", "zset"):
logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}")
await self.redis.delete(channel_messages_key)
# 验证删除是否成功
verify_type = await self.redis.type(channel_messages_key)
if verify_type != "none":
logger.error(
f"Failed to delete problematic key {channel_messages_key}, type is still {verify_type}"
)
# 强制删除
await self.redis.unlink(channel_messages_key)
except Exception as type_check_error:
logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}")
# 如果检查失败,尝试强制删除键以确保清理
try:
await self.redis.delete(channel_messages_key)
except Exception:
# 最后的努力使用unlink
try:
await self.redis.unlink(channel_messages_key)
except Exception as final_error:
logger.error(f"Critical: Unable to clear problematic key {channel_messages_key}: {final_error}")
# 添加到频道消息列表sorted set
try:
await self.redis.zadd(
channel_messages_key,
mapping={f"msg:{channel_id}:{message_id}": message_id},
)
except Exception as zadd_error:
logger.error(f"Failed to add message to sorted set {channel_messages_key}: {zadd_error}")
# 如果添加失败,再次尝试清理并重试
await self.redis.delete(channel_messages_key)
await self.redis.zadd(
channel_messages_key,
mapping={f"msg:{channel_id}:{message_id}": message_id},
@@ -284,18 +174,14 @@ class RedisMessageSystem:
# 保持频道消息列表大小最多1000条
await self.redis.zremrangebyrank(channel_messages_key, 0, -1001)
# 只有非多人房间消息才添加到待持久化队列
if not is_multiplayer:
await self.redis.lpush("pending_messages", f"{channel_id}:{message_id}")
logger.debug(f"Message {message_id} added to persistence queue")
else:
logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue")
except Exception as e:
logger.error(f"Failed to store message to Redis: {e}")
raise
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]:
async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageDict]:
"""从 Redis 获取消息"""
try:
# 获取消息键列表按消息ID排序
@@ -314,28 +200,16 @@ class RedisMessageSystem:
messages = []
for key in message_keys:
# 获取消息数据
raw_data = await self.redis.hgetall(key)
# 获取消息数据JSON 字符串)
raw_data = await self.redis.get(key)
if raw_data:
# 解码数据
message_data: dict[str, Any] = {}
for k, v in raw_data.items():
# 尝试解析 JSON
try:
if k in ["grade_counts", "level"] or v.startswith(("{", "[")):
message_data[k] = json.loads(v)
elif k in ["message_id", "channel_id", "sender_id"]:
message_data[k] = int(v)
elif k == "is_multiplayer":
message_data[k] = v == "True"
elif k == "created_at":
message_data[k] = float(v)
else:
message_data[k] = v
except (json.JSONDecodeError, ValueError):
message_data[k] = v
# 解析 JSON 字符串为字典
message_data = json.loads(raw_data)
messages.append(message_data)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode message JSON from {key}: {e}")
continue
# 确保消息按ID正序排序时间顺序
messages.sort(key=lambda x: x.get("message_id", 0))
@@ -350,15 +224,15 @@ class RedisMessageSystem:
logger.error(f"Failed to get messages from Redis: {e}")
return []
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int):
async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageDict], limit: int):
"""从数据库补充历史消息"""
try:
# 找到最小的消息ID
min_id = float("inf")
if existing_messages:
for msg in existing_messages:
if msg.message_id is not None and msg.message_id < min_id:
min_id = msg.message_id
if msg["message_id"] is not None and msg["message_id"] < min_id:
min_id = msg["message_id"]
needed = limit - len(existing_messages)
@@ -378,13 +252,13 @@ class RedisMessageSystem:
db_messages = (await session.exec(query)).all()
for msg in reversed(db_messages): # 按时间正序插入
msg_resp = await ChatMessageResp.from_db(msg, session)
msg_resp = await ChatMessageModel.transform(msg, includes=["sender"])
existing_messages.insert(0, msg_resp)
except Exception as e:
logger.error(f"Failed to backfill from database: {e}")
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]:
async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageDict]:
"""仅从数据库获取消息(回退方案)"""
try:
async with with_db() as session:
@@ -402,7 +276,7 @@ class RedisMessageSystem:
messages = (await session.exec(query)).all()
results = [await ChatMessageResp.from_db(msg, session) for msg in messages]
results = await ChatMessageModel.transform_many(messages, includes=["sender"])
# 如果是 since > 0保持正序否则反转为时间正序
if since == 0:
@@ -450,27 +324,17 @@ class RedisMessageSystem:
# 解析频道ID和消息ID
channel_id, message_id = map(int, key.split(":"))
# 从 Redis 获取消息数据
raw_data = await self.redis.hgetall(f"msg:{channel_id}:{message_id}")
# 从 Redis 获取消息数据JSON 字符串)
raw_data = await self.redis.get(f"msg:{channel_id}:{message_id}")
if not raw_data:
continue
# 解码数据
message_data = {}
for k, v in raw_data.items():
message_data[k] = v
# 检查是否是多人房间消息,如果是则跳过数据库存储
is_multiplayer = message_data.get("is_multiplayer", "False") == "True"
if is_multiplayer:
# 多人房间消息不存储到数据库,直接标记为已跳过
await self.redis.hset(
f"msg:{channel_id}:{message_id}",
"status",
"skipped_multiplayer",
)
logger.debug(f"Message {message_id} in multiplayer room skipped from database storage")
# 解析 JSON 字符串为字典
try:
message_data = json.loads(raw_data)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode message JSON for {channel_id}:{message_id}: {e}")
continue
# 检查消息是否已存在于数据库
@@ -491,13 +355,6 @@ class RedisMessageSystem:
session.add(db_message)
# 更新 Redis 中的状态
await self.redis.hset(
f"msg:{channel_id}:{message_id}",
"status",
"persisted",
)
logger.debug(f"Message {message_id} persisted to database")
except Exception as e:

View File

@@ -14,8 +14,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
async def create_playlist_room_from_api(session: AsyncSession, room: APIUploadedRoom, host_id: int) -> Room:
db_room = room.to_room()
db_room.host_id = host_id
db_room = Room.model_validate({"host_id": host_id, **room.model_dump(exclude={"playlist"})})
db_room.starts_at = utcnow()
db_room.ends_at = db_room.starts_at + timedelta(minutes=db_room.duration if db_room.duration is not None else 0)
session.add(db_room)

View File

@@ -3,19 +3,19 @@
用于缓存用户信息,提供热缓存和实时刷新功能
"""
from datetime import datetime
import json
from typing import TYPE_CHECKING, Any
from app.config import settings
from app.const import BANCHOBOT_ID
from app.database import User, UserResp
from app.database.score import LegacyScoreResp, ScoreResp
from app.database.user import SEARCH_INCLUDED
from app.database import User
from app.database.score import LegacyScoreResp
from app.database.user import UserDict, UserModel
from app.dependencies.database import with_db
from app.helpers.asset_proxy_helper import replace_asset_urls
from app.log import logger
from app.models.score import GameMode
from app.utils import safe_json_dumps
from redis.asyncio import Redis
from sqlmodel import col, select
@@ -25,20 +25,6 @@ if TYPE_CHECKING:
pass
class DateTimeEncoder(json.JSONEncoder):
"""自定义 JSON 编码器,支持 datetime 序列化"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
def safe_json_dumps(data: Any) -> str:
"""安全的 JSON 序列化,支持 datetime 对象"""
return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False)
class UserCacheService:
"""用户缓存服务"""
@@ -125,7 +111,7 @@ class UserCacheService:
"""生成用户谱面集缓存键"""
return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}"
async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserResp | None:
async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserDict | None:
"""从缓存获取用户信息"""
try:
cache_key = self._get_user_cache_key(user_id, ruleset)
@@ -133,7 +119,7 @@ class UserCacheService:
if cached_data:
logger.debug(f"User cache hit for user {user_id}")
data = json.loads(cached_data)
return UserResp(**data)
return data
return None
except Exception as e:
logger.error(f"Error getting user from cache: {e}")
@@ -141,7 +127,7 @@ class UserCacheService:
async def cache_user(
self,
user_resp: UserResp,
user_resp: UserDict,
ruleset: GameMode | None = None,
expire_seconds: int | None = None,
):
@@ -149,13 +135,10 @@ class UserCacheService:
try:
if expire_seconds is None:
expire_seconds = settings.user_cache_expire_seconds
if user_resp.id is None:
logger.warning("Cannot cache user with None id")
return
cache_key = self._get_user_cache_key(user_resp.id, ruleset)
cached_data = user_resp.model_dump_json()
cache_key = self._get_user_cache_key(user_resp["id"], ruleset)
cached_data = safe_json_dumps(user_resp)
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_resp.id} for {expire_seconds}s")
logger.debug(f"Cached user {user_resp['id']} for {expire_seconds}s")
except Exception as e:
logger.error(f"Error caching user: {e}")
@@ -168,10 +151,9 @@ class UserCacheService:
limit: int = 100,
offset: int = 0,
is_legacy: bool = False,
) -> list[ScoreResp] | list[LegacyScoreResp] | None:
) -> list[UserDict] | list[LegacyScoreResp] | None:
"""从缓存获取用户成绩"""
try:
model = LegacyScoreResp if is_legacy else ScoreResp
cache_key = self._get_user_scores_cache_key(
user_id, score_type, include_fail, mode, limit, offset, is_legacy
)
@@ -179,7 +161,7 @@ class UserCacheService:
if cached_data:
logger.debug(f"User scores cache hit for user {user_id}, type {score_type}")
data = json.loads(cached_data)
return [model(**score_data) for score_data in data] # pyright: ignore[reportReturnType]
return [LegacyScoreResp(**score_data) for score_data in data] if is_legacy else data
return None
except Exception as e:
logger.error(f"Error getting user scores from cache: {e}")
@@ -189,7 +171,7 @@ class UserCacheService:
self,
user_id: int,
score_type: str,
scores: list[ScoreResp] | list[LegacyScoreResp],
scores: list[UserDict] | list[LegacyScoreResp],
include_fail: bool,
mode: GameMode | None = None,
limit: int = 100,
@@ -204,8 +186,12 @@ class UserCacheService:
cache_key = self._get_user_scores_cache_key(
user_id, score_type, include_fail, mode, limit, offset, is_legacy
)
# 使用 model_dump_json() 而不是 model_dump() + json.dumps()
scores_json_list = [score.model_dump_json() for score in scores]
if len(scores) == 0:
return
if isinstance(scores[0], dict):
scores_json_list = [safe_json_dumps(score) for score in scores]
else:
scores_json_list = [score.model_dump_json() for score in scores] # pyright: ignore[reportAttributeAccessIssue]
cached_data = f"[{','.join(scores_json_list)}]"
await self.redis.setex(cache_key, expire_seconds, cached_data)
logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s")
@@ -308,7 +294,7 @@ class UserCacheService:
for user in users:
if user.id != BANCHOBOT_ID:
try:
await self._cache_single_user(user, session)
await self._cache_single_user(user)
cached_count += 1
except Exception as e:
logger.error(f"Failed to cache user {user.id}: {e}")
@@ -320,10 +306,10 @@ class UserCacheService:
finally:
self._refreshing = False
async def _cache_single_user(self, user: User, session: AsyncSession):
async def _cache_single_user(self, user: User):
"""缓存单个用户"""
try:
user_resp = await UserResp.from_db(user, session, include=SEARCH_INCLUDED)
user_resp = await UserModel.transform(user, includes=User.USER_INCLUDES)
# 应用资源代理处理
if settings.enable_asset_proxy:
@@ -347,7 +333,7 @@ class UserCacheService:
# 立即重新加载用户信息
user = await session.get(User, user_id)
if user and user.id != BANCHOBOT_ID:
await self._cache_single_user(user, session)
await self._cache_single_user(user)
logger.info(f"Refreshed cache for user {user_id} after score submit")
except Exception as e:
logger.error(f"Error refreshing user cache on score submit: {e}")

View File

@@ -4,10 +4,13 @@ from datetime import UTC, datetime
import functools
import inspect
from io import BytesIO
import json
import re
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
from types import NoneType, UnionType
from typing import TYPE_CHECKING, Any, ParamSpec, TypedDict, TypeVar, Union, get_args, get_origin
from fastapi import HTTPException
from fastapi.encoders import jsonable_encoder
from PIL import Image
if TYPE_CHECKING:
@@ -299,3 +302,51 @@ def hex_to_hue(hex_color: str) -> int:
hue = (60 * ((r - g) / delta) + 240) % 360
return int(hue)
def safe_json_dumps(data) -> str:
return json.dumps(jsonable_encoder(data), ensure_ascii=False)
def type_is_optional(typ: type):
origin_type = get_origin(typ)
args = get_args(typ)
return (origin_type is UnionType or origin_type is Union) and len(args) == 2 and NoneType in args
def _get_type(typ: type, includes: tuple[str, ...]) -> Any:
from app.database._base import DatabaseModel
origin = get_origin(typ)
if issubclass(typ, DatabaseModel):
return typ.generate_typeddict(includes)
elif origin is list:
item_type = typ.__args__[0]
return list[_get_type(item_type, includes)] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
elif origin is dict:
key_type, value_type = typ.__args__
return dict[key_type, _get_type(value_type, includes)] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
elif type_is_optional(typ):
inner_type = next(arg for arg in get_args(typ) if arg is not NoneType)
return Union[_get_type(inner_type, includes), None] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues] # noqa: UP007
elif origin is UnionType or origin is Union:
new_types = []
for arg in get_args(typ):
new_types.append(_get_type(arg, includes)) # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
return Union[tuple(new_types)] # pyright: ignore[reportArgumentType, reportGeneralTypeIssues] # noqa: UP007
else:
return typ
def api_doc(desc: str, model: Any, includes: list[str] = [], *, name: str = "APIDict"):
if includes:
includes_str = ", ".join(f"`{inc}`" for inc in includes)
desc += f"\n\n包含:{includes_str}"
if isinstance(model, dict):
fields = {}
for k, v in model.items():
fields[k] = _get_type(v, tuple(includes))
typed_dict = TypedDict(name, fields) # pyright: ignore[reportArgumentType, reportGeneralTypeIssues]
else:
typed_dict = _get_type(model, tuple(includes))
return {"description": desc, "model": typed_dict}

View File

@@ -0,0 +1,498 @@
"""project: remove unused fields in database
Revision ID: 23707640303c
Revises: 3f0f22f38c3d
Create Date: 2025-11-23 08:14:05.284238
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision: str = "23707640303c"
down_revision: str | Sequence[str] | None = "3f0f22f38c3d"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"beatmaps",
"mode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=False,
)
op.drop_column("beatmaps", "current_user_playcount")
op.alter_column("beatmapsync", "updated_at", existing_type=mysql.DATETIME(), nullable=True)
op.alter_column(
"best_scores",
"gamemode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=False,
)
op.alter_column(
"lazer_user_statistics",
"mode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=False,
)
op.alter_column("lazer_user_statistics", "ranked_score", existing_type=mysql.BIGINT(), nullable=True)
op.alter_column(
"lazer_users",
"playmode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=False,
)
op.alter_column(
"lazer_users",
"g0v0_playmode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=False,
)
op.drop_column("lazer_users", "beatmap_playcounts_count")
op.alter_column("login_sessions", "is_new_device", existing_type=mysql.TINYINT(display_width=1), nullable=False)
op.alter_column(
"rank_history",
"mode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=False,
)
op.alter_column(
"rank_top",
"mode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=False,
)
op.alter_column(
"rooms",
"type",
existing_type=mysql.ENUM("PLAYLISTS", "HEAD_TO_HEAD", "TEAM_VERSUS", "MATCHMAKING"),
nullable=False,
)
op.execute("UPDATE rooms SET channel_id = 0 WHERE channel_id IS NULL")
op.alter_column("rooms", "channel_id", existing_type=mysql.INTEGER(), nullable=False)
op.alter_column(
"score_tokens",
"ruleset_id",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=False,
)
op.alter_column(
"scores",
"gamemode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=False,
)
op.alter_column(
"teams",
"playmode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=False,
)
op.alter_column(
"total_score_best_scores",
"gamemode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"total_score_best_scores",
"gamemode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=True,
)
op.alter_column(
"teams",
"playmode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=True,
)
op.alter_column(
"scores",
"gamemode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=True,
)
op.alter_column(
"score_tokens",
"ruleset_id",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=True,
)
op.alter_column("rooms", "channel_id", existing_type=mysql.INTEGER(), nullable=True)
op.alter_column(
"rooms",
"type",
existing_type=mysql.ENUM("PLAYLISTS", "HEAD_TO_HEAD", "TEAM_VERSUS", "MATCHMAKING"),
nullable=True,
)
op.alter_column(
"rank_top",
"mode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=True,
)
op.alter_column(
"rank_history",
"mode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=True,
)
op.alter_column("login_sessions", "is_new_device", existing_type=mysql.TINYINT(display_width=1), nullable=True)
op.add_column(
"lazer_users", sa.Column("beatmap_playcounts_count", mysql.INTEGER(), autoincrement=False, nullable=False)
)
op.alter_column(
"lazer_users",
"g0v0_playmode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=True,
)
op.alter_column(
"lazer_users",
"playmode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=True,
)
op.alter_column("lazer_user_statistics", "ranked_score", existing_type=mysql.BIGINT(), nullable=False)
op.alter_column(
"lazer_user_statistics",
"mode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=True,
)
op.alter_column(
"best_scores",
"gamemode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=True,
)
op.alter_column("beatmapsync", "updated_at", existing_type=mysql.DATETIME(), nullable=False)
op.add_column("beatmaps", sa.Column("current_user_playcount", mysql.INTEGER(), autoincrement=False, nullable=False))
op.alter_column(
"beatmaps",
"mode",
existing_type=mysql.ENUM(
"OSU",
"TAIKO",
"FRUITS",
"MANIA",
"OSURX",
"OSUAP",
"TAIKORX",
"FRUITSRX",
"SENTAKKI",
"TAU",
"RUSH",
"HISHIGATA",
"SOYOKAZE",
),
nullable=True,
)
# ### end Alembic commands ###