chore(merge): merge branch 'main' into feat/multiplayer-api

This commit is contained in:
MingxuanGame
2025-08-01 05:24:12 +00:00
39 changed files with 971 additions and 2191 deletions

View File

@@ -8,7 +8,7 @@ import string
from app.config import settings
from app.database import (
OAuthToken,
User as DBUser,
User,
)
from app.log import logger
@@ -74,7 +74,7 @@ def get_password_hash(password: str) -> str:
async def authenticate_user_legacy(
db: AsyncSession, name: str, password: str
) -> DBUser | None:
) -> User | None:
"""
验证用户身份 - 使用类似 from_login 的逻辑
"""
@@ -82,7 +82,7 @@ async def authenticate_user_legacy(
pw_md5 = hashlib.md5(password.encode()).hexdigest()
# 2. 根据用户名查找用户
statement = select(DBUser).where(DBUser.name == name)
statement = select(User).where(User.username == name)
user = (await db.exec(statement)).first()
if not user:
return None
@@ -113,7 +113,7 @@ async def authenticate_user_legacy(
async def authenticate_user(
db: AsyncSession, username: str, password: str
) -> DBUser | None:
) -> User | None:
"""验证用户身份"""
return await authenticate_user_legacy(db, username, password)

View File

@@ -1,3 +1,4 @@
from .achievement import UserAchievement, UserAchievementResp
from .auth import OAuthToken
from .beatmap import (
Beatmap as Beatmap,
@@ -8,7 +9,13 @@ from .beatmapset import (
BeatmapsetResp as BeatmapsetResp,
)
from .best_score import BestScore
from .legacy import LegacyOAuthToken, LegacyUserStatistics
from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp
from .favourite_beatmapset import FavouriteBeatmapset
from .lazer_user import (
User,
UserResp,
)
from .pp_best_score import PPBestScore
from .relationship import Relationship, RelationshipResp, RelationshipType
from .score import (
Score,
@@ -17,52 +24,27 @@ from .score import (
ScoreStatistics,
)
from .score_token import ScoreToken, ScoreTokenResp
from .statistics import (
UserStatistics,
UserStatisticsResp,
)
from .team import Team, TeamMember
from .user import (
DailyChallengeStats,
LazerUserAchievement,
LazerUserBadge,
LazerUserBanners,
LazerUserCountry,
LazerUserCounts,
LazerUserKudosu,
LazerUserMonthlyPlaycounts,
LazerUserPreviousUsername,
LazerUserProfile,
LazerUserProfileSections,
LazerUserReplaysWatched,
LazerUserStatistics,
RankHistory,
User,
UserAchievement,
UserAvatar,
from .user_account_history import (
UserAccountHistory,
UserAccountHistoryResp,
UserAccountHistoryType,
)
BeatmapsetResp.model_rebuild()
BeatmapResp.model_rebuild()
__all__ = [
"Beatmap",
"BeatmapResp",
"Beatmapset",
"BeatmapsetResp",
"BestScore",
"DailyChallengeStats",
"LazerUserAchievement",
"LazerUserBadge",
"LazerUserBanners",
"LazerUserCountry",
"LazerUserCounts",
"LazerUserKudosu",
"LazerUserMonthlyPlaycounts",
"LazerUserPreviousUsername",
"LazerUserProfile",
"LazerUserProfileSections",
"LazerUserReplaysWatched",
"LazerUserStatistics",
"LegacyOAuthToken",
"LegacyUserStatistics",
"DailyChallengeStatsResp",
"FavouriteBeatmapset",
"OAuthToken",
"RankHistory",
"PPBestScore",
"Relationship",
"RelationshipResp",
"RelationshipType",
@@ -75,6 +57,17 @@ __all__ = [
"Team",
"TeamMember",
"User",
"UserAccountHistory",
"UserAccountHistoryResp",
"UserAccountHistoryType",
"UserAchievement",
"UserAvatar",
"UserAchievement",
"UserAchievementResp",
"UserResp",
"UserStatistics",
"UserStatisticsResp",
]
for i in __all__:
if i.endswith("Resp"):
globals()[i].model_rebuild() # type: ignore[call-arg]

View File

@@ -1,19 +1,21 @@
from datetime import datetime
from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from sqlalchemy import Column, DateTime
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
from .lazer_user import User
class OAuthToken(SQLModel, table=True):
class OAuthToken(UTCBaseModel, SQLModel, table=True):
__tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("users.id"), index=True)
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
access_token: str = Field(max_length=500, unique=True)
refresh_token: str = Field(max_length=500, unique=True)

View File

@@ -7,13 +7,14 @@ from app.models.score import MODE_TO_INT, GameMode
from .beatmapset import Beatmapset, BeatmapsetResp
from sqlalchemy import DECIMAL, Column, DateTime
from sqlalchemy.orm import joinedload
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
from .lazer_user import User
class BeatmapOwner(SQLModel):
id: int
@@ -66,7 +67,9 @@ class Beatmap(BeatmapBase, table=True):
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus
# optional
beatmapset: Beatmapset = Relationship(back_populates="beatmaps")
beatmapset: Beatmapset = Relationship(
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
)
@property
def can_ranked(self) -> bool:
@@ -87,13 +90,7 @@ class Beatmap(BeatmapBase, table=True):
session.add(beatmap)
await session.commit()
beatmap = (
await session.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.where(Beatmap.id == resp.id)
)
await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
).first()
assert beatmap is not None, "Beatmap should not be None after commit"
return beatmap
@@ -131,13 +128,9 @@ class Beatmap(BeatmapBase, table=True):
) -> "Beatmap":
beatmap = (
await session.exec(
select(Beatmap)
.where(
select(Beatmap).where(
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
)
).first()
if not beatmap:
@@ -164,11 +157,13 @@ class BeatmapResp(BeatmapBase):
url: str = ""
@classmethod
def from_db(
async def from_db(
cls,
beatmap: Beatmap,
query_mode: GameMode | None = None,
from_set: bool = False,
session: AsyncSession | None = None,
user: "User | None" = None,
) -> "BeatmapResp":
beatmap_ = beatmap.model_dump()
if query_mode is not None and beatmap.mode != query_mode:
@@ -178,5 +173,7 @@ class BeatmapResp(BeatmapBase):
beatmap_["ranked"] = beatmap.beatmap_status.value
beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode]
if not from_set:
beatmap_["beatmapset"] = BeatmapsetResp.from_db(beatmap.beatmapset)
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(
beatmap.beatmapset, session=session, user=user
)
return cls.model_validate(beatmap_)

View File

@@ -4,13 +4,17 @@ from typing import TYPE_CHECKING, TypedDict, cast
from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.score import GameMode
from .lazer_user import BASE_INCLUDES, User, UserResp
from pydantic import BaseModel, model_serializer
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
from sqlmodel import Field, Relationship, SQLModel
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import Field, Relationship, SQLModel, col, func, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from .beatmap import Beatmap, BeatmapResp
from .favourite_beatmapset import FavouriteBeatmapset
class BeatmapCovers(SQLModel):
@@ -88,7 +92,6 @@ class BeatmapsetBase(SQLModel):
artist_unicode: str = Field(index=True)
covers: BeatmapCovers | None = Field(sa_column=Column(JSON))
creator: str
favourite_count: int
nsfw: bool = Field(default=False)
play_count: int
preview_url: str
@@ -112,11 +115,9 @@ class BeatmapsetBase(SQLModel):
pack_tags: list[str] = Field(default=[], sa_column=Column(JSON))
ratings: list[int] = Field(default=None, sa_column=Column(JSON))
# TODO: recent_favourites: Optional[list[User]] = None
# TODO: related_users: Optional[list[User]] = None
# TODO: user: Optional[User] = Field(default=None)
track_id: int | None = Field(default=None) # feature artist?
# TODO: has_favourited
# BeatmapsetExtended
bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2)))
@@ -129,7 +130,7 @@ class BeatmapsetBase(SQLModel):
tags: str = Field(default="", sa_column=Column(Text))
class Beatmapset(BeatmapsetBase, table=True):
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
@@ -150,6 +151,7 @@ class Beatmapset(BeatmapsetBase, table=True):
hype_required: int = Field(default=0)
availability_info: str | None = Field(default=None)
download_disabled: bool = Field(default=False)
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
async def from_resp(
@@ -197,40 +199,88 @@ class BeatmapsetResp(BeatmapsetBase):
genre: BeatmapTranslationText | None = None
language: BeatmapTranslationText | None = None
nominations: BeatmapNominations | None = None
has_favourited: bool = False
favourite_count: int = 0
recent_favourites: list[UserResp] = Field(default_factory=list)
@classmethod
def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp":
async def from_db(
cls,
beatmapset: Beatmapset,
include: list[str] = [],
session: AsyncSession | None = None,
user: User | None = None,
) -> "BeatmapsetResp":
from .beatmap import BeatmapResp
from .favourite_beatmapset import FavouriteBeatmapset
beatmaps = [
BeatmapResp.from_db(beatmap, from_set=True)
for beatmap in beatmapset.beatmaps
]
update = {
"beatmaps": [
await BeatmapResp.from_db(beatmap, from_set=True)
for beatmap in await beatmapset.awaitable_attrs.beatmaps
],
"hype": BeatmapHype(
current=beatmapset.hype_current, required=beatmapset.hype_required
),
"availability": BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
),
"genre": BeatmapTranslationText(
name=beatmapset.beatmap_genre.name,
id=beatmapset.beatmap_genre.value,
),
"language": BeatmapTranslationText(
name=beatmapset.beatmap_language.name,
id=beatmapset.beatmap_language.value,
),
"nominations": BeatmapNominations(
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
),
"status": beatmapset.beatmap_status.name.lower(),
"ranked": beatmapset.beatmap_status.value,
"is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING,
**beatmapset.model_dump(),
}
if session and user:
existing_favourite = (
await session.exec(
select(FavouriteBeatmapset).where(
FavouriteBeatmapset.beatmapset_id == beatmapset.id
)
)
).first()
update["has_favourited"] = existing_favourite is not None
if session and "recent_favourites" in include:
recent_favourites = (
await session.exec(
select(FavouriteBeatmapset)
.where(
FavouriteBeatmapset.beatmapset_id == beatmapset.id,
)
.order_by(col(FavouriteBeatmapset.date).desc())
.limit(50)
)
).all()
update["recent_favourites"] = [
await UserResp.from_db(
await favourite.awaitable_attrs.user,
session=session,
include=BASE_INCLUDES,
)
for favourite in recent_favourites
]
if session:
update["favourite_count"] = (
await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
).one()
return cls.model_validate(
{
"beatmaps": beatmaps,
"hype": BeatmapHype(
current=beatmapset.hype_current, required=beatmapset.hype_required
),
"availability": BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
),
"genre": BeatmapTranslationText(
name=beatmapset.beatmap_genre.name,
id=beatmapset.beatmap_genre.value,
),
"language": BeatmapTranslationText(
name=beatmapset.beatmap_language.name,
id=beatmapset.beatmap_language.value,
),
"nominations": BeatmapNominations(
required=beatmapset.nominations_required,
current=beatmapset.nominations_current,
),
"status": beatmapset.beatmap_status.name.lower(),
"ranked": beatmapset.beatmap_status.value,
"is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING,
**beatmapset.model_dump(),
}
update,
)

View File

@@ -1,14 +1,14 @@
from typing import TYPE_CHECKING
from app.models.score import GameMode
from app.models.score import GameMode, Rank
from .user import User
from .lazer_user import User
from sqlmodel import (
JSON,
BigInteger,
Column,
Field,
Float,
ForeignKey,
Relationship,
SQLModel,
@@ -20,22 +20,29 @@ if TYPE_CHECKING:
class BestScore(SQLModel, table=True):
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
__tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("users.id"), index=True)
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True)
pp: float = Field(
sa_column=Column(Float, default=0),
total_score: int = Field(
default=0, sa_column=Column(BigInteger, ForeignKey("scores.total_score"))
)
acc: float = Field(
sa_column=Column(Float, default=0),
mods: list[str] = Field(
default_factory=list,
sa_column=Column(JSON),
)
rank: Rank
user: User = Relationship()
score: "Score" = Relationship()
score: "Score" = Relationship(
sa_relationship_kwargs={
"foreign_keys": "[BestScore.score_id]",
"lazy": "joined",
}
)
beatmap: "Beatmap" = Relationship()

View File

@@ -0,0 +1,53 @@
import datetime
from app.database.beatmapset import Beatmapset
from app.database.lazer_user import User
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
BigInteger,
Column,
DateTime,
Field,
ForeignKey,
Relationship,
SQLModel,
)
class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True):
__tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
exclude=True,
)
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("lazer_users.id"),
index=True,
),
)
beatmapset_id: int = Field(
default=None,
sa_column=Column(
ForeignKey("beatmapsets.id"),
index=True,
),
)
date: datetime.datetime = Field(
default=datetime.datetime.now(datetime.UTC),
sa_column=Column(
DateTime,
),
)
user: User = Relationship(back_populates="favourite_beatmapsets")
beatmapset: Beatmapset = Relationship(
sa_relationship_kwargs={
"lazy": "selectin",
},
back_populates="favourites",
)

View File

@@ -12,7 +12,7 @@ from .statistics import UserStatistics, UserStatisticsResp
from .team import Team, TeamMember
from .user_account_history import UserAccountHistory, UserAccountHistoryResp
from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import (
JSON,
BigInteger,
@@ -27,7 +27,8 @@ from sqlmodel import (
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.database.relationship import RelationshipResp
from .favourite_beatmapset import FavouriteBeatmapset
from .relationship import RelationshipResp
class Kudosu(TypedDict):
@@ -128,7 +129,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_bng: bool = False
class User(UserBase, table=True):
class User(AsyncAttrs, UserBase, table=True):
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
@@ -143,6 +144,9 @@ class User(UserBase, table=True):
back_populates="user"
)
monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user")
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(
back_populates="user"
)
email: str = Field(max_length=254, unique=True, index=True, exclude=True)
priv: int = Field(default=1, exclude=True)
@@ -154,21 +158,10 @@ class User(UserBase, table=True):
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
)
@classmethod
def all_select_option(cls):
return (
selectinload(cls.account_history), # pyright: ignore[reportArgumentType]
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
selectinload(cls.achievement), # pyright: ignore[reportArgumentType]
joinedload(cls.team_membership).joinedload(TeamMember.team), # pyright: ignore[reportArgumentType]
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
selectinload(cls.monthly_playcounts), # pyright: ignore[reportArgumentType]
)
class UserResp(UserBase):
id: int | None = None
is_online: bool = True # TODO
is_online: bool = False
groups: list = [] # TODO
country: Country = Field(default_factory=lambda: Country(code="CN", name="China"))
favourite_beatmapset_count: int = 0 # TODO
@@ -211,6 +204,8 @@ class UserResp(UserBase):
include: list[str] = [],
ruleset: GameMode | None = None,
) -> "UserResp":
from app.dependencies.database import get_redis
from .best_score import BestScore
from .relationship import Relationship, RelationshipResp, RelationshipType
@@ -236,6 +231,8 @@ class UserResp(UserBase):
.limit(200)
)
).one()
redis = get_redis()
u.is_online = await redis.exists(f"metadata:online:{obj.id}")
u.cover_url = (
obj.cover.get(
"url", "https://assets.ppy.sh/user-profile-covers/default.jpeg"
@@ -249,13 +246,7 @@ class UserResp(UserBase):
await RelationshipResp.from_db(session, r)
for r in (
await session.exec(
select(Relationship)
.options(
joinedload(Relationship.target).options( # pyright: ignore[reportArgumentType]
*User.all_select_option()
)
)
.where(
select(Relationship).where(
Relationship.user_id == obj.id,
Relationship.type == RelationshipType.FOLLOW,
)
@@ -264,23 +255,26 @@ class UserResp(UserBase):
]
if "team" in include:
if obj.team_membership:
if await obj.awaitable_attrs.team_membership:
assert obj.team_membership
u.team = obj.team_membership.team
if "account_history" in include:
u.account_history = [
UserAccountHistoryResp.from_db(ah) for ah in obj.account_history
UserAccountHistoryResp.from_db(ah)
for ah in await obj.awaitable_attrs.account_history
]
if "daily_challenge_user_stats":
if obj.daily_challenge_stats:
if await obj.awaitable_attrs.daily_challenge_stats:
assert obj.daily_challenge_stats
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
obj.daily_challenge_stats
)
if "statistics" in include:
current_stattistics = None
for i in obj.statistics:
for i in await obj.awaitable_attrs.statistics:
if i.mode == (ruleset or obj.playmode):
current_stattistics = i
break
@@ -292,17 +286,20 @@ class UserResp(UserBase):
if "statistics_rulesets" in include:
u.statistics_rulesets = {
i.mode.value: UserStatisticsResp.from_db(i) for i in obj.statistics
i.mode.value: UserStatisticsResp.from_db(i)
for i in await obj.awaitable_attrs.statistics
}
if "monthly_playcounts" in include:
u.monthly_playcounts = [
MonthlyPlaycountsResp.from_db(pc) for pc in obj.monthly_playcounts
MonthlyPlaycountsResp.from_db(pc)
for pc in await obj.awaitable_attrs.monthly_playcounts
]
if "achievements" in include:
u.user_achievements = [
UserAchievementResp.from_db(ua) for ua in obj.achievement
UserAchievementResp.from_db(ua)
for ua in await obj.awaitable_attrs.achievement
]
return u
@@ -328,3 +325,9 @@ SEARCH_INCLUDED = [
"achievements",
"monthly_playcounts",
]
BASE_INCLUDES = [
"team",
"daily_challenge_user_stats",
"statistics",
]

View File

@@ -1,94 +0,0 @@
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import JSON, Column, DateTime
from sqlalchemy.orm import Mapped
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
# ============================================
# 旧的兼容性表模型(保留以便向后兼容)
# ============================================
class LegacyUserStatistics(SQLModel, table=True):
__tablename__ = "user_statistics" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
mode: str = Field(max_length=10) # osu, taiko, fruits, mania
# 基本统计
count_100: int = Field(default=0)
count_300: int = Field(default=0)
count_50: int = Field(default=0)
count_miss: int = Field(default=0)
# 等级信息
level_current: int = Field(default=1)
level_progress: int = Field(default=0)
# 排名信息
global_rank: int | None = Field(default=None)
global_rank_exp: int | None = Field(default=None)
country_rank: int | None = Field(default=None)
# PP 和分数
pp: float = Field(default=0.0)
pp_exp: float = Field(default=0.0)
ranked_score: int = Field(default=0)
hit_accuracy: float = Field(default=0.0)
total_score: int = Field(default=0)
total_hits: int = Field(default=0)
maximum_combo: int = Field(default=0)
# 游戏统计
play_count: int = Field(default=0)
play_time: int = Field(default=0)
replays_watched_by_others: int = Field(default=0)
is_ranked: bool = Field(default=False)
# 成绩等级计数
grade_ss: int = Field(default=0)
grade_ssh: int = Field(default=0)
grade_s: int = Field(default=0)
grade_sh: int = Field(default=0)
grade_a: int = Field(default=0)
# 最高排名记录
rank_highest: int | None = Field(default=None)
rank_highest_updated_at: datetime | None = Field(
default=None, sa_column=Column(DateTime)
)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
# 关联关系
user: Mapped["User"] = Relationship(back_populates="statistics")
class LegacyOAuthToken(SQLModel, table=True):
__tablename__ = "legacy_oauth_tokens" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
access_token: str = Field(max_length=255, index=True)
refresh_token: str = Field(max_length=255, index=True)
expires_at: datetime = Field(sa_column=Column(DateTime))
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
previous_usernames: list = Field(default_factory=list, sa_column=Column(JSON))
replays_watched_counts: list = Field(default_factory=list, sa_column=Column(JSON))
# 用户关系
user: "User" = Relationship()

View File

@@ -0,0 +1,41 @@
from typing import TYPE_CHECKING
from app.models.score import GameMode
from .lazer_user import User
from sqlmodel import (
BigInteger,
Column,
Field,
Float,
ForeignKey,
Relationship,
SQLModel,
)
if TYPE_CHECKING:
from .beatmap import Beatmap
from .score import Score
class PPBestScore(SQLModel, table=True):
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True)
pp: float = Field(
sa_column=Column(Float, default=0),
)
acc: float = Field(
sa_column=Column(Float, default=0),
)
user: User = Relationship()
score: "Score" = Relationship()
beatmap: "Beatmap" = Relationship()

View File

@@ -1,8 +1,6 @@
from enum import Enum
from app.models.user import User as APIUser
from .user import User as DBUser
from .lazer_user import User, UserResp
from pydantic import BaseModel
from sqlmodel import (
@@ -24,12 +22,16 @@ class RelationshipType(str, Enum):
class Relationship(SQLModel, table=True):
__tablename__ = "relationship" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
exclude=True,
)
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
ForeignKey("lazer_users.id"),
index=True,
),
)
@@ -37,20 +39,22 @@ class Relationship(SQLModel, table=True):
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
ForeignKey("lazer_users.id"),
index=True,
),
)
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
target: DBUser = SQLRelationship(
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"}
target: User = SQLRelationship(
sa_relationship_kwargs={
"foreign_keys": "[Relationship.target_id]",
"lazy": "selectin",
}
)
class RelationshipResp(BaseModel):
target_id: int
target: APIUser
target: UserResp
mutual: bool = False
type: RelationshipType
@@ -58,8 +62,6 @@ class RelationshipResp(BaseModel):
async def from_db(
cls, session: AsyncSession, relationship: Relationship
) -> "RelationshipResp":
from app.utils import convert_db_user_to_api_user
target_relationship = (
await session.exec(
select(Relationship).where(
@@ -75,7 +77,16 @@ class RelationshipResp(BaseModel):
)
return cls(
target_id=relationship.target_id,
target=await convert_db_user_to_api_user(relationship.target),
target=await UserResp.from_db(
relationship.target,
session,
include=[
"team",
"daily_challenge_user_stats",
"statistics",
"statistics_rulesets",
],
),
mutual=mutual,
type=relationship.type,
)

View File

@@ -1,6 +1,7 @@
import asyncio
from collections.abc import Sequence
from datetime import UTC, datetime
from datetime import UTC, date, datetime
import json
import math
from typing import TYPE_CHECKING
@@ -12,9 +13,8 @@ from app.calculator import (
calculate_weighted_pp,
clamp,
)
from app.database.score_token import ScoreToken
from app.database.user import LazerUserStatistics, User
from app.models.beatmap import BeatmapRankStatus
from app.database.team import TeamMember
from app.models.model import UTCBaseModel
from app.models.mods import APIMod, mods_can_get_pp
from app.models.score import (
INT_TO_MODE,
@@ -26,15 +26,24 @@ from app.models.score import (
ScoreStatistics,
SoloScoreSubmissionInfo,
)
from app.models.user import User as APIUser
from .beatmap import Beatmap, BeatmapResp
from .beatmapset import Beatmapset, BeatmapsetResp
from .beatmapset import BeatmapsetResp
from .best_score import BestScore
from .lazer_user import User, UserResp
from .monthly_playcounts import MonthlyPlaycounts
from .pp_best_score import PPBestScore
from .relationship import (
Relationship as DBRelationship,
RelationshipType,
)
from .score_token import ScoreToken
from redis import Redis
from redis.asyncio import Redis
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
from sqlalchemy.orm import aliased, joinedload
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import (
JSON,
BigInteger,
@@ -43,9 +52,10 @@ from sqlmodel import (
Relationship,
SQLModel,
col,
false,
func,
select,
text,
true,
)
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql._expression_select_cls import SelectOfScalar
@@ -54,7 +64,7 @@ if TYPE_CHECKING:
from app.fetcher import Fetcher
class ScoreBase(SQLModel):
class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
# 基本字段
accuracy: float
map_md5: str = Field(max_length=32, index=True)
@@ -94,7 +104,7 @@ class Score(ScoreBase, table=True):
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
ForeignKey("lazer_users.id"),
index=True,
),
)
@@ -112,28 +122,13 @@ class Score(ScoreBase, table=True):
gamemode: GameMode = Field(index=True)
# optional
beatmap: "Beatmap" = Relationship()
user: "User" = Relationship()
beatmap: Beatmap = Relationship()
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
@property
def is_perfect_combo(self) -> bool:
return self.max_combo == self.beatmap.max_combo
@staticmethod
def select_clause(with_user: bool = True) -> SelectOfScalar["Score"]:
clause = select(Score).options(
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
)
if with_user:
return clause.options(
joinedload(Score.user).options(*User.all_select_option()) # pyright: ignore[reportArgumentType]
)
return clause
@staticmethod
def select_clause_unique(
*where_clauses: ColumnExpressionArgument[bool] | bool,
@@ -147,18 +142,7 @@ class Score(ScoreBase, table=True):
)
subq = select(Score, rownum).where(*where_clauses).subquery()
best = aliased(Score, subq, adapt_on_names=True)
return (
select(best)
.where(subq.c.rn == 1)
.options(
joinedload(best.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
joinedload(best.user).options(*User.all_select_option()), # pyright: ignore[reportArgumentType]
)
)
return select(best).where(subq.c.rn == 1)
class ScoreResp(ScoreBase):
@@ -173,22 +157,21 @@ class ScoreResp(ScoreBase):
ruleset_id: int | None = None
beatmap: BeatmapResp | None = None
beatmapset: BeatmapsetResp | None = None
user: APIUser | None = None
user: UserResp | None = None
statistics: ScoreStatistics | None = None
maximum_statistics: ScoreStatistics | None = None
rank_global: int | None = None
rank_country: int | None = None
@classmethod
async def from_db(
cls, session: AsyncSession, score: Score, user: User | None = None
) -> "ScoreResp":
from app.utils import convert_db_user_to_api_user
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
s = cls.model_validate(score.model_dump())
assert score.id
s.beatmap = BeatmapResp.from_db(score.beatmap)
s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset)
await score.awaitable_attrs.beatmap
s.beatmap = await BeatmapResp.from_db(score.beatmap)
s.beatmapset = await BeatmapsetResp.from_db(
score.beatmap.beatmapset, session=session, user=score.user
)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = MODE_TO_INT[score.gamemode]
@@ -220,25 +203,30 @@ class ScoreResp(ScoreBase):
s.maximum_statistics = {
HitResult.GREAT: score.beatmap.max_combo,
}
if user:
s.user = await convert_db_user_to_api_user(user)
s.user = await UserResp.from_db(
score.user,
session,
include=["statistics", "team", "daily_challenge_user_stats"],
ruleset=score.gamemode,
)
s.rank_global = (
await get_score_position_by_id(
session,
score.map_md5,
score.beatmap_id,
score.id,
mode=score.gamemode,
user=user or score.user,
user=score.user,
)
or None
)
s.rank_country = (
await get_score_position_by_id(
session,
score.map_md5,
score.beatmap_id,
score.id,
score.gamemode,
user or score.user,
score.user,
type=LeaderboardType.COUNTRY,
)
or None
)
@@ -248,135 +236,137 @@ class ScoreResp(ScoreBase):
async def get_best_id(session: AsyncSession, score_id: int) -> None:
rownum = (
func.row_number()
.over(partition_by=col(BestScore.user_id), order_by=col(BestScore.pp).desc())
.over(
partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()
)
.label("rn")
)
subq = select(BestScore, rownum).subquery()
subq = select(PPBestScore, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.score_id == score_id)
result = await session.exec(stmt)
return result.one_or_none()
async def _score_where(
type: LeaderboardType,
beatmap: int,
mode: GameMode,
mods: list[str] | None = None,
user: User | None = None,
) -> list[ColumnElement[bool]] | None:
wheres = [
col(BestScore.beatmap_id) == beatmap,
col(BestScore.gamemode) == mode,
]
if type == LeaderboardType.FRIENDS:
if user and user.is_supporter:
subq = (
select(DBRelationship.target_id)
.where(
DBRelationship.type == RelationshipType.FOLLOW,
DBRelationship.user_id == user.id,
)
.subquery()
)
wheres.append(col(BestScore.user_id).in_(select(subq.c.target_id)))
else:
return None
elif type == LeaderboardType.COUNTRY:
if user and user.is_supporter:
wheres.append(
col(BestScore.user).has(col(User.country_code) == user.country_code)
)
else:
return None
elif type == LeaderboardType.TEAM:
if user:
team_membership = await user.awaitable_attrs.team_membership
if team_membership:
team_id = team_membership.team_id
wheres.append(
col(BestScore.user).has(
col(User.team_membership).has(TeamMember.team_id == team_id)
)
)
if mods:
if user and user.is_supporter:
wheres.append(
text(
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
) # pyright: ignore[reportArgumentType]
)
else:
return None
return wheres
async def get_leaderboard(
session: AsyncSession,
beatmap_md5: str,
beatmap: int,
mode: GameMode,
type: LeaderboardType = LeaderboardType.GLOBAL,
mods: list[APIMod] | None = None,
mods: list[str] | None = None,
user: User | None = None,
limit: int = 50,
) -> list[Score]:
scores = []
if type == LeaderboardType.GLOBAL:
query = (
select(Score)
.where(
col(Beatmap.beatmap_status).in_(
[
BeatmapRankStatus.RANKED,
BeatmapRankStatus.LOVED,
BeatmapRankStatus.QUALIFIED,
BeatmapRankStatus.APPROVED,
]
),
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
col(Score.passed).is_(True),
Score.mods == mods if user and user.is_supporter else false(),
)
.limit(limit)
.order_by(
col(Score.total_score).desc(),
)
)
result = await session.exec(query)
scores = list[Score](result.all())
elif type == LeaderboardType.FRIENDS and user and user.is_supporter:
# TODO
...
elif type == LeaderboardType.TEAM and user and user.team_membership:
team_id = user.team_membership.team_id
query = (
select(Score)
.join(Beatmap)
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
.where(
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
col(Score.passed).is_(True),
col(Score.user.team_membership).is_not(None),
Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
Score.mods == mods if user and user.is_supporter else false(),
)
.limit(limit)
.order_by(
col(Score.total_score).desc(),
)
)
result = await session.exec(query)
scores = list[Score](result.all())
) -> tuple[list[Score], Score | None]:
wheres = await _score_where(type, beatmap, mode, mods, user)
if wheres is None:
return [], None
query = (
select(BestScore)
.where(*wheres)
.limit(limit)
.order_by(col(BestScore.total_score).desc())
)
if mods:
query = query.params(w=json.dumps(mods))
scores = [s.score for s in await session.exec(query)]
user_score = None
if user:
user_score = (
await session.exec(
select(Score).where(
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
Score.user_id == user.id,
col(Score.passed).is_(True),
self_query = (
select(BestScore)
.where(BestScore.user_id == user.id)
.order_by(col(BestScore.total_score).desc())
.limit(1)
)
if mods:
self_query = self_query.where(
text(
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
)
)
).first()
).params(w=json.dumps(mods))
user_bs = (await session.exec(self_query)).first()
if user_bs:
user_score = user_bs.score
if user_score and user_score not in scores:
scores.append(user_score)
return scores
return scores, user_score
async def get_score_position_by_user(
session: AsyncSession,
beatmap_md5: str,
beatmap: int,
user: User,
mode: GameMode,
type: LeaderboardType = LeaderboardType.GLOBAL,
mods: list[APIMod] | None = None,
mods: list[str] | None = None,
) -> int:
where_clause = [
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
col(Score.passed).is_(True),
col(Beatmap.beatmap_status).in_(
[
BeatmapRankStatus.RANKED,
BeatmapRankStatus.LOVED,
BeatmapRankStatus.QUALIFIED,
BeatmapRankStatus.APPROVED,
]
),
]
if mods and user.is_supporter:
where_clause.append(Score.mods == mods)
else:
where_clause.append(false())
if type == LeaderboardType.FRIENDS and user.is_supporter:
# TODO
...
elif type == LeaderboardType.TEAM and user.team_membership:
team_id = user.team_membership.team_id
where_clause.append(
col(Score.user.team_membership).is_not(None),
)
where_clause.append(
Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
)
wheres = await _score_where(type, beatmap, mode, mods, user=user)
if wheres is None:
return 0
rownum = (
func.row_number()
.over(
partition_by=Score.map_md5,
order_by=col(Score.total_score).desc(),
partition_by=col(BestScore.beatmap_id),
order_by=col(BestScore.total_score).desc(),
)
.label("row_number")
)
subq = select(Score, rownum).join(Beatmap).where(*where_clause).subquery()
stmt = select(subq.c.row_number).where(subq.c.user == user)
subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
stmt = select(subq.c.row_number).where(subq.c.user_id == user.id)
result = await session.exec(stmt)
s = result.one_or_none()
return s if s else 0
@@ -384,57 +374,26 @@ async def get_score_position_by_user(
async def get_score_position_by_id(
session: AsyncSession,
beatmap_md5: str,
beatmap: int,
score_id: int,
mode: GameMode,
user: User | None = None,
type: LeaderboardType = LeaderboardType.GLOBAL,
mods: list[APIMod] | None = None,
mods: list[str] | None = None,
) -> int:
where_clause = [
Score.map_md5 == beatmap_md5,
Score.id == score_id,
Score.gamemode == mode,
col(Score.passed).is_(True),
col(Beatmap.beatmap_status).in_(
[
BeatmapRankStatus.RANKED,
BeatmapRankStatus.LOVED,
BeatmapRankStatus.QUALIFIED,
BeatmapRankStatus.APPROVED,
]
),
]
if mods and user and user.is_supporter:
where_clause.append(Score.mods == mods)
elif mods:
where_clause.append(false())
wheres = await _score_where(type, beatmap, mode, mods, user=user)
if wheres is None:
return 0
rownum = (
func.row_number()
.over(
partition_by=[col(Score.user_id), col(Score.map_md5)],
order_by=col(Score.total_score).desc(),
partition_by=col(BestScore.beatmap_id),
order_by=col(BestScore.total_score).desc(),
)
.label("rownum")
.label("row_number")
)
subq = (
select(Score.user_id, Score.id, Score.total_score, rownum)
.join(Beatmap)
.where(*where_clause)
.subquery()
)
best_scores = aliased(subq)
overall_rank = (
func.rank().over(order_by=best_scores.c.total_score.desc()).label("global_rank")
)
final_q = (
select(best_scores.c.id, overall_rank)
.select_from(best_scores)
.where(best_scores.c.rownum == 1)
.subquery()
)
stmt = select(final_q.c.global_rank).where(final_q.c.id == score_id)
subq = select(BestScore, rownum).join(Beatmap).where(*wheres).subquery()
stmt = select(subq.c.row_number).where(subq.c.score_id == score_id)
result = await session.exec(stmt)
s = result.one_or_none()
return s if s else 0
@@ -445,16 +404,38 @@ async def get_user_best_score_in_beatmap(
beatmap: int,
user: int,
mode: GameMode | None = None,
) -> Score | None:
) -> BestScore | None:
return (
await session.exec(
Score.select_clause(False)
select(BestScore)
.where(
Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap,
Score.user_id == user,
BestScore.gamemode == mode if mode is not None else true(),
BestScore.beatmap_id == beatmap,
BestScore.user_id == user,
)
.order_by(col(Score.total_score).desc())
.order_by(col(BestScore.total_score).desc())
)
).first()
# FIXME
async def get_user_best_score_with_mod_in_beatmap(
session: AsyncSession,
beatmap: int,
user: int,
mod: list[str],
mode: GameMode | None = None,
) -> BestScore | None:
return (
await session.exec(
select(BestScore)
.where(
BestScore.gamemode == mode if mode is not None else True,
BestScore.beatmap_id == beatmap,
BestScore.user_id == user,
# BestScore.mods == mod,
)
.order_by(col(BestScore.total_score).desc())
)
).first()
@@ -464,13 +445,13 @@ async def get_user_best_pp_in_beatmap(
beatmap: int,
user: int,
mode: GameMode,
) -> BestScore | None:
) -> PPBestScore | None:
return (
await session.exec(
select(BestScore).where(
BestScore.beatmap_id == beatmap,
BestScore.user_id == user,
BestScore.gamemode == mode,
select(PPBestScore).where(
PPBestScore.beatmap_id == beatmap,
PPBestScore.user_id == user,
PPBestScore.gamemode == mode,
)
)
).first()
@@ -480,12 +461,12 @@ async def get_user_best_pp(
session: AsyncSession,
user: int,
limit: int = 200,
) -> Sequence[BestScore]:
) -> Sequence[PPBestScore]:
return (
await session.exec(
select(BestScore)
.where(BestScore.user_id == user)
.order_by(col(BestScore.pp).desc())
select(PPBestScore)
.where(PPBestScore.user_id == user)
.order_by(col(PPBestScore.pp).desc())
.limit(limit)
)
).all()
@@ -494,27 +475,45 @@ async def get_user_best_pp(
async def process_user(
session: AsyncSession, user: User, score: Score, ranked: bool = False
):
assert user.id
assert score.id
mod_for_save = list({mod["acronym"] for mod in score.mods})
previous_score_best = await get_user_best_score_in_beatmap(
session, score.beatmap_id, user.id, score.gamemode
)
statistics = None
previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap(
session, score.beatmap_id, user.id, mod_for_save, score.gamemode
)
add_to_db = False
for i in user.lazer_statistics:
mouthly_playcount = (
await session.exec(
select(MonthlyPlaycounts).where(
MonthlyPlaycounts.user_id == user.id,
MonthlyPlaycounts.year == date.today().year,
MonthlyPlaycounts.month == date.today().month,
)
)
).first()
if mouthly_playcount is None:
mouthly_playcount = MonthlyPlaycounts(
user_id=user.id, year=date.today().year, month=date.today().month
)
add_to_db = True
statistics = None
for i in await user.awaitable_attrs.statistics:
if i.mode == score.gamemode.value:
statistics = i
break
if statistics is None:
statistics = LazerUserStatistics(
mode=score.gamemode.value,
user_id=user.id,
raise ValueError(
f"User {user.id} does not have statistics for mode {score.gamemode.value}"
)
add_to_db = True
# pc, pt, tth, tts
statistics.total_score += score.total_score
difference = (
score.total_score - previous_score_best.total_score
if previous_score_best and previous_score_best.id != score.id
if previous_score_best
else score.total_score
)
if difference > 0 and score.passed and ranked:
@@ -541,11 +540,48 @@ async def process_user(
statistics.grade_sh -= 1
case Rank.A:
statistics.grade_a -= 1
else:
previous_score_best = BestScore(
user_id=user.id,
beatmap_id=score.beatmap_id,
gamemode=score.gamemode,
score_id=score.id,
total_score=score.total_score,
rank=score.rank,
mods=mod_for_save,
)
session.add(previous_score_best)
statistics.ranked_score += difference
statistics.level_current = calculate_score_to_level(statistics.ranked_score)
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)
if score.passed and ranked:
if previous_score_best_mod is not None:
previous_score_best_mod.mods = mod_for_save
previous_score_best_mod.score_id = score.id
previous_score_best_mod.rank = score.rank
previous_score_best_mod.total_score = score.total_score
elif (
previous_score_best is not None and previous_score_best.score_id != score.id
):
session.add(
BestScore(
user_id=user.id,
beatmap_id=score.beatmap_id,
gamemode=score.gamemode,
score_id=score.id,
total_score=score.total_score,
rank=score.rank,
mods=mod_for_save,
)
)
statistics.play_count += 1
mouthly_playcount.playcount += 1
statistics.play_time += int((score.ended_at - score.started_at).total_seconds())
statistics.count_100 += score.n100 + score.nkatu
statistics.count_300 += score.n300 + score.ngeki
statistics.count_50 += score.n50
statistics.count_miss += score.nmiss
statistics.total_hits += (
score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
)
@@ -563,11 +599,8 @@ async def process_user(
acc_sum = clamp(acc_sum, 0.0, 100.0)
statistics.pp = pp_sum
statistics.hit_accuracy = acc_sum
statistics.updated_at = datetime.now(UTC)
if add_to_db:
session.add(statistics)
session.add(mouthly_playcount)
await session.commit()
await session.refresh(user)
@@ -582,6 +615,8 @@ async def process_score(
session: AsyncSession,
redis: Redis,
) -> Score:
assert user.id
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
score = Score(
accuracy=info.accuracy,
max_combo=info.max_combo,
@@ -611,7 +646,7 @@ async def process_score(
nlarge_tick_hit=info.statistics.get(HitResult.LARGE_TICK_HIT, 0),
nslider_tail_hit=info.statistics.get(HitResult.SLIDER_TAIL_HIT, 0),
)
if info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods):
if can_get_pp:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
pp = await asyncio.get_event_loop().run_in_executor(
None, calculate_pp, score, beatmap_raw
@@ -621,13 +656,13 @@ async def process_score(
user_id = user.id
await session.commit()
await session.refresh(score)
if score.passed and ranked:
if can_get_pp:
previous_pp_best = await get_user_best_pp_in_beatmap(
session, beatmap_id, user_id, score.gamemode
)
if previous_pp_best is None or score.pp > previous_pp_best.pp:
assert score.id
best_score = BestScore(
best_score = PPBestScore(
user_id=user_id,
score_id=score.id,
beatmap_id=beatmap_id,
@@ -636,7 +671,7 @@ async def process_score(
acc=score.accuracy,
)
session.add(best_score)
session.delete(previous_pp_best) if previous_pp_best else None
await session.delete(previous_pp_best) if previous_pp_best else None
await session.commit()
await session.refresh(score)
await session.refresh(score_token)

View File

@@ -1,15 +1,16 @@
from datetime import datetime
from app.models.model import UTCBaseModel
from app.models.score import GameMode
from .beatmap import Beatmap
from .user import User
from .lazer_user import User
from sqlalchemy import Column, DateTime, Index
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
class ScoreTokenBase(SQLModel):
class ScoreTokenBase(SQLModel, UTCBaseModel):
score_id: int | None = Field(sa_column=Column(BigInteger), default=None)
ruleset_id: GameMode
playlist_item_id: int | None = Field(default=None) # playlist
@@ -34,10 +35,10 @@ class ScoreToken(ScoreTokenBase, table=True):
autoincrement=True,
),
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
beatmap_id: int = Field(foreign_key="beatmaps.id")
user: "User" = Relationship()
beatmap: "Beatmap" = Relationship()
user: User = Relationship()
beatmap: Beatmap = Relationship()
class ScoreTokenResp(ScoreTokenBase):

View File

@@ -1,14 +1,16 @@
from datetime import datetime
from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from sqlalchemy import Column, DateTime
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
from .lazer_user import User
class Team(SQLModel, table=True):
class Team(SQLModel, UTCBaseModel, table=True):
__tablename__ = "teams" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
@@ -22,15 +24,19 @@ class Team(SQLModel, table=True):
members: list["TeamMember"] = Relationship(back_populates="team")
class TeamMember(SQLModel, table=True):
class TeamMember(SQLModel, UTCBaseModel, table=True):
__tablename__ = "team_members" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
team_id: int = Field(foreign_key="teams.id")
joined_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="team_membership")
team: "Team" = Relationship(back_populates="members")
user: "User" = Relationship(
back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
)
team: "Team" = Relationship(
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
)

View File

@@ -1,527 +0,0 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from .legacy import LegacyUserStatistics
from .team import TeamMember
from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text
from sqlalchemy.dialects.mysql import VARCHAR
from sqlalchemy.orm import joinedload, selectinload
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel, select
class User(SQLModel, table=True):
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
# 主键
id: int = Field(
default=None, sa_column=Column(BigInteger, primary_key=True, index=True)
)
# 基本信息(匹配 migrations_old 中的结构)
name: str = Field(max_length=32, unique=True, index=True) # 用户名
safe_name: str = Field(max_length=32, unique=True, index=True) # 安全用户名
email: str = Field(max_length=254, unique=True, index=True)
priv: int = Field(default=1) # 权限
pw_bcrypt: str = Field(max_length=60) # bcrypt 哈希密码
country: str = Field(default="CN", max_length=2) # 国家代码
# 状态和时间
silence_end: int = Field(default=0)
donor_end: int = Field(default=0)
creation_time: int = Field(default=0) # Unix 时间戳
latest_activity: int = Field(default=0) # Unix 时间戳
# 游戏相关
preferred_mode: int = Field(default=0) # 偏好游戏模式
play_style: int = Field(default=0) # 游戏风格
# 扩展信息
clan_id: int = Field(default=0)
clan_priv: int = Field(default=0)
custom_badge_name: str | None = Field(default=None, max_length=16)
custom_badge_icon: str | None = Field(default=None, max_length=64)
userpage_content: str | None = Field(default=None, max_length=2048)
api_key: str | None = Field(default=None, max_length=36, unique=True)
# 虚拟字段用于兼容性
@property
def username(self):
return self.name
@property
def country_code(self):
return self.country
@property
def join_date(self):
creation_time = getattr(self, "creation_time", 0)
return (
datetime.fromtimestamp(creation_time)
if creation_time > 0
else datetime.utcnow()
)
@property
def last_visit(self):
latest_activity = getattr(self, "latest_activity", 0)
return datetime.fromtimestamp(latest_activity) if latest_activity > 0 else None
@property
def is_supporter(self):
return self.lazer_profile.is_supporter if self.lazer_profile else False
# 关联关系
lazer_profile: Optional["LazerUserProfile"] = Relationship(back_populates="user")
lazer_statistics: list["LazerUserStatistics"] = Relationship(back_populates="user")
lazer_counts: Optional["LazerUserCounts"] = Relationship(back_populates="user")
lazer_achievements: list["LazerUserAchievement"] = Relationship(
back_populates="user"
)
lazer_profile_sections: list["LazerUserProfileSections"] = Relationship(
back_populates="user"
)
statistics: list["LegacyUserStatistics"] = Relationship(back_populates="user")
team_membership: Optional["TeamMember"] = Relationship(back_populates="user")
daily_challenge_stats: Optional["DailyChallengeStats"] = Relationship(
back_populates="user"
)
rank_history: list["RankHistory"] = Relationship(back_populates="user")
avatar: Optional["UserAvatar"] = Relationship(back_populates="user")
active_banners: list["LazerUserBanners"] = Relationship(back_populates="user")
lazer_badges: list["LazerUserBadge"] = Relationship(back_populates="user")
lazer_monthly_playcounts: list["LazerUserMonthlyPlaycounts"] = Relationship(
back_populates="user"
)
lazer_previous_usernames: list["LazerUserPreviousUsername"] = Relationship(
back_populates="user"
)
lazer_replays_watched: list["LazerUserReplaysWatched"] = Relationship(
back_populates="user"
)
@classmethod
def all_select_option(cls):
return (
joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType]
joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType]
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
joinedload(cls.avatar), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_statistics), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_achievements), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_profile_sections), # pyright: ignore[reportArgumentType]
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
joinedload(cls.team_membership), # pyright: ignore[reportArgumentType]
selectinload(cls.rank_history), # pyright: ignore[reportArgumentType]
selectinload(cls.active_banners), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_badges), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType]
)
@classmethod
def all_select_clause(cls):
return select(cls).options(*cls.all_select_option())
# ============================================
# Lazer API 专用表模型
# ============================================
class LazerUserProfile(SQLModel, table=True):
__tablename__ = "lazer_user_profiles" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
# 基本状态字段
is_active: bool = Field(default=True)
is_bot: bool = Field(default=False)
is_deleted: bool = Field(default=False)
is_online: bool = Field(default=True)
is_supporter: bool = Field(default=False)
is_restricted: bool = Field(default=False)
session_verified: bool = Field(default=False)
has_supported: bool = Field(default=False)
pm_friends_only: bool = Field(default=False)
# 基本资料字段
default_group: str = Field(default="default", max_length=50)
last_visit: datetime | None = Field(default=None, sa_column=Column(DateTime))
join_date: datetime | None = Field(default=None, sa_column=Column(DateTime))
profile_colour: str | None = Field(default=None, max_length=7)
profile_hue: int | None = Field(default=None)
# 社交媒体和个人资料字段
avatar_url: str | None = Field(default=None, max_length=500)
cover_url: str | None = Field(default=None, max_length=500)
discord: str | None = Field(default=None, max_length=100)
twitter: str | None = Field(default=None, max_length=100)
website: str | None = Field(default=None, max_length=500)
title: str | None = Field(default=None, max_length=100)
title_url: str | None = Field(default=None, max_length=500)
interests: str | None = Field(default=None, sa_column=Column(Text))
location: str | None = Field(default=None, max_length=100)
occupation: str | None = Field(default=None) # 职业字段,默认为 None
# 游戏相关字段
playmode: str = Field(default="osu", max_length=10)
support_level: int = Field(default=0)
max_blocks: int = Field(default=100)
max_friends: int = Field(default=500)
post_count: int = Field(default=0)
# 页面内容
page_html: str | None = Field(default=None, sa_column=Column(Text))
page_raw: str | None = Field(default=None, sa_column=Column(Text))
profile_order: str = Field(
default="me,recent_activity,top_ranks,medals,historical,beatmaps,kudosu"
)
# 关联关系
user: "User" = Relationship(back_populates="lazer_profile")
class LazerUserProfileSections(SQLModel, table=True):
__tablename__ = "lazer_user_profile_sections" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
section_name: str = Field(sa_column=Column(VARCHAR(50)))
display_order: int | None = Field(default=None)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_profile_sections")
class LazerUserCountry(SQLModel, table=True):
__tablename__ = "lazer_user_countries" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
code: str = Field(max_length=2)
name: str = Field(max_length=100)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
class LazerUserKudosu(SQLModel, table=True):
__tablename__ = "lazer_user_kudosu" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
available: int = Field(default=0)
total: int = Field(default=0)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
class LazerUserCounts(SQLModel, table=True):
__tablename__ = "lazer_user_counts" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
# 统计计数字段
beatmap_playcounts_count: int = Field(default=0)
comments_count: int = Field(default=0)
favourite_beatmapset_count: int = Field(default=0)
follower_count: int = Field(default=0)
graveyard_beatmapset_count: int = Field(default=0)
guest_beatmapset_count: int = Field(default=0)
loved_beatmapset_count: int = Field(default=0)
mapping_follower_count: int = Field(default=0)
nominated_beatmapset_count: int = Field(default=0)
pending_beatmapset_count: int = Field(default=0)
ranked_beatmapset_count: int = Field(default=0)
ranked_and_approved_beatmapset_count: int = Field(default=0)
unranked_beatmapset_count: int = Field(default=0)
scores_best_count: int = Field(default=0)
scores_first_count: int = Field(default=0)
scores_pinned_count: int = Field(default=0)
scores_recent_count: int = Field(default=0)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
# 关联关系
user: "User" = Relationship(back_populates="lazer_counts")
class LazerUserStatistics(SQLModel, table=True):
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
mode: str = Field(default="osu", max_length=10, primary_key=True)
# 基本命中统计
count_100: int = Field(default=0)
count_300: int = Field(default=0)
count_50: int = Field(default=0)
count_miss: int = Field(default=0)
# 等级信息
level_current: int = Field(default=1)
level_progress: int = Field(default=0)
# 排名信息
global_rank: int | None = Field(default=None)
global_rank_exp: int | None = Field(default=None)
country_rank: int | None = Field(default=None)
# PP 和分数
pp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2)))
pp_exp: float = Field(default=0.00, sa_column=Column(DECIMAL(10, 2)))
ranked_score: int = Field(default=0, sa_column=Column(BigInteger))
hit_accuracy: float = Field(default=0.00, sa_column=Column(DECIMAL(5, 2)))
total_score: int = Field(default=0, sa_column=Column(BigInteger))
total_hits: int = Field(default=0, sa_column=Column(BigInteger))
maximum_combo: int = Field(default=0)
# 游戏统计
play_count: int = Field(default=0)
play_time: int = Field(default=0) # 秒
replays_watched_by_others: int = Field(default=0)
is_ranked: bool = Field(default=False)
# 成绩等级计数
grade_ss: int = Field(default=0)
grade_ssh: int = Field(default=0)
grade_s: int = Field(default=0)
grade_sh: int = Field(default=0)
grade_a: int = Field(default=0)
# 最高排名记录
rank_highest: int | None = Field(default=None)
rank_highest_updated_at: datetime | None = Field(
default=None, sa_column=Column(DateTime)
)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
# 关联关系
user: "User" = Relationship(back_populates="lazer_statistics")
class LazerUserBanners(SQLModel, table=True):
__tablename__ = "lazer_user_tournament_banners" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
tournament_id: int
image_url: str = Field(sa_column=Column(VARCHAR(500)))
is_active: bool | None = Field(default=None)
# 修正user关系的back_populates值
user: "User" = Relationship(back_populates="active_banners")
class LazerUserAchievement(SQLModel, table=True):
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
achievement_id: int
achieved_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_achievements")
class LazerUserBadge(SQLModel, table=True):
__tablename__ = "lazer_user_badges" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
badge_id: int
awarded_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
description: str | None = Field(default=None, sa_column=Column(Text))
image_url: str | None = Field(default=None, max_length=500)
url: str | None = Field(default=None, max_length=500)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_badges")
class LazerUserMonthlyPlaycounts(SQLModel, table=True):
__tablename__ = "lazer_user_monthly_playcounts" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
start_date: datetime = Field(sa_column=Column(Date))
play_count: int = Field(default=0)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_monthly_playcounts")
class LazerUserPreviousUsername(SQLModel, table=True):
__tablename__ = "lazer_user_previous_usernames" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
username: str = Field(max_length=32)
changed_at: datetime = Field(sa_column=Column(DateTime))
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_previous_usernames")
class LazerUserReplaysWatched(SQLModel, table=True):
__tablename__ = "lazer_user_replays_watched" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
start_date: datetime = Field(sa_column=Column(Date))
count: int = Field(default=0)
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="lazer_replays_watched")
# 类型转换用的 UserAchievement不是 SQLAlchemy 模型)
@dataclass
class UserAchievement:
achieved_at: datetime
achievement_id: int
class DailyChallengeStats(SQLModel, table=True):
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("users.id"), unique=True)
)
daily_streak_best: int = Field(default=0)
daily_streak_current: int = Field(default=0)
last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
last_weekly_streak: datetime | None = Field(
default=None, sa_column=Column(DateTime)
)
playcount: int = Field(default=0)
top_10p_placements: int = Field(default=0)
top_50p_placements: int = Field(default=0)
weekly_streak_best: int = Field(default=0)
weekly_streak_current: int = Field(default=0)
user: "User" = Relationship(back_populates="daily_challenge_stats")
class RankHistory(SQLModel, table=True):
__tablename__ = "rank_history" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
mode: str = Field(max_length=10)
rank_data: list = Field(sa_column=Column(JSON)) # Array of ranks
date_recorded: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship(back_populates="rank_history")
class UserAvatar(SQLModel, table=True):
__tablename__ = "user_avatars" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
filename: str = Field(max_length=255)
original_filename: str = Field(max_length=255)
file_size: int
mime_type: str = Field(max_length=100)
is_active: bool = Field(default=True)
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
r2_original_url: str | None = Field(default=None, max_length=500)
r2_game_url: str | None = Field(default=None, max_length=500)
user: "User" = Relationship(back_populates="avatar")

View File

@@ -5,15 +5,11 @@ import json
from app.config import settings
from pydantic import BaseModel
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
try:
import redis
except ImportError:
redis = None
def json_serializer(value):
if isinstance(value, BaseModel | SQLModel):
@@ -25,10 +21,7 @@ def json_serializer(value):
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
# Redis 连接
if redis:
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
else:
redis_client = None
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
# 数据库依赖

View File

@@ -8,7 +8,7 @@ from app.log import logger
fetcher: Fetcher | None = None
def get_fetcher() -> Fetcher:
async def get_fetcher() -> Fetcher:
global fetcher
if fetcher is None:
fetcher = Fetcher(
@@ -18,15 +18,14 @@ def get_fetcher() -> Fetcher:
settings.FETCHER_CALLBACK_URL,
)
redis = get_redis()
if redis:
access_token = redis.get(f"fetcher:access_token:{fetcher.client_id}")
if access_token:
fetcher.access_token = str(access_token)
refresh_token = redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
if refresh_token:
fetcher.refresh_token = str(refresh_token)
if not fetcher.access_token or not fetcher.refresh_token:
logger.opt(colors=True).info(
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
)
access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")
if access_token:
fetcher.access_token = str(access_token)
refresh_token = await redis.get(f"fetcher:refresh_token:{fetcher.client_id}")
if refresh_token:
fetcher.refresh_token = str(refresh_token)
if not fetcher.access_token or not fetcher.refresh_token:
logger.opt(colors=True).info(
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
)
return fetcher

View File

@@ -1,14 +1,13 @@
from __future__ import annotations
from app.auth import get_token_by_access_token
from app.database import (
User as DBUser,
)
from app.database import User
from .database import get_db
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer()
@@ -17,7 +16,7 @@ security = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
) -> DBUser:
) -> User:
"""获取当前认证用户"""
token = credentials.credentials
@@ -27,13 +26,9 @@ async def get_current_user(
return user
async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None:
async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None:
token_record = await get_token_by_access_token(db, token)
if not token_record:
return None
user = (
await db.exec(
DBUser.all_select_clause().where(DBUser.id == token_record.user_id)
)
).first()
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
return user

View File

@@ -59,16 +59,15 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
if redis:
redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
await redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
await redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
async def refresh_access_token(self) -> None:
async with AsyncClient() as client:
@@ -87,13 +86,12 @@ class BaseFetcher:
self.refresh_token = token_data.get("refresh_token", "")
self.token_expiry = int(time.time()) + token_data["expires_in"]
redis = get_redis()
if redis:
redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
await redis.set(
f"fetcher:access_token:{self.client_id}",
self.access_token,
ex=token_data["expires_in"],
)
await redis.set(
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)

View File

@@ -4,7 +4,7 @@ from ._base import BaseFetcher
from httpx import AsyncClient
from loguru import logger
import redis
import redis.asyncio as redis
class OsuDotDirectFetcher(BaseFetcher):
@@ -22,8 +22,8 @@ class OsuDotDirectFetcher(BaseFetcher):
async def get_or_fetch_beatmap_raw(
self, redis: redis.Redis, beatmap_id: int
) -> str:
if redis.exists(f"beatmap:{beatmap_id}:raw"):
return redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
if await redis.exists(f"beatmap:{beatmap_id}:raw"):
return await redis.get(f"beatmap:{beatmap_id}:raw") # pyright: ignore[reportReturnType]
raw = await self.get_beatmap_raw(beatmap_id)
redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
await redis.set(f"beatmap:{beatmap_id}:raw", raw, ex=60 * 60 * 24)
return raw

View File

@@ -42,11 +42,12 @@ class Language(IntEnum):
KOREAN = 6
FRENCH = 7
GERMAN = 8
ITALIAN = 9
SPANISH = 10
RUSSIAN = 11
POLISH = 12
OTHER = 13
SWEDISH = 9
ITALIAN = 10
SPANISH = 11
RUSSIAN = 12
POLISH = 13
OTHER = 14
class BeatmapAttributes(BaseModel):

View File

@@ -1,7 +1,6 @@
# OAuth 相关模型
from __future__ import annotations
from typing import List
from pydantic import BaseModel
@@ -39,18 +38,21 @@ class OAuthErrorResponse(BaseModel):
class RegistrationErrorResponse(BaseModel):
"""注册错误响应模型"""
form_error: dict
class UserRegistrationErrors(BaseModel):
"""用户注册错误模型"""
username: List[str] = []
user_email: List[str] = []
password: List[str] = []
username: list[str] = []
user_email: list[str] = []
password: list[str] = []
class RegistrationRequestErrors(BaseModel):
"""注册请求错误模型"""
message: str | None = None
redirect: str | None = None
user: UserRegistrationErrors | None = None

View File

@@ -132,7 +132,7 @@ class HitResultInt(IntEnum):
class LeaderboardType(Enum):
GLOBAL = "global"
FRIENDS = "friends"
FRIENDS = "friend"
COUNTRY = "country"
TEAM = "team"

View File

@@ -2,15 +2,11 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING
from .score import GameMode
from .model import UTCBaseModel
from pydantic import BaseModel
if TYPE_CHECKING:
from app.database import LazerUserAchievement, Team
class PlayStyle(str, Enum):
MOUSE = "mouse"
@@ -77,24 +73,7 @@ class MonthlyPlaycount(BaseModel):
count: int
class UserAchievement(BaseModel):
achieved_at: datetime
achievement_id: int
# 添加数据库模型转换方法
def to_db_model(self, user_id: int) -> "LazerUserAchievement":
from app.database import (
LazerUserAchievement,
)
return LazerUserAchievement(
user_id=user_id,
achievement_id=self.achievement_id,
achieved_at=self.achieved_at,
)
class RankHighest(BaseModel):
class RankHighest(UTCBaseModel):
rank: int
updated_at: datetime
@@ -104,115 +83,6 @@ class RankHistory(BaseModel):
data: list[int]
class DailyChallengeStats(BaseModel):
daily_streak_best: int = 0
daily_streak_current: int = 0
last_update: datetime | None = None
last_weekly_streak: datetime | None = None
playcount: int = 0
top_10p_placements: int = 0
top_50p_placements: int = 0
user_id: int
weekly_streak_best: int = 0
weekly_streak_current: int = 0
class Page(BaseModel):
html: str = ""
raw: str = ""
class User(BaseModel):
# 基本信息
id: int
username: str
avatar_url: str
country_code: str
default_group: str = "default"
is_active: bool = True
is_bot: bool = False
is_deleted: bool = False
is_online: bool = True
is_supporter: bool = False
is_restricted: bool = False
last_visit: datetime | None = None
pm_friends_only: bool = False
profile_colour: str | None = None
# 个人资料
cover_url: str | None = None
discord: str | None = None
has_supported: bool = False
interests: str | None = None
join_date: datetime
location: str | None = None
max_blocks: int = 100
max_friends: int = 500
occupation: str | None = None
playmode: GameMode = GameMode.OSU
playstyle: list[PlayStyle] = []
post_count: int = 0
profile_hue: int | None = None
profile_order: list[str] = [
"me",
"recent_activity",
"top_ranks",
"medals",
"historical",
"beatmaps",
"kudosu",
]
title: str | None = None
title_url: str | None = None
twitter: str | None = None
website: str | None = None
session_verified: bool = False
support_level: int = 0
# 关联对象
country: Country
cover: Cover
kudosu: Kudosu
statistics: Statistics
statistics_rulesets: dict[str, Statistics]
# 计数信息
beatmap_playcounts_count: int = 0
comments_count: int = 0
favourite_beatmapset_count: int = 0
follower_count: int = 0
graveyard_beatmapset_count: int = 0
guest_beatmapset_count: int = 0
loved_beatmapset_count: int = 0
mapping_follower_count: int = 0
nominated_beatmapset_count: int = 0
pending_beatmapset_count: int = 0
ranked_beatmapset_count: int = 0
ranked_and_approved_beatmapset_count: int = 0
unranked_beatmapset_count: int = 0
scores_best_count: int = 0
scores_first_count: int = 0
scores_pinned_count: int = 0
scores_recent_count: int = 0
# 历史数据
account_history: list[dict] = []
active_tournament_banner: dict | None = None
active_tournament_banners: list[dict] = []
badges: list[dict] = []
current_season_stats: dict | None = None
daily_challenge_user_stats: DailyChallengeStats | None = None
groups: list[dict] = []
monthly_playcounts: list[MonthlyPlaycount] = []
page: Page = Page()
previous_usernames: list[str] = []
rank_highest: RankHighest | None = None
rank_history: RankHistory | None = None
rankHistory: RankHistory | None = None # 兼容性别名
replays_watched_counts: list[dict] = []
team: "Team | None" = None
user_achievements: list[UserAchievement] = []
class APIUser(BaseModel):
id: int

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from datetime import timedelta
from datetime import UTC, datetime, timedelta
import re
from app.auth import (
@@ -12,17 +12,21 @@ from app.auth import (
store_token,
)
from app.config import settings
from app.database import User as DBUser
from app.database import DailyChallengeStats, User
from app.database.statistics import UserStatistics
from app.dependencies import get_db
from app.log import logger
from app.models.oauth import (
OAuthErrorResponse,
RegistrationRequestErrors,
TokenResponse,
UserRegistrationErrors,
)
from app.models.score import GameMode
from fastapi import APIRouter, Depends, Form
from fastapi.responses import JSONResponse
from sqlalchemy import text
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -110,12 +114,12 @@ async def register_user(
email_errors = validate_email(user_email)
password_errors = validate_password(user_password)
result = await db.exec(select(DBUser).where(DBUser.name == user_username))
result = await db.exec(select(User).where(User.username == user_username))
existing_user = result.first()
if existing_user:
username_errors.append("Username is already taken")
result = await db.exec(select(DBUser).where(DBUser.email == user_email))
result = await db.exec(select(User).where(User.email == user_email))
existing_email = result.first()
if existing_email:
email_errors.append("Email is already taken")
@@ -135,119 +139,41 @@ async def register_user(
try:
# 创建新用户
from datetime import datetime
import time
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
result = await db.execute( # pyright: ignore[reportDeprecated]
text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
)
)
next_id = result.one()[0]
if next_id <= 2:
await db.execute(text("ALTER TABLE lazer_users AUTO_INCREMENT = 3"))
await db.commit()
new_user = DBUser(
name=user_username,
safe_name=user_username.lower(), # 安全用户名(小写)
new_user = User(
username=user_username,
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限
country="CN", # 默认国家
creation_time=int(time.time()),
latest_activity=int(time.time()),
preferred_mode=0, # 默认模式
play_style=0, # 默认游戏风格
country_code="CN", # 默认国家
join_date=datetime.now(UTC),
last_visit=datetime.now(UTC),
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
# 保存用户ID因为会话可能会关闭
user_id = new_user.id
if user_id <= 2:
await db.rollback()
try:
from sqlalchemy import text
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3"))
await db.commit()
# 重新创建用户
new_user = DBUser(
name=user_username,
safe_name=user_username.lower(),
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1,
country="CN",
creation_time=int(time.time()),
latest_activity=int(time.time()),
preferred_mode=0,
play_style=0,
)
db.add(new_user)
await db.commit()
await db.refresh(new_user)
user_id = new_user.id
# 最终检查ID是否有效
if user_id <= 2:
await db.rollback()
errors = RegistrationRequestErrors(
message=(
"Failed to create account with valid ID. "
"Please contact support."
)
)
return JSONResponse(
status_code=500, content={"form_error": errors.model_dump()}
)
except Exception as fix_error:
await db.rollback()
print(f"Failed to fix AUTO_INCREMENT: {fix_error}")
errors = RegistrationRequestErrors(
message="Failed to create account with valid ID. Please try again."
)
return JSONResponse(
status_code=500, content={"form_error": errors.model_dump()}
)
# 创建默认的 lazer_profile
from app.database.user import LazerUserProfile
lazer_profile = LazerUserProfile(
user_id=user_id,
is_active=True,
is_bot=False,
is_deleted=False,
is_online=True,
is_supporter=False,
is_restricted=False,
session_verified=False,
has_supported=False,
pm_friends_only=False,
default_group="default",
join_date=datetime.utcnow(),
playmode="osu",
support_level=0,
max_blocks=50,
max_friends=250,
post_count=0,
)
db.add(lazer_profile)
assert new_user.id is not None, "New user ID should not be None"
for i in GameMode:
statistics = UserStatistics(mode=i, user_id=new_user.id)
db.add(statistics)
daily_challenge_user_stats = DailyChallengeStats(user_id=new_user.id)
db.add(daily_challenge_user_stats)
await db.commit()
# 返回成功响应
return JSONResponse(
status_code=201,
content={"message": "Account created successfully", "user_id": user_id},
)
except Exception as e:
except Exception:
await db.rollback()
# 打印详细错误信息用于调试
print(f"Registration error: {e}")
import traceback
traceback.print_exc()
logger.exception(f"Registration error for user {user_username}")
# 返回通用错误
errors = RegistrationRequestErrors(
@@ -323,6 +249,7 @@ async def oauth_token(
refresh_token_str = generate_refresh_token()
# 存储令牌
assert user.id
await store_token(
db,
user.id,

View File

@@ -5,12 +5,7 @@ import hashlib
import json
from app.calculator import calculate_beatmap_attribute
from app.database import (
Beatmap,
BeatmapResp,
User as DBUser,
)
from app.database.beatmapset import Beatmapset
from app.database import Beatmap, BeatmapResp, User
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
@@ -27,9 +22,8 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
from redis import Redis
from redis.asyncio import Redis
import rosu_pp_py as rosu
from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -39,7 +33,7 @@ async def lookup_beatmap(
id: int | None = Query(default=None, alias="id"),
md5: str | None = Query(default=None, alias="checksum"),
filename: str | None = Query(default=None, alias="filename"),
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -56,19 +50,19 @@ async def lookup_beatmap(
if beatmap is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
return BeatmapResp.from_db(beatmap)
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap(
bid: int,
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
return BeatmapResp.from_db(beatmap)
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
@@ -81,42 +75,27 @@ class BatchGetResp(BaseModel):
@router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp)
async def batch_get_beatmaps(
b_ids: list[int] = Query(alias="id", default_factory=list),
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if not b_ids:
# select 50 beatmaps by last_updated
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
).selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
)
.order_by(col(Beatmap.last_updated).desc())
.limit(50)
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
)
).all()
else:
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
).selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
)
.where(col(Beatmap.id).in_(b_ids))
.limit(50)
)
await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
).all()
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
return BatchGetResp(
beatmaps=[
await BeatmapResp.from_db(bm, session=db, user=current_user)
for bm in beatmaps
]
)
@router.post(
@@ -126,7 +105,7 @@ async def batch_get_beatmaps(
)
async def get_beatmap_attributes(
beatmap: int,
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
mods: list[str] = Query(default_factory=list),
ruleset: GameMode | None = Query(default=None),
ruleset_id: int | None = Query(default=None),
@@ -153,8 +132,8 @@ async def get_beatmap_attributes(
f"beatmap:{beatmap}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
)
if redis.exists(key):
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try:
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
@@ -164,7 +143,7 @@ async def get_beatmap_attributes(
)
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
raise HTTPException(status_code=400, detail=str(e))
redis.set(key, attr.model_dump_json())
await redis.set(key, attr.model_dump_json())
return attr
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found")

View File

@@ -1,10 +1,8 @@
from __future__ import annotations
from app.database import (
Beatmapset,
BeatmapsetResp,
User as DBUser,
)
from typing import Literal
from app.database import Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
from app.dependencies.database import get_db
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
@@ -12,9 +10,9 @@ from app.fetcher import Fetcher
from .api_router import router
from fastapi import Depends, HTTPException
from fastapi import Depends, Form, HTTPException, Query
from fastapi.responses import RedirectResponse
from httpx import HTTPStatusError
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -22,17 +20,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
async def get_beatmapset(
sid: int,
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
beatmapset = (
await db.exec(
select(Beatmapset)
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmapset.id == sid)
)
).first()
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
if not beatmapset:
try:
resp = await fetcher.get_beatmapset(sid)
@@ -40,5 +32,55 @@ async def get_beatmapset(
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
else:
resp = BeatmapsetResp.from_db(beatmapset)
resp = await BeatmapsetResp.from_db(
beatmapset, session=db, include=["recent_favourites"], user=current_user
)
return resp
@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"])
async def download_beatmapset(
beatmapset: int,
no_video: bool = Query(True, alias="noVideo"),
current_user: User = Depends(get_current_user),
):
if current_user.country_code == "CN":
return RedirectResponse(
f"https://txy1.sayobot.cn/beatmaps/download/"
f"{'novideo' if no_video else 'full'}/{beatmapset}?server=auto"
)
else:
return RedirectResponse(
f"https://api.nerinyan.moe/d/{beatmapset}?noVideo={no_video}"
)
@router.post("/beatmapsets/{beatmapset}/favourites", tags=["beatmapset"])
async def favourite_beatmapset(
beatmapset: int,
action: Literal["favourite", "unfavourite"] = Form(),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
existing_favourite = (
await db.exec(
select(FavouriteBeatmapset).where(
FavouriteBeatmapset.user_id == current_user.id,
FavouriteBeatmapset.beatmapset_id == beatmapset,
)
)
).first()
if action == "favourite" and existing_favourite:
raise HTTPException(status_code=400, detail="Already favourited")
elif action == "unfavourite" and not existing_favourite:
raise HTTPException(status_code=400, detail="Not favourited")
if action == "favourite":
favourite = FavouriteBeatmapset(
user_id=current_user.id, beatmapset_id=beatmapset
)
db.add(favourite)
else:
await db.delete(existing_favourite)
await db.commit()

View File

@@ -1,28 +1,27 @@
from __future__ import annotations
from typing import Literal
from app.database import (
User as DBUser,
)
from app.database import User, UserResp
from app.database.lazer_user import ALL_INCLUDED
from app.dependencies import get_current_user
from app.models.user import (
User as ApiUser,
)
from app.utils import convert_db_user_to_api_user
from app.dependencies.database import get_db
from app.models.score import GameMode
from .api_router import router
from fastapi import Depends
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/me/{ruleset}", response_model=ApiUser)
@router.get("/me/", response_model=ApiUser)
@router.get("/me/{ruleset}", response_model=UserResp)
@router.get("/me/", response_model=UserResp)
async def get_user_info_default(
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
current_user: DBUser = Depends(get_current_user),
ruleset: GameMode | None = None,
current_user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_db),
):
"""获取当前用户信息默认使用osu模式"""
# 默认使用osu模式
api_user = await convert_db_user_to_api_user(current_user, ruleset)
return api_user
return await UserResp.from_db(
current_user,
session,
ALL_INCLUDED,
ruleset,
)

View File

@@ -8,7 +8,7 @@ from app.dependencies.user import get_current_user
from .api_router import router
from fastapi import Depends, HTTPException, Query, Request
from sqlalchemy.orm import joinedload
from pydantic import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -26,17 +26,19 @@ async def get_relationship(
else RelationshipType.BLOCK
)
relationships = await db.exec(
select(Relationship)
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
.where(
select(Relationship).where(
Relationship.user_id == current_user.id,
Relationship.type == relationship_type,
)
)
return [await RelationshipResp.from_db(db, rel) for rel in relationships]
return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
@router.post("/friends", tags=["relationship"], response_model=RelationshipResp)
class AddFriendResp(BaseModel):
user_relation: RelationshipResp
@router.post("/friends", tags=["relationship"], response_model=AddFriendResp)
@router.post("/blocks", tags=["relationship"])
async def add_relationship(
request: Request,
@@ -87,14 +89,10 @@ async def add_relationship(
if origin_type == RelationshipType.FOLLOW:
relationship = (
await db.exec(
select(Relationship)
.where(
select(Relationship).where(
Relationship.user_id == current_user_id,
Relationship.target_id == target,
)
.options(
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
)
)
).first()
assert relationship, "Relationship should exist after commit"

View File

@@ -6,8 +6,10 @@ from app.dependencies.fetcher import get_fetcher
from app.fetcher import Fetcher
from app.models.room import MultiplayerRoom, MultiplayerRoomState, Room
from api_router import router
from .api_router import router
from fastapi import Depends, HTTPException, Query
from redis.asyncio import Redis
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -21,6 +23,7 @@ async def get_all_rooms(
), # TODO: 对房间根据分类进行筛选(真的有人用这功能吗)
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
redis: Redis = Depends(get_redis),
):
all_roomID = (await db.exec(select(RoomIndex))).all()
redis = get_redis()

View File

@@ -1,11 +1,7 @@
from __future__ import annotations
from app.database import (
User as DBUser,
)
from app.database.beatmap import Beatmap
from app.database.score import Score, ScoreResp, process_score, process_user
from app.database.score_token import ScoreToken, ScoreTokenResp
from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User
from app.database.score import get_leaderboard, process_score, process_user
from app.dependencies.database import get_db, get_redis
from app.dependencies.fetcher import get_fetcher
from app.dependencies.user import get_current_user
@@ -13,6 +9,7 @@ from app.models.beatmap import BeatmapRankStatus
from app.models.score import (
INT_TO_MODE,
GameMode,
LeaderboardType,
Rank,
SoloScoreSubmissionInfo,
)
@@ -21,9 +18,9 @@ from .api_router import router
from fastapi import Depends, Form, HTTPException, Query
from pydantic import BaseModel
from redis import Redis
from redis.asyncio import Redis
from sqlalchemy.orm import joinedload
from sqlmodel import col, select, true
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -37,44 +34,26 @@ class BeatmapScores(BaseModel):
)
async def get_beatmap_scores(
beatmap: int,
mode: GameMode,
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
mode: GameMode | None = Query(None),
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
type: str = Query(None),
current_user: DBUser = Depends(get_current_user),
mods: list[str] = Query(default_factory=set, alias="mods[]"),
type: LeaderboardType = Query(LeaderboardType.GLOBAL),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
limit: int = Query(50, ge=1, le=200),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="this server only contains lazer scores"
)
all_scores = (
await db.exec(
Score.select_clause_unique(
Score.beatmap_id == beatmap,
col(Score.passed).is_(True),
Score.gamemode == mode if mode is not None else true(),
)
)
).all()
user_score = (
await db.exec(
Score.select_clause_unique(
Score.beatmap_id == beatmap,
Score.user_id == current_user.id,
col(Score.passed).is_(True),
Score.gamemode == mode if mode is not None else true(),
)
)
).first()
all_scores, user_score = await get_leaderboard(
db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods
)
return BeatmapScores(
scores=[await ScoreResp.from_db(db, score, score.user) for score in all_scores],
userScore=await ScoreResp.from_db(db, user_score, user_score.user)
if user_score
else None,
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
)
@@ -94,7 +73,7 @@ async def get_user_beatmap_score(
legacy_only: bool = Query(None),
mode: str = Query(None),
mods: str = Query(None), # TODO:添加mods筛选
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
@@ -103,7 +82,7 @@ async def get_user_beatmap_score(
)
user_score = (
await db.exec(
Score.select_clause(True)
select(Score)
.where(
Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap,
@@ -120,7 +99,7 @@ async def get_user_beatmap_score(
else:
return BeatmapUserScore(
position=user_score.position if user_score.position is not None else 0,
score=await ScoreResp.from_db(db, user_score, user_score.user),
score=await ScoreResp.from_db(db, user_score),
)
@@ -134,7 +113,7 @@ async def get_user_all_beatmap_scores(
user: int,
legacy_only: bool = Query(None),
ruleset: str = Query(None),
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
@@ -143,7 +122,7 @@ async def get_user_all_beatmap_scores(
)
all_user_scores = (
await db.exec(
Score.select_clause()
select(Score)
.where(
Score.gamemode == ruleset if ruleset is not None else True,
Score.beatmap_id == beatmap,
@@ -153,9 +132,7 @@ async def get_user_all_beatmap_scores(
)
).all()
return [
await ScoreResp.from_db(db, score, current_user) for score in all_user_scores
]
return [await ScoreResp.from_db(db, score) for score in all_user_scores]
@router.post(
@@ -166,9 +143,10 @@ async def create_solo_score(
version_hash: str = Form(""),
beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3),
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
assert current_user.id
async with db:
score_token = ScoreToken(
user_id=current_user.id,
@@ -190,7 +168,7 @@ async def submit_solo_score(
beatmap: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: DBUser = Depends(get_current_user),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
@@ -210,9 +188,7 @@ async def submit_solo_score(
if score_token.score_id:
score = (
await db.exec(
select(Score)
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
.where(
select(Score).where(
Score.id == score_token.score_id,
Score.user_id == current_user.id,
)
@@ -246,8 +222,6 @@ async def submit_solo_score(
score_id = score.id
score_token.score_id = score_id
await process_user(db, current_user, score, ranked)
score = (
await db.exec(Score.select_clause().where(Score.id == score_id))
).first()
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
assert score is not None
return await ScoreResp.from_db(db, score, current_user)
return await ScoreResp.from_db(db, score)

View File

@@ -1,12 +1,9 @@
from __future__ import annotations
from typing import Literal
from app.database import User as DBUser
from app.database import User, UserResp
from app.database.lazer_user import SEARCH_INCLUDED
from app.dependencies.database import get_db
from app.models.score import INT_TO_MODE
from app.models.user import User as ApiUser
from app.utils import convert_db_user_to_api_user
from app.models.score import GameMode
from .api_router import router
@@ -17,28 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import col
# ---------- Shared Utility ----------
async def get_user_by_lookup(
db: AsyncSession, lookup: str, key: str = "id"
) -> DBUser | None:
"""根据查找方式获取用户"""
if key == "id":
try:
user_id = int(lookup)
result = await db.exec(select(DBUser).where(DBUser.id == user_id))
return result.first()
except ValueError:
return None
elif key == "username":
result = await db.exec(select(DBUser).where(DBUser.name == lookup))
return result.first()
else:
return None
# ---------- Batch Users ----------
class BatchUserResponse(BaseModel):
users: list[ApiUser]
users: list[UserResp]
@router.get("/users", response_model=BatchUserResponse)
@@ -51,75 +28,44 @@ async def get_users(
):
if user_ids:
searched_users = (
await session.exec(
DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids))
)
await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
).all()
else:
searched_users = (
await session.exec(DBUser.all_select_clause().limit(50))
).all()
searched_users = (await session.exec(select(User).limit(50))).all()
return BatchUserResponse(
users=[
await convert_db_user_to_api_user(
searched_user, ruleset=INT_TO_MODE[searched_user.preferred_mode].value
await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
)
for searched_user in searched_users
]
)
# # ---------- Individual User ----------
# @router.get("/users/{user_lookup}/{mode}", response_model=ApiUser)
# @router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser)
# async def get_user_with_mode(
# user_lookup: str,
# mode: Literal["osu", "taiko", "fruits", "mania"],
# key: Literal["id", "username"] = Query("id"),
# current_user: DBUser = Depends(get_current_user),
# db: AsyncSession = Depends(get_db),
# ):
# """获取指定游戏模式的用户信息"""
# user = await get_user_by_lookup(db, user_lookup, key)
# if not user:
# raise HTTPException(status_code=404, detail="User not found")
# return await convert_db_user_to_api_user(user, mode)
# @router.get("/users/{user_lookup}", response_model=ApiUser)
# @router.get("/users/{user_lookup}/", response_model=ApiUser)
# async def get_user_default(
# user_lookup: str,
# key: Literal["id", "username"] = Query("id"),
# current_user: DBUser = Depends(get_current_user),
# db: AsyncSession = Depends(get_db),
# ):
# """获取用户信息默认使用osu模式但包含所有模式的统计信息"""
# user = await get_user_by_lookup(db, user_lookup, key)
# if not user:
# raise HTTPException(status_code=404, detail="User not found")
# return await convert_db_user_to_api_user(user, "osu")
@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
@router.get("/users/{user}/", response_model=ApiUser)
@router.get("/users/{user}", response_model=ApiUser)
@router.get("/users/{user}/{ruleset}", response_model=UserResp)
@router.get("/users/{user}/", response_model=UserResp)
@router.get("/users/{user}", response_model=UserResp)
async def get_user_info(
user: str,
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
ruleset: GameMode | None = None,
session: AsyncSession = Depends(get_db),
):
searched_user = (
await session.exec(
DBUser.all_select_clause().where(
DBUser.id == int(user)
select(User).where(
User.id == int(user)
if user.isdigit()
else DBUser.name == user.removeprefix("@")
else User.username == user.removeprefix("@")
)
)
).first()
if not searched_user:
raise HTTPException(404, detail="User not found")
return await convert_db_user_to_api_user(searched_user, ruleset=ruleset)
return await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
ruleset=ruleset,
)

View File

@@ -2,10 +2,11 @@ from __future__ import annotations
import asyncio
from collections.abc import Coroutine
from datetime import UTC, datetime
from typing import override
from app.database.relationship import Relationship, RelationshipType
from app.dependencies.database import engine
from app.database import Relationship, RelationshipType, User
from app.dependencies.database import engine, get_redis
from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity
from .hub import Client, Hub
@@ -54,6 +55,18 @@ class MetadataHub(Hub[MetadataClientState]):
async def _clean_state(self, state: MetadataClientState) -> None:
if state.pushable:
await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None))
redis = get_redis()
if await redis.exists(f"metadata:online:{state.connection_id}"):
await redis.delete(f"metadata:online:{state.connection_id}")
async with AsyncSession(engine) as session:
async with session.begin():
user = (
await session.exec(
select(User).where(User.id == int(state.connection_id))
)
).one()
user.last_visit = datetime.now(UTC)
await session.commit()
@override
def create_state(self, client: Client) -> MetadataClientState:
@@ -93,6 +106,8 @@ class MetadataHub(Hub[MetadataClientState]):
)
)
await asyncio.gather(*tasks)
redis = get_redis()
await redis.set(f"metadata:online:{user_id}", "")
async def UpdateStatus(self, client: Client, status: int) -> None:
status_ = OnlineStatus(status)

View File

@@ -7,10 +7,9 @@ import struct
import time
from typing import override
from app.database import Beatmap
from app.database import Beatmap, User
from app.database.score import Score
from app.database.score_token import ScoreToken
from app.database.user import User
from app.dependencies.database import engine
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_to_int
@@ -197,7 +196,7 @@ class SpectatorHub(Hub[StoreClientState]):
).first()
if not user:
return
name = user.name
name = user.username
store.state = state
store.beatmap_status = beatmap.beatmap_status
store.checksum = beatmap.checksum
@@ -241,65 +240,17 @@ class SpectatorHub(Hub[StoreClientState]):
user_id = int(client.connection_id)
store = self.get_or_create_state(client)
score = store.score
assert store.beatmap_status is not None
assert store.state is not None
assert store.score is not None
if not score or not store.score_token:
return
assert store.beatmap_status is not None
async def _save_replay():
assert store.checksum is not None
assert store.ruleset_id is not None
assert store.state is not None
assert store.score is not None
async with AsyncSession(engine) as session:
async with session:
start_time = time.time()
score_record = None
while time.time() - start_time < READ_SCORE_TIMEOUT:
sub_query = select(ScoreToken.score_id).where(
ScoreToken.id == store.score_token,
)
result = await session.exec(
select(Score)
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
.where(
Score.id == sub_query,
Score.user_id == user_id,
)
)
score_record = result.first()
if score_record:
break
if not score_record:
return
if not score_record.passed:
return
score_record.has_replay = True
await session.commit()
await session.refresh(score_record)
save_replay(
ruleset_id=store.ruleset_id,
md5=store.checksum,
username=store.score.score_info.user.name,
score=score_record,
statistics=score.score_info.statistics,
maximum_statistics=score.score_info.maximum_statistics,
frames=score.replay_frames,
)
if (
(
BeatmapRankStatus.PENDING
< store.beatmap_status
<= BeatmapRankStatus.LOVED
)
and any(
k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()
)
and state.state != SpectatedUserState.Failed
BeatmapRankStatus.PENDING < store.beatmap_status <= BeatmapRankStatus.LOVED
) and any(
k.is_hit() and v > 0 for k, v in store.score.score_info.statistics.items()
):
# save replay
await _save_replay()
await self._process_score(store, client)
store.state = None
store.beatmap_status = None
store.checksum = None
@@ -308,6 +259,56 @@ class SpectatorHub(Hub[StoreClientState]):
store.score = None
await self._end_session(user_id, state)
async def _process_score(self, store: StoreClientState, client: Client) -> None:
user_id = int(client.connection_id)
assert store.state is not None
assert store.score_token is not None
assert store.checksum is not None
assert store.ruleset_id is not None
assert store.score is not None
async with AsyncSession(engine) as session:
async with session:
start_time = time.time()
score_record = None
while time.time() - start_time < READ_SCORE_TIMEOUT:
sub_query = select(ScoreToken.score_id).where(
ScoreToken.id == store.score_token,
)
result = await session.exec(
select(Score)
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
.where(
Score.id == sub_query,
Score.user_id == user_id,
)
)
score_record = result.first()
if score_record:
break
if not score_record:
return
if not score_record.passed:
return
await self.call_noblock(
client,
"UserScoreProcessed",
user_id,
score_record.id,
)
# save replay
score_record.has_replay = True
await session.commit()
await session.refresh(score_record)
save_replay(
ruleset_id=store.ruleset_id,
md5=store.checksum,
username=store.score.score_info.user.name,
score=score_record,
statistics=store.score.score_info.statistics,
maximum_statistics=store.score.score_info.maximum_statistics,
frames=store.score.replay_frames,
)
async def _end_session(self, user_id: int, state: SpectatorState) -> None:
if state.state == SpectatedUserState.Playing:
state.state = SpectatedUserState.Quit
@@ -336,7 +337,7 @@ class SpectatorHub(Hub[StoreClientState]):
async with AsyncSession(engine) as session:
async with session.begin():
username = (
await session.exec(select(User.name).where(User.id == user_id))
await session.exec(select(User.username).where(User.id == user_id))
).first()
if not username:
return

View File

@@ -1,465 +1,6 @@
from __future__ import annotations
from datetime import UTC, datetime
from app.database import (
LazerUserCounts,
LazerUserProfile,
LazerUserStatistics,
User as DBUser,
)
from app.models.user import (
Country,
Cover,
DailyChallengeStats,
GradeCounts,
Kudosu,
Level,
Page,
RankHighest,
RankHistory,
Statistics,
User,
UserAchievement,
)
def unix_timestamp_to_windows(timestamp: int) -> int:
"""Convert a Unix timestamp to a Windows timestamp."""
return (timestamp + 62135596800) * 10_000_000
async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User:
"""将数据库用户模型转换为API用户模型使用 Lazer 表)"""
# 从db_user获取基本字段值
user_id = getattr(db_user, "id")
user_name = getattr(db_user, "name")
user_country = getattr(db_user, "country")
user_country_code = user_country # 在User模型中country字段就是country_code
# 获取 Lazer 用户资料
profile = db_user.lazer_profile
if not profile:
# 如果没有 lazer 资料,使用默认值
profile = LazerUserProfile(
user_id=user_id,
)
# 获取 Lazer 用户计数 - 使用正确的 lazer_counts 关系
lzrcnt = db_user.lazer_counts
if not lzrcnt:
# 如果没有 lazer 计数,使用默认值
lzrcnt = LazerUserCounts(user_id=user_id)
# 获取指定模式的统计信息
user_stats = None
if db_user.lazer_statistics:
for stat in db_user.lazer_statistics:
if stat.mode == ruleset:
user_stats = stat
break
if not user_stats:
# 如果没有找到指定模式的统计,创建默认统计
user_stats = LazerUserStatistics(user_id=user_id)
# 获取国家信息
country_code = db_user.country_code if db_user.country_code is not None else "XX"
country = Country(code=str(country_code), name=get_country_name(str(country_code)))
# 获取 Kudosu 信息
kudosu = Kudosu(available=0, total=0)
# 获取计数信息
# counts = LazerUserCounts(user_id=user_id)
# 转换统计信息
statistics = Statistics(
count_100=user_stats.count_100,
count_300=user_stats.count_300,
count_50=user_stats.count_50,
count_miss=user_stats.count_miss,
level=Level(
current=user_stats.level_current, progress=user_stats.level_progress
),
global_rank=user_stats.global_rank,
global_rank_exp=user_stats.global_rank_exp,
pp=float(user_stats.pp) if user_stats.pp else 0.0,
pp_exp=float(user_stats.pp_exp) if user_stats.pp_exp else 0.0,
ranked_score=user_stats.ranked_score,
hit_accuracy=float(user_stats.hit_accuracy) if user_stats.hit_accuracy else 0.0,
play_count=user_stats.play_count,
play_time=user_stats.play_time,
total_score=user_stats.total_score,
total_hits=user_stats.total_hits,
maximum_combo=user_stats.maximum_combo,
replays_watched_by_others=user_stats.replays_watched_by_others,
is_ranked=user_stats.is_ranked,
grade_counts=GradeCounts(
ss=user_stats.grade_ss,
ssh=user_stats.grade_ssh,
s=user_stats.grade_s,
sh=user_stats.grade_sh,
a=user_stats.grade_a,
),
country_rank=user_stats.country_rank,
rank={"country": user_stats.country_rank} if user_stats.country_rank else None,
)
# 转换所有模式的统计信息
statistics_rulesets = {}
if db_user.lazer_statistics:
for stat in db_user.lazer_statistics:
statistics_rulesets[stat.mode] = Statistics(
count_100=stat.count_100,
count_300=stat.count_300,
count_50=stat.count_50,
count_miss=stat.count_miss,
level=Level(current=stat.level_current, progress=stat.level_progress),
global_rank=stat.global_rank,
global_rank_exp=stat.global_rank_exp,
pp=float(stat.pp) if stat.pp else 0.0,
pp_exp=float(stat.pp_exp) if stat.pp_exp else 0.0,
ranked_score=stat.ranked_score,
hit_accuracy=float(stat.hit_accuracy) if stat.hit_accuracy else 0.0,
play_count=stat.play_count,
play_time=stat.play_time,
total_score=stat.total_score,
total_hits=stat.total_hits,
maximum_combo=stat.maximum_combo,
replays_watched_by_others=stat.replays_watched_by_others,
is_ranked=stat.is_ranked,
grade_counts=GradeCounts(
ss=stat.grade_ss,
ssh=stat.grade_ssh,
s=stat.grade_s,
sh=stat.grade_sh,
a=stat.grade_a,
),
country_rank=stat.country_rank,
rank={"country": stat.country_rank} if stat.country_rank else None,
)
# 转换国家信息
country = Country(code=user_country_code, name=get_country_name(user_country_code))
# 转换封面信息
cover_url = (
profile.cover_url
if profile and profile.cover_url
else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
)
cover = Cover(
custom_url=profile.cover_url if profile else None, url=str(cover_url), id=None
)
# 转换 Kudosu 信息
kudosu = Kudosu(available=0, total=0)
# 转换成就信息
user_achievements = []
if db_user.lazer_achievements:
for achievement in db_user.lazer_achievements:
user_achievements.append(
UserAchievement(
achieved_at=achievement.achieved_at,
achievement_id=achievement.achievement_id,
)
)
# 转换排名历史
rank_history = None
rank_history_data = None
for rh in db_user.rank_history:
if rh.mode == ruleset:
rank_history_data = rh.rank_data
break
if rank_history_data:
rank_history = RankHistory(mode=ruleset, data=rank_history_data)
# 转换每日挑战统计
# daily_challenge_stats = None
# if db_user.daily_challenge_stats:
# dcs = db_user.daily_challenge_stats
# daily_challenge_stats = DailyChallengeStats(
# daily_streak_best=dcs.daily_streak_best,
# daily_streak_current=dcs.daily_streak_current,
# last_update=dcs.last_update,
# last_weekly_streak=dcs.last_weekly_streak,
# playcount=dcs.playcount,
# top_10p_placements=dcs.top_10p_placements,
# top_50p_placements=dcs.top_50p_placements,
# user_id=dcs.user_id,
# weekly_streak_best=dcs.weekly_streak_best,
# weekly_streak_current=dcs.weekly_streak_current,
# )
# 转换最高排名
rank_highest = None
if user_stats.rank_highest:
rank_highest = RankHighest(
rank=user_stats.rank_highest,
updated_at=user_stats.rank_highest_updated_at or datetime.utcnow(),
)
# 转换团队信息
team = None
if db_user.team_membership:
team_member = db_user.team_membership # 假设用户只属于一个团队
team = team_member.team
# 创建用户对象
# 从db_user获取基本字段值
user_id = getattr(db_user, "id")
user_name = getattr(db_user, "name")
user_country = getattr(db_user, "country")
# 获取用户头像URL
avatar_url = None
# 首先检查 profile 中的 avatar_url
if profile and hasattr(profile, "avatar_url") and profile.avatar_url:
avatar_url = str(profile.avatar_url)
# 然后检查是否有关联的头像记录
if avatar_url is None and hasattr(db_user, "avatar") and db_user.avatar is not None:
if db_user.avatar.r2_game_url:
# 优先使用游戏用的头像URL
avatar_url = str(db_user.avatar.r2_game_url)
elif db_user.avatar.r2_original_url:
# 其次使用原始头像URL
avatar_url = str(db_user.avatar.r2_original_url)
# 如果还是没有找到,通过查询获取
# if db_session and avatar_url is None:
# try:
# # 导入UserAvatar模型
# # 尝试查找用户的头像记录
# statement = select(UserAvatar).where(
# UserAvatar.user_id == user_id, UserAvatar.is_active == True
# )
# avatar_record = db_session.exec(statement).first()
# if avatar_record is not None:
# if avatar_record.r2_game_url is not None:
# # 优先使用游戏用的头像URL
# avatar_url = str(avatar_record.r2_game_url)
# elif avatar_record.r2_original_url is not None:
# # 其次使用原始头像URL
# avatar_url = str(avatar_record.r2_original_url)
# except Exception as e:
# print(f"获取用户头像时出错: {e}")
# print(f"最终头像URL: {avatar_url}")
# 如果仍然没有找到头像URL则使用默认URL
if avatar_url is None:
avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1"
# 处理 profile_order 列表排序
profile_order = [
"me",
"recent_activity",
"top_ranks",
"medals",
"historical",
"beatmaps",
"kudosu",
]
if profile and profile.profile_order:
profile_order = profile.profile_order.split(",")
# 在convert_db_user_to_api_user函数中添加active_tournament_banners处理
active_tournament_banners = []
if db_user.active_banners:
for banner in db_user.active_banners:
active_tournament_banners.append(
{
"tournament_id": banner.tournament_id,
"image_url": banner.image_url,
"is_active": banner.is_active,
}
)
# 在convert_db_user_to_api_user函数中添加badges处理
badges = []
if db_user.lazer_badges:
for badge in db_user.lazer_badges:
badges.append(
{
"badge_id": badge.badge_id,
"awarded_at": badge.awarded_at,
"description": badge.description,
"image_url": badge.image_url,
"url": badge.url,
}
)
# 在convert_db_user_to_api_user函数中添加monthly_playcounts处理
monthly_playcounts = []
if db_user.lazer_monthly_playcounts:
for playcount in db_user.lazer_monthly_playcounts:
monthly_playcounts.append(
{
"start_date": playcount.start_date.isoformat()
if playcount.start_date
else None,
"play_count": playcount.play_count,
}
)
# 在convert_db_user_to_api_user函数中添加previous_usernames处理
previous_usernames = []
if db_user.lazer_previous_usernames:
for username in db_user.lazer_previous_usernames:
previous_usernames.append(
{
"username": username.username,
"changed_at": username.changed_at.isoformat()
if username.changed_at
else None,
}
)
# 在convert_db_user_to_api_user函数中添加replays_watched_counts处理
replays_watched_counts = []
if hasattr(db_user, "lazer_replays_watched") and db_user.lazer_replays_watched:
for replay in db_user.lazer_replays_watched:
replays_watched_counts.append(
{
"start_date": replay.start_date.isoformat()
if replay.start_date
else None,
"count": replay.count,
}
)
# 创建用户对象
user = User(
id=user_id,
username=user_name,
avatar_url=avatar_url,
country_code=str(country_code),
default_group=profile.default_group if profile else "default",
is_active=profile.is_active,
is_bot=profile.is_bot,
is_deleted=profile.is_deleted,
is_online=profile.is_online,
is_supporter=profile.is_supporter,
is_restricted=profile.is_restricted,
last_visit=db_user.last_visit,
pm_friends_only=profile.pm_friends_only,
profile_colour=profile.profile_colour,
cover_url=profile.cover_url
if profile and profile.cover_url
else "https://assets.ppy.sh/user-profile-covers/default.jpeg",
discord=profile.discord if profile else None,
has_supported=profile.has_supported if profile else False,
interests=profile.interests if profile else None,
join_date=profile.join_date if profile.join_date else datetime.now(UTC),
location=profile.location if profile else None,
max_blocks=profile.max_blocks if profile and profile.max_blocks else 100,
max_friends=profile.max_friends if profile and profile.max_friends else 500,
post_count=profile.post_count if profile and profile.post_count else 0,
profile_hue=profile.profile_hue if profile and profile.profile_hue else None,
profile_order=profile_order, # 使用排序后的 profile_order
title=profile.title if profile else None,
title_url=profile.title_url if profile else None,
twitter=profile.twitter if profile else None,
website=profile.website if profile else None,
session_verified=True,
support_level=profile.support_level if profile else 0,
country=country,
cover=cover,
kudosu=kudosu,
statistics=statistics,
statistics_rulesets=statistics_rulesets,
beatmap_playcounts_count=lzrcnt.beatmap_playcounts_count if lzrcnt else 0,
comments_count=lzrcnt.comments_count if lzrcnt else 0,
favourite_beatmapset_count=lzrcnt.favourite_beatmapset_count if lzrcnt else 0,
follower_count=lzrcnt.follower_count if lzrcnt else 0,
graveyard_beatmapset_count=lzrcnt.graveyard_beatmapset_count if lzrcnt else 0,
guest_beatmapset_count=lzrcnt.guest_beatmapset_count if lzrcnt else 0,
loved_beatmapset_count=lzrcnt.loved_beatmapset_count if lzrcnt else 0,
mapping_follower_count=lzrcnt.mapping_follower_count if lzrcnt else 0,
nominated_beatmapset_count=lzrcnt.nominated_beatmapset_count if lzrcnt else 0,
pending_beatmapset_count=lzrcnt.pending_beatmapset_count if lzrcnt else 0,
ranked_beatmapset_count=lzrcnt.ranked_beatmapset_count if lzrcnt else 0,
ranked_and_approved_beatmapset_count=lzrcnt.ranked_and_approved_beatmapset_count
if lzrcnt
else 0,
unranked_beatmapset_count=lzrcnt.unranked_beatmapset_count if lzrcnt else 0,
scores_best_count=lzrcnt.scores_best_count if lzrcnt else 0,
scores_first_count=lzrcnt.scores_first_count if lzrcnt else 0,
scores_pinned_count=lzrcnt.scores_pinned_count,
scores_recent_count=lzrcnt.scores_recent_count if lzrcnt else 0,
account_history=[], # TODO: 获取用户历史账户信息
# active_tournament_banner=len(active_tournament_banners),
active_tournament_banners=active_tournament_banners,
badges=badges,
current_season_stats=None,
daily_challenge_user_stats=DailyChallengeStats(
user_id=user_id,
daily_streak_best=db_user.daily_challenge_stats.daily_streak_best
if db_user.daily_challenge_stats
else 0,
daily_streak_current=db_user.daily_challenge_stats.daily_streak_current
if db_user.daily_challenge_stats
else 0,
last_update=db_user.daily_challenge_stats.last_update
if db_user.daily_challenge_stats
else None,
last_weekly_streak=db_user.daily_challenge_stats.last_weekly_streak
if db_user.daily_challenge_stats
else None,
playcount=db_user.daily_challenge_stats.playcount
if db_user.daily_challenge_stats
else 0,
top_10p_placements=db_user.daily_challenge_stats.top_10p_placements
if db_user.daily_challenge_stats
else 0,
top_50p_placements=db_user.daily_challenge_stats.top_50p_placements
if db_user.daily_challenge_stats
else 0,
weekly_streak_best=db_user.daily_challenge_stats.weekly_streak_best
if db_user.daily_challenge_stats
else 0,
weekly_streak_current=db_user.daily_challenge_stats.weekly_streak_current
if db_user.daily_challenge_stats
else 0,
),
groups=[],
monthly_playcounts=monthly_playcounts,
page=Page(html=profile.page_html or "", raw=profile.page_raw or "")
if profile.page_html or profile.page_raw
else Page(),
previous_usernames=previous_usernames,
rank_highest=rank_highest,
rank_history=rank_history,
rankHistory=rank_history,
replays_watched_counts=replays_watched_counts,
team=team,
user_achievements=user_achievements,
)
return user
def get_country_name(country_code: str) -> str:
"""根据国家代码获取国家名称"""
country_names = {
"CN": "China",
"JP": "Japan",
"US": "United States",
"GB": "United Kingdom",
"DE": "Germany",
"FR": "France",
"KR": "South Korea",
"CA": "Canada",
"AU": "Australia",
"BR": "Brazil",
# 可以添加更多国家
}
return country_names.get(country_code, "Unknown")

107
main.py
View File

@@ -4,25 +4,22 @@ from contextlib import asynccontextmanager
from datetime import datetime
from app.config import settings
from app.database import Team # noqa: F401
from app.dependencies.database import create_tables, engine
from app.dependencies.database import create_tables, engine, redis_client
from app.dependencies.fetcher import get_fetcher
from app.models.user import User
from app.router import api_router, auth_router, fetcher_router, signalr_router
from fastapi import FastAPI
User.model_rebuild()
@asynccontextmanager
async def lifespan(app: FastAPI):
# on startup
await create_tables()
get_fetcher() # 初始化 fetcher
await get_fetcher() # 初始化 fetcher
# on shutdown
yield
await engine.dispose()
await redis_client.aclose()
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
@@ -44,104 +41,6 @@ async def health_check():
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
# @app.get("/api/v2/friends")
# async def get_friends():
# return JSONResponse(
# content=[
# {
# "id": 123456,
# "username": "BestFriend",
# "is_online": True,
# "is_supporter": False,
# "country": {"code": "US", "name": "United States"},
# }
# ]
# )
# @app.get("/api/v2/notifications")
# async def get_notifications():
# return JSONResponse(content={"notifications": [], "unread_count": 0})
# @app.post("/api/v2/chat/ack")
# async def chat_ack():
# return JSONResponse(content={"status": "ok"})
# @app.get("/api/v2/users/{user_id}/{mode}")
# async def get_user_mode(user_id: int, mode: str):
# return JSONResponse(
# content={
# "id": user_id,
# "username": "测试测试测",
# "statistics": {
# "level": {"current": 97, "progress": 96},
# "pp": 114514,
# "global_rank": 666,
# "country_rank": 1,
# "hit_accuracy": 100,
# },
# "country": {"code": "JP", "name": "Japan"},
# }
# )
# @app.get("/api/v2/me")
# async def get_me():
# return JSONResponse(
# content={
# "id": 15651670,
# "username": "Googujiang",
# "is_online": True,
# "country": {"code": "JP", "name": "Japan"},
# "statistics": {
# "level": {"current": 97, "progress": 96},
# "pp": 2826.26,
# "global_rank": 298026,
# "country_rank": 11220,
# "hit_accuracy": 95.7168,
# },
# }
# )
# @app.post("/signalr/metadata/negotiate")
# async def metadata_negotiate(negotiateVersion: int = 1):
# return JSONResponse(
# content={
# "connectionId": "abc123",
# "availableTransports": [
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
# ],
# }
# )
# @app.post("/signalr/spectator/negotiate")
# async def spectator_negotiate(negotiateVersion: int = 1):
# return JSONResponse(
# content={
# "connectionId": "spec456",
# "availableTransports": [
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
# ],
# }
# )
# @app.post("/signalr/multiplayer/negotiate")
# async def multiplayer_negotiate(negotiateVersion: int = 1):
# return JSONResponse(
# content={
# "connectionId": "multi789",
# "availableTransports": [
# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
# ],
# }
# )
if __name__ == "__main__":
from app.log import logger # noqa: F401

View File

@@ -1,8 +1,8 @@
"""score: remove best_id in database
"""beatmapset: support favourite count
Revision ID: 78be13c71791
Revises: dc4d25c428c7
Create Date: 2025-07-29 07:57:33.764517
Revision ID: 1178d0758ebf
Revises:
Create Date: 2025-08-01 04:05:09.882800
"""
@@ -15,8 +15,8 @@ import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision: str = "78be13c71791"
down_revision: str | Sequence[str] | None = "dc4d25c428c7"
revision: str = "1178d0758ebf"
down_revision: str | Sequence[str] | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
@@ -24,7 +24,7 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("scores", "best_id")
op.drop_column("beatmapsets", "favourite_count")
# ### end Alembic commands ###
@@ -32,7 +32,9 @@ def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"scores",
sa.Column("best_id", mysql.INTEGER(), autoincrement=False, nullable=True),
"beatmapsets",
sa.Column(
"favourite_count", mysql.INTEGER(), autoincrement=False, nullable=False
),
)
# ### end Alembic commands ###

View File

@@ -0,0 +1,54 @@
"""relationship: fix unique relationship
Revision ID: 58a11441d302
Revises: 1178d0758ebf
Create Date: 2025-08-01 04:23:02.498166
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision: str = "58a11441d302"
down_revision: str | Sequence[str] | None = "1178d0758ebf"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"relationship",
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
)
op.drop_constraint("PRIMARY", "relationship", type_="primary")
op.create_primary_key("pk_relationship", "relationship", ["id"])
op.alter_column(
"relationship", "user_id", existing_type=mysql.BIGINT(), nullable=True
)
op.alter_column(
"relationship", "target_id", existing_type=mysql.BIGINT(), nullable=True
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("pk_relationship", "relationship", type_="primary")
op.create_primary_key("PRIMARY", "relationship", ["user_id", "target_id"])
op.alter_column(
"relationship", "target_id", existing_type=mysql.BIGINT(), nullable=False
)
op.alter_column(
"relationship", "user_id", existing_type=mysql.BIGINT(), nullable=False
)
op.drop_column("relationship", "id")
# ### end Alembic commands ###

View File

@@ -1,36 +0,0 @@
"""score: add nlarge_tick_hit & nsmall_tick_hit for pp calculator
Revision ID: dc4d25c428c7
Revises:
Create Date: 2025-07-29 01:43:40.221070
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "dc4d25c428c7"
down_revision: str | Sequence[str] | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("scores", sa.Column("nlarge_tick_hit", sa.Integer(), nullable=True))
op.add_column("scores", sa.Column("nsmall_tick_hit", sa.Integer(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("scores", "nsmall_tick_hit")
op.drop_column("scores", "nlarge_tick_hit")
# ### end Alembic commands ###