refactor(project): make pyright & ruff happy

This commit is contained in:
MingxuanGame
2025-08-22 08:21:52 +00:00
parent 3b1d7a2234
commit 598fcc8b38
157 changed files with 2382 additions and 4590 deletions

View File

@@ -65,9 +65,7 @@ async def to_the_core(
# using either of the mods specified: DT, NC
if not score.passed:
return False
if (
"Nightcore" not in beatmap.beatmapset.title
) and "Nightcore" not in beatmap.beatmapset.artist:
if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist:
return False
mods_ = mod_to_save(score.mods)
if "DT" not in mods_ or "NC" not in mods_:
@@ -118,9 +116,7 @@ async def reckless_adandon(
fetcher = await get_fetcher()
redis = get_redis()
mods_ = score.mods.copy()
attribute = await calculate_beatmap_attributes(
beatmap.id, score.gamemode, mods_, redis, fetcher
)
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
if attribute.star_rating < 3:
return False
return True
@@ -186,9 +182,7 @@ async def slow_and_steady(
fetcher = await get_fetcher()
redis = get_redis()
mods_ = score.mods.copy()
attribute = await calculate_beatmap_attributes(
beatmap.id, score.gamemode, mods_, redis, fetcher
)
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
return attribute.star_rating >= 3
@@ -218,9 +212,7 @@ async def sognare(
mods_ = mod_to_save(score.mods)
if "HT" not in mods_:
return False
return (
beatmap.beatmapset.artist == "LeaF" and beatmap.beatmapset.title == "Evanescent"
)
return beatmap.beatmapset.artist == "LeaF" and beatmap.beatmapset.title == "Evanescent"
async def realtor_extraordinaire(
@@ -234,10 +226,7 @@ async def realtor_extraordinaire(
mods_ = mod_to_save(score.mods)
if not ("DT" in mods_ or "NC" in mods_) or "HR" not in mods_:
return False
return (
beatmap.beatmapset.artist == "cYsmix"
and beatmap.beatmapset.title == "House With Legs"
)
return beatmap.beatmapset.artist == "cYsmix" and beatmap.beatmapset.title == "House With Legs"
async def impeccable(
@@ -255,9 +244,7 @@ async def impeccable(
fetcher = await get_fetcher()
redis = get_redis()
mods_ = score.mods.copy()
attribute = await calculate_beatmap_attributes(
beatmap.id, score.gamemode, mods_, redis, fetcher
)
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
return attribute.star_rating >= 4
@@ -274,18 +261,14 @@ async def aeon(
mods_ = mod_to_save(score.mods)
if "FL" not in mods_ or "HD" not in mods_ or "HT" not in mods_:
return False
if not beatmap.beatmapset.ranked_date or beatmap.beatmapset.ranked_date > datetime(
2012, 1, 1
):
if not beatmap.beatmapset.ranked_date or beatmap.beatmapset.ranked_date > datetime(2012, 1, 1):
return False
if beatmap.total_length < 180:
return False
fetcher = await get_fetcher()
redis = get_redis()
mods_ = score.mods.copy()
attribute = await calculate_beatmap_attributes(
beatmap.id, score.gamemode, mods_, redis, fetcher
)
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
return attribute.star_rating >= 4
@@ -297,10 +280,7 @@ async def quick_maths(
# Get exactly 34 misses on any difficulty of Function Phantom - Variable.
if score.nmiss != 34:
return False
return (
beatmap.beatmapset.artist == "Function Phantom"
and beatmap.beatmapset.title == "Variable"
)
return beatmap.beatmapset.artist == "Function Phantom" and beatmap.beatmapset.title == "Variable"
async def kaleidoscope(
@@ -328,8 +308,7 @@ async def valediction(
return (
score.passed
and beatmap.beatmapset.artist == "a_hisa"
and beatmap.beatmapset.title
== "Alexithymia | Lupinus | Tokei no Heya to Seishin Sekai"
and beatmap.beatmapset.title == "Alexithymia | Lupinus | Tokei no Heya to Seishin Sekai"
and score.accuracy >= 0.9
)
@@ -342,9 +321,7 @@ async def right_on_time(
# Submit a score on Kola Kid - timer on the first minute of any hour
if not score.passed:
return False
if not (
beatmap.beatmapset.artist == "Kola Kid" and beatmap.beatmapset.title == "timer"
):
if not (beatmap.beatmapset.artist == "Kola Kid" and beatmap.beatmapset.title == "timer"):
return False
return score.ended_at.minute == 0
@@ -361,9 +338,7 @@ async def not_again(
return False
if score.accuracy < 0.99:
return False
return (
beatmap.beatmapset.artist == "ARForest" and beatmap.beatmapset.title == "Regret"
)
return beatmap.beatmapset.artist == "ARForest" and beatmap.beatmapset.title == "Regret"
async def deliberation(
@@ -377,18 +352,13 @@ async def deliberation(
mods_ = mod_to_save(score.mods)
if "HT" not in mods_:
return False
if (
not beatmap.beatmap_status.has_pp()
and beatmap.beatmap_status != BeatmapRankStatus.LOVED
):
if not beatmap.beatmap_status.has_pp() and beatmap.beatmap_status != BeatmapRankStatus.LOVED:
return False
fetcher = await get_fetcher()
redis = get_redis()
mods_copy = score.mods.copy()
attribute = await calculate_beatmap_attributes(
beatmap.id, score.gamemode, mods_copy, redis, fetcher
)
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_copy, redis, fetcher)
return attribute.star_rating >= 6

View File

@@ -72,7 +72,7 @@ MEDALS: Medals = {
Achievement(
id=93,
name="Sweet Rave Party",
desc="Founded in the fine tradition of changing things that were just fine as they were.", # noqa: E501
desc="Founded in the fine tradition of changing things that were just fine as they were.",
assets_id="all-intro-nightcore",
): partial(process_mod, "NC"),
Achievement(

View File

@@ -16,11 +16,7 @@ async def process_combo(
score: Score,
beatmap: Beatmap,
) -> bool:
if (
not score.passed
or not beatmap.beatmap_status.has_pp()
or score.gamemode != GameMode.OSU
):
if not score.passed or not beatmap.beatmap_status.has_pp() or score.gamemode != GameMode.OSU:
return False
if combo < 1:
return False

View File

@@ -44,9 +44,7 @@ async def process_skill(
redis = get_redis()
mods_ = score.mods.copy()
mods_.sort(key=lambda x: x["acronym"])
attribute = await calculate_beatmap_attributes(
beatmap.id, score.gamemode, mods_, redis, fetcher
)
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
if attribute.star_rating < star or attribute.star_rating >= star + 1:
return False
if type == "fc" and not score.is_perfect_combo:

View File

@@ -43,9 +43,7 @@ def validate_username(username: str) -> list[str]:
# 检查用户名格式(只允许字母、数字、下划线、连字符)
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
errors.append(
"Username can only contain letters, numbers, underscores, and hyphens"
)
errors.append("Username can only contain letters, numbers, underscores, and hyphens")
# 检查是否以数字开头
if username[0].isdigit():
@@ -104,9 +102,7 @@ def get_password_hash(password: str) -> str:
return pw_bcrypt.decode()
async def authenticate_user_legacy(
db: AsyncSession, name: str, password: str
) -> User | None:
async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) -> User | None:
"""
验证用户身份 - 使用类似 from_login 的逻辑
"""
@@ -145,9 +141,7 @@ async def authenticate_user_legacy(
return None
async def authenticate_user(
db: AsyncSession, username: str, password: str
) -> User | None:
async def authenticate_user(db: AsyncSession, username: str, password: str) -> User | None:
"""验证用户身份"""
return await authenticate_user_legacy(db, username, password)
@@ -158,14 +152,10 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s
if expires_delta:
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(UTC) + timedelta(
minutes=settings.access_token_expire_minutes
)
expire = datetime.now(UTC) + timedelta(minutes=settings.access_token_expire_minutes)
to_encode.update({"exp": expire, "random": secrets.token_hex(16)})
encoded_jwt = jwt.encode(
to_encode, settings.secret_key, algorithm=settings.algorithm
)
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
return encoded_jwt
@@ -178,20 +168,20 @@ def generate_refresh_token() -> str:
async def invalidate_user_tokens(db: AsyncSession, user_id: int) -> int:
"""使指定用户的所有令牌失效
返回删除的令牌数量
"""
# 使用 select 先获取所有令牌
stmt = select(OAuthToken).where(OAuthToken.user_id == user_id)
result = await db.exec(stmt)
tokens = result.all()
# 逐个删除令牌
count = 0
for token in tokens:
await db.delete(token)
count += 1
# 提交更改
await db.commit()
return count
@@ -200,9 +190,7 @@ async def invalidate_user_tokens(db: AsyncSession, user_id: int) -> int:
def verify_token(token: str) -> dict | None:
"""验证访问令牌"""
try:
payload = jwt.decode(
token, settings.secret_key, algorithms=[settings.algorithm]
)
payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
return payload
except JWTError:
return None
@@ -221,17 +209,13 @@ async def store_token(
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
# 删除用户的旧令牌
statement = select(OAuthToken).where(
OAuthToken.user_id == user_id, OAuthToken.client_id == client_id
)
statement = select(OAuthToken).where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id)
old_tokens = (await db.exec(statement)).all()
for token in old_tokens:
await db.delete(token)
# 检查是否有重复的 access_token
duplicate_token = (
await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))
).first()
duplicate_token = (await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))).first()
if duplicate_token:
await db.delete(duplicate_token)
@@ -250,9 +234,7 @@ async def store_token(
return token_record
async def get_token_by_access_token(
db: AsyncSession, access_token: str
) -> OAuthToken | None:
async def get_token_by_access_token(db: AsyncSession, access_token: str) -> OAuthToken | None:
"""根据访问令牌获取令牌记录"""
statement = select(OAuthToken).where(
OAuthToken.access_token == access_token,
@@ -261,9 +243,7 @@ async def get_token_by_access_token(
return (await db.exec(statement)).first()
async def get_token_by_refresh_token(
db: AsyncSession, refresh_token: str
) -> OAuthToken | None:
async def get_token_by_refresh_token(db: AsyncSession, refresh_token: str) -> OAuthToken | None:
"""根据刷新令牌获取令牌记录"""
statement = select(OAuthToken).where(
OAuthToken.refresh_token == refresh_token,

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from copy import deepcopy
from enum import Enum
import math
@@ -67,11 +68,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
if settings.suspicious_score_check:
beatmap_banned = (
await session.exec(
select(exists()).where(
col(BannedBeatmaps.beatmap_id) == score.beatmap_id
)
)
await session.exec(select(exists()).where(col(BannedBeatmaps.beatmap_id) == score.beatmap_id))
).first()
if beatmap_banned:
return 0
@@ -82,12 +79,9 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
logger.warning(f"Beatmap {score.beatmap_id} is suspicious, banned")
return 0
except Exception:
logger.exception(
f"Error checking if beatmap {score.beatmap_id} is suspicious"
)
logger.exception(f"Error checking if beatmap {score.beatmap_id} is suspicious")
# 使用线程池执行计算密集型操作以避免阻塞事件循环
import asyncio
loop = asyncio.get_event_loop()
@@ -118,9 +112,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
pp = attrs.pp
# mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp
if settings.suspicious_score_check and (
(attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 2300
):
if settings.suspicious_score_check and ((attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 2300):
logger.warning(
f"User {score.user_id} played {score.beatmap_id} "
f"(star={attrs.difficulty.stars}) with {pp=} "
@@ -131,9 +123,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
return pp
async def pre_fetch_and_calculate_pp(
score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher
) -> float:
async def pre_fetch_and_calculate_pp(score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher) -> float:
"""
优化版PP计算预先获取beatmap文件并使用缓存
"""
@@ -144,9 +134,7 @@ async def pre_fetch_and_calculate_pp(
# 快速检查是否被封禁
if settings.suspicious_score_check:
beatmap_banned = (
await session.exec(
select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id)
)
await session.exec(select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id))
).first()
if beatmap_banned:
return 0
@@ -202,9 +190,7 @@ async def batch_calculate_pp(
banned_beatmaps = set()
if settings.suspicious_score_check:
banned_results = await session.exec(
select(BannedBeatmaps.beatmap_id).where(
col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids)
)
select(BannedBeatmaps.beatmap_id).where(col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids))
)
banned_beatmaps = set(banned_results.all())
@@ -380,9 +366,7 @@ def calculate_score_to_level(total_score: int) -> float:
level = 0.0
while remaining_score > 0:
next_level_requirement = to_next_level[
min(len(to_next_level) - 1, round(level))
]
next_level_requirement = to_next_level[min(len(to_next_level) - 1, round(level))]
level += min(1, remaining_score / next_level_requirement)
remaining_score -= next_level_requirement
@@ -417,9 +401,7 @@ class Threshold(int, Enum):
NOTE_POSX_THRESHOLD = 512 # x: [-512,512]
NOTE_POSY_THRESHOLD = 384 # y: [-384,384]
POS_ERROR_THRESHOLD = (
1280 * 50
) # 超过这么多个物件(包括滑条控制点)的位置有问题就毙掉
POS_ERROR_THRESHOLD = 1280 * 50 # 超过这么多个物件(包括滑条控制点)的位置有问题就毙掉
SLIDER_REPEAT_THRESHOLD = 5000
@@ -469,10 +451,7 @@ def is_2b(hit_objects: list[HitObject]) -> bool:
def is_suspicious_beatmap(content: str) -> bool:
osufile = OsuFile(content=content.encode("utf-8")).parse_file()
if (
osufile.hit_objects[-1].start_time - osufile.hit_objects[0].start_time
> 24 * 60 * 60 * 1000
):
if osufile.hit_objects[-1].start_time - osufile.hit_objects[0].start_time > 24 * 60 * 60 * 1000:
return True
if osufile.mode == int(GameMode.TAIKO):
if len(osufile.hit_objects) > Threshold.TAIKO_THRESHOLD:

View File

@@ -124,14 +124,10 @@ class Settings(BaseSettings):
smtp_password: str = ""
from_email: str = "noreply@example.com"
from_name: str = "osu! server"
# 邮件验证功能开关
enable_email_verification: bool = Field(
default=True, description="是否启用邮件验证功能"
)
enable_email_sending: bool = Field(
default=False, description="是否真实发送邮件False时仅模拟发送"
)
enable_email_verification: bool = Field(default=True, description="是否启用邮件验证功能")
enable_email_sending: bool = Field(default=False, description="是否真实发送邮件False时仅模拟发送")
# Sentry 配置
sentry_dsn: HttpUrl | None = None
@@ -143,12 +139,8 @@ class Settings(BaseSettings):
geoip_update_hour: int = 2 # 每周更新的小时数0-23
# 游戏设置
enable_rx: bool = Field(
default=False, validation_alias=AliasChoices("enable_rx", "enable_osu_rx")
)
enable_ap: bool = Field(
default=False, validation_alias=AliasChoices("enable_ap", "enable_osu_ap")
)
enable_rx: bool = Field(default=False, validation_alias=AliasChoices("enable_rx", "enable_osu_rx"))
enable_ap: bool = Field(default=False, validation_alias=AliasChoices("enable_ap", "enable_osu_ap"))
enable_all_mods_pp: bool = False
enable_supporter_for_all_users: bool = False
enable_all_beatmap_leaderboard: bool = False
@@ -189,9 +181,7 @@ class Settings(BaseSettings):
# 存储设置
storage_service: StorageServiceType = StorageServiceType.LOCAL
storage_settings: (
LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings
) = LocalStorageSettings()
storage_settings: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings = LocalStorageSettings()
@field_validator("fetcher_scopes", mode="before")
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
@@ -207,22 +197,13 @@ class Settings(BaseSettings):
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2:
if not isinstance(v, CloudflareR2Settings):
raise ValueError(
"When storage_service is 'r2', "
"storage_settings must be CloudflareR2Settings"
)
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
elif info.data.get("storage_service") == StorageServiceType.LOCAL:
if not isinstance(v, LocalStorageSettings):
raise ValueError(
"When storage_service is 'local', "
"storage_settings must be LocalStorageSettings"
)
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
elif info.data.get("storage_service") == StorageServiceType.AWS_S3:
if not isinstance(v, AWSS3StorageSettings):
raise ValueError(
"When storage_service is 's3', "
"storage_settings must be AWSS3StorageSettings"
)
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
return v

View File

@@ -28,18 +28,14 @@ if TYPE_CHECKING:
class UserAchievementBase(SQLModel, UTCBaseModel):
achievement_id: int
achieved_at: datetime = Field(
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
)
achieved_at: datetime = Field(default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)))
class UserAchievement(UserAchievementBase, table=True):
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
__tablename__: str = "lazer_user_achievements"
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True)
user: "User" = Relationship(back_populates="achievement")
@@ -56,11 +52,7 @@ async def process_achievements(session: AsyncSession, redis: Redis, score_id: in
if not score:
return
achieved = (
await session.exec(
select(UserAchievement.achievement_id).where(
UserAchievement.user_id == score.user_id
)
)
await session.exec(select(UserAchievement.achievement_id).where(UserAchievement.user_id == score.user_id))
).all()
not_achieved = {k: v for k, v in MEDALS.items() if k.id not in achieved}
result: list[Achievement] = []
@@ -78,9 +70,7 @@ async def process_achievements(session: AsyncSession, redis: Redis, score_id: in
)
await redis.publish(
"chat:notification",
UserAchievementUnlock.init(
r, score.user_id, score.gamemode
).model_dump_json(),
UserAchievementUnlock.init(r, score.user_id, score.gamemode).model_dump_json(),
)
event = Event(
created_at=now,

View File

@@ -20,42 +20,34 @@ if TYPE_CHECKING:
class OAuthToken(UTCBaseModel, SQLModel, table=True):
__tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType]
__tablename__: str = "oauth_tokens"
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
client_id: int = Field(index=True)
access_token: str = Field(max_length=500, unique=True)
refresh_token: str = Field(max_length=500, unique=True)
token_type: str = Field(default="Bearer", max_length=20)
scope: str = Field(default="*", max_length=100)
expires_at: datetime = Field(sa_column=Column(DateTime))
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
created_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
user: "User" = Relationship()
class OAuthClient(SQLModel, table=True):
__tablename__ = "oauth_clients" # pyright: ignore[reportAssignmentType]
__tablename__: str = "oauth_clients"
name: str = Field(max_length=100, index=True)
description: str = Field(sa_column=Column(Text), default="")
client_id: int | None = Field(default=None, primary_key=True, index=True)
client_secret: str = Field(default_factory=secrets.token_hex, index=True)
redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON))
owner_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
class V1APIKeys(SQLModel, table=True):
__tablename__ = "v1_api_keys" # pyright: ignore[reportAssignmentType]
__tablename__: str = "v1_api_keys"
id: int | None = Field(default=None, primary_key=True)
name: str = Field(max_length=100, index=True)
key: str = Field(default_factory=secrets.token_hex, index=True)
owner_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))

View File

@@ -60,17 +60,13 @@ class BeatmapBase(SQLModel):
class Beatmap(BeatmapBase, table=True):
__tablename__ = "beatmaps" # pyright: ignore[reportAssignmentType]
__tablename__: str = "beatmaps"
id: int = Field(primary_key=True, index=True)
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus = Field(index=True)
# optional
beatmapset: Beatmapset = Relationship(
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
)
failtimes: FailTime | None = Relationship(
back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"}
)
beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
@classmethod
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
@@ -84,21 +80,15 @@ class Beatmap(BeatmapBase, table=True):
"beatmap_status": BeatmapRankStatus(resp.ranked),
}
)
if not (
await session.exec(select(exists()).where(Beatmap.id == resp.id))
).first():
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
session.add(beatmap)
await session.commit()
beatmap = (
await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
).first()
beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).first()
assert beatmap is not None, "Beatmap should not be None after commit"
return beatmap
@classmethod
async def from_resp_batch(
cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0
) -> list["Beatmap"]:
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
beatmaps = []
for resp in inp:
if resp.id == from_:
@@ -113,9 +103,7 @@ class Beatmap(BeatmapBase, table=True):
"beatmap_status": BeatmapRankStatus(resp.ranked),
}
)
if not (
await session.exec(select(exists()).where(Beatmap.id == resp.id))
).first():
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
session.add(beatmap)
beatmaps.append(beatmap)
await session.commit()
@@ -130,17 +118,11 @@ class Beatmap(BeatmapBase, table=True):
md5: str | None = None,
) -> "Beatmap":
beatmap = (
await session.exec(
select(Beatmap).where(
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
)
)
await session.exec(select(Beatmap).where(Beatmap.id == bid if bid is not None else Beatmap.checksum == md5))
).first()
if not beatmap:
resp = await fetcher.get_beatmap(bid, md5)
r = await session.exec(
select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)
)
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
if not r.first():
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
@@ -178,10 +160,7 @@ class BeatmapResp(BeatmapBase):
if query_mode is not None and beatmap.mode != query_mode:
beatmap_["convert"] = True
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
if (
settings.enable_all_beatmap_leaderboard
and not beatmap_status.has_leaderboard()
):
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
else:
@@ -189,9 +168,7 @@ class BeatmapResp(BeatmapBase):
beatmap_["ranked"] = beatmap_status.value
beatmap_["mode_int"] = int(beatmap.mode)
if not from_set:
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(
beatmap.beatmapset, session=session, user=user
)
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
if beatmap.failtimes is not None:
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
else:
@@ -218,7 +195,7 @@ class BeatmapResp(BeatmapBase):
class BannedBeatmaps(SQLModel, table=True):
__tablename__ = "banned_beatmaps" # pyright: ignore[reportAssignmentType]
__tablename__: str = "banned_beatmaps"
id: int | None = Field(primary_key=True, index=True, default=None)
beatmap_id: int = Field(index=True)
@@ -230,15 +207,10 @@ async def calculate_beatmap_attributes(
redis: Redis,
fetcher: "Fetcher",
):
key = (
f"beatmap:{beatmap_id}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
)
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
return BeatmapAttributes.model_validate_json(await redis.get(key))
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
attr = await asyncio.get_event_loop().run_in_executor(
None, calculate_beatmap_attribute, resp, ruleset, mods_
)
attr = await asyncio.get_event_loop().run_in_executor(None, calculate_beatmap_attribute, resp, ruleset, mods_)
await redis.set(key, attr.model_dump_json())
return attr

View File

@@ -23,15 +23,13 @@ if TYPE_CHECKING:
class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True):
__tablename__ = "beatmap_playcounts" # pyright: ignore[reportAssignmentType]
__tablename__: str = "beatmap_playcounts"
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
playcount: int = Field(default=0)
@@ -59,9 +57,7 @@ class BeatmapPlaycountsResp(BaseModel):
)
async def process_beatmap_playcount(
session: AsyncSession, user_id: int, beatmap_id: int
) -> None:
async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None:
existing_playcount = (
await session.exec(
select(BeatmapPlaycounts).where(
@@ -89,7 +85,5 @@ async def process_beatmap_playcount(
}
session.add(playcount_event)
else:
new_playcount = BeatmapPlaycounts(
user_id=user_id, beatmap_id=beatmap_id, playcount=1
)
new_playcount = BeatmapPlaycounts(user_id=user_id, beatmap_id=beatmap_id, playcount=1)
session.add(new_playcount)

View File

@@ -86,9 +86,7 @@ class BeatmapsetBase(SQLModel):
# optional
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
current_nominations: list[BeatmapNomination] | None = Field(
None, sa_column=Column(JSON)
)
current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON))
description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON))
# TODO: discussions: list[BeatmapsetDiscussion] = None
# TODO: current_user_attributes: Optional[CurrentUserAttributes] = None
@@ -105,22 +103,18 @@ class BeatmapsetBase(SQLModel):
can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
ranked_date: datetime | None = Field(
default=None, sa_column=Column(DateTime, index=True)
)
ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True))
storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
tags: str = Field(default="", sa_column=Column(Text))
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
__tablename__: str = "beatmapsets"
id: int | None = Field(default=None, primary_key=True, index=True)
id: int = Field(default=None, primary_key=True, index=True)
# Beatmapset
beatmap_status: BeatmapRankStatus = Field(
default=BeatmapRankStatus.GRAVEYARD, index=True
)
beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True)
# optional
beatmaps: list["Beatmap"] = Relationship(back_populates="beatmapset")
@@ -137,9 +131,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod
async def from_resp(
cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0
) -> "Beatmapset":
async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
from .beatmap import Beatmap
d = resp.model_dump()
@@ -167,18 +159,14 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
"download_disabled": resp.availability.download_disabled or False,
}
)
if not (
await session.exec(select(exists()).where(Beatmapset.id == resp.id))
).first():
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
session.add(beatmapset)
await session.commit()
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
return beatmapset
@classmethod
async def get_or_fetch(
cls, session: AsyncSession, fetcher: "Fetcher", sid: int
) -> "Beatmapset":
async def get_or_fetch(cls, session: AsyncSession, fetcher: "Fetcher", sid: int) -> "Beatmapset":
beatmapset = await session.get(Beatmapset, sid)
if not beatmapset:
resp = await fetcher.get_beatmapset(sid)
@@ -227,13 +215,9 @@ class BeatmapsetResp(BeatmapsetBase):
@model_validator(mode="after")
def fix_genre_language(self) -> Self:
if self.genre is None:
self.genre = BeatmapTranslationText(
name=Genre(self.genre_id).name, id=self.genre_id
)
self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id)
if self.language is None:
self.language = BeatmapTranslationText(
name=Language(self.language_id).name, id=self.language_id
)
self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id)
return self
@classmethod
@@ -252,9 +236,7 @@ class BeatmapsetResp(BeatmapsetBase):
await BeatmapResp.from_db(beatmap, from_set=True, session=session)
for beatmap in await beatmapset.awaitable_attrs.beatmaps
],
"hype": BeatmapHype(
current=beatmapset.hype_current, required=beatmapset.hype_required
),
"hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required),
"availability": BeatmapAvailability(
more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled,
@@ -282,10 +264,7 @@ class BeatmapsetResp(BeatmapsetBase):
update["ratings"] = []
beatmap_status = beatmapset.beatmap_status
if (
settings.enable_all_beatmap_leaderboard
and not beatmap_status.has_leaderboard()
):
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
update["status"] = BeatmapRankStatus.APPROVED.name.lower()
update["ranked"] = BeatmapRankStatus.APPROVED.value
else:
@@ -295,9 +274,7 @@ class BeatmapsetResp(BeatmapsetBase):
if session and user:
existing_favourite = (
await session.exec(
select(FavouriteBeatmapset).where(
FavouriteBeatmapset.beatmapset_id == beatmapset.id
)
select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
)
).first()
update["has_favourited"] = existing_favourite is not None

View File

@@ -20,13 +20,9 @@ if TYPE_CHECKING:
class BestScore(SQLModel, table=True):
__tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
__tablename__: str = "total_score_best_scores"
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True)
total_score: int = Field(default=0, sa_column=Column(BigInteger))

View File

@@ -51,30 +51,22 @@ class ChatChannelBase(SQLModel):
class ChatChannel(ChatChannelBase, table=True):
__tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType]
channel_id: int | None = Field(primary_key=True, index=True, default=None)
__tablename__: str = "chat_channels"
channel_id: int = Field(primary_key=True, index=True, default=None)
@classmethod
async def get(
cls, channel: str | int, session: AsyncSession
) -> "ChatChannel | None":
async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
if isinstance(channel, int) or channel.isdigit():
# 使用查询而不是 get() 来确保对象完全加载
result = await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
result = await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))
channel_ = result.first()
if channel_ is not None:
return channel_
result = await session.exec(
select(ChatChannel).where(ChatChannel.name == channel)
)
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
return result.first()
@classmethod
async def get_pm_channel(
cls, user1: int, user2: int, session: AsyncSession
) -> "ChatChannel | None":
async def get_pm_channel(cls, user1: int, user2: int, session: AsyncSession) -> "ChatChannel | None":
channel = await cls.get(f"pm_{user1}_{user2}", session)
if channel is None:
channel = await cls.get(f"pm_{user2}_{user1}", session)
@@ -153,18 +145,13 @@ class ChatChannelResp(ChatChannelBase):
.limit(10)
)
).all()
c.recent_messages = [
await ChatMessageResp.from_db(msg, session, user) for msg in messages
]
c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages]
c.recent_messages.reverse()
if c.type == ChannelType.PM and users and len(users) == 2:
target_user_id = next(u for u in users if u != user.id)
target_name = await session.exec(
select(User.username).where(User.id == target_user_id)
)
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
c.name = target_name.one()
assert user.id
c.users = [target_user_id, user.id]
return c
@@ -181,19 +168,15 @@ class MessageType(str, Enum):
class ChatMessageBase(UTCBaseModel, SQLModel):
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
content: str = Field(sa_column=Column(VARCHAR(1000)))
message_id: int | None = Field(index=True, primary_key=True, default=None)
sender_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
timestamp: datetime = Field(
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
)
message_id: int = Field(index=True, primary_key=True, default=None)
sender_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
timestamp: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC))
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
uuid: str | None = Field(default=None)
class ChatMessage(ChatMessageBase, table=True):
__tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType]
__tablename__: str = "chat_messages"
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
channel: ChatChannel = Relationship()
@@ -211,9 +194,7 @@ class ChatMessageResp(ChatMessageBase):
if user:
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
else:
m.sender = await UserResp.from_db(
db_message.user, session, RANKING_INCLUDES
)
m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES)
return m
@@ -221,17 +202,13 @@ class ChatMessageResp(ChatMessageBase):
class SilenceUser(UTCBaseModel, SQLModel, table=True):
__tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType]
id: int | None = Field(primary_key=True, default=None, index=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
__tablename__: str = "chat_silence_users"
id: int = Field(primary_key=True, default=None, index=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True)
until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None)
reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True))
banned_at: datetime = Field(
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
)
banned_at: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC))
class UserSilenceResp(SQLModel):
@@ -240,7 +217,6 @@ class UserSilenceResp(SQLModel):
@classmethod
def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp":
assert db_silence.id is not None
return cls(
id=db_silence.id,
user_id=db_silence.user_id,

View File

@@ -21,28 +21,24 @@ class CountBase(SQLModel):
class MonthlyPlaycounts(CountBase, table=True):
__tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType]
__tablename__: str = "monthly_playcounts"
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
user: "User" = Relationship(back_populates="monthly_playcounts")
class ReplayWatchedCount(CountBase, table=True):
__tablename__ = "replays_watched_counts" # pyright: ignore[reportAssignmentType]
__tablename__: str = "replays_watched_counts"
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
user: "User" = Relationship(back_populates="replays_watched_counts")

View File

@@ -24,9 +24,7 @@ class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
daily_streak_best: int = Field(default=0)
daily_streak_current: int = Field(default=0)
last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
last_weekly_streak: datetime | None = Field(
default=None, sa_column=Column(DateTime)
)
last_weekly_streak: datetime | None = Field(default=None, sa_column=Column(DateTime))
playcount: int = Field(default=0)
top_10p_placements: int = Field(default=0)
top_50p_placements: int = Field(default=0)
@@ -35,7 +33,7 @@ class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
class DailyChallengeStats(DailyChallengeStatsBase, table=True):
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
__tablename__: str = "daily_challenge_stats"
user_id: int | None = Field(
default=None,
@@ -61,9 +59,7 @@ class DailyChallengeStatsResp(DailyChallengeStatsBase):
return cls.model_validate(obj)
async def process_daily_challenge_score(
session: AsyncSession, user_id: int, room_id: int
):
async def process_daily_challenge_score(session: AsyncSession, user_id: int, room_id: int):
from .playlist_best_score import PlaylistBestScore
score = (

View File

@@ -4,16 +4,17 @@
from __future__ import annotations
from datetime import datetime, UTC
from sqlmodel import SQLModel, Field
from sqlalchemy import Column, BigInteger, ForeignKey
from datetime import UTC, datetime
from sqlalchemy import BigInteger, Column, ForeignKey
from sqlmodel import Field, SQLModel
class EmailVerification(SQLModel, table=True):
"""邮件验证记录"""
__tablename__: str = "email_verifications"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
email: str = Field(index=True)
@@ -28,9 +29,9 @@ class EmailVerification(SQLModel, table=True):
class LoginSession(SQLModel, table=True):
"""登录会话记录"""
__tablename__: str = "login_sessions"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
session_token: str = Field(unique=True, index=True) # 会话令牌

View File

@@ -36,17 +36,13 @@ class EventType(str, Enum):
class EventBase(SQLModel):
id: int = Field(default=None, primary_key=True)
created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), default=datetime.now(UTC))
)
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), default=datetime.now(UTC)))
type: EventType
event_payload: dict = Field(
exclude=True, default_factory=dict, sa_column=Column(JSON)
)
event_payload: dict = Field(exclude=True, default_factory=dict, sa_column=Column(JSON))
class Event(EventBase, table=True):
__tablename__ = "user_events" # pyright: ignore[reportAssignmentType]
__tablename__: str = "user_events"
user_id: int | None = Field(
default=None,
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True),

View File

@@ -16,8 +16,8 @@ FAILTIME_STRUCT = Struct("<100i")
class FailTime(SQLModel, table=True):
__tablename__ = "failtime" # pyright: ignore[reportAssignmentType]
beatmap_id: int = Field(primary_key=True, index=True, foreign_key="beatmaps.id")
__tablename__: str = "failtime"
beatmap_id: int = Field(primary_key=True, foreign_key="beatmaps.id")
exit: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False))
fail: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False))
@@ -41,12 +41,8 @@ class FailTime(SQLModel, table=True):
class FailTimeResp(BaseModel):
exit: list[int] = Field(
default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400))
)
fail: list[int] = Field(
default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400))
)
exit: list[int] = Field(default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)))
fail: list[int] = Field(default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)))
@classmethod
def from_db(cls, failtime: FailTime) -> "FailTimeResp":

View File

@@ -16,7 +16,7 @@ from sqlmodel import (
class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True):
__tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType]
__tablename__: str = "favourite_beatmapset"
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),

View File

@@ -75,9 +75,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_active: bool = True
is_bot: bool = False
is_supporter: bool = False
last_visit: datetime | None = Field(
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
)
last_visit: datetime | None = Field(default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)))
pm_friends_only: bool = False
profile_colour: str | None = None
username: str = Field(max_length=32, unique=True, index=True)
@@ -90,9 +88,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_restricted: bool = False
# blocks
cover: UserProfileCover = Field(
default=UserProfileCover(
url="https://assets.ppy.sh/user-profile-covers/default.jpeg"
),
default=UserProfileCover(url="https://assets.ppy.sh/user-profile-covers/default.jpeg"),
sa_column=Column(JSON),
)
beatmap_playcounts_count: int = 0
@@ -150,9 +146,9 @@ class UserBase(UTCBaseModel, SQLModel):
class User(AsyncAttrs, UserBase, table=True):
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
__tablename__: str = "lazer_users"
id: int | None = Field(
id: int = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
)
@@ -160,16 +156,10 @@ class User(AsyncAttrs, UserBase, table=True):
statistics: list[UserStatistics] = Relationship()
achievement: list[UserAchievement] = Relationship(back_populates="user")
team_membership: TeamMember | None = Relationship(back_populates="user")
daily_challenge_stats: DailyChallengeStats | None = Relationship(
back_populates="user"
)
daily_challenge_stats: DailyChallengeStats | None = Relationship(back_populates="user")
monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user")
replays_watched_counts: list[ReplayWatchedCount] = Relationship(
back_populates="user"
)
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(
back_populates="user"
)
replays_watched_counts: list[ReplayWatchedCount] = Relationship(back_populates="user")
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(back_populates="user")
rank_history: list[RankHistory] = Relationship(
back_populates="user",
)
@@ -178,16 +168,10 @@ class User(AsyncAttrs, UserBase, table=True):
email: str = Field(max_length=254, unique=True, index=True, exclude=True)
priv: int = Field(default=1, exclude=True)
pw_bcrypt: str = Field(max_length=60, exclude=True)
silence_end_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
)
donor_end_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
)
silence_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True)
donor_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True)
async def is_user_can_pm(
self, from_user: "User", session: AsyncSession
) -> tuple[bool, str]:
async def is_user_can_pm(self, from_user: "User", session: AsyncSession) -> tuple[bool, str]:
from .relationship import Relationship, RelationshipType
from_relationship = (
@@ -200,13 +184,10 @@ class User(AsyncAttrs, UserBase, table=True):
).first()
if from_relationship and from_relationship.type == RelationshipType.BLOCK:
return False, "You have blocked the target user."
if from_user.pm_friends_only and (
not from_relationship or from_relationship.type != RelationshipType.FOLLOW
):
if from_user.pm_friends_only and (not from_relationship or from_relationship.type != RelationshipType.FOLLOW):
return (
False,
"You have disabled non-friend communications "
"and target user is not your friend.",
"You have disabled non-friend communications and target user is not your friend.",
)
relationship = (
@@ -219,9 +200,7 @@ class User(AsyncAttrs, UserBase, table=True):
).first()
if relationship and relationship.type == RelationshipType.BLOCK:
return False, "Target user has blocked you."
if self.pm_friends_only and (
not relationship or relationship.type != RelationshipType.FOLLOW
):
if self.pm_friends_only and (not relationship or relationship.type != RelationshipType.FOLLOW):
return False, "Target user has disabled non-friend communications"
return True, ""
@@ -288,9 +267,7 @@ class UserResp(UserBase):
u = cls.model_validate(obj.model_dump())
u.id = obj.id
u.default_group = "bot" if u.is_bot else "default"
u.country = Country(
code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown")
)
u.country = Country(code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown"))
u.follower_count = (
await session.exec(
select(func.count())
@@ -314,9 +291,7 @@ class UserResp(UserBase):
redis = get_redis()
u.is_online = await redis.exists(f"metadata:online:{obj.id}")
u.cover_url = (
obj.cover.get(
"url", "https://assets.ppy.sh/user-profile-covers/default.jpeg"
)
obj.cover.get("url", "https://assets.ppy.sh/user-profile-covers/default.jpeg")
if obj.cover
else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
)
@@ -335,22 +310,15 @@ class UserResp(UserBase):
]
if "team" in include:
if await obj.awaitable_attrs.team_membership:
assert obj.team_membership
u.team = obj.team_membership.team
if team_membership := await obj.awaitable_attrs.team_membership:
u.team = team_membership.team
if "account_history" in include:
u.account_history = [
UserAccountHistoryResp.from_db(ah)
for ah in await obj.awaitable_attrs.account_history
]
u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
if "daily_challenge_user_stats":
if await obj.awaitable_attrs.daily_challenge_stats:
assert obj.daily_challenge_stats
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
obj.daily_challenge_stats
)
if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats:
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
if "statistics" in include:
current_stattistics = None
@@ -359,59 +327,40 @@ class UserResp(UserBase):
current_stattistics = i
break
u.statistics = (
await UserStatisticsResp.from_db(
current_stattistics, session, obj.country_code
)
await UserStatisticsResp.from_db(current_stattistics, session, obj.country_code)
if current_stattistics
else None
)
if "statistics_rulesets" in include:
u.statistics_rulesets = {
i.mode.value: await UserStatisticsResp.from_db(
i, session, obj.country_code
)
i.mode.value: await UserStatisticsResp.from_db(i, session, obj.country_code)
for i in await obj.awaitable_attrs.statistics
}
if "monthly_playcounts" in include:
u.monthly_playcounts = [
CountResp.from_db(pc)
for pc in await obj.awaitable_attrs.monthly_playcounts
]
u.monthly_playcounts = [CountResp.from_db(pc) for pc in await obj.awaitable_attrs.monthly_playcounts]
if len(u.monthly_playcounts) == 1:
d = u.monthly_playcounts[0].start_date
u.monthly_playcounts.insert(
0, CountResp(start_date=d - timedelta(days=20), count=0)
)
u.monthly_playcounts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
if "replays_watched_counts" in include:
u.replay_watched_counts = [
CountResp.from_db(rwc)
for rwc in await obj.awaitable_attrs.replays_watched_counts
CountResp.from_db(rwc) for rwc in await obj.awaitable_attrs.replays_watched_counts
]
if len(u.replay_watched_counts) == 1:
d = u.replay_watched_counts[0].start_date
u.replay_watched_counts.insert(
0, CountResp(start_date=d - timedelta(days=20), count=0)
)
u.replay_watched_counts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
if "achievements" in include:
u.user_achievements = [
UserAchievementResp.from_db(ua)
for ua in await obj.awaitable_attrs.achievement
]
u.user_achievements = [UserAchievementResp.from_db(ua) for ua in await obj.awaitable_attrs.achievement]
if "rank_history" in include:
rank_history = await RankHistoryResp.from_db(session, obj.id, ruleset)
if len(rank_history.data) != 0:
u.rank_history = rank_history
rank_top = (
await session.exec(
select(RankTop).where(
RankTop.user_id == obj.id, RankTop.mode == ruleset
)
)
await session.exec(select(RankTop).where(RankTop.user_id == obj.id, RankTop.mode == ruleset))
).first()
if rank_top:
u.rank_highest = (
@@ -425,9 +374,7 @@ class UserResp(UserBase):
u.favourite_beatmapset_count = (
await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.user_id == obj.id)
select(func.count()).select_from(FavouriteBeatmapset).where(FavouriteBeatmapset.user_id == obj.id)
)
).one()
u.scores_pinned_count = (
@@ -478,17 +425,19 @@ class UserResp(UserBase):
# 检查会话验证状态
# 如果邮件验证功能被禁用,则始终设置 session_verified 为 true
from app.config import settings
if not settings.enable_email_verification:
u.session_verified = True
else:
# 如果用户有未验证的登录会话,则设置 session_verified 为 false
from .email_verification import LoginSession
unverified_session = (
await session.exec(
select(LoginSession).where(
LoginSession.user_id == obj.id,
LoginSession.is_verified == False,
LoginSession.expires_at > datetime.now(UTC)
col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > datetime.now(UTC),
)
)
).first()

View File

@@ -30,8 +30,8 @@ class MultiplayerEventBase(SQLModel, UTCBaseModel):
class MultiplayerEvent(MultiplayerEventBase, table=True):
__tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
__tablename__: str = "multiplayer_events"
id: int = Field(
default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
)

View File

@@ -17,7 +17,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
class Notification(SQLModel, table=True):
__tablename__ = "notifications" # pyright: ignore[reportAssignmentType]
__tablename__: str = "notifications"
id: int = Field(primary_key=True, index=True, default=None)
name: NotificationName = Field(index=True)
@@ -30,7 +30,7 @@ class Notification(SQLModel, table=True):
class UserNotification(SQLModel, table=True):
__tablename__ = "user_notifications" # pyright: ignore[reportAssignmentType]
__tablename__: str = "user_notifications"
id: int = Field(
sa_column=Column(
BigInteger,
@@ -40,9 +40,7 @@ class UserNotification(SQLModel, table=True):
default=None,
)
notification_id: int = Field(index=True, foreign_key="notifications.id")
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
is_read: bool = Field(index=True)
notification: Notification = Relationship(sa_relationship_kwargs={"lazy": "joined"})

View File

@@ -4,16 +4,17 @@
from __future__ import annotations
from datetime import datetime, UTC
from sqlmodel import SQLModel, Field
from sqlalchemy import Column, BigInteger, ForeignKey
from datetime import UTC, datetime
from sqlalchemy import BigInteger, Column, ForeignKey
from sqlmodel import Field, SQLModel
class PasswordReset(SQLModel, table=True):
"""密码重置记录"""
__tablename__: str = "password_resets"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
email: str = Field(index=True)

View File

@@ -21,16 +21,14 @@ class ItemAttemptsCountBase(SQLModel):
room_id: int = Field(foreign_key="rooms.id", index=True)
attempts: int = Field(default=0)
completed: int = Field(default=0)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
accuracy: float = 0.0
pp: float = 0
total_score: int = 0
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
__tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType]
__tablename__: str = "item_attempts_count"
id: int | None = Field(default=None, primary_key=True)
user: User = Relationship()
@@ -63,9 +61,7 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
self.pp = sum(score.score.pp for score in playlist_scores)
self.completed = len([score for score in playlist_scores if score.score.passed])
self.accuracy = (
sum(score.score.accuracy for score in playlist_scores) / self.completed
if self.completed > 0
else 0.0
sum(score.score.accuracy for score in playlist_scores) / self.completed if self.completed > 0 else 0.0
)
await session.commit()
await session.refresh(self)

View File

@@ -21,14 +21,10 @@ if TYPE_CHECKING:
class PlaylistBestScore(SQLModel, table=True):
__tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType]
__tablename__: str = "playlist_best_scores"
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
room_id: int = Field(foreign_key="rooms.id", index=True)
playlist_id: int = Field(index=True)
total_score: int = Field(default=0, sa_column=Column(BigInteger))

View File

@@ -50,7 +50,7 @@ class PlaylistBase(SQLModel, UTCBaseModel):
class Playlist(PlaylistBase, table=True):
__tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType]
__tablename__: str = "room_playlists"
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
room_id: int = Field(foreign_key="rooms.id", exclude=True)
@@ -63,16 +63,12 @@ class Playlist(PlaylistBase, table=True):
@classmethod
async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int:
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(
cls.room_id == room_id
)
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(cls.room_id == room_id)
result = await session.exec(stmt)
return result.one()
@classmethod
async def from_hub(
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
) -> "Playlist":
async def from_hub(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
next_id = await cls.get_next_id_for_room(room_id, session=session)
return cls(
id=next_id,
@@ -90,9 +86,7 @@ class Playlist(PlaylistBase, table=True):
@classmethod
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
db_playlist = await session.exec(
select(cls).where(cls.id == playlist.id, cls.room_id == room_id)
)
db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id))
db_playlist = db_playlist.first()
if db_playlist is None:
raise ValueError("Playlist item not found")
@@ -108,9 +102,7 @@ class Playlist(PlaylistBase, table=True):
await session.commit()
@classmethod
async def add_to_db(
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
):
async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
db_playlist = await cls.from_hub(playlist, room_id, session)
session.add(db_playlist)
await session.commit()
@@ -119,9 +111,7 @@ class Playlist(PlaylistBase, table=True):
@classmethod
async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession):
db_playlist = await session.exec(
select(cls).where(cls.id == item_id, cls.room_id == room_id)
)
db_playlist = await session.exec(select(cls).where(cls.id == item_id, cls.room_id == room_id))
db_playlist = db_playlist.first()
if db_playlist is None:
raise ValueError("Playlist item not found")
@@ -133,9 +123,7 @@ class PlaylistResp(PlaylistBase):
beatmap: BeatmapResp | None = None
@classmethod
async def from_db(
cls, playlist: Playlist, include: list[str] = []
) -> "PlaylistResp":
async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp":
data = playlist.model_dump()
if "beatmap" in include:
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)

View File

@@ -20,13 +20,9 @@ if TYPE_CHECKING:
class PPBestScore(SQLModel, table=True):
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
__tablename__: str = "best_scores"
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True)
pp: float = Field(

View File

@@ -26,12 +26,10 @@ if TYPE_CHECKING:
class RankHistory(SQLModel, table=True):
__tablename__ = "rank_history" # pyright: ignore[reportAssignmentType]
__tablename__: str = "rank_history"
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
mode: GameMode
rank: int
date: dt = Field(
@@ -43,12 +41,10 @@ class RankHistory(SQLModel, table=True):
class RankTop(SQLModel, table=True):
__tablename__ = "rank_top" # pyright: ignore[reportAssignmentType]
__tablename__: str = "rank_top"
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
mode: GameMode
rank: int
date: dt = Field(
@@ -62,9 +58,7 @@ class RankHistoryResp(BaseModel):
data: list[int]
@classmethod
async def from_db(
cls, session: AsyncSession, user_id: int, mode: GameMode
) -> "RankHistoryResp":
async def from_db(cls, session: AsyncSession, user_id: int, mode: GameMode) -> "RankHistoryResp":
results = (
await session.exec(
select(RankHistory)

View File

@@ -21,7 +21,7 @@ class RelationshipType(str, Enum):
class Relationship(SQLModel, table=True):
__tablename__ = "relationship" # pyright: ignore[reportAssignmentType]
__tablename__: str = "relationship"
id: int | None = Field(
default=None,
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
@@ -59,9 +59,7 @@ class RelationshipResp(BaseModel):
type: RelationshipType
@classmethod
async def from_db(
cls, session: AsyncSession, relationship: Relationship
) -> "RelationshipResp":
async def from_db(cls, session: AsyncSession, relationship: Relationship) -> "RelationshipResp":
target_relationship = (
await session.exec(
select(Relationship).where(

View File

@@ -58,11 +58,9 @@ class RoomBase(SQLModel, UTCBaseModel):
class Room(AsyncAttrs, RoomBase, table=True):
__tablename__ = "rooms" # pyright: ignore[reportAssignmentType]
__tablename__: str = "rooms"
id: int = Field(default=None, primary_key=True, index=True)
host_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
host: User = Relationship()
playlist: list[Playlist] = Relationship(
@@ -109,12 +107,8 @@ class RoomResp(RoomBase):
if not playlist.expired:
stats.count_active += 1
rulesets.add(playlist.ruleset_id)
difficulty_range.min = min(
difficulty_range.min, playlist.beatmap.difficulty_rating
)
difficulty_range.max = max(
difficulty_range.max, playlist.beatmap.difficulty_rating
)
difficulty_range.min = min(difficulty_range.min, playlist.beatmap.difficulty_rating)
difficulty_range.max = max(difficulty_range.max, playlist.beatmap.difficulty_rating)
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
stats.ruleset_ids = list(rulesets)
resp.playlist_item_stats = stats
@@ -137,13 +131,9 @@ class RoomResp(RoomBase):
include=["statistics"],
)
)
resp.host = await UserResp.from_db(
await room.awaitable_attrs.host, session, include=["statistics"]
)
resp.host = await UserResp.from_db(await room.awaitable_attrs.host, session, include=["statistics"])
if "current_user_score" in include and user:
resp.current_user_score = await PlaylistAggregateScore.from_db(
room.id, user.id, session
)
resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
return resp
@classmethod

View File

@@ -18,22 +18,16 @@ if TYPE_CHECKING:
class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True):
__tablename__ = "room_participated_users" # pyright: ignore[reportAssignmentType]
__tablename__: str = "room_participated_users"
id: int | None = Field(
default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True)
)
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True))
room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), nullable=False))
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False))
joined_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False),
default=datetime.now(UTC),
)
left_at: datetime | None = Field(
sa_column=Column(DateTime(timezone=True), nullable=True), default=None
)
left_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True), default=None)
room: "Room" = Relationship()
user: "User" = Relationship()

View File

@@ -47,9 +47,9 @@ from .score_token import ScoreToken
from pydantic import field_serializer, field_validator
from redis.asyncio import Redis
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime, TextClause
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Mapped, aliased
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import (
JSON,
@@ -76,9 +76,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
accuracy: float
map_md5: str = Field(max_length=32, index=True)
build_id: int | None = Field(default=None)
classic_total_score: int | None = Field(
default=0, sa_column=Column(BigInteger)
) # solo_score
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score
ended_at: datetime = Field(sa_column=Column(DateTime))
has_replay: bool = Field(sa_column=Column(Boolean))
max_combo: int
@@ -91,14 +89,10 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
room_id: int | None = Field(default=None) # multiplayer
started_at: datetime = Field(sa_column=Column(DateTime))
total_score: int = Field(default=0, sa_column=Column(BigInteger))
total_score_without_mods: int = Field(
default=0, sa_column=Column(BigInteger), exclude=True
)
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
type: str
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
maximum_statistics: ScoreStatistics = Field(
sa_column=Column(JSON), default_factory=dict
)
maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict)
@field_validator("maximum_statistics", mode="before")
@classmethod
@@ -147,10 +141,8 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
class Score(ScoreBase, table=True):
__tablename__ = "scores" # pyright: ignore[reportAssignmentType]
id: int | None = Field(
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
)
__tablename__: str = "scores"
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
user_id: int = Field(
default=None,
sa_column=Column(
@@ -193,8 +185,8 @@ class Score(ScoreBase, table=True):
return str(v)
# optional
beatmap: Beatmap = Relationship()
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
beatmap: Mapped[Beatmap] = Relationship()
user: Mapped[User] = Relationship(sa_relationship_kwargs={"lazy": "joined"})
@property
def is_perfect_combo(self) -> bool:
@@ -205,11 +197,7 @@ class Score(ScoreBase, table=True):
*where_clauses: ColumnExpressionArgument[bool] | bool,
) -> SelectOfScalar["Score"]:
rownum = (
func.row_number()
.over(
partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()
)
.label("rn")
func.row_number().over(partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()).label("rn")
)
subq = select(Score, rownum).where(*where_clauses).subquery()
best = aliased(Score, subq, adapt_on_names=True)
@@ -296,12 +284,9 @@ class ScoreResp(ScoreBase):
await session.refresh(score)
s = cls.model_validate(score.model_dump())
assert score.id
await score.awaitable_attrs.beatmap
s.beatmap = await BeatmapResp.from_db(score.beatmap)
s.beatmapset = await BeatmapsetResp.from_db(
score.beatmap.beatmapset, session=session, user=score.user
)
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = int(score.gamemode)
@@ -371,11 +356,7 @@ class ScoreAround(SQLModel):
async def get_best_id(session: AsyncSession, score_id: int) -> None:
rownum = (
func.row_number()
.over(
partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()
)
.label("rn")
func.row_number().over(partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()).label("rn")
)
subq = select(PPBestScore, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.score_id == score_id)
@@ -389,8 +370,8 @@ async def _score_where(
mode: GameMode,
mods: list[str] | None = None,
user: User | None = None,
) -> list[ColumnElement[bool]] | None:
wheres = [
) -> list[ColumnElement[bool] | TextClause] | None:
wheres: list[ColumnElement[bool] | TextClause] = [
col(BestScore.beatmap_id) == beatmap,
col(BestScore.gamemode) == mode,
]
@@ -410,9 +391,7 @@ async def _score_where(
return None
elif type == LeaderboardType.COUNTRY:
if user and user.is_supporter:
wheres.append(
col(BestScore.user).has(col(User.country_code) == user.country_code)
)
wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
else:
return None
elif type == LeaderboardType.TEAM:
@@ -420,18 +399,14 @@ async def _score_where(
team_membership = await user.awaitable_attrs.team_membership
if team_membership:
team_id = team_membership.team_id
wheres.append(
col(BestScore.user).has(
col(User.team_membership).has(TeamMember.team_id == team_id)
)
)
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
if mods:
if user and user.is_supporter:
wheres.append(
text(
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
).params(w=json.dumps(mods)) # pyright: ignore[reportArgumentType]
).params(w=json.dumps(mods))
)
else:
return None
@@ -654,18 +629,14 @@ def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]:
+ (score.nsmall_tick_hit or 0)
)
total_obj = 0
for statistics, count in (
score.maximum_statistics.items() if score.maximum_statistics else {}
):
for statistics, count in score.maximum_statistics.items() if score.maximum_statistics else {}:
if not isinstance(statistics, HitResult):
statistics = HitResult(statistics)
if statistics.is_scorable():
total_obj += count
return total_length, score.passed or (
total_length > 8
and score.total_score >= 5000
and total_obj_hited >= min(0.1 * total_obj, 20)
total_length > 8 and score.total_score >= 5000 and total_obj_hited >= min(0.1 * total_obj, 20)
)
@@ -678,12 +649,8 @@ async def process_user(
ranked: bool = False,
has_leaderboard: bool = False,
):
assert user.id
assert score.id
mod_for_save = mod_to_save(score.mods)
previous_score_best = await get_user_best_score_in_beatmap(
session, score.beatmap_id, user.id, score.gamemode
)
previous_score_best = await get_user_best_score_in_beatmap(session, score.beatmap_id, user.id, score.gamemode)
previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap(
session, score.beatmap_id, user.id, mod_for_save, score.gamemode
)
@@ -698,9 +665,7 @@ async def process_user(
)
).first()
if mouthly_playcount is None:
mouthly_playcount = MonthlyPlaycounts(
user_id=user.id, year=date.today().year, month=date.today().month
)
mouthly_playcount = MonthlyPlaycounts(user_id=user.id, year=date.today().year, month=date.today().month)
add_to_db = True
statistics = None
for i in await user.awaitable_attrs.statistics:
@@ -708,17 +673,11 @@ async def process_user(
statistics = i
break
if statistics is None:
raise ValueError(
f"User {user.id} does not have statistics for mode {score.gamemode.value}"
)
raise ValueError(f"User {user.id} does not have statistics for mode {score.gamemode.value}")
# pc, pt, tth, tts
statistics.total_score += score.total_score
difference = (
score.total_score - previous_score_best.total_score
if previous_score_best
else score.total_score
)
difference = score.total_score - previous_score_best.total_score if previous_score_best else score.total_score
if difference > 0 and score.passed and ranked:
match score.rank:
case Rank.X:
@@ -746,11 +705,8 @@ async def process_user(
statistics.ranked_score += difference
statistics.level_current = calculate_score_to_level(statistics.total_score)
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)
new_score_position = await get_score_position_by_user(
session, score.beatmap_id, user, score.gamemode
)
new_score_position = await get_score_position_by_user(session, score.beatmap_id, user, score.gamemode)
total_users = await session.exec(select(func.count()).select_from(User))
assert total_users is not None
score_range = min(50, math.ceil(float(total_users.one()) * 0.01))
if new_score_position <= score_range and new_score_position > 0:
# Get the scores that might be displaced
@@ -774,11 +730,7 @@ async def process_user(
)
# If this score was previously in top positions but now pushed out
if (
i < score_range
and displaced_position > score_range
and displaced_position is not None
):
if i < score_range and displaced_position > score_range and displaced_position is not None:
# Create rank lost event for the displaced user
rank_lost_event = Event(
created_at=datetime.now(UTC),
@@ -814,10 +766,7 @@ async def process_user(
)
# 情况3: 有最佳分数记录和该mod组合的记录且是同一个记录更新得分更高的情况
elif (
previous_score_best.score_id == previous_score_best_mod.score_id
and difference > 0
):
elif previous_score_best.score_id == previous_score_best_mod.score_id and difference > 0:
previous_score_best.total_score = score.total_score
previous_score_best.rank = score.rank
previous_score_best.score_id = score.id
@@ -847,9 +796,7 @@ async def process_user(
statistics.count_300 += score.n300 + score.ngeki
statistics.count_50 += score.n50
statistics.count_miss += score.nmiss
statistics.total_hits += (
score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
)
statistics.total_hits += score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
if score.passed and ranked:
with session.no_autoflush:
@@ -885,7 +832,6 @@ async def process_score(
item_id: int | None = None,
room_id: int | None = None,
) -> Score:
assert user.id
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
gamemode = GameMode.from_int(info.ruleset_id).to_special_mode(info.mods)
score = Score(
@@ -922,20 +868,15 @@ async def process_score(
if can_get_pp:
from app.calculator import pre_fetch_and_calculate_pp
pp = await pre_fetch_and_calculate_pp(
score, beatmap_id, session, redis, fetcher
)
pp = await pre_fetch_and_calculate_pp(score, beatmap_id, session, redis, fetcher)
score.pp = pp
session.add(score)
user_id = user.id
await session.commit()
await session.refresh(score)
if can_get_pp and score.pp != 0:
previous_pp_best = await get_user_best_pp_in_beatmap(
session, beatmap_id, user_id, score.gamemode
)
previous_pp_best = await get_user_best_pp_in_beatmap(session, beatmap_id, user_id, score.gamemode)
if previous_pp_best is None or score.pp > previous_pp_best.pp:
assert score.id
best_score = PPBestScore(
user_id=user_id,
score_id=score.id,

View File

@@ -7,6 +7,7 @@ from .beatmap import Beatmap
from .lazer_user import User
from sqlalchemy import Column, DateTime, Index
from sqlalchemy.orm import Mapped
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
@@ -14,16 +15,12 @@ class ScoreTokenBase(SQLModel, UTCBaseModel):
score_id: int | None = Field(sa_column=Column(BigInteger), default=None)
ruleset_id: GameMode
playlist_item_id: int | None = Field(default=None) # playlist
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
created_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
class ScoreToken(ScoreTokenBase, table=True):
__tablename__ = "score_tokens" # pyright: ignore[reportAssignmentType]
__tablename__: str = "score_tokens"
__table_args__ = (Index("idx_user_playlist", "user_id", "playlist_item_id"),)
id: int | None = Field(
@@ -37,8 +34,8 @@ class ScoreToken(ScoreTokenBase, table=True):
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
beatmap_id: int = Field(foreign_key="beatmaps.id")
user: User = Relationship()
beatmap: Beatmap = Relationship()
user: Mapped[User] = Relationship()
beatmap: Mapped[Beatmap] = Relationship()
class ScoreTokenResp(ScoreTokenBase):

View File

@@ -58,7 +58,7 @@ class UserStatisticsBase(SQLModel):
class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
__tablename__: str = "lazer_user_statistics"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(
default=None,
@@ -123,9 +123,7 @@ class UserStatisticsResp(UserStatisticsBase):
if "user" in include:
from .lazer_user import RANKING_INCLUDES, UserResp
user = await UserResp.from_db(
await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES
)
user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
s.user = user
user_country = user.country_code
@@ -149,9 +147,7 @@ class UserStatisticsResp(UserStatisticsBase):
return s
async def get_rank(
session: AsyncSession, statistics: UserStatistics, country: str | None = None
) -> int | None:
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
from .lazer_user import User
query = select(
@@ -168,9 +164,7 @@ async def get_rank(
subq = query.subquery()
result = await session.exec(
select(subq.c.rank).where(subq.c.user_id == statistics.user_id)
)
result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id))
rank = result.first()
if rank is None:

View File

@@ -11,9 +11,9 @@ if TYPE_CHECKING:
class Team(SQLModel, UTCBaseModel, table=True):
__tablename__ = "teams" # pyright: ignore[reportAssignmentType]
__tablename__: str = "teams"
id: int | None = Field(default=None, primary_key=True, index=True)
id: int = Field(default=None, primary_key=True, index=True)
name: str = Field(max_length=100)
short_name: str = Field(max_length=10)
flag_url: str | None = Field(default=None)
@@ -26,34 +26,22 @@ class Team(SQLModel, UTCBaseModel, table=True):
class TeamMember(SQLModel, UTCBaseModel, table=True):
__tablename__ = "team_members" # pyright: ignore[reportAssignmentType]
__tablename__: str = "team_members"
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
team_id: int = Field(foreign_key="teams.id")
joined_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
joined_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
user: "User" = Relationship(
back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
)
team: "Team" = Relationship(
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
)
user: "User" = Relationship(back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"})
team: "Team" = Relationship(back_populates="members", sa_relationship_kwargs={"lazy": "joined"})
class TeamRequest(SQLModel, UTCBaseModel, table=True):
__tablename__ = "team_requests" # pyright: ignore[reportAssignmentType]
__tablename__: str = "team_requests"
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
team_id: int = Field(foreign_key="teams.id", primary_key=True)
requested_at: datetime = Field(
default=datetime.now(UTC), sa_column=Column(DateTime)
)
requested_at: datetime = Field(default=datetime.now(UTC), sa_column=Column(DateTime))
user: "User" = Relationship(sa_relationship_kwargs={"lazy": "joined"})
team: "Team" = Relationship(sa_relationship_kwargs={"lazy": "joined"})

View File

@@ -22,7 +22,7 @@ class UserAccountHistoryBase(SQLModel, UTCBaseModel):
class UserAccountHistory(UserAccountHistoryBase, table=True):
__tablename__ = "user_account_history" # pyright: ignore[reportAssignmentType]
__tablename__: str = "user_account_history"
id: int | None = Field(
sa_column=Column(
@@ -32,9 +32,7 @@ class UserAccountHistory(UserAccountHistoryBase, table=True):
primary_key=True,
)
)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
class UserAccountHistoryResp(UserAccountHistoryBase):

View File

@@ -10,27 +10,17 @@ from sqlmodel import Field, SQLModel
class UserLoginLog(SQLModel, table=True):
"""User login log table"""
__tablename__ = "user_login_log" # pyright: ignore[reportAssignmentType]
__tablename__: str = "user_login_log"
id: int | None = Field(default=None, primary_key=True, description="Record ID")
user_id: int = Field(index=True, description="User ID")
ip_address: str = Field(
max_length=45, index=True, description="IP address (supports IPv4 and IPv6)"
)
user_agent: str | None = Field(
default=None, max_length=500, description="User agent information"
)
login_time: datetime = Field(
default_factory=datetime.utcnow, description="Login time"
)
ip_address: str = Field(max_length=45, index=True, description="IP address (supports IPv4 and IPv6)")
user_agent: str | None = Field(default=None, max_length=500, description="User agent information")
login_time: datetime = Field(default_factory=datetime.utcnow, description="Login time")
# GeoIP information
country_code: str | None = Field(
default=None, max_length=2, description="Country code"
)
country_name: str | None = Field(
default=None, max_length=100, description="Country name"
)
country_code: str | None = Field(default=None, max_length=2, description="Country code")
country_name: str | None = Field(default=None, max_length=100, description="Country name")
city_name: str | None = Field(default=None, max_length=100, description="City name")
latitude: str | None = Field(default=None, max_length=20, description="Latitude")
longitude: str | None = Field(default=None, max_length=20, description="Longitude")
@@ -38,22 +28,14 @@ class UserLoginLog(SQLModel, table=True):
# ASN information
asn: int | None = Field(default=None, description="Autonomous System Number")
organization: str | None = Field(
default=None, max_length=200, description="Organization name"
)
organization: str | None = Field(default=None, max_length=200, description="Organization name")
# Login status
login_success: bool = Field(
default=True, description="Whether the login was successful"
)
login_method: str = Field(
max_length=50, description="Login method (password/oauth/etc.)"
)
login_success: bool = Field(default=True, description="Whether the login was successful")
login_method: str = Field(max_length=50, description="Login method (password/oauth/etc.)")
# Additional information
notes: str | None = Field(
default=None, max_length=500, description="Additional notes"
)
notes: str | None = Field(default=None, max_length=500, description="Additional notes")
class Config:
from_attributes = True

View File

@@ -40,15 +40,11 @@ engine = create_async_engine(
redis_client = redis.from_url(settings.redis_url, decode_responses=True)
# Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行
redis_message_client = sync_redis.from_url(
settings.redis_url, decode_responses=True, db=1
)
redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1)
# 数据库依赖
db_session_context: ContextVar[AsyncSession | None] = ContextVar(
"db_session_context", default=None
)
db_session_context: ContextVar[AsyncSession | None] = ContextVar("db_session_context", default=None)
async def get_db():

View File

@@ -25,7 +25,5 @@ async def get_fetcher() -> Fetcher:
if refresh_token:
fetcher.refresh_token = str(refresh_token)
if not fetcher.access_token or not fetcher.refresh_token:
logger.opt(colors=True).info(
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
)
logger.opt(colors=True).info(f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>")
return fetcher

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from datetime import UTC
from typing import cast
from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -16,7 +17,7 @@ def get_scheduler() -> AsyncIOScheduler:
global scheduler
if scheduler is None:
init_scheduler()
return scheduler # pyright: ignore[reportReturnType]
return cast(AsyncIOScheduler, scheduler)
def start_scheduler():

View File

@@ -70,9 +70,7 @@ async def v1_authorize(
if not api_key:
raise HTTPException(status_code=401, detail="Missing API key")
api_key_record = (
await db.exec(select(V1APIKeys).where(V1APIKeys.key == api_key))
).first()
api_key_record = (await db.exec(select(V1APIKeys).where(V1APIKeys.key == api_key))).first()
if not api_key_record:
raise HTTPException(status_code=401, detail="Invalid API key")
@@ -98,9 +96,7 @@ async def get_current_user(
security_scopes: SecurityScopes,
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
token_client_credentials: Annotated[
str | None, Depends(oauth2_client_credentials)
] = None,
token_client_credentials: Annotated[str | None, Depends(oauth2_client_credentials)] = None,
) -> User:
"""获取当前认证用户"""
token = token_pw or token_code or token_client_credentials
@@ -119,9 +115,7 @@ async def get_current_user(
if not is_client:
for scope in security_scopes.scopes:
if scope not in token_record.scope.split(","):
raise HTTPException(
status_code=403, detail=f"Insufficient scope: {scope}"
)
raise HTTPException(status_code=403, detail=f"Insufficient scope: {scope}")
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
if not user:

View File

@@ -121,14 +121,10 @@ class BaseFetcher:
except Exception as e:
last_error = e
if attempt < max_retries:
logger.warning(
f"Request failed (attempt {attempt + 1}/{max_retries + 1}): {e}, retrying..."
)
logger.warning(f"Request failed (attempt {attempt + 1}/{max_retries + 1}): {e}, retrying...")
continue
else:
logger.error(
f"Request failed after {max_retries + 1} attempts: {e}"
)
logger.error(f"Request failed after {max_retries + 1} attempts: {e}")
break
# 如果所有重试都失败了
@@ -196,13 +192,9 @@ class BaseFetcher:
f"fetcher:refresh_token:{self.client_id}",
self.refresh_token,
)
logger.info(
f"Successfully refreshed access token for client {self.client_id}"
)
logger.info(f"Successfully refreshed access token for client {self.client_id}")
except Exception as e:
logger.error(
f"Failed to refresh access token for client {self.client_id}: {e}"
)
logger.error(f"Failed to refresh access token for client {self.client_id}: {e}")
# 清除无效的 token要求重新授权
self.access_token = ""
self.refresh_token = ""
@@ -210,9 +202,7 @@ class BaseFetcher:
redis = get_redis()
await redis.delete(f"fetcher:access_token:{self.client_id}")
await redis.delete(f"fetcher:refresh_token:{self.client_id}")
logger.warning(
f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}"
)
logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}")
raise
async def _trigger_reauthorization(self) -> None:
@@ -237,8 +227,7 @@ class BaseFetcher:
await redis.delete(f"fetcher:refresh_token:{self.client_id}")
logger.warning(
f"All tokens cleared for client {self.client_id}. "
f"Please re-authorize using: {self.authorize_url}"
f"All tokens cleared for client {self.client_id}. Please re-authorize using: {self.authorize_url}"
)
def reset_auth_retry_count(self) -> None:

View File

@@ -7,18 +7,14 @@ from ._base import BaseFetcher
class BeatmapFetcher(BaseFetcher):
async def get_beatmap(
self, beatmap_id: int | None = None, beatmap_checksum: str | None = None
) -> BeatmapResp:
async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapResp:
if beatmap_id:
params = {"id": beatmap_id}
elif beatmap_checksum:
params = {"checksum": beatmap_checksum}
else:
raise ValueError("Either beatmap_id or beatmap_checksum must be provided.")
logger.opt(colors=True).debug(
f"<blue>[BeatmapFetcher]</blue> get_beatmap: <y>{params}</y>"
)
logger.opt(colors=True).debug(f"<blue>[BeatmapFetcher]</blue> get_beatmap: <y>{params}</y>")
return BeatmapResp.model_validate(
await self.request_api(

View File

@@ -18,9 +18,7 @@ class BeatmapRawFetcher(BaseFetcher):
async def get_beatmap_raw(self, beatmap_id: int) -> str:
for url in urls:
req_url = url.format(beatmap_id=beatmap_id)
logger.opt(colors=True).debug(
f"<blue>[BeatmapRawFetcher]</blue> get_beatmap_raw: <y>{req_url}</y>"
)
logger.opt(colors=True).debug(f"<blue>[BeatmapRawFetcher]</blue> get_beatmap_raw: <y>{req_url}</y>")
resp = await self._request(req_url)
if resp.status_code >= 400:
continue
@@ -34,9 +32,7 @@ class BeatmapRawFetcher(BaseFetcher):
)
return response
async def get_or_fetch_beatmap_raw(
self, redis: redis.Redis, beatmap_id: int
) -> str:
async def get_or_fetch_beatmap_raw(self, redis: redis.Redis, beatmap_id: int) -> str:
from app.config import settings
cache_key = f"beatmap:{beatmap_id}:raw"
@@ -48,7 +44,7 @@ class BeatmapRawFetcher(BaseFetcher):
if content:
# 延长缓存时间
await redis.expire(cache_key, cache_expire)
return content # pyright: ignore[reportReturnType]
return content
# 获取并缓存
raw = await self.get_beatmap_raw(beatmap_id)

View File

@@ -10,6 +10,7 @@ from app.helpers.rate_limiter import osu_api_rate_limiter
from app.log import logger
from app.models.beatmap import SearchQueryModel
from app.models.model import Cursor
from app.utils import bg_tasks
from ._base import BaseFetcher
@@ -81,9 +82,7 @@ class BeatmapsetFetcher(BaseFetcher):
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
cache_hash = hashlib.md5(cache_json.encode()).hexdigest()
logger.opt(colors=True).debug(
f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}"
)
logger.opt(colors=True).debug(f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}")
return f"beatmapset:search:{cache_hash}"
@@ -103,22 +102,16 @@ class BeatmapsetFetcher(BaseFetcher):
return {}
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
logger.opt(colors=True).debug(
f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>"
)
logger.opt(colors=True).debug(f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>")
return BeatmapsetResp.model_validate(
await self.request_api(
f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}"
)
await self.request_api(f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}")
)
async def search_beatmapset(
self, query: SearchQueryModel, cursor: Cursor, redis_client: redis.Redis
) -> SearchBeatmapsetsResp:
logger.opt(colors=True).debug(
f"<blue>[BeatmapsetFetcher]</blue> search_beatmapset: <y>{query}</y>"
)
logger.opt(colors=True).debug(f"<blue>[BeatmapsetFetcher]</blue> search_beatmapset: <y>{query}</y>")
# 生成缓存键
cache_key = self._generate_cache_key(query, cursor)
@@ -126,9 +119,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 尝试从缓存获取结果
cached_result = await redis_client.get(cache_key)
if cached_result:
logger.opt(colors=True).debug(
f"<green>[BeatmapsetFetcher]</green> Cache hit for key: <y>{cache_key}</y>"
)
logger.opt(colors=True).debug(f"<green>[BeatmapsetFetcher]</green> Cache hit for key: <y>{cache_key}</y>")
try:
cached_data = json.loads(cached_result)
return SearchBeatmapsetsResp.model_validate(cached_data)
@@ -138,13 +129,9 @@ class BeatmapsetFetcher(BaseFetcher):
)
# 缓存未命中,从 API 获取数据
logger.opt(colors=True).debug(
"<blue>[BeatmapsetFetcher]</blue> Cache miss, fetching from API"
)
logger.opt(colors=True).debug("<blue>[BeatmapsetFetcher]</blue> Cache miss, fetching from API")
params = query.model_dump(
exclude_none=True, exclude_unset=True, exclude_defaults=True
)
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
if query.cursor_string:
params["cursor_string"] = query.cursor_string
@@ -164,39 +151,26 @@ class BeatmapsetFetcher(BaseFetcher):
# 将结果缓存 15 分钟
cache_ttl = 15 * 60 # 15 分钟
await redis_client.set(
cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl
)
await redis_client.set(cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl)
logger.opt(colors=True).debug(
f"<green>[BeatmapsetFetcher]</green> Cached result for key: "
f"<y>{cache_key}</y> (TTL: {cache_ttl}s)"
f"<green>[BeatmapsetFetcher]</green> Cached result for key: <y>{cache_key}</y> (TTL: {cache_ttl}s)"
)
resp = SearchBeatmapsetsResp.model_validate(api_response)
# 智能预取只在用户明确搜索时才预取避免过多API请求
# 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取
if api_response.get("cursor") and (
query.q or query.s != "leaderboard" or cursor
):
if api_response.get("cursor") and (query.q or query.s != "leaderboard" or cursor):
# 在后台预取下1页减少预取量
import asyncio
# 不立即创建任务,而是延迟一段时间再预取
async def delayed_prefetch():
await asyncio.sleep(3.0) # 延迟3秒
await self.prefetch_next_pages(
query, api_response["cursor"], redis_client, pages=1
)
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
# 创建延迟预取任务
task = asyncio.create_task(delayed_prefetch())
# 添加到后台任务集合避免被垃圾回收
if not hasattr(self, "_background_tasks"):
self._background_tasks = set()
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
bg_tasks.add_task(delayed_prefetch)
return resp
@@ -218,18 +192,14 @@ class BeatmapsetFetcher(BaseFetcher):
# 使用当前 cursor 请求下一页
next_query = query.model_copy()
logger.opt(colors=True).debug(
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetching page {page + 1}"
)
logger.opt(colors=True).debug(f"<cyan>[BeatmapsetFetcher]</cyan> Prefetching page {page + 1}")
# 生成下一页的缓存键
next_cache_key = self._generate_cache_key(next_query, cursor)
# 检查是否已经缓存
if await redis_client.exists(next_cache_key):
logger.opt(colors=True).debug(
f"<cyan>[BeatmapsetFetcher]</cyan> Page {page + 1} already cached"
)
logger.opt(colors=True).debug(f"<cyan>[BeatmapsetFetcher]</cyan> Page {page + 1} already cached")
# 尝试从缓存获取cursor继续预取
cached_data = await redis_client.get(next_cache_key)
if cached_data:
@@ -247,9 +217,7 @@ class BeatmapsetFetcher(BaseFetcher):
await asyncio.sleep(1.5) # 1.5秒延迟
# 请求下一页数据
params = next_query.model_dump(
exclude_none=True, exclude_unset=True, exclude_defaults=True
)
params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
for k, v in cursor.items():
params[f"cursor[{k}]"] = v
@@ -277,22 +245,18 @@ class BeatmapsetFetcher(BaseFetcher):
)
logger.opt(colors=True).debug(
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetched page {page + 1} "
f"(TTL: {prefetch_ttl}s)"
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetched page {page + 1} (TTL: {prefetch_ttl}s)"
)
except Exception as e:
logger.opt(colors=True).warning(
f"<yellow>[BeatmapsetFetcher]</yellow> Prefetch failed: {e}"
)
logger.opt(colors=True).warning(f"<yellow>[BeatmapsetFetcher]</yellow> Prefetch failed: {e}")
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
"""预热主页缓存"""
homepage_queries = self._get_homepage_queries()
logger.opt(colors=True).info(
f"<magenta>[BeatmapsetFetcher]</magenta> Starting homepage cache warmup "
f"({len(homepage_queries)} queries)"
f"<magenta>[BeatmapsetFetcher]</magenta> Starting homepage cache warmup ({len(homepage_queries)} queries)"
)
for i, (query, cursor) in enumerate(homepage_queries):
@@ -306,15 +270,12 @@ class BeatmapsetFetcher(BaseFetcher):
# 检查是否已经缓存
if await redis_client.exists(cache_key):
logger.opt(colors=True).debug(
f"<magenta>[BeatmapsetFetcher]</magenta> "
f"Query {query.sort} already cached"
f"<magenta>[BeatmapsetFetcher]</magenta> Query {query.sort} already cached"
)
continue
# 请求并缓存
params = query.model_dump(
exclude_none=True, exclude_unset=True, exclude_defaults=True
)
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
api_response = await self.request_api(
"https://osu.ppy.sh/api/v2/beatmapsets/search",
@@ -334,17 +295,13 @@ class BeatmapsetFetcher(BaseFetcher):
)
logger.opt(colors=True).info(
f"<magenta>[BeatmapsetFetcher]</magenta> "
f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
f"<magenta>[BeatmapsetFetcher]</magenta> Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
)
if api_response.get("cursor"):
await self.prefetch_next_pages(
query, api_response["cursor"], redis_client, pages=2
)
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
except Exception as e:
logger.opt(colors=True).error(
f"<red>[BeatmapsetFetcher]</red> "
f"Failed to warmup cache for {query.sort}: {e}"
f"<red>[BeatmapsetFetcher]</red> Failed to warmup cache for {query.sort}: {e}"
)

View File

@@ -55,14 +55,9 @@ class GeoIPHelper:
- 临时目录退出后自动清理
"""
if not self.license_key:
raise ValueError(
"缺少 MaxMind License Key请传入或设置环境变量 MAXMIND_LICENSE_KEY"
)
raise ValueError("缺少 MaxMind License Key请传入或设置环境变量 MAXMIND_LICENSE_KEY")
url = (
f"{BASE_URL}?edition_id={edition_id}&"
f"license_key={self.license_key}&suffix=tar.gz"
)
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
with client.stream("GET", url) as resp:

View File

@@ -48,8 +48,7 @@ class RateLimiter:
if wait_time > 0:
logger.opt(colors=True).info(
f"<yellow>[RateLimiter]</yellow> Rate limit reached, "
f"waiting {wait_time:.2f}s"
f"<yellow>[RateLimiter]</yellow> Rate limit reached, waiting {wait_time:.2f}s"
)
await asyncio.sleep(wait_time)
current_time = time.time()
@@ -107,11 +106,7 @@ class RateLimiter:
"max_requests_per_minute": self.max_requests_per_minute,
"burst_requests": len(self.burst_times),
"burst_limit": self.burst_limit,
"next_reset_in_seconds": (
60.0 - (current_time - self.request_times[0])
if self.request_times
else 0.0
),
"next_reset_in_seconds": (60.0 - (current_time - self.request_times[0]) if self.request_times else 0.0),
}

View File

@@ -46,14 +46,10 @@ class InterceptHandler(logging.Handler):
color = True
else:
color = False
logger.opt(depth=depth, exception=record.exc_info, colors=color).log(
level, message
)
logger.opt(depth=depth, exception=record.exc_info, colors=color).log(level, message)
def _format_uvicorn_error_log(self, message: str) -> str:
websocket_pattern = (
r'(\d+\.\d+\.\d+\.\d+:\d+)\s*-\s*"WebSocket\s+([^"]+)"\s+([\w\[\]]+)'
)
websocket_pattern = r'(\d+\.\d+\.\d+\.\d+:\d+)\s*-\s*"WebSocket\s+([^"]+)"\s+([\w\[\]]+)'
websocket_match = re.search(websocket_pattern, message)
if websocket_match:
@@ -64,14 +60,8 @@ class InterceptHandler(logging.Handler):
"[accepted]": "<green>[accepted]</green>",
"403": "<red>403 [rejected]</red>",
}
colored_status = status_colors.get(
status.lower(), f"<white>{status}</white>"
)
return (
f'{colored_ip} - "<bold><magenta>WebSocket</magenta> '
f'{path}</bold>" '
f"{colored_status}"
)
colored_status = status_colors.get(status.lower(), f"<white>{status}</white>")
return f'{colored_ip} - "<bold><magenta>WebSocket</magenta> {path}</bold>" {colored_status}'
else:
return message
@@ -121,9 +111,7 @@ logger.remove()
logger.add(
stdout,
colorize=True,
format=(
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"
),
format=("<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"),
level=settings.log_level,
diagnose=settings.debug,
)

View File

@@ -19,17 +19,11 @@ class Achievement(NamedTuple):
@property
def url(self) -> str:
return (
self.medal_url
or f"https://assets.ppy.sh/medals/client/{self.assets_id}.png"
)
return self.medal_url or f"https://assets.ppy.sh/medals/client/{self.assets_id}.png"
@property
def url2x(self) -> str:
return (
self.medal_url2x
or f"https://assets.ppy.sh/medals/client/{self.assets_id}@2x.png"
)
return self.medal_url2x or f"https://assets.ppy.sh/medals/client/{self.assets_id}@2x.png"
MedalProcessor = Callable[[AsyncSession, "Score", "Beatmap"], Awaitable[bool]]

View File

@@ -11,7 +11,8 @@ class APIMe(UserResp):
"""
/me 端点的响应模型
对应 osu! 的 APIMe 类型,继承 APIUser(UserResp) 并包含 session_verified 字段
session_verified 字段已经在 UserResp 中定义,这里不需要重复定义
"""
pass

View File

@@ -95,11 +95,7 @@ class SearchQueryModel(BaseModel):
q: str = Field("", description="搜索关键词")
c: Annotated[
list[
Literal[
"recommended", "converts", "follows", "spotlights", "featured_artists"
]
],
list[Literal["recommended", "converts", "follows", "spotlights", "featured_artists"]],
BeforeValidator(_parse_list),
PlainSerializer(lambda x: ".".join(x)),
] = Field(
@@ -188,12 +184,10 @@ class SearchQueryModel(BaseModel):
list[Literal["video", "storyboard"]],
BeforeValidator(_parse_list),
PlainSerializer(lambda x: ".".join(x)),
] = Field(
default_factory=list, description=("其他video 有视频 / storyboard 有故事板")
] = Field(default_factory=list, description=("其他video 有视频 / storyboard 有故事板"))
r: Annotated[list[Rank], BeforeValidator(_parse_list), PlainSerializer(lambda x: ".".join(x))] = Field(
default_factory=list, description="成绩"
)
r: Annotated[
list[Rank], BeforeValidator(_parse_list), PlainSerializer(lambda x: ".".join(x))
] = Field(default_factory=list, description="成绩")
played: bool = Field(
default=False,
description="玩过",

View File

@@ -9,12 +9,13 @@ from pydantic import BaseModel
class ExtendedTokenResponse(BaseModel):
"""扩展的令牌响应,支持二次验证状态"""
access_token: str | None = None
token_type: str = "Bearer"
expires_in: int | None = None
refresh_token: str | None = None
scope: str | None = None
# 二次验证相关字段
requires_second_factor: bool = False
verification_message: str | None = None
@@ -23,6 +24,7 @@ class ExtendedTokenResponse(BaseModel):
class SessionState(BaseModel):
"""会话状态"""
user_id: int
username: str
email: str

View File

@@ -145,9 +145,7 @@ class MultiplayerPlaylistItemStats(BaseModel):
class MultiplayerRoomStats(BaseModel):
room_id: int
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(
default_factory=dict
)
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(default_factory=dict)
class MultiplayerRoomScoreSetEvent(BaseModel):

View File

@@ -174,11 +174,7 @@ def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool:
return True
ranked_mods = RANKED_MODS[ruleset_id]
for mod in mods:
if (
app_settings.enable_rx
and mod["acronym"] == "RX"
and ruleset_id in {0, 1, 2}
):
if app_settings.enable_rx and mod["acronym"] == "RX" and ruleset_id in {0, 1, 2}:
continue
if app_settings.enable_ap and mod["acronym"] == "AP" and ruleset_id == 0:
continue
@@ -251,10 +247,7 @@ def get_available_mods(ruleset_id: int, required_mods: list[APIMod]) -> list[API
if mod_acronym in incompatible_mods:
continue
if any(
required_acronym in mod_data["IncompatibleMods"]
for required_acronym in required_mod_acronyms
):
if any(required_acronym in mod_data["IncompatibleMods"] for required_acronym in required_mod_acronyms):
continue
if mod_data.get("UserPlayable", False):

View File

@@ -121,32 +121,21 @@ class PlaylistItem(BaseModel):
star_rating: float
freestyle: bool
def _validate_mod_for_ruleset(
self, mod: APIMod, ruleset_key: int, context: str = "mod"
) -> None:
def _validate_mod_for_ruleset(self, mod: APIMod, ruleset_key: int, context: str = "mod") -> None:
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
# Check if mod is valid for ruleset
if (
typed_ruleset_key not in API_MODS
or mod["acronym"] not in API_MODS[typed_ruleset_key]
):
raise InvokeException(
f"{context} {mod['acronym']} is invalid for this ruleset"
)
if typed_ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[typed_ruleset_key]:
raise InvokeException(f"{context} {mod['acronym']} is invalid for this ruleset")
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
# Check if mod is unplayable in multiplayer
if mod_settings.get("UserPlayable", True) is False:
raise InvokeException(
f"{context} {mod['acronym']} is not playable by users"
)
raise InvokeException(f"{context} {mod['acronym']} is not playable by users")
if mod_settings.get("ValidForMultiplayer", True) is False:
raise InvokeException(
f"{context} {mod['acronym']} is not valid for multiplayer"
)
raise InvokeException(f"{context} {mod['acronym']} is not valid for multiplayer")
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
from typing import Literal, cast
@@ -159,10 +148,7 @@ class PlaylistItem(BaseModel):
incompatible = set(mod1_settings.get("IncompatibleMods", []))
for mod2 in mods[i + 1 :]:
if mod2["acronym"] in incompatible:
raise InvokeException(
f"Mods {mod1['acronym']} and "
f"{mod2['acronym']} are incompatible"
)
raise InvokeException(f"Mods {mod1['acronym']} and {mod2['acronym']} are incompatible")
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
from typing import Literal, cast
@@ -178,10 +164,7 @@ class PlaylistItem(BaseModel):
conflicting_allowed = allowed_acronyms & incompatible
if conflicting_allowed:
conflict_list = ", ".join(conflicting_allowed)
raise InvokeException(
f"Required mod {req_acronym} conflicts with "
f"allowed mods: {conflict_list}"
)
raise InvokeException(f"Required mod {req_acronym} conflicts with allowed mods: {conflict_list}")
def validate_playlist_item_mods(self) -> None:
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
@@ -219,10 +202,7 @@ class PlaylistItem(BaseModel):
# Check if mods are valid for the ruleset
for mod in proposed_mods:
if (
ruleset_key not in API_MODS
or mod["acronym"] not in API_MODS[ruleset_key]
):
if ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[ruleset_key]:
all_proposed_valid = False
continue
valid_mods.append(mod)
@@ -252,9 +232,7 @@ class PlaylistItem(BaseModel):
# Check compatibility with required mods
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
all_mod_acronyms = {
mod["acronym"] for mod in final_valid_mods
} | required_mod_acronyms
all_mod_acronyms = {mod["acronym"] for mod in final_valid_mods} | required_mod_acronyms
# Check for incompatibility between required and user mods
filtered_valid_mods = []
@@ -288,9 +266,7 @@ class PlaylistItem(BaseModel):
class _MultiplayerCountdown(SignalRUnionMessage):
id: int = 0
time_remaining: timedelta
is_exclusive: Annotated[
bool, Field(default=True), SignalRMeta(member_ignore=True)
] = True
is_exclusive: Annotated[bool, Field(default=True), SignalRMeta(member_ignore=True)] = True
class MatchStartCountdown(_MultiplayerCountdown):
@@ -305,17 +281,13 @@ class ServerShuttingDownCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[2]] = 2
MultiplayerCountdown = (
MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
)
MultiplayerCountdown = MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
class MultiplayerRoomUser(BaseModel):
user_id: int
state: MultiplayerUserState = MultiplayerUserState.IDLE
availability: BeatmapAvailability = BeatmapAvailability(
state=DownloadState.UNKNOWN, download_progress=None
)
availability: BeatmapAvailability = BeatmapAvailability(state=DownloadState.UNKNOWN, download_progress=None)
mods: list[APIMod] = Field(default_factory=list)
match_state: MatchUserState | None = None
ruleset_id: int | None = None # freestyle
@@ -358,9 +330,7 @@ class MultiplayerRoom(BaseModel):
expired=item.expired,
playlist_order=item.playlist_order,
played_at=item.played_at,
star_rating=item.beatmap.difficulty_rating
if item.beatmap is not None
else 0.0,
star_rating=item.beatmap.difficulty_rating if item.beatmap is not None else 0.0,
freestyle=item.freestyle,
)
)
@@ -425,9 +395,7 @@ class MultiplayerQueue:
user_item_groups[item.owner_id] = []
user_item_groups[item.owner_id].append(item)
max_items = max(
(len(items) for items in user_item_groups.values()), default=0
)
max_items = max((len(items) for items in user_item_groups.values()), default=0)
for i in range(max_items):
current_set = []
@@ -436,20 +404,13 @@ class MultiplayerQueue:
current_set.append(items[i])
if is_first_set:
current_set.sort(
key=lambda item: (item.playlist_order, item.id)
)
current_set.sort(key=lambda item: (item.playlist_order, item.id))
ordered_active_items.extend(current_set)
first_set_order_by_user_id = {
item.owner_id: idx
for idx, item in enumerate(ordered_active_items)
item.owner_id: idx for idx, item in enumerate(ordered_active_items)
}
else:
current_set.sort(
key=lambda item: first_set_order_by_user_id.get(
item.owner_id, 0
)
)
current_set.sort(key=lambda item: first_set_order_by_user_id.get(item.owner_id, 0))
ordered_active_items.extend(current_set)
is_first_set = False
@@ -464,9 +425,7 @@ class MultiplayerQueue:
continue
item.playlist_order = idx
await Playlist.update(item, self.room.room_id, session)
await self.hub.playlist_changed(
self.server_room, item, beatmap_changed=False
)
await self.hub.playlist_changed(self.server_room, item, beatmap_changed=False)
async def update_current_item(self):
upcoming_items = self.upcoming_items
@@ -494,16 +453,7 @@ class MultiplayerQueue:
raise InvokeException("You are not the host")
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
if (
len(
[
True
for u in self.room.playlist
if u.owner_id == user.user_id and not u.expired
]
)
>= limit
):
if len([True for u in self.room.playlist if u.owner_id == user.user_id and not u.expired]) >= limit:
raise InvokeException(f"You can only have {limit} items in the queue")
if item.freestyle and len(item.allowed_mods) > 0:
@@ -512,9 +462,7 @@ class MultiplayerQueue:
async with with_db() as session:
fetcher = await get_fetcher()
async with session:
beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=item.beatmap_id
)
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
if beatmap is None:
raise InvokeException("Beatmap not found")
if item.beatmap_checksum != beatmap.checksum:
@@ -538,29 +486,19 @@ class MultiplayerQueue:
async with with_db() as session:
fetcher = await get_fetcher()
async with session:
beatmap = await Beatmap.get_or_fetch(
session, fetcher, bid=item.beatmap_id
)
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch")
existing_item = next(
(i for i in self.room.playlist if i.id == item.id), None
)
existing_item = next((i for i in self.room.playlist if i.id == item.id), None)
if existing_item is None:
raise InvokeException(
"Attempted to change an item that doesn't exist"
)
raise InvokeException("Attempted to change an item that doesn't exist")
if existing_item.owner_id != user.user_id and self.room.host != user:
raise InvokeException(
"Attempted to change an item which is not owned by the user"
)
raise InvokeException("Attempted to change an item which is not owned by the user")
if existing_item.expired:
raise InvokeException(
"Attempted to change an item which has already been played"
)
raise InvokeException("Attempted to change an item which has already been played")
item.validate_playlist_item_mods()
item.owner_id = user.user_id
@@ -578,8 +516,7 @@ class MultiplayerQueue:
await self.hub.playlist_changed(
self.server_room,
item,
beatmap_changed=item.beatmap_checksum
!= existing_item.beatmap_checksum,
beatmap_changed=item.beatmap_checksum != existing_item.beatmap_checksum,
)
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
@@ -600,14 +537,10 @@ class MultiplayerQueue:
raise InvokeException("The only item in the room cannot be removed")
if item.owner_id != user.user_id and self.room.host != user:
raise InvokeException(
"Attempted to remove an item which is not owned by the user"
)
raise InvokeException("Attempted to remove an item which is not owned by the user")
if item.expired:
raise InvokeException(
"Attempted to remove an item which has already been played"
)
raise InvokeException("Attempted to remove an item which has already been played")
async with with_db() as session:
await Playlist.delete_item(item.id, self.room.room_id, session)
@@ -668,9 +601,7 @@ class CountdownInfo:
def __init__(self, countdown: MultiplayerCountdown):
self.countdown = countdown
self.duration = (
countdown.time_remaining
if countdown.time_remaining > timedelta(seconds=0)
else timedelta(seconds=0)
countdown.time_remaining if countdown.time_remaining > timedelta(seconds=0) else timedelta(seconds=0)
)
@@ -704,9 +635,7 @@ class MatchTypeHandler(ABC):
async def handle_join(self, user: MultiplayerRoomUser): ...
@abstractmethod
async def handle_request(
self, user: MultiplayerRoomUser, request: MatchRequest
): ...
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
@abstractmethod
async def handle_leave(self, user: MultiplayerRoomUser): ...
@@ -723,9 +652,7 @@ class HeadToHeadHandler(MatchTypeHandler):
await self.hub.change_user_match_state(self.room, user)
@override
async def handle_request(
self, user: MultiplayerRoomUser, request: MatchRequest
): ...
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
@override
async def handle_leave(self, user: MultiplayerRoomUser): ...
@@ -762,9 +689,7 @@ class TeamVersusHandler(MatchTypeHandler):
team_counts = defaultdict(int)
for user in self.room.room.users:
if user.match_state is not None and isinstance(
user.match_state, TeamVersusUserState
):
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
team_counts[user.match_state.team_id] += 1
if team_counts:
@@ -798,9 +723,7 @@ class TeamVersusHandler(MatchTypeHandler):
def get_details(self) -> MatchStartedEventDetail:
teams: dict[int, Literal["blue", "red"]] = {}
for user in self.room.room.users:
if user.match_state is not None and isinstance(
user.match_state, TeamVersusUserState
):
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
return detail
@@ -843,9 +766,7 @@ class ServerMultiplayerRoom:
self._tracked_countdown = {}
async def set_handler(self):
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](
self
)
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](self)
for i in self.room.users:
await self.match_type_handler.handle_join(i)
@@ -871,9 +792,7 @@ class ServerMultiplayerRoom:
info = CountdownInfo(countdown)
self.room.active_countdowns.append(info.countdown)
self._tracked_countdown[countdown.id] = info
await self.hub.send_match_event(
self, CountdownStartedEvent(countdown=info.countdown)
)
await self.hub.send_match_event(self, CountdownStartedEvent(countdown=info.countdown))
info.task = asyncio.create_task(_countdown_task(self))
async def stop_countdown(self, countdown: MultiplayerCountdown):

View File

@@ -53,7 +53,7 @@ class NotificationName(str, Enum):
NotificationName.BEATMAP_OWNER_CHANGE: "beatmap_owner_change",
NotificationName.BEATMAPSET_DISCUSSION_LOCK: "beatmapset_discussion",
NotificationName.BEATMAPSET_DISCUSSION_POST_NEW: "beatmapset_discussion",
NotificationName.BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM: "beatmapset_problem", # noqa: E501
NotificationName.BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM: "beatmapset_problem",
NotificationName.BEATMAPSET_DISCUSSION_REVIEW_NEW: "beatmapset_discussion",
NotificationName.BEATMAPSET_DISCUSSION_UNLOCK: "beatmapset_discussion",
NotificationName.BEATMAPSET_DISQUALIFY: "beatmapset_state",
@@ -164,17 +164,11 @@ class ChannelMessageTeam(ChannelMessageBase):
from app.database import TeamMember
user_team_id = (
await session.exec(
select(TeamMember.team_id).where(TeamMember.user_id == self._user.id)
)
await session.exec(select(TeamMember.team_id).where(TeamMember.user_id == self._user.id))
).first()
if not user_team_id:
return []
user_ids = (
await session.exec(
select(TeamMember.user_id).where(TeamMember.team_id == user_team_id)
)
).all()
user_ids = (await session.exec(select(TeamMember.user_id).where(TeamMember.team_id == user_team_id))).all()
return list(user_ids)

View File

@@ -197,9 +197,7 @@ class SoloScoreSubmissionInfo(BaseModel):
# check incompatible mods
for mod in mods:
if mod["acronym"] in incompatible_mods:
raise ValueError(
f"Mod {mod['acronym']} is incompatible with other mods"
)
raise ValueError(f"Mod {mod['acronym']} is incompatible with other mods")
setting_mods = API_MODS[info.data["ruleset_id"]].get(mod["acronym"])
if not setting_mods:
raise ValueError(f"Invalid mod: {mod['acronym']}")

View File

@@ -22,9 +22,7 @@ class SignalRUnionMessage(BaseModel):
class Transport(BaseModel):
transport: str
transfer_formats: list[str] = Field(
default_factory=lambda: ["Binary", "Text"], alias="transferFormats"
)
transfer_formats: list[str] = Field(default_factory=lambda: ["Binary", "Text"], alias="transferFormats")
class NegotiateResponse(BaseModel):

View File

@@ -89,9 +89,7 @@ class LegacyReplayFrame(BaseModel):
mouse_y: float | None = None
button_state: int
header: Annotated[
FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)
]
header: Annotated[FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)]
class FrameDataBundle(BaseModel):

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta
import re
from typing import Literal, Union
from typing import Literal
from app.auth import (
authenticate_user,
@@ -22,19 +22,19 @@ from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger
from app.models.extended_auth import ExtendedTokenResponse
from app.models.oauth import (
OAuthErrorResponse,
RegistrationRequestErrors,
TokenResponse,
UserRegistrationErrors,
)
from app.models.extended_auth import ExtendedTokenResponse
from app.models.score import GameMode
from app.service.login_log_service import LoginLogService
from app.service.email_verification_service import (
EmailVerificationService,
LoginSessionService
LoginSessionService,
)
from app.service.login_log_service import LoginLogService
from app.service.password_reset_service import password_reset_service
from fastapi import APIRouter, Depends, Form, Request
@@ -44,13 +44,9 @@ from sqlalchemy import text
from sqlmodel import select
def create_oauth_error_response(
error: str, description: str, hint: str, status_code: int = 400
):
def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400):
"""创建标准的 OAuth 错误响应"""
error_data = OAuthErrorResponse(
error=error, error_description=description, hint=hint, message=description
)
error_data = OAuthErrorResponse(error=error, error_description=description, hint=hint, message=description)
return JSONResponse(status_code=status_code, content=error_data.model_dump())
@@ -123,9 +119,7 @@ async def register_user(
)
)
return JSONResponse(
status_code=422, content={"form_error": errors.model_dump()}
)
return JSONResponse(status_code=422, content={"form_error": errors.model_dump()})
try:
# 获取客户端 IP 并查询地理位置
@@ -137,10 +131,7 @@ async def register_user(
geo_info = geoip.lookup(client_ip)
if geo_info and geo_info.get("country_iso"):
country_code = geo_info["country_iso"]
logger.info(
f"User {user_username} registering from "
f"{client_ip}, country: {country_code}"
)
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
else:
logger.warning(f"Could not determine country for IP {client_ip}")
except Exception as e:
@@ -148,7 +139,7 @@ async def register_user(
# 创建新用户
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
result = await db.execute( # pyright: ignore[reportDeprecated]
result = await db.execute(
text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
@@ -173,7 +164,6 @@ async def register_user(
db.add(new_user)
await db.commit()
await db.refresh(new_user)
assert new_user.id is not None, "New user ID should not be None"
for i in [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]:
statistics = UserStatistics(mode=i, user_id=new_user.id)
db.add(statistics)
@@ -193,36 +183,30 @@ async def register_user(
logger.exception(f"Registration error for user {user_username}")
# 返回通用错误
errors = RegistrationRequestErrors(
message="An error occurred while creating your account. Please try again."
)
errors = RegistrationRequestErrors(message="An error occurred while creating your account. Please try again.")
return JSONResponse(
status_code=500, content={"form_error": errors.model_dump()}
)
return JSONResponse(status_code=500, content={"form_error": errors.model_dump()})
@router.post(
"/oauth/token",
response_model=Union[TokenResponse, ExtendedTokenResponse],
response_model=TokenResponse | ExtendedTokenResponse,
name="获取访问令牌",
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
)
async def oauth_token(
db: Database,
request: Request,
grant_type: Literal[
"authorization_code", "refresh_token", "password", "client_credentials"
] = Form(..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"),
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
),
client_id: int = Form(..., description="客户端 ID"),
client_secret: str = Form(..., description="客户端密钥"),
code: str | None = Form(None, description="授权码(仅授权码模式需要)"),
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*'"),
username: str | None = Form(None, description="用户名(仅密码模式需要)"),
password: str | None = Form(None, description="密码(仅密码模式需要)"),
refresh_token: str | None = Form(
None, description="刷新令牌(仅刷新令牌模式需要)"
),
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
redis: Redis = Depends(get_redis),
geoip: GeoIPHelper = Depends(get_geoip_helper),
):
@@ -303,37 +287,33 @@ async def oauth_token(
await db.refresh(user)
# 获取用户信息和客户端信息
user_id = getattr(user, "id")
assert user_id is not None, "User ID should not be None after authentication"
from app.dependencies.geoip import get_client_ip
user_id = user.id
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "")
# 获取国家代码
geo_info = geoip.lookup(ip_address)
country_code = geo_info.get("country_iso", "XX")
# 检查是否为新位置登录
is_new_location = await LoginSessionService.check_new_location(
db, user_id, ip_address, country_code
)
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
# 创建登录会话记录
login_session = await LoginSessionService.create_session(
login_session = await LoginSessionService.create_session( # noqa: F841
db, redis, user_id, ip_address, user_agent, country_code, is_new_location
)
# 如果是新位置登录,需要邮件验证
if is_new_location and settings.enable_email_verification:
# 刷新用户对象以确保属性已加载
await db.refresh(user)
# 发送邮件验证码
verification_sent = await EmailVerificationService.send_verification_email(
db, redis, user_id, user.username, user.email, ip_address, user_agent
)
# 记录需要二次验证的登录尝试
await LoginLogService.record_login(
db=db,
@@ -343,14 +323,16 @@ async def oauth_token(
login_method="password_pending_verification",
notes=f"新位置登录,需要邮件验证 - IP: {ip_address}, 国家: {country_code}",
)
if not verification_sent:
# 邮件发送失败,记录错误
logger.error(f"[Auth] Failed to send email verification code for user {user_id}")
elif is_new_location and not settings.enable_email_verification:
# 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
await LoginSessionService.mark_session_verified(db, user_id)
logger.debug(f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}")
logger.debug(
f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}"
)
else:
# 不是新位置登录,正常登录
await LoginLogService.record_login(
@@ -361,20 +343,17 @@ async def oauth_token(
login_method="password",
notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}",
)
# 无论是否新位置登录都返回正常的token
# session_verified状态通过/me接口的session_verified字段来体现
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
# 获取用户ID避免触发延迟加载
access_token = create_access_token(
data={"sub": str(user_id)}, expires_delta=access_token_expires
)
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
refresh_token_str = generate_refresh_token()
# 存储令牌
assert user_id
await store_token(
db,
user_id,
@@ -423,9 +402,7 @@ async def oauth_token(
# 生成新的访问令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires
)
access_token = create_access_token(data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires)
new_refresh_token = generate_refresh_token()
# 更新令牌
@@ -489,17 +466,11 @@ async def oauth_token(
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
# 重新查询只获取ID避免触发延迟加载
id_result = await db.exec(select(User.id).where(User.username == username))
user_id = id_result.first()
access_token = create_access_token(
data={"sub": str(user_id)}, expires_delta=access_token_expires
)
user_id = user.id
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
refresh_token_str = generate_refresh_token()
# 存储令牌
assert user_id
await store_token(
db,
user_id,
@@ -539,9 +510,7 @@ async def oauth_token(
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": "3"}, expires_delta=access_token_expires
)
access_token = create_access_token(data={"sub": "3"}, expires_delta=access_token_expires)
refresh_token_str = generate_refresh_token()
# 存储令牌
@@ -567,7 +536,7 @@ async def oauth_token(
@router.post(
"/password-reset/request",
name="请求密码重置",
description="通过邮箱请求密码重置验证码"
description="通过邮箱请求密码重置验证码",
)
async def request_password_reset(
request: Request,
@@ -578,42 +547,26 @@ async def request_password_reset(
请求密码重置
"""
from app.dependencies.geoip import get_client_ip
# 获取客户端信息
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "")
# 请求密码重置
success, message = await password_reset_service.request_password_reset(
email=email.lower().strip(),
ip_address=ip_address,
user_agent=user_agent,
redis=redis
redis=redis,
)
if success:
return JSONResponse(
status_code=200,
content={
"success": True,
"message": message
}
)
return JSONResponse(status_code=200, content={"success": True, "message": message})
else:
return JSONResponse(
status_code=400,
content={
"success": False,
"error": message
}
)
return JSONResponse(status_code=400, content={"success": False, "error": message})
@router.post(
"/password-reset/reset",
name="重置密码",
description="使用验证码重置密码"
)
@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
async def reset_password(
request: Request,
email: str = Form(..., description="邮箱地址"),
@@ -625,32 +578,20 @@ async def reset_password(
重置密码
"""
from app.dependencies.geoip import get_client_ip
# 获取客户端信息
ip_address = get_client_ip(request)
# 重置密码
success, message = await password_reset_service.reset_password(
email=email.lower().strip(),
reset_code=reset_code.strip(),
new_password=new_password,
ip_address=ip_address,
redis=redis
redis=redis,
)
if success:
return JSONResponse(
status_code=200,
content={
"success": True,
"message": message
}
)
return JSONResponse(status_code=200, content={"success": True, "message": message})
else:
return JSONResponse(
status_code=400,
content={
"success": False,
"error": message
}
)
return JSONResponse(status_code=400, content={"success": False, "error": message})

View File

@@ -43,9 +43,9 @@ async def get_notifications(
current_user: User = Security(get_client_user),
):
if settings.server_url is not None:
notification_endpoint = f"{settings.server_url}notification-server".replace(
"http://", "ws://"
).replace("https://", "wss://")
notification_endpoint = f"{settings.server_url}notification-server".replace("http://", "ws://").replace(
"https://", "wss://"
)
else:
notification_endpoint = "/notification-server"
query = select(UserNotification).where(
@@ -96,21 +96,15 @@ async def _get_notifications(
query = base_query.where(UserNotification.notification_id == identity.id)
if identity.object_id is not None:
query = base_query.where(
col(UserNotification.notification).has(
col(Notification.object_id) == identity.object_id
)
col(UserNotification.notification).has(col(Notification.object_id) == identity.object_id)
)
if identity.object_type is not None:
query = base_query.where(
col(UserNotification.notification).has(
col(Notification.object_type) == identity.object_type
)
col(UserNotification.notification).has(col(Notification.object_type) == identity.object_type)
)
if identity.category is not None:
query = base_query.where(
col(UserNotification.notification).has(
col(Notification.category) == identity.category
)
col(UserNotification.notification).has(col(Notification.category) == identity.category)
)
result.update({n.notification_id: n for n in await session.exec(query)})
return list(result.values())
@@ -134,7 +128,6 @@ async def mark_notifications_as_read(
for user_notification in user_notifications:
user_notification.is_read = True
assert current_user.id
await server.send_event(
current_user.id,
ChatEvent(

View File

@@ -91,9 +91,7 @@ class Bot:
if reply:
await self._send_reply(user, channel, reply, session)
async def _send_message(
self, channel: ChatChannel, content: str, session: AsyncSession
) -> None:
async def _send_message(self, channel: ChatChannel, content: str, session: AsyncSession) -> None:
bot = await session.get(User, self.bot_user_id)
if bot is None:
return
@@ -101,7 +99,6 @@ class Bot:
if channel_id is None:
return
assert bot.id is not None
msg = ChatMessage(
channel_id=channel_id,
content=content,
@@ -115,9 +112,7 @@ class Bot:
resp = await ChatMessageResp.from_db(msg, session, bot)
await server.send_message_to_channel(resp)
async def _ensure_pm_channel(
self, user: User, session: AsyncSession
) -> ChatChannel | None:
async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
user_id = user.id
if user_id is None:
return None
@@ -160,9 +155,7 @@ bot = Bot()
@bot.command("help")
async def _help(
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
) -> str:
async def _help(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
cmds = sorted(bot._handlers.keys())
if args:
target = args[0].lower()
@@ -175,9 +168,7 @@ async def _help(
@bot.command("roll")
def _roll(
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
) -> str:
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
if len(args) > 0 and args[0].isdigit():
r = random.randint(1, int(args[0]))
else:
@@ -186,13 +177,9 @@ def _roll(
@bot.command("stats")
async def _stats(
user: User, args: list[str], session: AsyncSession, channel: ChatChannel
) -> str:
async def _stats(user: User, args: list[str], session: AsyncSession, channel: ChatChannel) -> str:
if len(args) >= 1:
target_user = (
await session.exec(select(User).where(User.username == args[0]))
).first()
target_user = (await session.exec(select(User).where(User.username == args[0]))).first()
if not target_user:
return f"User '{args[0]}' not found."
else:
@@ -202,14 +189,8 @@ async def _stats(
if len(args) >= 2:
gamemode = GameMode.parse(args[1].upper())
if gamemode is None:
subquery = (
select(func.max(Score.id))
.where(Score.user_id == target_user.id)
.scalar_subquery()
)
last_score = (
await session.exec(select(Score).where(Score.id == subquery))
).first()
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
if last_score is not None:
gamemode = last_score.gamemode
else:
@@ -295,9 +276,7 @@ async def _mp_host(
return "Usage: !mp host <username>"
username = args[0]
user_id = (
await session.exec(select(User.id).where(User.username == username))
).first()
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
if not user_id:
return f"User '{username}' not found."
@@ -362,24 +341,18 @@ async def _mp_team(
if team is None:
return "Invalid team colour. Use 'red' or 'blue'."
user_id = (
await session.exec(select(User.id).where(User.username == username))
).first()
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
if not user_id:
return f"User '{username}' not found."
user_client = MultiplayerHubs.get_client_by_id(str(user_id))
if not user_client:
return f"User '{username}' is not in the room."
if (
user_client.user_id != signalr_client.user_id
and room.room.host.user_id != signalr_client.user_id
):
assert room.room.host
if user_client.user_id != signalr_client.user_id and room.room.host.user_id != signalr_client.user_id:
return "You are not allowed to change other users' teams."
try:
await MultiplayerHubs.SendMatchRequest(
user_client, ChangeTeamRequest(team_id=team)
)
await MultiplayerHubs.SendMatchRequest(user_client, ChangeTeamRequest(team_id=team))
return ""
except InvokeException as e:
return e.message
@@ -414,9 +387,7 @@ async def _mp_kick(
return "Usage: !mp kick <username>"
username = args[0]
user_id = (
await session.exec(select(User.id).where(User.username == username))
).first()
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
if not user_id:
return f"User '{username}' not found."
@@ -456,10 +427,7 @@ async def _mp_map(
try:
beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id)
if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode:
return (
f"Cannot convert to {playmode.value}. "
f"Original mode is {beatmap.mode.value}."
)
return f"Cannot convert to {playmode.value}. Original mode is {beatmap.mode.value}."
except HTTPError:
return "Beatmap not found"
@@ -530,9 +498,7 @@ async def _mp_mods(
if freestyle:
item.allowed_mods = []
elif freemod:
item.allowed_mods = get_available_mods(
current_item.ruleset_id, required_mods
)
item.allowed_mods = get_available_mods(current_item.ruleset_id, required_mods)
else:
item.allowed_mods = allowed_mods
item.required_mods = required_mods
@@ -601,14 +567,9 @@ async def _score(
include_fail: bool = False,
gamemode: GameMode | None = None,
) -> str:
q = (
select(Score)
.where(Score.user_id == user_id)
.order_by(col(Score.id).desc())
.options(joinedload(Score.beatmap))
)
q = select(Score).where(Score.user_id == user_id).order_by(col(Score.id).desc()).options(joinedload(Score.beatmap))
if not include_fail:
q = q.where(Score.passed.is_(True))
q = q.where(col(Score.passed).is_(True))
if gamemode is not None:
q = q.where(Score.gamemode == gamemode)
@@ -619,17 +580,13 @@ async def _score(
result = f"""{score.beatmap.beatmapset.title} [{score.beatmap.version}] ({score.gamemode.name.lower()})
Played at {score.started_at}
{score.pp:.2f}pp {score.accuracy:.2%} {",".join(mod_to_save(score.mods))} {score.rank.name.upper()}
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}""" # noqa: E501
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}"""
if score.gamemode == GameMode.MANIA:
keys = next(
(mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None
)
keys = next((mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None)
if keys is None:
keys = f"{int(score.beatmap.cs)}K"
p_d_g = f"{score.ngeki / score.n300:.2f}:1" if score.n300 > 0 else "inf:1"
result += (
f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
)
result += f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
return result

View File

@@ -38,27 +38,18 @@ class UpdateResponse(BaseModel):
)
async def get_update(
session: Database,
history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录"
),
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
includes: list[str] = Query(
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
),
includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
):
resp = UpdateResponse()
if "presence" in includes:
assert current_user.id
channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel:
# 提取必要的属性避免惰性加载
channel_type = db_channel.type
@@ -69,34 +60,20 @@ async def get_update(
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
if "silences" in includes:
if history_since:
silences = (
await session.exec(
select(SilenceUser).where(col(SilenceUser.id) > history_since)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(
select(SilenceUser).where(
col(SilenceUser.banned_at) > msg.timestamp
)
)
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
return resp
@@ -115,15 +92,9 @@ async def join_channel(
):
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -145,15 +116,9 @@ async def leave_channel(
):
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -173,27 +138,20 @@ async def get_channel_list(
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
):
channels = (
await session.exec(
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
)
).all()
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
results = []
for channel in channels:
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
channel_type = channel.type
assert channel_id is not None
results.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
return results
@@ -219,15 +177,9 @@ async def get_channel(
):
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -237,8 +189,6 @@ async def get_channel(
channel_type = db_channel.type
channel_name = db_channel.name
assert channel_id is not None
users = []
if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:]
@@ -259,9 +209,7 @@ async def get_channel(
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
@@ -284,9 +232,7 @@ class CreateChannelReq(BaseModel):
raise ValueError("target_id must be set for PM channels")
else:
if self.target_ids is None or self.channel is None or self.message is None:
raise ValueError(
"target_ids, channel, and message must be set for ANNOUNCE channels"
)
raise ValueError("target_ids, channel, and message must be set for ANNOUNCE channels")
return self
@@ -312,24 +258,20 @@ async def create_channel(
raise HTTPException(status_code=403, detail=block)
channel = await ChatChannel.get_pm_channel(
current_user.id, # pyright: ignore[reportArgumentType]
current_user.id,
req.target_id, # pyright: ignore[reportArgumentType]
session,
)
channel_name = f"pm_{current_user.id}_{req.target_id}"
else:
channel_name = req.channel.name if req.channel else "Unnamed Channel"
result = await session.exec(
select(ChatChannel).where(ChatChannel.name == channel_name)
)
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
channel = result.first()
if channel is None:
channel = ChatChannel(
name=channel_name,
description=req.channel.description
if req.channel
else "Private message channel",
description=req.channel.description if req.channel else "Private message channel",
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
)
session.add(channel)
@@ -340,16 +282,13 @@ async def create_channel(
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
else:
target_users = await session.exec(
select(User).where(col(User.id).in_(req.target_ids or []))
)
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
await server.batch_join_channel([*target_users, current_user], channel, session)
await server.join_channel(current_user, channel, session)
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
assert channel_id
return await ChatChannelResp.from_db(
channel,

View File

@@ -41,33 +41,19 @@ class KeepAliveResp(BaseModel):
)
async def keep_alive(
session: Database,
history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录"
),
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
):
resp = KeepAliveResp()
if history_since:
silences = (
await session.exec(
select(SilenceUser).where(col(SilenceUser.id) > history_since)
)
).all()
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(
select(SilenceUser).where(
col(SilenceUser.banned_at) > msg.timestamp
)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
return resp
@@ -93,15 +79,9 @@ async def send_message(
):
# 使用明确的查询来获取 channel避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -111,9 +91,6 @@ async def send_message(
channel_type = db_channel.type
channel_name = db_channel.name
assert channel_id is not None
assert current_user.id
# 使用 Redis 消息系统发送消息 - 立即返回
resp = await redis_message_system.send_message(
channel_id=channel_id,
@@ -125,9 +102,7 @@ async def send_message(
# 立即广播消息给所有客户端
is_bot_command = req.message.startswith("!")
await server.send_message_to_channel(
resp, is_bot_command and channel_type == ChannelType.PUBLIC
)
await server.send_message_to_channel(resp, is_bot_command and channel_type == ChannelType.PUBLIC)
# 处理机器人命令
if is_bot_command:
@@ -147,14 +122,10 @@ async def send_message(
if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:]
await server.new_private_notification(
ChannelMessage.init(
temp_msg, current_user, [int(u) for u in user_ids], channel_type
)
ChannelMessage.init(temp_msg, current_user, [int(u) for u in user_ids], channel_type)
)
elif channel_type == ChannelType.TEAM:
await server.new_private_notification(
ChannelMessageTeam.init(temp_msg, current_user)
)
await server.new_private_notification(ChannelMessageTeam.init(temp_msg, current_user))
return resp
@@ -176,22 +147,15 @@ async def get_message(
):
# 使用明确的查询获取 channel避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
# 提取必要的属性避免惰性加载
channel_id = db_channel.channel_id
assert channel_id is not None
# 使用 Redis 消息系统获取消息
try:
@@ -230,23 +194,15 @@ async def mark_as_read(
):
# 使用明确的查询获取 channel避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
# 立即提取需要的属性
channel_id = db_channel.channel_id
assert channel_id
assert current_user.id
await server.mark_as_read(channel_id, current_user.id, message)
@@ -283,7 +239,6 @@ async def create_new_pm(
if not is_can_pm:
raise HTTPException(status_code=403, detail=block)
assert user_id
channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session)
if channel is None:
channel = ChatChannel(
@@ -297,7 +252,6 @@ async def create_new_pm(
await session.refresh(target)
await session.refresh(current_user)
assert channel.channel_id
await server.batch_join_channel([target, current_user], channel, session)
channel_resp = await ChatChannelResp.from_db(
channel, session, current_user, redis, server.channels[channel.channel_id]

View File

@@ -17,6 +17,7 @@ from app.log import logger
from app.models.chat import ChatEvent
from app.models.notification import NotificationDetail
from app.service.subscribers.chat import ChatSubscriber
from app.utils import bg_tasks
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
@@ -37,20 +38,11 @@ class ChatServer:
self.ChatSubscriber.chat_server = self
self._subscribed = False
def _add_task(self, task):
task = asyncio.create_task(task)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
def connect(self, user_id: int, client: WebSocket):
self.connect_client[user_id] = client
def get_user_joined_channel(self, user_id: int) -> list[int]:
return [
channel_id
for channel_id, users in self.channels.items()
if user_id in users
]
return [channel_id for channel_id, users in self.channels.items() if user_id in users]
async def disconnect(self, user: User, session: AsyncSession):
user_id = user.id
@@ -61,9 +53,7 @@ class ChatServer:
channel.remove(user_id)
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
).first()
if db_channel:
await self.leave_channel(user, db_channel, session)
@@ -93,11 +83,10 @@ class ChatServer:
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
async def send_message_to_channel(
self, message: ChatMessageResp, is_bot_command: bool = False
):
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
logger.info(
f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}"
f"Sending message to channel {message.channel_id}, message_id: "
f"{message.message_id}, is_bot_command: {is_bot_command}"
)
event = ChatEvent(
@@ -106,62 +95,44 @@ class ChatServer:
)
if is_bot_command:
logger.info(f"Sending bot command to user {message.sender_id}")
self._add_task(self.send_event(message.sender_id, event))
bg_tasks.add_task(self.send_event, message.sender_id, event)
else:
# 总是广播消息无论是临时ID还是真实ID
logger.info(
f"Broadcasting message to all users in channel {message.channel_id}"
)
self._add_task(
self.broadcast(
message.channel_id,
event,
)
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
bg_tasks.add_task(
self.broadcast,
message.channel_id,
event,
)
# 只有真实消息 ID正数且非零才进行标记已读和设置最后消息
# Redis 消息系统生成的ID都是正数所以这里应该都能正常处理
if message.message_id and message.message_id > 0:
await self.mark_as_read(
message.channel_id, message.sender_id, message.message_id
)
await self.redis.set(
f"chat:{message.channel_id}:last_msg", message.message_id
)
logger.info(
f"Updated last message ID for channel {message.channel_id} to {message.message_id}"
)
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
else:
logger.debug(
f"Skipping last message update for message ID: {message.message_id}"
)
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
async def batch_join_channel(
self, users: list[User], channel: ChatChannel, session: AsyncSession
):
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
channel_id = channel.channel_id
assert channel_id is not None
not_joined = []
if channel_id not in self.channels:
self.channels[channel_id] = []
for user in users:
assert user.id is not None
if user.id not in self.channels[channel_id]:
self.channels[channel_id].append(user.id)
not_joined.append(user)
for user in not_joined:
assert user.id is not None
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id]
if channel.type != ChannelType.PUBLIC
else None,
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
)
await self.send_event(
user.id,
@@ -171,13 +142,9 @@ class ChatServer:
),
)
async def join_channel(
self, user: User, channel: ChatChannel, session: AsyncSession
) -> ChatChannelResp:
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
user_id = user.id
channel_id = channel.channel_id
assert channel_id is not None
assert user_id is not None
if channel_id not in self.channels:
self.channels[channel_id] = []
@@ -202,13 +169,9 @@ class ChatServer:
return channel_resp
async def leave_channel(
self, user: User, channel: ChatChannel, session: AsyncSession
) -> None:
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
user_id = user.id
channel_id = channel.channel_id
assert channel_id is not None
assert user_id is not None
if channel_id in self.channels and user_id in self.channels[channel_id]:
self.channels[channel_id].remove(user_id)
@@ -221,9 +184,7 @@ class ChatServer:
session,
user,
self.redis,
self.channels.get(channel_id)
if channel.type != ChannelType.PUBLIC
else None,
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
)
await self.send_event(
user_id,
@@ -236,11 +197,7 @@ class ChatServer:
async def join_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel is None:
return
@@ -253,11 +210,7 @@ class ChatServer:
async def leave_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel is None:
return
@@ -270,13 +223,7 @@ class ChatServer:
async def new_private_notification(self, detail: NotificationDetail):
async with with_db() as session:
id = await insert_notification(session, detail)
users = (
await session.exec(
select(UserNotification).where(
UserNotification.notification_id == id
)
)
).all()
users = (await session.exec(select(UserNotification).where(UserNotification.notification_id == id))).all()
for user_notification in users:
data = user_notification.notification.model_dump()
data["is_read"] = user_notification.is_read
@@ -308,9 +255,7 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
await ws.close(code=1000)
break
except WebSocketDisconnect as e:
logger.info(
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
)
logger.info(f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}")
except RuntimeError as e:
if "disconnect message" in str(e):
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
@@ -332,11 +277,7 @@ async def chat_websocket(
async for session in factory():
token = authorization[7:]
if (
user := await get_current_user(
session, SecurityScopes(scopes=["chat.read"]), token_pw=token
)
) is None:
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None:
await websocket.close(code=1008)
return
@@ -346,12 +287,9 @@ async def chat_websocket(
await websocket.close(code=1008)
return
user_id = user.id
assert user_id
server.connect(user_id, websocket)
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
if db_channel is not None:
await server.join_channel(user, db_channel, session)

View File

@@ -2,22 +2,20 @@
密码重置管理接口
"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from __future__ import annotations
from app.dependencies.database import get_redis
from app.service.password_reset_service import password_reset_service
from app.log import logger
from app.service.password_reset_service import password_reset_service
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
router = APIRouter(prefix="/admin/password-reset", tags=["密码重置管理"])
@router.get(
"/status/{email}",
name="查询重置状态",
description="查询指定邮箱的密码重置状态"
)
@router.get("/status/{email}", name="查询重置状态", description="查询指定邮箱的密码重置状态")
async def get_password_reset_status(
email: str,
redis: Redis = Depends(get_redis),
@@ -25,28 +23,16 @@ async def get_password_reset_status(
"""查询密码重置状态"""
try:
info = await password_reset_service.get_reset_code_info(email, redis)
return JSONResponse(
status_code=200,
content={
"success": True,
"data": info
}
)
return JSONResponse(status_code=200, content={"success": True, "data": info})
except Exception as e:
logger.error(f"[Admin] Failed to get password reset status for {email}: {e}")
return JSONResponse(
status_code=500,
content={
"success": False,
"error": "获取状态失败"
}
)
return JSONResponse(status_code=500, content={"success": False, "error": "获取状态失败"})
@router.delete(
"/cleanup/{email}",
name="清理重置数据",
description="强制清理指定邮箱的密码重置数据"
description="强制清理指定邮箱的密码重置数据",
)
async def force_cleanup_reset(
email: str,
@@ -55,38 +41,23 @@ async def force_cleanup_reset(
"""强制清理密码重置数据"""
try:
success = await password_reset_service.force_cleanup_user_reset(email, redis)
if success:
return JSONResponse(
status_code=200,
content={
"success": True,
"message": f"已清理邮箱 {email} 的重置数据"
}
content={"success": True, "message": f"已清理邮箱 {email} 的重置数据"},
)
else:
return JSONResponse(
status_code=500,
content={
"success": False,
"error": "清理失败"
}
)
return JSONResponse(status_code=500, content={"success": False, "error": "清理失败"})
except Exception as e:
logger.error(f"[Admin] Failed to cleanup password reset for {email}: {e}")
return JSONResponse(
status_code=500,
content={
"success": False,
"error": "清理操作失败"
}
)
return JSONResponse(status_code=500, content={"success": False, "error": "清理操作失败"})
@router.post(
"/cleanup/expired",
name="清理过期验证码",
description="清理所有过期的密码重置验证码"
description="清理所有过期的密码重置验证码",
)
async def cleanup_expired_codes(
redis: Redis = Depends(get_redis),
@@ -99,25 +70,15 @@ async def cleanup_expired_codes(
content={
"success": True,
"message": f"已清理 {count} 个过期的验证码",
"cleaned_count": count
}
"cleaned_count": count,
},
)
except Exception as e:
logger.error(f"[Admin] Failed to cleanup expired codes: {e}")
return JSONResponse(
status_code=500,
content={
"success": False,
"error": "清理操作失败"
}
)
return JSONResponse(status_code=500, content={"success": False, "error": "清理操作失败"})
@router.get(
"/stats",
name="重置统计",
description="获取密码重置的统计信息"
)
@router.get("/stats", name="重置统计", description="获取密码重置的统计信息")
async def get_reset_statistics(
redis: Redis = Depends(get_redis),
):
@@ -126,53 +87,42 @@ async def get_reset_statistics(
# 获取所有重置相关的键
reset_keys = await redis.keys("password_reset:code:*")
rate_limit_keys = await redis.keys("password_reset:rate_limit:*")
active_resets = 0
used_resets = 0
active_rate_limits = 0
# 统计活跃重置
for key in reset_keys:
data_str = await redis.get(key)
if data_str:
try:
import json
data = json.loads(data_str)
if data.get("used", False):
used_resets += 1
else:
active_resets += 1
except:
except Exception:
pass
# 统计频率限制
for key in rate_limit_keys:
ttl = await redis.ttl(key)
if ttl > 0:
active_rate_limits += 1
stats = {
"total_reset_codes": len(reset_keys),
"active_resets": active_resets,
"used_resets": used_resets,
"active_rate_limits": active_rate_limits,
"total_rate_limit_keys": len(rate_limit_keys)
"total_rate_limit_keys": len(rate_limit_keys),
}
return JSONResponse(
status_code=200,
content={
"success": True,
"data": stats
}
)
return JSONResponse(status_code=200, content={"success": True, "data": stats})
except Exception as e:
logger.error(f"[Admin] Failed to get reset statistics: {e}")
return JSONResponse(
status_code=500,
content={
"success": False,
"error": "获取统计信息失败"
}
)
return JSONResponse(status_code=500, content={"success": False, "error": "获取统计信息失败"})

View File

@@ -26,7 +26,7 @@ async def create_oauth_app(
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
current_user: User = Security(get_client_user),
):
result = await session.execute( # pyright: ignore[reportDeprecated]
result = await session.execute(
text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'oauth_clients'"
@@ -84,9 +84,7 @@ async def get_user_oauth_apps(
session: Database,
current_user: User = Security(get_client_user),
):
oauth_apps = await session.exec(
select(OAuthClient).where(OAuthClient.owner_id == current_user.id)
)
oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id))
return [
{
"name": app.name,
@@ -113,13 +111,9 @@ async def delete_oauth_app(
if not oauth_client:
raise HTTPException(status_code=404, detail="OAuth app not found")
if oauth_client.owner_id != current_user.id:
raise HTTPException(
status_code=403, detail="Forbidden: Not the owner of this app"
)
raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
tokens = await session.exec(
select(OAuthToken).where(OAuthToken.client_id == client_id)
)
tokens = await session.exec(select(OAuthToken).where(OAuthToken.client_id == client_id))
for token in tokens:
await session.delete(token)
@@ -144,9 +138,7 @@ async def update_oauth_app(
if not oauth_client:
raise HTTPException(status_code=404, detail="OAuth app not found")
if oauth_client.owner_id != current_user.id:
raise HTTPException(
status_code=403, detail="Forbidden: Not the owner of this app"
)
raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
oauth_client.name = name
oauth_client.description = description
@@ -176,14 +168,10 @@ async def refresh_secret(
if not oauth_client:
raise HTTPException(status_code=404, detail="OAuth app not found")
if oauth_client.owner_id != current_user.id:
raise HTTPException(
status_code=403, detail="Forbidden: Not the owner of this app"
)
raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
oauth_client.client_secret = secrets.token_hex()
tokens = await session.exec(
select(OAuthToken).where(OAuthToken.client_id == client_id)
)
tokens = await session.exec(select(OAuthToken).where(OAuthToken.client_id == client_id))
for token in tokens:
await session.delete(token)
@@ -215,9 +203,7 @@ async def generate_oauth_code(
raise HTTPException(status_code=404, detail="OAuth app not found")
if redirect_uri not in client.redirect_uris:
raise HTTPException(
status_code=403, detail="Redirect URI not allowed for this client"
)
raise HTTPException(status_code=403, detail="Redirect URI not allowed for this client")
code = secrets.token_urlsafe(80)
await redis.hset( # pyright: ignore[reportGeneralTypeIssues]

View File

@@ -50,12 +50,8 @@ async def check_user_relationship(
)
).first()
is_followed = bool(
target_relationship and target_relationship.type == RelationshipType.FOLLOW
)
is_following = bool(
my_relationship and my_relationship.type == RelationshipType.FOLLOW
)
is_followed = bool(target_relationship and target_relationship.type == RelationshipType.FOLLOW)
is_following = bool(my_relationship and my_relationship.type == RelationshipType.FOLLOW)
return CheckResponse(
is_followed=is_followed,

View File

@@ -40,16 +40,13 @@ async def create_team(
支持的图片格式: PNG、JPEG、GIF
"""
user_id = current_user.id
assert user_id
if (await current_user.awaitable_attrs.team_membership) is not None:
raise HTTPException(status_code=403, detail="You are already in a team")
is_existed = (await session.exec(select(exists()).where(Team.name == name))).first()
if is_existed:
raise HTTPException(status_code=409, detail="Name already exists")
is_existed = (
await session.exec(select(exists()).where(Team.short_name == short_name))
).first()
is_existed = (await session.exec(select(exists()).where(Team.short_name == short_name))).first()
if is_existed:
raise HTTPException(status_code=409, detail="Short name already exists")
@@ -101,7 +98,6 @@ async def update_team(
"""
team = await session.get(Team, team_id)
user_id = current_user.id
assert user_id
if not team:
raise HTTPException(status_code=404, detail="Team not found")
if team.leader_id != user_id:
@@ -110,9 +106,7 @@ async def update_team(
is_existed = (await session.exec(select(exists()).where(Team.name == name))).first()
if is_existed:
raise HTTPException(status_code=409, detail="Name already exists")
is_existed = (
await session.exec(select(exists()).where(Team.short_name == short_name))
).first()
is_existed = (await session.exec(select(exists()).where(Team.short_name == short_name))).first()
if is_existed:
raise HTTPException(status_code=409, detail="Short name already exists")
@@ -132,20 +126,12 @@ async def update_team(
team.cover_url = await storage.get_file_url(storage_path)
if leader_id is not None:
if not (
await session.exec(select(exists()).where(User.id == leader_id))
).first():
if not (await session.exec(select(exists()).where(User.id == leader_id))).first():
raise HTTPException(status_code=404, detail="Leader not found")
if not (
await session.exec(
select(TeamMember).where(
TeamMember.user_id == leader_id, TeamMember.team_id == team.id
)
)
await session.exec(select(TeamMember).where(TeamMember.user_id == leader_id, TeamMember.team_id == team.id))
).first():
raise HTTPException(
status_code=404, detail="Leader is not a member of the team"
)
raise HTTPException(status_code=404, detail="Leader is not a member of the team")
team.leader_id = leader_id
await session.commit()
@@ -166,9 +152,7 @@ async def delete_team(
if team.leader_id != current_user.id:
raise HTTPException(status_code=403, detail="You are not the team leader")
team_members = await session.exec(
select(TeamMember).where(TeamMember.team_id == team_id)
)
team_members = await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))
for member in team_members:
await session.delete(member)
@@ -186,15 +170,10 @@ async def get_team(
session: Database,
team_id: int = Path(..., description="战队 ID"),
):
members = (
await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))
).all()
members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all()
return TeamQueryResp(
team=members[0].team,
members=[
await UserResp.from_db(m.user, session, include=BASE_INCLUDES)
for m in members
],
members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members],
)
@@ -213,15 +192,11 @@ async def request_join_team(
if (
await session.exec(
select(exists()).where(
TeamRequest.team_id == team_id, TeamRequest.user_id == current_user.id
)
select(exists()).where(TeamRequest.team_id == team_id, TeamRequest.user_id == current_user.id)
)
).first():
raise HTTPException(status_code=409, detail="Join request already exists")
team_request = TeamRequest(
user_id=current_user.id, team_id=team_id, requested_at=datetime.now(UTC)
)
team_request = TeamRequest(user_id=current_user.id, team_id=team_id, requested_at=datetime.now(UTC))
session.add(team_request)
await session.commit()
await session.refresh(team_request)
@@ -229,9 +204,7 @@ async def request_join_team(
@router.post("/team/{team_id}/{user_id}/request", name="接受加入请求", status_code=204)
@router.delete(
"/team/{team_id}/{user_id}/request", name="拒绝加入请求", status_code=204
)
@router.delete("/team/{team_id}/{user_id}/request", name="拒绝加入请求", status_code=204)
async def handle_request(
req: Request,
session: Database,
@@ -247,11 +220,7 @@ async def handle_request(
raise HTTPException(status_code=403, detail="You are not the team leader")
team_request = (
await session.exec(
select(TeamRequest).where(
TeamRequest.team_id == team_id, TeamRequest.user_id == user_id
)
)
await session.exec(select(TeamRequest).where(TeamRequest.team_id == team_id, TeamRequest.user_id == user_id))
).first()
if not team_request:
raise HTTPException(status_code=404, detail="Join request not found")
@@ -261,16 +230,10 @@ async def handle_request(
raise HTTPException(status_code=404, detail="User not found")
if req.method == "POST":
if (
await session.exec(select(exists()).where(TeamMember.user_id == user_id))
).first():
raise HTTPException(
status_code=409, detail="User is already a member of the team"
)
if (await session.exec(select(exists()).where(TeamMember.user_id == user_id))).first():
raise HTTPException(status_code=409, detail="User is already a member of the team")
session.add(
TeamMember(user_id=user_id, team_id=team_id, joined_at=datetime.now(UTC))
)
session.add(TeamMember(user_id=user_id, team_id=team_id, joined_at=datetime.now(UTC)))
await server.new_private_notification(TeamApplicationAccept.init(team_request))
else:
@@ -294,19 +257,13 @@ async def kick_member(
raise HTTPException(status_code=403, detail="You are not the team leader")
team_member = (
await session.exec(
select(TeamMember).where(
TeamMember.team_id == team_id, TeamMember.user_id == user_id
)
)
await session.exec(select(TeamMember).where(TeamMember.team_id == team_id, TeamMember.user_id == user_id))
).first()
if not team_member:
raise HTTPException(status_code=404, detail="User is not a member of the team")
if team.leader_id == current_user.id:
raise HTTPException(
status_code=403, detail="You cannot leave because you are the team leader"
)
raise HTTPException(status_code=403, detail="You cannot leave because you are the team leader")
await session.delete(team_member)
await session.commit()

View File

@@ -35,10 +35,7 @@ async def user_rename(
返回:
- 成功: None
"""
assert current_user is not None
samename_user = (
await session.exec(select(User).where(User.username == new_name))
).first()
samename_user = (await session.exec(select(User).where(User.username == new_name))).first()
if samename_user:
raise HTTPException(409, "Username Exisits")
errors = validate_username(new_name)

View File

@@ -106,9 +106,7 @@ class V1Beatmap(AllStrModel):
await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(
FavouriteBeatmapset.beatmapset_id == db_beatmap.beatmapset.id
)
.where(FavouriteBeatmapset.beatmapset_id == db_beatmap.beatmapset.id)
)
).one(),
rating=0, # TODO
@@ -154,12 +152,8 @@ async def get_beatmaps(
beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"),
beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"),
user: str | None = Query(None, alias="u", description="谱师"),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
ruleset_id: int | None = Query(
None, alias="m", description="Ruleset ID", ge=0, le=3
), # TODO
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0, le=3), # TODO
convert: bool = Query(False, alias="a", description="转谱"), # TODO
checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"),
limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"),
@@ -181,11 +175,7 @@ async def get_beatmaps(
else:
beatmaps = beatmapset.beatmaps
elif user is not None:
where = (
Beatmapset.user_id == user
if type == "id" or user.isdigit()
else Beatmapset.creator == user
)
where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()
for beatmapset in beatmapsets:
if len(beatmaps) >= limit:
@@ -193,11 +183,7 @@ async def get_beatmaps(
beatmaps.extend(beatmapset.beatmaps)
elif since is not None:
beatmapsets = (
await session.exec(
select(Beatmapset)
.where(col(Beatmapset.ranked_date) > since)
.limit(limit)
)
await session.exec(select(Beatmapset).where(col(Beatmapset.ranked_date) > since).limit(limit))
).all()
for beatmapset in beatmapsets:
if len(beatmaps) >= limit:
@@ -214,11 +200,7 @@ async def get_beatmaps(
redis,
fetcher,
)
results.append(
await V1Beatmap.from_db(
session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty
)
)
results.append(await V1Beatmap.from_db(session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty))
continue
except Exception:
...

View File

@@ -41,9 +41,7 @@ async def download_replay(
ge=0,
),
score_id: int | None = Query(None, alias="s", description="成绩 ID"),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
mods: int = Query(0, description="成绩的 MOD"),
storage_service: StorageService = Depends(get_storage_service),
):
@@ -58,13 +56,9 @@ async def download_replay(
await session.exec(
select(Score).where(
Score.beatmap_id == beatmap,
Score.user_id == user
if type == "id" or user.isdigit()
else col(Score.user).has(username=user),
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
Score.mods == mods_,
Score.gamemode == GameMode.from_int_extra(ruleset_id)
if ruleset_id is not None
else True,
Score.gamemode == GameMode.from_int_extra(ruleset_id) if ruleset_id is not None else True,
)
)
).first()
@@ -73,10 +67,7 @@ async def download_replay(
except KeyError:
raise HTTPException(status_code=400, detail="Invalid request")
filepath = (
f"replays/{score_record.id}_{score_record.beatmap_id}"
f"_{score_record.user_id}_lazer_replay.osr"
)
filepath = f"replays/{score_record.id}_{score_record.beatmap_id}_{score_record.user_id}_lazer_replay.osr"
if not await storage_service.is_exists(filepath):
raise HTTPException(status_code=404, detail="Replay file not found")
@@ -100,6 +91,4 @@ async def download_replay(
await session.commit()
data = await storage_service.read_file(filepath)
return ReplayModel(
content=base64.b64encode(data).decode("utf-8"), encoding="base64"
)
return ReplayModel(content=base64.b64encode(data).decode("utf-8"), encoding="base64")

View File

@@ -8,9 +8,7 @@ from app.dependencies.user import v1_authorize
from fastapi import APIRouter, Depends
from pydantic import BaseModel, field_serializer
router = APIRouter(
prefix="/api/v1", dependencies=[Depends(v1_authorize)], tags=["V1 API"]
)
router = APIRouter(prefix="/api/v1", dependencies=[Depends(v1_authorize)], tags=["V1 API"])
class AllStrModel(BaseModel):

View File

@@ -70,9 +70,7 @@ async def get_user_best(
session: Database,
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
):
try:
@@ -80,9 +78,7 @@ async def get_user_best(
await session.exec(
select(Score)
.where(
Score.user_id == user
if type == "id" or user.isdigit()
else col(Score.user).has(username=user),
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
Score.gamemode == GameMode.from_int_extra(ruleset_id),
exists().where(col(PPBestScore.score_id) == Score.id),
)
@@ -106,9 +102,7 @@ async def get_user_recent(
session: Database,
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
):
try:
@@ -116,9 +110,7 @@ async def get_user_recent(
await session.exec(
select(Score)
.where(
Score.user_id == user
if type == "id" or user.isdigit()
else col(Score.user).has(username=user),
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
Score.gamemode == GameMode.from_int_extra(ruleset_id),
Score.ended_at > datetime.now(UTC) - timedelta(hours=24),
)
@@ -143,9 +135,7 @@ async def get_scores(
user: str | None = Query(None, alias="u", description="用户"),
beatmap_id: int = Query(alias="b", description="谱面 ID"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
mods: int = Query(0, description="成绩的 MOD"),
):
@@ -157,9 +147,7 @@ async def get_scores(
.where(
Score.gamemode == GameMode.from_int_extra(ruleset_id),
Score.beatmap_id == beatmap_id,
Score.user_id == user
if type == "id" or user.isdigit()
else col(Score.user).has(username=user),
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
)
.options(joinedload(Score.beatmap))
.order_by(col(Score.classic_total_score).desc())

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
from datetime import datetime
from typing import Literal
@@ -13,7 +12,7 @@ from app.service.user_cache_service import get_user_cache_service
from .router import AllStrModel, router
from fastapi import HTTPException, Query
from fastapi import BackgroundTasks, HTTPException, Query
from sqlmodel import select
@@ -49,9 +48,7 @@ class V1User(AllStrModel):
return f"v1_user:{user_id}"
@classmethod
async def from_db(
cls, session: Database, db_user: User, ruleset: GameMode | None = None
) -> "V1User":
async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User":
# 确保 user_id 不为 None
if db_user.id is None:
raise ValueError("User ID cannot be None")
@@ -63,9 +60,7 @@ class V1User(AllStrModel):
current_statistics = i
break
if current_statistics:
statistics = await UserStatisticsResp.from_db(
current_statistics, session, db_user.country_code
)
statistics = await UserStatisticsResp.from_db(current_statistics, session, db_user.country_code)
else:
statistics = None
return cls(
@@ -78,9 +73,7 @@ class V1User(AllStrModel):
playcount=statistics.play_count if statistics else 0,
ranked_score=statistics.ranked_score if statistics else 0,
total_score=statistics.total_score if statistics else 0,
pp_rank=statistics.global_rank
if statistics and statistics.global_rank
else 0,
pp_rank=statistics.global_rank if statistics and statistics.global_rank else 0,
level=current_statistics.level_current if current_statistics else 0,
pp_raw=statistics.pp if statistics else 0.0,
accuracy=statistics.hit_accuracy if statistics else 0,
@@ -91,9 +84,7 @@ class V1User(AllStrModel):
count_rank_a=current_statistics.grade_a if current_statistics else 0,
country=db_user.country_code,
total_seconds_played=statistics.play_time if statistics else 0,
pp_country_rank=statistics.country_rank
if statistics and statistics.country_rank
else 0,
pp_country_rank=statistics.country_rank if statistics and statistics.country_rank else 0,
events=[], # TODO
)
@@ -106,14 +97,11 @@ class V1User(AllStrModel):
)
async def get_user(
session: Database,
background_tasks: BackgroundTasks,
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
event_days: int = Query(
default=1, ge=1, le=31, description="从现在起所有事件的最大天数"
),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
event_days: int = Query(default=1, ge=1, le=31, description="从现在起所有事件的最大天数"),
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
@@ -131,9 +119,7 @@ async def get_user(
if is_id_query:
try:
user_id_for_cache = int(user)
cached_v1_user = await cache_service.get_v1_user_from_cache(
user_id_for_cache, ruleset
)
cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset)
if cached_v1_user:
return [V1User(**cached_v1_user)]
except (ValueError, TypeError):
@@ -158,9 +144,7 @@ async def get_user(
# 异步缓存结果如果有用户ID
if db_user.id is not None:
user_data = v1_user.model_dump()
asyncio.create_task(
cache_service.cache_v1_user(user_data, db_user.id, ruleset)
)
background_tasks.add_task(cache_service.cache_v1_user, user_data, db_user.id, ruleset)
return [v1_user]

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
from . import ( # noqa: F401
beatmap,
beatmapset,
me,

View File

@@ -40,18 +40,13 @@ class BatchGetResp(BaseModel):
tags=["谱面"],
name="查询单个谱面",
response_model=BeatmapResp,
description=(
"根据谱面 ID / MD5 / 文件名 查询单个谱面。"
"至少提供 id / checksum / filename 之一。"
),
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
)
async def lookup_beatmap(
db: Database,
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
filename: str | None = Query(
default=None, alias="filename", description="谱面文件名"
),
filename: str | None = Query(default=None, alias="filename", description="谱面文件名"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -96,43 +91,23 @@ async def get_beatmap(
tags=["谱面"],
name="批量获取谱面",
response_model=BatchGetResp,
description=(
"批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。"
"为空时按最近更新时间返回。"
),
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
)
async def batch_get_beatmaps(
db: Database,
beatmap_ids: list[int] = Query(
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
),
beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
):
if not beatmap_ids:
beatmaps = (
await db.exec(
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
)
).all()
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
else:
beatmaps = list(
(
await db.exec(
select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50)
)
).all()
)
not_found_beatmaps = [
bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]
]
beatmaps = list((await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50))).all())
not_found_beatmaps = [bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]]
beatmaps.extend(
beatmap
for beatmap in await asyncio.gather(
*[
Beatmap.get_or_fetch(db, fetcher, bid=bid)
for bid in not_found_beatmaps
],
*[Beatmap.get_or_fetch(db, fetcher, bid=bid) for bid in not_found_beatmaps],
return_exceptions=True,
)
if isinstance(beatmap, Beatmap)
@@ -140,12 +115,7 @@ async def batch_get_beatmaps(
for beatmap in beatmaps:
await db.refresh(beatmap)
return BatchGetResp(
beatmaps=[
await BeatmapResp.from_db(bm, session=db, user=current_user)
for bm in beatmaps
]
)
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
@router.post(
@@ -163,12 +133,8 @@ async def get_beatmap_attributes(
default_factory=list,
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
),
ruleset: GameMode | None = Query(
default=None, description="指定 ruleset;为空则使用谱面自身模式"
),
ruleset_id: int | None = Query(
default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3
),
ruleset: GameMode | None = Query(default=None, description="指定 ruleset为空则使用谱面自身模式"),
ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -187,16 +153,11 @@ async def get_beatmap_attributes(
if ruleset is None:
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
ruleset = beatmap_db.mode
key = (
f"beatmap:{beatmap_id}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
)
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try:
return await calculate_beatmap_attributes(
beatmap_id, ruleset, mods_, redis, fetcher
)
return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher)
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found")
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]

View File

@@ -35,9 +35,7 @@ from sqlmodel import exists, select
async def _save_to_db(sets: SearchBeatmapsetsResp):
async with with_db() as session:
for s in sets.beatmapsets:
if not (
await session.exec(select(exists()).where(Beatmapset.id == s.id))
).first():
if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first():
await Beatmapset.from_resp(session, s)
@@ -117,9 +115,7 @@ async def lookup_beatmapset(
fetcher: Fetcher = Depends(get_fetcher),
):
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
resp = await BeatmapsetResp.from_db(
beatmap.beatmapset, session=db, user=current_user
)
resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
return resp
@@ -138,9 +134,7 @@ async def get_beatmapset(
):
try:
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
return await BeatmapsetResp.from_db(
beatmapset, session=db, include=["recent_favourites"], user=current_user
)
return await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
@@ -165,9 +159,7 @@ async def download_beatmapset(
country_code = geo_info.get("country_iso", "")
# 优先使用IP地理位置判断如果获取失败则回退到用户账户的国家代码
is_china = country_code == "CN" or (
not country_code and current_user.country_code == "CN"
)
is_china = country_code == "CN" or (not country_code and current_user.country_code == "CN")
try:
# 使用负载均衡服务获取下载URL
@@ -179,13 +171,10 @@ async def download_beatmapset(
# 如果负载均衡服务失败,回退到原有逻辑
if is_china:
return RedirectResponse(
f"https://dl.sayobot.cn/beatmaps/download/"
f"{'novideo' if no_video else 'full'}/{beatmapset_id}"
f"https://dl.sayobot.cn/beatmaps/download/{'novideo' if no_video else 'full'}/{beatmapset_id}"
)
else:
return RedirectResponse(
f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}"
)
return RedirectResponse(f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}")
@router.post(
@@ -197,12 +186,9 @@ async def download_beatmapset(
async def favourite_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"),
action: Literal["favourite", "unfavourite"] = Form(
description="操作类型favourite 收藏 / unfavourite 取消收藏"
),
action: Literal["favourite", "unfavourite"] = Form(description="操作类型favourite 收藏 / unfavourite 取消收藏"),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
existing_favourite = (
await db.exec(
select(FavouriteBeatmapset).where(
@@ -212,15 +198,11 @@ async def favourite_beatmapset(
)
).first()
if (action == "favourite" and existing_favourite) or (
action == "unfavourite" and not existing_favourite
):
if (action == "favourite" and existing_favourite) or (action == "unfavourite" and not existing_favourite):
return
if action == "favourite":
favourite = FavouriteBeatmapset(
user_id=current_user.id, beatmapset_id=beatmapset_id
)
favourite = FavouriteBeatmapset(user_id=current_user.id, beatmapset_id=beatmapset_id)
db.add(favourite)
else:
await db.delete(existing_favourite)

View File

@@ -4,8 +4,8 @@ from app.database import User
from app.database.lazer_user import ALL_INCLUDED
from app.dependencies import get_current_user
from app.dependencies.database import Database
from app.models.score import GameMode
from app.models.api_me import APIMe
from app.models.score import GameMode
from .router import router

View File

@@ -33,6 +33,4 @@ class BackgroundsResp(BaseModel):
description="获取当前季节背景图列表。",
)
async def get_seasonal_backgrounds():
return BackgroundsResp(
backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds]
)
return BackgroundsResp(backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds])

View File

@@ -12,7 +12,7 @@ from app.service.ranking_cache_service import get_ranking_cache_service
from .router import router
from fastapi import Path, Query, Security
from fastapi import BackgroundTasks, Path, Query, Security
from pydantic import BaseModel
from sqlmodel import col, select
@@ -38,6 +38,7 @@ class CountryResponse(BaseModel):
)
async def get_country_ranking(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"), # TODO
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -51,9 +52,7 @@ async def get_country_ranking(
if cached_data:
# 从缓存返回数据
return CountryResponse(
ranking=[CountryStatistics.model_validate(item) for item in cached_data]
)
return CountryResponse(ranking=[CountryStatistics.model_validate(item) for item in cached_data])
# 缓存未命中,从数据库查询
response = CountryResponse(ranking=[])
@@ -105,14 +104,15 @@ async def get_country_ranking(
# 异步缓存数据(不等待完成)
cache_data = [item.model_dump() for item in current_page_data]
cache_task = cache_service.cache_country_ranking(
ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60
)
# 创建后台任务来缓存数据
import asyncio
asyncio.create_task(cache_task)
background_tasks.add_task(
cache_service.cache_country_ranking,
ruleset,
cache_data,
page,
ttl=settings.ranking_cache_expire_minutes * 60,
)
# 返回当前页的结果
response.ranking = current_page_data
@@ -132,10 +132,9 @@ class TopUsersResponse(BaseModel):
)
async def get_user_ranking(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
type: Literal["performance", "score"] = Path(
..., description="排名类型performance 表现分 / score 计分成绩总分"
),
type: Literal["performance", "score"] = Path(..., description="排名类型performance 表现分 / score 计分成绩总分"),
country: str | None = Query(None, description="国家代码"),
page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -149,9 +148,7 @@ async def get_user_ranking(
if cached_data:
# 从缓存返回数据
return TopUsersResponse(
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]
)
return TopUsersResponse(ranking=[UserStatisticsResp.model_validate(item) for item in cached_data])
# 缓存未命中,从数据库查询
wheres = [
@@ -169,25 +166,22 @@ async def get_user_ranking(
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
statistics_list = await session.exec(
select(UserStatistics)
.where(*wheres)
.order_by(order_by)
.limit(50)
.offset(50 * (page - 1))
select(UserStatistics).where(*wheres).order_by(order_by).limit(50).offset(50 * (page - 1))
)
# 转换为响应格式
ranking_data = []
for statistics in statistics_list:
user_stats_resp = await UserStatisticsResp.from_db(
statistics, session, None, include
)
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
ranking_data.append(user_stats_resp)
# 异步缓存数据(不等待完成)
# 使用配置文件中的TTL设置
cache_data = [item.model_dump() for item in ranking_data]
cache_task = cache_service.cache_ranking(
# 创建后台任务来缓存数据
background_tasks.add_task(
cache_service.cache_ranking,
ruleset,
type,
cache_data,
@@ -196,139 +190,134 @@ async def get_user_ranking(
ttl=settings.ranking_cache_expire_minutes * 60,
)
# 创建后台任务来缓存数据
import asyncio
asyncio.create_task(cache_task)
resp = TopUsersResponse(ranking=ranking_data)
return resp
""" @router.post(
"/rankings/cache/refresh",
name="刷新排行榜缓存",
description="手动刷新排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def refresh_ranking_cache(
session: Database,
ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
if ruleset and type:
# 刷新特定的用户排行榜
await cache_service.refresh_ranking_cache(session, ruleset, type, country)
message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
# 如果请求刷新地区排行榜
if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
await cache_service.refresh_country_ranking_cache(session, ruleset)
message += f" and country ranking for {ruleset}"
return {"message": message}
elif ruleset:
# 刷新特定游戏模式的所有排行榜
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
for ranking_type in ranking_types:
await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
if include_country_ranking:
await cache_service.refresh_country_ranking_cache(session, ruleset)
return {"message": f"Refreshed all ranking caches for {ruleset}"}
else:
# 刷新所有排行榜
await cache_service.refresh_all_rankings(session)
return {"message": "Refreshed all ranking caches"}
# @router.post(
# "/rankings/cache/refresh",
# name="刷新排行榜缓存",
# description="手动刷新排行榜缓存(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def refresh_ranking_cache(
# session: Database,
# ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
# type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
# country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# if ruleset and type:
# # 刷新特定的用户排行榜
# await cache_service.refresh_ranking_cache(session, ruleset, type, country)
# message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
# # 如果请求刷新地区排行榜
# if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
# await cache_service.refresh_country_ranking_cache(session, ruleset)
# message += f" and country ranking for {ruleset}"
# return {"message": message}
# elif ruleset:
# # 刷新特定游戏模式的所有排行榜
# ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
# for ranking_type in ranking_types:
# await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
# if include_country_ranking:
# await cache_service.refresh_country_ranking_cache(session, ruleset)
# return {"message": f"Refreshed all ranking caches for {ruleset}"}
# else:
# # 刷新所有排行榜
# await cache_service.refresh_all_rankings(session)
# return {"message": "Refreshed all ranking caches"}
@router.post(
"/rankings/{ruleset}/country/cache/refresh",
name="刷新地区排行榜缓存",
description="手动刷新地区排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def refresh_country_ranking_cache(
session: Database,
ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
await cache_service.refresh_country_ranking_cache(session, ruleset)
return {"message": f"Refreshed country ranking cache for {ruleset}"}
# @router.post(
# "/rankings/{ruleset}/country/cache/refresh",
# name="刷新地区排行榜缓存",
# description="手动刷新地区排行榜缓存(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def refresh_country_ranking_cache(
# session: Database,
# ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# await cache_service.refresh_country_ranking_cache(session, ruleset)
# return {"message": f"Refreshed country ranking cache for {ruleset}"}
@router.delete(
"/rankings/cache",
name="清除排行榜缓存",
description="清除排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def clear_ranking_cache(
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
if ruleset and type:
message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
if include_country_ranking:
message += " and country ranking"
return {"message": message}
else:
message = "Cleared all ranking caches"
if include_country_ranking:
message += " including country rankings"
return {"message": message}
# @router.delete(
# "/rankings/cache",
# name="清除排行榜缓存",
# description="清除排行榜缓存(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def clear_ranking_cache(
# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
# type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
# country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
# if ruleset and type:
# message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
# if include_country_ranking:
# message += " and country ranking"
# return {"message": message}
# else:
# message = "Cleared all ranking caches"
# if include_country_ranking:
# message += " including country rankings"
# return {"message": message}
@router.delete(
"/rankings/{ruleset}/country/cache",
name="清除地区排行榜缓存",
description="清除地区排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def clear_country_ranking_cache(
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
await cache_service.invalidate_country_cache(ruleset)
if ruleset:
return {"message": f"Cleared country ranking cache for {ruleset}"}
else:
return {"message": "Cleared all country ranking caches"}
# @router.delete(
# "/rankings/{ruleset}/country/cache",
# name="清除地区排行榜缓存",
# description="清除地区排行榜缓存(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def clear_country_ranking_cache(
# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# await cache_service.invalidate_country_cache(ruleset)
# if ruleset:
# return {"message": f"Cleared country ranking cache for {ruleset}"}
# else:
# return {"message": "Cleared all country ranking caches"}
@router.get(
"/rankings/cache/stats",
name="获取排行榜缓存统计",
description="获取排行榜缓存统计信息(管理员功能)",
tags=["排行榜", "管理"],
)
async def get_ranking_cache_stats(
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
stats = await cache_service.get_cache_stats()
return stats """
# @router.get(
# "/rankings/cache/stats",
# name="获取排行榜缓存统计",
# description="获取排行榜缓存统计信息(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def get_ranking_cache_stats(
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# stats = await cache_service.get_cache_stats()
# return stats

View File

@@ -30,11 +30,7 @@ async def get_relationship(
request: Request,
current_user: User = Security(get_current_user, scopes=["friends.read"]),
):
relationship_type = (
RelationshipType.FOLLOW
if request.url.path.endswith("/friends")
else RelationshipType.BLOCK
)
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
relationships = await db.exec(
select(Relationship).where(
Relationship.user_id == current_user.id,
@@ -71,12 +67,7 @@ async def add_relationship(
target: int = Query(description="目标用户 ID"),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
relationship_type = (
RelationshipType.FOLLOW
if request.url.path.endswith("/friends")
else RelationshipType.BLOCK
)
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
if target == current_user.id:
raise HTTPException(422, "Cannot add relationship to yourself")
relationship = (
@@ -120,11 +111,8 @@ async def add_relationship(
Relationship.target_id == target,
)
)
).first()
assert relationship, "Relationship should exist after commit"
return AddFriendResp(
user_relation=await RelationshipResp.from_db(db, relationship)
)
).one()
return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship))
@router.delete(
@@ -145,11 +133,7 @@ async def delete_relationship(
target: int = Path(..., description="目标用户 ID"),
current_user: User = Security(get_client_user),
):
relationship_type = (
RelationshipType.BLOCK
if "/blocks/" in request.url.path
else RelationshipType.FOLLOW
)
relationship_type = RelationshipType.BLOCK if "/blocks/" in request.url.path else RelationshipType.FOLLOW
relationship = (
await db.exec(
select(Relationship).where(

View File

@@ -39,17 +39,11 @@ async def get_all_rooms(
db: Database,
mode: Literal["open", "ended", "participated", "owned", None] = Query(
default="open",
description=(
"房间模式open 当前开放 / ended 已经结束 / "
"participated 参与过 / owned 自己创建的房间"
),
description=("房间模式open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
),
category: RoomCategory = Query(
RoomCategory.NORMAL,
description=(
"房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
" / DAILY_CHALLENGE 每日挑战"
),
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
),
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -60,10 +54,7 @@ async def get_all_rooms(
if status is not None:
where_clauses.append(col(Room.status) == status)
if mode == "open":
where_clauses.append(
(col(Room.ends_at).is_(None))
| (col(Room.ends_at) > now.replace(tzinfo=UTC))
)
where_clauses.append((col(Room.ends_at).is_(None)) | (col(Room.ends_at) > now.replace(tzinfo=UTC)))
if category == RoomCategory.REALTIME:
where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
if mode == "participated":
@@ -76,10 +67,7 @@ async def get_all_rooms(
if mode == "owned":
where_clauses.append(col(Room.host_id) == current_user.id)
if mode == "ended":
where_clauses.append(
(col(Room.ends_at).is_not(None))
& (col(Room.ends_at) < now.replace(tzinfo=UTC))
)
where_clauses.append((col(Room.ends_at).is_not(None)) & (col(Room.ends_at) < now.replace(tzinfo=UTC)))
db_rooms = (
(
@@ -97,11 +85,7 @@ async def get_all_rooms(
resp = await RoomResp.from_db(room, db)
if category == RoomCategory.REALTIME:
mp_room = MultiplayerHubs.rooms.get(room.id)
resp.has_password = (
bool(mp_room.room.settings.password.strip())
if mp_room is not None
else False
)
resp.has_password = bool(mp_room.room.settings.password.strip()) if mp_room is not None else False
resp.category = RoomCategory.NORMAL
resp_list.append(resp)
@@ -115,9 +99,7 @@ class APICreatedRoom(RoomResp):
error: str = ""
async def _participate_room(
room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
):
async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
participated_user = (
await session.exec(
select(RoomParticipatedUser).where(
@@ -154,7 +136,6 @@ async def create_room(
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
assert current_user.id is not None
user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id)
await _participate_room(db_room.id, user_id, db_room, db, redis)
@@ -177,10 +158,7 @@ async def get_room(
room_id: int = Path(..., description="房间 ID"),
category: str = Query(
default="",
description=(
"房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
" / DAILY_CHALLENGE 每日挑战 (可选)"
),
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
@@ -188,9 +166,7 @@ async def get_room(
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None:
raise HTTPException(404, "Room not found")
resp = await RoomResp.from_db(
db_room, include=["current_user_score"], session=db, user=current_user
)
resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user)
return resp
@@ -400,7 +376,6 @@ async def get_room_events(
for score in scores:
user_ids.add(score.user_id)
beatmap_ids.add(score.beatmap_id)
assert event.id is not None
first_event_id = min(first_event_id, event.id)
last_event_id = max(last_event_id, event.id)
@@ -416,16 +391,12 @@ async def get_room_events(
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
user_resps = [await UserResp.from_db(user, db) for user in users]
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
beatmap_resps = [
await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
]
beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps]
beatmapset_resps = {}
for beatmap_resp in beatmap_resps:
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
playlist_items_resps = [
await PlaylistResp.from_db(item) for item in playlist_items.values()
]
playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()]
return RoomEvents(
beatmaps=beatmap_resps,

View File

@@ -104,11 +104,7 @@ async def submit_score(
if not info.passed:
info.rank = Rank.F
score_token = (
await db.exec(
select(ScoreToken)
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
.where(ScoreToken.id == token)
)
await db.exec(select(ScoreToken).options(joinedload(ScoreToken.beatmap)).where(ScoreToken.id == token))
).first()
if not score_token or score_token.user_id != user_id:
raise HTTPException(status_code=404, detail="Score token not found")
@@ -138,10 +134,7 @@ async def submit_score(
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
has_pp = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp
has_leaderboard = (
db_beatmap.beatmap_status.has_leaderboard()
| settings.enable_all_beatmap_leaderboard
)
has_leaderboard = db_beatmap.beatmap_status.has_leaderboard() | settings.enable_all_beatmap_leaderboard
beatmap_length = db_beatmap.total_length
score = await process_score(
current_user,
@@ -167,21 +160,11 @@ async def submit_score(
has_pp,
has_leaderboard,
)
score = (
await db.exec(
select(Score)
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
.where(Score.id == score_id)
)
).first()
assert score is not None
score = (await db.exec(select(Score).options(joinedload(Score.user)).where(Score.id == score_id))).one()
resp = await ScoreResp.from_db(db, score)
total_users = (await db.exec(select(func.count()).select_from(User))).first()
assert total_users is not None
if resp.rank_global is not None and resp.rank_global <= min(
math.ceil(float(total_users) * 0.01), 50
):
total_users = (await db.exec(select(func.count()).select_from(User))).one()
if resp.rank_global is not None and resp.rank_global <= min(math.ceil(float(total_users) * 0.01), 50):
rank_event = Event(
created_at=datetime.now(UTC),
type=EventType.RANK,
@@ -207,9 +190,7 @@ async def submit_score(
score_gamemode = score.gamemode
if user_id is not None:
background_task.add_task(
_refresh_user_cache_background, redis, user_id, score_gamemode
)
background_task.add_task(_refresh_user_cache_background, redis, user_id, score_gamemode)
background_task.add_task(process_user_achievement, resp.id)
return resp
@@ -225,9 +206,7 @@ async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameM
# 创建独立的数据库会话
session = AsyncSession(engine)
try:
await user_cache_service.refresh_user_cache_on_score_submit(
session, user_id, mode
)
await user_cache_service.refresh_user_cache_on_score_submit(session, user_id, mode)
finally:
await session.close()
except Exception as e:
@@ -280,22 +259,16 @@ async def get_beatmap_scores(
beatmap_id: int = Path(description="谱面 ID"),
mode: GameMode = Query(description="指定 auleset"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
mods: list[str] = Query(
default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"
),
mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"),
type: LeaderboardType = Query(
LeaderboardType.GLOBAL,
description=(
"排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"
),
description=("排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
),
current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="this server only contains lazer scores"
)
raise HTTPException(status_code=404, detail="this server only contains lazer scores")
all_scores, user_score, count = await get_leaderboard(
db,
@@ -310,9 +283,7 @@ async def get_beatmap_scores(
user_score_resp = await ScoreResp.from_db(db, user_score) if user_score else None
resp = BeatmapScores(
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
user_score=BeatmapUserScore(
score=user_score_resp, position=user_score_resp.rank_global or 0
)
user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0)
if user_score_resp
else None,
score_count=count,
@@ -342,9 +313,7 @@ async def get_user_beatmap_score(
current_user: User = Security(get_current_user, scopes=["public"]),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="This server only contains non-legacy scores"
)
raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
user_score = (
await db.exec(
select(Score)
@@ -386,9 +355,7 @@ async def get_user_all_beatmap_scores(
current_user: User = Security(get_current_user, scopes=["public"]),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="This server only contains non-legacy scores"
)
raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
all_user_scores = (
await db.exec(
select(Score)
@@ -420,7 +387,6 @@ async def create_solo_score(
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -454,10 +420,7 @@ async def submit_solo_score(
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
):
assert current_user.id is not None
return await submit_score(
background_task, info, beatmap_id, token, current_user, db, redis, fetcher
)
return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher)
@router.post(
@@ -478,7 +441,6 @@ async def create_playlist_score(
version_hash: str = Form("", description="谱面版本哈希"),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -488,26 +450,16 @@ async def create_playlist_score(
db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None
if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC):
raise HTTPException(status_code=400, detail="Room has ended")
item = (
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
)
).first()
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
if not item:
raise HTTPException(status_code=404, detail="Playlist not found")
# validate
if not item.freestyle:
if item.ruleset_id != ruleset_id:
raise HTTPException(
status_code=400, detail="Ruleset mismatch in playlist item"
)
raise HTTPException(status_code=400, detail="Ruleset mismatch in playlist item")
if item.beatmap_id != beatmap_id:
raise HTTPException(
status_code=400, detail="Beatmap ID mismatch in playlist item"
)
raise HTTPException(status_code=400, detail="Beatmap ID mismatch in playlist item")
agg = await session.exec(
select(ItemAttemptsCount).where(
ItemAttemptsCount.room_id == room_id,
@@ -523,9 +475,7 @@ async def create_playlist_score(
if item.expired:
raise HTTPException(status_code=400, detail="Playlist item has expired")
if item.played_at:
raise HTTPException(
status_code=400, detail="Playlist item has already been played"
)
raise HTTPException(status_code=400, detail="Playlist item has already been played")
# 这里应该不用验证mod了吧。。。
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
score_token = ScoreToken(
@@ -557,18 +507,10 @@ async def submit_playlist_score(
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
):
assert current_user.id is not None
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
item = (
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
)
).first()
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
if not item:
raise HTTPException(status_code=404, detail="Playlist item not found")
room = await session.get(Room, room_id)
@@ -621,9 +563,7 @@ async def index_playlist_scores(
room_id: int,
playlist_id: int,
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
cursor: int = Query(
2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"
),
cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"),
current_user: User = Security(get_current_user, scopes=["public"]),
):
# 立即获取用户ID避免懒加载问题
@@ -693,9 +633,6 @@ async def show_playlist_score(
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
room = await session.get(Room, room_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
@@ -715,9 +652,7 @@ async def show_playlist_score(
)
)
).first()
if completed_players := await redis.get(
f"multiplayer:{room_id}:gameplay:players"
):
if completed_players := await redis.get(f"multiplayer:{room_id}:gameplay:players"):
completed = completed_players == "0"
if score_record and completed:
break
@@ -784,9 +719,7 @@ async def get_user_playlist_score(
raise HTTPException(status_code=404, detail="Score not found")
resp = await ScoreResp.from_db(session, score_record.score)
resp.position = await get_position(
room_id, playlist_id, score_record.score_id, session
)
resp.position = await get_position(room_id, playlist_id, score_record.score_id, session)
return resp
@@ -850,11 +783,7 @@ async def unpin_score(
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
score_record = (
await db.exec(
select(Score).where(Score.id == score_id, Score.user_id == user_id)
)
).first()
score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
@@ -878,10 +807,7 @@ async def unpin_score(
"/score-pins/{score_id}/reorder",
status_code=204,
name="调整置顶成绩顺序",
description=(
"**客户端专属**\n调整已置顶成绩的展示顺序。"
"仅提供 after_score_id 或 before_score_id 之一。"
),
description=("**客户端专属**\n调整已置顶成绩的展示顺序。仅提供 after_score_id 或 before_score_id 之一。"),
tags=["成绩"],
)
async def reorder_score_pin(
@@ -894,11 +820,7 @@ async def reorder_score_pin(
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
score_record = (
await db.exec(
select(Score).where(Score.id == score_id, Score.user_id == user_id)
)
).first()
score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
@@ -908,8 +830,7 @@ async def reorder_score_pin(
if (after_score_id is None) == (before_score_id is None):
raise HTTPException(
status_code=400,
detail="Either after_score_id or before_score_id "
"must be provided (but not both)",
detail="Either after_score_id or before_score_id must be provided (but not both)",
)
all_pinned_scores = (
@@ -927,9 +848,7 @@ async def reorder_score_pin(
target_order = None
reference_score_id = after_score_id or before_score_id
reference_score = next(
(s for s in all_pinned_scores if s.id == reference_score_id), None
)
reference_score = next((s for s in all_pinned_scores if s.id == reference_score_id), None)
if not reference_score:
detail = "After score not found" if after_score_id else "Before score not found"
raise HTTPException(status_code=404, detail=detail)
@@ -951,9 +870,7 @@ async def reorder_score_pin(
if current_order < s.pinned_order <= target_order and s.id != score_id:
updates.append((s.id, s.pinned_order - 1))
if after_score_id:
final_target = (
target_order - 1 if target_order > current_order else target_order
)
final_target = target_order - 1 if target_order > current_order else target_order
else:
final_target = target_order
else:
@@ -964,9 +881,7 @@ async def reorder_score_pin(
for score_id, new_order in updates:
await db.exec(select(Score).where(Score.id == score_id))
score_to_update = (
await db.exec(select(Score).where(Score.id == score_id))
).first()
score_to_update = (await db.exec(select(Score).where(Score.id == score_id))).first()
if score_to_update:
score_to_update.pinned_order = new_order

View File

@@ -4,34 +4,29 @@
from __future__ import annotations
from datetime import datetime, UTC
from typing import Annotated
from app.auth import authenticate_user
from app.config import settings
from app.database import User
from app.dependencies import get_current_user
from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import GeoIPHelper, get_geoip_helper
from app.database.email_verification import EmailVerification, LoginSession
from app.service.email_verification_service import (
EmailVerificationService,
LoginSessionService
EmailVerificationService,
LoginSessionService,
)
from app.service.login_log_service import LoginLogService
from app.models.extended_auth import ExtendedTokenResponse
from fastapi import Form, Depends, Request, HTTPException, status, Security
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlmodel import select
from .router import router
from fastapi import Depends, Form, HTTPException, Request, Security, status
from fastapi.responses import Response
from pydantic import BaseModel
from redis.asyncio import Redis
class SessionReissueResponse(BaseModel):
"""重新发送验证码响应"""
success: bool
message: str
@@ -40,39 +35,35 @@ class SessionReissueResponse(BaseModel):
"/session/verify",
name="验证会话",
description="验证邮件验证码并完成会话认证",
status_code=204
status_code=204,
)
async def verify_session(
request: Request,
db: Database,
redis: Annotated[Redis, Depends(get_redis)],
verification_key: str = Form(..., description="8位邮件验证码"),
current_user: User = Security(get_current_user)
current_user: User = Security(get_current_user),
) -> Response:
"""
验证邮件验证码并完成会话认证
对应 osu! 的 session/verify 接口
成功时返回 204 No Content失败时返回 401 Unauthorized
"""
try:
from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")
ip_address = get_client_ip(request) # noqa: F841
user_agent = request.headers.get("User-Agent", "Unknown") # noqa: F841
# 从当前认证用户获取信息
user_id = current_user.id
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户未认证"
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未认证")
# 验证邮件验证码
success, message = await EmailVerificationService.verify_code(
db, redis, user_id, verification_key
)
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_key)
if success:
# 记录成功的邮件验证
await LoginLogService.record_login(
@@ -81,9 +72,9 @@ async def verify_session(
request=request,
login_method="email_verification",
login_success=True,
notes=f"邮件验证成功"
notes="邮件验证成功",
)
# 返回 204 No Content 表示验证成功
return Response(status_code=status.HTTP_204_NO_CONTENT)
else:
@@ -93,83 +84,69 @@ async def verify_session(
request=request,
attempted_username=current_user.username,
login_method="email_verification",
notes=f"邮件验证失败: {message}"
notes=f"邮件验证失败: {message}",
)
# 返回 401 Unauthorized 表示验证失败
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=message
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=message)
except ValueError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的用户会话"
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="验证过程中发生错误"
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的用户会话")
except Exception:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="验证过程中发生错误")
@router.post(
"/session/verify/reissue",
name="重新发送验证码",
description="重新发送邮件验证码",
response_model=SessionReissueResponse
response_model=SessionReissueResponse,
)
async def reissue_verification_code(
request: Request,
db: Database,
redis: Annotated[Redis, Depends(get_redis)],
current_user: User = Security(get_current_user)
current_user: User = Security(get_current_user),
) -> SessionReissueResponse:
"""
重新发送邮件验证码
对应 osu! 的 session/verify/reissue 接口
"""
try:
from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")
# 从当前认证用户获取信息
user_id = current_user.id
if not user_id:
return SessionReissueResponse(
success=False,
message="用户未认证"
)
return SessionReissueResponse(success=False, message="用户未认证")
# 重新发送验证码
success, message = await EmailVerificationService.resend_verification_code(
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
db,
redis,
user_id,
current_user.username,
current_user.email,
ip_address,
user_agent,
)
return SessionReissueResponse(
success=success,
message=message
)
return SessionReissueResponse(success=success, message=message)
except ValueError:
return SessionReissueResponse(
success=False,
message="无效的用户会话"
)
except Exception as e:
return SessionReissueResponse(
success=False,
message="重新发送过程中发生错误"
)
return SessionReissueResponse(success=False, message="无效的用户会话")
except Exception:
return SessionReissueResponse(success=False, message="重新发送过程中发生错误")
@router.post(
"/session/check-new-location",
name="检查新位置登录",
description="检查登录是否来自新位置(内部接口)"
description="检查登录是否来自新位置(内部接口)",
)
async def check_new_location(
request: Request,
@@ -183,22 +160,21 @@ async def check_new_location(
"""
try:
from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request)
geo_info = geoip.lookup(ip_address)
country_code = geo_info.get("country_iso", "XX")
is_new_location = await LoginSessionService.check_new_location(
db, user_id, ip_address, country_code
)
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
return {
"is_new_location": is_new_location,
"ip_address": ip_address,
"country_code": country_code
"country_code": country_code,
}
except Exception as e:
return {
"is_new_location": True, # 出错时默认为新位置
"error": str(e)
"error": str(e),
}

View File

@@ -1,73 +1,80 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
import json
from typing import Any
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import json
from app.dependencies.database import get_redis, get_redis_message
from app.log import logger
from app.utils import bg_tasks
from .router import router
from fastapi import APIRouter
from pydantic import BaseModel
# Redis key constants
REDIS_ONLINE_USERS_KEY = "server:online_users"
REDIS_PLAYING_USERS_KEY = "server:playing_users"
REDIS_PLAYING_USERS_KEY = "server:playing_users"
REDIS_REGISTERED_USERS_KEY = "server:registered_users"
REDIS_ONLINE_HISTORY_KEY = "server:online_history"
# 线程池用于同步Redis操作
_executor = ThreadPoolExecutor(max_workers=2)
async def _redis_exec(func, *args, **kwargs):
"""在线程池中执行同步Redis操作"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(_executor, func, *args, **kwargs)
class ServerStats(BaseModel):
"""服务器统计信息响应模型"""
registered_users: int
online_users: int
playing_users: int
timestamp: datetime
class OnlineHistoryPoint(BaseModel):
"""在线历史数据点"""
timestamp: datetime
online_count: int
playing_count: int
class OnlineHistoryResponse(BaseModel):
"""24小时在线历史响应模型"""
history: list[OnlineHistoryPoint]
current_stats: ServerStats
@router.get("/stats", response_model=ServerStats, tags=["统计"])
async def get_server_stats() -> ServerStats:
"""
获取服务器实时统计信息
返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息
"""
redis = get_redis()
try:
# 并行获取所有统计数据
registered_count, online_count, playing_count = await asyncio.gather(
_get_registered_users_count(redis),
_get_online_users_count(redis),
_get_playing_users_count(redis)
_get_playing_users_count(redis),
)
return ServerStats(
registered_users=registered_count,
online_users=online_count,
playing_users=playing_count,
timestamp=datetime.utcnow()
timestamp=datetime.utcnow(),
)
except Exception as e:
logger.error(f"Error getting server stats: {e}")
@@ -76,14 +83,15 @@ async def get_server_stats() -> ServerStats:
registered_users=0,
online_users=0,
playing_users=0,
timestamp=datetime.utcnow()
timestamp=datetime.utcnow(),
)
@router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"])
async def get_online_history() -> OnlineHistoryResponse:
"""
获取最近24小时在线统计历史
返回过去24小时内每小时的在线用户数和游玩用户数统计
包含当前实时数据作为最新数据点
"""
@@ -92,80 +100,80 @@ async def get_online_history() -> OnlineHistoryResponse:
redis_sync = get_redis_message()
history_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
history_points = []
# 处理历史数据
for data in history_data:
try:
point_data = json.loads(data)
# 只保留基本字段
history_points.append(OnlineHistoryPoint(
timestamp=datetime.fromisoformat(point_data["timestamp"]),
online_count=point_data["online_count"],
playing_count=point_data["playing_count"]
))
history_points.append(
OnlineHistoryPoint(
timestamp=datetime.fromisoformat(point_data["timestamp"]),
online_count=point_data["online_count"],
playing_count=point_data["playing_count"],
)
)
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f"Invalid history data point: {data}, error: {e}")
continue
# 获取当前实时统计信息
current_stats = await get_server_stats()
# 如果历史数据为空或者最新数据超过15分钟添加当前数据点
if not history_points or (
history_points and
(current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds() > 15 * 60
history_points
and (current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds()
> 15 * 60
):
history_points.append(OnlineHistoryPoint(
timestamp=current_stats.timestamp,
online_count=current_stats.online_users,
playing_count=current_stats.playing_users
))
history_points.append(
OnlineHistoryPoint(
timestamp=current_stats.timestamp,
online_count=current_stats.online_users,
playing_count=current_stats.playing_users,
)
)
# 按时间排序(最新的在前)
history_points.sort(key=lambda x: x.timestamp, reverse=True)
# 限制到最多48个数据点24小时
history_points = history_points[:48]
return OnlineHistoryResponse(
history=history_points,
current_stats=current_stats
)
return OnlineHistoryResponse(history=history_points, current_stats=current_stats)
except Exception as e:
logger.error(f"Error getting online history: {e}")
# 返回空历史和当前状态
current_stats = await get_server_stats()
return OnlineHistoryResponse(
history=[],
current_stats=current_stats
)
return OnlineHistoryResponse(history=[], current_stats=current_stats)
@router.get("/stats/debug", tags=["统计"])
async def get_stats_debug_info():
"""
获取统计系统调试信息
用于调试时间对齐和区间统计问题
"""
try:
from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager
current_time = datetime.utcnow()
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
interval_stats = await EnhancedIntervalStatsManager.get_current_interval_stats()
# 获取Redis中的实际数据
redis_sync = get_redis_message()
online_key = f"server:interval_online_users:{current_interval.interval_key}"
playing_key = f"server:interval_playing_users:{current_interval.interval_key}"
online_users_raw = await _redis_exec(redis_sync.smembers, online_key)
playing_users_raw = await _redis_exec(redis_sync.smembers, playing_key)
online_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in online_users_raw]
playing_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in playing_users_raw]
return {
"current_time": current_time.isoformat(),
"current_interval": {
@@ -175,28 +183,29 @@ async def get_stats_debug_info():
"is_current": current_interval.is_current(),
"minutes_remaining": int((current_interval.end_time - current_time).total_seconds() / 60),
"seconds_remaining": int((current_interval.end_time - current_time).total_seconds()),
"progress_percentage": round((1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100, 1)
"progress_percentage": round(
(1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100,
1,
),
},
"interval_statistics": interval_stats.to_dict() if interval_stats else None,
"redis_data": {
"online_users": online_users,
"playing_users": playing_users,
"online_count": len(online_users),
"playing_count": len(playing_users)
"playing_count": len(playing_users),
},
"system_status": {
"stats_system": "enhanced_interval_stats",
"data_alignment": "30_minute_boundaries",
"real_time_updates": True,
"auto_24h_fill": True
}
"auto_24h_fill": True,
},
}
except Exception as e:
logger.error(f"Error getting debug info: {e}")
return {
"error": "Failed to retrieve debug information",
"message": str(e)
}
return {"error": "Failed to retrieve debug information", "message": str(e)}
async def _get_registered_users_count(redis) -> int:
"""获取注册用户总数(从缓存)"""
@@ -207,6 +216,7 @@ async def _get_registered_users_count(redis) -> int:
logger.error(f"Error getting registered users count: {e}")
return 0
async def _get_online_users_count(redis) -> int:
"""获取当前在线用户数"""
try:
@@ -216,6 +226,7 @@ async def _get_online_users_count(redis) -> int:
logger.error(f"Error getting online users count: {e}")
return 0
async def _get_playing_users_count(redis) -> int:
"""获取当前游玩用户数"""
try:
@@ -225,27 +236,28 @@ async def _get_playing_users_count(redis) -> int:
logger.error(f"Error getting playing users count: {e}")
return 0
# 统计更新功能
async def update_registered_users_count() -> None:
"""更新注册用户数缓存"""
from app.dependencies.database import with_db
from app.database import User
from app.const import BANCHOBOT_ID
from sqlmodel import select, func
from app.database import User
from app.dependencies.database import with_db
from sqlmodel import func, select
redis = get_redis()
try:
async with with_db() as db:
# 排除机器人用户BANCHOBOT_ID
result = await db.exec(
select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID)
)
result = await db.exec(select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID))
count = result.first()
await redis.set(REDIS_REGISTERED_USERS_KEY, count or 0, ex=300) # 5分钟过期
logger.debug(f"Updated registered users count: {count}")
except Exception as e:
logger.error(f"Error updating registered users count: {e}")
async def add_online_user(user_id: int) -> None:
"""添加在线用户"""
redis_sync = get_redis_message()
@@ -257,14 +269,20 @@ async def add_online_user(user_id: int) -> None:
if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期
await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期
logger.debug(f"Added online user {user_id}")
# 立即更新当前区间统计
from app.service.enhanced_interval_stats import update_user_activity_in_interval
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=False))
bg_tasks.add_task(
update_user_activity_in_interval,
user_id,
is_playing=False,
)
except Exception as e:
logger.error(f"Error adding online user {user_id}: {e}")
async def remove_online_user(user_id: int) -> None:
"""移除在线用户"""
redis_sync = get_redis_message()
@@ -274,6 +292,7 @@ async def remove_online_user(user_id: int) -> None:
except Exception as e:
logger.error(f"Error removing online user {user_id}: {e}")
async def add_playing_user(user_id: int) -> None:
"""添加游玩用户"""
redis_sync = get_redis_message()
@@ -285,14 +304,16 @@ async def add_playing_user(user_id: int) -> None:
if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期
await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期
logger.debug(f"Added playing user {user_id}")
# 立即更新当前区间统计
from app.service.enhanced_interval_stats import update_user_activity_in_interval
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=True))
bg_tasks.add_task(update_user_activity_in_interval, user_id, is_playing=True)
except Exception as e:
logger.error(f"Error adding playing user {user_id}: {e}")
async def remove_playing_user(user_id: int) -> None:
"""移除游玩用户"""
redis_sync = get_redis_message()
@@ -301,6 +322,7 @@ async def remove_playing_user(user_id: int) -> None:
except Exception as e:
logger.error(f"Error removing playing user {user_id}: {e}")
async def record_hourly_stats() -> None:
"""记录统计数据 - 简化版本主要作为fallback使用"""
redis_sync = get_redis_message()
@@ -308,24 +330,27 @@ async def record_hourly_stats() -> None:
try:
# 先确保Redis连接正常
await redis_async.ping()
online_count = await _get_online_users_count(redis_async)
playing_count = await _get_playing_users_count(redis_async)
current_time = datetime.utcnow()
history_point = {
"timestamp": current_time.isoformat(),
"online_count": online_count,
"playing_count": playing_count
"playing_count": playing_count,
}
# 添加到历史记录
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
# 只保留48个数据点24小时每30分钟一个点
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
# 设置过期时间为26小时确保有足够缓冲
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
logger.info(f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}")
logger.info(
f"Recorded fallback stats: online={online_count}, playing={playing_count} "
f"at {current_time.strftime('%H:%M:%S')}"
)
except Exception as e:
logger.error(f"Error recording fallback stats: {e}")

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
from datetime import UTC, datetime, timedelta
from typing import Literal
@@ -26,7 +25,7 @@ from app.service.user_cache_service import get_user_cache_service
from .router import router
from fastapi import HTTPException, Path, Query, Security
from fastapi import BackgroundTasks, HTTPException, Path, Query, Security
from pydantic import BaseModel
from sqlmodel import exists, false, select
from sqlmodel.sql.expression import col
@@ -47,13 +46,10 @@ class BatchUserResponse(BaseModel):
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
async def get_users(
session: Database,
user_ids: list[int] = Query(
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"
),
background_task: BackgroundTasks,
user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"),
# current_user: User = Security(get_current_user, scopes=["public"]),
include_variant_statistics: bool = Query(
default=False, description="是否包含各模式的统计信息"
), # TODO: future use
include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
@@ -72,11 +68,7 @@ async def get_users(
# 查询未缓存的用户
if uncached_user_ids:
searched_users = (
await session.exec(
select(User).where(col(User.id).in_(uncached_user_ids))
)
).all()
searched_users = (await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))).all()
# 将查询到的用户添加到缓存并返回
for searched_user in searched_users:
@@ -88,7 +80,7 @@ async def get_users(
)
cached_users.append(user_resp)
# 异步缓存,不阻塞响应
asyncio.create_task(cache_service.cache_user(user_resp))
background_task.add_task(cache_service.cache_user, user_resp)
return BatchUserResponse(users=cached_users)
else:
@@ -103,7 +95,7 @@ async def get_users(
)
users.append(user_resp)
# 异步缓存
asyncio.create_task(cache_service.cache_user(user_resp))
background_task.add_task(cache_service.cache_user, user_resp)
return BatchUserResponse(users=users)
@@ -117,6 +109,7 @@ async def get_users(
)
async def get_user_info_ruleset(
session: Database,
background_task: BackgroundTasks,
user_id: str = Path(description="用户 ID 或用户名"),
ruleset: GameMode | None = Path(description="指定 ruleset"),
# current_user: User = Security(get_current_user, scopes=["public"]),
@@ -134,9 +127,7 @@ async def get_user_info_ruleset(
searched_user = (
await session.exec(
select(User).where(
User.id == int(user_id)
if user_id.isdigit()
else User.username == user_id.removeprefix("@")
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
)
)
).first()
@@ -151,7 +142,7 @@ async def get_user_info_ruleset(
)
# 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp, ruleset))
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
return user_resp
@@ -165,6 +156,7 @@ async def get_user_info_ruleset(
tags=["用户"],
)
async def get_user_info(
background_task: BackgroundTasks,
session: Database,
user_id: str = Path(description="用户 ID 或用户名"),
# current_user: User = Security(get_current_user, scopes=["public"]),
@@ -182,9 +174,7 @@ async def get_user_info(
searched_user = (
await session.exec(
select(User).where(
User.id == int(user_id)
if user_id.isdigit()
else User.username == user_id.removeprefix("@")
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
)
)
).first()
@@ -198,7 +188,7 @@ async def get_user_info(
)
# 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp))
background_task.add_task(cache_service.cache_user, user_resp)
return user_resp
@@ -212,6 +202,7 @@ async def get_user_info(
)
async def get_user_beatmapsets(
session: Database,
background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"),
type: BeatmapsetType = Path(description="谱面集类型"),
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -222,9 +213,7 @@ async def get_user_beatmapsets(
cache_service = get_user_cache_service(redis)
# 先尝试从缓存获取
cached_result = await cache_service.get_user_beatmapsets_from_cache(
user_id, type.value, limit, offset
)
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
if cached_result is not None:
# 根据类型恢复对象
if type == BeatmapsetType.MOST_PLAYED:
@@ -253,10 +242,7 @@ async def get_user_beatmapsets(
raise HTTPException(404, detail="User not found")
favourites = await user.awaitable_attrs.favourite_beatmapsets
resp = [
await BeatmapsetResp.from_db(
favourite.beatmapset, session=session, user=user
)
for favourite in favourites
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
]
elif type == BeatmapsetType.MOST_PLAYED:
@@ -267,25 +253,18 @@ async def get_user_beatmapsets(
.limit(limit)
.offset(offset)
)
resp = [
await BeatmapPlaycountsResp.from_db(most_played_beatmap)
for most_played_beatmap in most_played
]
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
else:
raise HTTPException(400, detail="Invalid beatmapset type")
# 异步缓存结果
async def cache_beatmapsets():
try:
await cache_service.cache_user_beatmapsets(
user_id, type.value, resp, limit, offset
)
await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
except Exception as e:
logger.error(
f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
)
logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}")
asyncio.create_task(cache_beatmapsets())
background_task.add_task(cache_beatmapsets)
return resp
@@ -299,18 +278,14 @@ async def get_user_beatmapsets(
)
async def get_user_scores(
session: Database,
background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"),
type: Literal["best", "recent", "firsts", "pinned"] = Path(
description=(
"成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩"
" / firsts 第一名成绩 / pinned 置顶成绩"
)
description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")
),
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"),
include_fails: bool = Query(False, description="是否包含失败的成绩"),
mode: GameMode | None = Query(
None, description="指定 ruleset (可选,默认为用户主模式)"
),
mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"),
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -320,9 +295,7 @@ async def get_user_scores(
# 先尝试从缓存获取对于recent类型使用较短的缓存时间
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
cached_scores = await cache_service.get_user_scores_from_cache(
user_id, type, mode, limit, offset
)
cached_scores = await cache_service.get_user_scores_from_cache(user_id, type, mode, limit, offset)
if cached_scores is not None:
return cached_scores
@@ -332,9 +305,7 @@ async def get_user_scores(
gamemode = mode or db_user.playmode
order_by = None
where_clause = (col(Score.user_id) == db_user.id) & (
col(Score.gamemode) == gamemode
)
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
if not include_fails:
where_clause &= col(Score.passed).is_(True)
if type == "pinned":
@@ -351,13 +322,7 @@ async def get_user_scores(
where_clause &= false()
scores = (
await session.exec(
select(Score)
.where(where_clause)
.order_by(order_by)
.limit(limit)
.offset(offset)
)
await session.exec(select(Score).where(where_clause).order_by(order_by).limit(limit).offset(offset))
).all()
if not scores:
return []
@@ -371,18 +336,14 @@ async def get_user_scores(
]
# 异步缓存结果
asyncio.create_task(
cache_service.cache_user_scores(
user_id, type, score_responses, mode, limit, offset, cache_expire
)
background_task.add_task(
cache_service.cache_user_scores, user_id, type, score_responses, mode, limit, offset, cache_expire
)
return score_responses
@router.get(
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
)
@router.get("/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp])
async def get_user_events(
session: Database,
user: int,

View File

@@ -59,9 +59,7 @@ class CacheScheduler:
# 从配置文件获取间隔设置
check_interval = 5 * 60 # 5分钟检查间隔
beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔
ranking_cache_interval = (
settings.ranking_cache_refresh_interval_minutes * 60
) # 从配置读取
ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取
user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔
user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔

View File

@@ -5,9 +5,7 @@
from __future__ import annotations
import asyncio
from datetime import datetime
from app.config import settings
from app.dependencies.database import engine
from app.log import logger
from app.service.database_cleanup_service import DatabaseCleanupService
@@ -51,16 +49,16 @@ class DatabaseCleanupScheduler:
try:
# 每小时运行一次清理
await asyncio.sleep(3600) # 3600秒 = 1小时
if not self.running:
break
await self._run_cleanup()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Database cleanup scheduler error: {str(e)}")
logger.error(f"Database cleanup scheduler error: {e!s}")
# 发生错误后等待5分钟再继续
await asyncio.sleep(300)
@@ -69,20 +67,20 @@ class DatabaseCleanupScheduler:
try:
async with AsyncSession(engine) as db:
logger.debug("Starting scheduled database cleanup...")
# 清理过期的验证码
expired_codes = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
# 清理过期的登录会话
expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
# 只在有清理记录时输出总结
total_cleaned = expired_codes + expired_sessions
if total_cleaned > 0:
logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}")
except Exception as e:
logger.error(f"Error during scheduled database cleanup: {str(e)}")
logger.error(f"Error during scheduled database cleanup: {e!s}")
async def run_manual_cleanup(self):
"""手动运行完整清理"""
@@ -95,7 +93,7 @@ class DatabaseCleanupScheduler:
logger.debug(f"Manual cleanup completed, total records cleaned: {total}")
return results
except Exception as e:
logger.error(f"Error during manual database cleanup: {str(e)}")
logger.error(f"Error during manual database cleanup: {e!s}")
return {}

View File

@@ -63,10 +63,7 @@ class BeatmapCacheService:
if preload_tasks:
results = await asyncio.gather(*preload_tasks, return_exceptions=True)
success_count = sum(1 for r in results if r is True)
logger.info(
f"Preloaded {success_count}/{len(preload_tasks)} "
f"beatmaps successfully"
)
logger.info(f"Preloaded {success_count}/{len(preload_tasks)} beatmaps successfully")
except Exception as e:
logger.error(f"Error during beatmap preloading: {e}")
@@ -119,9 +116,7 @@ class BeatmapCacheService:
return {
"cached_beatmaps": len(keys),
"estimated_total_size_mb": (
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
),
"estimated_total_size_mb": (round(total_size / 1024 / 1024, 2) if total_size > 0 else 0),
"preloading": self._preloading,
}
except Exception as e:
@@ -155,9 +150,7 @@ def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheS
return _cache_service
async def schedule_preload_task(
session: AsyncSession, redis: Redis, fetcher: "Fetcher"
):
async def schedule_preload_task(session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
"""
定时预加载任务
"""

View File

@@ -192,22 +192,16 @@ class BeatmapDownloadService:
healthy_endpoints.sort(key=lambda x: x.priority)
return healthy_endpoints
def get_download_url(
self, beatmapset_id: int, no_video: bool, is_china: bool
) -> str:
def get_download_url(self, beatmapset_id: int, no_video: bool, is_china: bool) -> str:
"""获取下载URL带负载均衡和故障转移"""
healthy_endpoints = self.get_healthy_endpoints(is_china)
if not healthy_endpoints:
# 如果没有健康的端点,记录错误并回退到所有端点中优先级最高的
logger.error(f"No healthy endpoints available for is_china={is_china}")
endpoints = (
self.china_endpoints if is_china else self.international_endpoints
)
endpoints = self.china_endpoints if is_china else self.international_endpoints
if not endpoints:
raise HTTPException(
status_code=503, detail="No download endpoints available"
)
raise HTTPException(status_code=503, detail="No download endpoints available")
endpoint = min(endpoints, key=lambda x: x.priority)
else:
# 使用第一个健康的端点(已按优先级排序)
@@ -218,9 +212,7 @@ class BeatmapDownloadService:
video_type = "novideo" if no_video else "full"
return endpoint.url_template.format(type=video_type, sid=beatmapset_id)
elif endpoint.name == "Nerinyan":
return endpoint.url_template.format(
sid=beatmapset_id, no_video="true" if no_video else "false"
)
return endpoint.url_template.format(sid=beatmapset_id, no_video="true" if no_video else "false")
elif endpoint.name == "OsuDirect":
# osu.direct 似乎没有no_video参数直接使用基础URL
return endpoint.url_template.format(sid=beatmapset_id)
@@ -239,9 +231,7 @@ class BeatmapDownloadService:
for name, status in self.endpoint_status.items():
status_info["endpoints"][name] = {
"healthy": status.is_healthy,
"last_check": status.last_check.isoformat()
if status.last_check
else None,
"last_check": status.last_check.isoformat() if status.last_check else None,
"consecutive_failures": status.consecutive_failures,
"last_error": status.last_error,
"priority": status.endpoint.priority,

View File

@@ -11,9 +11,7 @@ from app.models.score import GameMode
from sqlmodel import col, exists, select, update
@get_scheduler().scheduled_job(
"cron", hour=0, minute=0, second=0, id="calculate_user_rank"
)
@get_scheduler().scheduled_job("cron", hour=0, minute=0, second=0, id="calculate_user_rank")
async def calculate_user_rank(is_today: bool = False):
today = datetime.now(UTC).date()
target_date = today if is_today else today - timedelta(days=1)

View File

@@ -11,9 +11,7 @@ from sqlmodel import exists, select
async def create_banchobot():
async with with_db() as session:
is_exist = (
await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))
).first()
is_exist = (await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))).first()
if not is_exist:
banchobot = User(
username="BanchoBot",

View File

@@ -82,8 +82,7 @@ async def daily_challenge_job():
if beatmap is None or ruleset_id is None:
logger.warning(
f"[DailyChallenge] Missing required data for daily challenge {now}."
" Will try again in 5 minutes."
f"[DailyChallenge] Missing required data for daily challenge {now}. Will try again in 5 minutes."
)
get_scheduler().add_job(
daily_challenge_job,
@@ -104,9 +103,7 @@ async def daily_challenge_job():
else:
allowed_mods_list = get_available_mods(ruleset_id_int, required_mods_list)
next_day = (now + timedelta(days=1)).replace(
hour=0, minute=0, second=0, microsecond=0
)
next_day = (now + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
room = await create_daily_challenge_room(
beatmap=beatmap_int,
ruleset_id=ruleset_id_int,
@@ -114,24 +111,13 @@ async def daily_challenge_job():
allowed_mods=allowed_mods_list,
duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60),
)
await MetadataHubs.broadcast_call(
"DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id)
)
logger.success(
"[DailyChallenge] Added today's daily challenge: "
f"{beatmap=}, {ruleset_id=}, {required_mods=}"
)
await MetadataHubs.broadcast_call("DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id))
logger.success(f"[DailyChallenge] Added today's daily challenge: {beatmap=}, {ruleset_id=}, {required_mods=}")
return
except (ValueError, json.JSONDecodeError) as e:
logger.warning(
f"[DailyChallenge] Error processing daily challenge data: {e}"
" Will try again in 5 minutes."
)
logger.warning(f"[DailyChallenge] Error processing daily challenge data: {e} Will try again in 5 minutes.")
except Exception as e:
logger.exception(
f"[DailyChallenge] Unexpected error in daily challenge job: {e}"
" Will try again in 5 minutes."
)
logger.exception(f"[DailyChallenge] Unexpected error in daily challenge job: {e} Will try again in 5 minutes.")
get_scheduler().add_job(
daily_challenge_job,
"date",
@@ -139,9 +125,7 @@ async def daily_challenge_job():
)
@get_scheduler().scheduled_job(
"cron", hour=0, minute=1, second=0, id="daily_challenge_last_top"
)
@get_scheduler().scheduled_job("cron", hour=0, minute=1, second=0, id="daily_challenge_last_top")
async def process_daily_challenge_top():
async with with_db() as session:
now = datetime.now(UTC)
@@ -182,11 +166,7 @@ async def process_daily_challenge_top():
await session.commit()
del s
user_ids = (
await session.exec(
select(User.id).where(col(User.id).not_in(participated_users))
)
).all()
user_ids = (await session.exec(select(User.id).where(col(User.id).not_in(participated_users)))).all()
for id in user_ids:
stats = await session.get(DailyChallengeStats, id)
if stats is None: # not execute

View File

@@ -4,14 +4,13 @@
from __future__ import annotations
from datetime import datetime, UTC, timedelta
from datetime import UTC, datetime, timedelta
from app.database.email_verification import EmailVerification, LoginSession
from app.log import logger
from sqlmodel import select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy import and_
class DatabaseCleanupService:
@@ -21,211 +20,207 @@ class DatabaseCleanupService:
async def cleanup_expired_verification_codes(db: AsyncSession) -> int:
"""
清理过期的邮件验证码
Args:
db: 数据库会话
Returns:
int: 清理的记录数
"""
try:
# 查找过期的验证码记录
current_time = datetime.now(UTC)
stmt = select(EmailVerification).where(
EmailVerification.expires_at < current_time
)
stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
result = await db.exec(stmt)
expired_codes = result.all()
# 删除过期的记录
deleted_count = 0
for code in expired_codes:
await db.delete(code)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired email verification codes")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {str(e)}")
logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {e!s}")
return 0
@staticmethod
async def cleanup_expired_login_sessions(db: AsyncSession) -> int:
"""
清理过期的登录会话
Args:
db: 数据库会话
Returns:
int: 清理的记录数
"""
try:
# 查找过期的登录会话记录
current_time = datetime.now(UTC)
stmt = select(LoginSession).where(
LoginSession.expires_at < current_time
)
stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
result = await db.exec(stmt)
expired_sessions = result.all()
# 删除过期的记录
deleted_count = 0
for session in expired_sessions:
await db.delete(session)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired login sessions")
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {str(e)}")
logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {e!s}")
return 0
@staticmethod
async def cleanup_old_used_verification_codes(db: AsyncSession, days_old: int = 7) -> int:
"""
清理旧的已使用验证码记录
Args:
db: 数据库会话
days_old: 清理多少天前的已使用记录默认7天
Returns:
int: 清理的记录数
"""
try:
# 查找指定天数前的已使用验证码记录
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
stmt = select(EmailVerification).where(
EmailVerification.is_used == True
)
stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
result = await db.exec(stmt)
all_used_codes = result.all()
# 筛选出过期的记录
old_used_codes = [
code for code in all_used_codes
if code.used_at and code.used_at < cutoff_time
]
old_used_codes = [code for code in all_used_codes if code.used_at and code.used_at < cutoff_time]
# 删除旧的已使用记录
deleted_count = 0
for code in old_used_codes:
await db.delete(code)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days")
logger.debug(
f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days"
)
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {str(e)}")
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}")
return 0
@staticmethod
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
"""
清理旧的已验证会话记录
Args:
db: 数据库会话
days_old: 清理多少天前的已验证记录默认30天
Returns:
int: 清理的记录数
"""
try:
# 查找指定天数前的已验证会话记录
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
stmt = select(LoginSession).where(
LoginSession.is_verified == True
)
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
result = await db.exec(stmt)
all_verified_sessions = result.all()
# 筛选出过期的记录
old_verified_sessions = [
session for session in all_verified_sessions
session
for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_time
]
# 删除旧的已验证记录
deleted_count = 0
for session in old_verified_sessions:
await db.delete(session)
deleted_count += 1
await db.commit()
if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days")
logger.debug(
f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days"
)
return deleted_count
except Exception as e:
await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {str(e)}")
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {e!s}")
return 0
@staticmethod
async def run_full_cleanup(db: AsyncSession) -> dict[str, int]:
"""
运行完整的清理流程
Args:
db: 数据库会话
Returns:
dict: 各项清理的结果统计
"""
results = {}
# 清理过期的验证码
results["expired_verification_codes"] = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
# 清理过期的登录会话
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
# 清理7天前的已使用验证码
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
# 清理30天前的已验证会话
results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
total_cleaned = sum(results.values())
if total_cleaned > 0:
logger.debug(f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}")
logger.debug(
f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}"
)
return results
@staticmethod
async def get_cleanup_statistics(db: AsyncSession) -> dict[str, int]:
"""
获取清理统计信息
Args:
db: 数据库会话
Returns:
dict: 统计信息
"""
@@ -233,57 +228,54 @@ class DatabaseCleanupService:
current_time = datetime.now(UTC)
cutoff_7_days = current_time - timedelta(days=7)
cutoff_30_days = current_time - timedelta(days=30)
# 统计过期的验证码数量
expired_codes_stmt = select(EmailVerification).where(
EmailVerification.expires_at < current_time
)
expired_codes_stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
expired_codes_result = await db.exec(expired_codes_stmt)
expired_codes_count = len(expired_codes_result.all())
# 统计过期的登录会话数量
expired_sessions_stmt = select(LoginSession).where(
LoginSession.expires_at < current_time
)
expired_sessions_stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
expired_sessions_result = await db.exec(expired_sessions_stmt)
expired_sessions_count = len(expired_sessions_result.all())
# 统计7天前的已使用验证码数量
old_used_codes_stmt = select(EmailVerification).where(
EmailVerification.is_used == True
)
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
old_used_codes_result = await db.exec(old_used_codes_stmt)
all_used_codes = old_used_codes_result.all()
old_used_codes_count = len([
code for code in all_used_codes
if code.used_at and code.used_at < cutoff_7_days
])
# 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(
LoginSession.is_verified == True
old_used_codes_count = len(
[code for code in all_used_codes if code.used_at and code.used_at < cutoff_7_days]
)
# 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
all_verified_sessions = old_verified_sessions_result.all()
old_verified_sessions_count = len([
session for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_30_days
])
old_verified_sessions_count = len(
[
session
for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_30_days
]
)
return {
"expired_verification_codes": expired_codes_count,
"expired_login_sessions": expired_sessions_count,
"old_used_verification_codes": old_used_codes_count,
"old_verified_sessions": old_verified_sessions_count,
"total_cleanable": expired_codes_count + expired_sessions_count + old_used_codes_count + old_verified_sessions_count
"total_cleanable": expired_codes_count
+ expired_sessions_count
+ old_used_codes_count
+ old_verified_sessions_count,
}
except Exception as e:
logger.error(f"[Cleanup Service] Error getting cleanup statistics: {str(e)}")
logger.error(f"[Cleanup Service] Error getting cleanup statistics: {e!s}")
return {
"expired_verification_codes": 0,
"expired_login_sessions": 0,
"old_used_verification_codes": 0,
"old_verified_sessions": 0,
"total_cleanable": 0
"total_cleanable": 0,
}

View File

@@ -8,17 +8,18 @@ from __future__ import annotations
import asyncio
import concurrent.futures
from datetime import datetime
import json
import uuid
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from typing import Dict, Any, Optional
import redis as sync_redis # 添加同步Redis导入
from email.mime.text import MIMEText
import json
import smtplib
from typing import Any
import uuid
from app.config import settings
from app.dependencies.database import redis_message_client # 使用同步Redis客户端
from app.log import logger
from app.utils import bg_tasks # 添加同步Redis导入
import redis as sync_redis
class EmailQueue:
@@ -30,14 +31,14 @@ class EmailQueue:
self._processing = False
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
self._retry_limit = 3 # 重试次数限制
# 邮件配置
self.smtp_server = getattr(settings, 'smtp_server', 'localhost')
self.smtp_port = getattr(settings, 'smtp_port', 587)
self.smtp_username = getattr(settings, 'smtp_username', '')
self.smtp_password = getattr(settings, 'smtp_password', '')
self.from_email = getattr(settings, 'from_email', 'noreply@example.com')
self.from_name = getattr(settings, 'from_name', 'osu! server')
self.smtp_server = getattr(settings, "smtp_server", "localhost")
self.smtp_port = getattr(settings, "smtp_port", 587)
self.smtp_username = getattr(settings, "smtp_username", "")
self.smtp_password = getattr(settings, "smtp_password", "")
self.from_email = getattr(settings, "from_email", "noreply@example.com")
self.from_name = getattr(settings, "from_name", "osu! server")
async def _run_in_executor(self, func, *args):
"""在线程池中运行同步操作"""
@@ -48,7 +49,7 @@ class EmailQueue:
"""启动邮件处理任务"""
if not self._processing:
self._processing = True
asyncio.create_task(self._process_email_queue())
bg_tasks.add_task(self._process_email_queue)
logger.info("Email queue processing started")
async def stop_processing(self):
@@ -56,27 +57,29 @@ class EmailQueue:
self._processing = False
logger.info("Email queue processing stopped")
async def enqueue_email(self,
to_email: str,
subject: str,
content: str,
html_content: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None) -> str:
async def enqueue_email(
self,
to_email: str,
subject: str,
content: str,
html_content: str | None = None,
metadata: dict[str, Any] | None = None,
) -> str:
"""
将邮件加入队列等待发送
Args:
to_email: 收件人邮箱地址
subject: 邮件主题
content: 邮件纯文本内容
html_content: 邮件HTML内容如果有
metadata: 额外元数据如密码重置ID等
Returns:
邮件任务ID
"""
email_id = str(uuid.uuid4())
email_data = {
"id": email_id,
"to_email": to_email,
@@ -86,125 +89,117 @@ class EmailQueue:
"metadata": json.dumps(metadata) if metadata else "{}",
"created_at": datetime.now().isoformat(),
"status": "pending", # pending, sending, sent, failed
"retry_count": "0"
"retry_count": "0",
}
# 将邮件数据存入Redis
await self._run_in_executor(
lambda: self.redis.hset(f"email:{email_id}", mapping=email_data)
)
await self._run_in_executor(lambda: self.redis.hset(f"email:{email_id}", mapping=email_data))
# 设置24小时过期防止数据堆积
await self._run_in_executor(
self.redis.expire, f"email:{email_id}", 86400
)
await self._run_in_executor(self.redis.expire, f"email:{email_id}", 86400)
# 加入发送队列
await self._run_in_executor(
self.redis.lpush, "email_queue", email_id
)
await self._run_in_executor(self.redis.lpush, "email_queue", email_id)
logger.info(f"Email enqueued with id: {email_id} to {to_email}")
return email_id
async def get_email_status(self, email_id: str) -> Dict[str, Any]:
async def get_email_status(self, email_id: str) -> dict[str, Any]:
"""
获取邮件发送状态
Args:
email_id: 邮件任务ID
Returns:
邮件任务状态信息
"""
email_data = await self._run_in_executor(
self.redis.hgetall, f"email:{email_id}"
)
email_data = await self._run_in_executor(self.redis.hgetall, f"email:{email_id}")
# 解码Redis返回的字节数据
if email_data:
return {
k.decode("utf-8") if isinstance(k, bytes) else k:
v.decode("utf-8") if isinstance(v, bytes) else v
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") if isinstance(v, bytes) else v
for k, v in email_data.items()
}
return {"status": "not_found"}
async def _process_email_queue(self):
"""处理邮件队列"""
logger.info("Starting email queue processor")
while self._processing:
try:
# 从队列获取邮件ID
def brpop_operation():
return self.redis.brpop(["email_queue"], timeout=5)
result = await self._run_in_executor(brpop_operation)
if not result:
await asyncio.sleep(1)
continue
# 解包返回结果(列表名和值)
queue_name, email_id = result
if isinstance(email_id, bytes):
email_id = email_id.decode("utf-8")
# 获取邮件数据
email_data = await self.get_email_status(email_id)
if email_data.get("status") == "not_found":
logger.warning(f"Email data not found for id: {email_id}")
continue
# 更新状态为发送中
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "status", "sending"
)
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "sending")
# 尝试发送邮件
success = await self._send_email(email_data)
if success:
# 更新状态为已发送
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "sent")
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "status", "sent"
)
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "sent_at", datetime.now().isoformat()
self.redis.hset,
f"email:{email_id}",
"sent_at",
datetime.now().isoformat(),
)
logger.info(f"Email {email_id} sent successfully to {email_data.get('to_email')}")
else:
# 计算重试次数
retry_count = int(email_data.get("retry_count", "0")) + 1
if retry_count <= self._retry_limit:
# 重新入队,稍后重试
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "retry_count", str(retry_count)
self.redis.hset,
f"email:{email_id}",
"retry_count",
str(retry_count),
)
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "pending")
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "status", "pending"
self.redis.hset,
f"email:{email_id}",
"last_retry",
datetime.now().isoformat(),
)
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "last_retry", datetime.now().isoformat()
)
# 延迟重试(使用指数退避)
delay = 60 * (2 ** (retry_count - 1)) # 1分钟2分钟4分钟...
# 创建延迟任务
asyncio.create_task(self._delayed_retry(email_id, delay))
bg_tasks.add_task(self._delayed_retry, email_id, delay)
logger.warning(f"Email {email_id} will be retried in {delay} seconds (attempt {retry_count})")
else:
# 超过重试次数,标记为失败
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "status", "failed"
)
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "failed")
logger.error(f"Email {email_id} failed after {retry_count} attempts")
except Exception as e:
logger.error(f"Error processing email queue: {e}")
await asyncio.sleep(5) # 出错后等待5秒
@@ -212,53 +207,51 @@ class EmailQueue:
async def _delayed_retry(self, email_id: str, delay: int):
"""延迟重试发送邮件"""
await asyncio.sleep(delay)
await self._run_in_executor(
self.redis.lpush, "email_queue", email_id
)
await self._run_in_executor(self.redis.lpush, "email_queue", email_id)
logger.info(f"Re-queued email {email_id} for retry after {delay} seconds")
async def _send_email(self, email_data: Dict[str, Any]) -> bool:
async def _send_email(self, email_data: dict[str, Any]) -> bool:
"""
实际发送邮件
Args:
email_data: 邮件数据
Returns:
是否发送成功
"""
try:
# 如果邮件发送功能被禁用,则只记录日志
if not getattr(settings, 'enable_email_sending', True):
if not getattr(settings, "enable_email_sending", True):
logger.info(f"[Mock Email] Would send to {email_data.get('to_email')}: {email_data.get('subject')}")
return True
# 创建邮件
msg = MIMEMultipart('alternative')
msg['From'] = f"{self.from_name} <{self.from_email}>"
msg['To'] = email_data.get('to_email', '')
msg['Subject'] = email_data.get('subject', '')
msg = MIMEMultipart("alternative")
msg["From"] = f"{self.from_name} <{self.from_email}>"
msg["To"] = email_data.get("to_email", "")
msg["Subject"] = email_data.get("subject", "")
# 添加纯文本内容
content = email_data.get('content', '')
content = email_data.get("content", "")
if content:
msg.attach(MIMEText(content, 'plain', 'utf-8'))
msg.attach(MIMEText(content, "plain", "utf-8"))
# 添加HTML内容如果有
html_content = email_data.get('html_content', '')
html_content = email_data.get("html_content", "")
if html_content:
msg.attach(MIMEText(html_content, 'html', 'utf-8'))
msg.attach(MIMEText(html_content, "html", "utf-8"))
# 发送邮件
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
if self.smtp_username and self.smtp_password:
server.starttls()
server.login(self.smtp_username, self.smtp_password)
server.send_message(msg)
return True
except Exception as e:
logger.error(f"Failed to send email: {e}")
return False
@@ -267,10 +260,12 @@ class EmailQueue:
# 全局邮件队列实例
email_queue = EmailQueue()
# 在应用启动时调用
async def start_email_processor():
await email_queue.start_processing()
# 在应用关闭时调用
async def stop_email_processor():
await email_queue.stop_processing()

View File

@@ -4,13 +4,11 @@
from __future__ import annotations
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
import secrets
import smtplib
import string
from datetime import datetime, UTC, timedelta
from typing import Optional
from app.config import settings
from app.log import logger
@@ -18,28 +16,28 @@ from app.log import logger
class EmailService:
"""邮件发送服务"""
def __init__(self):
self.smtp_server = getattr(settings, 'smtp_server', 'localhost')
self.smtp_port = getattr(settings, 'smtp_port', 587)
self.smtp_username = getattr(settings, 'smtp_username', '')
self.smtp_password = getattr(settings, 'smtp_password', '')
self.from_email = getattr(settings, 'from_email', 'noreply@example.com')
self.from_name = getattr(settings, 'from_name', 'osu! server')
self.smtp_server = getattr(settings, "smtp_server", "localhost")
self.smtp_port = getattr(settings, "smtp_port", 587)
self.smtp_username = getattr(settings, "smtp_username", "")
self.smtp_password = getattr(settings, "smtp_password", "")
self.from_email = getattr(settings, "from_email", "noreply@example.com")
self.from_name = getattr(settings, "from_name", "osu! server")
def generate_verification_code(self) -> str:
"""生成8位验证码"""
# 只使用数字,避免混淆
return ''.join(secrets.choice(string.digits) for _ in range(8))
return "".join(secrets.choice(string.digits) for _ in range(8))
async def send_verification_email(self, email: str, code: str, username: str) -> bool:
"""发送验证邮件"""
try:
msg = MIMEMultipart()
msg['From'] = f"{self.from_name} <{self.from_email}>"
msg['To'] = email
msg['Subject'] = "邮箱验证 - Email Verification"
msg["From"] = f"{self.from_name} <{self.from_email}>"
msg["To"] = email
msg["Subject"] = "邮箱验证 - Email Verification"
# HTML 邮件内容
html_content = f"""
<!DOCTYPE html>
@@ -101,15 +99,15 @@ class EmailService:
<h1>osu! 邮箱验证</h1>
<p>Email Verification</p>
</div>
<div class="content">
<h2>你好 {username}</h2>
<p>感谢你注册我们的 osu! 服务器。为了完成账户验证,请输入以下验证码:</p>
<div class="code">{code}</div>
<p>这个验证码将在 <strong>10 分钟后过期</strong>。</p>
<div class="warning">
<strong>注意:</strong>
<ul>
@@ -118,19 +116,19 @@ class EmailService:
<li>验证码只能使用一次</li>
</ul>
</div>
<p>如果你有任何问题,请联系我们的支持团队。</p>
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
<h3>Hello {username}!</h3>
<p>Thank you for registering on our osu! server. To complete your account verification, please enter the following verification code:</p>
<p>This verification code will expire in <strong>10 minutes</strong>.</p>
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
</div>
<div class="footer">
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
<p>This email was sent automatically, please do not reply.</p>
@@ -138,26 +136,26 @@ class EmailService:
</div>
</body>
</html>
"""
msg.attach(MIMEText(html_content, 'html', 'utf-8'))
""" # noqa: E501
msg.attach(MIMEText(html_content, "html", "utf-8"))
# 发送邮件
if not settings.enable_email_sending:
# 邮件发送功能禁用时只记录日志,不实际发送
logger.info(f"[Email Verification] Mock sending verification code to {email}: {code}")
return True
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
if self.smtp_username and self.smtp_password:
server.starttls()
server.login(self.smtp_username, self.smtp_password)
server.send_message(msg)
logger.info(f"[Email Verification] Successfully sent verification code to {email}")
return True
except Exception as e:
logger.error(f"[Email Verification] Failed to send email: {e}")
return False

View File

@@ -4,40 +4,38 @@
from __future__ import annotations
from datetime import UTC, datetime, timedelta
import secrets
import string
from datetime import datetime, UTC, timedelta
from typing import Optional
from app.database.email_verification import EmailVerification, LoginSession
from app.service.email_service import email_service
from app.service.email_queue import email_queue # 导入邮件队列
from app.log import logger
from app.config import settings
from app.database.email_verification import EmailVerification, LoginSession
from app.log import logger
from app.service.email_queue import email_queue # 导入邮件队列
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import select
from redis.asyncio import Redis
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class EmailVerificationService:
"""邮件验证服务"""
@staticmethod
def generate_verification_code() -> str:
"""生成8位验证码"""
return ''.join(secrets.choice(string.digits) for _ in range(8))
return "".join(secrets.choice(string.digits) for _ in range(8))
@staticmethod
async def send_verification_email_via_queue(email: str, code: str, username: str, user_id: int) -> bool:
"""使用邮件队列发送验证邮件
Args:
email: 接收验证码的邮箱地址
code: 验证码
username: 用户名
user_id: 用户ID
Returns:
是否成功将邮件加入队列
"""
@@ -103,15 +101,15 @@ class EmailVerificationService:
<h1>osu! 邮箱验证</h1>
<p>Email Verification</p>
</div>
<div class="content">
<h2>你好 {username}</h2>
<p>请使用以下验证码验证您的账户:</p>
<div class="code">{code}</div>
<p>验证码将在 <strong>10 分钟内有效</strong>。</p>
<div class="warning">
<p><strong>重要提示:</strong></p>
<ul>
@@ -120,17 +118,17 @@ class EmailVerificationService:
<li>为了账户安全,请勿在其他网站使用相同的密码</li>
</ul>
</div>
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
<h3>Hello {username}!</h3>
<p>Please use the following verification code to verify your account:</p>
<p>This verification code will be valid for <strong>10 minutes</strong>.</p>
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
</div>
<div class="footer">
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
<p>This email was sent automatically, please do not reply.</p>
@@ -138,8 +136,8 @@ class EmailVerificationService:
</div>
</body>
</html>
"""
""" # noqa: E501
# 纯文本备用内容
plain_content = f"""
你好 {username}
@@ -162,34 +160,30 @@ This verification code will be valid for 10 minutes.
© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。
This email was sent automatically, please do not reply.
"""
# 将邮件加入队列
subject = "邮箱验证 - Email Verification"
metadata = {
"type": "email_verification",
"user_id": user_id,
"code": code
}
metadata = {"type": "email_verification", "user_id": user_id, "code": code}
await email_queue.enqueue_email(
to_email=email,
subject=subject,
content=plain_content,
html_content=html_content,
metadata=metadata
metadata=metadata,
)
return True
except Exception as e:
logger.error(f"[Email Verification] Failed to enqueue email: {e}")
return False
@staticmethod
def generate_session_token() -> str:
"""生成会话令牌"""
return secrets.token_urlsafe(32)
@staticmethod
async def create_verification_record(
db: AsyncSession,
@@ -197,27 +191,27 @@ This email was sent automatically, please do not reply.
user_id: int,
email: str,
ip_address: str | None = None,
user_agent: str | None = None
user_agent: str | None = None,
) -> tuple[EmailVerification, str]:
"""创建邮件验证记录"""
# 检查是否有未过期的验证码
existing_result = await db.exec(
select(EmailVerification).where(
EmailVerification.user_id == user_id,
EmailVerification.is_used == False,
EmailVerification.expires_at > datetime.now(UTC)
col(EmailVerification.is_used).is_(False),
EmailVerification.expires_at > datetime.now(UTC),
)
)
existing = existing_result.first()
if existing:
# 如果有未过期的验证码,直接返回
return existing, existing.verification_code
# 生成新的验证码
code = EmailVerificationService.generate_verification_code()
# 创建验证记录
verification = EmailVerification(
user_id=user_id,
@@ -225,23 +219,23 @@ This email was sent automatically, please do not reply.
verification_code=code,
expires_at=datetime.now(UTC) + timedelta(minutes=10), # 10分钟过期
ip_address=ip_address,
user_agent=user_agent
user_agent=user_agent,
)
db.add(verification)
await db.commit()
await db.refresh(verification)
# 存储到 Redis用于快速验证
await redis.setex(
f"email_verification:{user_id}:{code}",
600, # 10分钟过期
str(verification.id) if verification.id else "0"
str(verification.id) if verification.id else "0",
)
logger.info(f"[Email Verification] Created verification code for user {user_id}: {code}")
return verification, code
@staticmethod
async def send_verification_email(
db: AsyncSession,
@@ -250,7 +244,7 @@ This email was sent automatically, please do not reply.
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None
user_agent: str | None = None,
) -> bool:
"""发送验证邮件"""
try:
@@ -258,33 +252,38 @@ This email was sent automatically, please do not reply.
if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}")
return True # 返回成功,但不执行验证流程
# 创建验证记录
verification, code = await EmailVerificationService.create_verification_record(
(
verification,
code,
) = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent
)
# 使用邮件队列发送验证邮件
success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id)
if success:
logger.info(f"[Email Verification] Successfully enqueued verification email to {email} (user: {username})")
logger.info(
f"[Email Verification] Successfully enqueued verification email to {email} (user: {username})"
)
return True
else:
logger.error(f"[Email Verification] Failed to enqueue verification email: {email} (user: {username})")
return False
except Exception as e:
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
return False
@staticmethod
async def verify_code(
db: AsyncSession,
redis: Redis,
user_id: int,
code: str,
ip_address: str | None = None
ip_address: str | None = None,
) -> tuple[bool, str]:
"""验证验证码"""
try:
@@ -294,46 +293,46 @@ This email was sent automatically, please do not reply.
# 仍然标记登录会话为已验证
await LoginSessionService.mark_session_verified(db, user_id)
return True, "验证成功(邮件验证功能已禁用)"
# 先从 Redis 检查
verification_id = await redis.get(f"email_verification:{user_id}:{code}")
if not verification_id:
return False, "验证码无效或已过期"
# 从数据库获取验证记录
result = await db.exec(
select(EmailVerification).where(
EmailVerification.id == int(verification_id),
EmailVerification.user_id == user_id,
EmailVerification.verification_code == code,
EmailVerification.is_used == False,
EmailVerification.expires_at > datetime.now(UTC)
col(EmailVerification.is_used).is_(False),
EmailVerification.expires_at > datetime.now(UTC),
)
)
verification = result.first()
if not verification:
return False, "验证码无效或已过期"
# 标记为已使用
verification.is_used = True
verification.used_at = datetime.now(UTC)
# 同时更新对应的登录会话状态
await LoginSessionService.mark_session_verified(db, user_id)
await db.commit()
# 删除 Redis 记录
await redis.delete(f"email_verification:{user_id}:{code}")
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
return True, "验证成功"
except Exception as e:
logger.error(f"[Email Verification] Exception during verification code validation: {e}")
return False, "验证过程中发生错误"
@staticmethod
async def resend_verification_code(
db: AsyncSession,
@@ -342,7 +341,7 @@ This email was sent automatically, please do not reply.
username: str,
email: str,
ip_address: str | None = None,
user_agent: str | None = None
user_agent: str | None = None,
) -> tuple[bool, str]:
"""重新发送验证码"""
try:
@@ -350,25 +349,25 @@ This email was sent automatically, please do not reply.
if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}")
return True, "验证码已发送(邮件验证功能已禁用)"
# 检查重发频率限制60秒内只能发送一次
rate_limit_key = f"email_verification_rate_limit:{user_id}"
if await redis.get(rate_limit_key):
return False, "请等待60秒后再重新发送"
# 设置频率限制
await redis.setex(rate_limit_key, 60, "1")
# 生成新的验证码
success = await EmailVerificationService.send_verification_email(
db, redis, user_id, username, email, ip_address, user_agent
)
if success:
return True, "验证码已重新发送"
else:
return False, "重新发送失败,请稍后再试"
except Exception as e:
logger.error(f"[Email Verification] Exception during resending verification code: {e}")
return False, "重新发送过程中发生错误"
@@ -376,7 +375,7 @@ This email was sent automatically, please do not reply.
class LoginSessionService:
"""登录会话服务"""
@staticmethod
async def create_session(
db: AsyncSession,
@@ -385,47 +384,40 @@ class LoginSessionService:
ip_address: str,
user_agent: str | None = None,
country_code: str | None = None,
is_new_location: bool = False
is_new_location: bool = False,
) -> LoginSession:
"""创建登录会话"""
from app.utils import simplify_user_agent
session_token = EmailVerificationService.generate_session_token()
# 简化 User-Agent 字符串
simplified_user_agent = simplify_user_agent(user_agent, max_length=250)
session = LoginSession(
user_id=user_id,
session_token=session_token,
ip_address=ip_address,
user_agent=simplified_user_agent,
user_agent=None,
country_code=country_code,
is_new_location=is_new_location,
expires_at=datetime.now(UTC) + timedelta(hours=24), # 24小时过期
is_verified=not is_new_location # 新位置需要验证
is_verified=not is_new_location, # 新位置需要验证
)
db.add(session)
await db.commit()
await db.refresh(session)
# 存储到 Redis
await redis.setex(
f"login_session:{session_token}",
86400, # 24小时
user_id
user_id,
)
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
return session
@staticmethod
async def verify_session(
db: AsyncSession,
redis: Redis,
session_token: str,
verification_code: str
db: AsyncSession, redis: Redis, session_token: str, verification_code: str
) -> tuple[bool, str]:
"""验证会话(通过邮件验证码)"""
try:
@@ -433,98 +425,89 @@ class LoginSessionService:
user_id = await redis.get(f"login_session:{session_token}")
if not user_id:
return False, "会话无效或已过期"
user_id = int(user_id)
# 验证邮件验证码
success, message = await EmailVerificationService.verify_code(
db, redis, user_id, verification_code
)
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_code)
if not success:
return False, message
# 更新会话状态
result = await db.exec(
select(LoginSession).where(
LoginSession.session_token == session_token,
LoginSession.user_id == user_id,
LoginSession.is_verified == False
col(LoginSession.is_verified).is_(False),
)
)
session = result.first()
if session:
session.is_verified = True
session.verified_at = datetime.now(UTC)
await db.commit()
logger.info(f"[Login Session] User {user_id} session verification successful")
return True, "会话验证成功"
except Exception as e:
logger.error(f"[Login Session] Exception during session verification: {e}")
return False, "验证过程中发生错误"
@staticmethod
async def check_new_location(
db: AsyncSession,
user_id: int,
ip_address: str,
country_code: str | None = None
db: AsyncSession, user_id: int, ip_address: str, country_code: str | None = None
) -> bool:
"""检查是否为新位置登录"""
try:
# 查看过去30天内是否有相同IP或相同国家的登录记录
thirty_days_ago = datetime.now(UTC) - timedelta(days=30)
result = await db.exec(
select(LoginSession).where(
LoginSession.user_id == user_id,
LoginSession.created_at > thirty_days_ago,
(LoginSession.ip_address == ip_address) |
(LoginSession.country_code == country_code)
(LoginSession.ip_address == ip_address) | (LoginSession.country_code == country_code),
)
)
existing_sessions = result.all()
# 如果有历史记录,则不是新位置
return len(existing_sessions) == 0
except Exception as e:
logger.error(f"[Login Session] Exception during new location check: {e}")
# 出错时默认为新位置(更安全)
return True
@staticmethod
async def mark_session_verified(
db: AsyncSession,
user_id: int
) -> bool:
async def mark_session_verified(db: AsyncSession, user_id: int) -> bool:
"""标记用户的未验证会话为已验证"""
try:
# 查找用户所有未验证且未过期的会话
result = await db.exec(
select(LoginSession).where(
LoginSession.user_id == user_id,
LoginSession.is_verified == False,
LoginSession.expires_at > datetime.now(UTC)
col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > datetime.now(UTC),
)
)
sessions = result.all()
# 标记所有会话为已验证
for session in sessions:
session.is_verified = True
session.verified_at = datetime.now(UTC)
if sessions:
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
return len(sessions) > 0
except Exception as e:
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
return False

View File

@@ -117,14 +117,10 @@ class EnhancedIntervalStatsManager:
@staticmethod
async def get_current_interval_info() -> IntervalInfo:
"""获取当前区间信息"""
start_time, end_time = (
EnhancedIntervalStatsManager.get_current_interval_boundaries()
)
start_time, end_time = EnhancedIntervalStatsManager.get_current_interval_boundaries()
interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time)
return IntervalInfo(
start_time=start_time, end_time=end_time, interval_key=interval_key
)
return IntervalInfo(start_time=start_time, end_time=end_time, interval_key=interval_key)
@staticmethod
async def initialize_current_interval() -> None:
@@ -133,9 +129,7 @@ class EnhancedIntervalStatsManager:
redis_async = get_redis()
try:
current_interval = (
await EnhancedIntervalStatsManager.get_current_interval_info()
)
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
# 存储当前区间信息
await _redis_exec(
@@ -147,9 +141,7 @@ class EnhancedIntervalStatsManager:
# 初始化区间用户集合(如果不存在)
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
playing_key = (
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
)
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
# 设置过期时间为35分钟
await redis_async.expire(online_key, 35 * 60)
@@ -179,7 +171,8 @@ class EnhancedIntervalStatsManager:
await EnhancedIntervalStatsManager._ensure_24h_history_exists()
logger.info(
f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')} - {current_interval.end_time.strftime('%H:%M')}"
f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')}"
f" - {current_interval.end_time.strftime('%H:%M')}"
)
except Exception as e:
@@ -193,42 +186,32 @@ class EnhancedIntervalStatsManager:
try:
# 检查现有历史数据数量
history_length = await _redis_exec(
redis_sync.llen, REDIS_ONLINE_HISTORY_KEY
)
history_length = await _redis_exec(redis_sync.llen, REDIS_ONLINE_HISTORY_KEY)
if history_length < 48: # 少于48个数据点24小时*2
logger.info(
f"History has only {history_length} points, filling with zeros for 24h"
)
logger.info(f"History has only {history_length} points, filling with zeros for 24h")
# 计算需要填充的数据点数量
needed_points = 48 - history_length
# 从当前时间往前推创建缺失的时间点都填充为0
current_time = datetime.utcnow()
current_interval_start, _ = (
EnhancedIntervalStatsManager.get_current_interval_boundaries()
)
current_time = datetime.utcnow() # noqa: F841
current_interval_start, _ = EnhancedIntervalStatsManager.get_current_interval_boundaries()
# 从当前区间开始往前推创建历史数据点确保时间对齐到30分钟边界
fill_points = []
for i in range(needed_points):
# 每次往前推30分钟确保时间对齐
point_time = current_interval_start - timedelta(
minutes=30 * (i + 1)
)
point_time = current_interval_start - timedelta(minutes=30 * (i + 1))
# 确保时间对齐到30分钟边界
aligned_minute = (point_time.minute // 30) * 30
point_time = point_time.replace(
minute=aligned_minute, second=0, microsecond=0
)
point_time = point_time.replace(minute=aligned_minute, second=0, microsecond=0)
history_point = {
"timestamp": point_time.isoformat(),
"online_count": 0,
"playing_count": 0
"playing_count": 0,
}
fill_points.append(json.dumps(history_point))
@@ -238,9 +221,7 @@ class EnhancedIntervalStatsManager:
temp_key = f"{REDIS_ONLINE_HISTORY_KEY}_temp"
if history_length > 0:
# 复制现有数据到临时key
existing_data = await _redis_exec(
redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1
)
existing_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
if existing_data:
for data in existing_data:
await _redis_exec(redis_sync.rpush, temp_key, data)
@@ -250,19 +231,13 @@ class EnhancedIntervalStatsManager:
# 先添加填充数据(最旧的)
for point in reversed(fill_points): # 反向添加,最旧的在最后
await _redis_exec(
redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point
)
await _redis_exec(redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point)
# 再添加原有数据(较新的)
if history_length > 0:
existing_data = await _redis_exec(
redis_sync.lrange, temp_key, 0, -1
)
existing_data = await _redis_exec(redis_sync.lrange, temp_key, 0, -1)
for data in existing_data:
await _redis_exec(
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data
)
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data)
# 清理临时key
await redis_async.delete(temp_key)
@@ -273,9 +248,7 @@ class EnhancedIntervalStatsManager:
# 设置过期时间
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
logger.info(
f"Filled {len(fill_points)} historical data points with zeros"
)
logger.info(f"Filled {len(fill_points)} historical data points with zeros")
except Exception as e:
logger.error(f"Error ensuring 24h history exists: {e}")
@@ -287,9 +260,7 @@ class EnhancedIntervalStatsManager:
redis_async = get_redis()
try:
current_interval = (
await EnhancedIntervalStatsManager.get_current_interval_info()
)
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
# 添加到区间在线用户集合
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
@@ -298,9 +269,7 @@ class EnhancedIntervalStatsManager:
# 如果用户在游玩,也添加到游玩用户集合
if is_playing:
playing_key = (
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
)
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
await _redis_exec(redis_sync.sadd, playing_key, str(user_id))
await redis_async.expire(playing_key, 35 * 60)
@@ -308,7 +277,8 @@ class EnhancedIntervalStatsManager:
await EnhancedIntervalStatsManager._update_interval_stats()
logger.debug(
f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}-{current_interval.end_time.strftime('%H:%M')}"
f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}"
f"-{current_interval.end_time.strftime('%H:%M')}"
)
except Exception as e:
@@ -321,15 +291,11 @@ class EnhancedIntervalStatsManager:
redis_async = get_redis()
try:
current_interval = (
await EnhancedIntervalStatsManager.get_current_interval_info()
)
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
# 获取区间内独特用户数
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
playing_key = (
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
)
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
unique_online = await _redis_exec(redis_sync.scard, online_key)
unique_playing = await _redis_exec(redis_sync.scard, playing_key)
@@ -339,16 +305,12 @@ class EnhancedIntervalStatsManager:
current_playing = await _get_playing_users_count(redis_async)
# 获取现有统计数据
existing_data = await _redis_exec(
redis_sync.get, current_interval.interval_key
)
existing_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
if existing_data:
stats = IntervalStats.from_dict(json.loads(existing_data))
# 更新峰值
stats.peak_online_count = max(stats.peak_online_count, current_online)
stats.peak_playing_count = max(
stats.peak_playing_count, current_playing
)
stats.peak_playing_count = max(stats.peak_playing_count, current_playing)
stats.total_samples += 1
else:
# 创建新的统计记录
@@ -377,7 +339,8 @@ class EnhancedIntervalStatsManager:
await redis_async.expire(current_interval.interval_key, 35 * 60)
logger.debug(
f"Updated interval stats: online={unique_online}, playing={unique_playing}, peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}"
f"Updated interval stats: online={unique_online}, playing={unique_playing}, "
f"peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}"
)
except Exception as e:
@@ -395,21 +358,21 @@ class EnhancedIntervalStatsManager:
# 上一个区间开始时间是当前区间开始时间减去30分钟
previous_start = current_start - timedelta(minutes=30)
previous_end = current_start # 上一个区间的结束时间就是当前区间的开始时间
interval_key = EnhancedIntervalStatsManager.generate_interval_key(previous_start)
previous_interval = IntervalInfo(
start_time=previous_start,
end_time=previous_end,
interval_key=interval_key
interval_key=interval_key,
)
# 获取最终统计数据
stats_data = await _redis_exec(
redis_sync.get, previous_interval.interval_key
)
stats_data = await _redis_exec(redis_sync.get, previous_interval.interval_key)
if not stats_data:
logger.warning(f"No interval stats found to finalize for {previous_interval.start_time.strftime('%H:%M')}")
logger.warning(
f"No interval stats found to finalize for {previous_interval.start_time.strftime('%H:%M')}"
)
return None
stats = IntervalStats.from_dict(json.loads(stats_data))
@@ -418,13 +381,11 @@ class EnhancedIntervalStatsManager:
history_point = {
"timestamp": previous_interval.start_time.isoformat(),
"online_count": stats.unique_online_users,
"playing_count": stats.unique_playing_users
"playing_count": stats.unique_playing_users,
}
# 添加到历史记录
await _redis_exec(
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)
)
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
# 只保留48个数据点24小时每30分钟一个点
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
# 设置过期时间为26小时确保有足够缓冲
@@ -452,12 +413,8 @@ class EnhancedIntervalStatsManager:
redis_sync = get_redis_message()
try:
current_interval = (
await EnhancedIntervalStatsManager.get_current_interval_info()
)
stats_data = await _redis_exec(
redis_sync.get, current_interval.interval_key
)
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
stats_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
if stats_data:
return IntervalStats.from_dict(json.loads(stats_data))
@@ -506,8 +463,6 @@ class EnhancedIntervalStatsManager:
# 便捷函数,用于替换现有的统计更新函数
async def update_user_activity_in_interval(
user_id: int, is_playing: bool = False
) -> None:
async def update_user_activity_in_interval(user_id: int, is_playing: bool = False) -> None:
"""用户活动时更新区间统计(在登录、开始游玩等时调用)"""
await EnhancedIntervalStatsManager.add_user_to_interval(user_id, is_playing)

Some files were not shown because too many files have changed in this diff Show More