refactor(project): make pyright & ruff happy

This commit is contained in:
MingxuanGame
2025-08-22 08:21:52 +00:00
parent 3b1d7a2234
commit 598fcc8b38
157 changed files with 2382 additions and 4590 deletions

View File

@@ -28,18 +28,14 @@ if TYPE_CHECKING:
class UserAchievementBase(SQLModel, UTCBaseModel):
achievement_id: int
achieved_at: datetime = Field(
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
)
achieved_at: datetime = Field(default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)))
class UserAchievement(UserAchievementBase, table=True):
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
__tablename__: str = "lazer_user_achievements"
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True)
user: "User" = Relationship(back_populates="achievement")
@@ -56,11 +52,7 @@ async def process_achievements(session: AsyncSession, redis: Redis, score_id: in
if not score:
return
achieved = (
await session.exec(
select(UserAchievement.achievement_id).where(
UserAchievement.user_id == score.user_id
)
)
await session.exec(select(UserAchievement.achievement_id).where(UserAchievement.user_id == score.user_id))
).all()
not_achieved = {k: v for k, v in MEDALS.items() if k.id not in achieved}
result: list[Achievement] = []
@@ -78,9 +70,7 @@ async def process_achievements(session: AsyncSession, redis: Redis, score_id: in
)
await redis.publish(
"chat:notification",
UserAchievementUnlock.init(
r, score.user_id, score.gamemode
).model_dump_json(),
UserAchievementUnlock.init(r, score.user_id, score.gamemode).model_dump_json(),
)
event = Event(
created_at=now,

View File

@@ -20,42 +20,34 @@ if TYPE_CHECKING:
class OAuthToken(UTCBaseModel, SQLModel, table=True):
__tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType]
__tablename__: str = "oauth_tokens"
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
client_id: int = Field(index=True)
access_token: str = Field(max_length=500, unique=True)
refresh_token: str = Field(max_length=500, unique=True)
token_type: str = Field(default="Bearer", max_length=20)
scope: str = Field(default="*", max_length=100)
expires_at: datetime = Field(sa_column=Column(DateTime))
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
created_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
user: "User" = Relationship()
class OAuthClient(SQLModel, table=True):
__tablename__ = "oauth_clients" # pyright: ignore[reportAssignmentType]
__tablename__: str = "oauth_clients"
name: str = Field(max_length=100, index=True)
description: str = Field(sa_column=Column(Text), default="")
client_id: int | None = Field(default=None, primary_key=True, index=True)
client_secret: str = Field(default_factory=secrets.token_hex, index=True)
redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON))
owner_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
class V1APIKeys(SQLModel, table=True):
__tablename__ = "v1_api_keys" # pyright: ignore[reportAssignmentType]
__tablename__: str = "v1_api_keys"
id: int | None = Field(default=None, primary_key=True)
name: str = Field(max_length=100, index=True)
key: str = Field(default_factory=secrets.token_hex, index=True)
owner_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))

View File

@@ -60,17 +60,13 @@ class BeatmapBase(SQLModel):
class Beatmap(BeatmapBase, table=True):
__tablename__ = "beatmaps" # pyright: ignore[reportAssignmentType]
__tablename__: str = "beatmaps"
id: int = Field(primary_key=True, index=True)
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus = Field(index=True)
# optional
beatmapset: Beatmapset = Relationship(
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
)
failtimes: FailTime | None = Relationship(
back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"}
)
beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
@classmethod
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
@@ -84,21 +80,15 @@ class Beatmap(BeatmapBase, table=True):
"beatmap_status": BeatmapRankStatus(resp.ranked),
}
)
if not (
await session.exec(select(exists()).where(Beatmap.id == resp.id))
).first():
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
session.add(beatmap)
await session.commit()
beatmap = (
await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
).first()
beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).first()
assert beatmap is not None, "Beatmap should not be None after commit"
return beatmap
@classmethod
async def from_resp_batch(
cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0
) -> list["Beatmap"]:
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
beatmaps = []
for resp in inp:
if resp.id == from_:
@@ -113,9 +103,7 @@ class Beatmap(BeatmapBase, table=True):
"beatmap_status": BeatmapRankStatus(resp.ranked),
}
)
if not (
await session.exec(select(exists()).where(Beatmap.id == resp.id))
).first():
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
session.add(beatmap)
beatmaps.append(beatmap)
await session.commit()
@@ -130,17 +118,11 @@ class Beatmap(BeatmapBase, table=True):
md5: str | None = None,
) -> "Beatmap":
beatmap = (
await session.exec(
select(Beatmap).where(
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
)
)
await session.exec(select(Beatmap).where(Beatmap.id == bid if bid is not None else Beatmap.checksum == md5))
).first()
if not beatmap:
resp = await fetcher.get_beatmap(bid, md5)
r = await session.exec(
select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)
)
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
if not r.first():
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
@@ -178,10 +160,7 @@ class BeatmapResp(BeatmapBase):
if query_mode is not None and beatmap.mode != query_mode:
beatmap_["convert"] = True
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
if (
settings.enable_all_beatmap_leaderboard
and not beatmap_status.has_leaderboard()
):
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
else:
@@ -189,9 +168,7 @@ class BeatmapResp(BeatmapBase):
beatmap_["ranked"] = beatmap_status.value
beatmap_["mode_int"] = int(beatmap.mode)
if not from_set:
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(
beatmap.beatmapset, session=session, user=user
)
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
if beatmap.failtimes is not None:
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
else:
@@ -218,7 +195,7 @@ class BeatmapResp(BeatmapBase):
class BannedBeatmaps(SQLModel, table=True):
__tablename__ = "banned_beatmaps" # pyright: ignore[reportAssignmentType]
__tablename__: str = "banned_beatmaps"
id: int | None = Field(primary_key=True, index=True, default=None)
beatmap_id: int = Field(index=True)
@@ -230,15 +207,10 @@ async def calculate_beatmap_attributes(
redis: Redis,
fetcher: "Fetcher",
):
key = (
f"beatmap:{beatmap_id}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
)
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
return BeatmapAttributes.model_validate_json(await redis.get(key))
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
attr = await asyncio.get_event_loop().run_in_executor(
None, calculate_beatmap_attribute, resp, ruleset, mods_
)
attr = await asyncio.get_event_loop().run_in_executor(None, calculate_beatmap_attribute, resp, ruleset, mods_)
await redis.set(key, attr.model_dump_json())
return attr

View File

@@ -23,15 +23,13 @@ if TYPE_CHECKING:
class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True):
__tablename__ = "beatmap_playcounts" # pyright: ignore[reportAssignmentType]
__tablename__: str = "beatmap_playcounts"
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
playcount: int = Field(default=0)
@@ -59,9 +57,7 @@ class BeatmapPlaycountsResp(BaseModel):
)
async def process_beatmap_playcount(
session: AsyncSession, user_id: int, beatmap_id: int
) -> None:
async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None:
existing_playcount = (
await session.exec(
select(BeatmapPlaycounts).where(
@@ -89,7 +85,5 @@ async def process_beatmap_playcount(
}
session.add(playcount_event)
else:
new_playcount = BeatmapPlaycounts(
user_id=user_id, beatmap_id=beatmap_id, playcount=1
)
new_playcount = BeatmapPlaycounts(user_id=user_id, beatmap_id=beatmap_id, playcount=1)
session.add(new_playcount)

View File

@@ -86,9 +86,7 @@ class BeatmapsetBase(SQLModel):
# optional
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
current_nominations: list[BeatmapNomination] | None = Field(
None, sa_column=Column(JSON)
)
current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON))
description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON))
# TODO: discussions: list[BeatmapsetDiscussion] = None
# TODO: current_user_attributes: Optional[CurrentUserAttributes] = None
@@ -105,22 +103,18 @@ class BeatmapsetBase(SQLModel):
can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
ranked_date: datetime | None = Field(
default=None, sa_column=Column(DateTime, index=True)
)
ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True))
storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
tags: str = Field(default="", sa_column=Column(Text))
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
__tablename__: str = "beatmapsets"
id: int | None = Field(default=None, primary_key=True, index=True)
id: int = Field(default=None, primary_key=True, index=True)
# Beatmapset
beatmap_status: BeatmapRankStatus = Field(
default=BeatmapRankStatus.GRAVEYARD, index=True
)
beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True)
# optional
beatmaps: list["Beatmap"] = Relationship(back_populates="beatmapset")
@@ -137,9 +131,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
async def from_resp(
cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0
) -> "Beatmapset":
async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
from .beatmap import Beatmap
d = resp.model_dump()
@@ -167,18 +159,14 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
"download_disabled": resp.availability.download_disabled or False,
}
)
if not (
await session.exec(select(exists()).where(Beatmapset.id == resp.id))
).first():
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
session.add(beatmapset)
await session.commit()
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
return beatmapset
@classmethod
async def get_or_fetch(
cls, session: AsyncSession, fetcher: "Fetcher", sid: int
) -> "Beatmapset":
async def get_or_fetch(cls, session: AsyncSession, fetcher: "Fetcher", sid: int) -> "Beatmapset":
beatmapset = await session.get(Beatmapset, sid)
if not beatmapset:
resp = await fetcher.get_beatmapset(sid)
@@ -227,13 +215,9 @@ class BeatmapsetResp(BeatmapsetBase):
@model_validator(mode="after")
def fix_genre_language(self) -> Self:
if self.genre is None:
self.genre = BeatmapTranslationText(
name=Genre(self.genre_id).name, id=self.genre_id
)
self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id)
if self.language is None:
self.language = BeatmapTranslationText(
name=Language(self.language_id).name, id=self.language_id
)
self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id)
return self
@classmethod
@@ -252,9 +236,7 @@ class BeatmapsetResp(BeatmapsetBase):
await BeatmapResp.from_db(beatmap, from_set=True, session=session)
for beatmap in await beatmapset.awaitable_attrs.beatmaps
],
"hype": BeatmapHype(
current=beatmapset.hype_current, required=beatmapset.hype_required
),
"hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required),
"availability": BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
@@ -282,10 +264,7 @@ class BeatmapsetResp(BeatmapsetBase):
update["ratings"] = []
beatmap_status = beatmapset.beatmap_status
if (
settings.enable_all_beatmap_leaderboard
and not beatmap_status.has_leaderboard()
):
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
update["status"] = BeatmapRankStatus.APPROVED.name.lower()
update["ranked"] = BeatmapRankStatus.APPROVED.value
else:
@@ -295,9 +274,7 @@ class BeatmapsetResp(BeatmapsetBase):
if session and user:
existing_favourite = (
await session.exec(
select(FavouriteBeatmapset).where(
FavouriteBeatmapset.beatmapset_id == beatmapset.id
)
select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
).first()
update["has_favourited"] = existing_favourite is not None

View File

@@ -20,13 +20,9 @@ if TYPE_CHECKING:
class BestScore(SQLModel, table=True):
__tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
__tablename__: str = "total_score_best_scores"
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True)
total_score: int = Field(default=0, sa_column=Column(BigInteger))

View File

@@ -51,30 +51,22 @@ class ChatChannelBase(SQLModel):
class ChatChannel(ChatChannelBase, table=True):
__tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType]
channel_id: int | None = Field(primary_key=True, index=True, default=None)
__tablename__: str = "chat_channels"
channel_id: int = Field(primary_key=True, index=True, default=None)
@classmethod
async def get(
cls, channel: str | int, session: AsyncSession
) -> "ChatChannel | None":
async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
if isinstance(channel, int) or channel.isdigit():
# 使用查询而不是 get() 来确保对象完全加载
result = await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
result = await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))
channel_ = result.first()
if channel_ is not None:
return channel_
result = await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
return result.first()
@classmethod
async def get_pm_channel(
cls, user1: int, user2: int, session: AsyncSession
) -> "ChatChannel | None":
async def get_pm_channel(cls, user1: int, user2: int, session: AsyncSession) -> "ChatChannel | None":
channel = await cls.get(f"pm_{user1}_{user2}", session)
if channel is None:
channel = await cls.get(f"pm_{user2}_{user1}", session)
@@ -153,18 +145,13 @@ class ChatChannelResp(ChatChannelBase):
.limit(10)
)
).all()
c.recent_messages = [
await ChatMessageResp.from_db(msg, session, user) for msg in messages
]
c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages]
c.recent_messages.reverse()
if c.type == ChannelType.PM and users and len(users) == 2:
target_user_id = next(u for u in users if u != user.id)
target_name = await session.exec(
select(User.username).where(User.id == target_user_id)
)
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
c.name = target_name.one()
assert user.id
c.users = [target_user_id, user.id]
return c
@@ -181,19 +168,15 @@ class MessageType(str, Enum):
class ChatMessageBase(UTCBaseModel, SQLModel):
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
content: str = Field(sa_column=Column(VARCHAR(1000)))
message_id: int | None = Field(index=True, primary_key=True, default=None)
sender_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
timestamp: datetime = Field(
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
)
message_id: int = Field(index=True, primary_key=True, default=None)
sender_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
timestamp: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC))
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
uuid: str | None = Field(default=None)
class ChatMessage(ChatMessageBase, table=True):
__tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType]
__tablename__: str = "chat_messages"
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
channel: ChatChannel = Relationship()
@@ -211,9 +194,7 @@ class ChatMessageResp(ChatMessageBase):
if user:
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
else:
m.sender = await UserResp.from_db(
db_message.user, session, RANKING_INCLUDES
)
m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES)
return m
@@ -221,17 +202,13 @@ class ChatMessageResp(ChatMessageBase):
class SilenceUser(UTCBaseModel, SQLModel, table=True):
__tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType]
id: int | None = Field(primary_key=True, default=None, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
__tablename__: str = "chat_silence_users"
id: int = Field(primary_key=True, default=None, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True)
until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None)
reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True))
banned_at: datetime = Field(
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
)
banned_at: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC))
class UserSilenceResp(SQLModel):
@@ -240,7 +217,6 @@ class UserSilenceResp(SQLModel):
@classmethod
def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp":
assert db_silence.id is not None
return cls(
id=db_silence.id,
user_id=db_silence.user_id,

View File

@@ -21,28 +21,24 @@ class CountBase(SQLModel):
class MonthlyPlaycounts(CountBase, table=True):
__tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType]
__tablename__: str = "monthly_playcounts"
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
user: "User" = Relationship(back_populates="monthly_playcounts")
class ReplayWatchedCount(CountBase, table=True):
__tablename__ = "replays_watched_counts" # pyright: ignore[reportAssignmentType]
__tablename__: str = "replays_watched_counts"
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
user: "User" = Relationship(back_populates="replays_watched_counts")

View File

@@ -24,9 +24,7 @@ class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
daily_streak_best: int = Field(default=0)
daily_streak_current: int = Field(default=0)
last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
last_weekly_streak: datetime | None = Field(
default=None, sa_column=Column(DateTime)
)
last_weekly_streak: datetime | None = Field(default=None, sa_column=Column(DateTime))
playcount: int = Field(default=0)
top_10p_placements: int = Field(default=0)
top_50p_placements: int = Field(default=0)
@@ -35,7 +33,7 @@ class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
class DailyChallengeStats(DailyChallengeStatsBase, table=True):
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
__tablename__: str = "daily_challenge_stats"
user_id: int | None = Field(
default=None,
@@ -61,9 +59,7 @@ class DailyChallengeStatsResp(DailyChallengeStatsBase):
return cls.model_validate(obj)
async def process_daily_challenge_score(
session: AsyncSession, user_id: int, room_id: int
):
async def process_daily_challenge_score(session: AsyncSession, user_id: int, room_id: int):
from .playlist_best_score import PlaylistBestScore
score = (

View File

@@ -4,16 +4,17 @@
from __future__ import annotations
from datetime import datetime, UTC
from sqlmodel import SQLModel, Field
from sqlalchemy import Column, BigInteger, ForeignKey
from datetime import UTC, datetime
from sqlalchemy import BigInteger, Column, ForeignKey
from sqlmodel import Field, SQLModel
class EmailVerification(SQLModel, table=True):
"""邮件验证记录"""
__tablename__: str = "email_verifications"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
email: str = Field(index=True)
@@ -28,9 +29,9 @@ class EmailVerification(SQLModel, table=True):
class LoginSession(SQLModel, table=True):
"""登录会话记录"""
__tablename__: str = "login_sessions"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
session_token: str = Field(unique=True, index=True) # 会话令牌

View File

@@ -36,17 +36,13 @@ class EventType(str, Enum):
class EventBase(SQLModel):
id: int = Field(default=None, primary_key=True)
created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), default=datetime.now(UTC))
)
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), default=datetime.now(UTC)))
type: EventType
event_payload: dict = Field(
exclude=True, default_factory=dict, sa_column=Column(JSON)
)
event_payload: dict = Field(exclude=True, default_factory=dict, sa_column=Column(JSON))
class Event(EventBase, table=True):
__tablename__ = "user_events" # pyright: ignore[reportAssignmentType]
__tablename__: str = "user_events"
user_id: int | None = Field(
default=None,
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True),

View File

@@ -16,8 +16,8 @@ FAILTIME_STRUCT = Struct("<100i")
class FailTime(SQLModel, table=True):
__tablename__ = "failtime" # pyright: ignore[reportAssignmentType]
beatmap_id: int = Field(primary_key=True, index=True, foreign_key="beatmaps.id")
__tablename__: str = "failtime"
beatmap_id: int = Field(primary_key=True, foreign_key="beatmaps.id")
exit: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False))
fail: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False))
@@ -41,12 +41,8 @@ class FailTime(SQLModel, table=True):
class FailTimeResp(BaseModel):
exit: list[int] = Field(
default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400))
)
fail: list[int] = Field(
default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400))
)
exit: list[int] = Field(default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)))
fail: list[int] = Field(default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)))
@classmethod
def from_db(cls, failtime: FailTime) -> "FailTimeResp":

View File

@@ -16,7 +16,7 @@ from sqlmodel import (
class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True):
__tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType]
__tablename__: str = "favourite_beatmapset"
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),

View File

@@ -75,9 +75,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_active: bool = True
is_bot: bool = False
is_supporter: bool = False
last_visit: datetime | None = Field(
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
)
last_visit: datetime | None = Field(default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)))
pm_friends_only: bool = False
profile_colour: str | None = None
username: str = Field(max_length=32, unique=True, index=True)
@@ -90,9 +88,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_restricted: bool = False
# blocks
cover: UserProfileCover = Field(
default=UserProfileCover(
url="https://assets.ppy.sh/user-profile-covers/default.jpeg"
),
default=UserProfileCover(url="https://assets.ppy.sh/user-profile-covers/default.jpeg"),
sa_column=Column(JSON),
)
beatmap_playcounts_count: int = 0
@@ -150,9 +146,9 @@ class UserBase(UTCBaseModel, SQLModel):
class User(AsyncAttrs, UserBase, table=True):
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
__tablename__: str = "lazer_users"
id: int | None = Field(
id: int = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
)
@@ -160,16 +156,10 @@ class User(AsyncAttrs, UserBase, table=True):
statistics: list[UserStatistics] = Relationship()
achievement: list[UserAchievement] = Relationship(back_populates="user")
team_membership: TeamMember | None = Relationship(back_populates="user")
daily_challenge_stats: DailyChallengeStats | None = Relationship(
back_populates="user"
)
daily_challenge_stats: DailyChallengeStats | None = Relationship(back_populates="user")
monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user")
replays_watched_counts: list[ReplayWatchedCount] = Relationship(
back_populates="user"
)
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(
back_populates="user"
)
replays_watched_counts: list[ReplayWatchedCount] = Relationship(back_populates="user")
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(back_populates="user")
rank_history: list[RankHistory] = Relationship(
back_populates="user",
)
@@ -178,16 +168,10 @@ class User(AsyncAttrs, UserBase, table=True):
email: str = Field(max_length=254, unique=True, index=True, exclude=True)
priv: int = Field(default=1, exclude=True)
pw_bcrypt: str = Field(max_length=60, exclude=True)
silence_end_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
)
donor_end_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
)
silence_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True)
donor_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True)
async def is_user_can_pm(
self, from_user: "User", session: AsyncSession
) -> tuple[bool, str]:
async def is_user_can_pm(self, from_user: "User", session: AsyncSession) -> tuple[bool, str]:
from .relationship import Relationship, RelationshipType
from_relationship = (
@@ -200,13 +184,10 @@ class User(AsyncAttrs, UserBase, table=True):
).first()
if from_relationship and from_relationship.type == RelationshipType.BLOCK:
return False, "You have blocked the target user."
if from_user.pm_friends_only and (
not from_relationship or from_relationship.type != RelationshipType.FOLLOW
):
if from_user.pm_friends_only and (not from_relationship or from_relationship.type != RelationshipType.FOLLOW):
return (
False,
"You have disabled non-friend communications "
"and target user is not your friend.",
"You have disabled non-friend communications and target user is not your friend.",
)
relationship = (
@@ -219,9 +200,7 @@ class User(AsyncAttrs, UserBase, table=True):
).first()
if relationship and relationship.type == RelationshipType.BLOCK:
return False, "Target user has blocked you."
if self.pm_friends_only and (
not relationship or relationship.type != RelationshipType.FOLLOW
):
if self.pm_friends_only and (not relationship or relationship.type != RelationshipType.FOLLOW):
return False, "Target user has disabled non-friend communications"
return True, ""
@@ -288,9 +267,7 @@ class UserResp(UserBase):
u = cls.model_validate(obj.model_dump())
u.id = obj.id
u.default_group = "bot" if u.is_bot else "default"
u.country = Country(
code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown")
)
u.country = Country(code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown"))
u.follower_count = (
await session.exec(
select(func.count())
@@ -314,9 +291,7 @@ class UserResp(UserBase):
redis = get_redis()
u.is_online = await redis.exists(f"metadata:online:{obj.id}")
u.cover_url = (
obj.cover.get(
"url", "https://assets.ppy.sh/user-profile-covers/default.jpeg"
)
obj.cover.get("url", "https://assets.ppy.sh/user-profile-covers/default.jpeg")
if obj.cover
else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
)
@@ -335,22 +310,15 @@ class UserResp(UserBase):
]
if "team" in include:
if await obj.awaitable_attrs.team_membership:
assert obj.team_membership
u.team = obj.team_membership.team
if team_membership := await obj.awaitable_attrs.team_membership:
u.team = team_membership.team
if "account_history" in include:
u.account_history = [
UserAccountHistoryResp.from_db(ah)
for ah in await obj.awaitable_attrs.account_history
]
u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
if "daily_challenge_user_stats":
if await obj.awaitable_attrs.daily_challenge_stats:
assert obj.daily_challenge_stats
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
obj.daily_challenge_stats
)
if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats:
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
if "statistics" in include:
current_stattistics = None
@@ -359,59 +327,40 @@ class UserResp(UserBase):
current_stattistics = i
break
u.statistics = (
await UserStatisticsResp.from_db(
current_stattistics, session, obj.country_code
)
await UserStatisticsResp.from_db(current_stattistics, session, obj.country_code)
if current_stattistics
else None
)
if "statistics_rulesets" in include:
u.statistics_rulesets = {
i.mode.value: await UserStatisticsResp.from_db(
i, session, obj.country_code
)
i.mode.value: await UserStatisticsResp.from_db(i, session, obj.country_code)
for i in await obj.awaitable_attrs.statistics
}
if "monthly_playcounts" in include:
u.monthly_playcounts = [
CountResp.from_db(pc)
for pc in await obj.awaitable_attrs.monthly_playcounts
]
u.monthly_playcounts = [CountResp.from_db(pc) for pc in await obj.awaitable_attrs.monthly_playcounts]
if len(u.monthly_playcounts) == 1:
d = u.monthly_playcounts[0].start_date
u.monthly_playcounts.insert(
0, CountResp(start_date=d - timedelta(days=20), count=0)
)
u.monthly_playcounts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
if "replays_watched_counts" in include:
u.replay_watched_counts = [
CountResp.from_db(rwc)
for rwc in await obj.awaitable_attrs.replays_watched_counts
CountResp.from_db(rwc) for rwc in await obj.awaitable_attrs.replays_watched_counts
]
if len(u.replay_watched_counts) == 1:
d = u.replay_watched_counts[0].start_date
u.replay_watched_counts.insert(
0, CountResp(start_date=d - timedelta(days=20), count=0)
)
u.replay_watched_counts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
if "achievements" in include:
u.user_achievements = [
UserAchievementResp.from_db(ua)
for ua in await obj.awaitable_attrs.achievement
]
u.user_achievements = [UserAchievementResp.from_db(ua) for ua in await obj.awaitable_attrs.achievement]
if "rank_history" in include:
rank_history = await RankHistoryResp.from_db(session, obj.id, ruleset)
if len(rank_history.data) != 0:
u.rank_history = rank_history
rank_top = (
await session.exec(
select(RankTop).where(
RankTop.user_id == obj.id, RankTop.mode == ruleset
)
)
await session.exec(select(RankTop).where(RankTop.user_id == obj.id, RankTop.mode == ruleset))
).first()
if rank_top:
u.rank_highest = (
@@ -425,9 +374,7 @@ class UserResp(UserBase):
u.favourite_beatmapset_count = (
await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.user_id == obj.id)
select(func.count()).select_from(FavouriteBeatmapset).where(FavouriteBeatmapset.user_id == obj.id)
)
).one()
u.scores_pinned_count = (
@@ -478,17 +425,19 @@ class UserResp(UserBase):
# 检查会话验证状态
# 如果邮件验证功能被禁用,则始终设置 session_verified 为 true
from app.config import settings
if not settings.enable_email_verification:
u.session_verified = True
else:
# 如果用户有未验证的登录会话,则设置 session_verified 为 false
from .email_verification import LoginSession
unverified_session = (
await session.exec(
select(LoginSession).where(
LoginSession.user_id == obj.id,
LoginSession.is_verified == False,
LoginSession.expires_at > datetime.now(UTC)
col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > datetime.now(UTC),
)
)
).first()

View File

@@ -30,8 +30,8 @@ class MultiplayerEventBase(SQLModel, UTCBaseModel):
class MultiplayerEvent(MultiplayerEventBase, table=True):
__tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
__tablename__: str = "multiplayer_events"
id: int = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
)

View File

@@ -17,7 +17,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
class Notification(SQLModel, table=True):
__tablename__ = "notifications" # pyright: ignore[reportAssignmentType]
__tablename__: str = "notifications"
id: int = Field(primary_key=True, index=True, default=None)
name: NotificationName = Field(index=True)
@@ -30,7 +30,7 @@ class Notification(SQLModel, table=True):
class UserNotification(SQLModel, table=True):
__tablename__ = "user_notifications" # pyright: ignore[reportAssignmentType]
__tablename__: str = "user_notifications"
id: int = Field(
sa_column=Column(
BigInteger,
@@ -40,9 +40,7 @@ class UserNotification(SQLModel, table=True):
default=None,
)
notification_id: int = Field(index=True, foreign_key="notifications.id")
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
is_read: bool = Field(index=True)
notification: Notification = Relationship(sa_relationship_kwargs={"lazy": "joined"})

View File

@@ -4,16 +4,17 @@
from __future__ import annotations
from datetime import datetime, UTC
from sqlmodel import SQLModel, Field
from sqlalchemy import Column, BigInteger, ForeignKey
from datetime import UTC, datetime
from sqlalchemy import BigInteger, Column, ForeignKey
from sqlmodel import Field, SQLModel
class PasswordReset(SQLModel, table=True):
"""密码重置记录"""
__tablename__: str = "password_resets"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
email: str = Field(index=True)

View File

@@ -21,16 +21,14 @@ class ItemAttemptsCountBase(SQLModel):
room_id: int = Field(foreign_key="rooms.id", index=True)
attempts: int = Field(default=0)
completed: int = Field(default=0)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
accuracy: float = 0.0
pp: float = 0
total_score: int = 0
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
__tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType]
__tablename__: str = "item_attempts_count"
id: int | None = Field(default=None, primary_key=True)
user: User = Relationship()
@@ -63,9 +61,7 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
self.pp = sum(score.score.pp for score in playlist_scores)
self.completed = len([score for score in playlist_scores if score.score.passed])
self.accuracy = (
sum(score.score.accuracy for score in playlist_scores) / self.completed
if self.completed > 0
else 0.0
sum(score.score.accuracy for score in playlist_scores) / self.completed if self.completed > 0 else 0.0
)
await session.commit()
await session.refresh(self)

View File

@@ -21,14 +21,10 @@ if TYPE_CHECKING:
class PlaylistBestScore(SQLModel, table=True):
__tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType]
__tablename__: str = "playlist_best_scores"
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
room_id: int = Field(foreign_key="rooms.id", index=True)
playlist_id: int = Field(index=True)
total_score: int = Field(default=0, sa_column=Column(BigInteger))

View File

@@ -50,7 +50,7 @@ class PlaylistBase(SQLModel, UTCBaseModel):
class Playlist(PlaylistBase, table=True):
__tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType]
__tablename__: str = "room_playlists"
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
room_id: int = Field(foreign_key="rooms.id", exclude=True)
@@ -63,16 +63,12 @@ class Playlist(PlaylistBase, table=True):
@classmethod
async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int:
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(
cls.room_id == room_id
)
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(cls.room_id == room_id)
result = await session.exec(stmt)
return result.one()
@classmethod
async def from_hub(
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
) -> "Playlist":
async def from_hub(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
next_id = await cls.get_next_id_for_room(room_id, session=session)
return cls(
id=next_id,
@@ -90,9 +86,7 @@ class Playlist(PlaylistBase, table=True):
@classmethod
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
db_playlist = await session.exec(
select(cls).where(cls.id == playlist.id, cls.room_id == room_id)
)
db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id))
db_playlist = db_playlist.first()
if db_playlist is None:
raise ValueError("Playlist item not found")
@@ -108,9 +102,7 @@ class Playlist(PlaylistBase, table=True):
await session.commit()
@classmethod
async def add_to_db(
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
):
async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
db_playlist = await cls.from_hub(playlist, room_id, session)
session.add(db_playlist)
await session.commit()
@@ -119,9 +111,7 @@ class Playlist(PlaylistBase, table=True):
@classmethod
async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession):
db_playlist = await session.exec(
select(cls).where(cls.id == item_id, cls.room_id == room_id)
)
db_playlist = await session.exec(select(cls).where(cls.id == item_id, cls.room_id == room_id))
db_playlist = db_playlist.first()
if db_playlist is None:
raise ValueError("Playlist item not found")
@@ -133,9 +123,7 @@ class PlaylistResp(PlaylistBase):
beatmap: BeatmapResp | None = None
@classmethod
async def from_db(
cls, playlist: Playlist, include: list[str] = []
) -> "PlaylistResp":
async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp":
data = playlist.model_dump()
if "beatmap" in include:
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)

View File

@@ -20,13 +20,9 @@ if TYPE_CHECKING:
class PPBestScore(SQLModel, table=True):
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
__tablename__: str = "best_scores"
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True)
pp: float = Field(

View File

@@ -26,12 +26,10 @@ if TYPE_CHECKING:
class RankHistory(SQLModel, table=True):
__tablename__ = "rank_history" # pyright: ignore[reportAssignmentType]
__tablename__: str = "rank_history"
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
mode: GameMode
rank: int
date: dt = Field(
@@ -43,12 +41,10 @@ class RankHistory(SQLModel, table=True):
class RankTop(SQLModel, table=True):
__tablename__ = "rank_top" # pyright: ignore[reportAssignmentType]
__tablename__: str = "rank_top"
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
mode: GameMode
rank: int
date: dt = Field(
@@ -62,9 +58,7 @@ class RankHistoryResp(BaseModel):
data: list[int]
@classmethod
async def from_db(
cls, session: AsyncSession, user_id: int, mode: GameMode
) -> "RankHistoryResp":
async def from_db(cls, session: AsyncSession, user_id: int, mode: GameMode) -> "RankHistoryResp":
results = (
await session.exec(
select(RankHistory)

View File

@@ -21,7 +21,7 @@ class RelationshipType(str, Enum):
class Relationship(SQLModel, table=True):
__tablename__ = "relationship" # pyright: ignore[reportAssignmentType]
__tablename__: str = "relationship"
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
@@ -59,9 +59,7 @@ class RelationshipResp(BaseModel):
type: RelationshipType
@classmethod
async def from_db(
cls, session: AsyncSession, relationship: Relationship
) -> "RelationshipResp":
async def from_db(cls, session: AsyncSession, relationship: Relationship) -> "RelationshipResp":
target_relationship = (
await session.exec(
select(Relationship).where(

View File

@@ -58,11 +58,9 @@ class RoomBase(SQLModel, UTCBaseModel):
class Room(AsyncAttrs, RoomBase, table=True):
__tablename__ = "rooms" # pyright: ignore[reportAssignmentType]
__tablename__: str = "rooms"
id: int = Field(default=None, primary_key=True, index=True)
host_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
host: User = Relationship()
playlist: list[Playlist] = Relationship(
@@ -109,12 +107,8 @@ class RoomResp(RoomBase):
if not playlist.expired:
stats.count_active += 1
rulesets.add(playlist.ruleset_id)
difficulty_range.min = min(
difficulty_range.min, playlist.beatmap.difficulty_rating
)
difficulty_range.max = max(
difficulty_range.max, playlist.beatmap.difficulty_rating
)
difficulty_range.min = min(difficulty_range.min, playlist.beatmap.difficulty_rating)
difficulty_range.max = max(difficulty_range.max, playlist.beatmap.difficulty_rating)
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
stats.ruleset_ids = list(rulesets)
resp.playlist_item_stats = stats
@@ -137,13 +131,9 @@ class RoomResp(RoomBase):
include=["statistics"],
)
)
resp.host = await UserResp.from_db(
await room.awaitable_attrs.host, session, include=["statistics"]
)
resp.host = await UserResp.from_db(await room.awaitable_attrs.host, session, include=["statistics"])
if "current_user_score" in include and user:
resp.current_user_score = await PlaylistAggregateScore.from_db(
room.id, user.id, session
)
resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
return resp
@classmethod

View File

@@ -18,22 +18,16 @@ if TYPE_CHECKING:
class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True):
__tablename__ = "room_participated_users" # pyright: ignore[reportAssignmentType]
__tablename__: str = "room_participated_users"
id: int | None = Field(
default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True)
)
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True))
room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), nullable=False))
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False))
joined_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False),
default=datetime.now(UTC),
)
left_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True), nullable=True), default=None
)
left_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True), default=None)
room: "Room" = Relationship()
user: "User" = Relationship()

View File

@@ -47,9 +47,9 @@ from .score_token import ScoreToken
from pydantic import field_serializer, field_validator
from redis.asyncio import Redis
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime, TextClause
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Mapped, aliased
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import (
JSON,
@@ -76,9 +76,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
accuracy: float
map_md5: str = Field(max_length=32, index=True)
build_id: int | None = Field(default=None)
classic_total_score: int | None = Field(
default=0, sa_column=Column(BigInteger)
) # solo_score
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score
ended_at: datetime = Field(sa_column=Column(DateTime))
has_replay: bool = Field(sa_column=Column(Boolean))
max_combo: int
@@ -91,14 +89,10 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
room_id: int | None = Field(default=None) # multiplayer
started_at: datetime = Field(sa_column=Column(DateTime))
total_score: int = Field(default=0, sa_column=Column(BigInteger))
total_score_without_mods: int = Field(
default=0, sa_column=Column(BigInteger), exclude=True
)
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
type: str
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
maximum_statistics: ScoreStatistics = Field(
sa_column=Column(JSON), default_factory=dict
)
maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict)
@field_validator("maximum_statistics", mode="before")
@classmethod
@@ -147,10 +141,8 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
class Score(ScoreBase, table=True):
__tablename__ = "scores" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
)
__tablename__: str = "scores"
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
user_id: int = Field(
default=None,
sa_column=Column(
@@ -193,8 +185,8 @@ class Score(ScoreBase, table=True):
return str(v)
# optional
beatmap: Beatmap = Relationship()
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
beatmap: Mapped[Beatmap] = Relationship()
user: Mapped[User] = Relationship(sa_relationship_kwargs={"lazy": "joined"})
@property
def is_perfect_combo(self) -> bool:
@@ -205,11 +197,7 @@ class Score(ScoreBase, table=True):
*where_clauses: ColumnExpressionArgument[bool] | bool,
) -> SelectOfScalar["Score"]:
rownum = (
func.row_number()
.over(
partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()
)
.label("rn")
func.row_number().over(partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()).label("rn")
)
subq = select(Score, rownum).where(*where_clauses).subquery()
best = aliased(Score, subq, adapt_on_names=True)
@@ -296,12 +284,9 @@ class ScoreResp(ScoreBase):
await session.refresh(score)
s = cls.model_validate(score.model_dump())
assert score.id
await score.awaitable_attrs.beatmap
s.beatmap = await BeatmapResp.from_db(score.beatmap)
s.beatmapset = await BeatmapsetResp.from_db(
score.beatmap.beatmapset, session=session, user=score.user
)
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = int(score.gamemode)
@@ -371,11 +356,7 @@ class ScoreAround(SQLModel):
async def get_best_id(session: AsyncSession, score_id: int) -> None:
rownum = (
func.row_number()
.over(
partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()
)
.label("rn")
func.row_number().over(partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()).label("rn")
)
subq = select(PPBestScore, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.score_id == score_id)
@@ -389,8 +370,8 @@ async def _score_where(
mode: GameMode,
mods: list[str] | None = None,
user: User | None = None,
) -> list[ColumnElement[bool]] | None:
wheres = [
) -> list[ColumnElement[bool] | TextClause] | None:
wheres: list[ColumnElement[bool] | TextClause] = [
col(BestScore.beatmap_id) == beatmap,
col(BestScore.gamemode) == mode,
]
@@ -410,9 +391,7 @@ async def _score_where(
return None
elif type == LeaderboardType.COUNTRY:
if user and user.is_supporter:
wheres.append(
col(BestScore.user).has(col(User.country_code) == user.country_code)
)
wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
else:
return None
elif type == LeaderboardType.TEAM:
@@ -420,18 +399,14 @@ async def _score_where(
team_membership = await user.awaitable_attrs.team_membership
if team_membership:
team_id = team_membership.team_id
wheres.append(
col(BestScore.user).has(
col(User.team_membership).has(TeamMember.team_id == team_id)
)
)
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
if mods:
if user and user.is_supporter:
wheres.append(
text(
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
).params(w=json.dumps(mods)) # pyright: ignore[reportArgumentType]
).params(w=json.dumps(mods))
)
else:
return None
@@ -654,18 +629,14 @@ def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]:
+ (score.nsmall_tick_hit or 0)
)
total_obj = 0
for statistics, count in (
score.maximum_statistics.items() if score.maximum_statistics else {}
):
for statistics, count in score.maximum_statistics.items() if score.maximum_statistics else {}:
if not isinstance(statistics, HitResult):
statistics = HitResult(statistics)
if statistics.is_scorable():
total_obj += count
return total_length, score.passed or (
total_length > 8
and score.total_score >= 5000
and total_obj_hited >= min(0.1 * total_obj, 20)
total_length > 8 and score.total_score >= 5000 and total_obj_hited >= min(0.1 * total_obj, 20)
)
@@ -678,12 +649,8 @@ async def process_user(
ranked: bool = False,
has_leaderboard: bool = False,
):
assert user.id
assert score.id
mod_for_save = mod_to_save(score.mods)
previous_score_best = await get_user_best_score_in_beatmap(
session, score.beatmap_id, user.id, score.gamemode
)
previous_score_best = await get_user_best_score_in_beatmap(session, score.beatmap_id, user.id, score.gamemode)
previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap(
session, score.beatmap_id, user.id, mod_for_save, score.gamemode
)
@@ -698,9 +665,7 @@ async def process_user(
)
).first()
if mouthly_playcount is None:
mouthly_playcount = MonthlyPlaycounts(
user_id=user.id, year=date.today().year, month=date.today().month
)
mouthly_playcount = MonthlyPlaycounts(user_id=user.id, year=date.today().year, month=date.today().month)
add_to_db = True
statistics = None
for i in await user.awaitable_attrs.statistics:
@@ -708,17 +673,11 @@ async def process_user(
statistics = i
break
if statistics is None:
raise ValueError(
f"User {user.id} does not have statistics for mode {score.gamemode.value}"
)
raise ValueError(f"User {user.id} does not have statistics for mode {score.gamemode.value}")
# pc, pt, tth, tts
statistics.total_score += score.total_score
difference = (
score.total_score - previous_score_best.total_score
if previous_score_best
else score.total_score
)
difference = score.total_score - previous_score_best.total_score if previous_score_best else score.total_score
if difference > 0 and score.passed and ranked:
match score.rank:
case Rank.X:
@@ -746,11 +705,8 @@ async def process_user(
statistics.ranked_score += difference
statistics.level_current = calculate_score_to_level(statistics.total_score)
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)
new_score_position = await get_score_position_by_user(
session, score.beatmap_id, user, score.gamemode
)
new_score_position = await get_score_position_by_user(session, score.beatmap_id, user, score.gamemode)
total_users = await session.exec(select(func.count()).select_from(User))
assert total_users is not None
score_range = min(50, math.ceil(float(total_users.one()) * 0.01))
if new_score_position <= score_range and new_score_position > 0:
# Get the scores that might be displaced
@@ -774,11 +730,7 @@ async def process_user(
)
# If this score was previously in top positions but now pushed out
if (
i < score_range
and displaced_position > score_range
and displaced_position is not None
):
if i < score_range and displaced_position > score_range and displaced_position is not None:
# Create rank lost event for the displaced user
rank_lost_event = Event(
created_at=datetime.now(UTC),
@@ -814,10 +766,7 @@ async def process_user(
)
# 情况3: 有最佳分数记录和该mod组合的记录且是同一个记录更新得分更高的情况
elif (
previous_score_best.score_id == previous_score_best_mod.score_id
and difference > 0
):
elif previous_score_best.score_id == previous_score_best_mod.score_id and difference > 0:
previous_score_best.total_score = score.total_score
previous_score_best.rank = score.rank
previous_score_best.score_id = score.id
@@ -847,9 +796,7 @@ async def process_user(
statistics.count_300 += score.n300 + score.ngeki
statistics.count_50 += score.n50
statistics.count_miss += score.nmiss
statistics.total_hits += (
score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
)
statistics.total_hits += score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
if score.passed and ranked:
with session.no_autoflush:
@@ -885,7 +832,6 @@ async def process_score(
item_id: int | None = None,
room_id: int | None = None,
) -> Score:
assert user.id
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
gamemode = GameMode.from_int(info.ruleset_id).to_special_mode(info.mods)
score = Score(
@@ -922,20 +868,15 @@ async def process_score(
if can_get_pp:
from app.calculator import pre_fetch_and_calculate_pp
pp = await pre_fetch_and_calculate_pp(
score, beatmap_id, session, redis, fetcher
)
pp = await pre_fetch_and_calculate_pp(score, beatmap_id, session, redis, fetcher)
score.pp = pp
session.add(score)
user_id = user.id
await session.commit()
await session.refresh(score)
if can_get_pp and score.pp != 0:
previous_pp_best = await get_user_best_pp_in_beatmap(
session, beatmap_id, user_id, score.gamemode
)
previous_pp_best = await get_user_best_pp_in_beatmap(session, beatmap_id, user_id, score.gamemode)
if previous_pp_best is None or score.pp > previous_pp_best.pp:
assert score.id
best_score = PPBestScore(
user_id=user_id,
score_id=score.id,

View File

@@ -7,6 +7,7 @@ from .beatmap import Beatmap
from .lazer_user import User
from sqlalchemy import Column, DateTime, Index
from sqlalchemy.orm import Mapped
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
@@ -14,16 +15,12 @@ class ScoreTokenBase(SQLModel, UTCBaseModel):
score_id: int | None = Field(sa_column=Column(BigInteger), default=None)
ruleset_id: GameMode
playlist_item_id: int | None = Field(default=None) # playlist
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
created_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
class ScoreToken(ScoreTokenBase, table=True):
__tablename__ = "score_tokens" # pyright: ignore[reportAssignmentType]
__tablename__: str = "score_tokens"
__table_args__ = (Index("idx_user_playlist", "user_id", "playlist_item_id"),)
id: int | None = Field(
@@ -37,8 +34,8 @@ class ScoreToken(ScoreTokenBase, table=True):
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
beatmap_id: int = Field(foreign_key="beatmaps.id")
user: User = Relationship()
beatmap: Beatmap = Relationship()
user: Mapped[User] = Relationship()
beatmap: Mapped[Beatmap] = Relationship()
class ScoreTokenResp(ScoreTokenBase):

View File

@@ -58,7 +58,7 @@ class UserStatisticsBase(SQLModel):
class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
__tablename__: str = "lazer_user_statistics"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(
default=None,
@@ -123,9 +123,7 @@ class UserStatisticsResp(UserStatisticsBase):
if "user" in include:
from .lazer_user import RANKING_INCLUDES, UserResp
user = await UserResp.from_db(
await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES
)
user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
s.user = user
user_country = user.country_code
@@ -149,9 +147,7 @@ class UserStatisticsResp(UserStatisticsBase):
return s
async def get_rank(
session: AsyncSession, statistics: UserStatistics, country: str | None = None
) -> int | None:
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
from .lazer_user import User
query = select(
@@ -168,9 +164,7 @@ async def get_rank(
subq = query.subquery()
result = await session.exec(
select(subq.c.rank).where(subq.c.user_id == statistics.user_id)
)
result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id))
rank = result.first()
if rank is None:

View File

@@ -11,9 +11,9 @@ if TYPE_CHECKING:
class Team(SQLModel, UTCBaseModel, table=True):
__tablename__ = "teams" # pyright: ignore[reportAssignmentType]
__tablename__: str = "teams"
id: int | None = Field(default=None, primary_key=True, index=True)
id: int = Field(default=None, primary_key=True, index=True)
name: str = Field(max_length=100)
short_name: str = Field(max_length=10)
flag_url: str | None = Field(default=None)
@@ -26,34 +26,22 @@ class Team(SQLModel, UTCBaseModel, table=True):
class TeamMember(SQLModel, UTCBaseModel, table=True):
__tablename__ = "team_members" # pyright: ignore[reportAssignmentType]
__tablename__: str = "team_members"
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
team_id: int = Field(foreign_key="teams.id")
joined_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
joined_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
user: "User" = Relationship(
back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
)
team: "Team" = Relationship(
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
)
user: "User" = Relationship(back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"})
team: "Team" = Relationship(back_populates="members", sa_relationship_kwargs={"lazy": "joined"})
class TeamRequest(SQLModel, UTCBaseModel, table=True):
__tablename__ = "team_requests" # pyright: ignore[reportAssignmentType]
__tablename__: str = "team_requests"
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
team_id: int = Field(foreign_key="teams.id", primary_key=True)
requested_at: datetime = Field(
default=datetime.now(UTC), sa_column=Column(DateTime)
)
requested_at: datetime = Field(default=datetime.now(UTC), sa_column=Column(DateTime))
user: "User" = Relationship(sa_relationship_kwargs={"lazy": "joined"})
team: "Team" = Relationship(sa_relationship_kwargs={"lazy": "joined"})

View File

@@ -22,7 +22,7 @@ class UserAccountHistoryBase(SQLModel, UTCBaseModel):
class UserAccountHistory(UserAccountHistoryBase, table=True):
__tablename__ = "user_account_history" # pyright: ignore[reportAssignmentType]
__tablename__: str = "user_account_history"
id: int | None = Field(
sa_column=Column(
@@ -32,9 +32,7 @@ class UserAccountHistory(UserAccountHistoryBase, table=True):
primary_key=True,
)
)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
class UserAccountHistoryResp(UserAccountHistoryBase):

View File

@@ -10,27 +10,17 @@ from sqlmodel import Field, SQLModel
class UserLoginLog(SQLModel, table=True):
"""User login log table"""
__tablename__ = "user_login_log" # pyright: ignore[reportAssignmentType]
__tablename__: str = "user_login_log"
id: int | None = Field(default=None, primary_key=True, description="Record ID")
user_id: int = Field(index=True, description="User ID")
ip_address: str = Field(
max_length=45, index=True, description="IP address (supports IPv4 and IPv6)"
)
user_agent: str | None = Field(
default=None, max_length=500, description="User agent information"
)
login_time: datetime = Field(
default_factory=datetime.utcnow, description="Login time"
)
ip_address: str = Field(max_length=45, index=True, description="IP address (supports IPv4 and IPv6)")
user_agent: str | None = Field(default=None, max_length=500, description="User agent information")
login_time: datetime = Field(default_factory=datetime.utcnow, description="Login time")
# GeoIP information
country_code: str | None = Field(
default=None, max_length=2, description="Country code"
)
country_name: str | None = Field(
default=None, max_length=100, description="Country name"
)
country_code: str | None = Field(default=None, max_length=2, description="Country code")
country_name: str | None = Field(default=None, max_length=100, description="Country name")
city_name: str | None = Field(default=None, max_length=100, description="City name")
latitude: str | None = Field(default=None, max_length=20, description="Latitude")
longitude: str | None = Field(default=None, max_length=20, description="Longitude")
@@ -38,22 +28,14 @@ class UserLoginLog(SQLModel, table=True):
# ASN information
asn: int | None = Field(default=None, description="Autonomous System Number")
organization: str | None = Field(
default=None, max_length=200, description="Organization name"
)
organization: str | None = Field(default=None, max_length=200, description="Organization name")
# Login status
login_success: bool = Field(
default=True, description="Whether the login was successful"
)
login_method: str = Field(
max_length=50, description="Login method (password/oauth/etc.)"
)
login_success: bool = Field(default=True, description="Whether the login was successful")
login_method: str = Field(max_length=50, description="Login method (password/oauth/etc.)")
# Additional information
notes: str | None = Field(
default=None, max_length=500, description="Additional notes"
)
notes: str | None = Field(default=None, max_length=500, description="Additional notes")
class Config:
from_attributes = True