diff --git a/app/auth.py b/app/auth.py index 4c690f8..4762662 100644 --- a/app/auth.py +++ b/app/auth.py @@ -8,7 +8,7 @@ import string from app.config import settings from app.database import ( OAuthToken, - User as DBUser, + User, ) from app.log import logger @@ -74,7 +74,7 @@ def get_password_hash(password: str) -> str: async def authenticate_user_legacy( db: AsyncSession, name: str, password: str -) -> DBUser | None: +) -> User | None: """ 验证用户身份 - 使用类似 from_login 的逻辑 """ @@ -82,7 +82,7 @@ async def authenticate_user_legacy( pw_md5 = hashlib.md5(password.encode()).hexdigest() # 2. 根据用户名查找用户 - statement = select(DBUser).where(DBUser.name == name) + statement = select(User).where(User.username == name) user = (await db.exec(statement)).first() if not user: return None @@ -113,7 +113,7 @@ async def authenticate_user_legacy( async def authenticate_user( db: AsyncSession, username: str, password: str -) -> DBUser | None: +) -> User | None: """验证用户身份""" return await authenticate_user_legacy(db, username, password) diff --git a/app/database/__init__.py b/app/database/__init__.py index 191a193..6e2e8c5 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -1,3 +1,4 @@ +from .achievement import UserAchievement, UserAchievementResp from .auth import OAuthToken from .beatmap import ( Beatmap as Beatmap, @@ -8,7 +9,13 @@ from .beatmapset import ( BeatmapsetResp as BeatmapsetResp, ) from .best_score import BestScore -from .legacy import LegacyOAuthToken, LegacyUserStatistics +from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp +from .favourite_beatmapset import FavouriteBeatmapset +from .lazer_user import ( + User, + UserResp, +) +from .pp_best_score import PPBestScore from .relationship import Relationship, RelationshipResp, RelationshipType from .score import ( Score, @@ -17,52 +24,27 @@ from .score import ( ScoreStatistics, ) from .score_token import ScoreToken, ScoreTokenResp +from .statistics import ( + UserStatistics, + UserStatisticsResp, +) from .team import Team, TeamMember -from .user import ( - DailyChallengeStats, - LazerUserAchievement, - LazerUserBadge, - LazerUserBanners, - LazerUserCountry, - LazerUserCounts, - LazerUserKudosu, - LazerUserMonthlyPlaycounts, - LazerUserPreviousUsername, - LazerUserProfile, - LazerUserProfileSections, - LazerUserReplaysWatched, - LazerUserStatistics, - RankHistory, - User, - UserAchievement, - UserAvatar, +from .user_account_history import ( + UserAccountHistory, + UserAccountHistoryResp, + UserAccountHistoryType, ) -BeatmapsetResp.model_rebuild() -BeatmapResp.model_rebuild() __all__ = [ "Beatmap", - "BeatmapResp", "Beatmapset", "BeatmapsetResp", "BestScore", "DailyChallengeStats", - "LazerUserAchievement", - "LazerUserBadge", - "LazerUserBanners", - "LazerUserCountry", - "LazerUserCounts", - "LazerUserKudosu", - "LazerUserMonthlyPlaycounts", - "LazerUserPreviousUsername", - "LazerUserProfile", - "LazerUserProfileSections", - "LazerUserReplaysWatched", - "LazerUserStatistics", - "LegacyOAuthToken", - "LegacyUserStatistics", + "DailyChallengeStatsResp", + "FavouriteBeatmapset", "OAuthToken", - "RankHistory", + "PPBestScore", "Relationship", "RelationshipResp", "RelationshipType", @@ -75,6 +57,17 @@ __all__ = [ "Team", "TeamMember", "User", + "UserAccountHistory", + "UserAccountHistoryResp", + "UserAccountHistoryType", "UserAchievement", - "UserAvatar", + "UserAchievement", + "UserAchievementResp", + "UserResp", + "UserStatistics", + "UserStatisticsResp", ] + +for i in __all__: + if i.endswith("Resp"): + globals()[i].model_rebuild() # type: ignore[call-arg] diff --git a/app/database/auth.py b/app/database/auth.py index ae49676..554dced 100644 --- a/app/database/auth.py +++ b/app/database/auth.py @@ -1,19 +1,21 @@ from datetime import datetime from typing import TYPE_CHECKING +from app.models.model import UTCBaseModel + from sqlalchemy import Column, DateTime from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel if TYPE_CHECKING: - from .user import User + from .lazer_user import User -class OAuthToken(SQLModel, table=True): +class OAuthToken(UTCBaseModel, SQLModel, table=True): __tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("users.id"), index=True) + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) ) access_token: str = Field(max_length=500, unique=True) refresh_token: str = Field(max_length=500, unique=True) diff --git a/app/database/beatmap.py b/app/database/beatmap.py index b28473e..62a0120 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -7,13 +7,14 @@ from app.models.score import MODE_TO_INT, GameMode from .beatmapset import Beatmapset, BeatmapsetResp from sqlalchemy import DECIMAL, Column, DateTime -from sqlalchemy.orm import joinedload from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from app.fetcher import Fetcher + from .lazer_user import User + class BeatmapOwner(SQLModel): id: int @@ -66,7 +67,9 @@ class Beatmap(BeatmapBase, table=True): beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) beatmap_status: BeatmapRankStatus # optional - beatmapset: Beatmapset = Relationship(back_populates="beatmaps") + beatmapset: Beatmapset = Relationship( + back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"} + ) @property def can_ranked(self) -> bool: @@ -87,13 +90,7 @@ class Beatmap(BeatmapBase, table=True): session.add(beatmap) await session.commit() beatmap = ( - await session.exec( - select(Beatmap) - .options( - joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] - ) - .where(Beatmap.id == resp.id) - ) + await session.exec(select(Beatmap).where(Beatmap.id == resp.id)) ).first() assert beatmap is not None, "Beatmap should not be None after commit" return beatmap @@ -131,13 +128,9 @@ class Beatmap(BeatmapBase, table=True): ) -> "Beatmap": beatmap = ( await session.exec( - select(Beatmap) - .where( + select(Beatmap).where( Beatmap.id == bid if bid is not None else Beatmap.checksum == md5 ) - .options( - joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] - ) ) ).first() if not beatmap: @@ -164,11 +157,13 @@ class BeatmapResp(BeatmapBase): url: str = "" @classmethod - def from_db( + async def from_db( cls, beatmap: Beatmap, query_mode: GameMode | None = None, from_set: bool = False, + session: AsyncSession | None = None, + user: "User | None" = None, ) -> "BeatmapResp": beatmap_ = beatmap.model_dump() if query_mode is not None and beatmap.mode != query_mode: @@ -178,5 +173,7 @@ class BeatmapResp(BeatmapBase): beatmap_["ranked"] = beatmap.beatmap_status.value beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode] if not from_set: - beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset) + beatmap_["beatmapset"] = await BeatmapsetResp.from_db( + beatmap.beatmapset, session=session, user=user + ) return cls.model_validate(beatmap_) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 1e6ba27..3bad7e9 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -4,13 +4,17 @@ from typing import TYPE_CHECKING, TypedDict, cast from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.models.score import GameMode +from .lazer_user import BASE_INCLUDES, User, UserResp + from pydantic import BaseModel, model_serializer from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text -from sqlmodel import Field, Relationship, SQLModel +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlmodel import Field, Relationship, SQLModel, col, func, select from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from .beatmap import Beatmap, BeatmapResp + from .favourite_beatmapset import FavouriteBeatmapset class BeatmapCovers(SQLModel): @@ -88,7 +92,6 @@ class BeatmapsetBase(SQLModel): artist_unicode: str = Field(index=True) covers: BeatmapCovers | None = Field(sa_column=Column(JSON)) creator: str - favourite_count: int nsfw: bool = Field(default=False) play_count: int preview_url: str @@ -112,11 +115,9 @@ class BeatmapsetBase(SQLModel): pack_tags: list[str] = Field(default=[], sa_column=Column(JSON)) ratings: list[int] = Field(default=None, sa_column=Column(JSON)) - # TODO: recent_favourites: Optional[list[User]] = None # TODO: related_users: Optional[list[User]] = None # TODO: user: Optional[User] = Field(default=None) track_id: int | None = Field(default=None) # feature artist? - # TODO: has_favourited # BeatmapsetExtended bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2))) @@ -129,7 +130,7 @@ class BeatmapsetBase(SQLModel): tags: str = Field(default="", sa_column=Column(Text)) -class Beatmapset(BeatmapsetBase, table=True): +class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): __tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) @@ -150,6 +151,7 @@ class Beatmapset(BeatmapsetBase, table=True): hype_required: int = Field(default=0) availability_info: str | None = Field(default=None) download_disabled: bool = Field(default=False) + favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset") @classmethod async def from_resp( @@ -197,40 +199,88 @@ class BeatmapsetResp(BeatmapsetBase): genre: BeatmapTranslationText | None = None language: BeatmapTranslationText | None = None nominations: BeatmapNominations | None = None + has_favourited: bool = False + favourite_count: int = 0 + recent_favourites: list[UserResp] = Field(default_factory=list) @classmethod - def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": + async def from_db( + cls, + beatmapset: Beatmapset, + include: list[str] = [], + session: AsyncSession | None = None, + user: User | None = None, + ) -> "BeatmapsetResp": from .beatmap import BeatmapResp + from .favourite_beatmapset import FavouriteBeatmapset - beatmaps = [ - BeatmapResp.from_db(beatmap, from_set=True) - for beatmap in beatmapset.beatmaps - ] + update = { + "beatmaps": [ + await BeatmapResp.from_db(beatmap, from_set=True) + for beatmap in await beatmapset.awaitable_attrs.beatmaps + ], + "hype": BeatmapHype( + current=beatmapset.hype_current, required=beatmapset.hype_required + ), + "availability": BeatmapAvailability( + more_information=beatmapset.availability_info, + download_disabled=beatmapset.download_disabled, + ), + "genre": BeatmapTranslationText( + name=beatmapset.beatmap_genre.name, + id=beatmapset.beatmap_genre.value, + ), + "language": BeatmapTranslationText( + name=beatmapset.beatmap_language.name, + id=beatmapset.beatmap_language.value, + ), + "nominations": BeatmapNominations( + required=beatmapset.nominations_required, + current=beatmapset.nominations_current, + ), + "status": beatmapset.beatmap_status.name.lower(), + "ranked": beatmapset.beatmap_status.value, + "is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING, + **beatmapset.model_dump(), + } + if session and user: + existing_favourite = ( + await session.exec( + select(FavouriteBeatmapset).where( + FavouriteBeatmapset.beatmapset_id == beatmapset.id + ) + ) + ).first() + update["has_favourited"] = existing_favourite is not None + + if session and "recent_favourites" in include: + recent_favourites = ( + await session.exec( + select(FavouriteBeatmapset) + .where( + FavouriteBeatmapset.beatmapset_id == beatmapset.id, + ) + .order_by(col(FavouriteBeatmapset.date).desc()) + .limit(50) + ) + ).all() + update["recent_favourites"] = [ + await UserResp.from_db( + await favourite.awaitable_attrs.user, + session=session, + include=BASE_INCLUDES, + ) + for favourite in recent_favourites + ] + + if session: + update["favourite_count"] = ( + await session.exec( + select(func.count()) + .select_from(FavouriteBeatmapset) + .where(FavouriteBeatmapset.beatmapset_id == beatmapset.id) + ) + ).one() return cls.model_validate( - { - "beatmaps": beatmaps, - "hype": BeatmapHype( - current=beatmapset.hype_current, required=beatmapset.hype_required - ), - "availability": BeatmapAvailability( - more_information=beatmapset.availability_info, - download_disabled=beatmapset.download_disabled, - ), - "genre": BeatmapTranslationText( - name=beatmapset.beatmap_genre.name, - id=beatmapset.beatmap_genre.value, - ), - "language": BeatmapTranslationText( - name=beatmapset.beatmap_language.name, - id=beatmapset.beatmap_language.value, - ), - "nominations": BeatmapNominations( - required=beatmapset.nominations_required, - current=beatmapset.nominations_current, - ), - "status": beatmapset.beatmap_status.name.lower(), - "ranked": beatmapset.beatmap_status.value, - "is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING, - **beatmapset.model_dump(), - } + update, ) diff --git a/app/database/best_score.py b/app/database/best_score.py index 313da3e..42b0024 100644 --- a/app/database/best_score.py +++ b/app/database/best_score.py @@ -1,14 +1,14 @@ from typing import TYPE_CHECKING -from app.models.score import GameMode +from app.models.score import GameMode, Rank -from .user import User +from .lazer_user import User from sqlmodel import ( + JSON, BigInteger, Column, Field, - Float, ForeignKey, Relationship, SQLModel, @@ -20,22 +20,29 @@ if TYPE_CHECKING: class BestScore(SQLModel, table=True): - __tablename__ = "best_scores" # pyright: ignore[reportAssignmentType] + __tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType] user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("users.id"), index=True) + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) ) score_id: int = Field( sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) ) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) gamemode: GameMode = Field(index=True) - pp: float = Field( - sa_column=Column(Float, default=0), + total_score: int = Field( + default=0, sa_column=Column(BigInteger, ForeignKey("scores.total_score")) ) - acc: float = Field( - sa_column=Column(Float, default=0), + mods: list[str] = Field( + default_factory=list, + sa_column=Column(JSON), ) + rank: Rank user: User = Relationship() - score: "Score" = Relationship() + score: "Score" = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[BestScore.score_id]", + "lazy": "joined", + } + ) beatmap: "Beatmap" = Relationship() diff --git a/app/database/favourite_beatmapset.py b/app/database/favourite_beatmapset.py new file mode 100644 index 0000000..51bd578 --- /dev/null +++ b/app/database/favourite_beatmapset.py @@ -0,0 +1,53 @@ +import datetime + +from app.database.beatmapset import Beatmapset +from app.database.lazer_user import User + +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) + + +class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True): + __tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType] + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, autoincrement=True, primary_key=True), + exclude=True, + ) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("lazer_users.id"), + index=True, + ), + ) + beatmapset_id: int = Field( + default=None, + sa_column=Column( + ForeignKey("beatmapsets.id"), + index=True, + ), + ) + date: datetime.datetime = Field( + default=datetime.datetime.now(datetime.UTC), + sa_column=Column( + DateTime, + ), + ) + + user: User = Relationship(back_populates="favourite_beatmapsets") + beatmapset: Beatmapset = Relationship( + sa_relationship_kwargs={ + "lazy": "selectin", + }, + back_populates="favourites", + ) diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 9b98c98..3bd751b 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -12,7 +12,7 @@ from .statistics import UserStatistics, UserStatisticsResp from .team import Team, TeamMember from .user_account_history import UserAccountHistory, UserAccountHistoryResp -from sqlalchemy.orm import joinedload, selectinload +from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import ( JSON, BigInteger, @@ -27,7 +27,8 @@ from sqlmodel import ( from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: - from app.database.relationship import RelationshipResp + from .favourite_beatmapset import FavouriteBeatmapset + from .relationship import RelationshipResp class Kudosu(TypedDict): @@ -128,7 +129,7 @@ class UserBase(UTCBaseModel, SQLModel): is_bng: bool = False -class User(UserBase, table=True): +class User(AsyncAttrs, UserBase, table=True): __tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType] id: int | None = Field( @@ -143,6 +144,9 @@ class User(UserBase, table=True): back_populates="user" ) monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user") + favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship( + back_populates="user" + ) email: str = Field(max_length=254, unique=True, index=True, exclude=True) priv: int = Field(default=1, exclude=True) @@ -154,21 +158,10 @@ class User(UserBase, table=True): default=None, sa_column=Column(DateTime(timezone=True)), exclude=True ) - @classmethod - def all_select_option(cls): - return ( - selectinload(cls.account_history), # pyright: ignore[reportArgumentType] - selectinload(cls.statistics), # pyright: ignore[reportArgumentType] - selectinload(cls.achievement), # pyright: ignore[reportArgumentType] - joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType] - joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType] - selectinload(cls.monthly_playcounts), # pyright: ignore[reportArgumentType] - ) - class UserResp(UserBase): id: int | None = None - is_online: bool = True # TODO + is_online: bool = False groups: list = [] # TODO country: Country = Field(default_factory=lambda: Country(code="CN", name="China")) favourite_beatmapset_count: int = 0 # TODO @@ -211,6 +204,8 @@ class UserResp(UserBase): include: list[str] = [], ruleset: GameMode | None = None, ) -> "UserResp": + from app.dependencies.database import get_redis + from .best_score import BestScore from .relationship import Relationship, RelationshipResp, RelationshipType @@ -236,6 +231,8 @@ class UserResp(UserBase): .limit(200) ) ).one() + redis = get_redis() + u.is_online = await redis.exists(f"metadata:online:{obj.id}") u.cover_url = ( obj.cover.get( "url", "https://assets.ppy.sh/user-profile-covers/default.jpeg" @@ -249,13 +246,7 @@ class UserResp(UserBase): await RelationshipResp.from_db(session, r) for r in ( await session.exec( - select(Relationship) - .options( - joinedload(Relationship.target).options( # pyright: ignore[reportArgumentType] - *User.all_select_option() - ) - ) - .where( + select(Relationship).where( Relationship.user_id == obj.id, Relationship.type == RelationshipType.FOLLOW, ) @@ -264,23 +255,26 @@ class UserResp(UserBase): ] if "team" in include: - if obj.team_membership: + if await obj.awaitable_attrs.team_membership: + assert obj.team_membership u.team = obj.team_membership.team if "account_history" in include: u.account_history = [ - UserAccountHistoryResp.from_db(ah) for ah in obj.account_history + UserAccountHistoryResp.from_db(ah) + for ah in await obj.awaitable_attrs.account_history ] if "daily_challenge_user_stats": - if obj.daily_challenge_stats: + if await obj.awaitable_attrs.daily_challenge_stats: + assert obj.daily_challenge_stats u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db( obj.daily_challenge_stats ) if "statistics" in include: current_stattistics = None - for i in obj.statistics: + for i in await obj.awaitable_attrs.statistics: if i.mode == (ruleset or obj.playmode): current_stattistics = i break @@ -292,17 +286,20 @@ class UserResp(UserBase): if "statistics_rulesets" in include: u.statistics_rulesets = { - i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics + i.mode.value: UserStatisticsResp.from_db(i) + for i in await obj.awaitable_attrs.statistics } if "monthly_playcounts" in include: u.monthly_playcounts = [ - MonthlyPlaycountsResp.from_db(pc) for pc in obj.monthly_playcounts + MonthlyPlaycountsResp.from_db(pc) + for pc in await obj.awaitable_attrs.monthly_playcounts ] if "achievements" in include: u.user_achievements = [ - UserAchievementResp.from_db(ua) for ua in obj.achievement + UserAchievementResp.from_db(ua) + for ua in await obj.awaitable_attrs.achievement ] return u @@ -328,3 +325,9 @@ SEARCH_INCLUDED = [ "achievements", "monthly_playcounts", ] + +BASE_INCLUDES = [ + "team", + "daily_challenge_user_stats", + "statistics", +] diff --git a/app/database/legacy.py b/app/database/legacy.py deleted file mode 100644 index ff1e957..0000000 --- a/app/database/legacy.py +++ /dev/null @@ -1,94 +0,0 @@ -from datetime import datetime -from typing import TYPE_CHECKING - -from sqlalchemy import JSON, Column, DateTime -from sqlalchemy.orm import Mapped -from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel - -if TYPE_CHECKING: - from .user import User -# ============================================ -# 旧的兼容性表模型(保留以便向后兼容) -# ============================================ - - -class LegacyUserStatistics(SQLModel, table=True): - __tablename__ = "user_statistics" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - mode: str = Field(max_length=10) # osu, taiko, fruits, mania - - # 基本统计 - count_100: int = Field(default=0) - count_300: int = Field(default=0) - count_50: int = Field(default=0) - count_miss: int = Field(default=0) - - # 等级信息 - level_current: int = Field(default=1) - level_progress: int = Field(default=0) - - # 排名信息 - global_rank: int | None = Field(default=None) - global_rank_exp: int | None = Field(default=None) - country_rank: int | None = Field(default=None) - - # PP 和分数 - pp: float = Field(default=0.0) - pp_exp: float = Field(default=0.0) - ranked_score: int = Field(default=0) - hit_accuracy: float = Field(default=0.0) - total_score: int = Field(default=0) - total_hits: int = Field(default=0) - maximum_combo: int = Field(default=0) - - # 游戏统计 - play_count: int = Field(default=0) - play_time: int = Field(default=0) - replays_watched_by_others: int = Field(default=0) - is_ranked: bool = Field(default=False) - - # 成绩等级计数 - grade_ss: int = Field(default=0) - grade_ssh: int = Field(default=0) - grade_s: int = Field(default=0) - grade_sh: int = Field(default=0) - grade_a: int = Field(default=0) - - # 最高排名记录 - rank_highest: int | None = Field(default=None) - rank_highest_updated_at: datetime | None = Field( - default=None, sa_column=Column(DateTime) - ) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - # 关联关系 - user: Mapped["User"] = Relationship(back_populates="statistics") - - -class LegacyOAuthToken(SQLModel, table=True): - __tablename__ = "legacy_oauth_tokens" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - access_token: str = Field(max_length=255, index=True) - refresh_token: str = Field(max_length=255, index=True) - expires_at: datetime = Field(sa_column=Column(DateTime)) - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - previous_usernames: list = Field(default_factory=list, sa_column=Column(JSON)) - replays_watched_counts: list = Field(default_factory=list, sa_column=Column(JSON)) - - # 用户关系 - user: "User" = Relationship() diff --git a/app/database/pp_best_score.py b/app/database/pp_best_score.py new file mode 100644 index 0000000..ffc74d3 --- /dev/null +++ b/app/database/pp_best_score.py @@ -0,0 +1,41 @@ +from typing import TYPE_CHECKING + +from app.models.score import GameMode + +from .lazer_user import User + +from sqlmodel import ( + BigInteger, + Column, + Field, + Float, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .beatmap import Beatmap + from .score import Score + + +class PPBestScore(SQLModel, table=True): + __tablename__ = "best_scores" # pyright: ignore[reportAssignmentType] + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + score_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) + ) + beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) + gamemode: GameMode = Field(index=True) + pp: float = Field( + sa_column=Column(Float, default=0), + ) + acc: float = Field( + sa_column=Column(Float, default=0), + ) + + user: User = Relationship() + score: "Score" = Relationship() + beatmap: "Beatmap" = Relationship() diff --git a/app/database/relationship.py b/app/database/relationship.py index 61dc109..b941c28 100644 --- a/app/database/relationship.py +++ b/app/database/relationship.py @@ -1,8 +1,6 @@ from enum import Enum -from app.models.user import User as APIUser - -from .user import User as DBUser +from .lazer_user import User, UserResp from pydantic import BaseModel from sqlmodel import ( @@ -24,12 +22,16 @@ class RelationshipType(str, Enum): class Relationship(SQLModel, table=True): __tablename__ = "relationship" # pyright: ignore[reportAssignmentType] + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, autoincrement=True, primary_key=True), + exclude=True, + ) user_id: int = Field( default=None, sa_column=Column( BigInteger, - ForeignKey("users.id"), - primary_key=True, + ForeignKey("lazer_users.id"), index=True, ), ) @@ -37,20 +39,22 @@ class Relationship(SQLModel, table=True): default=None, sa_column=Column( BigInteger, - ForeignKey("users.id"), - primary_key=True, + ForeignKey("lazer_users.id"), index=True, ), ) type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) - target: DBUser = SQLRelationship( - sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"} + target: User = SQLRelationship( + sa_relationship_kwargs={ + "foreign_keys": "[Relationship.target_id]", + "lazy": "selectin", + } ) class RelationshipResp(BaseModel): target_id: int - target: APIUser + target: UserResp mutual: bool = False type: RelationshipType @@ -58,8 +62,6 @@ class RelationshipResp(BaseModel): async def from_db( cls, session: AsyncSession, relationship: Relationship ) -> "RelationshipResp": - from app.utils import convert_db_user_to_api_user - target_relationship = ( await session.exec( select(Relationship).where( @@ -75,7 +77,16 @@ class RelationshipResp(BaseModel): ) return cls( target_id=relationship.target_id, - target=await convert_db_user_to_api_user(relationship.target), + target=await UserResp.from_db( + relationship.target, + session, + include=[ + "team", + "daily_challenge_user_stats", + "statistics", + "statistics_rulesets", + ], + ), mutual=mutual, type=relationship.type, ) diff --git a/app/database/score.py b/app/database/score.py index 694395b..79cb005 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -1,6 +1,7 @@ import asyncio from collections.abc import Sequence -from datetime import UTC, datetime +from datetime import UTC, date, datetime +import json import math from typing import TYPE_CHECKING @@ -12,9 +13,8 @@ from app.calculator import ( calculate_weighted_pp, clamp, ) -from app.database.score_token import ScoreToken -from app.database.user import LazerUserStatistics, User -from app.models.beatmap import BeatmapRankStatus +from app.database.team import TeamMember +from app.models.model import UTCBaseModel from app.models.mods import APIMod, mods_can_get_pp from app.models.score import ( INT_TO_MODE, @@ -26,15 +26,24 @@ from app.models.score import ( ScoreStatistics, SoloScoreSubmissionInfo, ) -from app.models.user import User as APIUser from .beatmap import Beatmap, BeatmapResp -from .beatmapset import Beatmapset, BeatmapsetResp +from .beatmapset import BeatmapsetResp from .best_score import BestScore +from .lazer_user import User, UserResp +from .monthly_playcounts import MonthlyPlaycounts +from .pp_best_score import PPBestScore +from .relationship import ( + Relationship as DBRelationship, + RelationshipType, +) +from .score_token import ScoreToken -from redis import Redis +from redis.asyncio import Redis from sqlalchemy import Column, ColumnExpressionArgument, DateTime -from sqlalchemy.orm import aliased, joinedload +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import aliased +from sqlalchemy.sql.elements import ColumnElement from sqlmodel import ( JSON, BigInteger, @@ -43,9 +52,10 @@ from sqlmodel import ( Relationship, SQLModel, col, - false, func, select, + text, + true, ) from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.sql._expression_select_cls import SelectOfScalar @@ -54,7 +64,7 @@ if TYPE_CHECKING: from app.fetcher import Fetcher -class ScoreBase(SQLModel): +class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): # 基本字段 accuracy: float map_md5: str = Field(max_length=32, index=True) @@ -94,7 +104,7 @@ class Score(ScoreBase, table=True): default=None, sa_column=Column( BigInteger, - ForeignKey("users.id"), + ForeignKey("lazer_users.id"), index=True, ), ) @@ -112,28 +122,13 @@ class Score(ScoreBase, table=True): gamemode: GameMode = Field(index=True) # optional - beatmap: "Beatmap" = Relationship() - user: "User" = Relationship() + beatmap: Beatmap = Relationship() + user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) @property def is_perfect_combo(self) -> bool: return self.max_combo == self.beatmap.max_combo - @staticmethod - def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]: - clause = select(Score).options( - joinedload(Score.beatmap) # pyright: ignore[reportArgumentType] - .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType] - .selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ), - ) - if with_user: - return clause.options( - joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType] - ) - return clause - @staticmethod def select_clause_unique( *where_clauses: ColumnExpressionArgument[bool] | bool, @@ -147,18 +142,7 @@ class Score(ScoreBase, table=True): ) subq = select(Score, rownum).where(*where_clauses).subquery() best = aliased(Score, subq, adapt_on_names=True) - return ( - select(best) - .where(subq.c.rn == 1) - .options( - joinedload(best.beatmap) # pyright: ignore[reportArgumentType] - .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType] - .selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ), - joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType] - ) - ) + return select(best).where(subq.c.rn == 1) class ScoreResp(ScoreBase): @@ -173,22 +157,21 @@ class ScoreResp(ScoreBase): ruleset_id: int | None = None beatmap: BeatmapResp | None = None beatmapset: BeatmapsetResp | None = None - user: APIUser | None = None + user: UserResp | None = None statistics: ScoreStatistics | None = None maximum_statistics: ScoreStatistics | None = None rank_global: int | None = None rank_country: int | None = None @classmethod - async def from_db( - cls, session: AsyncSession, score: Score, user: User | None = None - ) -> "ScoreResp": - from app.utils import convert_db_user_to_api_user - + async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": s = cls.model_validate(score.model_dump()) assert score.id - s.beatmap = BeatmapResp.from_db(score.beatmap) - s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset) + await score.awaitable_attrs.beatmap + s.beatmap = await BeatmapResp.from_db(score.beatmap) + s.beatmapset = await BeatmapsetResp.from_db( + score.beatmap.beatmapset, session=session, user=score.user + ) s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo s.ruleset_id = MODE_TO_INT[score.gamemode] @@ -220,25 +203,30 @@ class ScoreResp(ScoreBase): s.maximum_statistics = { HitResult.GREAT: score.beatmap.max_combo, } - if user: - s.user = await convert_db_user_to_api_user(user) + s.user = await UserResp.from_db( + score.user, + session, + include=["statistics", "team", "daily_challenge_user_stats"], + ruleset=score.gamemode, + ) s.rank_global = ( await get_score_position_by_id( session, - score.map_md5, + score.beatmap_id, score.id, mode=score.gamemode, - user=user or score.user, + user=score.user, ) or None ) s.rank_country = ( await get_score_position_by_id( session, - score.map_md5, + score.beatmap_id, score.id, score.gamemode, - user or score.user, + score.user, + type=LeaderboardType.COUNTRY, ) or None ) @@ -248,135 +236,137 @@ class ScoreResp(ScoreBase): async def get_best_id(session: AsyncSession, score_id: int) -> None: rownum = ( func.row_number() - .over(partition_by=col(BestScore.user_id), order_by=col(BestScore.pp).desc()) + .over( + partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc() + ) .label("rn") ) - subq = select(BestScore, rownum).subquery() + subq = select(PPBestScore, rownum).subquery() stmt = select(subq.c.rn).where(subq.c.score_id == score_id) result = await session.exec(stmt) return result.one_or_none() +async def _score_where( + type: LeaderboardType, + beatmap: int, + mode: GameMode, + mods: list[str] | None = None, + user: User | None = None, +) -> list[ColumnElement[bool]] | None: + wheres = [ + col(BestScore.beatmap_id) == beatmap, + col(BestScore.gamemode) == mode, + ] + + if type == LeaderboardType.FRIENDS: + if user and user.is_supporter: + subq = ( + select(DBRelationship.target_id) + .where( + DBRelationship.type == RelationshipType.FOLLOW, + DBRelationship.user_id == user.id, + ) + .subquery() + ) + wheres.append(col(BestScore.user_id).in_(select(subq.c.target_id))) + else: + return None + elif type == LeaderboardType.COUNTRY: + if user and user.is_supporter: + wheres.append( + col(BestScore.user).has(col(User.country_code) == user.country_code) + ) + else: + return None + elif type == LeaderboardType.TEAM: + if user: + team_membership = await user.awaitable_attrs.team_membership + if team_membership: + team_id = team_membership.team_id + wheres.append( + col(BestScore.user).has( + col(User.team_membership).has(TeamMember.team_id == team_id) + ) + ) + if mods: + if user and user.is_supporter: + wheres.append( + text( + "JSON_CONTAINS(total_score_best_scores.mods, :w)" + " AND JSON_CONTAINS(:w, total_score_best_scores.mods)" + ) # pyright: ignore[reportArgumentType] + ) + else: + return None + return wheres + + async def get_leaderboard( session: AsyncSession, - beatmap_md5: str, + beatmap: int, mode: GameMode, type: LeaderboardType = LeaderboardType.GLOBAL, - mods: list[APIMod] | None = None, + mods: list[str] | None = None, user: User | None = None, limit: int = 50, -) -> list[Score]: - scores = [] - if type == LeaderboardType.GLOBAL: - query = ( - select(Score) - .where( - col(Beatmap.beatmap_status).in_( - [ - BeatmapRankStatus.RANKED, - BeatmapRankStatus.LOVED, - BeatmapRankStatus.QUALIFIED, - BeatmapRankStatus.APPROVED, - ] - ), - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - col(Score.passed).is_(True), - Score.mods == mods if user and user.is_supporter else false(), - ) - .limit(limit) - .order_by( - col(Score.total_score).desc(), - ) - ) - result = await session.exec(query) - scores = list[Score](result.all()) - elif type == LeaderboardType.FRIENDS and user and user.is_supporter: - # TODO - ... - elif type == LeaderboardType.TEAM and user and user.team_membership: - team_id = user.team_membership.team_id - query = ( - select(Score) - .join(Beatmap) - .options(joinedload(Score.user)) # pyright: ignore[reportArgumentType] - .where( - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - col(Score.passed).is_(True), - col(Score.user.team_membership).is_not(None), - Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess] - Score.mods == mods if user and user.is_supporter else false(), - ) - .limit(limit) - .order_by( - col(Score.total_score).desc(), - ) - ) - result = await session.exec(query) - scores = list[Score](result.all()) +) -> tuple[list[Score], Score | None]: + wheres = await _score_where(type, beatmap, mode, mods, user) + if wheres is None: + return [], None + query = ( + select(BestScore) + .where(*wheres) + .limit(limit) + .order_by(col(BestScore.total_score).desc()) + ) + if mods: + query = query.params(w=json.dumps(mods)) + scores = [s.score for s in await session.exec(query)] + user_score = None if user: - user_score = ( - await session.exec( - select(Score).where( - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - Score.user_id == user.id, - col(Score.passed).is_(True), + self_query = ( + select(BestScore) + .where(BestScore.user_id == user.id) + .order_by(col(BestScore.total_score).desc()) + .limit(1) + ) + if mods: + self_query = self_query.where( + text( + "JSON_CONTAINS(total_score_best_scores.mods, :w)" + " AND JSON_CONTAINS(:w, total_score_best_scores.mods)" ) - ) - ).first() + ).params(w=json.dumps(mods)) + user_bs = (await session.exec(self_query)).first() + if user_bs: + user_score = user_bs.score if user_score and user_score not in scores: scores.append(user_score) - return scores + return scores, user_score async def get_score_position_by_user( session: AsyncSession, - beatmap_md5: str, + beatmap: int, user: User, mode: GameMode, type: LeaderboardType = LeaderboardType.GLOBAL, - mods: list[APIMod] | None = None, + mods: list[str] | None = None, ) -> int: - where_clause = [ - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - col(Score.passed).is_(True), - col(Beatmap.beatmap_status).in_( - [ - BeatmapRankStatus.RANKED, - BeatmapRankStatus.LOVED, - BeatmapRankStatus.QUALIFIED, - BeatmapRankStatus.APPROVED, - ] - ), - ] - if mods and user.is_supporter: - where_clause.append(Score.mods == mods) - else: - where_clause.append(false()) - if type == LeaderboardType.FRIENDS and user.is_supporter: - # TODO - ... - elif type == LeaderboardType.TEAM and user.team_membership: - team_id = user.team_membership.team_id - where_clause.append( - col(Score.user.team_membership).is_not(None), - ) - where_clause.append( - Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess] - ) + wheres = await _score_where(type, beatmap, mode, mods, user=user) + if wheres is None: + return 0 rownum = ( func.row_number() .over( - partition_by=Score.map_md5, - order_by=col(Score.total_score).desc(), + partition_by=col(BestScore.beatmap_id), + order_by=col(BestScore.total_score).desc(), ) .label("row_number") ) - subq = select(Score, rownum).join(Beatmap).where(*where_clause).subquery() - stmt = select(subq.c.row_number).where(subq.c.user == user) + subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery() + stmt = select(subq.c.row_number).where(subq.c.user_id == user.id) result = await session.exec(stmt) s = result.one_or_none() return s if s else 0 @@ -384,57 +374,26 @@ async def get_score_position_by_user( async def get_score_position_by_id( session: AsyncSession, - beatmap_md5: str, + beatmap: int, score_id: int, mode: GameMode, user: User | None = None, type: LeaderboardType = LeaderboardType.GLOBAL, - mods: list[APIMod] | None = None, + mods: list[str] | None = None, ) -> int: - where_clause = [ - Score.map_md5 == beatmap_md5, - Score.id == score_id, - Score.gamemode == mode, - col(Score.passed).is_(True), - col(Beatmap.beatmap_status).in_( - [ - BeatmapRankStatus.RANKED, - BeatmapRankStatus.LOVED, - BeatmapRankStatus.QUALIFIED, - BeatmapRankStatus.APPROVED, - ] - ), - ] - if mods and user and user.is_supporter: - where_clause.append(Score.mods == mods) - elif mods: - where_clause.append(false()) + wheres = await _score_where(type, beatmap, mode, mods, user=user) + if wheres is None: + return 0 rownum = ( func.row_number() .over( - partition_by=[col(Score.user_id), col(Score.map_md5)], - order_by=col(Score.total_score).desc(), + partition_by=col(BestScore.beatmap_id), + order_by=col(BestScore.total_score).desc(), ) - .label("rownum") + .label("row_number") ) - subq = ( - select(Score.user_id, Score.id, Score.total_score, rownum) - .join(Beatmap) - .where(*where_clause) - .subquery() - ) - best_scores = aliased(subq) - overall_rank = ( - func.rank().over(order_by=best_scores.c.total_score.desc()).label("global_rank") - ) - final_q = ( - select(best_scores.c.id, overall_rank) - .select_from(best_scores) - .where(best_scores.c.rownum == 1) - .subquery() - ) - - stmt = select(final_q.c.global_rank).where(final_q.c.id == score_id) + subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery() + stmt = select(subq.c.row_number).where(subq.c.score_id == score_id) result = await session.exec(stmt) s = result.one_or_none() return s if s else 0 @@ -445,16 +404,38 @@ async def get_user_best_score_in_beatmap( beatmap: int, user: int, mode: GameMode | None = None, -) -> Score | None: +) -> BestScore | None: return ( await session.exec( - Score.select_clause(False) + select(BestScore) .where( - Score.gamemode == mode if mode is not None else True, - Score.beatmap_id == beatmap, - Score.user_id == user, + BestScore.gamemode == mode if mode is not None else true(), + BestScore.beatmap_id == beatmap, + BestScore.user_id == user, ) - .order_by(col(Score.total_score).desc()) + .order_by(col(BestScore.total_score).desc()) + ) + ).first() + + +# FIXME +async def get_user_best_score_with_mod_in_beatmap( + session: AsyncSession, + beatmap: int, + user: int, + mod: list[str], + mode: GameMode | None = None, +) -> BestScore | None: + return ( + await session.exec( + select(BestScore) + .where( + BestScore.gamemode == mode if mode is not None else True, + BestScore.beatmap_id == beatmap, + BestScore.user_id == user, + # BestScore.mods == mod, + ) + .order_by(col(BestScore.total_score).desc()) ) ).first() @@ -464,13 +445,13 @@ async def get_user_best_pp_in_beatmap( beatmap: int, user: int, mode: GameMode, -) -> BestScore | None: +) -> PPBestScore | None: return ( await session.exec( - select(BestScore).where( - BestScore.beatmap_id == beatmap, - BestScore.user_id == user, - BestScore.gamemode == mode, + select(PPBestScore).where( + PPBestScore.beatmap_id == beatmap, + PPBestScore.user_id == user, + PPBestScore.gamemode == mode, ) ) ).first() @@ -480,12 +461,12 @@ async def get_user_best_pp( session: AsyncSession, user: int, limit: int = 200, -) -> Sequence[BestScore]: +) -> Sequence[PPBestScore]: return ( await session.exec( - select(BestScore) - .where(BestScore.user_id == user) - .order_by(col(BestScore.pp).desc()) + select(PPBestScore) + .where(PPBestScore.user_id == user) + .order_by(col(PPBestScore.pp).desc()) .limit(limit) ) ).all() @@ -494,27 +475,45 @@ async def get_user_best_pp( async def process_user( session: AsyncSession, user: User, score: Score, ranked: bool = False ): + assert user.id + assert score.id + mod_for_save = list({mod["acronym"] for mod in score.mods}) previous_score_best = await get_user_best_score_in_beatmap( session, score.beatmap_id, user.id, score.gamemode ) - statistics = None + previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap( + session, score.beatmap_id, user.id, mod_for_save, score.gamemode + ) add_to_db = False - for i in user.lazer_statistics: + mouthly_playcount = ( + await session.exec( + select(MonthlyPlaycounts).where( + MonthlyPlaycounts.user_id == user.id, + MonthlyPlaycounts.year == date.today().year, + MonthlyPlaycounts.month == date.today().month, + ) + ) + ).first() + if mouthly_playcount is None: + mouthly_playcount = MonthlyPlaycounts( + user_id=user.id, year=date.today().year, month=date.today().month + ) + add_to_db = True + statistics = None + for i in await user.awaitable_attrs.statistics: if i.mode == score.gamemode.value: statistics = i break if statistics is None: - statistics = LazerUserStatistics( - mode=score.gamemode.value, - user_id=user.id, + raise ValueError( + f"User {user.id} does not have statistics for mode {score.gamemode.value}" ) - add_to_db = True # pc, pt, tth, tts statistics.total_score += score.total_score difference = ( score.total_score - previous_score_best.total_score - if previous_score_best and previous_score_best.id != score.id + if previous_score_best else score.total_score ) if difference > 0 and score.passed and ranked: @@ -541,11 +540,48 @@ async def process_user( statistics.grade_sh -= 1 case Rank.A: statistics.grade_a -= 1 + else: + previous_score_best = BestScore( + user_id=user.id, + beatmap_id=score.beatmap_id, + gamemode=score.gamemode, + score_id=score.id, + total_score=score.total_score, + rank=score.rank, + mods=mod_for_save, + ) + session.add(previous_score_best) + statistics.ranked_score += difference statistics.level_current = calculate_score_to_level(statistics.ranked_score) statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) + if score.passed and ranked: + if previous_score_best_mod is not None: + previous_score_best_mod.mods = mod_for_save + previous_score_best_mod.score_id = score.id + previous_score_best_mod.rank = score.rank + previous_score_best_mod.total_score = score.total_score + elif ( + previous_score_best is not None and previous_score_best.score_id != score.id + ): + session.add( + BestScore( + user_id=user.id, + beatmap_id=score.beatmap_id, + gamemode=score.gamemode, + score_id=score.id, + total_score=score.total_score, + rank=score.rank, + mods=mod_for_save, + ) + ) statistics.play_count += 1 + mouthly_playcount.playcount += 1 statistics.play_time += int((score.ended_at - score.started_at).total_seconds()) + statistics.count_100 += score.n100 + score.nkatu + statistics.count_300 += score.n300 + score.ngeki + statistics.count_50 += score.n50 + statistics.count_miss += score.nmiss statistics.total_hits += ( score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu ) @@ -563,11 +599,8 @@ async def process_user( acc_sum = clamp(acc_sum, 0.0, 100.0) statistics.pp = pp_sum statistics.hit_accuracy = acc_sum - - statistics.updated_at = datetime.now(UTC) - if add_to_db: - session.add(statistics) + session.add(mouthly_playcount) await session.commit() await session.refresh(user) @@ -582,6 +615,8 @@ async def process_score( session: AsyncSession, redis: Redis, ) -> Score: + assert user.id + can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods) score = Score( accuracy=info.accuracy, max_combo=info.max_combo, @@ -611,7 +646,7 @@ async def process_score( nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0), nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0), ) - if info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods): + if can_get_pp: beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) pp = await asyncio.get_event_loop().run_in_executor( None, calculate_pp, score, beatmap_raw @@ -621,13 +656,13 @@ async def process_score( user_id = user.id await session.commit() await session.refresh(score) - if score.passed and ranked: + if can_get_pp: previous_pp_best = await get_user_best_pp_in_beatmap( session, beatmap_id, user_id, score.gamemode ) if previous_pp_best is None or score.pp > previous_pp_best.pp: assert score.id - best_score = BestScore( + best_score = PPBestScore( user_id=user_id, score_id=score.id, beatmap_id=beatmap_id, @@ -636,7 +671,7 @@ async def process_score( acc=score.accuracy, ) session.add(best_score) - session.delete(previous_pp_best) if previous_pp_best else None + await session.delete(previous_pp_best) if previous_pp_best else None await session.commit() await session.refresh(score) await session.refresh(score_token) diff --git a/app/database/score_token.py b/app/database/score_token.py index 6a6edb3..4467b8b 100644 --- a/app/database/score_token.py +++ b/app/database/score_token.py @@ -1,15 +1,16 @@ from datetime import datetime +from app.models.model import UTCBaseModel from app.models.score import GameMode from .beatmap import Beatmap -from .user import User +from .lazer_user import User from sqlalchemy import Column, DateTime, Index from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel -class ScoreTokenBase(SQLModel): +class ScoreTokenBase(SQLModel, UTCBaseModel): score_id: int | None = Field(sa_column=Column(BigInteger), default=None) ruleset_id: GameMode playlist_item_id: int | None = Field(default=None) # playlist @@ -34,10 +35,10 @@ class ScoreToken(ScoreTokenBase, table=True): autoincrement=True, ), ) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) beatmap_id: int = Field(foreign_key="beatmaps.id") - user: "User" = Relationship() - beatmap: "Beatmap" = Relationship() + user: User = Relationship() + beatmap: Beatmap = Relationship() class ScoreTokenResp(ScoreTokenBase): diff --git a/app/database/team.py b/app/database/team.py index 360e805..562b0c8 100644 --- a/app/database/team.py +++ b/app/database/team.py @@ -1,14 +1,16 @@ from datetime import datetime from typing import TYPE_CHECKING +from app.models.model import UTCBaseModel + from sqlalchemy import Column, DateTime from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel if TYPE_CHECKING: - from .user import User + from .lazer_user import User -class Team(SQLModel, table=True): +class Team(SQLModel, UTCBaseModel, table=True): __tablename__ = "teams" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) @@ -22,15 +24,19 @@ class Team(SQLModel, table=True): members: list["TeamMember"] = Relationship(back_populates="team") -class TeamMember(SQLModel, table=True): +class TeamMember(SQLModel, UTCBaseModel, table=True): __tablename__ = "team_members" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) team_id: int = Field(foreign_key="teams.id") joined_at: datetime = Field( default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - user: "User" = Relationship(back_populates="team_membership") - team: "Team" = Relationship(back_populates="members") + user: "User" = Relationship( + back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"} + ) + team: "Team" = Relationship( + back_populates="members", sa_relationship_kwargs={"lazy": "joined"} + ) diff --git a/app/database/user.py b/app/database/user.py deleted file mode 100644 index a188497..0000000 --- a/app/database/user.py +++ /dev/null @@ -1,527 +0,0 @@ -from dataclasses import dataclass -from datetime import datetime -from typing import Optional - -from .legacy import LegacyUserStatistics -from .team import TeamMember - -from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text -from sqlalchemy.dialects.mysql import VARCHAR -from sqlalchemy.orm import joinedload, selectinload -from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel, select - - -class User(SQLModel, table=True): - __tablename__ = "users" # pyright: ignore[reportAssignmentType] - - # 主键 - id: int = Field( - default=None, sa_column=Column(BigInteger, primary_key=True, index=True) - ) - - # 基本信息(匹配 migrations_old 中的结构) - name: str = Field(max_length=32, unique=True, index=True) # 用户名 - safe_name: str = Field(max_length=32, unique=True, index=True) # 安全用户名 - email: str = Field(max_length=254, unique=True, index=True) - priv: int = Field(default=1) # 权限 - pw_bcrypt: str = Field(max_length=60) # bcrypt 哈希密码 - country: str = Field(default="CN", max_length=2) # 国家代码 - - # 状态和时间 - silence_end: int = Field(default=0) - donor_end: int = Field(default=0) - creation_time: int = Field(default=0) # Unix 时间戳 - latest_activity: int = Field(default=0) # Unix 时间戳 - - # 游戏相关 - preferred_mode: int = Field(default=0) # 偏好游戏模式 - play_style: int = Field(default=0) # 游戏风格 - - # 扩展信息 - clan_id: int = Field(default=0) - clan_priv: int = Field(default=0) - custom_badge_name: str | None = Field(default=None, max_length=16) - custom_badge_icon: str | None = Field(default=None, max_length=64) - userpage_content: str | None = Field(default=None, max_length=2048) - api_key: str | None = Field(default=None, max_length=36, unique=True) - - # 虚拟字段用于兼容性 - @property - def username(self): - return self.name - - @property - def country_code(self): - return self.country - - @property - def join_date(self): - creation_time = getattr(self, "creation_time", 0) - return ( - datetime.fromtimestamp(creation_time) - if creation_time > 0 - else datetime.utcnow() - ) - - @property - def last_visit(self): - latest_activity = getattr(self, "latest_activity", 0) - return datetime.fromtimestamp(latest_activity) if latest_activity > 0 else None - - @property - def is_supporter(self): - return self.lazer_profile.is_supporter if self.lazer_profile else False - - # 关联关系 - lazer_profile: Optional["LazerUserProfile"] = Relationship(back_populates="user") - lazer_statistics: list["LazerUserStatistics"] = Relationship(back_populates="user") - lazer_counts: Optional["LazerUserCounts"] = Relationship(back_populates="user") - lazer_achievements: list["LazerUserAchievement"] = Relationship( - back_populates="user" - ) - lazer_profile_sections: list["LazerUserProfileSections"] = Relationship( - back_populates="user" - ) - statistics: list["LegacyUserStatistics"] = Relationship(back_populates="user") - team_membership: Optional["TeamMember"] = Relationship(back_populates="user") - daily_challenge_stats: Optional["DailyChallengeStats"] = Relationship( - back_populates="user" - ) - rank_history: list["RankHistory"] = Relationship(back_populates="user") - avatar: Optional["UserAvatar"] = Relationship(back_populates="user") - active_banners: list["LazerUserBanners"] = Relationship(back_populates="user") - lazer_badges: list["LazerUserBadge"] = Relationship(back_populates="user") - lazer_monthly_playcounts: list["LazerUserMonthlyPlaycounts"] = Relationship( - back_populates="user" - ) - lazer_previous_usernames: list["LazerUserPreviousUsername"] = Relationship( - back_populates="user" - ) - lazer_replays_watched: list["LazerUserReplaysWatched"] = Relationship( - back_populates="user" - ) - - @classmethod - def all_select_option(cls): - return ( - joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType] - joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType] - joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType] - joinedload(cls.avatar), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_statistics), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_achievements), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_profile_sections), # pyright: ignore[reportArgumentType] - selectinload(cls.statistics), # pyright: ignore[reportArgumentType] - joinedload(cls.team_membership), # pyright: ignore[reportArgumentType] - selectinload(cls.rank_history), # pyright: ignore[reportArgumentType] - selectinload(cls.active_banners), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_badges), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_previous_usernames), # pyright: ignore[reportArgumentType] - selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType] - ) - - @classmethod - def all_select_clause(cls): - return select(cls).options(*cls.all_select_option()) - - -# ============================================ -# Lazer API 专用表模型 -# ============================================ - - -class LazerUserProfile(SQLModel, table=True): - __tablename__ = "lazer_user_profiles" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - - # 基本状态字段 - is_active: bool = Field(default=True) - is_bot: bool = Field(default=False) - is_deleted: bool = Field(default=False) - is_online: bool = Field(default=True) - is_supporter: bool = Field(default=False) - is_restricted: bool = Field(default=False) - session_verified: bool = Field(default=False) - has_supported: bool = Field(default=False) - pm_friends_only: bool = Field(default=False) - - # 基本资料字段 - default_group: str = Field(default="default", max_length=50) - last_visit: datetime | None = Field(default=None, sa_column=Column(DateTime)) - join_date: datetime | None = Field(default=None, sa_column=Column(DateTime)) - profile_colour: str | None = Field(default=None, max_length=7) - profile_hue: int | None = Field(default=None) - - # 社交媒体和个人资料字段 - avatar_url: str | None = Field(default=None, max_length=500) - cover_url: str | None = Field(default=None, max_length=500) - discord: str | None = Field(default=None, max_length=100) - twitter: str | None = Field(default=None, max_length=100) - website: str | None = Field(default=None, max_length=500) - title: str | None = Field(default=None, max_length=100) - title_url: str | None = Field(default=None, max_length=500) - interests: str | None = Field(default=None, sa_column=Column(Text)) - location: str | None = Field(default=None, max_length=100) - - occupation: str | None = Field(default=None) # 职业字段,默认为 None - - # 游戏相关字段 - playmode: str = Field(default="osu", max_length=10) - support_level: int = Field(default=0) - max_blocks: int = Field(default=100) - max_friends: int = Field(default=500) - post_count: int = Field(default=0) - - # 页面内容 - page_html: str | None = Field(default=None, sa_column=Column(Text)) - page_raw: str | None = Field(default=None, sa_column=Column(Text)) - - profile_order: str = Field( - default="me,recent_activity,top_ranks,medals,historical,beatmaps,kudosu" - ) - - # 关联关系 - user: "User" = Relationship(back_populates="lazer_profile") - - -class LazerUserProfileSections(SQLModel, table=True): - __tablename__ = "lazer_user_profile_sections" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - section_name: str = Field(sa_column=Column(VARCHAR(50))) - display_order: int | None = Field(default=None) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_profile_sections") - - -class LazerUserCountry(SQLModel, table=True): - __tablename__ = "lazer_user_countries" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - code: str = Field(max_length=2) - name: str = Field(max_length=100) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - -class LazerUserKudosu(SQLModel, table=True): - __tablename__ = "lazer_user_kudosu" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - available: int = Field(default=0) - total: int = Field(default=0) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - -class LazerUserCounts(SQLModel, table=True): - __tablename__ = "lazer_user_counts" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - - # 统计计数字段 - beatmap_playcounts_count: int = Field(default=0) - comments_count: int = Field(default=0) - favourite_beatmapset_count: int = Field(default=0) - follower_count: int = Field(default=0) - graveyard_beatmapset_count: int = Field(default=0) - guest_beatmapset_count: int = Field(default=0) - loved_beatmapset_count: int = Field(default=0) - mapping_follower_count: int = Field(default=0) - nominated_beatmapset_count: int = Field(default=0) - pending_beatmapset_count: int = Field(default=0) - ranked_beatmapset_count: int = Field(default=0) - ranked_and_approved_beatmapset_count: int = Field(default=0) - unranked_beatmapset_count: int = Field(default=0) - scores_best_count: int = Field(default=0) - scores_first_count: int = Field(default=0) - scores_pinned_count: int = Field(default=0) - scores_recent_count: int = Field(default=0) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - # 关联关系 - user: "User" = Relationship(back_populates="lazer_counts") - - -class LazerUserStatistics(SQLModel, table=True): - __tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType] - - user_id: int = Field( - default=None, - sa_column=Column( - BigInteger, - ForeignKey("users.id"), - primary_key=True, - ), - ) - mode: str = Field(default="osu", max_length=10, primary_key=True) - - # 基本命中统计 - count_100: int = Field(default=0) - count_300: int = Field(default=0) - count_50: int = Field(default=0) - count_miss: int = Field(default=0) - - # 等级信息 - level_current: int = Field(default=1) - level_progress: int = Field(default=0) - - # 排名信息 - global_rank: int | None = Field(default=None) - global_rank_exp: int | None = Field(default=None) - country_rank: int | None = Field(default=None) - - # PP 和分数 - pp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2))) - pp_exp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2))) - ranked_score: int = Field(default=0, sa_column=Column(BigInteger)) - hit_accuracy: float = Field(default=0.00, sa_column=Column(DECIMAL(5, 2))) - total_score: int = Field(default=0, sa_column=Column(BigInteger)) - total_hits: int = Field(default=0, sa_column=Column(BigInteger)) - maximum_combo: int = Field(default=0) - - # 游戏统计 - play_count: int = Field(default=0) - play_time: int = Field(default=0) # 秒 - replays_watched_by_others: int = Field(default=0) - is_ranked: bool = Field(default=False) - - # 成绩等级计数 - grade_ss: int = Field(default=0) - grade_ssh: int = Field(default=0) - grade_s: int = Field(default=0) - grade_sh: int = Field(default=0) - grade_a: int = Field(default=0) - - # 最高排名记录 - rank_highest: int | None = Field(default=None) - rank_highest_updated_at: datetime | None = Field( - default=None, sa_column=Column(DateTime) - ) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - # 关联关系 - user: "User" = Relationship(back_populates="lazer_statistics") - - -class LazerUserBanners(SQLModel, table=True): - __tablename__ = "lazer_user_tournament_banners" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - tournament_id: int - image_url: str = Field(sa_column=Column(VARCHAR(500))) - is_active: bool | None = Field(default=None) - - # 修正user关系的back_populates值 - user: "User" = Relationship(back_populates="active_banners") - - -class LazerUserAchievement(SQLModel, table=True): - __tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - achievement_id: int - achieved_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_achievements") - - -class LazerUserBadge(SQLModel, table=True): - __tablename__ = "lazer_user_badges" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - badge_id: int - awarded_at: datetime | None = Field(default=None, sa_column=Column(DateTime)) - description: str | None = Field(default=None, sa_column=Column(Text)) - image_url: str | None = Field(default=None, max_length=500) - url: str | None = Field(default=None, max_length=500) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_badges") - - -class LazerUserMonthlyPlaycounts(SQLModel, table=True): - __tablename__ = "lazer_user_monthly_playcounts" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - start_date: datetime = Field(sa_column=Column(Date)) - play_count: int = Field(default=0) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_monthly_playcounts") - - -class LazerUserPreviousUsername(SQLModel, table=True): - __tablename__ = "lazer_user_previous_usernames" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - username: str = Field(max_length=32) - changed_at: datetime = Field(sa_column=Column(DateTime)) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_previous_usernames") - - -class LazerUserReplaysWatched(SQLModel, table=True): - __tablename__ = "lazer_user_replays_watched" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - start_date: datetime = Field(sa_column=Column(Date)) - count: int = Field(default=0) - - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="lazer_replays_watched") - - -# 类型转换用的 UserAchievement(不是 SQLAlchemy 模型) -@dataclass -class UserAchievement: - achieved_at: datetime - achievement_id: int - - -class DailyChallengeStats(SQLModel, table=True): - __tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("users.id"), unique=True) - ) - - daily_streak_best: int = Field(default=0) - daily_streak_current: int = Field(default=0) - last_update: datetime | None = Field(default=None, sa_column=Column(DateTime)) - last_weekly_streak: datetime | None = Field( - default=None, sa_column=Column(DateTime) - ) - playcount: int = Field(default=0) - top_10p_placements: int = Field(default=0) - top_50p_placements: int = Field(default=0) - weekly_streak_best: int = Field(default=0) - weekly_streak_current: int = Field(default=0) - - user: "User" = Relationship(back_populates="daily_challenge_stats") - - -class RankHistory(SQLModel, table=True): - __tablename__ = "rank_history" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - mode: str = Field(max_length=10) - rank_data: list = Field(sa_column=Column(JSON)) # Array of ranks - date_recorded: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - - user: "User" = Relationship(back_populates="rank_history") - - -class UserAvatar(SQLModel, table=True): - __tablename__ = "user_avatars" # pyright: ignore[reportAssignmentType] - - id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - filename: str = Field(max_length=255) - original_filename: str = Field(max_length=255) - file_size: int - mime_type: str = Field(max_length=100) - is_active: bool = Field(default=True) - created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) - updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) - r2_original_url: str | None = Field(default=None, max_length=500) - r2_game_url: str | None = Field(default=None, max_length=500) - - user: "User" = Relationship(back_populates="avatar") diff --git a/app/dependencies/database.py b/app/dependencies/database.py index fe09139..77b15c3 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -5,15 +5,11 @@ import json from app.config import settings from pydantic import BaseModel +import redis.asyncio as redis from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession -try: - import redis -except ImportError: - redis = None - def json_serializer(value): if isinstance(value, BaseModel | SQLModel): @@ -25,10 +21,7 @@ def json_serializer(value): engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer) # Redis 连接 -if redis: - redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) -else: - redis_client = None +redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) # 数据库依赖 diff --git a/app/dependencies/fetcher.py b/app/dependencies/fetcher.py index d3c216a..806eb87 100644 --- a/app/dependencies/fetcher.py +++ b/app/dependencies/fetcher.py @@ -8,7 +8,7 @@ from app.log import logger fetcher: Fetcher | None = None -def get_fetcher() -> Fetcher: +async def get_fetcher() -> Fetcher: global fetcher if fetcher is None: fetcher = Fetcher( @@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher: settings.FETCHER_CALLBACK_URL, ) redis = get_redis() - if redis: - access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}") - if access_token: - fetcher.access_token = str(access_token) - refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}") - if refresh_token: - fetcher.refresh_token = str(refresh_token) - if not fetcher.access_token or not fetcher.refresh_token: - logger.opt(colors=True).info( - f"Login to initialize fetcher: {fetcher.authorize_url}" - ) + access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}") + if access_token: + fetcher.access_token = str(access_token) + refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}") + if refresh_token: + fetcher.refresh_token = str(refresh_token) + if not fetcher.access_token or not fetcher.refresh_token: + logger.opt(colors=True).info( + f"Login to initialize fetcher: {fetcher.authorize_url}" + ) return fetcher diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 0c8f8bc..5537f4f 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -1,14 +1,13 @@ from __future__ import annotations from app.auth import get_token_by_access_token -from app.database import ( - User as DBUser, -) +from app.database import User from .database import get_db from fastapi import Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession security = HTTPBearer() @@ -17,7 +16,7 @@ security = HTTPBearer() async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), db: AsyncSession = Depends(get_db), -) -> DBUser: +) -> User: """获取当前认证用户""" token = credentials.credentials @@ -27,13 +26,9 @@ async def get_current_user( return user -async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None: +async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None: token_record = await get_token_by_access_token(db, token) if not token_record: return None - user = ( - await db.exec( - DBUser.all_select_clause().where(DBUser.id == token_record.user_id) - ) - ).first() + user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() return user diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py index 08e3508..2717a35 100644 --- a/app/fetcher/_base.py +++ b/app/fetcher/_base.py @@ -59,16 +59,15 @@ class BaseFetcher: self.refresh_token = token_data.get("refresh_token", "") self.token_expiry = int(time.time()) + token_data["expires_in"] redis = get_redis() - if redis: - redis.set( - f"fetcher:access_token:{self.client_id}", - self.access_token, - ex=token_data["expires_in"], - ) - redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, - ) + await redis.set( + f"fetcher:access_token:{self.client_id}", + self.access_token, + ex=token_data["expires_in"], + ) + await redis.set( + f"fetcher:refresh_token:{self.client_id}", + self.refresh_token, + ) async def refresh_access_token(self) -> None: async with AsyncClient() as client: @@ -87,13 +86,12 @@ class BaseFetcher: self.refresh_token = token_data.get("refresh_token", "") self.token_expiry = int(time.time()) + token_data["expires_in"] redis = get_redis() - if redis: - redis.set( - f"fetcher:access_token:{self.client_id}", - self.access_token, - ex=token_data["expires_in"], - ) - redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, - ) + await redis.set( + f"fetcher:access_token:{self.client_id}", + self.access_token, + ex=token_data["expires_in"], + ) + await redis.set( + f"fetcher:refresh_token:{self.client_id}", + self.refresh_token, + ) diff --git a/app/fetcher/osu_dot_direct.py b/app/fetcher/osu_dot_direct.py index 08b8dfc..6e18435 100644 --- a/app/fetcher/osu_dot_direct.py +++ b/app/fetcher/osu_dot_direct.py @@ -4,7 +4,7 @@ from ._base import BaseFetcher from httpx import AsyncClient from loguru import logger -import redis +import redis.asyncio as redis class OsuDotDirectFetcher(BaseFetcher): @@ -22,8 +22,8 @@ class OsuDotDirectFetcher(BaseFetcher): async def get_or_fetch_beatmap_raw( self, redis: redis.Redis, beatmap_id: int ) -> str: - if redis.exists(f"beatmap:{beatmap_id}:raw"): - return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] + if await redis.exists(f"beatmap:{beatmap_id}:raw"): + return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] raw = await self.get_beatmap_raw(beatmap_id) - redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) + await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) return raw diff --git a/app/models/beatmap.py b/app/models/beatmap.py index 4f12e13..fae18ba 100644 --- a/app/models/beatmap.py +++ b/app/models/beatmap.py @@ -42,11 +42,12 @@ class Language(IntEnum): KOREAN = 6 FRENCH = 7 GERMAN = 8 - ITALIAN = 9 - SPANISH = 10 - RUSSIAN = 11 - POLISH = 12 - OTHER = 13 + SWEDISH = 9 + ITALIAN = 10 + SPANISH = 11 + RUSSIAN = 12 + POLISH = 13 + OTHER = 14 class BeatmapAttributes(BaseModel): diff --git a/app/models/oauth.py b/app/models/oauth.py index 22fcf63..6665965 100644 --- a/app/models/oauth.py +++ b/app/models/oauth.py @@ -1,7 +1,6 @@ # OAuth 相关模型 from __future__ import annotations -from typing import List from pydantic import BaseModel @@ -39,18 +38,21 @@ class OAuthErrorResponse(BaseModel): class RegistrationErrorResponse(BaseModel): """注册错误响应模型""" + form_error: dict class UserRegistrationErrors(BaseModel): """用户注册错误模型""" - username: List[str] = [] - user_email: List[str] = [] - password: List[str] = [] + + username: list[str] = [] + user_email: list[str] = [] + password: list[str] = [] class RegistrationRequestErrors(BaseModel): """注册请求错误模型""" + message: str | None = None redirect: str | None = None user: UserRegistrationErrors | None = None diff --git a/app/models/score.py b/app/models/score.py index b613ae2..bfc9f53 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -132,7 +132,7 @@ class HitResultInt(IntEnum): class LeaderboardType(Enum): GLOBAL = "global" - FRIENDS = "friends" + FRIENDS = "friend" COUNTRY = "country" TEAM = "team" diff --git a/app/models/user.py b/app/models/user.py index 4eb0911..3052eef 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -2,15 +2,11 @@ from __future__ import annotations from datetime import datetime from enum import Enum -from typing import TYPE_CHECKING -from .score import GameMode +from .model import UTCBaseModel from pydantic import BaseModel -if TYPE_CHECKING: - from app.database import LazerUserAchievement, Team - class PlayStyle(str, Enum): MOUSE = "mouse" @@ -77,24 +73,7 @@ class MonthlyPlaycount(BaseModel): count: int -class UserAchievement(BaseModel): - achieved_at: datetime - achievement_id: int - - # 添加数据库模型转换方法 - def to_db_model(self, user_id: int) -> "LazerUserAchievement": - from app.database import ( - LazerUserAchievement, - ) - - return LazerUserAchievement( - user_id=user_id, - achievement_id=self.achievement_id, - achieved_at=self.achieved_at, - ) - - -class RankHighest(BaseModel): +class RankHighest(UTCBaseModel): rank: int updated_at: datetime @@ -104,115 +83,6 @@ class RankHistory(BaseModel): data: list[int] -class DailyChallengeStats(BaseModel): - daily_streak_best: int = 0 - daily_streak_current: int = 0 - last_update: datetime | None = None - last_weekly_streak: datetime | None = None - playcount: int = 0 - top_10p_placements: int = 0 - top_50p_placements: int = 0 - user_id: int - weekly_streak_best: int = 0 - weekly_streak_current: int = 0 - - class Page(BaseModel): html: str = "" raw: str = "" - - -class User(BaseModel): - # 基本信息 - id: int - username: str - avatar_url: str - country_code: str - default_group: str = "default" - is_active: bool = True - is_bot: bool = False - is_deleted: bool = False - is_online: bool = True - is_supporter: bool = False - is_restricted: bool = False - last_visit: datetime | None = None - pm_friends_only: bool = False - profile_colour: str | None = None - - # 个人资料 - cover_url: str | None = None - discord: str | None = None - has_supported: bool = False - interests: str | None = None - join_date: datetime - location: str | None = None - max_blocks: int = 100 - max_friends: int = 500 - occupation: str | None = None - playmode: GameMode = GameMode.OSU - playstyle: list[PlayStyle] = [] - post_count: int = 0 - profile_hue: int | None = None - profile_order: list[str] = [ - "me", - "recent_activity", - "top_ranks", - "medals", - "historical", - "beatmaps", - "kudosu", - ] - title: str | None = None - title_url: str | None = None - twitter: str | None = None - website: str | None = None - session_verified: bool = False - support_level: int = 0 - - # 关联对象 - country: Country - cover: Cover - kudosu: Kudosu - statistics: Statistics - statistics_rulesets: dict[str, Statistics] - - # 计数信息 - beatmap_playcounts_count: int = 0 - comments_count: int = 0 - favourite_beatmapset_count: int = 0 - follower_count: int = 0 - graveyard_beatmapset_count: int = 0 - guest_beatmapset_count: int = 0 - loved_beatmapset_count: int = 0 - mapping_follower_count: int = 0 - nominated_beatmapset_count: int = 0 - pending_beatmapset_count: int = 0 - ranked_beatmapset_count: int = 0 - ranked_and_approved_beatmapset_count: int = 0 - unranked_beatmapset_count: int = 0 - scores_best_count: int = 0 - scores_first_count: int = 0 - scores_pinned_count: int = 0 - scores_recent_count: int = 0 - - # 历史数据 - account_history: list[dict] = [] - active_tournament_banner: dict | None = None - active_tournament_banners: list[dict] = [] - badges: list[dict] = [] - current_season_stats: dict | None = None - daily_challenge_user_stats: DailyChallengeStats | None = None - groups: list[dict] = [] - monthly_playcounts: list[MonthlyPlaycount] = [] - page: Page = Page() - previous_usernames: list[str] = [] - rank_highest: RankHighest | None = None - rank_history: RankHistory | None = None - rankHistory: RankHistory | None = None # 兼容性别名 - replays_watched_counts: list[dict] = [] - team: "Team | None" = None - user_achievements: list[UserAchievement] = [] - - -class APIUser(BaseModel): - id: int diff --git a/app/router/auth.py b/app/router/auth.py index 0f41b32..7a2a14d 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import timedelta +from datetime import UTC, datetime, timedelta import re from app.auth import ( @@ -12,17 +12,21 @@ from app.auth import ( store_token, ) from app.config import settings -from app.database import User as DBUser +from app.database import DailyChallengeStats, User +from app.database.statistics import UserStatistics from app.dependencies import get_db +from app.log import logger from app.models.oauth import ( OAuthErrorResponse, RegistrationRequestErrors, TokenResponse, UserRegistrationErrors, ) +from app.models.score import GameMode from fastapi import APIRouter, Depends, Form from fastapi.responses import JSONResponse +from sqlalchemy import text from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -110,12 +114,12 @@ async def register_user( email_errors = validate_email(user_email) password_errors = validate_password(user_password) - result = await db.exec(select(DBUser).where(DBUser.name == user_username)) + result = await db.exec(select(User).where(User.username == user_username)) existing_user = result.first() if existing_user: username_errors.append("Username is already taken") - result = await db.exec(select(DBUser).where(DBUser.email == user_email)) + result = await db.exec(select(User).where(User.email == user_email)) existing_email = result.first() if existing_email: email_errors.append("Email is already taken") @@ -135,119 +139,41 @@ async def register_user( try: # 创建新用户 - from datetime import datetime - import time + # 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy) + result = await db.execute( # pyright: ignore[reportDeprecated] + text( + "SELECT AUTO_INCREMENT FROM information_schema.TABLES " + "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'" + ) + ) + next_id = result.one()[0] + if next_id <= 2: + await db.execute(text("ALTER TABLE lazer_users AUTO_INCREMENT = 3")) + await db.commit() - new_user = DBUser( - name=user_username, - safe_name=user_username.lower(), # 安全用户名(小写) + new_user = User( + username=user_username, email=user_email, pw_bcrypt=get_password_hash(user_password), priv=1, # 普通用户权限 - country="CN", # 默认国家 - creation_time=int(time.time()), - latest_activity=int(time.time()), - preferred_mode=0, # 默认模式 - play_style=0, # 默认游戏风格 + country_code="CN", # 默认国家 + join_date=datetime.now(UTC), + last_visit=datetime.now(UTC), ) - db.add(new_user) await db.commit() await db.refresh(new_user) - - # 保存用户ID,因为会话可能会关闭 - user_id = new_user.id - - if user_id <= 2: - await db.rollback() - try: - from sqlalchemy import text - - # 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy) - await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3")) - await db.commit() - - # 重新创建用户 - new_user = DBUser( - name=user_username, - safe_name=user_username.lower(), - email=user_email, - pw_bcrypt=get_password_hash(user_password), - priv=1, - country="CN", - creation_time=int(time.time()), - latest_activity=int(time.time()), - preferred_mode=0, - play_style=0, - ) - - db.add(new_user) - await db.commit() - await db.refresh(new_user) - user_id = new_user.id - - # 最终检查ID是否有效 - if user_id <= 2: - await db.rollback() - errors = RegistrationRequestErrors( - message=( - "Failed to create account with valid ID. " - "Please contact support." - ) - ) - return JSONResponse( - status_code=500, content={"form_error": errors.model_dump()} - ) - - except Exception as fix_error: - await db.rollback() - print(f"Failed to fix AUTO_INCREMENT: {fix_error}") - errors = RegistrationRequestErrors( - message="Failed to create account with valid ID. Please try again." - ) - return JSONResponse( - status_code=500, content={"form_error": errors.model_dump()} - ) - - # 创建默认的 lazer_profile - from app.database.user import LazerUserProfile - - lazer_profile = LazerUserProfile( - user_id=user_id, - is_active=True, - is_bot=False, - is_deleted=False, - is_online=True, - is_supporter=False, - is_restricted=False, - session_verified=False, - has_supported=False, - pm_friends_only=False, - default_group="default", - join_date=datetime.utcnow(), - playmode="osu", - support_level=0, - max_blocks=50, - max_friends=250, - post_count=0, - ) - - db.add(lazer_profile) + assert new_user.id is not None, "New user ID should not be None" + for i in GameMode: + statistics = UserStatistics(mode=i, user_id=new_user.id) + db.add(statistics) + daily_challenge_user_stats = DailyChallengeStats(user_id=new_user.id) + db.add(daily_challenge_user_stats) await db.commit() - - # 返回成功响应 - return JSONResponse( - status_code=201, - content={"message": "Account created successfully", "user_id": user_id}, - ) - - except Exception as e: + except Exception: await db.rollback() # 打印详细错误信息用于调试 - print(f"Registration error: {e}") - import traceback - - traceback.print_exc() + logger.exception(f"Registration error for user {user_username}") # 返回通用错误 errors = RegistrationRequestErrors( @@ -323,6 +249,7 @@ async def oauth_token( refresh_token_str = generate_refresh_token() # 存储令牌 + assert user.id await store_token( db, user.id, diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 71d554f..9574bdb 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -5,12 +5,7 @@ import hashlib import json from app.calculator import calculate_beatmap_attribute -from app.database import ( - Beatmap, - BeatmapResp, - User as DBUser, -) -from app.database.beatmapset import Beatmapset +from app.database import Beatmap, BeatmapResp, User from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -27,9 +22,8 @@ from .api_router import router from fastapi import Depends, HTTPException, Query from httpx import HTTPError, HTTPStatusError from pydantic import BaseModel -from redis import Redis +from redis.asyncio import Redis import rosu_pp_py as rosu -from sqlalchemy.orm import joinedload from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -39,7 +33,7 @@ async def lookup_beatmap( id: int | None = Query(default=None, alias="id"), md5: str | None = Query(default=None, alias="checksum"), filename: str | None = Query(default=None, alias="filename"), - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): @@ -56,19 +50,19 @@ async def lookup_beatmap( if beatmap is None: raise HTTPException(status_code=404, detail="Beatmap not found") - return BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap, session=db, user=current_user) @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) async def get_beatmap( bid: int, - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): try: beatmap = await Beatmap.get_or_fetch(db, fetcher, bid) - return BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap, session=db, user=current_user) except HTTPError: raise HTTPException(status_code=404, detail="Beatmap not found") @@ -81,42 +75,27 @@ class BatchGetResp(BaseModel): @router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp) async def batch_get_beatmaps( b_ids: list[int] = Query(alias="id", default_factory=list), - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): if not b_ids: # select 50 beatmaps by last_updated beatmaps = ( await db.exec( - select(Beatmap) - .options( - joinedload( - Beatmap.beatmapset # pyright: ignore[reportArgumentType] - ).selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ) - ) - .order_by(col(Beatmap.last_updated).desc()) - .limit(50) + select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50) ) ).all() else: beatmaps = ( - await db.exec( - select(Beatmap) - .options( - joinedload( - Beatmap.beatmapset # pyright: ignore[reportArgumentType] - ).selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ) - ) - .where(col(Beatmap.id).in_(b_ids)) - .limit(50) - ) + await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)) ).all() - return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps]) + return BatchGetResp( + beatmaps=[ + await BeatmapResp.from_db(bm, session=db, user=current_user) + for bm in beatmaps + ] + ) @router.post( @@ -126,7 +105,7 @@ async def batch_get_beatmaps( ) async def get_beatmap_attributes( beatmap: int, - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), mods: list[str] = Query(default_factory=list), ruleset: GameMode | None = Query(default=None), ruleset_id: int | None = Query(default=None), @@ -153,8 +132,8 @@ async def get_beatmap_attributes( f"beatmap:{beatmap}:{ruleset}:" f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" ) - if redis.exists(key): - return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType] + if await redis.exists(key): + return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] try: resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap) @@ -164,7 +143,7 @@ async def get_beatmap_attributes( ) except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue] raise HTTPException(status_code=400, detail=str(e)) - redis.set(key, attr.model_dump_json()) + await redis.set(key, attr.model_dump_json()) return attr except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmap not found") diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index db2dd77..b4d2e4c 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -1,10 +1,8 @@ from __future__ import annotations -from app.database import ( - Beatmapset, - BeatmapsetResp, - User as DBUser, -) +from typing import Literal + +from app.database import Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.dependencies.database import get_db from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -12,9 +10,9 @@ from app.fetcher import Fetcher from .api_router import router -from fastapi import Depends, HTTPException +from fastapi import Depends, Form, HTTPException, Query +from fastapi.responses import RedirectResponse from httpx import HTTPStatusError -from sqlalchemy.orm import selectinload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -22,17 +20,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession @router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp) async def get_beatmapset( sid: int, - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): - beatmapset = ( - await db.exec( - select(Beatmapset) - .options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType] - .where(Beatmapset.id == sid) - ) - ).first() + beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first() if not beatmapset: try: resp = await fetcher.get_beatmapset(sid) @@ -40,5 +32,55 @@ async def get_beatmapset( except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmapset not found") else: - resp = BeatmapsetResp.from_db(beatmapset) + resp = await BeatmapsetResp.from_db( + beatmapset, session=db, include=["recent_favourites"], user=current_user + ) return resp + + +@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"]) +async def download_beatmapset( + beatmapset: int, + no_video: bool = Query(True, alias="noVideo"), + current_user: User = Depends(get_current_user), +): + if current_user.country_code == "CN": + return RedirectResponse( + f"https://txy1.sayobot.cn/beatmaps/download/" + f"{'novideo' if no_video else 'full'}/{beatmapset}?server=auto" + ) + else: + return RedirectResponse( + f"https://api.nerinyan.moe/d/{beatmapset}?noVideo={no_video}" + ) + + +@router.post("/beatmapsets/{beatmapset}/favourites", tags=["beatmapset"]) +async def favourite_beatmapset( + beatmapset: int, + action: Literal["favourite", "unfavourite"] = Form(), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + existing_favourite = ( + await db.exec( + select(FavouriteBeatmapset).where( + FavouriteBeatmapset.user_id == current_user.id, + FavouriteBeatmapset.beatmapset_id == beatmapset, + ) + ) + ).first() + + if action == "favourite" and existing_favourite: + raise HTTPException(status_code=400, detail="Already favourited") + elif action == "unfavourite" and not existing_favourite: + raise HTTPException(status_code=400, detail="Not favourited") + + if action == "favourite": + favourite = FavouriteBeatmapset( + user_id=current_user.id, beatmapset_id=beatmapset + ) + db.add(favourite) + else: + await db.delete(existing_favourite) + await db.commit() diff --git a/app/router/me.py b/app/router/me.py index 93dcbdc..b6d7d26 100644 --- a/app/router/me.py +++ b/app/router/me.py @@ -1,28 +1,27 @@ from __future__ import annotations -from typing import Literal - -from app.database import ( - User as DBUser, -) +from app.database import User, UserResp +from app.database.lazer_user import ALL_INCLUDED from app.dependencies import get_current_user -from app.models.user import ( - User as ApiUser, -) -from app.utils import convert_db_user_to_api_user +from app.dependencies.database import get_db +from app.models.score import GameMode from .api_router import router from fastapi import Depends +from sqlmodel.ext.asyncio.session import AsyncSession -@router.get("/me/{ruleset}", response_model=ApiUser) -@router.get("/me/", response_model=ApiUser) +@router.get("/me/{ruleset}", response_model=UserResp) +@router.get("/me/", response_model=UserResp) async def get_user_info_default( - ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu", - current_user: DBUser = Depends(get_current_user), + ruleset: GameMode | None = None, + current_user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_db), ): - """获取当前用户信息(默认使用osu模式)""" - # 默认使用osu模式 - api_user = await convert_db_user_to_api_user(current_user, ruleset) - return api_user + return await UserResp.from_db( + current_user, + session, + ALL_INCLUDED, + ruleset, + ) diff --git a/app/router/relationship.py b/app/router/relationship.py index 9ed5b0f..0832d09 100644 --- a/app/router/relationship.py +++ b/app/router/relationship.py @@ -8,7 +8,7 @@ from app.dependencies.user import get_current_user from .api_router import router from fastapi import Depends, HTTPException, Query, Request -from sqlalchemy.orm import joinedload +from pydantic import BaseModel from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -26,17 +26,19 @@ async def get_relationship( else RelationshipType.BLOCK ) relationships = await db.exec( - select(Relationship) - .options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType] - .where( + select(Relationship).where( Relationship.user_id == current_user.id, Relationship.type == relationship_type, ) ) - return [await RelationshipResp.from_db(db, rel) for rel in relationships] + return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()] -@router.post("/friends", tags=["relationship"], response_model=RelationshipResp) +class AddFriendResp(BaseModel): + user_relation: RelationshipResp + + +@router.post("/friends", tags=["relationship"], response_model=AddFriendResp) @router.post("/blocks", tags=["relationship"]) async def add_relationship( request: Request, @@ -87,14 +89,10 @@ async def add_relationship( if origin_type == RelationshipType.FOLLOW: relationship = ( await db.exec( - select(Relationship) - .where( + select(Relationship).where( Relationship.user_id == current_user_id, Relationship.target_id == target, ) - .options( - joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType] - ) ) ).first() assert relationship, "Relationship should exist after commit" diff --git a/app/router/room.py b/app/router/room.py index 6e6f4a1..a2347ec 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -6,8 +6,10 @@ from app.dependencies.fetcher import get_fetcher from app.fetcher import Fetcher from app.models.room import MultiplayerRoom, MultiplayerRoomState, Room -from api_router import router +from .api_router import router + from fastapi import Depends, HTTPException, Query +from redis.asyncio import Redis from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -21,6 +23,7 @@ async def get_all_rooms( ), # TODO: 对房间根据分类进行筛选(真的有人用这功能吗) db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), + redis: Redis = Depends(get_redis), ): all_roomID = (await db.exec(select(RoomIndex))).all() redis = get_redis() diff --git a/app/router/score.py b/app/router/score.py index cc38dcc..2f1303e 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -1,11 +1,7 @@ from __future__ import annotations -from app.database import ( - User as DBUser, -) -from app.database.beatmap import Beatmap -from app.database.score import Score, ScoreResp, process_score, process_user -from app.database.score_token import ScoreToken, ScoreTokenResp +from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User +from app.database.score import get_leaderboard, process_score, process_user from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -13,6 +9,7 @@ from app.models.beatmap import BeatmapRankStatus from app.models.score import ( INT_TO_MODE, GameMode, + LeaderboardType, Rank, SoloScoreSubmissionInfo, ) @@ -21,9 +18,9 @@ from .api_router import router from fastapi import Depends, Form, HTTPException, Query from pydantic import BaseModel -from redis import Redis +from redis.asyncio import Redis from sqlalchemy.orm import joinedload -from sqlmodel import col, select, true +from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -37,44 +34,26 @@ class BeatmapScores(BaseModel): ) async def get_beatmap_scores( beatmap: int, + mode: GameMode, legacy_only: bool = Query(None), # TODO:加入对这个参数的查询 - mode: GameMode | None = Query(None), - # mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询 - type: str = Query(None), - current_user: DBUser = Depends(get_current_user), + mods: list[str] = Query(default_factory=set, alias="mods[]"), + type: LeaderboardType = Query(LeaderboardType.GLOBAL), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), + limit: int = Query(50, ge=1, le=200), ): if legacy_only: raise HTTPException( status_code=404, detail="this server only contains lazer scores" ) - all_scores = ( - await db.exec( - Score.select_clause_unique( - Score.beatmap_id == beatmap, - col(Score.passed).is_(True), - Score.gamemode == mode if mode is not None else true(), - ) - ) - ).all() - - user_score = ( - await db.exec( - Score.select_clause_unique( - Score.beatmap_id == beatmap, - Score.user_id == current_user.id, - col(Score.passed).is_(True), - Score.gamemode == mode if mode is not None else true(), - ) - ) - ).first() + all_scores, user_score = await get_leaderboard( + db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods + ) return BeatmapScores( - scores=[await ScoreResp.from_db(db, score, score.user) for score in all_scores], - userScore=await ScoreResp.from_db(db, user_score, user_score.user) - if user_score - else None, + scores=[await ScoreResp.from_db(db, score) for score in all_scores], + userScore=await ScoreResp.from_db(db, user_score) if user_score else None, ) @@ -94,7 +73,7 @@ async def get_user_beatmap_score( legacy_only: bool = Query(None), mode: str = Query(None), mods: str = Query(None), # TODO:添加mods筛选 - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): if legacy_only: @@ -103,7 +82,7 @@ async def get_user_beatmap_score( ) user_score = ( await db.exec( - Score.select_clause(True) + select(Score) .where( Score.gamemode == mode if mode is not None else True, Score.beatmap_id == beatmap, @@ -120,7 +99,7 @@ async def get_user_beatmap_score( else: return BeatmapUserScore( position=user_score.position if user_score.position is not None else 0, - score=await ScoreResp.from_db(db, user_score, user_score.user), + score=await ScoreResp.from_db(db, user_score), ) @@ -134,7 +113,7 @@ async def get_user_all_beatmap_scores( user: int, legacy_only: bool = Query(None), ruleset: str = Query(None), - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): if legacy_only: @@ -143,7 +122,7 @@ async def get_user_all_beatmap_scores( ) all_user_scores = ( await db.exec( - Score.select_clause() + select(Score) .where( Score.gamemode == ruleset if ruleset is not None else True, Score.beatmap_id == beatmap, @@ -153,9 +132,7 @@ async def get_user_all_beatmap_scores( ) ).all() - return [ - await ScoreResp.from_db(db, score, current_user) for score in all_user_scores - ] + return [await ScoreResp.from_db(db, score) for score in all_user_scores] @router.post( @@ -166,9 +143,10 @@ async def create_solo_score( version_hash: str = Form(""), beatmap_hash: str = Form(), ruleset_id: int = Form(..., ge=0, le=3), - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): + assert current_user.id async with db: score_token = ScoreToken( user_id=current_user.id, @@ -190,7 +168,7 @@ async def submit_solo_score( beatmap: int, token: int, info: SoloScoreSubmissionInfo, - current_user: DBUser = Depends(get_current_user), + current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), fetcher=Depends(get_fetcher), @@ -210,9 +188,7 @@ async def submit_solo_score( if score_token.score_id: score = ( await db.exec( - select(Score) - .options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType] - .where( + select(Score).where( Score.id == score_token.score_id, Score.user_id == current_user.id, ) @@ -246,8 +222,6 @@ async def submit_solo_score( score_id = score.id score_token.score_id = score_id await process_user(db, current_user, score, ranked) - score = ( - await db.exec(Score.select_clause().where(Score.id == score_id)) - ).first() + score = (await db.exec(select(Score).where(Score.id == score_id))).first() assert score is not None - return await ScoreResp.from_db(db, score, current_user) + return await ScoreResp.from_db(db, score) diff --git a/app/router/user.py b/app/router/user.py index 6e169c3..649f1d4 100644 --- a/app/router/user.py +++ b/app/router/user.py @@ -1,12 +1,9 @@ from __future__ import annotations -from typing import Literal - -from app.database import User as DBUser +from app.database import User, UserResp +from app.database.lazer_user import SEARCH_INCLUDED from app.dependencies.database import get_db -from app.models.score import INT_TO_MODE -from app.models.user import User as ApiUser -from app.utils import convert_db_user_to_api_user +from app.models.score import GameMode from .api_router import router @@ -17,28 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.sql.expression import col -# ---------- Shared Utility ---------- -async def get_user_by_lookup( - db: AsyncSession, lookup: str, key: str = "id" -) -> DBUser | None: - """根据查找方式获取用户""" - if key == "id": - try: - user_id = int(lookup) - result = await db.exec(select(DBUser).where(DBUser.id == user_id)) - return result.first() - except ValueError: - return None - elif key == "username": - result = await db.exec(select(DBUser).where(DBUser.name == lookup)) - return result.first() - else: - return None - - -# ---------- Batch Users ---------- class BatchUserResponse(BaseModel): - users: list[ApiUser] + users: list[UserResp] @router.get("/users", response_model=BatchUserResponse) @@ -51,75 +28,44 @@ async def get_users( ): if user_ids: searched_users = ( - await session.exec( - DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids)) - ) + await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids))) ).all() else: - searched_users = ( - await session.exec(DBUser.all_select_clause().limit(50)) - ).all() + searched_users = (await session.exec(select(User).limit(50))).all() return BatchUserResponse( users=[ - await convert_db_user_to_api_user( - searched_user, ruleset=INT_TO_MODE[searched_user.preferred_mode].value + await UserResp.from_db( + searched_user, + session, + include=SEARCH_INCLUDED, ) for searched_user in searched_users ] ) -# # ---------- Individual User ---------- -# @router.get("/users/{user_lookup}/{mode}", response_model=ApiUser) -# @router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser) -# async def get_user_with_mode( -# user_lookup: str, -# mode: Literal["osu", "taiko", "fruits", "mania"], -# key: Literal["id", "username"] = Query("id"), -# current_user: DBUser = Depends(get_current_user), -# db: AsyncSession = Depends(get_db), -# ): -# """获取指定游戏模式的用户信息""" -# user = await get_user_by_lookup(db, user_lookup, key) -# if not user: -# raise HTTPException(status_code=404, detail="User not found") - -# return await convert_db_user_to_api_user(user, mode) - - -# @router.get("/users/{user_lookup}", response_model=ApiUser) -# @router.get("/users/{user_lookup}/", response_model=ApiUser) -# async def get_user_default( -# user_lookup: str, -# key: Literal["id", "username"] = Query("id"), -# current_user: DBUser = Depends(get_current_user), -# db: AsyncSession = Depends(get_db), -# ): -# """获取用户信息(默认使用osu模式,但包含所有模式的统计信息)""" -# user = await get_user_by_lookup(db, user_lookup, key) -# if not user: -# raise HTTPException(status_code=404, detail="User not found") - -# return await convert_db_user_to_api_user(user, "osu") - - -@router.get("/users/{user}/{ruleset}", response_model=ApiUser) -@router.get("/users/{user}/", response_model=ApiUser) -@router.get("/users/{user}", response_model=ApiUser) +@router.get("/users/{user}/{ruleset}", response_model=UserResp) +@router.get("/users/{user}/", response_model=UserResp) +@router.get("/users/{user}", response_model=UserResp) async def get_user_info( user: str, - ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu", + ruleset: GameMode | None = None, session: AsyncSession = Depends(get_db), ): searched_user = ( await session.exec( - DBUser.all_select_clause().where( - DBUser.id == int(user) + select(User).where( + User.id == int(user) if user.isdigit() - else DBUser.name == user.removeprefix("@") + else User.username == user.removeprefix("@") ) ) ).first() if not searched_user: raise HTTPException(404, detail="User not found") - return await convert_db_user_to_api_user(searched_user, ruleset=ruleset) + return await UserResp.from_db( + searched_user, + session, + include=SEARCH_INCLUDED, + ruleset=ruleset, + ) diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index 821d831..909d503 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -2,10 +2,11 @@ from __future__ import annotations import asyncio from collections.abc import Coroutine +from datetime import UTC, datetime from typing import override -from app.database.relationship import Relationship, RelationshipType -from app.dependencies.database import engine +from app.database import Relationship, RelationshipType, User +from app.dependencies.database import engine, get_redis from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity from .hub import Client, Hub @@ -54,6 +55,18 @@ class MetadataHub(Hub[MetadataClientState]): async def _clean_state(self, state: MetadataClientState) -> None: if state.pushable: await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None)) + redis = get_redis() + if await redis.exists(f"metadata:online:{state.connection_id}"): + await redis.delete(f"metadata:online:{state.connection_id}") + async with AsyncSession(engine) as session: + async with session.begin(): + user = ( + await session.exec( + select(User).where(User.id == int(state.connection_id)) + ) + ).one() + user.last_visit = datetime.now(UTC) + await session.commit() @override def create_state(self, client: Client) -> MetadataClientState: @@ -93,6 +106,8 @@ class MetadataHub(Hub[MetadataClientState]): ) ) await asyncio.gather(*tasks) + redis = get_redis() + await redis.set(f"metadata:online:{user_id}", "") async def UpdateStatus(self, client: Client, status: int) -> None: status_ = OnlineStatus(status) diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index 7efcc89..bd311ec 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -7,10 +7,9 @@ import struct import time from typing import override -from app.database import Beatmap +from app.database import Beatmap, User from app.database.score import Score from app.database.score_token import ScoreToken -from app.database.user import User from app.dependencies.database import engine from app.models.beatmap import BeatmapRankStatus from app.models.mods import mods_to_int @@ -197,7 +196,7 @@ class SpectatorHub(Hub[StoreClientState]): ).first() if not user: return - name = user.name + name = user.username store.state = state store.beatmap_status = beatmap.beatmap_status store.checksum = beatmap.checksum @@ -241,65 +240,17 @@ class SpectatorHub(Hub[StoreClientState]): user_id = int(client.connection_id) store = self.get_or_create_state(client) score = store.score + assert store.beatmap_status is not None + assert store.state is not None + assert store.score is not None if not score or not store.score_token: return - - assert store.beatmap_status is not None - - async def _save_replay(): - assert store.checksum is not None - assert store.ruleset_id is not None - assert store.state is not None - assert store.score is not None - async with AsyncSession(engine) as session: - async with session: - start_time = time.time() - score_record = None - while time.time() - start_time < READ_SCORE_TIMEOUT: - sub_query = select(ScoreToken.score_id).where( - ScoreToken.id == store.score_token, - ) - result = await session.exec( - select(Score) - .options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType] - .where( - Score.id == sub_query, - Score.user_id == user_id, - ) - ) - score_record = result.first() - if score_record: - break - if not score_record: - return - if not score_record.passed: - return - score_record.has_replay = True - await session.commit() - await session.refresh(score_record) - save_replay( - ruleset_id=store.ruleset_id, - md5=store.checksum, - username=store.score.score_info.user.name, - score=score_record, - statistics=score.score_info.statistics, - maximum_statistics=score.score_info.maximum_statistics, - frames=score.replay_frames, - ) - if ( - ( - BeatmapRankStatus.PENDING - < store.beatmap_status - <= BeatmapRankStatus.LOVED - ) - and any( - k.is_hit() and v > 0 for k, v in score.score_info.statistics.items() - ) - and state.state != SpectatedUserState.Failed + BeatmapRankStatus.PENDING < store.beatmap_status <= BeatmapRankStatus.LOVED + ) and any( + k.is_hit() and v > 0 for k, v in store.score.score_info.statistics.items() ): - # save replay - await _save_replay() + await self._process_score(store, client) store.state = None store.beatmap_status = None store.checksum = None @@ -308,6 +259,56 @@ class SpectatorHub(Hub[StoreClientState]): store.score = None await self._end_session(user_id, state) + async def _process_score(self, store: StoreClientState, client: Client) -> None: + user_id = int(client.connection_id) + assert store.state is not None + assert store.score_token is not None + assert store.checksum is not None + assert store.ruleset_id is not None + assert store.score is not None + async with AsyncSession(engine) as session: + async with session: + start_time = time.time() + score_record = None + while time.time() - start_time < READ_SCORE_TIMEOUT: + sub_query = select(ScoreToken.score_id).where( + ScoreToken.id == store.score_token, + ) + result = await session.exec( + select(Score) + .options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType] + .where( + Score.id == sub_query, + Score.user_id == user_id, + ) + ) + score_record = result.first() + if score_record: + break + if not score_record: + return + if not score_record.passed: + return + await self.call_noblock( + client, + "UserScoreProcessed", + user_id, + score_record.id, + ) + # save replay + score_record.has_replay = True + await session.commit() + await session.refresh(score_record) + save_replay( + ruleset_id=store.ruleset_id, + md5=store.checksum, + username=store.score.score_info.user.name, + score=score_record, + statistics=store.score.score_info.statistics, + maximum_statistics=store.score.score_info.maximum_statistics, + frames=store.score.replay_frames, + ) + async def _end_session(self, user_id: int, state: SpectatorState) -> None: if state.state == SpectatedUserState.Playing: state.state = SpectatedUserState.Quit @@ -336,7 +337,7 @@ class SpectatorHub(Hub[StoreClientState]): async with AsyncSession(engine) as session: async with session.begin(): username = ( - await session.exec(select(User.name).where(User.id == user_id)) + await session.exec(select(User.username).where(User.id == user_id)) ).first() if not username: return diff --git a/app/utils.py b/app/utils.py index 9008706..09e8fdc 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,465 +1,6 @@ from __future__ import annotations -from datetime import UTC, datetime - -from app.database import ( - LazerUserCounts, - LazerUserProfile, - LazerUserStatistics, - User as DBUser, -) -from app.models.user import ( - Country, - Cover, - DailyChallengeStats, - GradeCounts, - Kudosu, - Level, - Page, - RankHighest, - RankHistory, - Statistics, - User, - UserAchievement, -) - def unix_timestamp_to_windows(timestamp: int) -> int: """Convert a Unix timestamp to a Windows timestamp.""" return (timestamp + 62135596800) * 10_000_000 - - -async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User: - """将数据库用户模型转换为API用户模型(使用 Lazer 表)""" - - # 从db_user获取基本字段值 - user_id = getattr(db_user, "id") - user_name = getattr(db_user, "name") - user_country = getattr(db_user, "country") - user_country_code = user_country # 在User模型中,country字段就是country_code - - # 获取 Lazer 用户资料 - profile = db_user.lazer_profile - if not profile: - # 如果没有 lazer 资料,使用默认值 - profile = LazerUserProfile( - user_id=user_id, - ) - - # 获取 Lazer 用户计数 - 使用正确的 lazer_counts 关系 - lzrcnt = db_user.lazer_counts - - if not lzrcnt: - # 如果没有 lazer 计数,使用默认值 - lzrcnt = LazerUserCounts(user_id=user_id) - - # 获取指定模式的统计信息 - user_stats = None - if db_user.lazer_statistics: - for stat in db_user.lazer_statistics: - if stat.mode == ruleset: - user_stats = stat - break - - if not user_stats: - # 如果没有找到指定模式的统计,创建默认统计 - user_stats = LazerUserStatistics(user_id=user_id) - - # 获取国家信息 - country_code = db_user.country_code if db_user.country_code is not None else "XX" - - country = Country(code=str(country_code), name=get_country_name(str(country_code))) - - # 获取 Kudosu 信息 - kudosu = Kudosu(available=0, total=0) - - # 获取计数信息 - # counts = LazerUserCounts(user_id=user_id) - - # 转换统计信息 - statistics = Statistics( - count_100=user_stats.count_100, - count_300=user_stats.count_300, - count_50=user_stats.count_50, - count_miss=user_stats.count_miss, - level=Level( - current=user_stats.level_current, progress=user_stats.level_progress - ), - global_rank=user_stats.global_rank, - global_rank_exp=user_stats.global_rank_exp, - pp=float(user_stats.pp) if user_stats.pp else 0.0, - pp_exp=float(user_stats.pp_exp) if user_stats.pp_exp else 0.0, - ranked_score=user_stats.ranked_score, - hit_accuracy=float(user_stats.hit_accuracy) if user_stats.hit_accuracy else 0.0, - play_count=user_stats.play_count, - play_time=user_stats.play_time, - total_score=user_stats.total_score, - total_hits=user_stats.total_hits, - maximum_combo=user_stats.maximum_combo, - replays_watched_by_others=user_stats.replays_watched_by_others, - is_ranked=user_stats.is_ranked, - grade_counts=GradeCounts( - ss=user_stats.grade_ss, - ssh=user_stats.grade_ssh, - s=user_stats.grade_s, - sh=user_stats.grade_sh, - a=user_stats.grade_a, - ), - country_rank=user_stats.country_rank, - rank={"country": user_stats.country_rank} if user_stats.country_rank else None, - ) - - # 转换所有模式的统计信息 - statistics_rulesets = {} - if db_user.lazer_statistics: - for stat in db_user.lazer_statistics: - statistics_rulesets[stat.mode] = Statistics( - count_100=stat.count_100, - count_300=stat.count_300, - count_50=stat.count_50, - count_miss=stat.count_miss, - level=Level(current=stat.level_current, progress=stat.level_progress), - global_rank=stat.global_rank, - global_rank_exp=stat.global_rank_exp, - pp=float(stat.pp) if stat.pp else 0.0, - pp_exp=float(stat.pp_exp) if stat.pp_exp else 0.0, - ranked_score=stat.ranked_score, - hit_accuracy=float(stat.hit_accuracy) if stat.hit_accuracy else 0.0, - play_count=stat.play_count, - play_time=stat.play_time, - total_score=stat.total_score, - total_hits=stat.total_hits, - maximum_combo=stat.maximum_combo, - replays_watched_by_others=stat.replays_watched_by_others, - is_ranked=stat.is_ranked, - grade_counts=GradeCounts( - ss=stat.grade_ss, - ssh=stat.grade_ssh, - s=stat.grade_s, - sh=stat.grade_sh, - a=stat.grade_a, - ), - country_rank=stat.country_rank, - rank={"country": stat.country_rank} if stat.country_rank else None, - ) - - # 转换国家信息 - country = Country(code=user_country_code, name=get_country_name(user_country_code)) - - # 转换封面信息 - cover_url = ( - profile.cover_url - if profile and profile.cover_url - else "https://assets.ppy.sh/user-profile-covers/default.jpeg" - ) - cover = Cover( - custom_url=profile.cover_url if profile else None, url=str(cover_url), id=None - ) - - # 转换 Kudosu 信息 - kudosu = Kudosu(available=0, total=0) - - # 转换成就信息 - user_achievements = [] - if db_user.lazer_achievements: - for achievement in db_user.lazer_achievements: - user_achievements.append( - UserAchievement( - achieved_at=achievement.achieved_at, - achievement_id=achievement.achievement_id, - ) - ) - - # 转换排名历史 - rank_history = None - rank_history_data = None - for rh in db_user.rank_history: - if rh.mode == ruleset: - rank_history_data = rh.rank_data - break - - if rank_history_data: - rank_history = RankHistory(mode=ruleset, data=rank_history_data) - - # 转换每日挑战统计 - # daily_challenge_stats = None - # if db_user.daily_challenge_stats: - # dcs = db_user.daily_challenge_stats - # daily_challenge_stats = DailyChallengeStats( - # daily_streak_best=dcs.daily_streak_best, - # daily_streak_current=dcs.daily_streak_current, - # last_update=dcs.last_update, - # last_weekly_streak=dcs.last_weekly_streak, - # playcount=dcs.playcount, - # top_10p_placements=dcs.top_10p_placements, - # top_50p_placements=dcs.top_50p_placements, - # user_id=dcs.user_id, - # weekly_streak_best=dcs.weekly_streak_best, - # weekly_streak_current=dcs.weekly_streak_current, - # ) - - # 转换最高排名 - rank_highest = None - if user_stats.rank_highest: - rank_highest = RankHighest( - rank=user_stats.rank_highest, - updated_at=user_stats.rank_highest_updated_at or datetime.utcnow(), - ) - - # 转换团队信息 - team = None - if db_user.team_membership: - team_member = db_user.team_membership # 假设用户只属于一个团队 - team = team_member.team - - # 创建用户对象 - # 从db_user获取基本字段值 - user_id = getattr(db_user, "id") - user_name = getattr(db_user, "name") - user_country = getattr(db_user, "country") - - # 获取用户头像URL - avatar_url = None - - # 首先检查 profile 中的 avatar_url - if profile and hasattr(profile, "avatar_url") and profile.avatar_url: - avatar_url = str(profile.avatar_url) - - # 然后检查是否有关联的头像记录 - if avatar_url is None and hasattr(db_user, "avatar") and db_user.avatar is not None: - if db_user.avatar.r2_game_url: - # 优先使用游戏用的头像URL - avatar_url = str(db_user.avatar.r2_game_url) - elif db_user.avatar.r2_original_url: - # 其次使用原始头像URL - avatar_url = str(db_user.avatar.r2_original_url) - - # 如果还是没有找到,通过查询获取 - # if db_session and avatar_url is None: - # try: - # # 导入UserAvatar模型 - - # # 尝试查找用户的头像记录 - # statement = select(UserAvatar).where( - # UserAvatar.user_id == user_id, UserAvatar.is_active == True - # ) - # avatar_record = db_session.exec(statement).first() - # if avatar_record is not None: - # if avatar_record.r2_game_url is not None: - # # 优先使用游戏用的头像URL - # avatar_url = str(avatar_record.r2_game_url) - # elif avatar_record.r2_original_url is not None: - # # 其次使用原始头像URL - # avatar_url = str(avatar_record.r2_original_url) - # except Exception as e: - # print(f"获取用户头像时出错: {e}") - # print(f"最终头像URL: {avatar_url}") - # 如果仍然没有找到头像URL,则使用默认URL - if avatar_url is None: - avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1" - - # 处理 profile_order 列表排序 - profile_order = [ - "me", - "recent_activity", - "top_ranks", - "medals", - "historical", - "beatmaps", - "kudosu", - ] - if profile and profile.profile_order: - profile_order = profile.profile_order.split(",") - - # 在convert_db_user_to_api_user函数中添加active_tournament_banners处理 - active_tournament_banners = [] - if db_user.active_banners: - for banner in db_user.active_banners: - active_tournament_banners.append( - { - "tournament_id": banner.tournament_id, - "image_url": banner.image_url, - "is_active": banner.is_active, - } - ) - - # 在convert_db_user_to_api_user函数中添加badges处理 - badges = [] - if db_user.lazer_badges: - for badge in db_user.lazer_badges: - badges.append( - { - "badge_id": badge.badge_id, - "awarded_at": badge.awarded_at, - "description": badge.description, - "image_url": badge.image_url, - "url": badge.url, - } - ) - - # 在convert_db_user_to_api_user函数中添加monthly_playcounts处理 - monthly_playcounts = [] - if db_user.lazer_monthly_playcounts: - for playcount in db_user.lazer_monthly_playcounts: - monthly_playcounts.append( - { - "start_date": playcount.start_date.isoformat() - if playcount.start_date - else None, - "play_count": playcount.play_count, - } - ) - - # 在convert_db_user_to_api_user函数中添加previous_usernames处理 - previous_usernames = [] - if db_user.lazer_previous_usernames: - for username in db_user.lazer_previous_usernames: - previous_usernames.append( - { - "username": username.username, - "changed_at": username.changed_at.isoformat() - if username.changed_at - else None, - } - ) - - # 在convert_db_user_to_api_user函数中添加replays_watched_counts处理 - replays_watched_counts = [] - if hasattr(db_user, "lazer_replays_watched") and db_user.lazer_replays_watched: - for replay in db_user.lazer_replays_watched: - replays_watched_counts.append( - { - "start_date": replay.start_date.isoformat() - if replay.start_date - else None, - "count": replay.count, - } - ) - - # 创建用户对象 - user = User( - id=user_id, - username=user_name, - avatar_url=avatar_url, - country_code=str(country_code), - default_group=profile.default_group if profile else "default", - is_active=profile.is_active, - is_bot=profile.is_bot, - is_deleted=profile.is_deleted, - is_online=profile.is_online, - is_supporter=profile.is_supporter, - is_restricted=profile.is_restricted, - last_visit=db_user.last_visit, - pm_friends_only=profile.pm_friends_only, - profile_colour=profile.profile_colour, - cover_url=profile.cover_url - if profile and profile.cover_url - else "https://assets.ppy.sh/user-profile-covers/default.jpeg", - discord=profile.discord if profile else None, - has_supported=profile.has_supported if profile else False, - interests=profile.interests if profile else None, - join_date=profile.join_date if profile.join_date else datetime.now(UTC), - location=profile.location if profile else None, - max_blocks=profile.max_blocks if profile and profile.max_blocks else 100, - max_friends=profile.max_friends if profile and profile.max_friends else 500, - post_count=profile.post_count if profile and profile.post_count else 0, - profile_hue=profile.profile_hue if profile and profile.profile_hue else None, - profile_order=profile_order, # 使用排序后的 profile_order - title=profile.title if profile else None, - title_url=profile.title_url if profile else None, - twitter=profile.twitter if profile else None, - website=profile.website if profile else None, - session_verified=True, - support_level=profile.support_level if profile else 0, - country=country, - cover=cover, - kudosu=kudosu, - statistics=statistics, - statistics_rulesets=statistics_rulesets, - beatmap_playcounts_count=lzrcnt.beatmap_playcounts_count if lzrcnt else 0, - comments_count=lzrcnt.comments_count if lzrcnt else 0, - favourite_beatmapset_count=lzrcnt.favourite_beatmapset_count if lzrcnt else 0, - follower_count=lzrcnt.follower_count if lzrcnt else 0, - graveyard_beatmapset_count=lzrcnt.graveyard_beatmapset_count if lzrcnt else 0, - guest_beatmapset_count=lzrcnt.guest_beatmapset_count if lzrcnt else 0, - loved_beatmapset_count=lzrcnt.loved_beatmapset_count if lzrcnt else 0, - mapping_follower_count=lzrcnt.mapping_follower_count if lzrcnt else 0, - nominated_beatmapset_count=lzrcnt.nominated_beatmapset_count if lzrcnt else 0, - pending_beatmapset_count=lzrcnt.pending_beatmapset_count if lzrcnt else 0, - ranked_beatmapset_count=lzrcnt.ranked_beatmapset_count if lzrcnt else 0, - ranked_and_approved_beatmapset_count=lzrcnt.ranked_and_approved_beatmapset_count - if lzrcnt - else 0, - unranked_beatmapset_count=lzrcnt.unranked_beatmapset_count if lzrcnt else 0, - scores_best_count=lzrcnt.scores_best_count if lzrcnt else 0, - scores_first_count=lzrcnt.scores_first_count if lzrcnt else 0, - scores_pinned_count=lzrcnt.scores_pinned_count, - scores_recent_count=lzrcnt.scores_recent_count if lzrcnt else 0, - account_history=[], # TODO: 获取用户历史账户信息 - # active_tournament_banner=len(active_tournament_banners), - active_tournament_banners=active_tournament_banners, - badges=badges, - current_season_stats=None, - daily_challenge_user_stats=DailyChallengeStats( - user_id=user_id, - daily_streak_best=db_user.daily_challenge_stats.daily_streak_best - if db_user.daily_challenge_stats - else 0, - daily_streak_current=db_user.daily_challenge_stats.daily_streak_current - if db_user.daily_challenge_stats - else 0, - last_update=db_user.daily_challenge_stats.last_update - if db_user.daily_challenge_stats - else None, - last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak - if db_user.daily_challenge_stats - else None, - playcount=db_user.daily_challenge_stats.playcount - if db_user.daily_challenge_stats - else 0, - top_10p_placements=db_user.daily_challenge_stats.top_10p_placements - if db_user.daily_challenge_stats - else 0, - top_50p_placements=db_user.daily_challenge_stats.top_50p_placements - if db_user.daily_challenge_stats - else 0, - weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best - if db_user.daily_challenge_stats - else 0, - weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current - if db_user.daily_challenge_stats - else 0, - ), - groups=[], - monthly_playcounts=monthly_playcounts, - page=Page(html=profile.page_html or "", raw=profile.page_raw or "") - if profile.page_html or profile.page_raw - else Page(), - previous_usernames=previous_usernames, - rank_highest=rank_highest, - rank_history=rank_history, - rankHistory=rank_history, - replays_watched_counts=replays_watched_counts, - team=team, - user_achievements=user_achievements, - ) - - return user - - -def get_country_name(country_code: str) -> str: - """根据国家代码获取国家名称""" - country_names = { - "CN": "China", - "JP": "Japan", - "US": "United States", - "GB": "United Kingdom", - "DE": "Germany", - "FR": "France", - "KR": "South Korea", - "CA": "Canada", - "AU": "Australia", - "BR": "Brazil", - # 可以添加更多国家 - } - return country_names.get(country_code, "Unknown") diff --git a/main.py b/main.py index 526d593..72444ef 100644 --- a/main.py +++ b/main.py @@ -4,25 +4,22 @@ from contextlib import asynccontextmanager from datetime import datetime from app.config import settings -from app.database import Team # noqa: F401 -from app.dependencies.database import create_tables, engine +from app.dependencies.database import create_tables, engine, redis_client from app.dependencies.fetcher import get_fetcher -from app.models.user import User from app.router import api_router, auth_router, fetcher_router, signalr_router from fastapi import FastAPI -User.model_rebuild() - @asynccontextmanager async def lifespan(app: FastAPI): # on startup await create_tables() - get_fetcher() # 初始化 fetcher + await get_fetcher() # 初始化 fetcher # on shutdown yield await engine.dispose() + await redis_client.aclose() app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan) @@ -44,104 +41,6 @@ async def health_check(): return {"status": "ok", "timestamp": datetime.utcnow().isoformat()} -# @app.get("/api/v2/friends") -# async def get_friends(): -# return JSONResponse( -# content=[ -# { -# "id": 123456, -# "username": "BestFriend", -# "is_online": True, -# "is_supporter": False, -# "country": {"code": "US", "name": "United States"}, -# } -# ] -# ) - - -# @app.get("/api/v2/notifications") -# async def get_notifications(): -# return JSONResponse(content={"notifications": [], "unread_count": 0}) - - -# @app.post("/api/v2/chat/ack") -# async def chat_ack(): -# return JSONResponse(content={"status": "ok"}) - - -# @app.get("/api/v2/users/{user_id}/{mode}") -# async def get_user_mode(user_id: int, mode: str): -# return JSONResponse( -# content={ -# "id": user_id, -# "username": "测试测试测", -# "statistics": { -# "level": {"current": 97, "progress": 96}, -# "pp": 114514, -# "global_rank": 666, -# "country_rank": 1, -# "hit_accuracy": 100, -# }, -# "country": {"code": "JP", "name": "Japan"}, -# } -# ) - - -# @app.get("/api/v2/me") -# async def get_me(): -# return JSONResponse( -# content={ -# "id": 15651670, -# "username": "Googujiang", -# "is_online": True, -# "country": {"code": "JP", "name": "Japan"}, -# "statistics": { -# "level": {"current": 97, "progress": 96}, -# "pp": 2826.26, -# "global_rank": 298026, -# "country_rank": 11220, -# "hit_accuracy": 95.7168, -# }, -# } -# ) - - -# @app.post("/signalr/metadata/negotiate") -# async def metadata_negotiate(negotiateVersion: int = 1): -# return JSONResponse( -# content={ -# "connectionId": "abc123", -# "availableTransports": [ -# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} -# ], -# } -# ) - - -# @app.post("/signalr/spectator/negotiate") -# async def spectator_negotiate(negotiateVersion: int = 1): -# return JSONResponse( -# content={ -# "connectionId": "spec456", -# "availableTransports": [ -# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} -# ], -# } -# ) - - -# @app.post("/signalr/multiplayer/negotiate") -# async def multiplayer_negotiate(negotiateVersion: int = 1): -# return JSONResponse( -# content={ -# "connectionId": "multi789", -# "availableTransports": [ -# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} -# ], -# } -# ) - - if __name__ == "__main__": from app.log import logger # noqa: F401 diff --git a/migrations/versions/78be13c71791_score_remove_best_id_in_database.py b/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py similarity index 62% rename from migrations/versions/78be13c71791_score_remove_best_id_in_database.py rename to migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py index d0cab2b..84bae15 100644 --- a/migrations/versions/78be13c71791_score_remove_best_id_in_database.py +++ b/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py @@ -1,8 +1,8 @@ -"""score: remove best_id in database +"""beatmapset: support favourite count -Revision ID: 78be13c71791 -Revises: dc4d25c428c7 -Create Date: 2025-07-29 07:57:33.764517 +Revision ID: 1178d0758ebf +Revises: +Create Date: 2025-08-01 04:05:09.882800 """ @@ -15,8 +15,8 @@ import sqlalchemy as sa from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision: str = "78be13c71791" -down_revision: str | Sequence[str] | None = "dc4d25c428c7" +revision: str = "1178d0758ebf" +down_revision: str | Sequence[str] | None = None branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None @@ -24,7 +24,7 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("scores", "best_id") + op.drop_column("beatmapsets", "favourite_count") # ### end Alembic commands ### @@ -32,7 +32,9 @@ def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### op.add_column( - "scores", - sa.Column("best_id", mysql.INTEGER(), autoincrement=False, nullable=True), + "beatmapsets", + sa.Column( + "favourite_count", mysql.INTEGER(), autoincrement=False, nullable=False + ), ) # ### end Alembic commands ### diff --git a/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py b/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py new file mode 100644 index 0000000..e383621 --- /dev/null +++ b/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py @@ -0,0 +1,54 @@ +"""relationship: fix unique relationship + +Revision ID: 58a11441d302 +Revises: 1178d0758ebf +Create Date: 2025-08-01 04:23:02.498166 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "58a11441d302" +down_revision: str | Sequence[str] | None = "1178d0758ebf" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "relationship", + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + ) + op.drop_constraint("PRIMARY", "relationship", type_="primary") + op.create_primary_key("pk_relationship", "relationship", ["id"]) + op.alter_column( + "relationship", "user_id", existing_type=mysql.BIGINT(), nullable=True + ) + op.alter_column( + "relationship", "target_id", existing_type=mysql.BIGINT(), nullable=True + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("pk_relationship", "relationship", type_="primary") + op.create_primary_key("PRIMARY", "relationship", ["user_id", "target_id"]) + op.alter_column( + "relationship", "target_id", existing_type=mysql.BIGINT(), nullable=False + ) + op.alter_column( + "relationship", "user_id", existing_type=mysql.BIGINT(), nullable=False + ) + op.drop_column("relationship", "id") + # ### end Alembic commands ### diff --git a/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py b/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py deleted file mode 100644 index d90ec3d..0000000 --- a/migrations/versions/dc4d25c428c7_score_add_nlarge_tick_hit_nsmall_tick_.py +++ /dev/null @@ -1,36 +0,0 @@ -"""score: add nlarge_tick_hit & nsmall_tick_hit for pp calculator - -Revision ID: dc4d25c428c7 -Revises: -Create Date: 2025-07-29 01:43:40.221070 - -""" - -from __future__ import annotations - -from collections.abc import Sequence - -from alembic import op -import sqlalchemy as sa - -# revision identifiers, used by Alembic. -revision: str = "dc4d25c428c7" -down_revision: str | Sequence[str] | None = None -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - """Upgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.add_column("scores", sa.Column("nlarge_tick_hit", sa.Integer(), nullable=True)) - op.add_column("scores", sa.Column("nsmall_tick_hit", sa.Integer(), nullable=True)) - # ### end Alembic commands ### - - -def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("scores", "nsmall_tick_hit") - op.drop_column("scores", "nlarge_tick_hit") - # ### end Alembic commands ###