From be401e88850b79a48398c94cfdf9b7fc820c3779 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 31 Jul 2025 20:11:22 +0800 Subject: [PATCH 1/6] =?UTF-8?q?refactor(database):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E5=85=B3=E8=81=94=E5=AF=B9=E8=B1=A1?= =?UTF-8?q?=E7=9A=84=E8=BD=BD=E5=85=A5=20(#10)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/database/beatmap.py | 23 ++++++------------ app/database/beatmapset.py | 9 ++++---- app/database/lazer_user.py | 43 +++++++++++++--------------------- app/database/relationship.py | 6 +++-- app/database/score.py | 45 ++++++++---------------------------- app/database/team.py | 8 +++++-- app/dependencies/user.py | 8 +------ app/models/beatmap.py | 11 +++++---- app/router/beatmap.py | 33 +++++--------------------- app/router/beatmapset.py | 11 ++------- app/router/relationship.py | 13 +++-------- app/router/score.py | 12 ++++------ app/router/user.py | 17 +++----------- 13 files changed, 73 insertions(+), 166 deletions(-) diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 751bc5c..2ab5ad0 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -8,7 +8,6 @@ from app.models.score import MODE_TO_INT, GameMode from .beatmapset import Beatmapset, BeatmapsetResp from sqlalchemy import DECIMAL, Column, DateTime -from sqlalchemy.orm import joinedload from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -67,7 +66,9 @@ class Beatmap(BeatmapBase, table=True): beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) beatmap_status: BeatmapRankStatus # optional - beatmapset: Beatmapset = Relationship(back_populates="beatmaps") + beatmapset: Beatmapset = Relationship( + back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"} + ) @property def can_ranked(self) -> bool: @@ -88,13 +89,7 @@ class Beatmap(BeatmapBase, table=True): session.add(beatmap) await session.commit() beatmap = ( - await session.exec( - select(Beatmap) - .options( - joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] - ) - .where(Beatmap.id == resp.id) - ) + await session.exec(select(Beatmap).where(Beatmap.id == resp.id)) ).first() assert beatmap is not None, "Beatmap should not be None after commit" return beatmap @@ -132,13 +127,9 @@ class Beatmap(BeatmapBase, table=True): ) -> "Beatmap": beatmap = ( await session.exec( - select(Beatmap) - .where( + select(Beatmap).where( Beatmap.id == bid if bid is not None else Beatmap.checksum == md5 ) - .options( - joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] - ) ) ).first() if not beatmap: @@ -165,7 +156,7 @@ class BeatmapResp(BeatmapBase): url: str = "" @classmethod - def from_db( + async def from_db( cls, beatmap: Beatmap, query_mode: GameMode | None = None, @@ -179,5 +170,5 @@ class BeatmapResp(BeatmapBase): beatmap_["ranked"] = beatmap.beatmap_status.value beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode] if not from_set: - beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset) + beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset) return cls.model_validate(beatmap_) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 2ef6280..5a618b7 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -7,6 +7,7 @@ from app.models.score import GameMode from pydantic import BaseModel, model_serializer from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text +from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import Field, Relationship, SQLModel from sqlmodel.ext.asyncio.session import AsyncSession @@ -130,7 +131,7 @@ class BeatmapsetBase(SQLModel, UTCBaseModel): tags: str = Field(default="", sa_column=Column(Text)) -class Beatmapset(BeatmapsetBase, table=True): +class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): __tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) @@ -200,12 +201,12 @@ class BeatmapsetResp(BeatmapsetBase): nominations: BeatmapNominations | None = None @classmethod - def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": + async def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": from .beatmap import BeatmapResp beatmaps = [ - BeatmapResp.from_db(beatmap, from_set=True) - for beatmap in beatmapset.beatmaps + await BeatmapResp.from_db(beatmap, from_set=True) + for beatmap in await beatmapset.awaitable_attrs.beatmaps ] return cls.model_validate( { diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 9b98c98..d502ccb 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -12,7 +12,7 @@ from .statistics import UserStatistics, UserStatisticsResp from .team import Team, TeamMember from .user_account_history import UserAccountHistory, UserAccountHistoryResp -from sqlalchemy.orm import joinedload, selectinload +from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import ( JSON, BigInteger, @@ -128,7 +128,7 @@ class UserBase(UTCBaseModel, SQLModel): is_bng: bool = False -class User(UserBase, table=True): +class User(AsyncAttrs, UserBase, table=True): __tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType] id: int | None = Field( @@ -154,17 +154,6 @@ class User(UserBase, table=True): default=None, sa_column=Column(DateTime(timezone=True)), exclude=True ) - @classmethod - def all_select_option(cls): - return ( - selectinload(cls.account_history), # pyright: ignore[reportArgumentType] - selectinload(cls.statistics), # pyright: ignore[reportArgumentType] - selectinload(cls.achievement), # pyright: ignore[reportArgumentType] - joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType] - joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType] - selectinload(cls.monthly_playcounts), # pyright: ignore[reportArgumentType] - ) - class UserResp(UserBase): id: int | None = None @@ -249,13 +238,7 @@ class UserResp(UserBase): await RelationshipResp.from_db(session, r) for r in ( await session.exec( - select(Relationship) - .options( - joinedload(Relationship.target).options( # pyright: ignore[reportArgumentType] - *User.all_select_option() - ) - ) - .where( + select(Relationship).where( Relationship.user_id == obj.id, Relationship.type == RelationshipType.FOLLOW, ) @@ -264,23 +247,26 @@ class UserResp(UserBase): ] if "team" in include: - if obj.team_membership: + if await obj.awaitable_attrs.team_membership: + assert obj.team_membership u.team = obj.team_membership.team if "account_history" in include: u.account_history = [ - UserAccountHistoryResp.from_db(ah) for ah in obj.account_history + UserAccountHistoryResp.from_db(ah) + for ah in await obj.awaitable_attrs.account_history ] if "daily_challenge_user_stats": - if obj.daily_challenge_stats: + if await obj.awaitable_attrs.daily_challenge_stats: + assert obj.daily_challenge_stats u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db( obj.daily_challenge_stats ) if "statistics" in include: current_stattistics = None - for i in obj.statistics: + for i in await obj.awaitable_attrs.statistics: if i.mode == (ruleset or obj.playmode): current_stattistics = i break @@ -292,17 +278,20 @@ class UserResp(UserBase): if "statistics_rulesets" in include: u.statistics_rulesets = { - i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics + i.mode.value: UserStatisticsResp.from_db(i) + for i in await obj.awaitable_attrs.statistics } if "monthly_playcounts" in include: u.monthly_playcounts = [ - MonthlyPlaycountsResp.from_db(pc) for pc in obj.monthly_playcounts + MonthlyPlaycountsResp.from_db(pc) + for pc in await obj.awaitable_attrs.monthly_playcounts ] if "achievements" in include: u.user_achievements = [ - UserAchievementResp.from_db(ua) for ua in obj.achievement + UserAchievementResp.from_db(ua) + for ua in await obj.awaitable_attrs.achievement ] return u diff --git a/app/database/relationship.py b/app/database/relationship.py index 07daa25..7a351aa 100644 --- a/app/database/relationship.py +++ b/app/database/relationship.py @@ -42,7 +42,10 @@ class Relationship(SQLModel, table=True): ) type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) target: User = SQLRelationship( - sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"} + sa_relationship_kwargs={ + "foreign_keys": "[Relationship.target_id]", + "lazy": "selectin", + } ) @@ -79,7 +82,6 @@ class RelationshipResp(BaseModel): "daily_challenge_user_stats", "statistics", "statistics_rulesets", - "achievements", ], ), mutual=mutual, diff --git a/app/database/score.py b/app/database/score.py index 32c8cf5..4cee832 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -27,7 +27,7 @@ from app.models.score import ( ) from .beatmap import Beatmap, BeatmapResp -from .beatmapset import Beatmapset, BeatmapsetResp +from .beatmapset import BeatmapsetResp from .best_score import BestScore from .lazer_user import User, UserResp from .monthly_playcounts import MonthlyPlaycounts @@ -35,7 +35,8 @@ from .score_token import ScoreToken from redis import Redis from sqlalchemy import Column, ColumnExpressionArgument, DateTime -from sqlalchemy.orm import aliased, joinedload +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import aliased from sqlmodel import ( JSON, BigInteger, @@ -55,7 +56,7 @@ if TYPE_CHECKING: from app.fetcher import Fetcher -class ScoreBase(SQLModel, UTCBaseModel): +class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): # 基本字段 accuracy: float map_md5: str = Field(max_length=32, index=True) @@ -114,27 +115,12 @@ class Score(ScoreBase, table=True): # optional beatmap: Beatmap = Relationship() - user: User = Relationship() + user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) @property def is_perfect_combo(self) -> bool: return self.max_combo == self.beatmap.max_combo - @staticmethod - def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]: - clause = select(Score).options( - joinedload(Score.beatmap) # pyright: ignore[reportArgumentType] - .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType] - .selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ), - ) - if with_user: - return clause.options( - joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType] - ) - return clause - @staticmethod def select_clause_unique( *where_clauses: ColumnExpressionArgument[bool] | bool, @@ -148,18 +134,7 @@ class Score(ScoreBase, table=True): ) subq = select(Score, rownum).where(*where_clauses).subquery() best = aliased(Score, subq, adapt_on_names=True) - return ( - select(best) - .where(subq.c.rn == 1) - .options( - joinedload(best.beatmap) # pyright: ignore[reportArgumentType] - .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType] - .selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ), - joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType] - ) - ) + return select(best).where(subq.c.rn == 1) class ScoreResp(ScoreBase): @@ -186,8 +161,9 @@ class ScoreResp(ScoreBase): ) -> "ScoreResp": s = cls.model_validate(score.model_dump()) assert score.id - s.beatmap = BeatmapResp.from_db(score.beatmap) - s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset) + await score.awaitable_attrs.beatmap + s.beatmap = await BeatmapResp.from_db(score.beatmap) + s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset) s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo s.ruleset_id = MODE_TO_INT[score.gamemode] @@ -303,7 +279,6 @@ async def get_leaderboard( query = ( select(Score) .join(Beatmap) - .options(joinedload(Score.user)) # pyright: ignore[reportArgumentType] .where( Score.map_md5 == beatmap_md5, Score.gamemode == mode, @@ -452,7 +427,7 @@ async def get_user_best_score_in_beatmap( ) -> Score | None: return ( await session.exec( - Score.select_clause(False) + select(Score) .where( Score.gamemode == mode if mode is not None else True, Score.beatmap_id == beatmap, diff --git a/app/database/team.py b/app/database/team.py index 146ca9f..562b0c8 100644 --- a/app/database/team.py +++ b/app/database/team.py @@ -34,5 +34,9 @@ class TeamMember(SQLModel, UTCBaseModel, table=True): default_factory=datetime.utcnow, sa_column=Column(DateTime) ) - user: "User" = Relationship(back_populates="team_membership") - team: "Team" = Relationship(back_populates="members") + user: "User" = Relationship( + back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"} + ) + team: "Team" = Relationship( + back_populates="members", sa_relationship_kwargs={"lazy": "joined"} + ) diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 769247c..5537f4f 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -30,11 +30,5 @@ async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None token_record = await get_token_by_access_token(db, token) if not token_record: return None - user = ( - await db.exec( - select(User) - .options(*User.all_select_option()) - .where(User.id == token_record.user_id) - ) - ).first() + user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() return user diff --git a/app/models/beatmap.py b/app/models/beatmap.py index 4f12e13..fae18ba 100644 --- a/app/models/beatmap.py +++ b/app/models/beatmap.py @@ -42,11 +42,12 @@ class Language(IntEnum): KOREAN = 6 FRENCH = 7 GERMAN = 8 - ITALIAN = 9 - SPANISH = 10 - RUSSIAN = 11 - POLISH = 12 - OTHER = 13 + SWEDISH = 9 + ITALIAN = 10 + SPANISH = 11 + RUSSIAN = 12 + POLISH = 13 + OTHER = 14 class BeatmapAttributes(BaseModel): diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 4af9c9a..df5f20d 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -5,7 +5,7 @@ import hashlib import json from app.calculator import calculate_beatmap_attribute -from app.database import Beatmap, BeatmapResp, Beatmapset, User +from app.database import Beatmap, BeatmapResp, User from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -24,7 +24,6 @@ from httpx import HTTPError, HTTPStatusError from pydantic import BaseModel from redis import Redis import rosu_pp_py as rosu -from sqlalchemy.orm import joinedload from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -51,7 +50,7 @@ async def lookup_beatmap( if beatmap is None: raise HTTPException(status_code=404, detail="Beatmap not found") - return BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap) @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) @@ -63,7 +62,7 @@ async def get_beatmap( ): try: beatmap = await Beatmap.get_or_fetch(db, fetcher, bid) - return BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap) except HTTPError: raise HTTPException(status_code=404, detail="Beatmap not found") @@ -83,35 +82,15 @@ async def batch_get_beatmaps( # select 50 beatmaps by last_updated beatmaps = ( await db.exec( - select(Beatmap) - .options( - joinedload( - Beatmap.beatmapset # pyright: ignore[reportArgumentType] - ).selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ) - ) - .order_by(col(Beatmap.last_updated).desc()) - .limit(50) + select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50) ) ).all() else: beatmaps = ( - await db.exec( - select(Beatmap) - .options( - joinedload( - Beatmap.beatmapset # pyright: ignore[reportArgumentType] - ).selectinload( - Beatmapset.beatmaps # pyright: ignore[reportArgumentType] - ) - ) - .where(col(Beatmap.id).in_(b_ids)) - .limit(50) - ) + await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)) ).all() - return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps]) + return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm) for bm in beatmaps]) @router.post( diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index 80396fd..b82678d 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -15,7 +15,6 @@ from .api_router import router from fastapi import Depends, HTTPException, Query from fastapi.responses import RedirectResponse from httpx import HTTPStatusError -from sqlalchemy.orm import selectinload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -27,13 +26,7 @@ async def get_beatmapset( db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): - beatmapset = ( - await db.exec( - select(Beatmapset) - .options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType] - .where(Beatmapset.id == sid) - ) - ).first() + beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first() if not beatmapset: try: resp = await fetcher.get_beatmapset(sid) @@ -41,7 +34,7 @@ async def get_beatmapset( except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmapset not found") else: - resp = BeatmapsetResp.from_db(beatmapset) + resp = await BeatmapsetResp.from_db(beatmapset) return resp diff --git a/app/router/relationship.py b/app/router/relationship.py index 9e39e8b..02292c9 100644 --- a/app/router/relationship.py +++ b/app/router/relationship.py @@ -9,7 +9,6 @@ from .api_router import router from fastapi import Depends, HTTPException, Query, Request from pydantic import BaseModel -from sqlalchemy.orm import joinedload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -27,14 +26,12 @@ async def get_relationship( else RelationshipType.BLOCK ) relationships = await db.exec( - select(Relationship) - .options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType] - .where( + select(Relationship).where( Relationship.user_id == current_user.id, Relationship.type == relationship_type, ) ) - return [await RelationshipResp.from_db(db, rel) for rel in relationships] + return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()] class AddFriendResp(BaseModel): @@ -92,14 +89,10 @@ async def add_relationship( if origin_type == RelationshipType.FOLLOW: relationship = ( await db.exec( - select(Relationship) - .where( + select(Relationship).where( Relationship.user_id == current_user_id, Relationship.target_id == target, ) - .options( - joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType] - ) ) ).first() assert relationship, "Relationship should exist after commit" diff --git a/app/router/score.py b/app/router/score.py index baab3a2..cd0a236 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -99,7 +99,7 @@ async def get_user_beatmap_score( ) user_score = ( await db.exec( - Score.select_clause(True) + select(Score) .where( Score.gamemode == mode if mode is not None else True, Score.beatmap_id == beatmap, @@ -139,7 +139,7 @@ async def get_user_all_beatmap_scores( ) all_user_scores = ( await db.exec( - Score.select_clause() + select(Score) .where( Score.gamemode == ruleset if ruleset is not None else True, Score.beatmap_id == beatmap, @@ -207,9 +207,7 @@ async def submit_solo_score( if score_token.score_id: score = ( await db.exec( - select(Score) - .options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType] - .where( + select(Score).where( Score.id == score_token.score_id, Score.user_id == current_user.id, ) @@ -243,8 +241,6 @@ async def submit_solo_score( score_id = score.id score_token.score_id = score_id await process_user(db, current_user, score, ranked) - score = ( - await db.exec(Score.select_clause().where(Score.id == score_id)) - ).first() + score = (await db.exec(select(Score).where(Score.id == score_id))).first() assert score is not None return await ScoreResp.from_db(db, score, current_user) diff --git a/app/router/user.py b/app/router/user.py index 3df5a49..649f1d4 100644 --- a/app/router/user.py +++ b/app/router/user.py @@ -28,19 +28,10 @@ async def get_users( ): if user_ids: searched_users = ( - await session.exec( - select(User) - .options(*User.all_select_option()) - .limit(50) - .where(col(User.id).in_(user_ids)) - ) + await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids))) ).all() else: - searched_users = ( - await session.exec( - select(User).options(*User.all_select_option()).limit(50) - ) - ).all() + searched_users = (await session.exec(select(User).limit(50))).all() return BatchUserResponse( users=[ await UserResp.from_db( @@ -63,9 +54,7 @@ async def get_user_info( ): searched_user = ( await session.exec( - select(User) - .options(*User.all_select_option()) - .where( + select(User).where( User.id == int(user) if user.isdigit() else User.username == user.removeprefix("@") From 1635641654860591f069a492f0b4d0b0cd23009b Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 31 Jul 2025 14:11:42 +0000 Subject: [PATCH 2/6] feat(score): support leaderboard for country/friends/team/selected mods --- app/database/__init__.py | 3 +- app/database/best_score.py | 23 +- app/database/pp_best_score.py | 41 ++++ app/database/score.py | 382 +++++++++++++++++++--------------- app/models/score.py | 2 +- app/router/score.py | 49 ++--- 6 files changed, 284 insertions(+), 216 deletions(-) create mode 100644 app/database/pp_best_score.py diff --git a/app/database/__init__.py b/app/database/__init__.py index 91bc7cc..12fa867 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -14,6 +14,7 @@ from .lazer_user import ( User, UserResp, ) +from .pp_best_score import PPBestScore from .relationship import Relationship, RelationshipResp, RelationshipType from .score import ( Score, @@ -35,13 +36,13 @@ from .user_account_history import ( __all__ = [ "Beatmap", - "BeatmapResp", "Beatmapset", "BeatmapsetResp", "BestScore", "DailyChallengeStats", "DailyChallengeStatsResp", "OAuthToken", + "PPBestScore", "Relationship", "RelationshipResp", "RelationshipType", diff --git a/app/database/best_score.py b/app/database/best_score.py index 9993b63..42b0024 100644 --- a/app/database/best_score.py +++ b/app/database/best_score.py @@ -1,14 +1,14 @@ from typing import TYPE_CHECKING -from app.models.score import GameMode +from app.models.score import GameMode, Rank from .lazer_user import User from sqlmodel import ( + JSON, BigInteger, Column, Field, - Float, ForeignKey, Relationship, SQLModel, @@ -20,7 +20,7 @@ if TYPE_CHECKING: class BestScore(SQLModel, table=True): - __tablename__ = "best_scores" # pyright: ignore[reportAssignmentType] + __tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType] user_id: int = Field( sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) ) @@ -29,13 +29,20 @@ class BestScore(SQLModel, table=True): ) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) gamemode: GameMode = Field(index=True) - pp: float = Field( - sa_column=Column(Float, default=0), + total_score: int = Field( + default=0, sa_column=Column(BigInteger, ForeignKey("scores.total_score")) ) - acc: float = Field( - sa_column=Column(Float, default=0), + mods: list[str] = Field( + default_factory=list, + sa_column=Column(JSON), ) + rank: Rank user: User = Relationship() - score: "Score" = Relationship() + score: "Score" = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[BestScore.score_id]", + "lazy": "joined", + } + ) beatmap: "Beatmap" = Relationship() diff --git a/app/database/pp_best_score.py b/app/database/pp_best_score.py new file mode 100644 index 0000000..ffc74d3 --- /dev/null +++ b/app/database/pp_best_score.py @@ -0,0 +1,41 @@ +from typing import TYPE_CHECKING + +from app.models.score import GameMode + +from .lazer_user import User + +from sqlmodel import ( + BigInteger, + Column, + Field, + Float, + ForeignKey, + Relationship, + SQLModel, +) + +if TYPE_CHECKING: + from .beatmap import Beatmap + from .score import Score + + +class PPBestScore(SQLModel, table=True): + __tablename__ = "best_scores" # pyright: ignore[reportAssignmentType] + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) + score_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) + ) + beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) + gamemode: GameMode = Field(index=True) + pp: float = Field( + sa_column=Column(Float, default=0), + ) + acc: float = Field( + sa_column=Column(Float, default=0), + ) + + user: User = Relationship() + score: "Score" = Relationship() + beatmap: "Beatmap" = Relationship() diff --git a/app/database/score.py b/app/database/score.py index 4cee832..c5f1a38 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -1,6 +1,7 @@ import asyncio from collections.abc import Sequence from datetime import UTC, date, datetime +import json import math from typing import TYPE_CHECKING @@ -12,7 +13,7 @@ from app.calculator import ( calculate_weighted_pp, clamp, ) -from app.models.beatmap import BeatmapRankStatus +from app.database.team import TeamMember from app.models.model import UTCBaseModel from app.models.mods import APIMod, mods_can_get_pp from app.models.score import ( @@ -31,12 +32,18 @@ from .beatmapset import BeatmapsetResp from .best_score import BestScore from .lazer_user import User, UserResp from .monthly_playcounts import MonthlyPlaycounts +from .pp_best_score import PPBestScore +from .relationship import ( + Relationship as DBRelationship, + RelationshipType, +) from .score_token import ScoreToken from redis import Redis from sqlalchemy import Column, ColumnExpressionArgument, DateTime from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import aliased +from sqlalchemy.sql.elements import ColumnElement from sqlmodel import ( JSON, BigInteger, @@ -45,9 +52,10 @@ from sqlmodel import ( Relationship, SQLModel, col, - false, func, select, + text, + true, ) from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.sql._expression_select_cls import SelectOfScalar @@ -156,9 +164,7 @@ class ScoreResp(ScoreBase): rank_country: int | None = None @classmethod - async def from_db( - cls, session: AsyncSession, score: Score, user: User | None = None - ) -> "ScoreResp": + async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp": s = cls.model_validate(score.model_dump()) assert score.id await score.awaitable_attrs.beatmap @@ -195,30 +201,30 @@ class ScoreResp(ScoreBase): s.maximum_statistics = { HitResult.GREAT: score.beatmap.max_combo, } - if user: - s.user = await UserResp.from_db( - user, - session, - include=["statistics", "team", "daily_challenge_user_stats"], - ruleset=score.gamemode, - ) + s.user = await UserResp.from_db( + score.user, + session, + include=["statistics", "team", "daily_challenge_user_stats"], + ruleset=score.gamemode, + ) s.rank_global = ( await get_score_position_by_id( session, - score.map_md5, + score.beatmap_id, score.id, mode=score.gamemode, - user=user or score.user, + user=score.user, ) or None ) s.rank_country = ( await get_score_position_by_id( session, - score.map_md5, + score.beatmap_id, score.id, score.gamemode, - user or score.user, + score.user, + type=LeaderboardType.COUNTRY, ) or None ) @@ -228,134 +234,137 @@ class ScoreResp(ScoreBase): async def get_best_id(session: AsyncSession, score_id: int) -> None: rownum = ( func.row_number() - .over(partition_by=col(BestScore.user_id), order_by=col(BestScore.pp).desc()) + .over( + partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc() + ) .label("rn") ) - subq = select(BestScore, rownum).subquery() + subq = select(PPBestScore, rownum).subquery() stmt = select(subq.c.rn).where(subq.c.score_id == score_id) result = await session.exec(stmt) return result.one_or_none() +async def _score_where( + type: LeaderboardType, + beatmap: int, + mode: GameMode, + mods: list[str] | None = None, + user: User | None = None, +) -> list[ColumnElement[bool]] | None: + wheres = [ + col(BestScore.beatmap_id) == beatmap, + col(BestScore.gamemode) == mode, + ] + + if type == LeaderboardType.FRIENDS: + if user and user.is_supporter: + subq = ( + select(DBRelationship.target_id) + .where( + DBRelationship.type == RelationshipType.FOLLOW, + DBRelationship.user_id == user.id, + ) + .subquery() + ) + wheres.append(col(BestScore.user_id).in_(select(subq.c.target_id))) + else: + return None + elif type == LeaderboardType.COUNTRY: + if user and user.is_supporter: + wheres.append( + col(BestScore.user).has(col(User.country_code) == user.country_code) + ) + else: + return None + elif type == LeaderboardType.TEAM: + if user: + team_membership = await user.awaitable_attrs.team_membership + if team_membership: + team_id = team_membership.team_id + wheres.append( + col(BestScore.user).has( + col(User.team_membership).has(TeamMember.team_id == team_id) + ) + ) + if mods: + if user and user.is_supporter: + wheres.append( + text( + "JSON_CONTAINS(total_score_best_scores.mods, :w)" + " AND JSON_CONTAINS(:w, total_score_best_scores.mods)" + ) # pyright: ignore[reportArgumentType] + ) + else: + return None + return wheres + + async def get_leaderboard( session: AsyncSession, - beatmap_md5: str, + beatmap: int, mode: GameMode, type: LeaderboardType = LeaderboardType.GLOBAL, - mods: list[APIMod] | None = None, + mods: list[str] | None = None, user: User | None = None, limit: int = 50, -) -> list[Score]: - scores = [] - if type == LeaderboardType.GLOBAL: - query = ( - select(Score) - .where( - col(Beatmap.beatmap_status).in_( - [ - BeatmapRankStatus.RANKED, - BeatmapRankStatus.LOVED, - BeatmapRankStatus.QUALIFIED, - BeatmapRankStatus.APPROVED, - ] - ), - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - col(Score.passed).is_(True), - Score.mods == mods if user and user.is_supporter else false(), - ) - .limit(limit) - .order_by( - col(Score.total_score).desc(), - ) - ) - result = await session.exec(query) - scores = list[Score](result.all()) - elif type == LeaderboardType.FRIENDS and user and user.is_supporter: - # TODO - ... - elif type == LeaderboardType.TEAM and user and user.team_membership: - team_id = user.team_membership.team_id - query = ( - select(Score) - .join(Beatmap) - .where( - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - col(Score.passed).is_(True), - col(Score.user.team_membership).is_not(None), - Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess] - Score.mods == mods if user and user.is_supporter else false(), - ) - .limit(limit) - .order_by( - col(Score.total_score).desc(), - ) - ) - result = await session.exec(query) - scores = list[Score](result.all()) +) -> tuple[list[Score], Score | None]: + wheres = await _score_where(type, beatmap, mode, mods, user) + if wheres is None: + return [], None + query = ( + select(BestScore) + .where(*wheres) + .limit(limit) + .order_by(col(BestScore.total_score).desc()) + ) + if mods: + query = query.params(w=json.dumps(mods)) + scores = [s.score for s in await session.exec(query)] + user_score = None if user: - user_score = ( - await session.exec( - select(Score).where( - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - Score.user_id == user.id, - col(Score.passed).is_(True), + self_query = ( + select(BestScore) + .where(BestScore.user_id == user.id) + .order_by(col(BestScore.total_score).desc()) + .limit(1) + ) + if mods: + self_query = self_query.where( + text( + "JSON_CONTAINS(total_score_best_scores.mods, :w)" + " AND JSON_CONTAINS(:w, total_score_best_scores.mods)" ) - ) - ).first() + ).params(w=json.dumps(mods)) + user_bs = (await session.exec(self_query)).first() + if user_bs: + user_score = user_bs.score if user_score and user_score not in scores: scores.append(user_score) - return scores + return scores, user_score async def get_score_position_by_user( session: AsyncSession, - beatmap_md5: str, + beatmap: int, user: User, mode: GameMode, type: LeaderboardType = LeaderboardType.GLOBAL, - mods: list[APIMod] | None = None, + mods: list[str] | None = None, ) -> int: - where_clause = [ - Score.map_md5 == beatmap_md5, - Score.gamemode == mode, - col(Score.passed).is_(True), - col(Beatmap.beatmap_status).in_( - [ - BeatmapRankStatus.RANKED, - BeatmapRankStatus.LOVED, - BeatmapRankStatus.QUALIFIED, - BeatmapRankStatus.APPROVED, - ] - ), - ] - if mods and user.is_supporter: - where_clause.append(Score.mods == mods) - else: - where_clause.append(false()) - if type == LeaderboardType.FRIENDS and user.is_supporter: - # TODO - ... - elif type == LeaderboardType.TEAM and user.team_membership: - team_id = user.team_membership.team_id - where_clause.append( - col(Score.user.team_membership).is_not(None), - ) - where_clause.append( - Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess] - ) + wheres = await _score_where(type, beatmap, mode, mods, user=user) + if wheres is None: + return 0 rownum = ( func.row_number() .over( - partition_by=Score.map_md5, - order_by=col(Score.total_score).desc(), + partition_by=col(BestScore.beatmap_id), + order_by=col(BestScore.total_score).desc(), ) .label("row_number") ) - subq = select(Score, rownum).join(Beatmap).where(*where_clause).subquery() - stmt = select(subq.c.row_number).where(subq.c.user == user) + subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery() + stmt = select(subq.c.row_number).where(subq.c.user_id == user.id) result = await session.exec(stmt) s = result.one_or_none() return s if s else 0 @@ -363,57 +372,26 @@ async def get_score_position_by_user( async def get_score_position_by_id( session: AsyncSession, - beatmap_md5: str, + beatmap: int, score_id: int, mode: GameMode, user: User | None = None, type: LeaderboardType = LeaderboardType.GLOBAL, - mods: list[APIMod] | None = None, + mods: list[str] | None = None, ) -> int: - where_clause = [ - Score.map_md5 == beatmap_md5, - Score.id == score_id, - Score.gamemode == mode, - col(Score.passed).is_(True), - col(Beatmap.beatmap_status).in_( - [ - BeatmapRankStatus.RANKED, - BeatmapRankStatus.LOVED, - BeatmapRankStatus.QUALIFIED, - BeatmapRankStatus.APPROVED, - ] - ), - ] - if mods and user and user.is_supporter: - where_clause.append(Score.mods == mods) - elif mods: - where_clause.append(false()) + wheres = await _score_where(type, beatmap, mode, mods, user=user) + if wheres is None: + return 0 rownum = ( func.row_number() .over( - partition_by=[col(Score.user_id), col(Score.map_md5)], - order_by=col(Score.total_score).desc(), + partition_by=col(BestScore.beatmap_id), + order_by=col(BestScore.total_score).desc(), ) - .label("rownum") + .label("row_number") ) - subq = ( - select(Score.user_id, Score.id, Score.total_score, rownum) - .join(Beatmap) - .where(*where_clause) - .subquery() - ) - best_scores = aliased(subq) - overall_rank = ( - func.rank().over(order_by=best_scores.c.total_score.desc()).label("global_rank") - ) - final_q = ( - select(best_scores.c.id, overall_rank) - .select_from(best_scores) - .where(best_scores.c.rownum == 1) - .subquery() - ) - - stmt = select(final_q.c.global_rank).where(final_q.c.id == score_id) + subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery() + stmt = select(subq.c.row_number).where(subq.c.score_id == score_id) result = await session.exec(stmt) s = result.one_or_none() return s if s else 0 @@ -424,16 +402,38 @@ async def get_user_best_score_in_beatmap( beatmap: int, user: int, mode: GameMode | None = None, -) -> Score | None: +) -> BestScore | None: return ( await session.exec( - select(Score) + select(BestScore) .where( - Score.gamemode == mode if mode is not None else True, - Score.beatmap_id == beatmap, - Score.user_id == user, + BestScore.gamemode == mode if mode is not None else true(), + BestScore.beatmap_id == beatmap, + BestScore.user_id == user, ) - .order_by(col(Score.total_score).desc()) + .order_by(col(BestScore.total_score).desc()) + ) + ).first() + + +# FIXME +async def get_user_best_score_with_mod_in_beatmap( + session: AsyncSession, + beatmap: int, + user: int, + mod: list[str], + mode: GameMode | None = None, +) -> BestScore | None: + return ( + await session.exec( + select(BestScore) + .where( + BestScore.gamemode == mode if mode is not None else True, + BestScore.beatmap_id == beatmap, + BestScore.user_id == user, + # BestScore.mods == mod, + ) + .order_by(col(BestScore.total_score).desc()) ) ).first() @@ -443,13 +443,13 @@ async def get_user_best_pp_in_beatmap( beatmap: int, user: int, mode: GameMode, -) -> BestScore | None: +) -> PPBestScore | None: return ( await session.exec( - select(BestScore).where( - BestScore.beatmap_id == beatmap, - BestScore.user_id == user, - BestScore.gamemode == mode, + select(PPBestScore).where( + PPBestScore.beatmap_id == beatmap, + PPBestScore.user_id == user, + PPBestScore.gamemode == mode, ) ) ).first() @@ -459,12 +459,12 @@ async def get_user_best_pp( session: AsyncSession, user: int, limit: int = 200, -) -> Sequence[BestScore]: +) -> Sequence[PPBestScore]: return ( await session.exec( - select(BestScore) - .where(BestScore.user_id == user) - .order_by(col(BestScore.pp).desc()) + select(PPBestScore) + .where(PPBestScore.user_id == user) + .order_by(col(PPBestScore.pp).desc()) .limit(limit) ) ).all() @@ -474,9 +474,15 @@ async def process_user( session: AsyncSession, user: User, score: Score, ranked: bool = False ): assert user.id + assert score.id + mod_for_save = list({mod["acronym"] for mod in score.mods}) previous_score_best = await get_user_best_score_in_beatmap( session, score.beatmap_id, user.id, score.gamemode ) + previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap( + session, score.beatmap_id, user.id, mod_for_save, score.gamemode + ) + print(previous_score_best, previous_score_best_mod) add_to_db = False mouthly_playcount = ( await session.exec( @@ -493,7 +499,7 @@ async def process_user( ) add_to_db = True statistics = None - for i in user.statistics: + for i in await user.awaitable_attrs.statistics: if i.mode == score.gamemode.value: statistics = i break @@ -506,7 +512,7 @@ async def process_user( statistics.total_score += score.total_score difference = ( score.total_score - previous_score_best.total_score - if previous_score_best and previous_score_best.id != score.id + if previous_score_best else score.total_score ) if difference > 0 and score.passed and ranked: @@ -533,9 +539,41 @@ async def process_user( statistics.grade_sh -= 1 case Rank.A: statistics.grade_a -= 1 + else: + previous_score_best = BestScore( + user_id=user.id, + beatmap_id=score.beatmap_id, + gamemode=score.gamemode, + score_id=score.id, + total_score=score.total_score, + rank=score.rank, + mods=mod_for_save, + ) + session.add(previous_score_best) + statistics.ranked_score += difference statistics.level_current = calculate_score_to_level(statistics.ranked_score) statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) + if score.passed and ranked: + if previous_score_best_mod is not None: + previous_score_best_mod.mods = mod_for_save + previous_score_best_mod.score_id = score.id + previous_score_best_mod.rank = score.rank + previous_score_best_mod.total_score = score.total_score + elif ( + previous_score_best is not None and previous_score_best.score_id != score.id + ): + session.add( + BestScore( + user_id=user.id, + beatmap_id=score.beatmap_id, + gamemode=score.gamemode, + score_id=score.id, + total_score=score.total_score, + rank=score.rank, + mods=mod_for_save, + ) + ) statistics.play_count += 1 mouthly_playcount.playcount += 1 statistics.play_time += int((score.ended_at - score.started_at).total_seconds()) @@ -623,7 +661,7 @@ async def process_score( ) if previous_pp_best is None or score.pp > previous_pp_best.pp: assert score.id - best_score = BestScore( + best_score = PPBestScore( user_id=user_id, score_id=score.id, beatmap_id=beatmap_id, diff --git a/app/models/score.py b/app/models/score.py index b613ae2..bfc9f53 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -132,7 +132,7 @@ class HitResultInt(IntEnum): class LeaderboardType(Enum): GLOBAL = "global" - FRIENDS = "friends" + FRIENDS = "friend" COUNTRY = "country" TEAM = "team" diff --git a/app/router/score.py b/app/router/score.py index cd0a236..6c6a475 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -1,7 +1,7 @@ from __future__ import annotations from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User -from app.database.score import process_score, process_user +from app.database.score import get_leaderboard, process_score, process_user from app.dependencies.database import get_db, get_redis from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -9,6 +9,7 @@ from app.models.beatmap import BeatmapRankStatus from app.models.score import ( INT_TO_MODE, GameMode, + LeaderboardType, Rank, SoloScoreSubmissionInfo, ) @@ -19,7 +20,7 @@ from fastapi import Depends, Form, HTTPException, Query from pydantic import BaseModel from redis import Redis from sqlalchemy.orm import joinedload -from sqlmodel import col, select, true +from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -33,44 +34,26 @@ class BeatmapScores(BaseModel): ) async def get_beatmap_scores( beatmap: int, + mode: GameMode, legacy_only: bool = Query(None), # TODO:加入对这个参数的查询 - mode: GameMode | None = Query(None), - # mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询 - type: str = Query(None), + mods: list[str] = Query(default_factory=set, alias="mods[]"), + type: LeaderboardType = Query(LeaderboardType.GLOBAL), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), + limit: int = Query(50, ge=1, le=200), ): if legacy_only: raise HTTPException( status_code=404, detail="this server only contains lazer scores" ) - all_scores = ( - await db.exec( - Score.select_clause_unique( - Score.beatmap_id == beatmap, - col(Score.passed).is_(True), - Score.gamemode == mode if mode is not None else true(), - ) - ) - ).all() - - user_score = ( - await db.exec( - Score.select_clause_unique( - Score.beatmap_id == beatmap, - Score.user_id == current_user.id, - col(Score.passed).is_(True), - Score.gamemode == mode if mode is not None else true(), - ) - ) - ).first() + all_scores, user_score = await get_leaderboard( + db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods + ) return BeatmapScores( - scores=[await ScoreResp.from_db(db, score, score.user) for score in all_scores], - userScore=await ScoreResp.from_db(db, user_score, user_score.user) - if user_score - else None, + scores=[await ScoreResp.from_db(db, score) for score in all_scores], + userScore=await ScoreResp.from_db(db, user_score) if user_score else None, ) @@ -116,7 +99,7 @@ async def get_user_beatmap_score( else: return BeatmapUserScore( position=user_score.position if user_score.position is not None else 0, - score=await ScoreResp.from_db(db, user_score, user_score.user), + score=await ScoreResp.from_db(db, user_score), ) @@ -149,9 +132,7 @@ async def get_user_all_beatmap_scores( ) ).all() - return [ - await ScoreResp.from_db(db, score, current_user) for score in all_user_scores - ] + return [await ScoreResp.from_db(db, score) for score in all_user_scores] @router.post( @@ -243,4 +224,4 @@ async def submit_solo_score( await process_user(db, current_user, score, ranked) score = (await db.exec(select(Score).where(Score.id == score_id))).first() assert score is not None - return await ScoreResp.from_db(db, score, current_user) + return await ScoreResp.from_db(db, score) From c5fc6afc189fbe665801412c1cff9cc7a308ccd6 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Thu, 31 Jul 2025 14:38:10 +0000 Subject: [PATCH 3/6] feat(redis): use asyncio --- app/database/score.py | 2 +- app/dependencies/database.py | 11 ++-------- app/dependencies/fetcher.py | 23 ++++++++++----------- app/fetcher/_base.py | 38 +++++++++++++++++------------------ app/fetcher/osu_dot_direct.py | 6 +++--- app/router/beatmap.py | 8 ++++---- app/router/room.py | 22 +++++++++----------- app/router/score.py | 2 +- main.py | 5 +++-- 9 files changed, 53 insertions(+), 64 deletions(-) diff --git a/app/database/score.py b/app/database/score.py index c5f1a38..642eac1 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -39,7 +39,7 @@ from .relationship import ( ) from .score_token import ScoreToken -from redis import Redis +from redis.asyncio import Redis from sqlalchemy import Column, ColumnExpressionArgument, DateTime from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import aliased diff --git a/app/dependencies/database.py b/app/dependencies/database.py index fe09139..77b15c3 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -5,15 +5,11 @@ import json from app.config import settings from pydantic import BaseModel +import redis.asyncio as redis from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession -try: - import redis -except ImportError: - redis = None - def json_serializer(value): if isinstance(value, BaseModel | SQLModel): @@ -25,10 +21,7 @@ def json_serializer(value): engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer) # Redis 连接 -if redis: - redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) -else: - redis_client = None +redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) # 数据库依赖 diff --git a/app/dependencies/fetcher.py b/app/dependencies/fetcher.py index d3c216a..806eb87 100644 --- a/app/dependencies/fetcher.py +++ b/app/dependencies/fetcher.py @@ -8,7 +8,7 @@ from app.log import logger fetcher: Fetcher | None = None -def get_fetcher() -> Fetcher: +async def get_fetcher() -> Fetcher: global fetcher if fetcher is None: fetcher = Fetcher( @@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher: settings.FETCHER_CALLBACK_URL, ) redis = get_redis() - if redis: - access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}") - if access_token: - fetcher.access_token = str(access_token) - refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}") - if refresh_token: - fetcher.refresh_token = str(refresh_token) - if not fetcher.access_token or not fetcher.refresh_token: - logger.opt(colors=True).info( - f"Login to initialize fetcher: {fetcher.authorize_url}" - ) + access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}") + if access_token: + fetcher.access_token = str(access_token) + refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}") + if refresh_token: + fetcher.refresh_token = str(refresh_token) + if not fetcher.access_token or not fetcher.refresh_token: + logger.opt(colors=True).info( + f"Login to initialize fetcher: {fetcher.authorize_url}" + ) return fetcher diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py index 08e3508..2717a35 100644 --- a/app/fetcher/_base.py +++ b/app/fetcher/_base.py @@ -59,16 +59,15 @@ class BaseFetcher: self.refresh_token = token_data.get("refresh_token", "") self.token_expiry = int(time.time()) + token_data["expires_in"] redis = get_redis() - if redis: - redis.set( - f"fetcher:access_token:{self.client_id}", - self.access_token, - ex=token_data["expires_in"], - ) - redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, - ) + await redis.set( + f"fetcher:access_token:{self.client_id}", + self.access_token, + ex=token_data["expires_in"], + ) + await redis.set( + f"fetcher:refresh_token:{self.client_id}", + self.refresh_token, + ) async def refresh_access_token(self) -> None: async with AsyncClient() as client: @@ -87,13 +86,12 @@ class BaseFetcher: self.refresh_token = token_data.get("refresh_token", "") self.token_expiry = int(time.time()) + token_data["expires_in"] redis = get_redis() - if redis: - redis.set( - f"fetcher:access_token:{self.client_id}", - self.access_token, - ex=token_data["expires_in"], - ) - redis.set( - f"fetcher:refresh_token:{self.client_id}", - self.refresh_token, - ) + await redis.set( + f"fetcher:access_token:{self.client_id}", + self.access_token, + ex=token_data["expires_in"], + ) + await redis.set( + f"fetcher:refresh_token:{self.client_id}", + self.refresh_token, + ) diff --git a/app/fetcher/osu_dot_direct.py b/app/fetcher/osu_dot_direct.py index 08b8dfc..cb3897f 100644 --- a/app/fetcher/osu_dot_direct.py +++ b/app/fetcher/osu_dot_direct.py @@ -4,7 +4,7 @@ from ._base import BaseFetcher from httpx import AsyncClient from loguru import logger -import redis +import redis.asyncio as redis class OsuDotDirectFetcher(BaseFetcher): @@ -23,7 +23,7 @@ class OsuDotDirectFetcher(BaseFetcher): self, redis: redis.Redis, beatmap_id: int ) -> str: if redis.exists(f"beatmap:{beatmap_id}:raw"): - return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] + return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] raw = await self.get_beatmap_raw(beatmap_id) - redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) + await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) return raw diff --git a/app/router/beatmap.py b/app/router/beatmap.py index df5f20d..0a25562 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -22,7 +22,7 @@ from .api_router import router from fastapi import Depends, HTTPException, Query from httpx import HTTPError, HTTPStatusError from pydantic import BaseModel -from redis import Redis +from redis.asyncio import Redis import rosu_pp_py as rosu from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -127,8 +127,8 @@ async def get_beatmap_attributes( f"beatmap:{beatmap}:{ruleset}:" f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" ) - if redis.exists(key): - return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType] + if await redis.exists(key): + return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] try: resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap) @@ -138,7 +138,7 @@ async def get_beatmap_attributes( ) except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue] raise HTTPException(status_code=400, detail=str(e)) - redis.set(key, attr.model_dump_json()) + await redis.set(key, attr.model_dump_json()) return attr except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmap not found") diff --git a/app/router/room.py b/app/router/room.py index ed540fc..3a65617 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -6,7 +6,8 @@ from app.models.room import Room from .api_router import router -from fastapi import Depends, HTTPException, Query +from fastapi import Depends, Query +from redis.asyncio import Redis from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -19,17 +20,14 @@ async def get_all_rooms( status: str = Query(None), category: str = Query(None), db: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), ): all_room_ids = (await db.exec(select(RoomIndex).where(True))).all() - redis = get_redis() roomsList: list[Room] = [] - if redis: - for room_index in all_room_ids: - dumped_room = redis.get(str(room_index.id)) - if dumped_room: - actual_room = Room.model_validate_json(str(dumped_room)) - if actual_room.status == status and actual_room.category == category: - roomsList.append(actual_room) - return roomsList - else: - raise HTTPException(status_code=500, detail="Redis Error") + for room_index in all_room_ids: + dumped_room = await redis.get(str(room_index.id)) + if dumped_room: + actual_room = Room.model_validate_json(str(dumped_room)) + if actual_room.status == status and actual_room.category == category: + roomsList.append(actual_room) + return roomsList diff --git a/app/router/score.py b/app/router/score.py index 6c6a475..2f1303e 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -18,7 +18,7 @@ from .api_router import router from fastapi import Depends, Form, HTTPException, Query from pydantic import BaseModel -from redis import Redis +from redis.asyncio import Redis from sqlalchemy.orm import joinedload from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession diff --git a/main.py b/main.py index 92d4402..f5d20c1 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager from datetime import datetime from app.config import settings -from app.dependencies.database import create_tables, engine +from app.dependencies.database import create_tables, engine, redis_client from app.dependencies.fetcher import get_fetcher from app.router import api_router, auth_router, fetcher_router, signalr_router @@ -15,10 +15,11 @@ from fastapi import FastAPI async def lifespan(app: FastAPI): # on startup await create_tables() - get_fetcher() # 初始化 fetcher + await get_fetcher() # 初始化 fetcher # on shutdown yield await engine.dispose() + await redis_client.aclose() app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan) From 86a6fd1b69b962692ead1146ff6bc1addf1455ec Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 1 Aug 2025 02:49:49 +0000 Subject: [PATCH 4/6] feat(user): support `online` & `last_visit` --- app/database/lazer_user.py | 5 ++++- app/database/score.py | 1 - app/fetcher/osu_dot_direct.py | 2 +- app/signalr/hub/metadata.py | 18 +++++++++++++++++- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index d502ccb..1337cc2 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -1,6 +1,7 @@ from datetime import UTC, datetime from typing import TYPE_CHECKING, NotRequired, TypedDict +from app.dependencies.database import get_redis from app.models.model import UTCBaseModel from app.models.score import GameMode from app.models.user import Country, Page, RankHistory @@ -157,7 +158,7 @@ class User(AsyncAttrs, UserBase, table=True): class UserResp(UserBase): id: int | None = None - is_online: bool = True # TODO + is_online: bool = False groups: list = [] # TODO country: Country = Field(default_factory=lambda: Country(code="CN", name="China")) favourite_beatmapset_count: int = 0 # TODO @@ -225,6 +226,8 @@ class UserResp(UserBase): .limit(200) ) ).one() + redis = get_redis() + u.is_online = await redis.exists(f"metadata:online:{obj.id}") u.cover_url = ( obj.cover.get( "url", "https://assets.ppy.sh/user-profile-covers/default.jpeg" diff --git a/app/database/score.py b/app/database/score.py index 642eac1..32ddb6c 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -482,7 +482,6 @@ async def process_user( previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap( session, score.beatmap_id, user.id, mod_for_save, score.gamemode ) - print(previous_score_best, previous_score_best_mod) add_to_db = False mouthly_playcount = ( await session.exec( diff --git a/app/fetcher/osu_dot_direct.py b/app/fetcher/osu_dot_direct.py index cb3897f..6e18435 100644 --- a/app/fetcher/osu_dot_direct.py +++ b/app/fetcher/osu_dot_direct.py @@ -22,7 +22,7 @@ class OsuDotDirectFetcher(BaseFetcher): async def get_or_fetch_beatmap_raw( self, redis: redis.Redis, beatmap_id: int ) -> str: - if redis.exists(f"beatmap:{beatmap_id}:raw"): + if await redis.exists(f"beatmap:{beatmap_id}:raw"): return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] raw = await self.get_beatmap_raw(beatmap_id) await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index 2712883..227cf7b 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -2,10 +2,12 @@ from __future__ import annotations import asyncio from collections.abc import Coroutine +from datetime import UTC, datetime from typing import override from app.database import Relationship, RelationshipType -from app.dependencies.database import engine +from app.database.lazer_user import User +from app.dependencies.database import engine, get_redis from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity from .hub import Client, Hub @@ -54,6 +56,18 @@ class MetadataHub(Hub[MetadataClientState]): async def _clean_state(self, state: MetadataClientState) -> None: if state.pushable: await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None)) + redis = get_redis() + if await redis.exists(f"metadata:online:{state.connection_id}"): + await redis.delete(f"metadata:online:{state.connection_id}") + async with AsyncSession(engine) as session: + async with session.begin(): + user = ( + await session.exec( + select(User).where(User.id == int(state.connection_id)) + ) + ).one() + user.last_visit = datetime.now(UTC) + await session.commit() @override def create_state(self, client: Client) -> MetadataClientState: @@ -93,6 +107,8 @@ class MetadataHub(Hub[MetadataClientState]): ) ) await asyncio.gather(*tasks) + redis = get_redis() + await redis.set(f"metadata:online:{user_id}", "") async def UpdateStatus(self, client: Client, status: int) -> None: status_ = OnlineStatus(status) From d938998239c0b445a01261c492137856f88a9683 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 1 Aug 2025 04:22:17 +0000 Subject: [PATCH 5/6] feat(beatmapset): support post favoutite to beatmapset --- app/database/__init__.py | 2 + app/database/beatmap.py | 8 +- app/database/beatmapset.py | 119 ++++++++++++------ app/database/favourite_beatmapset.py | 53 ++++++++ app/database/lazer_user.py | 15 ++- app/database/score.py | 6 +- app/router/beatmap.py | 11 +- app/router/beatmapset.py | 45 +++++-- ...8ebf_beatmapset_support_favourite_count.py | 40 ++++++ 9 files changed, 249 insertions(+), 50 deletions(-) create mode 100644 app/database/favourite_beatmapset.py create mode 100644 migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py diff --git a/app/database/__init__.py b/app/database/__init__.py index 12fa867..6e2e8c5 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -10,6 +10,7 @@ from .beatmapset import ( ) from .best_score import BestScore from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp +from .favourite_beatmapset import FavouriteBeatmapset from .lazer_user import ( User, UserResp, @@ -41,6 +42,7 @@ __all__ = [ "BestScore", "DailyChallengeStats", "DailyChallengeStatsResp", + "FavouriteBeatmapset", "OAuthToken", "PPBestScore", "Relationship", diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 2ab5ad0..c55643a 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -14,6 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from app.fetcher import Fetcher + from .lazer_user import User + class BeatmapOwner(SQLModel): id: int @@ -161,6 +163,8 @@ class BeatmapResp(BeatmapBase): beatmap: Beatmap, query_mode: GameMode | None = None, from_set: bool = False, + session: AsyncSession | None = None, + user: "User | None" = None, ) -> "BeatmapResp": beatmap_ = beatmap.model_dump() if query_mode is not None and beatmap.mode != query_mode: @@ -170,5 +174,7 @@ class BeatmapResp(BeatmapBase): beatmap_["ranked"] = beatmap.beatmap_status.value beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode] if not from_set: - beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset) + beatmap_["beatmapset"] = await BeatmapsetResp.from_db( + beatmap.beatmapset, session=session, user=user + ) return cls.model_validate(beatmap_) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 5a618b7..49313b2 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -5,14 +5,17 @@ from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.models.model import UTCBaseModel from app.models.score import GameMode +from .lazer_user import BASE_INCLUDES, User, UserResp + from pydantic import BaseModel, model_serializer from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text from sqlalchemy.ext.asyncio import AsyncAttrs -from sqlmodel import Field, Relationship, SQLModel +from sqlmodel import Field, Relationship, SQLModel, col, func, select from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from .beatmap import Beatmap, BeatmapResp + from .favourite_beatmapset import FavouriteBeatmapset class BeatmapCovers(SQLModel): @@ -90,7 +93,6 @@ class BeatmapsetBase(SQLModel, UTCBaseModel): artist_unicode: str = Field(index=True) covers: BeatmapCovers | None = Field(sa_column=Column(JSON)) creator: str - favourite_count: int nsfw: bool = Field(default=False) play_count: int preview_url: str @@ -114,11 +116,9 @@ class BeatmapsetBase(SQLModel, UTCBaseModel): pack_tags: list[str] = Field(default=[], sa_column=Column(JSON)) ratings: list[int] = Field(default=None, sa_column=Column(JSON)) - # TODO: recent_favourites: Optional[list[User]] = None # TODO: related_users: Optional[list[User]] = None # TODO: user: Optional[User] = Field(default=None) track_id: int | None = Field(default=None) # feature artist? - # TODO: has_favourited # BeatmapsetExtended bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2))) @@ -152,6 +152,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): hype_required: int = Field(default=0) availability_info: str | None = Field(default=None) download_disabled: bool = Field(default=False) + favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset") @classmethod async def from_resp( @@ -199,40 +200,88 @@ class BeatmapsetResp(BeatmapsetBase): genre: BeatmapTranslationText | None = None language: BeatmapTranslationText | None = None nominations: BeatmapNominations | None = None + has_favourited: bool = False + favourite_count: int = 0 + recent_favourites: list[UserResp] = Field(default_factory=list) @classmethod - async def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": + async def from_db( + cls, + beatmapset: Beatmapset, + include: list[str] = [], + session: AsyncSession | None = None, + user: User | None = None, + ) -> "BeatmapsetResp": from .beatmap import BeatmapResp + from .favourite_beatmapset import FavouriteBeatmapset - beatmaps = [ - await BeatmapResp.from_db(beatmap, from_set=True) - for beatmap in await beatmapset.awaitable_attrs.beatmaps - ] + update = { + "beatmaps": [ + await BeatmapResp.from_db(beatmap, from_set=True) + for beatmap in await beatmapset.awaitable_attrs.beatmaps + ], + "hype": BeatmapHype( + current=beatmapset.hype_current, required=beatmapset.hype_required + ), + "availability": BeatmapAvailability( + more_information=beatmapset.availability_info, + download_disabled=beatmapset.download_disabled, + ), + "genre": BeatmapTranslationText( + name=beatmapset.beatmap_genre.name, + id=beatmapset.beatmap_genre.value, + ), + "language": BeatmapTranslationText( + name=beatmapset.beatmap_language.name, + id=beatmapset.beatmap_language.value, + ), + "nominations": BeatmapNominations( + required=beatmapset.nominations_required, + current=beatmapset.nominations_current, + ), + "status": beatmapset.beatmap_status.name.lower(), + "ranked": beatmapset.beatmap_status.value, + "is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING, + **beatmapset.model_dump(), + } + if session and user: + existing_favourite = ( + await session.exec( + select(FavouriteBeatmapset).where( + FavouriteBeatmapset.beatmapset_id == beatmapset.id + ) + ) + ).first() + update["has_favourited"] = existing_favourite is not None + + if session and "recent_favourites" in include: + recent_favourites = ( + await session.exec( + select(FavouriteBeatmapset) + .where( + FavouriteBeatmapset.beatmapset_id == beatmapset.id, + ) + .order_by(col(FavouriteBeatmapset.date).desc()) + .limit(50) + ) + ).all() + update["recent_favourites"] = [ + await UserResp.from_db( + await favourite.awaitable_attrs.user, + session=session, + include=BASE_INCLUDES, + ) + for favourite in recent_favourites + ] + + if session: + update["favourite_count"] = ( + await session.exec( + select(func.count()) + .select_from(FavouriteBeatmapset) + .where(FavouriteBeatmapset.beatmapset_id == beatmapset.id) + ) + ).one() return cls.model_validate( - { - "beatmaps": beatmaps, - "hype": BeatmapHype( - current=beatmapset.hype_current, required=beatmapset.hype_required - ), - "availability": BeatmapAvailability( - more_information=beatmapset.availability_info, - download_disabled=beatmapset.download_disabled, - ), - "genre": BeatmapTranslationText( - name=beatmapset.beatmap_genre.name, - id=beatmapset.beatmap_genre.value, - ), - "language": BeatmapTranslationText( - name=beatmapset.beatmap_language.name, - id=beatmapset.beatmap_language.value, - ), - "nominations": BeatmapNominations( - required=beatmapset.nominations_required, - current=beatmapset.nominations_current, - ), - "status": beatmapset.beatmap_status.name.lower(), - "ranked": beatmapset.beatmap_status.value, - "is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING, - **beatmapset.model_dump(), - } + update, ) diff --git a/app/database/favourite_beatmapset.py b/app/database/favourite_beatmapset.py new file mode 100644 index 0000000..51bd578 --- /dev/null +++ b/app/database/favourite_beatmapset.py @@ -0,0 +1,53 @@ +import datetime + +from app.database.beatmapset import Beatmapset +from app.database.lazer_user import User + +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) + + +class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True): + __tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType] + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, autoincrement=True, primary_key=True), + exclude=True, + ) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("lazer_users.id"), + index=True, + ), + ) + beatmapset_id: int = Field( + default=None, + sa_column=Column( + ForeignKey("beatmapsets.id"), + index=True, + ), + ) + date: datetime.datetime = Field( + default=datetime.datetime.now(datetime.UTC), + sa_column=Column( + DateTime, + ), + ) + + user: User = Relationship(back_populates="favourite_beatmapsets") + beatmapset: Beatmapset = Relationship( + sa_relationship_kwargs={ + "lazy": "selectin", + }, + back_populates="favourites", + ) diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 1337cc2..3bd751b 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -1,7 +1,6 @@ from datetime import UTC, datetime from typing import TYPE_CHECKING, NotRequired, TypedDict -from app.dependencies.database import get_redis from app.models.model import UTCBaseModel from app.models.score import GameMode from app.models.user import Country, Page, RankHistory @@ -28,7 +27,8 @@ from sqlmodel import ( from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: - from app.database.relationship import RelationshipResp + from .favourite_beatmapset import FavouriteBeatmapset + from .relationship import RelationshipResp class Kudosu(TypedDict): @@ -144,6 +144,9 @@ class User(AsyncAttrs, UserBase, table=True): back_populates="user" ) monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user") + favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship( + back_populates="user" + ) email: str = Field(max_length=254, unique=True, index=True, exclude=True) priv: int = Field(default=1, exclude=True) @@ -201,6 +204,8 @@ class UserResp(UserBase): include: list[str] = [], ruleset: GameMode | None = None, ) -> "UserResp": + from app.dependencies.database import get_redis + from .best_score import BestScore from .relationship import Relationship, RelationshipResp, RelationshipType @@ -320,3 +325,9 @@ SEARCH_INCLUDED = [ "achievements", "monthly_playcounts", ] + +BASE_INCLUDES = [ + "team", + "daily_challenge_user_stats", + "statistics", +] diff --git a/app/database/score.py b/app/database/score.py index 32ddb6c..79cb005 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -169,7 +169,9 @@ class ScoreResp(ScoreBase): assert score.id await score.awaitable_attrs.beatmap s.beatmap = await BeatmapResp.from_db(score.beatmap) - s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset) + s.beatmapset = await BeatmapsetResp.from_db( + score.beatmap.beatmapset, session=session, user=score.user + ) s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo s.ruleset_id = MODE_TO_INT[score.gamemode] @@ -669,7 +671,7 @@ async def process_score( acc=score.accuracy, ) session.add(best_score) - session.delete(previous_pp_best) if previous_pp_best else None + await session.delete(previous_pp_best) if previous_pp_best else None await session.commit() await session.refresh(score) await session.refresh(score_token) diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 0a25562..9574bdb 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -50,7 +50,7 @@ async def lookup_beatmap( if beatmap is None: raise HTTPException(status_code=404, detail="Beatmap not found") - return await BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap, session=db, user=current_user) @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) @@ -62,7 +62,7 @@ async def get_beatmap( ): try: beatmap = await Beatmap.get_or_fetch(db, fetcher, bid) - return await BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap, session=db, user=current_user) except HTTPError: raise HTTPException(status_code=404, detail="Beatmap not found") @@ -90,7 +90,12 @@ async def batch_get_beatmaps( await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)) ).all() - return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm) for bm in beatmaps]) + return BatchGetResp( + beatmaps=[ + await BeatmapResp.from_db(bm, session=db, user=current_user) + for bm in beatmaps + ] + ) @router.post( diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index b82678d..b4d2e4c 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -1,10 +1,8 @@ from __future__ import annotations -from app.database import ( - Beatmapset, - BeatmapsetResp, - User, -) +from typing import Literal + +from app.database import Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.dependencies.database import get_db from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -12,7 +10,7 @@ from app.fetcher import Fetcher from .api_router import router -from fastapi import Depends, HTTPException, Query +from fastapi import Depends, Form, HTTPException, Query from fastapi.responses import RedirectResponse from httpx import HTTPStatusError from sqlmodel import select @@ -34,7 +32,9 @@ async def get_beatmapset( except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmapset not found") else: - resp = await BeatmapsetResp.from_db(beatmapset) + resp = await BeatmapsetResp.from_db( + beatmapset, session=db, include=["recent_favourites"], user=current_user + ) return resp @@ -53,3 +53,34 @@ async def download_beatmapset( return RedirectResponse( f"https://api.nerinyan.moe/d/{beatmapset}?noVideo={no_video}" ) + + +@router.post("/beatmapsets/{beatmapset}/favourites", tags=["beatmapset"]) +async def favourite_beatmapset( + beatmapset: int, + action: Literal["favourite", "unfavourite"] = Form(), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + existing_favourite = ( + await db.exec( + select(FavouriteBeatmapset).where( + FavouriteBeatmapset.user_id == current_user.id, + FavouriteBeatmapset.beatmapset_id == beatmapset, + ) + ) + ).first() + + if action == "favourite" and existing_favourite: + raise HTTPException(status_code=400, detail="Already favourited") + elif action == "unfavourite" and not existing_favourite: + raise HTTPException(status_code=400, detail="Not favourited") + + if action == "favourite": + favourite = FavouriteBeatmapset( + user_id=current_user.id, beatmapset_id=beatmapset + ) + db.add(favourite) + else: + await db.delete(existing_favourite) + await db.commit() diff --git a/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py b/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py new file mode 100644 index 0000000..84bae15 --- /dev/null +++ b/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py @@ -0,0 +1,40 @@ +"""beatmapset: support favourite count + +Revision ID: 1178d0758ebf +Revises: +Create Date: 2025-08-01 04:05:09.882800 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "1178d0758ebf" +down_revision: str | Sequence[str] | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("beatmapsets", "favourite_count") + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "beatmapsets", + sa.Column( + "favourite_count", mysql.INTEGER(), autoincrement=False, nullable=False + ), + ) + # ### end Alembic commands ### From 74e4b1cb530a67e8dcd85c29073562a5343719f0 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 1 Aug 2025 04:27:44 +0000 Subject: [PATCH 6/6] fix(relationship): fix unique relationship --- app/database/relationship.py | 7 ++- ...02_relationship_fix_unique_relationship.py | 54 +++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 migrations/versions/58a11441d302_relationship_fix_unique_relationship.py diff --git a/app/database/relationship.py b/app/database/relationship.py index 7a351aa..b941c28 100644 --- a/app/database/relationship.py +++ b/app/database/relationship.py @@ -22,12 +22,16 @@ class RelationshipType(str, Enum): class Relationship(SQLModel, table=True): __tablename__ = "relationship" # pyright: ignore[reportAssignmentType] + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, autoincrement=True, primary_key=True), + exclude=True, + ) user_id: int = Field( default=None, sa_column=Column( BigInteger, ForeignKey("lazer_users.id"), - primary_key=True, index=True, ), ) @@ -36,7 +40,6 @@ class Relationship(SQLModel, table=True): sa_column=Column( BigInteger, ForeignKey("lazer_users.id"), - primary_key=True, index=True, ), ) diff --git a/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py b/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py new file mode 100644 index 0000000..e383621 --- /dev/null +++ b/migrations/versions/58a11441d302_relationship_fix_unique_relationship.py @@ -0,0 +1,54 @@ +"""relationship: fix unique relationship + +Revision ID: 58a11441d302 +Revises: 1178d0758ebf +Create Date: 2025-08-01 04:23:02.498166 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "58a11441d302" +down_revision: str | Sequence[str] | None = "1178d0758ebf" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "relationship", + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + ) + op.drop_constraint("PRIMARY", "relationship", type_="primary") + op.create_primary_key("pk_relationship", "relationship", ["id"]) + op.alter_column( + "relationship", "user_id", existing_type=mysql.BIGINT(), nullable=True + ) + op.alter_column( + "relationship", "target_id", existing_type=mysql.BIGINT(), nullable=True + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("pk_relationship", "relationship", type_="primary") + op.create_primary_key("PRIMARY", "relationship", ["user_id", "target_id"]) + op.alter_column( + "relationship", "target_id", existing_type=mysql.BIGINT(), nullable=False + ) + op.alter_column( + "relationship", "user_id", existing_type=mysql.BIGINT(), nullable=False + ) + op.drop_column("relationship", "id") + # ### end Alembic commands ###