refactor(project): make pyright & ruff happy
This commit is contained in:
@@ -28,18 +28,14 @@ if TYPE_CHECKING:
|
||||
|
||||
class UserAchievementBase(SQLModel, UTCBaseModel):
|
||||
achievement_id: int
|
||||
achieved_at: datetime = Field(
|
||||
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
achieved_at: datetime = Field(default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)))
|
||||
|
||||
|
||||
class UserAchievement(UserAchievementBase, table=True):
|
||||
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "lazer_user_achievements"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True)
|
||||
user: "User" = Relationship(back_populates="achievement")
|
||||
|
||||
|
||||
@@ -56,11 +52,7 @@ async def process_achievements(session: AsyncSession, redis: Redis, score_id: in
|
||||
if not score:
|
||||
return
|
||||
achieved = (
|
||||
await session.exec(
|
||||
select(UserAchievement.achievement_id).where(
|
||||
UserAchievement.user_id == score.user_id
|
||||
)
|
||||
)
|
||||
await session.exec(select(UserAchievement.achievement_id).where(UserAchievement.user_id == score.user_id))
|
||||
).all()
|
||||
not_achieved = {k: v for k, v in MEDALS.items() if k.id not in achieved}
|
||||
result: list[Achievement] = []
|
||||
@@ -78,9 +70,7 @@ async def process_achievements(session: AsyncSession, redis: Redis, score_id: in
|
||||
)
|
||||
await redis.publish(
|
||||
"chat:notification",
|
||||
UserAchievementUnlock.init(
|
||||
r, score.user_id, score.gamemode
|
||||
).model_dump_json(),
|
||||
UserAchievementUnlock.init(r, score.user_id, score.gamemode).model_dump_json(),
|
||||
)
|
||||
event = Event(
|
||||
created_at=now,
|
||||
|
||||
@@ -20,42 +20,34 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class OAuthToken(UTCBaseModel, SQLModel, table=True):
|
||||
__tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "oauth_tokens"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
client_id: int = Field(index=True)
|
||||
access_token: str = Field(max_length=500, unique=True)
|
||||
refresh_token: str = Field(max_length=500, unique=True)
|
||||
token_type: str = Field(default="Bearer", max_length=20)
|
||||
scope: str = Field(default="*", max_length=100)
|
||||
expires_at: datetime = Field(sa_column=Column(DateTime))
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
|
||||
|
||||
user: "User" = Relationship()
|
||||
|
||||
|
||||
class OAuthClient(SQLModel, table=True):
|
||||
__tablename__ = "oauth_clients" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "oauth_clients"
|
||||
name: str = Field(max_length=100, index=True)
|
||||
description: str = Field(sa_column=Column(Text), default="")
|
||||
client_id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
client_secret: str = Field(default_factory=secrets.token_hex, index=True)
|
||||
redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
owner_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
|
||||
|
||||
class V1APIKeys(SQLModel, table=True):
|
||||
__tablename__ = "v1_api_keys" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "v1_api_keys"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
name: str = Field(max_length=100, index=True)
|
||||
key: str = Field(default_factory=secrets.token_hex, index=True)
|
||||
owner_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
|
||||
@@ -60,17 +60,13 @@ class BeatmapBase(SQLModel):
|
||||
|
||||
|
||||
class Beatmap(BeatmapBase, table=True):
|
||||
__tablename__ = "beatmaps" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "beatmaps"
|
||||
id: int = Field(primary_key=True, index=True)
|
||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||
beatmap_status: BeatmapRankStatus = Field(index=True)
|
||||
# optional
|
||||
beatmapset: Beatmapset = Relationship(
|
||||
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
failtimes: FailTime | None = Relationship(
|
||||
back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
|
||||
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@classmethod
|
||||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||
@@ -84,21 +80,15 @@ class Beatmap(BeatmapBase, table=True):
|
||||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
||||
}
|
||||
)
|
||||
if not (
|
||||
await session.exec(select(exists()).where(Beatmap.id == resp.id))
|
||||
).first():
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||||
session.add(beatmap)
|
||||
await session.commit()
|
||||
beatmap = (
|
||||
await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
|
||||
).first()
|
||||
beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).first()
|
||||
assert beatmap is not None, "Beatmap should not be None after commit"
|
||||
return beatmap
|
||||
|
||||
@classmethod
|
||||
async def from_resp_batch(
|
||||
cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0
|
||||
) -> list["Beatmap"]:
|
||||
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
|
||||
beatmaps = []
|
||||
for resp in inp:
|
||||
if resp.id == from_:
|
||||
@@ -113,9 +103,7 @@ class Beatmap(BeatmapBase, table=True):
|
||||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
||||
}
|
||||
)
|
||||
if not (
|
||||
await session.exec(select(exists()).where(Beatmap.id == resp.id))
|
||||
).first():
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||||
session.add(beatmap)
|
||||
beatmaps.append(beatmap)
|
||||
await session.commit()
|
||||
@@ -130,17 +118,11 @@ class Beatmap(BeatmapBase, table=True):
|
||||
md5: str | None = None,
|
||||
) -> "Beatmap":
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap).where(
|
||||
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
|
||||
)
|
||||
)
|
||||
await session.exec(select(Beatmap).where(Beatmap.id == bid if bid is not None else Beatmap.checksum == md5))
|
||||
).first()
|
||||
if not beatmap:
|
||||
resp = await fetcher.get_beatmap(bid, md5)
|
||||
r = await session.exec(
|
||||
select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)
|
||||
)
|
||||
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
|
||||
if not r.first():
|
||||
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
|
||||
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
|
||||
@@ -178,10 +160,7 @@ class BeatmapResp(BeatmapBase):
|
||||
if query_mode is not None and beatmap.mode != query_mode:
|
||||
beatmap_["convert"] = True
|
||||
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
|
||||
if (
|
||||
settings.enable_all_beatmap_leaderboard
|
||||
and not beatmap_status.has_leaderboard()
|
||||
):
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
|
||||
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
|
||||
else:
|
||||
@@ -189,9 +168,7 @@ class BeatmapResp(BeatmapBase):
|
||||
beatmap_["ranked"] = beatmap_status.value
|
||||
beatmap_["mode_int"] = int(beatmap.mode)
|
||||
if not from_set:
|
||||
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(
|
||||
beatmap.beatmapset, session=session, user=user
|
||||
)
|
||||
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
|
||||
if beatmap.failtimes is not None:
|
||||
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
|
||||
else:
|
||||
@@ -218,7 +195,7 @@ class BeatmapResp(BeatmapBase):
|
||||
|
||||
|
||||
class BannedBeatmaps(SQLModel, table=True):
|
||||
__tablename__ = "banned_beatmaps" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "banned_beatmaps"
|
||||
id: int | None = Field(primary_key=True, index=True, default=None)
|
||||
beatmap_id: int = Field(index=True)
|
||||
|
||||
@@ -230,15 +207,10 @@ async def calculate_beatmap_attributes(
|
||||
redis: Redis,
|
||||
fetcher: "Fetcher",
|
||||
):
|
||||
key = (
|
||||
f"beatmap:{beatmap_id}:{ruleset}:"
|
||||
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
)
|
||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
if await redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key))
|
||||
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
attr = await asyncio.get_event_loop().run_in_executor(
|
||||
None, calculate_beatmap_attribute, resp, ruleset, mods_
|
||||
)
|
||||
attr = await asyncio.get_event_loop().run_in_executor(None, calculate_beatmap_attribute, resp, ruleset, mods_)
|
||||
await redis.set(key, attr.model_dump_json())
|
||||
return attr
|
||||
|
||||
@@ -23,15 +23,13 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True):
|
||||
__tablename__ = "beatmap_playcounts" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "beatmap_playcounts"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
|
||||
)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
||||
playcount: int = Field(default=0)
|
||||
|
||||
@@ -59,9 +57,7 @@ class BeatmapPlaycountsResp(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
async def process_beatmap_playcount(
|
||||
session: AsyncSession, user_id: int, beatmap_id: int
|
||||
) -> None:
|
||||
async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None:
|
||||
existing_playcount = (
|
||||
await session.exec(
|
||||
select(BeatmapPlaycounts).where(
|
||||
@@ -89,7 +85,5 @@ async def process_beatmap_playcount(
|
||||
}
|
||||
session.add(playcount_event)
|
||||
else:
|
||||
new_playcount = BeatmapPlaycounts(
|
||||
user_id=user_id, beatmap_id=beatmap_id, playcount=1
|
||||
)
|
||||
new_playcount = BeatmapPlaycounts(user_id=user_id, beatmap_id=beatmap_id, playcount=1)
|
||||
session.add(new_playcount)
|
||||
|
||||
@@ -86,9 +86,7 @@ class BeatmapsetBase(SQLModel):
|
||||
|
||||
# optional
|
||||
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
|
||||
current_nominations: list[BeatmapNomination] | None = Field(
|
||||
None, sa_column=Column(JSON)
|
||||
)
|
||||
current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON))
|
||||
description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON))
|
||||
# TODO: discussions: list[BeatmapsetDiscussion] = None
|
||||
# TODO: current_user_attributes: Optional[CurrentUserAttributes] = None
|
||||
@@ -105,22 +103,18 @@ class BeatmapsetBase(SQLModel):
|
||||
can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
|
||||
ranked_date: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime, index=True)
|
||||
)
|
||||
ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True))
|
||||
storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
|
||||
submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
|
||||
tags: str = Field(default="", sa_column=Column(Text))
|
||||
|
||||
|
||||
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "beatmapsets"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
# Beatmapset
|
||||
beatmap_status: BeatmapRankStatus = Field(
|
||||
default=BeatmapRankStatus.GRAVEYARD, index=True
|
||||
)
|
||||
beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True)
|
||||
|
||||
# optional
|
||||
beatmaps: list["Beatmap"] = Relationship(back_populates="beatmapset")
|
||||
@@ -137,9 +131,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
|
||||
|
||||
@classmethod
|
||||
async def from_resp(
|
||||
cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0
|
||||
) -> "Beatmapset":
|
||||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
|
||||
from .beatmap import Beatmap
|
||||
|
||||
d = resp.model_dump()
|
||||
@@ -167,18 +159,14 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
"download_disabled": resp.availability.download_disabled or False,
|
||||
}
|
||||
)
|
||||
if not (
|
||||
await session.exec(select(exists()).where(Beatmapset.id == resp.id))
|
||||
).first():
|
||||
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
|
||||
session.add(beatmapset)
|
||||
await session.commit()
|
||||
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
|
||||
return beatmapset
|
||||
|
||||
@classmethod
|
||||
async def get_or_fetch(
|
||||
cls, session: AsyncSession, fetcher: "Fetcher", sid: int
|
||||
) -> "Beatmapset":
|
||||
async def get_or_fetch(cls, session: AsyncSession, fetcher: "Fetcher", sid: int) -> "Beatmapset":
|
||||
beatmapset = await session.get(Beatmapset, sid)
|
||||
if not beatmapset:
|
||||
resp = await fetcher.get_beatmapset(sid)
|
||||
@@ -227,13 +215,9 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
@model_validator(mode="after")
|
||||
def fix_genre_language(self) -> Self:
|
||||
if self.genre is None:
|
||||
self.genre = BeatmapTranslationText(
|
||||
name=Genre(self.genre_id).name, id=self.genre_id
|
||||
)
|
||||
self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id)
|
||||
if self.language is None:
|
||||
self.language = BeatmapTranslationText(
|
||||
name=Language(self.language_id).name, id=self.language_id
|
||||
)
|
||||
self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@@ -252,9 +236,7 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
await BeatmapResp.from_db(beatmap, from_set=True, session=session)
|
||||
for beatmap in await beatmapset.awaitable_attrs.beatmaps
|
||||
],
|
||||
"hype": BeatmapHype(
|
||||
current=beatmapset.hype_current, required=beatmapset.hype_required
|
||||
),
|
||||
"hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required),
|
||||
"availability": BeatmapAvailability(
|
||||
more_information=beatmapset.availability_info,
|
||||
download_disabled=beatmapset.download_disabled,
|
||||
@@ -282,10 +264,7 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
update["ratings"] = []
|
||||
|
||||
beatmap_status = beatmapset.beatmap_status
|
||||
if (
|
||||
settings.enable_all_beatmap_leaderboard
|
||||
and not beatmap_status.has_leaderboard()
|
||||
):
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||
update["status"] = BeatmapRankStatus.APPROVED.name.lower()
|
||||
update["ranked"] = BeatmapRankStatus.APPROVED.value
|
||||
else:
|
||||
@@ -295,9 +274,7 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
if session and user:
|
||||
existing_favourite = (
|
||||
await session.exec(
|
||||
select(FavouriteBeatmapset).where(
|
||||
FavouriteBeatmapset.beatmapset_id == beatmapset.id
|
||||
)
|
||||
select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
||||
)
|
||||
).first()
|
||||
update["has_favourited"] = existing_favourite is not None
|
||||
|
||||
@@ -20,13 +20,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class BestScore(SQLModel, table=True):
|
||||
__tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType]
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
score_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
|
||||
)
|
||||
__tablename__: str = "total_score_best_scores"
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
||||
gamemode: GameMode = Field(index=True)
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
|
||||
@@ -51,30 +51,22 @@ class ChatChannelBase(SQLModel):
|
||||
|
||||
|
||||
class ChatChannel(ChatChannelBase, table=True):
|
||||
__tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType]
|
||||
channel_id: int | None = Field(primary_key=True, index=True, default=None)
|
||||
__tablename__: str = "chat_channels"
|
||||
channel_id: int = Field(primary_key=True, index=True, default=None)
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls, channel: str | int, session: AsyncSession
|
||||
) -> "ChatChannel | None":
|
||||
async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
|
||||
if isinstance(channel, int) or channel.isdigit():
|
||||
# 使用查询而不是 get() 来确保对象完全加载
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))
|
||||
channel_ = result.first()
|
||||
if channel_ is not None:
|
||||
return channel_
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
return result.first()
|
||||
|
||||
@classmethod
|
||||
async def get_pm_channel(
|
||||
cls, user1: int, user2: int, session: AsyncSession
|
||||
) -> "ChatChannel | None":
|
||||
async def get_pm_channel(cls, user1: int, user2: int, session: AsyncSession) -> "ChatChannel | None":
|
||||
channel = await cls.get(f"pm_{user1}_{user2}", session)
|
||||
if channel is None:
|
||||
channel = await cls.get(f"pm_{user2}_{user1}", session)
|
||||
@@ -153,18 +145,13 @@ class ChatChannelResp(ChatChannelBase):
|
||||
.limit(10)
|
||||
)
|
||||
).all()
|
||||
c.recent_messages = [
|
||||
await ChatMessageResp.from_db(msg, session, user) for msg in messages
|
||||
]
|
||||
c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages]
|
||||
c.recent_messages.reverse()
|
||||
|
||||
if c.type == ChannelType.PM and users and len(users) == 2:
|
||||
target_user_id = next(u for u in users if u != user.id)
|
||||
target_name = await session.exec(
|
||||
select(User.username).where(User.id == target_user_id)
|
||||
)
|
||||
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
|
||||
c.name = target_name.one()
|
||||
assert user.id
|
||||
c.users = [target_user_id, user.id]
|
||||
return c
|
||||
|
||||
@@ -181,19 +168,15 @@ class MessageType(str, Enum):
|
||||
class ChatMessageBase(UTCBaseModel, SQLModel):
|
||||
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
|
||||
content: str = Field(sa_column=Column(VARCHAR(1000)))
|
||||
message_id: int | None = Field(index=True, primary_key=True, default=None)
|
||||
sender_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||
)
|
||||
message_id: int = Field(index=True, primary_key=True, default=None)
|
||||
sender_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
timestamp: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC))
|
||||
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
|
||||
uuid: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ChatMessage(ChatMessageBase, table=True):
|
||||
__tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "chat_messages"
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
channel: ChatChannel = Relationship()
|
||||
|
||||
@@ -211,9 +194,7 @@ class ChatMessageResp(ChatMessageBase):
|
||||
if user:
|
||||
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
else:
|
||||
m.sender = await UserResp.from_db(
|
||||
db_message.user, session, RANKING_INCLUDES
|
||||
)
|
||||
m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES)
|
||||
return m
|
||||
|
||||
|
||||
@@ -221,17 +202,13 @@ class ChatMessageResp(ChatMessageBase):
|
||||
|
||||
|
||||
class SilenceUser(UTCBaseModel, SQLModel, table=True):
|
||||
__tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(primary_key=True, default=None, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
__tablename__: str = "chat_silence_users"
|
||||
id: int = Field(primary_key=True, default=None, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True)
|
||||
until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None)
|
||||
reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True))
|
||||
banned_at: datetime = Field(
|
||||
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||
)
|
||||
banned_at: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC))
|
||||
|
||||
|
||||
class UserSilenceResp(SQLModel):
|
||||
@@ -240,7 +217,6 @@ class UserSilenceResp(SQLModel):
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp":
|
||||
assert db_silence.id is not None
|
||||
return cls(
|
||||
id=db_silence.id,
|
||||
user_id=db_silence.user_id,
|
||||
|
||||
@@ -21,28 +21,24 @@ class CountBase(SQLModel):
|
||||
|
||||
|
||||
class MonthlyPlaycounts(CountBase, table=True):
|
||||
__tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "monthly_playcounts"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
|
||||
)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
user: "User" = Relationship(back_populates="monthly_playcounts")
|
||||
|
||||
|
||||
class ReplayWatchedCount(CountBase, table=True):
|
||||
__tablename__ = "replays_watched_counts" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "replays_watched_counts"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
|
||||
)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
user: "User" = Relationship(back_populates="replays_watched_counts")
|
||||
|
||||
|
||||
|
||||
@@ -24,9 +24,7 @@ class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
|
||||
daily_streak_best: int = Field(default=0)
|
||||
daily_streak_current: int = Field(default=0)
|
||||
last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||||
last_weekly_streak: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime)
|
||||
)
|
||||
last_weekly_streak: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||||
playcount: int = Field(default=0)
|
||||
top_10p_placements: int = Field(default=0)
|
||||
top_50p_placements: int = Field(default=0)
|
||||
@@ -35,7 +33,7 @@ class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
|
||||
|
||||
|
||||
class DailyChallengeStats(DailyChallengeStatsBase, table=True):
|
||||
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "daily_challenge_stats"
|
||||
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
@@ -61,9 +59,7 @@ class DailyChallengeStatsResp(DailyChallengeStatsBase):
|
||||
return cls.model_validate(obj)
|
||||
|
||||
|
||||
async def process_daily_challenge_score(
|
||||
session: AsyncSession, user_id: int, room_id: int
|
||||
):
|
||||
async def process_daily_challenge_score(session: AsyncSession, user_id: int, room_id: int):
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
|
||||
score = (
|
||||
|
||||
@@ -4,16 +4,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, UTC
|
||||
from sqlmodel import SQLModel, Field
|
||||
from sqlalchemy import Column, BigInteger, ForeignKey
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import BigInteger, Column, ForeignKey
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class EmailVerification(SQLModel, table=True):
|
||||
"""邮件验证记录"""
|
||||
|
||||
|
||||
__tablename__: str = "email_verifications"
|
||||
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
|
||||
email: str = Field(index=True)
|
||||
@@ -28,9 +29,9 @@ class EmailVerification(SQLModel, table=True):
|
||||
|
||||
class LoginSession(SQLModel, table=True):
|
||||
"""登录会话记录"""
|
||||
|
||||
|
||||
__tablename__: str = "login_sessions"
|
||||
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
|
||||
session_token: str = Field(unique=True, index=True) # 会话令牌
|
||||
|
||||
@@ -36,17 +36,13 @@ class EventType(str, Enum):
|
||||
|
||||
class EventBase(SQLModel):
|
||||
id: int = Field(default=None, primary_key=True)
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), default=datetime.now(UTC))
|
||||
)
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), default=datetime.now(UTC)))
|
||||
type: EventType
|
||||
event_payload: dict = Field(
|
||||
exclude=True, default_factory=dict, sa_column=Column(JSON)
|
||||
)
|
||||
event_payload: dict = Field(exclude=True, default_factory=dict, sa_column=Column(JSON))
|
||||
|
||||
|
||||
class Event(EventBase, table=True):
|
||||
__tablename__ = "user_events" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "user_events"
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True),
|
||||
|
||||
@@ -16,8 +16,8 @@ FAILTIME_STRUCT = Struct("<100i")
|
||||
|
||||
|
||||
class FailTime(SQLModel, table=True):
|
||||
__tablename__ = "failtime" # pyright: ignore[reportAssignmentType]
|
||||
beatmap_id: int = Field(primary_key=True, index=True, foreign_key="beatmaps.id")
|
||||
__tablename__: str = "failtime"
|
||||
beatmap_id: int = Field(primary_key=True, foreign_key="beatmaps.id")
|
||||
exit: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False))
|
||||
fail: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False))
|
||||
|
||||
@@ -41,12 +41,8 @@ class FailTime(SQLModel, table=True):
|
||||
|
||||
|
||||
class FailTimeResp(BaseModel):
|
||||
exit: list[int] = Field(
|
||||
default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400))
|
||||
)
|
||||
fail: list[int] = Field(
|
||||
default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400))
|
||||
)
|
||||
exit: list[int] = Field(default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)))
|
||||
fail: list[int] = Field(default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)))
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, failtime: FailTime) -> "FailTimeResp":
|
||||
|
||||
@@ -16,7 +16,7 @@ from sqlmodel import (
|
||||
|
||||
|
||||
class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True):
|
||||
__tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "favourite_beatmapset"
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
|
||||
|
||||
@@ -75,9 +75,7 @@ class UserBase(UTCBaseModel, SQLModel):
|
||||
is_active: bool = True
|
||||
is_bot: bool = False
|
||||
is_supporter: bool = False
|
||||
last_visit: datetime | None = Field(
|
||||
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
last_visit: datetime | None = Field(default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)))
|
||||
pm_friends_only: bool = False
|
||||
profile_colour: str | None = None
|
||||
username: str = Field(max_length=32, unique=True, index=True)
|
||||
@@ -90,9 +88,7 @@ class UserBase(UTCBaseModel, SQLModel):
|
||||
is_restricted: bool = False
|
||||
# blocks
|
||||
cover: UserProfileCover = Field(
|
||||
default=UserProfileCover(
|
||||
url="https://assets.ppy.sh/user-profile-covers/default.jpeg"
|
||||
),
|
||||
default=UserProfileCover(url="https://assets.ppy.sh/user-profile-covers/default.jpeg"),
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
beatmap_playcounts_count: int = 0
|
||||
@@ -150,9 +146,9 @@ class UserBase(UTCBaseModel, SQLModel):
|
||||
|
||||
|
||||
class User(AsyncAttrs, UserBase, table=True):
|
||||
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "lazer_users"
|
||||
|
||||
id: int | None = Field(
|
||||
id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
|
||||
)
|
||||
@@ -160,16 +156,10 @@ class User(AsyncAttrs, UserBase, table=True):
|
||||
statistics: list[UserStatistics] = Relationship()
|
||||
achievement: list[UserAchievement] = Relationship(back_populates="user")
|
||||
team_membership: TeamMember | None = Relationship(back_populates="user")
|
||||
daily_challenge_stats: DailyChallengeStats | None = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
daily_challenge_stats: DailyChallengeStats | None = Relationship(back_populates="user")
|
||||
monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user")
|
||||
replays_watched_counts: list[ReplayWatchedCount] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
replays_watched_counts: list[ReplayWatchedCount] = Relationship(back_populates="user")
|
||||
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(back_populates="user")
|
||||
rank_history: list[RankHistory] = Relationship(
|
||||
back_populates="user",
|
||||
)
|
||||
@@ -178,16 +168,10 @@ class User(AsyncAttrs, UserBase, table=True):
|
||||
email: str = Field(max_length=254, unique=True, index=True, exclude=True)
|
||||
priv: int = Field(default=1, exclude=True)
|
||||
pw_bcrypt: str = Field(max_length=60, exclude=True)
|
||||
silence_end_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
|
||||
)
|
||||
donor_end_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
|
||||
)
|
||||
silence_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True)
|
||||
donor_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True)
|
||||
|
||||
async def is_user_can_pm(
|
||||
self, from_user: "User", session: AsyncSession
|
||||
) -> tuple[bool, str]:
|
||||
async def is_user_can_pm(self, from_user: "User", session: AsyncSession) -> tuple[bool, str]:
|
||||
from .relationship import Relationship, RelationshipType
|
||||
|
||||
from_relationship = (
|
||||
@@ -200,13 +184,10 @@ class User(AsyncAttrs, UserBase, table=True):
|
||||
).first()
|
||||
if from_relationship and from_relationship.type == RelationshipType.BLOCK:
|
||||
return False, "You have blocked the target user."
|
||||
if from_user.pm_friends_only and (
|
||||
not from_relationship or from_relationship.type != RelationshipType.FOLLOW
|
||||
):
|
||||
if from_user.pm_friends_only and (not from_relationship or from_relationship.type != RelationshipType.FOLLOW):
|
||||
return (
|
||||
False,
|
||||
"You have disabled non-friend communications "
|
||||
"and target user is not your friend.",
|
||||
"You have disabled non-friend communications and target user is not your friend.",
|
||||
)
|
||||
|
||||
relationship = (
|
||||
@@ -219,9 +200,7 @@ class User(AsyncAttrs, UserBase, table=True):
|
||||
).first()
|
||||
if relationship and relationship.type == RelationshipType.BLOCK:
|
||||
return False, "Target user has blocked you."
|
||||
if self.pm_friends_only and (
|
||||
not relationship or relationship.type != RelationshipType.FOLLOW
|
||||
):
|
||||
if self.pm_friends_only and (not relationship or relationship.type != RelationshipType.FOLLOW):
|
||||
return False, "Target user has disabled non-friend communications"
|
||||
return True, ""
|
||||
|
||||
@@ -288,9 +267,7 @@ class UserResp(UserBase):
|
||||
u = cls.model_validate(obj.model_dump())
|
||||
u.id = obj.id
|
||||
u.default_group = "bot" if u.is_bot else "default"
|
||||
u.country = Country(
|
||||
code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown")
|
||||
)
|
||||
u.country = Country(code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown"))
|
||||
u.follower_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
@@ -314,9 +291,7 @@ class UserResp(UserBase):
|
||||
redis = get_redis()
|
||||
u.is_online = await redis.exists(f"metadata:online:{obj.id}")
|
||||
u.cover_url = (
|
||||
obj.cover.get(
|
||||
"url", "https://assets.ppy.sh/user-profile-covers/default.jpeg"
|
||||
)
|
||||
obj.cover.get("url", "https://assets.ppy.sh/user-profile-covers/default.jpeg")
|
||||
if obj.cover
|
||||
else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
|
||||
)
|
||||
@@ -335,22 +310,15 @@ class UserResp(UserBase):
|
||||
]
|
||||
|
||||
if "team" in include:
|
||||
if await obj.awaitable_attrs.team_membership:
|
||||
assert obj.team_membership
|
||||
u.team = obj.team_membership.team
|
||||
if team_membership := await obj.awaitable_attrs.team_membership:
|
||||
u.team = team_membership.team
|
||||
|
||||
if "account_history" in include:
|
||||
u.account_history = [
|
||||
UserAccountHistoryResp.from_db(ah)
|
||||
for ah in await obj.awaitable_attrs.account_history
|
||||
]
|
||||
u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
|
||||
|
||||
if "daily_challenge_user_stats":
|
||||
if await obj.awaitable_attrs.daily_challenge_stats:
|
||||
assert obj.daily_challenge_stats
|
||||
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
|
||||
obj.daily_challenge_stats
|
||||
)
|
||||
if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats:
|
||||
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
|
||||
|
||||
if "statistics" in include:
|
||||
current_stattistics = None
|
||||
@@ -359,59 +327,40 @@ class UserResp(UserBase):
|
||||
current_stattistics = i
|
||||
break
|
||||
u.statistics = (
|
||||
await UserStatisticsResp.from_db(
|
||||
current_stattistics, session, obj.country_code
|
||||
)
|
||||
await UserStatisticsResp.from_db(current_stattistics, session, obj.country_code)
|
||||
if current_stattistics
|
||||
else None
|
||||
)
|
||||
|
||||
if "statistics_rulesets" in include:
|
||||
u.statistics_rulesets = {
|
||||
i.mode.value: await UserStatisticsResp.from_db(
|
||||
i, session, obj.country_code
|
||||
)
|
||||
i.mode.value: await UserStatisticsResp.from_db(i, session, obj.country_code)
|
||||
for i in await obj.awaitable_attrs.statistics
|
||||
}
|
||||
|
||||
if "monthly_playcounts" in include:
|
||||
u.monthly_playcounts = [
|
||||
CountResp.from_db(pc)
|
||||
for pc in await obj.awaitable_attrs.monthly_playcounts
|
||||
]
|
||||
u.monthly_playcounts = [CountResp.from_db(pc) for pc in await obj.awaitable_attrs.monthly_playcounts]
|
||||
if len(u.monthly_playcounts) == 1:
|
||||
d = u.monthly_playcounts[0].start_date
|
||||
u.monthly_playcounts.insert(
|
||||
0, CountResp(start_date=d - timedelta(days=20), count=0)
|
||||
)
|
||||
u.monthly_playcounts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
|
||||
|
||||
if "replays_watched_counts" in include:
|
||||
u.replay_watched_counts = [
|
||||
CountResp.from_db(rwc)
|
||||
for rwc in await obj.awaitable_attrs.replays_watched_counts
|
||||
CountResp.from_db(rwc) for rwc in await obj.awaitable_attrs.replays_watched_counts
|
||||
]
|
||||
if len(u.replay_watched_counts) == 1:
|
||||
d = u.replay_watched_counts[0].start_date
|
||||
u.replay_watched_counts.insert(
|
||||
0, CountResp(start_date=d - timedelta(days=20), count=0)
|
||||
)
|
||||
u.replay_watched_counts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
|
||||
|
||||
if "achievements" in include:
|
||||
u.user_achievements = [
|
||||
UserAchievementResp.from_db(ua)
|
||||
for ua in await obj.awaitable_attrs.achievement
|
||||
]
|
||||
u.user_achievements = [UserAchievementResp.from_db(ua) for ua in await obj.awaitable_attrs.achievement]
|
||||
if "rank_history" in include:
|
||||
rank_history = await RankHistoryResp.from_db(session, obj.id, ruleset)
|
||||
if len(rank_history.data) != 0:
|
||||
u.rank_history = rank_history
|
||||
|
||||
rank_top = (
|
||||
await session.exec(
|
||||
select(RankTop).where(
|
||||
RankTop.user_id == obj.id, RankTop.mode == ruleset
|
||||
)
|
||||
)
|
||||
await session.exec(select(RankTop).where(RankTop.user_id == obj.id, RankTop.mode == ruleset))
|
||||
).first()
|
||||
if rank_top:
|
||||
u.rank_highest = (
|
||||
@@ -425,9 +374,7 @@ class UserResp(UserBase):
|
||||
|
||||
u.favourite_beatmapset_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(FavouriteBeatmapset)
|
||||
.where(FavouriteBeatmapset.user_id == obj.id)
|
||||
select(func.count()).select_from(FavouriteBeatmapset).where(FavouriteBeatmapset.user_id == obj.id)
|
||||
)
|
||||
).one()
|
||||
u.scores_pinned_count = (
|
||||
@@ -478,17 +425,19 @@ class UserResp(UserBase):
|
||||
# 检查会话验证状态
|
||||
# 如果邮件验证功能被禁用,则始终设置 session_verified 为 true
|
||||
from app.config import settings
|
||||
|
||||
if not settings.enable_email_verification:
|
||||
u.session_verified = True
|
||||
else:
|
||||
# 如果用户有未验证的登录会话,则设置 session_verified 为 false
|
||||
from .email_verification import LoginSession
|
||||
|
||||
unverified_session = (
|
||||
await session.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == obj.id,
|
||||
LoginSession.is_verified == False,
|
||||
LoginSession.expires_at > datetime.now(UTC)
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
LoginSession.expires_at > datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
@@ -30,8 +30,8 @@ class MultiplayerEventBase(SQLModel, UTCBaseModel):
|
||||
|
||||
|
||||
class MultiplayerEvent(MultiplayerEventBase, table=True):
|
||||
__tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(
|
||||
__tablename__: str = "multiplayer_events"
|
||||
id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class Notification(SQLModel, table=True):
|
||||
__tablename__ = "notifications" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "notifications"
|
||||
|
||||
id: int = Field(primary_key=True, index=True, default=None)
|
||||
name: NotificationName = Field(index=True)
|
||||
@@ -30,7 +30,7 @@ class Notification(SQLModel, table=True):
|
||||
|
||||
|
||||
class UserNotification(SQLModel, table=True):
|
||||
__tablename__ = "user_notifications" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "user_notifications"
|
||||
id: int = Field(
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
@@ -40,9 +40,7 @@ class UserNotification(SQLModel, table=True):
|
||||
default=None,
|
||||
)
|
||||
notification_id: int = Field(index=True, foreign_key="notifications.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
is_read: bool = Field(index=True)
|
||||
|
||||
notification: Notification = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@@ -4,16 +4,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, UTC
|
||||
from sqlmodel import SQLModel, Field
|
||||
from sqlalchemy import Column, BigInteger, ForeignKey
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import BigInteger, Column, ForeignKey
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class PasswordReset(SQLModel, table=True):
|
||||
"""密码重置记录"""
|
||||
|
||||
|
||||
__tablename__: str = "password_resets"
|
||||
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
|
||||
email: str = Field(index=True)
|
||||
|
||||
@@ -21,16 +21,14 @@ class ItemAttemptsCountBase(SQLModel):
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
attempts: int = Field(default=0)
|
||||
completed: int = Field(default=0)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
accuracy: float = 0.0
|
||||
pp: float = 0
|
||||
total_score: int = 0
|
||||
|
||||
|
||||
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
|
||||
__tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "item_attempts_count"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
|
||||
user: User = Relationship()
|
||||
@@ -63,9 +61,7 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
|
||||
self.pp = sum(score.score.pp for score in playlist_scores)
|
||||
self.completed = len([score for score in playlist_scores if score.score.passed])
|
||||
self.accuracy = (
|
||||
sum(score.score.accuracy for score in playlist_scores) / self.completed
|
||||
if self.completed > 0
|
||||
else 0.0
|
||||
sum(score.score.accuracy for score in playlist_scores) / self.completed if self.completed > 0 else 0.0
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(self)
|
||||
|
||||
@@ -21,14 +21,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class PlaylistBestScore(SQLModel, table=True):
|
||||
__tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "playlist_best_scores"
|
||||
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
score_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
playlist_id: int = Field(index=True)
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
|
||||
@@ -50,7 +50,7 @@ class PlaylistBase(SQLModel, UTCBaseModel):
|
||||
|
||||
|
||||
class Playlist(PlaylistBase, table=True):
|
||||
__tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "room_playlists"
|
||||
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
|
||||
room_id: int = Field(foreign_key="rooms.id", exclude=True)
|
||||
|
||||
@@ -63,16 +63,12 @@ class Playlist(PlaylistBase, table=True):
|
||||
|
||||
@classmethod
|
||||
async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int:
|
||||
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(
|
||||
cls.room_id == room_id
|
||||
)
|
||||
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(cls.room_id == room_id)
|
||||
result = await session.exec(stmt)
|
||||
return result.one()
|
||||
|
||||
@classmethod
|
||||
async def from_hub(
|
||||
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
|
||||
) -> "Playlist":
|
||||
async def from_hub(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
|
||||
next_id = await cls.get_next_id_for_room(room_id, session=session)
|
||||
return cls(
|
||||
id=next_id,
|
||||
@@ -90,9 +86,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
|
||||
@classmethod
|
||||
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
db_playlist = await session.exec(
|
||||
select(cls).where(cls.id == playlist.id, cls.room_id == room_id)
|
||||
)
|
||||
db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id))
|
||||
db_playlist = db_playlist.first()
|
||||
if db_playlist is None:
|
||||
raise ValueError("Playlist item not found")
|
||||
@@ -108,9 +102,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def add_to_db(
|
||||
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
|
||||
):
|
||||
async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
db_playlist = await cls.from_hub(playlist, room_id, session)
|
||||
session.add(db_playlist)
|
||||
await session.commit()
|
||||
@@ -119,9 +111,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
|
||||
@classmethod
|
||||
async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession):
|
||||
db_playlist = await session.exec(
|
||||
select(cls).where(cls.id == item_id, cls.room_id == room_id)
|
||||
)
|
||||
db_playlist = await session.exec(select(cls).where(cls.id == item_id, cls.room_id == room_id))
|
||||
db_playlist = db_playlist.first()
|
||||
if db_playlist is None:
|
||||
raise ValueError("Playlist item not found")
|
||||
@@ -133,9 +123,7 @@ class PlaylistResp(PlaylistBase):
|
||||
beatmap: BeatmapResp | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, playlist: Playlist, include: list[str] = []
|
||||
) -> "PlaylistResp":
|
||||
async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp":
|
||||
data = playlist.model_dump()
|
||||
if "beatmap" in include:
|
||||
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
|
||||
|
||||
@@ -20,13 +20,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class PPBestScore(SQLModel, table=True):
|
||||
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
score_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
|
||||
)
|
||||
__tablename__: str = "best_scores"
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
||||
gamemode: GameMode = Field(index=True)
|
||||
pp: float = Field(
|
||||
|
||||
@@ -26,12 +26,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class RankHistory(SQLModel, table=True):
|
||||
__tablename__ = "rank_history" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "rank_history"
|
||||
|
||||
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
mode: GameMode
|
||||
rank: int
|
||||
date: dt = Field(
|
||||
@@ -43,12 +41,10 @@ class RankHistory(SQLModel, table=True):
|
||||
|
||||
|
||||
class RankTop(SQLModel, table=True):
|
||||
__tablename__ = "rank_top" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "rank_top"
|
||||
|
||||
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
mode: GameMode
|
||||
rank: int
|
||||
date: dt = Field(
|
||||
@@ -62,9 +58,7 @@ class RankHistoryResp(BaseModel):
|
||||
data: list[int]
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, session: AsyncSession, user_id: int, mode: GameMode
|
||||
) -> "RankHistoryResp":
|
||||
async def from_db(cls, session: AsyncSession, user_id: int, mode: GameMode) -> "RankHistoryResp":
|
||||
results = (
|
||||
await session.exec(
|
||||
select(RankHistory)
|
||||
|
||||
@@ -21,7 +21,7 @@ class RelationshipType(str, Enum):
|
||||
|
||||
|
||||
class Relationship(SQLModel, table=True):
|
||||
__tablename__ = "relationship" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "relationship"
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
|
||||
@@ -59,9 +59,7 @@ class RelationshipResp(BaseModel):
|
||||
type: RelationshipType
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, session: AsyncSession, relationship: Relationship
|
||||
) -> "RelationshipResp":
|
||||
async def from_db(cls, session: AsyncSession, relationship: Relationship) -> "RelationshipResp":
|
||||
target_relationship = (
|
||||
await session.exec(
|
||||
select(Relationship).where(
|
||||
|
||||
@@ -58,11 +58,9 @@ class RoomBase(SQLModel, UTCBaseModel):
|
||||
|
||||
|
||||
class Room(AsyncAttrs, RoomBase, table=True):
|
||||
__tablename__ = "rooms" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "rooms"
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
host_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
|
||||
host: User = Relationship()
|
||||
playlist: list[Playlist] = Relationship(
|
||||
@@ -109,12 +107,8 @@ class RoomResp(RoomBase):
|
||||
if not playlist.expired:
|
||||
stats.count_active += 1
|
||||
rulesets.add(playlist.ruleset_id)
|
||||
difficulty_range.min = min(
|
||||
difficulty_range.min, playlist.beatmap.difficulty_rating
|
||||
)
|
||||
difficulty_range.max = max(
|
||||
difficulty_range.max, playlist.beatmap.difficulty_rating
|
||||
)
|
||||
difficulty_range.min = min(difficulty_range.min, playlist.beatmap.difficulty_rating)
|
||||
difficulty_range.max = max(difficulty_range.max, playlist.beatmap.difficulty_rating)
|
||||
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
|
||||
stats.ruleset_ids = list(rulesets)
|
||||
resp.playlist_item_stats = stats
|
||||
@@ -137,13 +131,9 @@ class RoomResp(RoomBase):
|
||||
include=["statistics"],
|
||||
)
|
||||
)
|
||||
resp.host = await UserResp.from_db(
|
||||
await room.awaitable_attrs.host, session, include=["statistics"]
|
||||
)
|
||||
resp.host = await UserResp.from_db(await room.awaitable_attrs.host, session, include=["statistics"])
|
||||
if "current_user_score" in include and user:
|
||||
resp.current_user_score = await PlaylistAggregateScore.from_db(
|
||||
room.id, user.id, session
|
||||
)
|
||||
resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -18,22 +18,16 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True):
|
||||
__tablename__ = "room_participated_users" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "room_participated_users"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
)
|
||||
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True))
|
||||
room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), nullable=False))
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False))
|
||||
joined_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
left_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=True), default=None
|
||||
)
|
||||
left_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True), default=None)
|
||||
|
||||
room: "Room" = Relationship()
|
||||
user: "User" = Relationship()
|
||||
|
||||
@@ -47,9 +47,9 @@ from .score_token import ScoreToken
|
||||
|
||||
from pydantic import field_serializer, field_validator
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime
|
||||
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime, TextClause
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import Mapped, aliased
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
@@ -76,9 +76,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
accuracy: float
|
||||
map_md5: str = Field(max_length=32, index=True)
|
||||
build_id: int | None = Field(default=None)
|
||||
classic_total_score: int | None = Field(
|
||||
default=0, sa_column=Column(BigInteger)
|
||||
) # solo_score
|
||||
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score
|
||||
ended_at: datetime = Field(sa_column=Column(DateTime))
|
||||
has_replay: bool = Field(sa_column=Column(Boolean))
|
||||
max_combo: int
|
||||
@@ -91,14 +89,10 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
room_id: int | None = Field(default=None) # multiplayer
|
||||
started_at: datetime = Field(sa_column=Column(DateTime))
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
total_score_without_mods: int = Field(
|
||||
default=0, sa_column=Column(BigInteger), exclude=True
|
||||
)
|
||||
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
|
||||
type: str
|
||||
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
||||
maximum_statistics: ScoreStatistics = Field(
|
||||
sa_column=Column(JSON), default_factory=dict
|
||||
)
|
||||
maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict)
|
||||
|
||||
@field_validator("maximum_statistics", mode="before")
|
||||
@classmethod
|
||||
@@ -147,10 +141,8 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
|
||||
|
||||
class Score(ScoreBase, table=True):
|
||||
__tablename__ = "scores" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(
|
||||
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
|
||||
)
|
||||
__tablename__: str = "scores"
|
||||
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
@@ -193,8 +185,8 @@ class Score(ScoreBase, table=True):
|
||||
return str(v)
|
||||
|
||||
# optional
|
||||
beatmap: Beatmap = Relationship()
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
beatmap: Mapped[Beatmap] = Relationship()
|
||||
user: Mapped[User] = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@property
|
||||
def is_perfect_combo(self) -> bool:
|
||||
@@ -205,11 +197,7 @@ class Score(ScoreBase, table=True):
|
||||
*where_clauses: ColumnExpressionArgument[bool] | bool,
|
||||
) -> SelectOfScalar["Score"]:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()
|
||||
)
|
||||
.label("rn")
|
||||
func.row_number().over(partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()).label("rn")
|
||||
)
|
||||
subq = select(Score, rownum).where(*where_clauses).subquery()
|
||||
best = aliased(Score, subq, adapt_on_names=True)
|
||||
@@ -296,12 +284,9 @@ class ScoreResp(ScoreBase):
|
||||
await session.refresh(score)
|
||||
|
||||
s = cls.model_validate(score.model_dump())
|
||||
assert score.id
|
||||
await score.awaitable_attrs.beatmap
|
||||
s.beatmap = await BeatmapResp.from_db(score.beatmap)
|
||||
s.beatmapset = await BeatmapsetResp.from_db(
|
||||
score.beatmap.beatmapset, session=session, user=score.user
|
||||
)
|
||||
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user)
|
||||
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
|
||||
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
|
||||
s.ruleset_id = int(score.gamemode)
|
||||
@@ -371,11 +356,7 @@ class ScoreAround(SQLModel):
|
||||
|
||||
async def get_best_id(session: AsyncSession, score_id: int) -> None:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()
|
||||
)
|
||||
.label("rn")
|
||||
func.row_number().over(partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()).label("rn")
|
||||
)
|
||||
subq = select(PPBestScore, rownum).subquery()
|
||||
stmt = select(subq.c.rn).where(subq.c.score_id == score_id)
|
||||
@@ -389,8 +370,8 @@ async def _score_where(
|
||||
mode: GameMode,
|
||||
mods: list[str] | None = None,
|
||||
user: User | None = None,
|
||||
) -> list[ColumnElement[bool]] | None:
|
||||
wheres = [
|
||||
) -> list[ColumnElement[bool] | TextClause] | None:
|
||||
wheres: list[ColumnElement[bool] | TextClause] = [
|
||||
col(BestScore.beatmap_id) == beatmap,
|
||||
col(BestScore.gamemode) == mode,
|
||||
]
|
||||
@@ -410,9 +391,7 @@ async def _score_where(
|
||||
return None
|
||||
elif type == LeaderboardType.COUNTRY:
|
||||
if user and user.is_supporter:
|
||||
wheres.append(
|
||||
col(BestScore.user).has(col(User.country_code) == user.country_code)
|
||||
)
|
||||
wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
|
||||
else:
|
||||
return None
|
||||
elif type == LeaderboardType.TEAM:
|
||||
@@ -420,18 +399,14 @@ async def _score_where(
|
||||
team_membership = await user.awaitable_attrs.team_membership
|
||||
if team_membership:
|
||||
team_id = team_membership.team_id
|
||||
wheres.append(
|
||||
col(BestScore.user).has(
|
||||
col(User.team_membership).has(TeamMember.team_id == team_id)
|
||||
)
|
||||
)
|
||||
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
|
||||
if mods:
|
||||
if user and user.is_supporter:
|
||||
wheres.append(
|
||||
text(
|
||||
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
|
||||
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
|
||||
).params(w=json.dumps(mods)) # pyright: ignore[reportArgumentType]
|
||||
).params(w=json.dumps(mods))
|
||||
)
|
||||
else:
|
||||
return None
|
||||
@@ -654,18 +629,14 @@ def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]:
|
||||
+ (score.nsmall_tick_hit or 0)
|
||||
)
|
||||
total_obj = 0
|
||||
for statistics, count in (
|
||||
score.maximum_statistics.items() if score.maximum_statistics else {}
|
||||
):
|
||||
for statistics, count in score.maximum_statistics.items() if score.maximum_statistics else {}:
|
||||
if not isinstance(statistics, HitResult):
|
||||
statistics = HitResult(statistics)
|
||||
if statistics.is_scorable():
|
||||
total_obj += count
|
||||
|
||||
return total_length, score.passed or (
|
||||
total_length > 8
|
||||
and score.total_score >= 5000
|
||||
and total_obj_hited >= min(0.1 * total_obj, 20)
|
||||
total_length > 8 and score.total_score >= 5000 and total_obj_hited >= min(0.1 * total_obj, 20)
|
||||
)
|
||||
|
||||
|
||||
@@ -678,12 +649,8 @@ async def process_user(
|
||||
ranked: bool = False,
|
||||
has_leaderboard: bool = False,
|
||||
):
|
||||
assert user.id
|
||||
assert score.id
|
||||
mod_for_save = mod_to_save(score.mods)
|
||||
previous_score_best = await get_user_best_score_in_beatmap(
|
||||
session, score.beatmap_id, user.id, score.gamemode
|
||||
)
|
||||
previous_score_best = await get_user_best_score_in_beatmap(session, score.beatmap_id, user.id, score.gamemode)
|
||||
previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap(
|
||||
session, score.beatmap_id, user.id, mod_for_save, score.gamemode
|
||||
)
|
||||
@@ -698,9 +665,7 @@ async def process_user(
|
||||
)
|
||||
).first()
|
||||
if mouthly_playcount is None:
|
||||
mouthly_playcount = MonthlyPlaycounts(
|
||||
user_id=user.id, year=date.today().year, month=date.today().month
|
||||
)
|
||||
mouthly_playcount = MonthlyPlaycounts(user_id=user.id, year=date.today().year, month=date.today().month)
|
||||
add_to_db = True
|
||||
statistics = None
|
||||
for i in await user.awaitable_attrs.statistics:
|
||||
@@ -708,17 +673,11 @@ async def process_user(
|
||||
statistics = i
|
||||
break
|
||||
if statistics is None:
|
||||
raise ValueError(
|
||||
f"User {user.id} does not have statistics for mode {score.gamemode.value}"
|
||||
)
|
||||
raise ValueError(f"User {user.id} does not have statistics for mode {score.gamemode.value}")
|
||||
|
||||
# pc, pt, tth, tts
|
||||
statistics.total_score += score.total_score
|
||||
difference = (
|
||||
score.total_score - previous_score_best.total_score
|
||||
if previous_score_best
|
||||
else score.total_score
|
||||
)
|
||||
difference = score.total_score - previous_score_best.total_score if previous_score_best else score.total_score
|
||||
if difference > 0 and score.passed and ranked:
|
||||
match score.rank:
|
||||
case Rank.X:
|
||||
@@ -746,11 +705,8 @@ async def process_user(
|
||||
statistics.ranked_score += difference
|
||||
statistics.level_current = calculate_score_to_level(statistics.total_score)
|
||||
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)
|
||||
new_score_position = await get_score_position_by_user(
|
||||
session, score.beatmap_id, user, score.gamemode
|
||||
)
|
||||
new_score_position = await get_score_position_by_user(session, score.beatmap_id, user, score.gamemode)
|
||||
total_users = await session.exec(select(func.count()).select_from(User))
|
||||
assert total_users is not None
|
||||
score_range = min(50, math.ceil(float(total_users.one()) * 0.01))
|
||||
if new_score_position <= score_range and new_score_position > 0:
|
||||
# Get the scores that might be displaced
|
||||
@@ -774,11 +730,7 @@ async def process_user(
|
||||
)
|
||||
|
||||
# If this score was previously in top positions but now pushed out
|
||||
if (
|
||||
i < score_range
|
||||
and displaced_position > score_range
|
||||
and displaced_position is not None
|
||||
):
|
||||
if i < score_range and displaced_position > score_range and displaced_position is not None:
|
||||
# Create rank lost event for the displaced user
|
||||
rank_lost_event = Event(
|
||||
created_at=datetime.now(UTC),
|
||||
@@ -814,10 +766,7 @@ async def process_user(
|
||||
)
|
||||
|
||||
# 情况3: 有最佳分数记录和该mod组合的记录,且是同一个记录,更新得分更高的情况
|
||||
elif (
|
||||
previous_score_best.score_id == previous_score_best_mod.score_id
|
||||
and difference > 0
|
||||
):
|
||||
elif previous_score_best.score_id == previous_score_best_mod.score_id and difference > 0:
|
||||
previous_score_best.total_score = score.total_score
|
||||
previous_score_best.rank = score.rank
|
||||
previous_score_best.score_id = score.id
|
||||
@@ -847,9 +796,7 @@ async def process_user(
|
||||
statistics.count_300 += score.n300 + score.ngeki
|
||||
statistics.count_50 += score.n50
|
||||
statistics.count_miss += score.nmiss
|
||||
statistics.total_hits += (
|
||||
score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
|
||||
)
|
||||
statistics.total_hits += score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
|
||||
|
||||
if score.passed and ranked:
|
||||
with session.no_autoflush:
|
||||
@@ -885,7 +832,6 @@ async def process_score(
|
||||
item_id: int | None = None,
|
||||
room_id: int | None = None,
|
||||
) -> Score:
|
||||
assert user.id
|
||||
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
|
||||
gamemode = GameMode.from_int(info.ruleset_id).to_special_mode(info.mods)
|
||||
score = Score(
|
||||
@@ -922,20 +868,15 @@ async def process_score(
|
||||
if can_get_pp:
|
||||
from app.calculator import pre_fetch_and_calculate_pp
|
||||
|
||||
pp = await pre_fetch_and_calculate_pp(
|
||||
score, beatmap_id, session, redis, fetcher
|
||||
)
|
||||
pp = await pre_fetch_and_calculate_pp(score, beatmap_id, session, redis, fetcher)
|
||||
score.pp = pp
|
||||
session.add(score)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
await session.refresh(score)
|
||||
if can_get_pp and score.pp != 0:
|
||||
previous_pp_best = await get_user_best_pp_in_beatmap(
|
||||
session, beatmap_id, user_id, score.gamemode
|
||||
)
|
||||
previous_pp_best = await get_user_best_pp_in_beatmap(session, beatmap_id, user_id, score.gamemode)
|
||||
if previous_pp_best is None or score.pp > previous_pp_best.pp:
|
||||
assert score.id
|
||||
best_score = PPBestScore(
|
||||
user_id=user_id,
|
||||
score_id=score.id,
|
||||
|
||||
@@ -7,6 +7,7 @@ from .beatmap import Beatmap
|
||||
from .lazer_user import User
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
|
||||
@@ -14,16 +15,12 @@ class ScoreTokenBase(SQLModel, UTCBaseModel):
|
||||
score_id: int | None = Field(sa_column=Column(BigInteger), default=None)
|
||||
ruleset_id: GameMode
|
||||
playlist_item_id: int | None = Field(default=None) # playlist
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
|
||||
|
||||
|
||||
class ScoreToken(ScoreTokenBase, table=True):
|
||||
__tablename__ = "score_tokens" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "score_tokens"
|
||||
__table_args__ = (Index("idx_user_playlist", "user_id", "playlist_item_id"),)
|
||||
|
||||
id: int | None = Field(
|
||||
@@ -37,8 +34,8 @@ class ScoreToken(ScoreTokenBase, table=True):
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id")
|
||||
user: User = Relationship()
|
||||
beatmap: Beatmap = Relationship()
|
||||
user: Mapped[User] = Relationship()
|
||||
beatmap: Mapped[Beatmap] = Relationship()
|
||||
|
||||
|
||||
class ScoreTokenResp(ScoreTokenBase):
|
||||
|
||||
@@ -58,7 +58,7 @@ class UserStatisticsBase(SQLModel):
|
||||
|
||||
|
||||
class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
|
||||
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "lazer_user_statistics"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
@@ -123,9 +123,7 @@ class UserStatisticsResp(UserStatisticsBase):
|
||||
if "user" in include:
|
||||
from .lazer_user import RANKING_INCLUDES, UserResp
|
||||
|
||||
user = await UserResp.from_db(
|
||||
await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES
|
||||
)
|
||||
user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
|
||||
s.user = user
|
||||
user_country = user.country_code
|
||||
|
||||
@@ -149,9 +147,7 @@ class UserStatisticsResp(UserStatisticsBase):
|
||||
return s
|
||||
|
||||
|
||||
async def get_rank(
|
||||
session: AsyncSession, statistics: UserStatistics, country: str | None = None
|
||||
) -> int | None:
|
||||
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
|
||||
from .lazer_user import User
|
||||
|
||||
query = select(
|
||||
@@ -168,9 +164,7 @@ async def get_rank(
|
||||
|
||||
subq = query.subquery()
|
||||
|
||||
result = await session.exec(
|
||||
select(subq.c.rank).where(subq.c.user_id == statistics.user_id)
|
||||
)
|
||||
result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id))
|
||||
|
||||
rank = result.first()
|
||||
if rank is None:
|
||||
|
||||
@@ -11,9 +11,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Team(SQLModel, UTCBaseModel, table=True):
|
||||
__tablename__ = "teams" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "teams"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
name: str = Field(max_length=100)
|
||||
short_name: str = Field(max_length=10)
|
||||
flag_url: str | None = Field(default=None)
|
||||
@@ -26,34 +26,22 @@ class Team(SQLModel, UTCBaseModel, table=True):
|
||||
|
||||
|
||||
class TeamMember(SQLModel, UTCBaseModel, table=True):
|
||||
__tablename__ = "team_members" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "team_members"
|
||||
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
|
||||
team_id: int = Field(foreign_key="teams.id")
|
||||
joined_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
joined_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
|
||||
|
||||
user: "User" = Relationship(
|
||||
back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
team: "Team" = Relationship(
|
||||
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
user: "User" = Relationship(back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"})
|
||||
team: "Team" = Relationship(back_populates="members", sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
|
||||
class TeamRequest(SQLModel, UTCBaseModel, table=True):
|
||||
__tablename__ = "team_requests" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "team_requests"
|
||||
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
|
||||
team_id: int = Field(foreign_key="teams.id", primary_key=True)
|
||||
requested_at: datetime = Field(
|
||||
default=datetime.now(UTC), sa_column=Column(DateTime)
|
||||
)
|
||||
requested_at: datetime = Field(default=datetime.now(UTC), sa_column=Column(DateTime))
|
||||
|
||||
user: "User" = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
team: "Team" = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@@ -22,7 +22,7 @@ class UserAccountHistoryBase(SQLModel, UTCBaseModel):
|
||||
|
||||
|
||||
class UserAccountHistory(UserAccountHistoryBase, table=True):
|
||||
__tablename__ = "user_account_history" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "user_account_history"
|
||||
|
||||
id: int | None = Field(
|
||||
sa_column=Column(
|
||||
@@ -32,9 +32,7 @@ class UserAccountHistory(UserAccountHistoryBase, table=True):
|
||||
primary_key=True,
|
||||
)
|
||||
)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
|
||||
|
||||
class UserAccountHistoryResp(UserAccountHistoryBase):
|
||||
|
||||
@@ -10,27 +10,17 @@ from sqlmodel import Field, SQLModel
|
||||
class UserLoginLog(SQLModel, table=True):
|
||||
"""User login log table"""
|
||||
|
||||
__tablename__ = "user_login_log" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "user_login_log"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, description="Record ID")
|
||||
user_id: int = Field(index=True, description="User ID")
|
||||
ip_address: str = Field(
|
||||
max_length=45, index=True, description="IP address (supports IPv4 and IPv6)"
|
||||
)
|
||||
user_agent: str | None = Field(
|
||||
default=None, max_length=500, description="User agent information"
|
||||
)
|
||||
login_time: datetime = Field(
|
||||
default_factory=datetime.utcnow, description="Login time"
|
||||
)
|
||||
ip_address: str = Field(max_length=45, index=True, description="IP address (supports IPv4 and IPv6)")
|
||||
user_agent: str | None = Field(default=None, max_length=500, description="User agent information")
|
||||
login_time: datetime = Field(default_factory=datetime.utcnow, description="Login time")
|
||||
|
||||
# GeoIP information
|
||||
country_code: str | None = Field(
|
||||
default=None, max_length=2, description="Country code"
|
||||
)
|
||||
country_name: str | None = Field(
|
||||
default=None, max_length=100, description="Country name"
|
||||
)
|
||||
country_code: str | None = Field(default=None, max_length=2, description="Country code")
|
||||
country_name: str | None = Field(default=None, max_length=100, description="Country name")
|
||||
city_name: str | None = Field(default=None, max_length=100, description="City name")
|
||||
latitude: str | None = Field(default=None, max_length=20, description="Latitude")
|
||||
longitude: str | None = Field(default=None, max_length=20, description="Longitude")
|
||||
@@ -38,22 +28,14 @@ class UserLoginLog(SQLModel, table=True):
|
||||
|
||||
# ASN information
|
||||
asn: int | None = Field(default=None, description="Autonomous System Number")
|
||||
organization: str | None = Field(
|
||||
default=None, max_length=200, description="Organization name"
|
||||
)
|
||||
organization: str | None = Field(default=None, max_length=200, description="Organization name")
|
||||
|
||||
# Login status
|
||||
login_success: bool = Field(
|
||||
default=True, description="Whether the login was successful"
|
||||
)
|
||||
login_method: str = Field(
|
||||
max_length=50, description="Login method (password/oauth/etc.)"
|
||||
)
|
||||
login_success: bool = Field(default=True, description="Whether the login was successful")
|
||||
login_method: str = Field(max_length=50, description="Login method (password/oauth/etc.)")
|
||||
|
||||
# Additional information
|
||||
notes: str | None = Field(
|
||||
default=None, max_length=500, description="Additional notes"
|
||||
)
|
||||
notes: str | None = Field(default=None, max_length=500, description="Additional notes")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
Reference in New Issue
Block a user