chore(merge): merge branch 'main' into feat/multiplayer-api

This commit is contained in:
MingxuanGame
2025-08-01 05:24:12 +00:00
39 changed files with 971 additions and 2191 deletions

View File

@@ -8,7 +8,7 @@ import string
from app.config import settings from app.config import settings
from app.database import ( from app.database import (
OAuthToken, OAuthToken,
User as DBUser, User,
) )
from app.log import logger from app.log import logger
@@ -74,7 +74,7 @@ def get_password_hash(password: str) -> str:
async def authenticate_user_legacy( async def authenticate_user_legacy(
db: AsyncSession, name: str, password: str db: AsyncSession, name: str, password: str
) -> DBUser | None: ) -> User | None:
""" """
验证用户身份 - 使用类似 from_login 的逻辑 验证用户身份 - 使用类似 from_login 的逻辑
""" """
@@ -82,7 +82,7 @@ async def authenticate_user_legacy(
pw_md5 = hashlib.md5(password.encode()).hexdigest() pw_md5 = hashlib.md5(password.encode()).hexdigest()
# 2. 根据用户名查找用户 # 2. 根据用户名查找用户
statement = select(DBUser).where(DBUser.name == name) statement = select(User).where(User.username == name)
user = (await db.exec(statement)).first() user = (await db.exec(statement)).first()
if not user: if not user:
return None return None
@@ -113,7 +113,7 @@ async def authenticate_user_legacy(
async def authenticate_user( async def authenticate_user(
db: AsyncSession, username: str, password: str db: AsyncSession, username: str, password: str
) -> DBUser | None: ) -> User | None:
"""验证用户身份""" """验证用户身份"""
return await authenticate_user_legacy(db, username, password) return await authenticate_user_legacy(db, username, password)

View File

@@ -1,3 +1,4 @@
from .achievement import UserAchievement, UserAchievementResp
from .auth import OAuthToken from .auth import OAuthToken
from .beatmap import ( from .beatmap import (
Beatmap as Beatmap, Beatmap as Beatmap,
@@ -8,7 +9,13 @@ from .beatmapset import (
BeatmapsetResp as BeatmapsetResp, BeatmapsetResp as BeatmapsetResp,
) )
from .best_score import BestScore from .best_score import BestScore
from .legacy import LegacyOAuthToken, LegacyUserStatistics from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
from .favourite_beatmapset import FavouriteBeatmapset
from .lazer_user import (
User,
UserResp,
)
from .pp_best_score import PPBestScore
from .relationship import Relationship, RelationshipResp, RelationshipType from .relationship import Relationship, RelationshipResp, RelationshipType
from .score import ( from .score import (
Score, Score,
@@ -17,52 +24,27 @@ from .score import (
ScoreStatistics, ScoreStatistics,
) )
from .score_token import ScoreToken, ScoreTokenResp from .score_token import ScoreToken, ScoreTokenResp
from .statistics import (
UserStatistics,
UserStatisticsResp,
)
from .team import Team, TeamMember from .team import Team, TeamMember
from .user import ( from .user_account_history import (
DailyChallengeStats, UserAccountHistory,
LazerUserAchievement, UserAccountHistoryResp,
LazerUserBadge, UserAccountHistoryType,
LazerUserBanners,
LazerUserCountry,
LazerUserCounts,
LazerUserKudosu,
LazerUserMonthlyPlaycounts,
LazerUserPreviousUsername,
LazerUserProfile,
LazerUserProfileSections,
LazerUserReplaysWatched,
LazerUserStatistics,
RankHistory,
User,
UserAchievement,
UserAvatar,
) )
BeatmapsetResp.model_rebuild()
BeatmapResp.model_rebuild()
__all__ = [ __all__ = [
"Beatmap", "Beatmap",
"BeatmapResp",
"Beatmapset", "Beatmapset",
"BeatmapsetResp", "BeatmapsetResp",
"BestScore", "BestScore",
"DailyChallengeStats", "DailyChallengeStats",
"LazerUserAchievement", "DailyChallengeStatsResp",
"LazerUserBadge", "FavouriteBeatmapset",
"LazerUserBanners",
"LazerUserCountry",
"LazerUserCounts",
"LazerUserKudosu",
"LazerUserMonthlyPlaycounts",
"LazerUserPreviousUsername",
"LazerUserProfile",
"LazerUserProfileSections",
"LazerUserReplaysWatched",
"LazerUserStatistics",
"LegacyOAuthToken",
"LegacyUserStatistics",
"OAuthToken", "OAuthToken",
"RankHistory", "PPBestScore",
"Relationship", "Relationship",
"RelationshipResp", "RelationshipResp",
"RelationshipType", "RelationshipType",
@@ -75,6 +57,17 @@ __all__ = [
"Team", "Team",
"TeamMember", "TeamMember",
"User", "User",
"UserAccountHistory",
"UserAccountHistoryResp",
"UserAccountHistoryType",
"UserAchievement", "UserAchievement",
"UserAvatar", "UserAchievement",
"UserAchievementResp",
"UserResp",
"UserStatistics",
"UserStatisticsResp",
] ]
for i in __all__:
if i.endswith("Resp"):
globals()[i].model_rebuild() # type: ignore[call-arg]

View File

@@ -1,19 +1,21 @@
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from sqlalchemy import Column, DateTime from sqlalchemy import Column, DateTime
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .lazer_user import User
class OAuthToken(SQLModel, table=True): class OAuthToken(UTCBaseModel, SQLModel, table=True):
__tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType] __tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True) id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field( user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("users.id"), index=True) sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
) )
access_token: str = Field(max_length=500, unique=True) access_token: str = Field(max_length=500, unique=True)
refresh_token: str = Field(max_length=500, unique=True) refresh_token: str = Field(max_length=500, unique=True)

View File

@@ -7,13 +7,14 @@ from app.models.score import MODE_TO_INT, GameMode
from .beatmapset import Beatmapset, BeatmapsetResp from .beatmapset import Beatmapset, BeatmapsetResp
from sqlalchemy import DECIMAL, Column, DateTime from sqlalchemy import DECIMAL, Column, DateTime
from sqlalchemy.orm import joinedload
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING: if TYPE_CHECKING:
from app.fetcher import Fetcher from app.fetcher import Fetcher
from .lazer_user import User
class BeatmapOwner(SQLModel): class BeatmapOwner(SQLModel):
id: int id: int
@@ -66,7 +67,9 @@ class Beatmap(BeatmapBase, table=True):
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus beatmap_status: BeatmapRankStatus
# optional # optional
beatmapset: Beatmapset = Relationship(back_populates="beatmaps") beatmapset: Beatmapset = Relationship(
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
)
@property @property
def can_ranked(self) -> bool: def can_ranked(self) -> bool:
@@ -87,13 +90,7 @@ class Beatmap(BeatmapBase, table=True):
session.add(beatmap) session.add(beatmap)
await session.commit() await session.commit()
beatmap = ( beatmap = (
await session.exec( await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.where(Beatmap.id == resp.id)
)
).first() ).first()
assert beatmap is not None, "Beatmap should not be None after commit" assert beatmap is not None, "Beatmap should not be None after commit"
return beatmap return beatmap
@@ -131,13 +128,9 @@ class Beatmap(BeatmapBase, table=True):
) -> "Beatmap": ) -> "Beatmap":
beatmap = ( beatmap = (
await session.exec( await session.exec(
select(Beatmap) select(Beatmap).where(
.where(
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5 Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
) )
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
) )
).first() ).first()
if not beatmap: if not beatmap:
@@ -164,11 +157,13 @@ class BeatmapResp(BeatmapBase):
url: str = "" url: str = ""
@classmethod @classmethod
def from_db( async def from_db(
cls, cls,
beatmap: Beatmap, beatmap: Beatmap,
query_mode: GameMode | None = None, query_mode: GameMode | None = None,
from_set: bool = False, from_set: bool = False,
session: AsyncSession | None = None,
user: "User | None" = None,
) -> "BeatmapResp": ) -> "BeatmapResp":
beatmap_ = beatmap.model_dump() beatmap_ = beatmap.model_dump()
if query_mode is not None and beatmap.mode != query_mode: if query_mode is not None and beatmap.mode != query_mode:
@@ -178,5 +173,7 @@ class BeatmapResp(BeatmapBase):
beatmap_["ranked"] = beatmap.beatmap_status.value beatmap_["ranked"] = beatmap.beatmap_status.value
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode] beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
if not from_set: if not from_set:
beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset) beatmap_["beatmapset"] = await BeatmapsetResp.from_db(
beatmap.beatmapset, session=session, user=user
)
return cls.model_validate(beatmap_) return cls.model_validate(beatmap_)

View File

@@ -4,13 +4,17 @@ from typing import TYPE_CHECKING, TypedDict, cast
from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.score import GameMode from app.models.score import GameMode
from .lazer_user import BASE_INCLUDES, User, UserResp
from pydantic import BaseModel, model_serializer from pydantic import BaseModel, model_serializer
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
from sqlmodel import Field, Relationship, SQLModel from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel, col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING: if TYPE_CHECKING:
from .beatmap import Beatmap, BeatmapResp from .beatmap import Beatmap, BeatmapResp
from .favourite_beatmapset import FavouriteBeatmapset
class BeatmapCovers(SQLModel): class BeatmapCovers(SQLModel):
@@ -88,7 +92,6 @@ class BeatmapsetBase(SQLModel):
artist_unicode: str = Field(index=True) artist_unicode: str = Field(index=True)
covers: BeatmapCovers | None = Field(sa_column=Column(JSON)) covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
creator: str creator: str
favourite_count: int
nsfw: bool = Field(default=False) nsfw: bool = Field(default=False)
play_count: int play_count: int
preview_url: str preview_url: str
@@ -112,11 +115,9 @@ class BeatmapsetBase(SQLModel):
pack_tags: list[str] = Field(default=[], sa_column=Column(JSON)) pack_tags: list[str] = Field(default=[], sa_column=Column(JSON))
ratings: list[int] = Field(default=None, sa_column=Column(JSON)) ratings: list[int] = Field(default=None, sa_column=Column(JSON))
# TODO: recent_favourites: Optional[list[User]] = None
# TODO: related_users: Optional[list[User]] = None # TODO: related_users: Optional[list[User]] = None
# TODO: user: Optional[User] = Field(default=None) # TODO: user: Optional[User] = Field(default=None)
track_id: int | None = Field(default=None) # feature artist? track_id: int | None = Field(default=None) # feature artist?
# TODO: has_favourited
# BeatmapsetExtended # BeatmapsetExtended
bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2))) bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2)))
@@ -129,7 +130,7 @@ class BeatmapsetBase(SQLModel):
tags: str = Field(default="", sa_column=Column(Text)) tags: str = Field(default="", sa_column=Column(Text))
class Beatmapset(BeatmapsetBase, table=True): class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType] __tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True) id: int | None = Field(default=None, primary_key=True, index=True)
@@ -150,6 +151,7 @@ class Beatmapset(BeatmapsetBase, table=True):
hype_required: int = Field(default=0) hype_required: int = Field(default=0)
availability_info: str | None = Field(default=None) availability_info: str | None = Field(default=None)
download_disabled: bool = Field(default=False) download_disabled: bool = Field(default=False)
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod @classmethod
async def from_resp( async def from_resp(
@@ -197,40 +199,88 @@ class BeatmapsetResp(BeatmapsetBase):
genre: BeatmapTranslationText | None = None genre: BeatmapTranslationText | None = None
language: BeatmapTranslationText | None = None language: BeatmapTranslationText | None = None
nominations: BeatmapNominations | None = None nominations: BeatmapNominations | None = None
has_favourited: bool = False
favourite_count: int = 0
recent_favourites: list[UserResp] = Field(default_factory=list)
@classmethod @classmethod
def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": async def from_db(
cls,
beatmapset: Beatmapset,
include: list[str] = [],
session: AsyncSession | None = None,
user: User | None = None,
) -> "BeatmapsetResp":
from .beatmap import BeatmapResp from .beatmap import BeatmapResp
from .favourite_beatmapset import FavouriteBeatmapset
beatmaps = [ update = {
BeatmapResp.from_db(beatmap, from_set=True) "beatmaps": [
for beatmap in beatmapset.beatmaps await BeatmapResp.from_db(beatmap, from_set=True)
] for beatmap in await beatmapset.awaitable_attrs.beatmaps
],
"hype": BeatmapHype(
current=beatmapset.hype_current, required=beatmapset.hype_required
),
"availability": BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
),
"genre": BeatmapTranslationText(
name=beatmapset.beatmap_genre.name,
id=beatmapset.beatmap_genre.value,
),
"language": BeatmapTranslationText(
name=beatmapset.beatmap_language.name,
id=beatmapset.beatmap_language.value,
),
"nominations": BeatmapNominations(
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
),
"status": beatmapset.beatmap_status.name.lower(),
"ranked": beatmapset.beatmap_status.value,
"is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING,
**beatmapset.model_dump(),
}
if session and user:
existing_favourite = (
await session.exec(
select(FavouriteBeatmapset).where(
FavouriteBeatmapset.beatmapset_id == beatmapset.id
)
)
).first()
update["has_favourited"] = existing_favourite is not None
if session and "recent_favourites" in include:
recent_favourites = (
await session.exec(
select(FavouriteBeatmapset)
.where(
FavouriteBeatmapset.beatmapset_id == beatmapset.id,
)
.order_by(col(FavouriteBeatmapset.date).desc())
.limit(50)
)
).all()
update["recent_favourites"] = [
await UserResp.from_db(
await favourite.awaitable_attrs.user,
session=session,
include=BASE_INCLUDES,
)
for favourite in recent_favourites
]
if session:
update["favourite_count"] = (
await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
).one()
return cls.model_validate( return cls.model_validate(
{ update,
"beatmaps": beatmaps,
"hype": BeatmapHype(
current=beatmapset.hype_current, required=beatmapset.hype_required
),
"availability": BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
),
"genre": BeatmapTranslationText(
name=beatmapset.beatmap_genre.name,
id=beatmapset.beatmap_genre.value,
),
"language": BeatmapTranslationText(
name=beatmapset.beatmap_language.name,
id=beatmapset.beatmap_language.value,
),
"nominations": BeatmapNominations(
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
),
"status": beatmapset.beatmap_status.name.lower(),
"ranked": beatmapset.beatmap_status.value,
"is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING,
**beatmapset.model_dump(),
}
) )

View File

@@ -1,14 +1,14 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from app.models.score import GameMode from app.models.score import GameMode, Rank
from .user import User from .lazer_user import User
from sqlmodel import ( from sqlmodel import (
JSON,
BigInteger, BigInteger,
Column, Column,
Field, Field,
Float,
ForeignKey, ForeignKey,
Relationship, Relationship,
SQLModel, SQLModel,
@@ -20,22 +20,29 @@ if TYPE_CHECKING:
class BestScore(SQLModel, table=True): class BestScore(SQLModel, table=True):
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType] __tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field( user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("users.id"), index=True) sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
) )
score_id: int = Field( score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
) )
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True) gamemode: GameMode = Field(index=True)
pp: float = Field( total_score: int = Field(
sa_column=Column(Float, default=0), default=0, sa_column=Column(BigInteger, ForeignKey("scores.total_score"))
) )
acc: float = Field( mods: list[str] = Field(
sa_column=Column(Float, default=0), default_factory=list,
sa_column=Column(JSON),
) )
rank: Rank
user: User = Relationship() user: User = Relationship()
score: "Score" = Relationship() score: "Score" = Relationship(
sa_relationship_kwargs={
"foreign_keys": "[BestScore.score_id]",
"lazy": "joined",
}
)
beatmap: "Beatmap" = Relationship() beatmap: "Beatmap" = Relationship()

View File

@@ -0,0 +1,53 @@
import datetime
from app.database.beatmapset import Beatmapset
from app.database.lazer_user import User
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
)
class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True):
__tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
exclude=True,
)
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("lazer_users.id"),
index=True,
),
)
beatmapset_id: int = Field(
default=None,
sa_column=Column(
ForeignKey("beatmapsets.id"),
index=True,
),
)
date: datetime.datetime = Field(
default=datetime.datetime.now(datetime.UTC),
sa_column=Column(
DateTime,
),
)
user: User = Relationship(back_populates="favourite_beatmapsets")
beatmapset: Beatmapset = Relationship(
sa_relationship_kwargs={
"lazy": "selectin",
},
back_populates="favourites",
)

View File

@@ -12,7 +12,7 @@ from .statistics import UserStatistics, UserStatisticsResp
from .team import Team, TeamMember from .team import Team, TeamMember
from .user_account_history import UserAccountHistory, UserAccountHistoryResp from .user_account_history import UserAccountHistory, UserAccountHistoryResp
from sqlalchemy.orm import joinedload, selectinload from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import ( from sqlmodel import (
JSON, JSON,
BigInteger, BigInteger,
@@ -27,7 +27,8 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING: if TYPE_CHECKING:
from app.database.relationship import RelationshipResp from .favourite_beatmapset import FavouriteBeatmapset
from .relationship import RelationshipResp
class Kudosu(TypedDict): class Kudosu(TypedDict):
@@ -128,7 +129,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_bng: bool = False is_bng: bool = False
class User(UserBase, table=True): class User(AsyncAttrs, UserBase, table=True):
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType] __tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
id: int | None = Field( id: int | None = Field(
@@ -143,6 +144,9 @@ class User(UserBase, table=True):
back_populates="user" back_populates="user"
) )
monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user") monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user")
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(
back_populates="user"
)
email: str = Field(max_length=254, unique=True, index=True, exclude=True) email: str = Field(max_length=254, unique=True, index=True, exclude=True)
priv: int = Field(default=1, exclude=True) priv: int = Field(default=1, exclude=True)
@@ -154,21 +158,10 @@ class User(UserBase, table=True):
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
) )
@classmethod
def all_select_option(cls):
return (
selectinload(cls.account_history), # pyright: ignore[reportArgumentType]
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
selectinload(cls.achievement), # pyright: ignore[reportArgumentType]
joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType]
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
selectinload(cls.monthly_playcounts), # pyright: ignore[reportArgumentType]
)
class UserResp(UserBase): class UserResp(UserBase):
id: int | None = None id: int | None = None
is_online: bool = True # TODO is_online: bool = False
groups: list = [] # TODO groups: list = [] # TODO
country: Country = Field(default_factory=lambda: Country(code="CN", name="China")) country: Country = Field(default_factory=lambda: Country(code="CN", name="China"))
favourite_beatmapset_count: int = 0 # TODO favourite_beatmapset_count: int = 0 # TODO
@@ -211,6 +204,8 @@ class UserResp(UserBase):
include: list[str] = [], include: list[str] = [],
ruleset: GameMode | None = None, ruleset: GameMode | None = None,
) -> "UserResp": ) -> "UserResp":
from app.dependencies.database import get_redis
from .best_score import BestScore from .best_score import BestScore
from .relationship import Relationship, RelationshipResp, RelationshipType from .relationship import Relationship, RelationshipResp, RelationshipType
@@ -236,6 +231,8 @@ class UserResp(UserBase):
.limit(200) .limit(200)
) )
).one() ).one()
redis = get_redis()
u.is_online = await redis.exists(f"metadata:online:{obj.id}")
u.cover_url = ( u.cover_url = (
obj.cover.get( obj.cover.get(
"url", "https://assets.ppy.sh/user-profile-covers/default.jpeg" "url", "https://assets.ppy.sh/user-profile-covers/default.jpeg"
@@ -249,13 +246,7 @@ class UserResp(UserBase):
await RelationshipResp.from_db(session, r) await RelationshipResp.from_db(session, r)
for r in ( for r in (
await session.exec( await session.exec(
select(Relationship) select(Relationship).where(
.options(
joinedload(Relationship.target).options( # pyright: ignore[reportArgumentType]
*User.all_select_option()
)
)
.where(
Relationship.user_id == obj.id, Relationship.user_id == obj.id,
Relationship.type == RelationshipType.FOLLOW, Relationship.type == RelationshipType.FOLLOW,
) )
@@ -264,23 +255,26 @@ class UserResp(UserBase):
] ]
if "team" in include: if "team" in include:
if obj.team_membership: if await obj.awaitable_attrs.team_membership:
assert obj.team_membership
u.team = obj.team_membership.team u.team = obj.team_membership.team
if "account_history" in include: if "account_history" in include:
u.account_history = [ u.account_history = [
UserAccountHistoryResp.from_db(ah) for ah in obj.account_history UserAccountHistoryResp.from_db(ah)
for ah in await obj.awaitable_attrs.account_history
] ]
if "daily_challenge_user_stats": if "daily_challenge_user_stats":
if obj.daily_challenge_stats: if await obj.awaitable_attrs.daily_challenge_stats:
assert obj.daily_challenge_stats
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db( u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
obj.daily_challenge_stats obj.daily_challenge_stats
) )
if "statistics" in include: if "statistics" in include:
current_stattistics = None current_stattistics = None
for i in obj.statistics: for i in await obj.awaitable_attrs.statistics:
if i.mode == (ruleset or obj.playmode): if i.mode == (ruleset or obj.playmode):
current_stattistics = i current_stattistics = i
break break
@@ -292,17 +286,20 @@ class UserResp(UserBase):
if "statistics_rulesets" in include: if "statistics_rulesets" in include:
u.statistics_rulesets = { u.statistics_rulesets = {
i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics i.mode.value: UserStatisticsResp.from_db(i)
for i in await obj.awaitable_attrs.statistics
} }
if "monthly_playcounts" in include: if "monthly_playcounts" in include:
u.monthly_playcounts = [ u.monthly_playcounts = [
MonthlyPlaycountsResp.from_db(pc) for pc in obj.monthly_playcounts MonthlyPlaycountsResp.from_db(pc)
for pc in await obj.awaitable_attrs.monthly_playcounts
] ]
if "achievements" in include: if "achievements" in include:
u.user_achievements = [ u.user_achievements = [
UserAchievementResp.from_db(ua) for ua in obj.achievement UserAchievementResp.from_db(ua)
for ua in await obj.awaitable_attrs.achievement
] ]
return u return u
@@ -328,3 +325,9 @@ SEARCH_INCLUDED = [
"achievements", "achievements",
"monthly_playcounts", "monthly_playcounts",
] ]
BASE_INCLUDES = [
"team",
"daily_challenge_user_stats",
"statistics",
]

View File

@@ -1,94 +0,0 @@
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import JSON, Column, DateTime
from sqlalchemy.orm import Mapped
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
# ============================================
# 旧的兼容性表模型(保留以便向后兼容)
# ============================================
class LegacyUserStatistics(SQLModel, table=True):
__tablename__ = "user_statistics" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
mode: str = Field(max_length=10) # osu, taiko, fruits, mania
# 基本统计
count_100: int = Field(default=0)
count_300: int = Field(default=0)
count_50: int = Field(default=0)
count_miss: int = Field(default=0)
# 等级信息
level_current: int = Field(default=1)
level_progress: int = Field(default=0)
# 排名信息
global_rank: int | None = Field(default=None)
global_rank_exp: int | None = Field(default=None)
country_rank: int | None = Field(default=None)
# PP 和分数
pp: float = Field(default=0.0)
pp_exp: float = Field(default=0.0)
ranked_score: int = Field(default=0)
hit_accuracy: float = Field(default=0.0)
total_score: int = Field(default=0)
total_hits: int = Field(default=0)
maximum_combo: int = Field(default=0)
# 游戏统计
play_count: int = Field(default=0)
play_time: int = Field(default=0)
replays_watched_by_others: int = Field(default=0)
is_ranked: bool = Field(default=False)
# 成绩等级计数
grade_ss: int = Field(default=0)
grade_ssh: int = Field(default=0)
grade_s: int = Field(default=0)
grade_sh: int = Field(default=0)
grade_a: int = Field(default=0)
# 最高排名记录
rank_highest: int | None = Field(default=None)
rank_highest_updated_at: datetime | None = Field(
default=None, sa_column=Column(DateTime)
)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
# 关联关系
user: Mapped["User"] = Relationship(back_populates="statistics")
class LegacyOAuthToken(SQLModel, table=True):
__tablename__ = "legacy_oauth_tokens" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
access_token: str = Field(max_length=255, index=True)
refresh_token: str = Field(max_length=255, index=True)
expires_at: datetime = Field(sa_column=Column(DateTime))
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
previous_usernames: list = Field(default_factory=list, sa_column=Column(JSON))
replays_watched_counts: list = Field(default_factory=list, sa_column=Column(JSON))
# 用户关系
user: "User" = Relationship()

View File

@@ -0,0 +1,41 @@
from typing import TYPE_CHECKING
from app.models.score import GameMode
from .lazer_user import User
from sqlmodel import (
BigInteger,
Column,
Field,
Float,
ForeignKey,
Relationship,
SQLModel,
)
if TYPE_CHECKING:
from .beatmap import Beatmap
from .score import Score
class PPBestScore(SQLModel, table=True):
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True)
pp: float = Field(
sa_column=Column(Float, default=0),
)
acc: float = Field(
sa_column=Column(Float, default=0),
)
user: User = Relationship()
score: "Score" = Relationship()
beatmap: "Beatmap" = Relationship()

View File

@@ -1,8 +1,6 @@
from enum import Enum from enum import Enum
from app.models.user import User as APIUser from .lazer_user import User, UserResp
from .user import User as DBUser
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import ( from sqlmodel import (
@@ -24,12 +22,16 @@ class RelationshipType(str, Enum):
class Relationship(SQLModel, table=True): class Relationship(SQLModel, table=True):
__tablename__ = "relationship" # pyright: ignore[reportAssignmentType] __tablename__ = "relationship" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
exclude=True,
)
user_id: int = Field( user_id: int = Field(
default=None, default=None,
sa_column=Column( sa_column=Column(
BigInteger, BigInteger,
ForeignKey("users.id"), ForeignKey("lazer_users.id"),
primary_key=True,
index=True, index=True,
), ),
) )
@@ -37,20 +39,22 @@ class Relationship(SQLModel, table=True):
default=None, default=None,
sa_column=Column( sa_column=Column(
BigInteger, BigInteger,
ForeignKey("users.id"), ForeignKey("lazer_users.id"),
primary_key=True,
index=True, index=True,
), ),
) )
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
target: DBUser = SQLRelationship( target: User = SQLRelationship(
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"} sa_relationship_kwargs={
"foreign_keys": "[Relationship.target_id]",
"lazy": "selectin",
}
) )
class RelationshipResp(BaseModel): class RelationshipResp(BaseModel):
target_id: int target_id: int
target: APIUser target: UserResp
mutual: bool = False mutual: bool = False
type: RelationshipType type: RelationshipType
@@ -58,8 +62,6 @@ class RelationshipResp(BaseModel):
async def from_db( async def from_db(
cls, session: AsyncSession, relationship: Relationship cls, session: AsyncSession, relationship: Relationship
) -> "RelationshipResp": ) -> "RelationshipResp":
from app.utils import convert_db_user_to_api_user
target_relationship = ( target_relationship = (
await session.exec( await session.exec(
select(Relationship).where( select(Relationship).where(
@@ -75,7 +77,16 @@ class RelationshipResp(BaseModel):
) )
return cls( return cls(
target_id=relationship.target_id, target_id=relationship.target_id,
target=await convert_db_user_to_api_user(relationship.target), target=await UserResp.from_db(
relationship.target,
session,
include=[
"team",
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
],
),
mutual=mutual, mutual=mutual,
type=relationship.type, type=relationship.type,
) )

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
from collections.abc import Sequence from collections.abc import Sequence
from datetime import UTC, datetime from datetime import UTC, date, datetime
import json
import math import math
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -12,9 +13,8 @@ from app.calculator import (
calculate_weighted_pp, calculate_weighted_pp,
clamp, clamp,
) )
from app.database.score_token import ScoreToken from app.database.team import TeamMember
from app.database.user import LazerUserStatistics, User from app.models.model import UTCBaseModel
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod, mods_can_get_pp from app.models.mods import APIMod, mods_can_get_pp
from app.models.score import ( from app.models.score import (
INT_TO_MODE, INT_TO_MODE,
@@ -26,15 +26,24 @@ from app.models.score import (
ScoreStatistics, ScoreStatistics,
SoloScoreSubmissionInfo, SoloScoreSubmissionInfo,
) )
from app.models.user import User as APIUser
from .beatmap import Beatmap, BeatmapResp from .beatmap import Beatmap, BeatmapResp
from .beatmapset import Beatmapset, BeatmapsetResp from .beatmapset import BeatmapsetResp
from .best_score import BestScore from .best_score import BestScore
from .lazer_user import User, UserResp
from .monthly_playcounts import MonthlyPlaycounts
from .pp_best_score import PPBestScore
from .relationship import (
Relationship as DBRelationship,
RelationshipType,
)
from .score_token import ScoreToken
from redis import Redis from redis.asyncio import Redis
from sqlalchemy import Column, ColumnExpressionArgument, DateTime from sqlalchemy import Column, ColumnExpressionArgument, DateTime
from sqlalchemy.orm import aliased, joinedload from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import ( from sqlmodel import (
JSON, JSON,
BigInteger, BigInteger,
@@ -43,9 +52,10 @@ from sqlmodel import (
Relationship, Relationship,
SQLModel, SQLModel,
col, col,
false,
func, func,
select, select,
text,
true,
) )
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql._expression_select_cls import SelectOfScalar from sqlmodel.sql._expression_select_cls import SelectOfScalar
@@ -54,7 +64,7 @@ if TYPE_CHECKING:
from app.fetcher import Fetcher from app.fetcher import Fetcher
class ScoreBase(SQLModel): class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# 基本字段 # 基本字段
accuracy: float accuracy: float
map_md5: str = Field(max_length=32, index=True) map_md5: str = Field(max_length=32, index=True)
@@ -94,7 +104,7 @@ class Score(ScoreBase, table=True):
default=None, default=None,
sa_column=Column( sa_column=Column(
BigInteger, BigInteger,
ForeignKey("users.id"), ForeignKey("lazer_users.id"),
index=True, index=True,
), ),
) )
@@ -112,28 +122,13 @@ class Score(ScoreBase, table=True):
gamemode: GameMode = Field(index=True) gamemode: GameMode = Field(index=True)
# optional # optional
beatmap: "Beatmap" = Relationship() beatmap: Beatmap = Relationship()
user: "User" = Relationship() user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
@property @property
def is_perfect_combo(self) -> bool: def is_perfect_combo(self) -> bool:
return self.max_combo == self.beatmap.max_combo return self.max_combo == self.beatmap.max_combo
@staticmethod
def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]:
clause = select(Score).options(
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
)
if with_user:
return clause.options(
joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType]
)
return clause
@staticmethod @staticmethod
def select_clause_unique( def select_clause_unique(
*where_clauses: ColumnExpressionArgument[bool] | bool, *where_clauses: ColumnExpressionArgument[bool] | bool,
@@ -147,18 +142,7 @@ class Score(ScoreBase, table=True):
) )
subq = select(Score, rownum).where(*where_clauses).subquery() subq = select(Score, rownum).where(*where_clauses).subquery()
best = aliased(Score, subq, adapt_on_names=True) best = aliased(Score, subq, adapt_on_names=True)
return ( return select(best).where(subq.c.rn == 1)
select(best)
.where(subq.c.rn == 1)
.options(
joinedload(best.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType]
)
)
class ScoreResp(ScoreBase): class ScoreResp(ScoreBase):
@@ -173,22 +157,21 @@ class ScoreResp(ScoreBase):
ruleset_id: int | None = None ruleset_id: int | None = None
beatmap: BeatmapResp | None = None beatmap: BeatmapResp | None = None
beatmapset: BeatmapsetResp | None = None beatmapset: BeatmapsetResp | None = None
user: APIUser | None = None user: UserResp | None = None
statistics: ScoreStatistics | None = None statistics: ScoreStatistics | None = None
maximum_statistics: ScoreStatistics | None = None maximum_statistics: ScoreStatistics | None = None
rank_global: int | None = None rank_global: int | None = None
rank_country: int | None = None rank_country: int | None = None
@classmethod @classmethod
async def from_db( async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
cls, session: AsyncSession, score: Score, user: User | None = None
) -> "ScoreResp":
from app.utils import convert_db_user_to_api_user
s = cls.model_validate(score.model_dump()) s = cls.model_validate(score.model_dump())
assert score.id assert score.id
s.beatmap = BeatmapResp.from_db(score.beatmap) await score.awaitable_attrs.beatmap
s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset) s.beatmap = await BeatmapResp.from_db(score.beatmap)
s.beatmapset = await BeatmapsetResp.from_db(
score.beatmap.beatmapset, session=session, user=score.user
)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = MODE_TO_INT[score.gamemode] s.ruleset_id = MODE_TO_INT[score.gamemode]
@@ -220,25 +203,30 @@ class ScoreResp(ScoreBase):
s.maximum_statistics = { s.maximum_statistics = {
HitResult.GREAT: score.beatmap.max_combo, HitResult.GREAT: score.beatmap.max_combo,
} }
if user: s.user = await UserResp.from_db(
s.user = await convert_db_user_to_api_user(user) score.user,
session,
include=["statistics", "team", "daily_challenge_user_stats"],
ruleset=score.gamemode,
)
s.rank_global = ( s.rank_global = (
await get_score_position_by_id( await get_score_position_by_id(
session, session,
score.map_md5, score.beatmap_id,
score.id, score.id,
mode=score.gamemode, mode=score.gamemode,
user=user or score.user, user=score.user,
) )
or None or None
) )
s.rank_country = ( s.rank_country = (
await get_score_position_by_id( await get_score_position_by_id(
session, session,
score.map_md5, score.beatmap_id,
score.id, score.id,
score.gamemode, score.gamemode,
user or score.user, score.user,
type=LeaderboardType.COUNTRY,
) )
or None or None
) )
@@ -248,135 +236,137 @@ class ScoreResp(ScoreBase):
async def get_best_id(session: AsyncSession, score_id: int) -> None: async def get_best_id(session: AsyncSession, score_id: int) -> None:
rownum = ( rownum = (
func.row_number() func.row_number()
.over(partition_by=col(BestScore.user_id), order_by=col(BestScore.pp).desc()) .over(
partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()
)
.label("rn") .label("rn")
) )
subq = select(BestScore, rownum).subquery() subq = select(PPBestScore, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.score_id == score_id) stmt = select(subq.c.rn).where(subq.c.score_id == score_id)
result = await session.exec(stmt) result = await session.exec(stmt)
return result.one_or_none() return result.one_or_none()
async def _score_where(
type: LeaderboardType,
beatmap: int,
mode: GameMode,
mods: list[str] | None = None,
user: User | None = None,
) -> list[ColumnElement[bool]] | None:
wheres = [
col(BestScore.beatmap_id) == beatmap,
col(BestScore.gamemode) == mode,
]
if type == LeaderboardType.FRIENDS:
if user and user.is_supporter:
subq = (
select(DBRelationship.target_id)
.where(
DBRelationship.type == RelationshipType.FOLLOW,
DBRelationship.user_id == user.id,
)
.subquery()
)
wheres.append(col(BestScore.user_id).in_(select(subq.c.target_id)))
else:
return None
elif type == LeaderboardType.COUNTRY:
if user and user.is_supporter:
wheres.append(
col(BestScore.user).has(col(User.country_code) == user.country_code)
)
else:
return None
elif type == LeaderboardType.TEAM:
if user:
team_membership = await user.awaitable_attrs.team_membership
if team_membership:
team_id = team_membership.team_id
wheres.append(
col(BestScore.user).has(
col(User.team_membership).has(TeamMember.team_id == team_id)
)
)
if mods:
if user and user.is_supporter:
wheres.append(
text(
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
) # pyright: ignore[reportArgumentType]
)
else:
return None
return wheres
async def get_leaderboard( async def get_leaderboard(
session: AsyncSession, session: AsyncSession,
beatmap_md5: str, beatmap: int,
mode: GameMode, mode: GameMode,
type: LeaderboardType = LeaderboardType.GLOBAL, type: LeaderboardType = LeaderboardType.GLOBAL,
mods: list[APIMod] | None = None, mods: list[str] | None = None,
user: User | None = None, user: User | None = None,
limit: int = 50, limit: int = 50,
) -> list[Score]: ) -> tuple[list[Score], Score | None]:
scores = [] wheres = await _score_where(type, beatmap, mode, mods, user)
if type == LeaderboardType.GLOBAL: if wheres is None:
query = ( return [], None
select(Score) query = (
.where( select(BestScore)
col(Beatmap.beatmap_status).in_( .where(*wheres)
[ .limit(limit)
BeatmapRankStatus.RANKED, .order_by(col(BestScore.total_score).desc())
BeatmapRankStatus.LOVED, )
BeatmapRankStatus.QUALIFIED, if mods:
BeatmapRankStatus.APPROVED, query = query.params(w=json.dumps(mods))
] scores = [s.score for s in await session.exec(query)]
), user_score = None
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
col(Score.passed).is_(True),
Score.mods == mods if user and user.is_supporter else false(),
)
.limit(limit)
.order_by(
col(Score.total_score).desc(),
)
)
result = await session.exec(query)
scores = list[Score](result.all())
elif type == LeaderboardType.FRIENDS and user and user.is_supporter:
# TODO
...
elif type == LeaderboardType.TEAM and user and user.team_membership:
team_id = user.team_membership.team_id
query = (
select(Score)
.join(Beatmap)
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
.where(
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
col(Score.passed).is_(True),
col(Score.user.team_membership).is_not(None),
Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
Score.mods == mods if user and user.is_supporter else false(),
)
.limit(limit)
.order_by(
col(Score.total_score).desc(),
)
)
result = await session.exec(query)
scores = list[Score](result.all())
if user: if user:
user_score = ( self_query = (
await session.exec( select(BestScore)
select(Score).where( .where(BestScore.user_id == user.id)
Score.map_md5 == beatmap_md5, .order_by(col(BestScore.total_score).desc())
Score.gamemode == mode, .limit(1)
Score.user_id == user.id, )
col(Score.passed).is_(True), if mods:
self_query = self_query.where(
text(
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
) )
) ).params(w=json.dumps(mods))
).first() user_bs = (await session.exec(self_query)).first()
if user_bs:
user_score = user_bs.score
if user_score and user_score not in scores: if user_score and user_score not in scores:
scores.append(user_score) scores.append(user_score)
return scores return scores, user_score
async def get_score_position_by_user( async def get_score_position_by_user(
session: AsyncSession, session: AsyncSession,
beatmap_md5: str, beatmap: int,
user: User, user: User,
mode: GameMode, mode: GameMode,
type: LeaderboardType = LeaderboardType.GLOBAL, type: LeaderboardType = LeaderboardType.GLOBAL,
mods: list[APIMod] | None = None, mods: list[str] | None = None,
) -> int: ) -> int:
where_clause = [ wheres = await _score_where(type, beatmap, mode, mods, user=user)
Score.map_md5 == beatmap_md5, if wheres is None:
Score.gamemode == mode, return 0
col(Score.passed).is_(True),
col(Beatmap.beatmap_status).in_(
[
BeatmapRankStatus.RANKED,
BeatmapRankStatus.LOVED,
BeatmapRankStatus.QUALIFIED,
BeatmapRankStatus.APPROVED,
]
),
]
if mods and user.is_supporter:
where_clause.append(Score.mods == mods)
else:
where_clause.append(false())
if type == LeaderboardType.FRIENDS and user.is_supporter:
# TODO
...
elif type == LeaderboardType.TEAM and user.team_membership:
team_id = user.team_membership.team_id
where_clause.append(
col(Score.user.team_membership).is_not(None),
)
where_clause.append(
Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
)
rownum = ( rownum = (
func.row_number() func.row_number()
.over( .over(
partition_by=Score.map_md5, partition_by=col(BestScore.beatmap_id),
order_by=col(Score.total_score).desc(), order_by=col(BestScore.total_score).desc(),
) )
.label("row_number") .label("row_number")
) )
subq = select(Score, rownum).join(Beatmap).where(*where_clause).subquery() subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
stmt = select(subq.c.row_number).where(subq.c.user == user) stmt = select(subq.c.row_number).where(subq.c.user_id == user.id)
result = await session.exec(stmt) result = await session.exec(stmt)
s = result.one_or_none() s = result.one_or_none()
return s if s else 0 return s if s else 0
@@ -384,57 +374,26 @@ async def get_score_position_by_user(
async def get_score_position_by_id( async def get_score_position_by_id(
session: AsyncSession, session: AsyncSession,
beatmap_md5: str, beatmap: int,
score_id: int, score_id: int,
mode: GameMode, mode: GameMode,
user: User | None = None, user: User | None = None,
type: LeaderboardType = LeaderboardType.GLOBAL, type: LeaderboardType = LeaderboardType.GLOBAL,
mods: list[APIMod] | None = None, mods: list[str] | None = None,
) -> int: ) -> int:
where_clause = [ wheres = await _score_where(type, beatmap, mode, mods, user=user)
Score.map_md5 == beatmap_md5, if wheres is None:
Score.id == score_id, return 0
Score.gamemode == mode,
col(Score.passed).is_(True),
col(Beatmap.beatmap_status).in_(
[
BeatmapRankStatus.RANKED,
BeatmapRankStatus.LOVED,
BeatmapRankStatus.QUALIFIED,
BeatmapRankStatus.APPROVED,
]
),
]
if mods and user and user.is_supporter:
where_clause.append(Score.mods == mods)
elif mods:
where_clause.append(false())
rownum = ( rownum = (
func.row_number() func.row_number()
.over( .over(
partition_by=[col(Score.user_id), col(Score.map_md5)], partition_by=col(BestScore.beatmap_id),
order_by=col(Score.total_score).desc(), order_by=col(BestScore.total_score).desc(),
) )
.label("rownum") .label("row_number")
) )
subq = ( subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
select(Score.user_id, Score.id, Score.total_score, rownum) stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
.join(Beatmap)
.where(*where_clause)
.subquery()
)
best_scores = aliased(subq)
overall_rank = (
func.rank().over(order_by=best_scores.c.total_score.desc()).label("global_rank")
)
final_q = (
select(best_scores.c.id, overall_rank)
.select_from(best_scores)
.where(best_scores.c.rownum == 1)
.subquery()
)
stmt = select(final_q.c.global_rank).where(final_q.c.id == score_id)
result = await session.exec(stmt) result = await session.exec(stmt)
s = result.one_or_none() s = result.one_or_none()
return s if s else 0 return s if s else 0
@@ -445,16 +404,38 @@ async def get_user_best_score_in_beatmap(
beatmap: int, beatmap: int,
user: int, user: int,
mode: GameMode | None = None, mode: GameMode | None = None,
) -> Score | None: ) -> BestScore | None:
return ( return (
await session.exec( await session.exec(
Score.select_clause(False) select(BestScore)
.where( .where(
Score.gamemode == mode if mode is not None else True, BestScore.gamemode == mode if mode is not None else true(),
Score.beatmap_id == beatmap, BestScore.beatmap_id == beatmap,
Score.user_id == user, BestScore.user_id == user,
) )
.order_by(col(Score.total_score).desc()) .order_by(col(BestScore.total_score).desc())
)
).first()
# FIXME
async def get_user_best_score_with_mod_in_beatmap(
session: AsyncSession,
beatmap: int,
user: int,
mod: list[str],
mode: GameMode | None = None,
) -> BestScore | None:
return (
await session.exec(
select(BestScore)
.where(
BestScore.gamemode == mode if mode is not None else True,
BestScore.beatmap_id == beatmap,
BestScore.user_id == user,
# BestScore.mods == mod,
)
.order_by(col(BestScore.total_score).desc())
) )
).first() ).first()
@@ -464,13 +445,13 @@ async def get_user_best_pp_in_beatmap(
beatmap: int, beatmap: int,
user: int, user: int,
mode: GameMode, mode: GameMode,
) -> BestScore | None: ) -> PPBestScore | None:
return ( return (
await session.exec( await session.exec(
select(BestScore).where( select(PPBestScore).where(
BestScore.beatmap_id == beatmap, PPBestScore.beatmap_id == beatmap,
BestScore.user_id == user, PPBestScore.user_id == user,
BestScore.gamemode == mode, PPBestScore.gamemode == mode,
) )
) )
).first() ).first()
@@ -480,12 +461,12 @@ async def get_user_best_pp(
session: AsyncSession, session: AsyncSession,
user: int, user: int,
limit: int = 200, limit: int = 200,
) -> Sequence[BestScore]: ) -> Sequence[PPBestScore]:
return ( return (
await session.exec( await session.exec(
select(BestScore) select(PPBestScore)
.where(BestScore.user_id == user) .where(PPBestScore.user_id == user)
.order_by(col(BestScore.pp).desc()) .order_by(col(PPBestScore.pp).desc())
.limit(limit) .limit(limit)
) )
).all() ).all()
@@ -494,27 +475,45 @@ async def get_user_best_pp(
async def process_user( async def process_user(
session: AsyncSession, user: User, score: Score, ranked: bool = False session: AsyncSession, user: User, score: Score, ranked: bool = False
): ):
assert user.id
assert score.id
mod_for_save = list({mod["acronym"] for mod in score.mods})
previous_score_best = await get_user_best_score_in_beatmap( previous_score_best = await get_user_best_score_in_beatmap(
session, score.beatmap_id, user.id, score.gamemode session, score.beatmap_id, user.id, score.gamemode
) )
statistics = None previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap(
session, score.beatmap_id, user.id, mod_for_save, score.gamemode
)
add_to_db = False add_to_db = False
for i in user.lazer_statistics: mouthly_playcount = (
await session.exec(
select(MonthlyPlaycounts).where(
MonthlyPlaycounts.user_id == user.id,
MonthlyPlaycounts.year == date.today().year,
MonthlyPlaycounts.month == date.today().month,
)
)
).first()
if mouthly_playcount is None:
mouthly_playcount = MonthlyPlaycounts(
user_id=user.id, year=date.today().year, month=date.today().month
)
add_to_db = True
statistics = None
for i in await user.awaitable_attrs.statistics:
if i.mode == score.gamemode.value: if i.mode == score.gamemode.value:
statistics = i statistics = i
break break
if statistics is None: if statistics is None:
statistics = LazerUserStatistics( raise ValueError(
mode=score.gamemode.value, f"User {user.id} does not have statistics for mode {score.gamemode.value}"
user_id=user.id,
) )
add_to_db = True
# pc, pt, tth, tts # pc, pt, tth, tts
statistics.total_score += score.total_score statistics.total_score += score.total_score
difference = ( difference = (
score.total_score - previous_score_best.total_score score.total_score - previous_score_best.total_score
if previous_score_best and previous_score_best.id != score.id if previous_score_best
else score.total_score else score.total_score
) )
if difference > 0 and score.passed and ranked: if difference > 0 and score.passed and ranked:
@@ -541,11 +540,48 @@ async def process_user(
statistics.grade_sh -= 1 statistics.grade_sh -= 1
case Rank.A: case Rank.A:
statistics.grade_a -= 1 statistics.grade_a -= 1
else:
previous_score_best = BestScore(
user_id=user.id,
beatmap_id=score.beatmap_id,
gamemode=score.gamemode,
score_id=score.id,
total_score=score.total_score,
rank=score.rank,
mods=mod_for_save,
)
session.add(previous_score_best)
statistics.ranked_score += difference statistics.ranked_score += difference
statistics.level_current = calculate_score_to_level(statistics.ranked_score) statistics.level_current = calculate_score_to_level(statistics.ranked_score)
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)
if score.passed and ranked:
if previous_score_best_mod is not None:
previous_score_best_mod.mods = mod_for_save
previous_score_best_mod.score_id = score.id
previous_score_best_mod.rank = score.rank
previous_score_best_mod.total_score = score.total_score
elif (
previous_score_best is not None and previous_score_best.score_id != score.id
):
session.add(
BestScore(
user_id=user.id,
beatmap_id=score.beatmap_id,
gamemode=score.gamemode,
score_id=score.id,
total_score=score.total_score,
rank=score.rank,
mods=mod_for_save,
)
)
statistics.play_count += 1 statistics.play_count += 1
mouthly_playcount.playcount += 1
statistics.play_time += int((score.ended_at - score.started_at).total_seconds()) statistics.play_time += int((score.ended_at - score.started_at).total_seconds())
statistics.count_100 += score.n100 + score.nkatu
statistics.count_300 += score.n300 + score.ngeki
statistics.count_50 += score.n50
statistics.count_miss += score.nmiss
statistics.total_hits += ( statistics.total_hits += (
score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
) )
@@ -563,11 +599,8 @@ async def process_user(
acc_sum = clamp(acc_sum, 0.0, 100.0) acc_sum = clamp(acc_sum, 0.0, 100.0)
statistics.pp = pp_sum statistics.pp = pp_sum
statistics.hit_accuracy = acc_sum statistics.hit_accuracy = acc_sum
statistics.updated_at = datetime.now(UTC)
if add_to_db: if add_to_db:
session.add(statistics) session.add(mouthly_playcount)
await session.commit() await session.commit()
await session.refresh(user) await session.refresh(user)
@@ -582,6 +615,8 @@ async def process_score(
session: AsyncSession, session: AsyncSession,
redis: Redis, redis: Redis,
) -> Score: ) -> Score:
assert user.id
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
score = Score( score = Score(
accuracy=info.accuracy, accuracy=info.accuracy,
max_combo=info.max_combo, max_combo=info.max_combo,
@@ -611,7 +646,7 @@ async def process_score(
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0), nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0), nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
) )
if info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods): if can_get_pp:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
pp = await asyncio.get_event_loop().run_in_executor( pp = await asyncio.get_event_loop().run_in_executor(
None, calculate_pp, score, beatmap_raw None, calculate_pp, score, beatmap_raw
@@ -621,13 +656,13 @@ async def process_score(
user_id = user.id user_id = user.id
await session.commit() await session.commit()
await session.refresh(score) await session.refresh(score)
if score.passed and ranked: if can_get_pp:
previous_pp_best = await get_user_best_pp_in_beatmap( previous_pp_best = await get_user_best_pp_in_beatmap(
session, beatmap_id, user_id, score.gamemode session, beatmap_id, user_id, score.gamemode
) )
if previous_pp_best is None or score.pp > previous_pp_best.pp: if previous_pp_best is None or score.pp > previous_pp_best.pp:
assert score.id assert score.id
best_score = BestScore( best_score = PPBestScore(
user_id=user_id, user_id=user_id,
score_id=score.id, score_id=score.id,
beatmap_id=beatmap_id, beatmap_id=beatmap_id,
@@ -636,7 +671,7 @@ async def process_score(
acc=score.accuracy, acc=score.accuracy,
) )
session.add(best_score) session.add(best_score)
session.delete(previous_pp_best) if previous_pp_best else None await session.delete(previous_pp_best) if previous_pp_best else None
await session.commit() await session.commit()
await session.refresh(score) await session.refresh(score)
await session.refresh(score_token) await session.refresh(score_token)

View File

@@ -1,15 +1,16 @@
from datetime import datetime from datetime import datetime
from app.models.model import UTCBaseModel
from app.models.score import GameMode from app.models.score import GameMode
from .beatmap import Beatmap from .beatmap import Beatmap
from .user import User from .lazer_user import User
from sqlalchemy import Column, DateTime, Index from sqlalchemy import Column, DateTime, Index
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
class ScoreTokenBase(SQLModel): class ScoreTokenBase(SQLModel, UTCBaseModel):
score_id: int | None = Field(sa_column=Column(BigInteger), default=None) score_id: int | None = Field(sa_column=Column(BigInteger), default=None)
ruleset_id: GameMode ruleset_id: GameMode
playlist_item_id: int | None = Field(default=None) # playlist playlist_item_id: int | None = Field(default=None) # playlist
@@ -34,10 +35,10 @@ class ScoreToken(ScoreTokenBase, table=True):
autoincrement=True, autoincrement=True,
), ),
) )
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
beatmap_id: int = Field(foreign_key="beatmaps.id") beatmap_id: int = Field(foreign_key="beatmaps.id")
user: "User" = Relationship() user: User = Relationship()
beatmap: "Beatmap" = Relationship() beatmap: Beatmap = Relationship()
class ScoreTokenResp(ScoreTokenBase): class ScoreTokenResp(ScoreTokenBase):

View File

@@ -1,14 +1,16 @@
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from sqlalchemy import Column, DateTime from sqlalchemy import Column, DateTime
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .lazer_user import User
class Team(SQLModel, table=True): class Team(SQLModel, UTCBaseModel, table=True):
__tablename__ = "teams" # pyright: ignore[reportAssignmentType] __tablename__ = "teams" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True) id: int | None = Field(default=None, primary_key=True, index=True)
@@ -22,15 +24,19 @@ class Team(SQLModel, table=True):
members: list["TeamMember"] = Relationship(back_populates="team") members: list["TeamMember"] = Relationship(back_populates="team")
class TeamMember(SQLModel, table=True): class TeamMember(SQLModel, UTCBaseModel, table=True):
__tablename__ = "team_members" # pyright: ignore[reportAssignmentType] __tablename__ = "team_members" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True) id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
team_id: int = Field(foreign_key="teams.id") team_id: int = Field(foreign_key="teams.id")
joined_at: datetime = Field( joined_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime) default_factory=datetime.utcnow, sa_column=Column(DateTime)
) )
user: "User" = Relationship(back_populates="team_membership") user: "User" = Relationship(
team: "Team" = Relationship(back_populates="members") back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
)
team: "Team" = Relationship(
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
)

View File

@@ -1,527 +0,0 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from .legacy import LegacyUserStatistics
from .team import TeamMember
from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text
from sqlalchemy.dialects.mysql import VARCHAR
from sqlalchemy.orm import joinedload, selectinload
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel, select
class User(SQLModel, table=True):
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
# 主键
id: int = Field(
default=None, sa_column=Column(BigInteger, primary_key=True, index=True)
)
# 基本信息(匹配 migrations_old 中的结构)
name: str = Field(max_length=32, unique=True, index=True) # 用户名
safe_name: str = Field(max_length=32, unique=True, index=True) # 安全用户名
email: str = Field(max_length=254, unique=True, index=True)
priv: int = Field(default=1) # 权限
pw_bcrypt: str = Field(max_length=60) # bcrypt 哈希密码
country: str = Field(default="CN", max_length=2) # 国家代码
# 状态和时间
silence_end: int = Field(default=0)
donor_end: int = Field(default=0)
creation_time: int = Field(default=0) # Unix 时间戳
latest_activity: int = Field(default=0) # Unix 时间戳
# 游戏相关
preferred_mode: int = Field(default=0) # 偏好游戏模式
play_style: int = Field(default=0) # 游戏风格
# 扩展信息
clan_id: int = Field(default=0)
clan_priv: int = Field(default=0)
custom_badge_name: str | None = Field(default=None, max_length=16)
custom_badge_icon: str | None = Field(default=None, max_length=64)
userpage_content: str | None = Field(default=None, max_length=2048)
api_key: str | None = Field(default=None, max_length=36, unique=True)
# 虚拟字段用于兼容性
@property
def username(self):
return self.name
@property
def country_code(self):
return self.country
@property
def join_date(self):
creation_time = getattr(self, "creation_time", 0)
return (
datetime.fromtimestamp(creation_time)
if creation_time > 0
else datetime.utcnow()
)
@property
def last_visit(self):
latest_activity = getattr(self, "latest_activity", 0)
return datetime.fromtimestamp(latest_activity) if latest_activity > 0 else None
@property
def is_supporter(self):
return self.lazer_profile.is_supporter if self.lazer_profile else False
# 关联关系
lazer_profile: Optional["LazerUserProfile"] = Relationship(back_populates="user")
lazer_statistics: list["LazerUserStatistics"] = Relationship(back_populates="user")
lazer_counts: Optional["LazerUserCounts"] = Relationship(back_populates="user")
lazer_achievements: list["LazerUserAchievement"] = Relationship(
back_populates="user"
)
lazer_profile_sections: list["LazerUserProfileSections"] = Relationship(
back_populates="user"
)
statistics: list["LegacyUserStatistics"] = Relationship(back_populates="user")
team_membership: Optional["TeamMember"] = Relationship(back_populates="user")
daily_challenge_stats: Optional["DailyChallengeStats"] = Relationship(
back_populates="user"
)
rank_history: list["RankHistory"] = Relationship(back_populates="user")
avatar: Optional["UserAvatar"] = Relationship(back_populates="user")
active_banners: list["LazerUserBanners"] = Relationship(back_populates="user")
lazer_badges: list["LazerUserBadge"] = Relationship(back_populates="user")
lazer_monthly_playcounts: list["LazerUserMonthlyPlaycounts"] = Relationship(
back_populates="user"
)
lazer_previous_usernames: list["LazerUserPreviousUsername"] = Relationship(
back_populates="user"
)
lazer_replays_watched: list["LazerUserReplaysWatched"] = Relationship(
back_populates="user"
)
@classmethod
def all_select_option(cls):
return (
joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType]
joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType]
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
joinedload(cls.avatar), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_statistics), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_achievements), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_profile_sections), # pyright: ignore[reportArgumentType]
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
joinedload(cls.team_membership), # pyright: ignore[reportArgumentType]
selectinload(cls.rank_history), # pyright: ignore[reportArgumentType]
selectinload(cls.active_banners), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_badges), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType]
)
@classmethod
def all_select_clause(cls):
return select(cls).options(*cls.all_select_option())
# ============================================
# Lazer API 专用表模型
# ============================================
class LazerUserProfile(SQLModel, table=True):
__tablename__ = "lazer_user_profiles" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
# 基本状态字段
is_active: bool = Field(default=True)
is_bot: bool = Field(default=False)
is_deleted: bool = Field(default=False)
is_online: bool = Field(default=True)
is_supporter: bool = Field(default=False)
is_restricted: bool = Field(default=False)
session_verified: bool = Field(default=False)
has_supported: bool = Field(default=False)
pm_friends_only: bool = Field(default=False)
# 基本资料字段
default_group: str = Field(default="default", max_length=50)
last_visit: datetime | None = Field(default=None, sa_column=Column(DateTime))
join_date: datetime | None = Field(default=None, sa_column=Column(DateTime))
profile_colour: str | None = Field(default=None, max_length=7)
profile_hue: int | None = Field(default=None)
# 社交媒体和个人资料字段
avatar_url: str | None = Field(default=None, max_length=500)
cover_url: str | None = Field(default=None, max_length=500)
discord: str | None = Field(default=None, max_length=100)
twitter: str | None = Field(default=None, max_length=100)
website: str | None = Field(default=None, max_length=500)
title: str | None = Field(default=None, max_length=100)
title_url: str | None = Field(default=None, max_length=500)
interests: str | None = Field(default=None, sa_column=Column(Text))
location: str | None = Field(default=None, max_length=100)
occupation: str | None = Field(default=None) # 职业字段,默认为 None
# 游戏相关字段
playmode: str = Field(default="osu", max_length=10)
support_level: int = Field(default=0)
max_blocks: int = Field(default=100)
max_friends: int = Field(default=500)
post_count: int = Field(default=0)
# 页面内容
page_html: str | None = Field(default=None, sa_column=Column(Text))
page_raw: str | None = Field(default=None, sa_column=Column(Text))
profile_order: str = Field(
default="me,recent_activity,top_ranks,medals,historical,beatmaps,kudosu"
)
# 关联关系
user: "User" = Relationship(back_populates="lazer_profile")
class LazerUserProfileSections(SQLModel, table=True):
__tablename__ = "lazer_user_profile_sections" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
section_name: str = Field(sa_column=Column(VARCHAR(50)))
display_order: int | None = Field(default=None)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_profile_sections")
class LazerUserCountry(SQLModel, table=True):
__tablename__ = "lazer_user_countries" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
code: str = Field(max_length=2)
name: str = Field(max_length=100)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
class LazerUserKudosu(SQLModel, table=True):
__tablename__ = "lazer_user_kudosu" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
available: int = Field(default=0)
total: int = Field(default=0)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
class LazerUserCounts(SQLModel, table=True):
__tablename__ = "lazer_user_counts" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
# 统计计数字段
beatmap_playcounts_count: int = Field(default=0)
comments_count: int = Field(default=0)
favourite_beatmapset_count: int = Field(default=0)
follower_count: int = Field(default=0)
graveyard_beatmapset_count: int = Field(default=0)
guest_beatmapset_count: int = Field(default=0)
loved_beatmapset_count: int = Field(default=0)
mapping_follower_count: int = Field(default=0)
nominated_beatmapset_count: int = Field(default=0)
pending_beatmapset_count: int = Field(default=0)
ranked_beatmapset_count: int = Field(default=0)
ranked_and_approved_beatmapset_count: int = Field(default=0)
unranked_beatmapset_count: int = Field(default=0)
scores_best_count: int = Field(default=0)
scores_first_count: int = Field(default=0)
scores_pinned_count: int = Field(default=0)
scores_recent_count: int = Field(default=0)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
# 关联关系
user: "User" = Relationship(back_populates="lazer_counts")
class LazerUserStatistics(SQLModel, table=True):
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
mode: str = Field(default="osu", max_length=10, primary_key=True)
# 基本命中统计
count_100: int = Field(default=0)
count_300: int = Field(default=0)
count_50: int = Field(default=0)
count_miss: int = Field(default=0)
# 等级信息
level_current: int = Field(default=1)
level_progress: int = Field(default=0)
# 排名信息
global_rank: int | None = Field(default=None)
global_rank_exp: int | None = Field(default=None)
country_rank: int | None = Field(default=None)
# PP 和分数
pp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2)))
pp_exp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2)))
ranked_score: int = Field(default=0, sa_column=Column(BigInteger))
hit_accuracy: float = Field(default=0.00, sa_column=Column(DECIMAL(5, 2)))
total_score: int = Field(default=0, sa_column=Column(BigInteger))
total_hits: int = Field(default=0, sa_column=Column(BigInteger))
maximum_combo: int = Field(default=0)
# 游戏统计
play_count: int = Field(default=0)
play_time: int = Field(default=0) # 秒
replays_watched_by_others: int = Field(default=0)
is_ranked: bool = Field(default=False)
# 成绩等级计数
grade_ss: int = Field(default=0)
grade_ssh: int = Field(default=0)
grade_s: int = Field(default=0)
grade_sh: int = Field(default=0)
grade_a: int = Field(default=0)
# 最高排名记录
rank_highest: int | None = Field(default=None)
rank_highest_updated_at: datetime | None = Field(
default=None, sa_column=Column(DateTime)
)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
# 关联关系
user: "User" = Relationship(back_populates="lazer_statistics")
class LazerUserBanners(SQLModel, table=True):
__tablename__ = "lazer_user_tournament_banners" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
tournament_id: int
image_url: str = Field(sa_column=Column(VARCHAR(500)))
is_active: bool | None = Field(default=None)
# 修正user关系的back_populates值
user: "User" = Relationship(back_populates="active_banners")
class LazerUserAchievement(SQLModel, table=True):
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
achievement_id: int
achieved_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_achievements")
class LazerUserBadge(SQLModel, table=True):
__tablename__ = "lazer_user_badges" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
badge_id: int
awarded_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
description: str | None = Field(default=None, sa_column=Column(Text))
image_url: str | None = Field(default=None, max_length=500)
url: str | None = Field(default=None, max_length=500)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_badges")
class LazerUserMonthlyPlaycounts(SQLModel, table=True):
__tablename__ = "lazer_user_monthly_playcounts" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
start_date: datetime = Field(sa_column=Column(Date))
play_count: int = Field(default=0)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_monthly_playcounts")
class LazerUserPreviousUsername(SQLModel, table=True):
__tablename__ = "lazer_user_previous_usernames" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
username: str = Field(max_length=32)
changed_at: datetime = Field(sa_column=Column(DateTime))
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_previous_usernames")
class LazerUserReplaysWatched(SQLModel, table=True):
__tablename__ = "lazer_user_replays_watched" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
start_date: datetime = Field(sa_column=Column(Date))
count: int = Field(default=0)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_replays_watched")
# 类型转换用的 UserAchievement不是 SQLAlchemy 模型)
@dataclass
class UserAchievement:
achieved_at: datetime
achievement_id: int
class DailyChallengeStats(SQLModel, table=True):
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("users.id"), unique=True)
)
daily_streak_best: int = Field(default=0)
daily_streak_current: int = Field(default=0)
last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
last_weekly_streak: datetime | None = Field(
default=None, sa_column=Column(DateTime)
)
playcount: int = Field(default=0)
top_10p_placements: int = Field(default=0)
top_50p_placements: int = Field(default=0)
weekly_streak_best: int = Field(default=0)
weekly_streak_current: int = Field(default=0)
user: "User" = Relationship(back_populates="daily_challenge_stats")
class RankHistory(SQLModel, table=True):
__tablename__ = "rank_history" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
mode: str = Field(max_length=10)
rank_data: list = Field(sa_column=Column(JSON)) # Array of ranks
date_recorded: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="rank_history")
class UserAvatar(SQLModel, table=True):
__tablename__ = "user_avatars" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
filename: str = Field(max_length=255)
original_filename: str = Field(max_length=255)
file_size: int
mime_type: str = Field(max_length=100)
is_active: bool = Field(default=True)
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
r2_original_url: str | None = Field(default=None, max_length=500)
r2_game_url: str | None = Field(default=None, max_length=500)
user: "User" = Relationship(back_populates="avatar")

View File

@@ -5,15 +5,11 @@ import json
from app.config import settings from app.config import settings
from pydantic import BaseModel from pydantic import BaseModel
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
try:
import redis
except ImportError:
redis = None
def json_serializer(value): def json_serializer(value):
if isinstance(value, BaseModel | SQLModel): if isinstance(value, BaseModel | SQLModel):
@@ -25,10 +21,7 @@ def json_serializer(value):
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer) engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
# Redis 连接 # Redis 连接
if redis: redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
else:
redis_client = None
# 数据库依赖 # 数据库依赖

View File

@@ -8,7 +8,7 @@ from app.log import logger
fetcher: Fetcher | None = None fetcher: Fetcher | None = None
def get_fetcher() -> Fetcher: async def get_fetcher() -> Fetcher:
global fetcher global fetcher
if fetcher is None: if fetcher is None:
fetcher = Fetcher( fetcher = Fetcher(
@@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher:
settings.FETCHER_CALLBACK_URL, settings.FETCHER_CALLBACK_URL,
) )
redis = get_redis() redis = get_redis()
if redis: access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")
access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}") if access_token:
if access_token: fetcher.access_token = str(access_token)
fetcher.access_token = str(access_token) refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}") if refresh_token:
if refresh_token: fetcher.refresh_token = str(refresh_token)
fetcher.refresh_token = str(refresh_token) if not fetcher.access_token or not fetcher.refresh_token:
if not fetcher.access_token or not fetcher.refresh_token: logger.opt(colors=True).info(
logger.opt(colors=True).info( f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>" )
)
return fetcher return fetcher

View File

@@ -1,14 +1,13 @@
from __future__ import annotations from __future__ import annotations
from app.auth import get_token_by_access_token from app.auth import get_token_by_access_token
from app.database import ( from app.database import User
User as DBUser,
)
from .database import get_db from .database import get_db
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer() security = HTTPBearer()
@@ -17,7 +16,7 @@ security = HTTPBearer()
async def get_current_user( async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security), credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> DBUser: ) -> User:
"""获取当前认证用户""" """获取当前认证用户"""
token = credentials.credentials token = credentials.credentials
@@ -27,13 +26,9 @@ async def get_current_user(
return user return user
async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None: async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None:
token_record = await get_token_by_access_token(db, token) token_record = await get_token_by_access_token(db, token)
if not token_record: if not token_record:
return None return None
user = ( user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
await db.exec(
DBUser.all_select_clause().where(DBUser.id == token_record.user_id)
)
).first()
return user return user

View File

@@ -59,16 +59,15 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "") self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"] self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis() redis = get_redis()
if redis: await redis.set(
redis.set( f"fetcher:access_token:{self.client_id}",
f"fetcher:access_token:{self.client_id}", self.access_token,
self.access_token, ex=token_data["expires_in"],
ex=token_data["expires_in"], )
) await redis.set(
redis.set( f"fetcher:refresh_token:{self.client_id}",
f"fetcher:refresh_token:{self.client_id}", self.refresh_token,
self.refresh_token, )
)
async def refresh_access_token(self) -> None: async def refresh_access_token(self) -> None:
async with AsyncClient() as client: async with AsyncClient() as client:
@@ -87,13 +86,12 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "") self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"] self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis() redis = get_redis()
if redis: await redis.set(
redis.set( f"fetcher:access_token:{self.client_id}",
f"fetcher:access_token:{self.client_id}", self.access_token,
self.access_token, ex=token_data["expires_in"],
ex=token_data["expires_in"], )
) await redis.set(
redis.set( f"fetcher:refresh_token:{self.client_id}",
f"fetcher:refresh_token:{self.client_id}", self.refresh_token,
self.refresh_token, )
)

View File

@@ -4,7 +4,7 @@ from ._base import BaseFetcher
from httpx import AsyncClient from httpx import AsyncClient
from loguru import logger from loguru import logger
import redis import redis.asyncio as redis
class OsuDotDirectFetcher(BaseFetcher): class OsuDotDirectFetcher(BaseFetcher):
@@ -22,8 +22,8 @@ class OsuDotDirectFetcher(BaseFetcher):
async def get_or_fetch_beatmap_raw( async def get_or_fetch_beatmap_raw(
self, redis: redis.Redis, beatmap_id: int self, redis: redis.Redis, beatmap_id: int
) -> str: ) -> str:
if redis.exists(f"beatmap:{beatmap_id}:raw"): if await redis.exists(f"beatmap:{beatmap_id}:raw"):
return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType] return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
raw = await self.get_beatmap_raw(beatmap_id) raw = await self.get_beatmap_raw(beatmap_id)
redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24) await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
return raw return raw

View File

@@ -42,11 +42,12 @@ class Language(IntEnum):
KOREAN = 6 KOREAN = 6
FRENCH = 7 FRENCH = 7
GERMAN = 8 GERMAN = 8
ITALIAN = 9 SWEDISH = 9
SPANISH = 10 ITALIAN = 10
RUSSIAN = 11 SPANISH = 11
POLISH = 12 RUSSIAN = 12
OTHER = 13 POLISH = 13
OTHER = 14
class BeatmapAttributes(BaseModel): class BeatmapAttributes(BaseModel):

View File

@@ -1,7 +1,6 @@
# OAuth 相关模型 # OAuth 相关模型
from __future__ import annotations from __future__ import annotations
from typing import List
from pydantic import BaseModel from pydantic import BaseModel
@@ -39,18 +38,21 @@ class OAuthErrorResponse(BaseModel):
class RegistrationErrorResponse(BaseModel): class RegistrationErrorResponse(BaseModel):
"""注册错误响应模型""" """注册错误响应模型"""
form_error: dict form_error: dict
class UserRegistrationErrors(BaseModel): class UserRegistrationErrors(BaseModel):
"""用户注册错误模型""" """用户注册错误模型"""
username: List[str] = []
user_email: List[str] = [] username: list[str] = []
password: List[str] = [] user_email: list[str] = []
password: list[str] = []
class RegistrationRequestErrors(BaseModel): class RegistrationRequestErrors(BaseModel):
"""注册请求错误模型""" """注册请求错误模型"""
message: str | None = None message: str | None = None
redirect: str | None = None redirect: str | None = None
user: UserRegistrationErrors | None = None user: UserRegistrationErrors | None = None

View File

@@ -132,7 +132,7 @@ class HitResultInt(IntEnum):
class LeaderboardType(Enum): class LeaderboardType(Enum):
GLOBAL = "global" GLOBAL = "global"
FRIENDS = "friends" FRIENDS = "friend"
COUNTRY = "country" COUNTRY = "country"
TEAM = "team" TEAM = "team"

View File

@@ -2,15 +2,11 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING
from .score import GameMode from .model import UTCBaseModel
from pydantic import BaseModel from pydantic import BaseModel
if TYPE_CHECKING:
from app.database import LazerUserAchievement, Team
class PlayStyle(str, Enum): class PlayStyle(str, Enum):
MOUSE = "mouse" MOUSE = "mouse"
@@ -77,24 +73,7 @@ class MonthlyPlaycount(BaseModel):
count: int count: int
class UserAchievement(BaseModel): class RankHighest(UTCBaseModel):
achieved_at: datetime
achievement_id: int
# 添加数据库模型转换方法
def to_db_model(self, user_id: int) -> "LazerUserAchievement":
from app.database import (
LazerUserAchievement,
)
return LazerUserAchievement(
user_id=user_id,
achievement_id=self.achievement_id,
achieved_at=self.achieved_at,
)
class RankHighest(BaseModel):
rank: int rank: int
updated_at: datetime updated_at: datetime
@@ -104,115 +83,6 @@ class RankHistory(BaseModel):
data: list[int] data: list[int]
class DailyChallengeStats(BaseModel):
daily_streak_best: int = 0
daily_streak_current: int = 0
last_update: datetime | None = None
last_weekly_streak: datetime | None = None
playcount: int = 0
top_10p_placements: int = 0
top_50p_placements: int = 0
user_id: int
weekly_streak_best: int = 0
weekly_streak_current: int = 0
class Page(BaseModel): class Page(BaseModel):
html: str = "" html: str = ""
raw: str = "" raw: str = ""
class User(BaseModel):
# 基本信息
id: int
username: str
avatar_url: str
country_code: str
default_group: str = "default"
is_active: bool = True
is_bot: bool = False
is_deleted: bool = False
is_online: bool = True
is_supporter: bool = False
is_restricted: bool = False
last_visit: datetime | None = None
pm_friends_only: bool = False
profile_colour: str | None = None
# 个人资料
cover_url: str | None = None
discord: str | None = None
has_supported: bool = False
interests: str | None = None
join_date: datetime
location: str | None = None
max_blocks: int = 100
max_friends: int = 500
occupation: str | None = None
playmode: GameMode = GameMode.OSU
playstyle: list[PlayStyle] = []
post_count: int = 0
profile_hue: int | None = None
profile_order: list[str] = [
"me",
"recent_activity",
"top_ranks",
"medals",
"historical",
"beatmaps",
"kudosu",
]
title: str | None = None
title_url: str | None = None
twitter: str | None = None
website: str | None = None
session_verified: bool = False
support_level: int = 0
# 关联对象
country: Country
cover: Cover
kudosu: Kudosu
statistics: Statistics
statistics_rulesets: dict[str, Statistics]
# 计数信息
beatmap_playcounts_count: int = 0
comments_count: int = 0
favourite_beatmapset_count: int = 0
follower_count: int = 0
graveyard_beatmapset_count: int = 0
guest_beatmapset_count: int = 0
loved_beatmapset_count: int = 0
mapping_follower_count: int = 0
nominated_beatmapset_count: int = 0
pending_beatmapset_count: int = 0
ranked_beatmapset_count: int = 0
ranked_and_approved_beatmapset_count: int = 0
unranked_beatmapset_count: int = 0
scores_best_count: int = 0
scores_first_count: int = 0
scores_pinned_count: int = 0
scores_recent_count: int = 0
# 历史数据
account_history: list[dict] = []
active_tournament_banner: dict | None = None
active_tournament_banners: list[dict] = []
badges: list[dict] = []
current_season_stats: dict | None = None
daily_challenge_user_stats: DailyChallengeStats | None = None
groups: list[dict] = []
monthly_playcounts: list[MonthlyPlaycount] = []
page: Page = Page()
previous_usernames: list[str] = []
rank_highest: RankHighest | None = None
rank_history: RankHistory | None = None
rankHistory: RankHistory | None = None # 兼容性别名
replays_watched_counts: list[dict] = []
team: "Team | None" = None
user_achievements: list[UserAchievement] = []
class APIUser(BaseModel):
id: int

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from datetime import UTC, datetime, timedelta
import re import re
from app.auth import ( from app.auth import (
@@ -12,17 +12,21 @@ from app.auth import (
store_token, store_token,
) )
from app.config import settings from app.config import settings
from app.database import User as DBUser from app.database import DailyChallengeStats, User
from app.database.statistics import UserStatistics
from app.dependencies import get_db from app.dependencies import get_db
from app.log import logger
from app.models.oauth import ( from app.models.oauth import (
OAuthErrorResponse, OAuthErrorResponse,
RegistrationRequestErrors, RegistrationRequestErrors,
TokenResponse, TokenResponse,
UserRegistrationErrors, UserRegistrationErrors,
) )
from app.models.score import GameMode
from fastapi import APIRouter, Depends, Form from fastapi import APIRouter, Depends, Form
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlalchemy import text
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -110,12 +114,12 @@ async def register_user(
email_errors = validate_email(user_email) email_errors = validate_email(user_email)
password_errors = validate_password(user_password) password_errors = validate_password(user_password)
result = await db.exec(select(DBUser).where(DBUser.name == user_username)) result = await db.exec(select(User).where(User.username == user_username))
existing_user = result.first() existing_user = result.first()
if existing_user: if existing_user:
username_errors.append("Username is already taken") username_errors.append("Username is already taken")
result = await db.exec(select(DBUser).where(DBUser.email == user_email)) result = await db.exec(select(User).where(User.email == user_email))
existing_email = result.first() existing_email = result.first()
if existing_email: if existing_email:
email_errors.append("Email is already taken") email_errors.append("Email is already taken")
@@ -135,119 +139,41 @@ async def register_user(
try: try:
# 创建新用户 # 创建新用户
from datetime import datetime # 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
import time result = await db.execute( # pyright: ignore[reportDeprecated]
text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
)
)
next_id = result.one()[0]
if next_id <= 2:
await db.execute(text("ALTER TABLE lazer_users AUTO_INCREMENT = 3"))
await db.commit()
new_user = DBUser( new_user = User(
name=user_username, username=user_username,
safe_name=user_username.lower(), # 安全用户名(小写)
email=user_email, email=user_email,
pw_bcrypt=get_password_hash(user_password), pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限 priv=1, # 普通用户权限
country="CN", # 默认国家 country_code="CN", # 默认国家
creation_time=int(time.time()), join_date=datetime.now(UTC),
latest_activity=int(time.time()), last_visit=datetime.now(UTC),
preferred_mode=0, # 默认模式
play_style=0, # 默认游戏风格
) )
db.add(new_user) db.add(new_user)
await db.commit() await db.commit()
await db.refresh(new_user) await db.refresh(new_user)
assert new_user.id is not None, "New user ID should not be None"
# 保存用户ID因为会话可能会关闭 for i in GameMode:
user_id = new_user.id statistics = UserStatistics(mode=i, user_id=new_user.id)
db.add(statistics)
if user_id <= 2: daily_challenge_user_stats = DailyChallengeStats(user_id=new_user.id)
await db.rollback() db.add(daily_challenge_user_stats)
try:
from sqlalchemy import text
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3"))
await db.commit()
# 重新创建用户
new_user = DBUser(
name=user_username,
safe_name=user_username.lower(),
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1,
country="CN",
creation_time=int(time.time()),
latest_activity=int(time.time()),
preferred_mode=0,
play_style=0,
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
user_id = new_user.id
# 最终检查ID是否有效
if user_id <= 2:
await db.rollback()
errors = RegistrationRequestErrors(
message=(
"Failed to create account with valid ID. "
"Please contact support."
)
)
return JSONResponse(
status_code=500, content={"form_error": errors.model_dump()}
)
except Exception as fix_error:
await db.rollback()
print(f"Failed to fix AUTO_INCREMENT: {fix_error}")
errors = RegistrationRequestErrors(
message="Failed to create account with valid ID. Please try again."
)
return JSONResponse(
status_code=500, content={"form_error": errors.model_dump()}
)
# 创建默认的 lazer_profile
from app.database.user import LazerUserProfile
lazer_profile = LazerUserProfile(
user_id=user_id,
is_active=True,
is_bot=False,
is_deleted=False,
is_online=True,
is_supporter=False,
is_restricted=False,
session_verified=False,
has_supported=False,
pm_friends_only=False,
default_group="default",
join_date=datetime.utcnow(),
playmode="osu",
support_level=0,
max_blocks=50,
max_friends=250,
post_count=0,
)
db.add(lazer_profile)
await db.commit() await db.commit()
except Exception:
# 返回成功响应
return JSONResponse(
status_code=201,
content={"message": "Account created successfully", "user_id": user_id},
)
except Exception as e:
await db.rollback() await db.rollback()
# 打印详细错误信息用于调试 # 打印详细错误信息用于调试
print(f"Registration error: {e}") logger.exception(f"Registration error for user {user_username}")
import traceback
traceback.print_exc()
# 返回通用错误 # 返回通用错误
errors = RegistrationRequestErrors( errors = RegistrationRequestErrors(
@@ -323,6 +249,7 @@ async def oauth_token(
refresh_token_str = generate_refresh_token() refresh_token_str = generate_refresh_token()
# 存储令牌 # 存储令牌
assert user.id
await store_token( await store_token(
db, db,
user.id, user.id,

View File

@@ -5,12 +5,7 @@ import hashlib
import json import json
from app.calculator import calculate_beatmap_attribute from app.calculator import calculate_beatmap_attribute
from app.database import ( from app.database import Beatmap, BeatmapResp, User
Beatmap,
BeatmapResp,
User as DBUser,
)
from app.database.beatmapset import Beatmapset
from app.dependencies.database import get_db, get_redis from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
@@ -27,9 +22,8 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query from fastapi import Depends, HTTPException, Query
from httpx import HTTPError, HTTPStatusError from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel from pydantic import BaseModel
from redis import Redis from redis.asyncio import Redis
import rosu_pp_py as rosu import rosu_pp_py as rosu
from sqlalchemy.orm import joinedload
from sqlmodel import col, select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -39,7 +33,7 @@ async def lookup_beatmap(
id: int | None = Query(default=None, alias="id"), id: int | None = Query(default=None, alias="id"),
md5: str | None = Query(default=None, alias="checksum"), md5: str | None = Query(default=None, alias="checksum"),
filename: str | None = Query(default=None, alias="filename"), filename: str | None = Query(default=None, alias="filename"),
current_user: DBUser = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
@@ -56,19 +50,19 @@ async def lookup_beatmap(
if beatmap is None: if beatmap is None:
raise HTTPException(status_code=404, detail="Beatmap not found") raise HTTPException(status_code=404, detail="Beatmap not found")
return BeatmapResp.from_db(beatmap) return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap( async def get_beatmap(
bid: int, bid: int,
current_user: DBUser = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
try: try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid) beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
return BeatmapResp.from_db(beatmap) return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
except HTTPError: except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found") raise HTTPException(status_code=404, detail="Beatmap not found")
@@ -81,42 +75,27 @@ class BatchGetResp(BaseModel):
@router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp) @router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp)
async def batch_get_beatmaps( async def batch_get_beatmaps(
b_ids: list[int] = Query(alias="id", default_factory=list), b_ids: list[int] = Query(alias="id", default_factory=list),
current_user: DBUser = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
if not b_ids: if not b_ids:
# select 50 beatmaps by last_updated # select 50 beatmaps by last_updated
beatmaps = ( beatmaps = (
await db.exec( await db.exec(
select(Beatmap) select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
.options(
joinedload(
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
).selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
)
.order_by(col(Beatmap.last_updated).desc())
.limit(50)
) )
).all() ).all()
else: else:
beatmaps = ( beatmaps = (
await db.exec( await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
select(Beatmap)
.options(
joinedload(
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
).selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
)
.where(col(Beatmap.id).in_(b_ids))
.limit(50)
)
).all() ).all()
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps]) return BatchGetResp(
beatmaps=[
await BeatmapResp.from_db(bm, session=db, user=current_user)
for bm in beatmaps
]
)
@router.post( @router.post(
@@ -126,7 +105,7 @@ async def batch_get_beatmaps(
) )
async def get_beatmap_attributes( async def get_beatmap_attributes(
beatmap: int, beatmap: int,
current_user: DBUser = Depends(get_current_user), current_user: User = Depends(get_current_user),
mods: list[str] = Query(default_factory=list), mods: list[str] = Query(default_factory=list),
ruleset: GameMode | None = Query(default=None), ruleset: GameMode | None = Query(default=None),
ruleset_id: int | None = Query(default=None), ruleset_id: int | None = Query(default=None),
@@ -153,8 +132,8 @@ async def get_beatmap_attributes(
f"beatmap:{beatmap}:{ruleset}:" f"beatmap:{beatmap}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
) )
if redis.exists(key): if await redis.exists(key):
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType] return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try: try:
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap) resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
@@ -164,7 +143,7 @@ async def get_beatmap_attributes(
) )
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue] except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
redis.set(key, attr.model_dump_json()) await redis.set(key, attr.model_dump_json())
return attr return attr
except HTTPStatusError: except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found") raise HTTPException(status_code=404, detail="Beatmap not found")

View File

@@ -1,10 +1,8 @@
from __future__ import annotations from __future__ import annotations
from app.database import ( from typing import Literal
Beatmapset,
BeatmapsetResp, from app.database import Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
User as DBUser,
)
from app.dependencies.database import get_db from app.dependencies.database import get_db
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
@@ -12,9 +10,9 @@ from app.fetcher import Fetcher
from .api_router import router from .api_router import router
from fastapi import Depends, HTTPException from fastapi import Depends, Form, HTTPException, Query
from fastapi.responses import RedirectResponse
from httpx import HTTPStatusError from httpx import HTTPStatusError
from sqlalchemy.orm import selectinload
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -22,17 +20,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp) @router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
async def get_beatmapset( async def get_beatmapset(
sid: int, sid: int,
current_user: DBUser = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
beatmapset = ( beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
await db.exec(
select(Beatmapset)
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmapset.id == sid)
)
).first()
if not beatmapset: if not beatmapset:
try: try:
resp = await fetcher.get_beatmapset(sid) resp = await fetcher.get_beatmapset(sid)
@@ -40,5 +32,55 @@ async def get_beatmapset(
except HTTPStatusError: except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found") raise HTTPException(status_code=404, detail="Beatmapset not found")
else: else:
resp = BeatmapsetResp.from_db(beatmapset) resp = await BeatmapsetResp.from_db(
beatmapset, session=db, include=["recent_favourites"], user=current_user
)
return resp return resp
@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"])
async def download_beatmapset(
beatmapset: int,
no_video: bool = Query(True, alias="noVideo"),
current_user: User = Depends(get_current_user),
):
if current_user.country_code == "CN":
return RedirectResponse(
f"https://txy1.sayobot.cn/beatmaps/download/"
f"{'novideo' if no_video else 'full'}/{beatmapset}?server=auto"
)
else:
return RedirectResponse(
f"https://api.nerinyan.moe/d/{beatmapset}?noVideo={no_video}"
)
@router.post("/beatmapsets/{beatmapset}/favourites", tags=["beatmapset"])
async def favourite_beatmapset(
beatmapset: int,
action: Literal["favourite", "unfavourite"] = Form(),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
existing_favourite = (
await db.exec(
select(FavouriteBeatmapset).where(
FavouriteBeatmapset.user_id == current_user.id,
FavouriteBeatmapset.beatmapset_id == beatmapset,
)
)
).first()
if action == "favourite" and existing_favourite:
raise HTTPException(status_code=400, detail="Already favourited")
elif action == "unfavourite" and not existing_favourite:
raise HTTPException(status_code=400, detail="Not favourited")
if action == "favourite":
favourite = FavouriteBeatmapset(
user_id=current_user.id, beatmapset_id=beatmapset
)
db.add(favourite)
else:
await db.delete(existing_favourite)
await db.commit()

View File

@@ -1,28 +1,27 @@
from __future__ import annotations from __future__ import annotations
from typing import Literal from app.database import User, UserResp
from app.database.lazer_user import ALL_INCLUDED
from app.database import (
User as DBUser,
)
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.models.user import ( from app.dependencies.database import get_db
User as ApiUser, from app.models.score import GameMode
)
from app.utils import convert_db_user_to_api_user
from .api_router import router from .api_router import router
from fastapi import Depends from fastapi import Depends
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/me/{ruleset}", response_model=ApiUser) @router.get("/me/{ruleset}", response_model=UserResp)
@router.get("/me/", response_model=ApiUser) @router.get("/me/", response_model=UserResp)
async def get_user_info_default( async def get_user_info_default(
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu", ruleset: GameMode | None = None,
current_user: DBUser = Depends(get_current_user), current_user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
): ):
"""获取当前用户信息默认使用osu模式""" return await UserResp.from_db(
# 默认使用osu模式 current_user,
api_user = await convert_db_user_to_api_user(current_user, ruleset) session,
return api_user ALL_INCLUDED,
ruleset,
)

View File

@@ -8,7 +8,7 @@ from app.dependencies.user import get_current_user
from .api_router import router from .api_router import router
from fastapi import Depends, HTTPException, Query, Request from fastapi import Depends, HTTPException, Query, Request
from sqlalchemy.orm import joinedload from pydantic import BaseModel
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -26,17 +26,19 @@ async def get_relationship(
else RelationshipType.BLOCK else RelationshipType.BLOCK
) )
relationships = await db.exec( relationships = await db.exec(
select(Relationship) select(Relationship).where(
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
.where(
Relationship.user_id == current_user.id, Relationship.user_id == current_user.id,
Relationship.type == relationship_type, Relationship.type == relationship_type,
) )
) )
return [await RelationshipResp.from_db(db, rel) for rel in relationships] return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
@router.post("/friends", tags=["relationship"], response_model=RelationshipResp) class AddFriendResp(BaseModel):
user_relation: RelationshipResp
@router.post("/friends", tags=["relationship"], response_model=AddFriendResp)
@router.post("/blocks", tags=["relationship"]) @router.post("/blocks", tags=["relationship"])
async def add_relationship( async def add_relationship(
request: Request, request: Request,
@@ -87,14 +89,10 @@ async def add_relationship(
if origin_type == RelationshipType.FOLLOW: if origin_type == RelationshipType.FOLLOW:
relationship = ( relationship = (
await db.exec( await db.exec(
select(Relationship) select(Relationship).where(
.where(
Relationship.user_id == current_user_id, Relationship.user_id == current_user_id,
Relationship.target_id == target, Relationship.target_id == target,
) )
.options(
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
)
) )
).first() ).first()
assert relationship, "Relationship should exist after commit" assert relationship, "Relationship should exist after commit"

View File

@@ -6,8 +6,10 @@ from app.dependencies.fetcher import get_fetcher
from app.fetcher import Fetcher from app.fetcher import Fetcher
from app.models.room import MultiplayerRoom, MultiplayerRoomState, Room from app.models.room import MultiplayerRoom, MultiplayerRoomState, Room
from api_router import router from .api_router import router
from fastapi import Depends, HTTPException, Query from fastapi import Depends, HTTPException, Query
from redis.asyncio import Redis
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -21,6 +23,7 @@ async def get_all_rooms(
), # TODO: 对房间根据分类进行筛选(真的有人用这功能吗) ), # TODO: 对房间根据分类进行筛选(真的有人用这功能吗)
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
redis: Redis = Depends(get_redis),
): ):
all_roomID = (await db.exec(select(RoomIndex))).all() all_roomID = (await db.exec(select(RoomIndex))).all()
redis = get_redis() redis = get_redis()

View File

@@ -1,11 +1,7 @@
from __future__ import annotations from __future__ import annotations
from app.database import ( from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User
User as DBUser, from app.database.score import get_leaderboard, process_score, process_user
)
from app.database.beatmap import Beatmap
from app.database.score import Score, ScoreResp, process_score, process_user
from app.database.score_token import ScoreToken, ScoreTokenResp
from app.dependencies.database import get_db, get_redis from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
@@ -13,6 +9,7 @@ from app.models.beatmap import BeatmapRankStatus
from app.models.score import ( from app.models.score import (
INT_TO_MODE, INT_TO_MODE,
GameMode, GameMode,
LeaderboardType,
Rank, Rank,
SoloScoreSubmissionInfo, SoloScoreSubmissionInfo,
) )
@@ -21,9 +18,9 @@ from .api_router import router
from fastapi import Depends, Form, HTTPException, Query from fastapi import Depends, Form, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from redis import Redis from redis.asyncio import Redis
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlmodel import col, select, true from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -37,44 +34,26 @@ class BeatmapScores(BaseModel):
) )
async def get_beatmap_scores( async def get_beatmap_scores(
beatmap: int, beatmap: int,
mode: GameMode,
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询 legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
mode: GameMode | None = Query(None), mods: list[str] = Query(default_factory=set, alias="mods[]"),
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询 type: LeaderboardType = Query(LeaderboardType.GLOBAL),
type: str = Query(None), current_user: User = Depends(get_current_user),
current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
limit: int = Query(50, ge=1, le=200),
): ):
if legacy_only: if legacy_only:
raise HTTPException( raise HTTPException(
status_code=404, detail="this server only contains lazer scores" status_code=404, detail="this server only contains lazer scores"
) )
all_scores = ( all_scores, user_score = await get_leaderboard(
await db.exec( db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods
Score.select_clause_unique( )
Score.beatmap_id == beatmap,
col(Score.passed).is_(True),
Score.gamemode == mode if mode is not None else true(),
)
)
).all()
user_score = (
await db.exec(
Score.select_clause_unique(
Score.beatmap_id == beatmap,
Score.user_id == current_user.id,
col(Score.passed).is_(True),
Score.gamemode == mode if mode is not None else true(),
)
)
).first()
return BeatmapScores( return BeatmapScores(
scores=[await ScoreResp.from_db(db, score, score.user) for score in all_scores], scores=[await ScoreResp.from_db(db, score) for score in all_scores],
userScore=await ScoreResp.from_db(db, user_score, user_score.user) userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
if user_score
else None,
) )
@@ -94,7 +73,7 @@ async def get_user_beatmap_score(
legacy_only: bool = Query(None), legacy_only: bool = Query(None),
mode: str = Query(None), mode: str = Query(None),
mods: str = Query(None), # TODO:添加mods筛选 mods: str = Query(None), # TODO:添加mods筛选
current_user: DBUser = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
if legacy_only: if legacy_only:
@@ -103,7 +82,7 @@ async def get_user_beatmap_score(
) )
user_score = ( user_score = (
await db.exec( await db.exec(
Score.select_clause(True) select(Score)
.where( .where(
Score.gamemode == mode if mode is not None else True, Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap, Score.beatmap_id == beatmap,
@@ -120,7 +99,7 @@ async def get_user_beatmap_score(
else: else:
return BeatmapUserScore( return BeatmapUserScore(
position=user_score.position if user_score.position is not None else 0, position=user_score.position if user_score.position is not None else 0,
score=await ScoreResp.from_db(db, user_score, user_score.user), score=await ScoreResp.from_db(db, user_score),
) )
@@ -134,7 +113,7 @@ async def get_user_all_beatmap_scores(
user: int, user: int,
legacy_only: bool = Query(None), legacy_only: bool = Query(None),
ruleset: str = Query(None), ruleset: str = Query(None),
current_user: DBUser = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
if legacy_only: if legacy_only:
@@ -143,7 +122,7 @@ async def get_user_all_beatmap_scores(
) )
all_user_scores = ( all_user_scores = (
await db.exec( await db.exec(
Score.select_clause() select(Score)
.where( .where(
Score.gamemode == ruleset if ruleset is not None else True, Score.gamemode == ruleset if ruleset is not None else True,
Score.beatmap_id == beatmap, Score.beatmap_id == beatmap,
@@ -153,9 +132,7 @@ async def get_user_all_beatmap_scores(
) )
).all() ).all()
return [ return [await ScoreResp.from_db(db, score) for score in all_user_scores]
await ScoreResp.from_db(db, score, current_user) for score in all_user_scores
]
@router.post( @router.post(
@@ -166,9 +143,10 @@ async def create_solo_score(
version_hash: str = Form(""), version_hash: str = Form(""),
beatmap_hash: str = Form(), beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3), ruleset_id: int = Form(..., ge=0, le=3),
current_user: DBUser = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
assert current_user.id
async with db: async with db:
score_token = ScoreToken( score_token = ScoreToken(
user_id=current_user.id, user_id=current_user.id,
@@ -190,7 +168,7 @@ async def submit_solo_score(
beatmap: int, beatmap: int,
token: int, token: int,
info: SoloScoreSubmissionInfo, info: SoloScoreSubmissionInfo,
current_user: DBUser = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher), fetcher=Depends(get_fetcher),
@@ -210,9 +188,7 @@ async def submit_solo_score(
if score_token.score_id: if score_token.score_id:
score = ( score = (
await db.exec( await db.exec(
select(Score) select(Score).where(
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
.where(
Score.id == score_token.score_id, Score.id == score_token.score_id,
Score.user_id == current_user.id, Score.user_id == current_user.id,
) )
@@ -246,8 +222,6 @@ async def submit_solo_score(
score_id = score.id score_id = score.id
score_token.score_id = score_id score_token.score_id = score_id
await process_user(db, current_user, score, ranked) await process_user(db, current_user, score, ranked)
score = ( score = (await db.exec(select(Score).where(Score.id == score_id))).first()
await db.exec(Score.select_clause().where(Score.id == score_id))
).first()
assert score is not None assert score is not None
return await ScoreResp.from_db(db, score, current_user) return await ScoreResp.from_db(db, score)

View File

@@ -1,12 +1,9 @@
from __future__ import annotations from __future__ import annotations
from typing import Literal from app.database import User, UserResp
from app.database.lazer_user import SEARCH_INCLUDED
from app.database import User as DBUser
from app.dependencies.database import get_db from app.dependencies.database import get_db
from app.models.score import INT_TO_MODE from app.models.score import GameMode
from app.models.user import User as ApiUser
from app.utils import convert_db_user_to_api_user
from .api_router import router from .api_router import router
@@ -17,28 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import col from sqlmodel.sql.expression import col
# ---------- Shared Utility ----------
async def get_user_by_lookup(
db: AsyncSession, lookup: str, key: str = "id"
) -> DBUser | None:
"""根据查找方式获取用户"""
if key == "id":
try:
user_id = int(lookup)
result = await db.exec(select(DBUser).where(DBUser.id == user_id))
return result.first()
except ValueError:
return None
elif key == "username":
result = await db.exec(select(DBUser).where(DBUser.name == lookup))
return result.first()
else:
return None
# ---------- Batch Users ----------
class BatchUserResponse(BaseModel): class BatchUserResponse(BaseModel):
users: list[ApiUser] users: list[UserResp]
@router.get("/users", response_model=BatchUserResponse) @router.get("/users", response_model=BatchUserResponse)
@@ -51,75 +28,44 @@ async def get_users(
): ):
if user_ids: if user_ids:
searched_users = ( searched_users = (
await session.exec( await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids))
)
).all() ).all()
else: else:
searched_users = ( searched_users = (await session.exec(select(User).limit(50))).all()
await session.exec(DBUser.all_select_clause().limit(50))
).all()
return BatchUserResponse( return BatchUserResponse(
users=[ users=[
await convert_db_user_to_api_user( await UserResp.from_db(
searched_user, ruleset=INT_TO_MODE[searched_user.preferred_mode].value searched_user,
session,
include=SEARCH_INCLUDED,
) )
for searched_user in searched_users for searched_user in searched_users
] ]
) )
# # ---------- Individual User ---------- @router.get("/users/{user}/{ruleset}", response_model=UserResp)
# @router.get("/users/{user_lookup}/{mode}", response_model=ApiUser) @router.get("/users/{user}/", response_model=UserResp)
# @router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser) @router.get("/users/{user}", response_model=UserResp)
# async def get_user_with_mode(
# user_lookup: str,
# mode: Literal["osu", "taiko", "fruits", "mania"],
# key: Literal["id", "username"] = Query("id"),
# current_user: DBUser = Depends(get_current_user),
# db: AsyncSession = Depends(get_db),
# ):
# """获取指定游戏模式的用户信息"""
# user = await get_user_by_lookup(db, user_lookup, key)
# if not user:
# raise HTTPException(status_code=404, detail="User not found")
# return await convert_db_user_to_api_user(user, mode)
# @router.get("/users/{user_lookup}", response_model=ApiUser)
# @router.get("/users/{user_lookup}/", response_model=ApiUser)
# async def get_user_default(
# user_lookup: str,
# key: Literal["id", "username"] = Query("id"),
# current_user: DBUser = Depends(get_current_user),
# db: AsyncSession = Depends(get_db),
# ):
# """获取用户信息默认使用osu模式但包含所有模式的统计信息"""
# user = await get_user_by_lookup(db, user_lookup, key)
# if not user:
# raise HTTPException(status_code=404, detail="User not found")
# return await convert_db_user_to_api_user(user, "osu")
@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
@router.get("/users/{user}/", response_model=ApiUser)
@router.get("/users/{user}", response_model=ApiUser)
async def get_user_info( async def get_user_info(
user: str, user: str,
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu", ruleset: GameMode | None = None,
session: AsyncSession = Depends(get_db), session: AsyncSession = Depends(get_db),
): ):
searched_user = ( searched_user = (
await session.exec( await session.exec(
DBUser.all_select_clause().where( select(User).where(
DBUser.id == int(user) User.id == int(user)
if user.isdigit() if user.isdigit()
else DBUser.name == user.removeprefix("@") else User.username == user.removeprefix("@")
) )
) )
).first() ).first()
if not searched_user: if not searched_user:
raise HTTPException(404, detail="User not found") raise HTTPException(404, detail="User not found")
return await convert_db_user_to_api_user(searched_user, ruleset=ruleset) return await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
ruleset=ruleset,
)

View File

@@ -2,10 +2,11 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Coroutine from collections.abc import Coroutine
from datetime import UTC, datetime
from typing import override from typing import override
from app.database.relationship import Relationship, RelationshipType from app.database import Relationship, RelationshipType, User
from app.dependencies.database import engine from app.dependencies.database import engine, get_redis
from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity
from .hub import Client, Hub from .hub import Client, Hub
@@ -54,6 +55,18 @@ class MetadataHub(Hub[MetadataClientState]):
async def _clean_state(self, state: MetadataClientState) -> None: async def _clean_state(self, state: MetadataClientState) -> None:
if state.pushable: if state.pushable:
await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None)) await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None))
redis = get_redis()
if await redis.exists(f"metadata:online:{state.connection_id}"):
await redis.delete(f"metadata:online:{state.connection_id}")
async with AsyncSession(engine) as session:
async with session.begin():
user = (
await session.exec(
select(User).where(User.id == int(state.connection_id))
)
).one()
user.last_visit = datetime.now(UTC)
await session.commit()
@override @override
def create_state(self, client: Client) -> MetadataClientState: def create_state(self, client: Client) -> MetadataClientState:
@@ -93,6 +106,8 @@ class MetadataHub(Hub[MetadataClientState]):
) )
) )
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
redis = get_redis()
await redis.set(f"metadata:online:{user_id}", "")
async def UpdateStatus(self, client: Client, status: int) -> None: async def UpdateStatus(self, client: Client, status: int) -> None:
status_ = OnlineStatus(status) status_ = OnlineStatus(status)

View File

@@ -7,10 +7,9 @@ import struct
import time import time
from typing import override from typing import override
from app.database import Beatmap from app.database import Beatmap, User
from app.database.score import Score from app.database.score import Score
from app.database.score_token import ScoreToken from app.database.score_token import ScoreToken
from app.database.user import User
from app.dependencies.database import engine from app.dependencies.database import engine
from app.models.beatmap import BeatmapRankStatus from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_to_int from app.models.mods import mods_to_int
@@ -197,7 +196,7 @@ class SpectatorHub(Hub[StoreClientState]):
).first() ).first()
if not user: if not user:
return return
name = user.name name = user.username
store.state = state store.state = state
store.beatmap_status = beatmap.beatmap_status store.beatmap_status = beatmap.beatmap_status
store.checksum = beatmap.checksum store.checksum = beatmap.checksum
@@ -241,65 +240,17 @@ class SpectatorHub(Hub[StoreClientState]):
user_id = int(client.connection_id) user_id = int(client.connection_id)
store = self.get_or_create_state(client) store = self.get_or_create_state(client)
score = store.score score = store.score
assert store.beatmap_status is not None
assert store.state is not None
assert store.score is not None
if not score or not store.score_token: if not score or not store.score_token:
return return
assert store.beatmap_status is not None
async def _save_replay():
assert store.checksum is not None
assert store.ruleset_id is not None
assert store.state is not None
assert store.score is not None
async with AsyncSession(engine) as session:
async with session:
start_time = time.time()
score_record = None
while time.time() - start_time < READ_SCORE_TIMEOUT:
sub_query = select(ScoreToken.score_id).where(
ScoreToken.id == store.score_token,
)
result = await session.exec(
select(Score)
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
.where(
Score.id == sub_query,
Score.user_id == user_id,
)
)
score_record = result.first()
if score_record:
break
if not score_record:
return
if not score_record.passed:
return
score_record.has_replay = True
await session.commit()
await session.refresh(score_record)
save_replay(
ruleset_id=store.ruleset_id,
md5=store.checksum,
username=store.score.score_info.user.name,
score=score_record,
statistics=score.score_info.statistics,
maximum_statistics=score.score_info.maximum_statistics,
frames=score.replay_frames,
)
if ( if (
( BeatmapRankStatus.PENDING < store.beatmap_status <= BeatmapRankStatus.LOVED
BeatmapRankStatus.PENDING ) and any(
< store.beatmap_status k.is_hit() and v > 0 for k, v in store.score.score_info.statistics.items()
<= BeatmapRankStatus.LOVED
)
and any(
k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()
)
and state.state != SpectatedUserState.Failed
): ):
# save replay await self._process_score(store, client)
await _save_replay()
store.state = None store.state = None
store.beatmap_status = None store.beatmap_status = None
store.checksum = None store.checksum = None
@@ -308,6 +259,56 @@ class SpectatorHub(Hub[StoreClientState]):
store.score = None store.score = None
await self._end_session(user_id, state) await self._end_session(user_id, state)
async def _process_score(self, store: StoreClientState, client: Client) -> None:
user_id = int(client.connection_id)
assert store.state is not None
assert store.score_token is not None
assert store.checksum is not None
assert store.ruleset_id is not None
assert store.score is not None
async with AsyncSession(engine) as session:
async with session:
start_time = time.time()
score_record = None
while time.time() - start_time < READ_SCORE_TIMEOUT:
sub_query = select(ScoreToken.score_id).where(
ScoreToken.id == store.score_token,
)
result = await session.exec(
select(Score)
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
.where(
Score.id == sub_query,
Score.user_id == user_id,
)
)
score_record = result.first()
if score_record:
break
if not score_record:
return
if not score_record.passed:
return
await self.call_noblock(
client,
"UserScoreProcessed",
user_id,
score_record.id,
)
# save replay
score_record.has_replay = True
await session.commit()
await session.refresh(score_record)
save_replay(
ruleset_id=store.ruleset_id,
md5=store.checksum,
username=store.score.score_info.user.name,
score=score_record,
statistics=store.score.score_info.statistics,
maximum_statistics=store.score.score_info.maximum_statistics,
frames=store.score.replay_frames,
)
async def _end_session(self, user_id: int, state: SpectatorState) -> None: async def _end_session(self, user_id: int, state: SpectatorState) -> None:
if state.state == SpectatedUserState.Playing: if state.state == SpectatedUserState.Playing:
state.state = SpectatedUserState.Quit state.state = SpectatedUserState.Quit
@@ -336,7 +337,7 @@ class SpectatorHub(Hub[StoreClientState]):
async with AsyncSession(engine) as session: async with AsyncSession(engine) as session:
async with session.begin(): async with session.begin():
username = ( username = (
await session.exec(select(User.name).where(User.id == user_id)) await session.exec(select(User.username).where(User.id == user_id))
).first() ).first()
if not username: if not username:
return return

View File

@@ -1,465 +1,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime
from app.database import (
LazerUserCounts,
LazerUserProfile,
LazerUserStatistics,
User as DBUser,
)
from app.models.user import (
Country,
Cover,
DailyChallengeStats,
GradeCounts,
Kudosu,
Level,
Page,
RankHighest,
RankHistory,
Statistics,
User,
UserAchievement,
)
def unix_timestamp_to_windows(timestamp: int) -> int: def unix_timestamp_to_windows(timestamp: int) -> int:
"""Convert a Unix timestamp to a Windows timestamp.""" """Convert a Unix timestamp to a Windows timestamp."""
return (timestamp + 62135596800) * 10_000_000 return (timestamp + 62135596800) * 10_000_000
async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User:
"""将数据库用户模型转换为API用户模型使用 Lazer 表)"""
# 从db_user获取基本字段值
user_id = getattr(db_user, "id")
user_name = getattr(db_user, "name")
user_country = getattr(db_user, "country")
user_country_code = user_country # 在User模型中country字段就是country_code
# 获取 Lazer 用户资料
profile = db_user.lazer_profile
if not profile:
# 如果没有 lazer 资料,使用默认值
profile = LazerUserProfile(
user_id=user_id,
)
# 获取 Lazer 用户计数 - 使用正确的 lazer_counts 关系
lzrcnt = db_user.lazer_counts
if not lzrcnt:
# 如果没有 lazer 计数,使用默认值
lzrcnt = LazerUserCounts(user_id=user_id)
# 获取指定模式的统计信息
user_stats = None
if db_user.lazer_statistics:
for stat in db_user.lazer_statistics:
if stat.mode == ruleset:
user_stats = stat
break
if not user_stats:
# 如果没有找到指定模式的统计,创建默认统计
user_stats = LazerUserStatistics(user_id=user_id)
# 获取国家信息
country_code = db_user.country_code if db_user.country_code is not None else "XX"
country = Country(code=str(country_code), name=get_country_name(str(country_code)))
# 获取 Kudosu 信息
kudosu = Kudosu(available=0, total=0)
# 获取计数信息
# counts = LazerUserCounts(user_id=user_id)
# 转换统计信息
statistics = Statistics(
count_100=user_stats.count_100,
count_300=user_stats.count_300,
count_50=user_stats.count_50,
count_miss=user_stats.count_miss,
level=Level(
current=user_stats.level_current, progress=user_stats.level_progress
),
global_rank=user_stats.global_rank,
global_rank_exp=user_stats.global_rank_exp,
pp=float(user_stats.pp) if user_stats.pp else 0.0,
pp_exp=float(user_stats.pp_exp) if user_stats.pp_exp else 0.0,
ranked_score=user_stats.ranked_score,
hit_accuracy=float(user_stats.hit_accuracy) if user_stats.hit_accuracy else 0.0,
play_count=user_stats.play_count,
play_time=user_stats.play_time,
total_score=user_stats.total_score,
total_hits=user_stats.total_hits,
maximum_combo=user_stats.maximum_combo,
replays_watched_by_others=user_stats.replays_watched_by_others,
is_ranked=user_stats.is_ranked,
grade_counts=GradeCounts(
ss=user_stats.grade_ss,
ssh=user_stats.grade_ssh,
s=user_stats.grade_s,
sh=user_stats.grade_sh,
a=user_stats.grade_a,
),
country_rank=user_stats.country_rank,
rank={"country": user_stats.country_rank} if user_stats.country_rank else None,
)
# 转换所有模式的统计信息
statistics_rulesets = {}
if db_user.lazer_statistics:
for stat in db_user.lazer_statistics:
statistics_rulesets[stat.mode] = Statistics(
count_100=stat.count_100,
count_300=stat.count_300,
count_50=stat.count_50,
count_miss=stat.count_miss,
level=Level(current=stat.level_current, progress=stat.level_progress),
global_rank=stat.global_rank,
global_rank_exp=stat.global_rank_exp,
pp=float(stat.pp) if stat.pp else 0.0,
pp_exp=float(stat.pp_exp) if stat.pp_exp else 0.0,
ranked_score=stat.ranked_score,
hit_accuracy=float(stat.hit_accuracy) if stat.hit_accuracy else 0.0,
play_count=stat.play_count,
play_time=stat.play_time,
total_score=stat.total_score,
total_hits=stat.total_hits,
maximum_combo=stat.maximum_combo,
replays_watched_by_others=stat.replays_watched_by_others,
is_ranked=stat.is_ranked,
grade_counts=GradeCounts(
ss=stat.grade_ss,
ssh=stat.grade_ssh,
s=stat.grade_s,
sh=stat.grade_sh,
a=stat.grade_a,
),
country_rank=stat.country_rank,
rank={"country": stat.country_rank} if stat.country_rank else None,
)
# 转换国家信息
country = Country(code=user_country_code, name=get_country_name(user_country_code))
# 转换封面信息
cover_url = (
profile.cover_url
if profile and profile.cover_url
else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
)
cover = Cover(
custom_url=profile.cover_url if profile else None, url=str(cover_url), id=None
)
# 转换 Kudosu 信息
kudosu = Kudosu(available=0, total=0)
# 转换成就信息
user_achievements = []
if db_user.lazer_achievements:
for achievement in db_user.lazer_achievements:
user_achievements.append(
UserAchievement(
achieved_at=achievement.achieved_at,
achievement_id=achievement.achievement_id,
)
)
# 转换排名历史
rank_history = None
rank_history_data = None
for rh in db_user.rank_history:
if rh.mode == ruleset:
rank_history_data = rh.rank_data
break
if rank_history_data:
rank_history = RankHistory(mode=ruleset, data=rank_history_data)
# 转换每日挑战统计
# daily_challenge_stats = None
# if db_user.daily_challenge_stats:
# dcs = db_user.daily_challenge_stats
# daily_challenge_stats = DailyChallengeStats(
# daily_streak_best=dcs.daily_streak_best,
# daily_streak_current=dcs.daily_streak_current,
# last_update=dcs.last_update,
# last_weekly_streak=dcs.last_weekly_streak,
# playcount=dcs.playcount,
# top_10p_placements=dcs.top_10p_placements,
# top_50p_placements=dcs.top_50p_placements,
# user_id=dcs.user_id,
# weekly_streak_best=dcs.weekly_streak_best,
# weekly_streak_current=dcs.weekly_streak_current,
# )
# 转换最高排名
rank_highest = None
if user_stats.rank_highest:
rank_highest = RankHighest(
rank=user_stats.rank_highest,
updated_at=user_stats.rank_highest_updated_at or datetime.utcnow(),
)
# 转换团队信息
team = None
if db_user.team_membership:
team_member = db_user.team_membership # 假设用户只属于一个团队
team = team_member.team
# 创建用户对象
# 从db_user获取基本字段值
user_id = getattr(db_user, "id")
user_name = getattr(db_user, "name")
user_country = getattr(db_user, "country")
# 获取用户头像URL
avatar_url = None
# 首先检查 profile 中的 avatar_url
if profile and hasattr(profile, "avatar_url") and profile.avatar_url:
avatar_url = str(profile.avatar_url)
# 然后检查是否有关联的头像记录
if avatar_url is None and hasattr(db_user, "avatar") and db_user.avatar is not None:
if db_user.avatar.r2_game_url:
# 优先使用游戏用的头像URL
avatar_url = str(db_user.avatar.r2_game_url)
elif db_user.avatar.r2_original_url:
# 其次使用原始头像URL
avatar_url = str(db_user.avatar.r2_original_url)
# 如果还是没有找到,通过查询获取
# if db_session and avatar_url is None:
# try:
# # 导入UserAvatar模型
# # 尝试查找用户的头像记录
# statement = select(UserAvatar).where(
# UserAvatar.user_id == user_id, UserAvatar.is_active == True
# )
# avatar_record = db_session.exec(statement).first()
# if avatar_record is not None:
# if avatar_record.r2_game_url is not None:
# # 优先使用游戏用的头像URL
# avatar_url = str(avatar_record.r2_game_url)
# elif avatar_record.r2_original_url is not None:
# # 其次使用原始头像URL
# avatar_url = str(avatar_record.r2_original_url)
# except Exception as e:
# print(f"获取用户头像时出错: {e}")
# print(f"最终头像URL: {avatar_url}")
# 如果仍然没有找到头像URL则使用默认URL
if avatar_url is None:
avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1"
# 处理 profile_order 列表排序
profile_order = [
"me",
"recent_activity",
"top_ranks",
"medals",
"historical",
"beatmaps",
"kudosu",
]
if profile and profile.profile_order:
profile_order = profile.profile_order.split(",")
# 在convert_db_user_to_api_user函数中添加active_tournament_banners处理
active_tournament_banners = []
if db_user.active_banners:
for banner in db_user.active_banners:
active_tournament_banners.append(
{
"tournament_id": banner.tournament_id,
"image_url": banner.image_url,
"is_active": banner.is_active,
}
)
# 在convert_db_user_to_api_user函数中添加badges处理
badges = []
if db_user.lazer_badges:
for badge in db_user.lazer_badges:
badges.append(
{
"badge_id": badge.badge_id,
"awarded_at": badge.awarded_at,
"description": badge.description,
"image_url": badge.image_url,
"url": badge.url,
}
)
# 在convert_db_user_to_api_user函数中添加monthly_playcounts处理
monthly_playcounts = []
if db_user.lazer_monthly_playcounts:
for playcount in db_user.lazer_monthly_playcounts:
monthly_playcounts.append(
{
"start_date": playcount.start_date.isoformat()
if playcount.start_date
else None,
"play_count": playcount.play_count,
}
)
# 在convert_db_user_to_api_user函数中添加previous_usernames处理
previous_usernames = []
if db_user.lazer_previous_usernames:
for username in db_user.lazer_previous_usernames:
previous_usernames.append(
{
"username": username.username,
"changed_at": username.changed_at.isoformat()
if username.changed_at
else None,
}
)
# 在convert_db_user_to_api_user函数中添加replays_watched_counts处理
replays_watched_counts = []
if hasattr(db_user, "lazer_replays_watched") and db_user.lazer_replays_watched:
for replay in db_user.lazer_replays_watched:
replays_watched_counts.append(
{
"start_date": replay.start_date.isoformat()
if replay.start_date
else None,
"count": replay.count,
}
)
# 创建用户对象
user = User(
id=user_id,
username=user_name,
avatar_url=avatar_url,
country_code=str(country_code),
default_group=profile.default_group if profile else "default",
is_active=profile.is_active,
is_bot=profile.is_bot,
is_deleted=profile.is_deleted,
is_online=profile.is_online,
is_supporter=profile.is_supporter,
is_restricted=profile.is_restricted,
last_visit=db_user.last_visit,
pm_friends_only=profile.pm_friends_only,
profile_colour=profile.profile_colour,
cover_url=profile.cover_url
if profile and profile.cover_url
else "https://assets.ppy.sh/user-profile-covers/default.jpeg",
discord=profile.discord if profile else None,
has_supported=profile.has_supported if profile else False,
interests=profile.interests if profile else None,
join_date=profile.join_date if profile.join_date else datetime.now(UTC),
location=profile.location if profile else None,
max_blocks=profile.max_blocks if profile and profile.max_blocks else 100,
max_friends=profile.max_friends if profile and profile.max_friends else 500,
post_count=profile.post_count if profile and profile.post_count else 0,
profile_hue=profile.profile_hue if profile and profile.profile_hue else None,
profile_order=profile_order, # 使用排序后的 profile_order
title=profile.title if profile else None,
title_url=profile.title_url if profile else None,
twitter=profile.twitter if profile else None,
website=profile.website if profile else None,
session_verified=True,
support_level=profile.support_level if profile else 0,
country=country,
cover=cover,
kudosu=kudosu,
statistics=statistics,
statistics_rulesets=statistics_rulesets,
beatmap_playcounts_count=lzrcnt.beatmap_playcounts_count if lzrcnt else 0,
comments_count=lzrcnt.comments_count if lzrcnt else 0,
favourite_beatmapset_count=lzrcnt.favourite_beatmapset_count if lzrcnt else 0,
follower_count=lzrcnt.follower_count if lzrcnt else 0,
graveyard_beatmapset_count=lzrcnt.graveyard_beatmapset_count if lzrcnt else 0,
guest_beatmapset_count=lzrcnt.guest_beatmapset_count if lzrcnt else 0,
loved_beatmapset_count=lzrcnt.loved_beatmapset_count if lzrcnt else 0,
mapping_follower_count=lzrcnt.mapping_follower_count if lzrcnt else 0,
nominated_beatmapset_count=lzrcnt.nominated_beatmapset_count if lzrcnt else 0,
pending_beatmapset_count=lzrcnt.pending_beatmapset_count if lzrcnt else 0,
ranked_beatmapset_count=lzrcnt.ranked_beatmapset_count if lzrcnt else 0,
ranked_and_approved_beatmapset_count=lzrcnt.ranked_and_approved_beatmapset_count
if lzrcnt
else 0,
unranked_beatmapset_count=lzrcnt.unranked_beatmapset_count if lzrcnt else 0,
scores_best_count=lzrcnt.scores_best_count if lzrcnt else 0,
scores_first_count=lzrcnt.scores_first_count if lzrcnt else 0,
scores_pinned_count=lzrcnt.scores_pinned_count,
scores_recent_count=lzrcnt.scores_recent_count if lzrcnt else 0,
account_history=[], # TODO: 获取用户历史账户信息
# active_tournament_banner=len(active_tournament_banners),
active_tournament_banners=active_tournament_banners,
badges=badges,
current_season_stats=None,
daily_challenge_user_stats=DailyChallengeStats(
user_id=user_id,
daily_streak_best=db_user.daily_challenge_stats.daily_streak_best
if db_user.daily_challenge_stats
else 0,
daily_streak_current=db_user.daily_challenge_stats.daily_streak_current
if db_user.daily_challenge_stats
else 0,
last_update=db_user.daily_challenge_stats.last_update
if db_user.daily_challenge_stats
else None,
last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak
if db_user.daily_challenge_stats
else None,
playcount=db_user.daily_challenge_stats.playcount
if db_user.daily_challenge_stats
else 0,
top_10p_placements=db_user.daily_challenge_stats.top_10p_placements
if db_user.daily_challenge_stats
else 0,
top_50p_placements=db_user.daily_challenge_stats.top_50p_placements
if db_user.daily_challenge_stats
else 0,
weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best
if db_user.daily_challenge_stats
else 0,
weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current
if db_user.daily_challenge_stats
else 0,
),
groups=[],
monthly_playcounts=monthly_playcounts,
page=Page(html=profile.page_html or "", raw=profile.page_raw or "")
if profile.page_html or profile.page_raw
else Page(),
previous_usernames=previous_usernames,
rank_highest=rank_highest,
rank_history=rank_history,
rankHistory=rank_history,
replays_watched_counts=replays_watched_counts,
team=team,
user_achievements=user_achievements,
)
return user
def get_country_name(country_code: str) -> str:
"""根据国家代码获取国家名称"""
country_names = {
"CN": "China",
"JP": "Japan",
"US": "United States",
"GB": "United Kingdom",
"DE": "Germany",
"FR": "France",
"KR": "South Korea",
"CA": "Canada",
"AU": "Australia",
"BR": "Brazil",
# 可以添加更多国家
}
return country_names.get(country_code, "Unknown")

107
main.py
View File

@@ -4,25 +4,22 @@ from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from app.config import settings from app.config import settings
from app.database import Team # noqa: F401 from app.dependencies.database import create_tables, engine, redis_client
from app.dependencies.database import create_tables, engine
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.models.user import User
from app.router import api_router, auth_router, fetcher_router, signalr_router from app.router import api_router, auth_router, fetcher_router, signalr_router
from fastapi import FastAPI from fastapi import FastAPI
User.model_rebuild()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# on startup # on startup
await create_tables() await create_tables()
get_fetcher() # 初始化 fetcher await get_fetcher() # 初始化 fetcher
# on shutdown # on shutdown
yield yield
await engine.dispose() await engine.dispose()
await redis_client.aclose()
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan) app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
@@ -44,104 +41,6 @@ async def health_check():
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()} return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
# @app.get("/api/v2/friends")
# async def get_friends():
# return JSONResponse(
# content=[
# {
# "id": 123456,
# "username": "BestFriend",
# "is_online": True,
# "is_supporter": False,
# "country": {"code": "US", "name": "United States"},
# }
# ]
# )
# @app.get("/api/v2/notifications")
# async def get_notifications():
# return JSONResponse(content={"notifications": [], "unread_count": 0})
# @app.post("/api/v2/chat/ack")
# async def chat_ack():
# return JSONResponse(content={"status": "ok"})
# @app.get("/api/v2/users/{user_id}/{mode}")
# async def get_user_mode(user_id: int, mode: str):
# return JSONResponse(
# content={
# "id": user_id,
# "username": "测试测试测",
# "statistics": {
# "level": {"current": 97, "progress": 96},
# "pp": 114514,
# "global_rank": 666,
# "country_rank": 1,
# "hit_accuracy": 100,
# },
# "country": {"code": "JP", "name": "Japan"},
# }
# )
# @app.get("/api/v2/me")
# async def get_me():
# return JSONResponse(
# content={
# "id": 15651670,
# "username": "Googujiang",
# "is_online": True,
# "country": {"code": "JP", "name": "Japan"},
# "statistics": {
# "level": {"current": 97, "progress": 96},
# "pp": 2826.26,
# "global_rank": 298026,
# "country_rank": 11220,
# "hit_accuracy": 95.7168,
# },
# }
# )
# @app.post("/signalr/metadata/negotiate")
# async def metadata_negotiate(negotiateVersion: int = 1):
# return JSONResponse(
# content={
# "connectionId": "abc123",
# "availableTransports": [
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
# ],
# }
# )
# @app.post("/signalr/spectator/negotiate")
# async def spectator_negotiate(negotiateVersion: int = 1):
# return JSONResponse(
# content={
# "connectionId": "spec456",
# "availableTransports": [
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
# ],
# }
# )
# @app.post("/signalr/multiplayer/negotiate")
# async def multiplayer_negotiate(negotiateVersion: int = 1):
# return JSONResponse(
# content={
# "connectionId": "multi789",
# "availableTransports": [
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
# ],
# }
# )
if __name__ == "__main__": if __name__ == "__main__":
from app.log import logger # noqa: F401 from app.log import logger # noqa: F401

View File

@@ -1,8 +1,8 @@
"""score: remove best_id in database """beatmapset: support favourite count
Revision ID: 78be13c71791 Revision ID: 1178d0758ebf
Revises: dc4d25c428c7 Revises:
Create Date: 2025-07-29 07:57:33.764517 Create Date: 2025-08-01 04:05:09.882800
""" """
@@ -15,8 +15,8 @@ import sqlalchemy as sa
from sqlalchemy.dialects import mysql from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "78be13c71791" revision: str = "1178d0758ebf"
down_revision: str | Sequence[str] | None = "dc4d25c428c7" down_revision: str | Sequence[str] | None = None
branch_labels: str | Sequence[str] | None = None branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None
@@ -24,7 +24,7 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
"""Upgrade schema.""" """Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_column("scores", "best_id") op.drop_column("beatmapsets", "favourite_count")
# ### end Alembic commands ### # ### end Alembic commands ###
@@ -32,7 +32,9 @@ def downgrade() -> None:
"""Downgrade schema.""" """Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.add_column( op.add_column(
"scores", "beatmapsets",
sa.Column("best_id", mysql.INTEGER(), autoincrement=False, nullable=True), sa.Column(
"favourite_count", mysql.INTEGER(), autoincrement=False, nullable=False
),
) )
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -0,0 +1,54 @@
"""relationship: fix unique relationship
Revision ID: 58a11441d302
Revises: 1178d0758ebf
Create Date: 2025-08-01 04:23:02.498166
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision: str = "58a11441d302"
down_revision: str | Sequence[str] | None = "1178d0758ebf"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"relationship",
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
)
op.drop_constraint("PRIMARY", "relationship", type_="primary")
op.create_primary_key("pk_relationship", "relationship", ["id"])
op.alter_column(
"relationship", "user_id", existing_type=mysql.BIGINT(), nullable=True
)
op.alter_column(
"relationship", "target_id", existing_type=mysql.BIGINT(), nullable=True
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("pk_relationship", "relationship", type_="primary")
op.create_primary_key("PRIMARY", "relationship", ["user_id", "target_id"])
op.alter_column(
"relationship", "target_id", existing_type=mysql.BIGINT(), nullable=False
)
op.alter_column(
"relationship", "user_id", existing_type=mysql.BIGINT(), nullable=False
)
op.drop_column("relationship", "id")
# ### end Alembic commands ###

View File

@@ -1,36 +0,0 @@
"""score: add nlarge_tick_hit & nsmall_tick_hit for pp calculator
Revision ID: dc4d25c428c7
Revises:
Create Date: 2025-07-29 01:43:40.221070
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "dc4d25c428c7"
down_revision: str | Sequence[str] | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("scores", sa.Column("nlarge_tick_hit", sa.Integer(), nullable=True))
op.add_column("scores", sa.Column("nsmall_tick_hit", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("scores", "nsmall_tick_hit")
op.drop_column("scores", "nlarge_tick_hit")
# ### end Alembic commands ###