refactor(project): make pyright & ruff happy

This commit is contained in:
MingxuanGame
2025-08-22 08:21:52 +00:00
parent 3b1d7a2234
commit 598fcc8b38
157 changed files with 2382 additions and 4590 deletions

View File

@@ -65,9 +65,7 @@ async def to_the_core(
# using either of the mods specified: DT, NC # using either of the mods specified: DT, NC
if not score.passed: if not score.passed:
return False return False
if ( if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist:
"Nightcore" not in beatmap.beatmapset.title
) and "Nightcore" not in beatmap.beatmapset.artist:
return False return False
mods_ = mod_to_save(score.mods) mods_ = mod_to_save(score.mods)
if "DT" not in mods_ or "NC" not in mods_: if "DT" not in mods_ or "NC" not in mods_:
@@ -118,9 +116,7 @@ async def reckless_adandon(
fetcher = await get_fetcher() fetcher = await get_fetcher()
redis = get_redis() redis = get_redis()
mods_ = score.mods.copy() mods_ = score.mods.copy()
attribute = await calculate_beatmap_attributes( attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
beatmap.id, score.gamemode, mods_, redis, fetcher
)
if attribute.star_rating < 3: if attribute.star_rating < 3:
return False return False
return True return True
@@ -186,9 +182,7 @@ async def slow_and_steady(
fetcher = await get_fetcher() fetcher = await get_fetcher()
redis = get_redis() redis = get_redis()
mods_ = score.mods.copy() mods_ = score.mods.copy()
attribute = await calculate_beatmap_attributes( attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
beatmap.id, score.gamemode, mods_, redis, fetcher
)
return attribute.star_rating >= 3 return attribute.star_rating >= 3
@@ -218,9 +212,7 @@ async def sognare(
mods_ = mod_to_save(score.mods) mods_ = mod_to_save(score.mods)
if "HT" not in mods_: if "HT" not in mods_:
return False return False
return ( return beatmap.beatmapset.artist == "LeaF" and beatmap.beatmapset.title == "Evanescent"
beatmap.beatmapset.artist == "LeaF" and beatmap.beatmapset.title == "Evanescent"
)
async def realtor_extraordinaire( async def realtor_extraordinaire(
@@ -234,10 +226,7 @@ async def realtor_extraordinaire(
mods_ = mod_to_save(score.mods) mods_ = mod_to_save(score.mods)
if not ("DT" in mods_ or "NC" in mods_) or "HR" not in mods_: if not ("DT" in mods_ or "NC" in mods_) or "HR" not in mods_:
return False return False
return ( return beatmap.beatmapset.artist == "cYsmix" and beatmap.beatmapset.title == "House With Legs"
beatmap.beatmapset.artist == "cYsmix"
and beatmap.beatmapset.title == "House With Legs"
)
async def impeccable( async def impeccable(
@@ -255,9 +244,7 @@ async def impeccable(
fetcher = await get_fetcher() fetcher = await get_fetcher()
redis = get_redis() redis = get_redis()
mods_ = score.mods.copy() mods_ = score.mods.copy()
attribute = await calculate_beatmap_attributes( attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
beatmap.id, score.gamemode, mods_, redis, fetcher
)
return attribute.star_rating >= 4 return attribute.star_rating >= 4
@@ -274,18 +261,14 @@ async def aeon(
mods_ = mod_to_save(score.mods) mods_ = mod_to_save(score.mods)
if "FL" not in mods_ or "HD" not in mods_ or "HT" not in mods_: if "FL" not in mods_ or "HD" not in mods_ or "HT" not in mods_:
return False return False
if not beatmap.beatmapset.ranked_date or beatmap.beatmapset.ranked_date > datetime( if not beatmap.beatmapset.ranked_date or beatmap.beatmapset.ranked_date > datetime(2012, 1, 1):
2012, 1, 1
):
return False return False
if beatmap.total_length < 180: if beatmap.total_length < 180:
return False return False
fetcher = await get_fetcher() fetcher = await get_fetcher()
redis = get_redis() redis = get_redis()
mods_ = score.mods.copy() mods_ = score.mods.copy()
attribute = await calculate_beatmap_attributes( attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
beatmap.id, score.gamemode, mods_, redis, fetcher
)
return attribute.star_rating >= 4 return attribute.star_rating >= 4
@@ -297,10 +280,7 @@ async def quick_maths(
# Get exactly 34 misses on any difficulty of Function Phantom - Variable. # Get exactly 34 misses on any difficulty of Function Phantom - Variable.
if score.nmiss != 34: if score.nmiss != 34:
return False return False
return ( return beatmap.beatmapset.artist == "Function Phantom" and beatmap.beatmapset.title == "Variable"
beatmap.beatmapset.artist == "Function Phantom"
and beatmap.beatmapset.title == "Variable"
)
async def kaleidoscope( async def kaleidoscope(
@@ -328,8 +308,7 @@ async def valediction(
return ( return (
score.passed score.passed
and beatmap.beatmapset.artist == "a_hisa" and beatmap.beatmapset.artist == "a_hisa"
and beatmap.beatmapset.title and beatmap.beatmapset.title == "Alexithymia | Lupinus | Tokei no Heya to Seishin Sekai"
== "Alexithymia | Lupinus | Tokei no Heya to Seishin Sekai"
and score.accuracy >= 0.9 and score.accuracy >= 0.9
) )
@@ -342,9 +321,7 @@ async def right_on_time(
# Submit a score on Kola Kid - timer on the first minute of any hour # Submit a score on Kola Kid - timer on the first minute of any hour
if not score.passed: if not score.passed:
return False return False
if not ( if not (beatmap.beatmapset.artist == "Kola Kid" and beatmap.beatmapset.title == "timer"):
beatmap.beatmapset.artist == "Kola Kid" and beatmap.beatmapset.title == "timer"
):
return False return False
return score.ended_at.minute == 0 return score.ended_at.minute == 0
@@ -361,9 +338,7 @@ async def not_again(
return False return False
if score.accuracy < 0.99: if score.accuracy < 0.99:
return False return False
return ( return beatmap.beatmapset.artist == "ARForest" and beatmap.beatmapset.title == "Regret"
beatmap.beatmapset.artist == "ARForest" and beatmap.beatmapset.title == "Regret"
)
async def deliberation( async def deliberation(
@@ -377,18 +352,13 @@ async def deliberation(
mods_ = mod_to_save(score.mods) mods_ = mod_to_save(score.mods)
if "HT" not in mods_: if "HT" not in mods_:
return False return False
if ( if not beatmap.beatmap_status.has_pp() and beatmap.beatmap_status != BeatmapRankStatus.LOVED:
not beatmap.beatmap_status.has_pp()
and beatmap.beatmap_status != BeatmapRankStatus.LOVED
):
return False return False
fetcher = await get_fetcher() fetcher = await get_fetcher()
redis = get_redis() redis = get_redis()
mods_copy = score.mods.copy() mods_copy = score.mods.copy()
attribute = await calculate_beatmap_attributes( attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_copy, redis, fetcher)
beatmap.id, score.gamemode, mods_copy, redis, fetcher
)
return attribute.star_rating >= 6 return attribute.star_rating >= 6

View File

@@ -72,7 +72,7 @@ MEDALS: Medals = {
Achievement( Achievement(
id=93, id=93,
name="Sweet Rave Party", name="Sweet Rave Party",
desc="Founded in the fine tradition of changing things that were just fine as they were.", # noqa: E501 desc="Founded in the fine tradition of changing things that were just fine as they were.",
assets_id="all-intro-nightcore", assets_id="all-intro-nightcore",
): partial(process_mod, "NC"), ): partial(process_mod, "NC"),
Achievement( Achievement(

View File

@@ -16,11 +16,7 @@ async def process_combo(
score: Score, score: Score,
beatmap: Beatmap, beatmap: Beatmap,
) -> bool: ) -> bool:
if ( if not score.passed or not beatmap.beatmap_status.has_pp() or score.gamemode != GameMode.OSU:
not score.passed
or not beatmap.beatmap_status.has_pp()
or score.gamemode != GameMode.OSU
):
return False return False
if combo < 1: if combo < 1:
return False return False

View File

@@ -44,9 +44,7 @@ async def process_skill(
redis = get_redis() redis = get_redis()
mods_ = score.mods.copy() mods_ = score.mods.copy()
mods_.sort(key=lambda x: x["acronym"]) mods_.sort(key=lambda x: x["acronym"])
attribute = await calculate_beatmap_attributes( attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
beatmap.id, score.gamemode, mods_, redis, fetcher
)
if attribute.star_rating < star or attribute.star_rating >= star + 1: if attribute.star_rating < star or attribute.star_rating >= star + 1:
return False return False
if type == "fc" and not score.is_perfect_combo: if type == "fc" and not score.is_perfect_combo:

View File

@@ -43,9 +43,7 @@ def validate_username(username: str) -> list[str]:
# 检查用户名格式(只允许字母、数字、下划线、连字符) # 检查用户名格式(只允许字母、数字、下划线、连字符)
if not re.match(r"^[a-zA-Z0-9_-]+$", username): if not re.match(r"^[a-zA-Z0-9_-]+$", username):
errors.append( errors.append("Username can only contain letters, numbers, underscores, and hyphens")
"Username can only contain letters, numbers, underscores, and hyphens"
)
# 检查是否以数字开头 # 检查是否以数字开头
if username[0].isdigit(): if username[0].isdigit():
@@ -104,9 +102,7 @@ def get_password_hash(password: str) -> str:
return pw_bcrypt.decode() return pw_bcrypt.decode()
async def authenticate_user_legacy( async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) -> User | None:
db: AsyncSession, name: str, password: str
) -> User | None:
""" """
验证用户身份 - 使用类似 from_login 的逻辑 验证用户身份 - 使用类似 from_login 的逻辑
""" """
@@ -145,9 +141,7 @@ async def authenticate_user_legacy(
return None return None
async def authenticate_user( async def authenticate_user(db: AsyncSession, username: str, password: str) -> User | None:
db: AsyncSession, username: str, password: str
) -> User | None:
"""验证用户身份""" """验证用户身份"""
return await authenticate_user_legacy(db, username, password) return await authenticate_user_legacy(db, username, password)
@@ -158,14 +152,10 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s
if expires_delta: if expires_delta:
expire = datetime.now(UTC) + expires_delta expire = datetime.now(UTC) + expires_delta
else: else:
expire = datetime.now(UTC) + timedelta( expire = datetime.now(UTC) + timedelta(minutes=settings.access_token_expire_minutes)
minutes=settings.access_token_expire_minutes
)
to_encode.update({"exp": expire, "random": secrets.token_hex(16)}) to_encode.update({"exp": expire, "random": secrets.token_hex(16)})
encoded_jwt = jwt.encode( encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
to_encode, settings.secret_key, algorithm=settings.algorithm
)
return encoded_jwt return encoded_jwt
@@ -178,20 +168,20 @@ def generate_refresh_token() -> str:
async def invalidate_user_tokens(db: AsyncSession, user_id: int) -> int: async def invalidate_user_tokens(db: AsyncSession, user_id: int) -> int:
"""使指定用户的所有令牌失效 """使指定用户的所有令牌失效
返回删除的令牌数量 返回删除的令牌数量
""" """
# 使用 select 先获取所有令牌 # 使用 select 先获取所有令牌
stmt = select(OAuthToken).where(OAuthToken.user_id == user_id) stmt = select(OAuthToken).where(OAuthToken.user_id == user_id)
result = await db.exec(stmt) result = await db.exec(stmt)
tokens = result.all() tokens = result.all()
# 逐个删除令牌 # 逐个删除令牌
count = 0 count = 0
for token in tokens: for token in tokens:
await db.delete(token) await db.delete(token)
count += 1 count += 1
# 提交更改 # 提交更改
await db.commit() await db.commit()
return count return count
@@ -200,9 +190,7 @@ async def invalidate_user_tokens(db: AsyncSession, user_id: int) -> int:
def verify_token(token: str) -> dict | None: def verify_token(token: str) -> dict | None:
"""验证访问令牌""" """验证访问令牌"""
try: try:
payload = jwt.decode( payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
token, settings.secret_key, algorithms=[settings.algorithm]
)
return payload return payload
except JWTError: except JWTError:
return None return None
@@ -221,17 +209,13 @@ async def store_token(
expires_at = datetime.utcnow() + timedelta(seconds=expires_in) expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
# 删除用户的旧令牌 # 删除用户的旧令牌
statement = select(OAuthToken).where( statement = select(OAuthToken).where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id)
OAuthToken.user_id == user_id, OAuthToken.client_id == client_id
)
old_tokens = (await db.exec(statement)).all() old_tokens = (await db.exec(statement)).all()
for token in old_tokens: for token in old_tokens:
await db.delete(token) await db.delete(token)
# 检查是否有重复的 access_token # 检查是否有重复的 access_token
duplicate_token = ( duplicate_token = (await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))).first()
await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))
).first()
if duplicate_token: if duplicate_token:
await db.delete(duplicate_token) await db.delete(duplicate_token)
@@ -250,9 +234,7 @@ async def store_token(
return token_record return token_record
async def get_token_by_access_token( async def get_token_by_access_token(db: AsyncSession, access_token: str) -> OAuthToken | None:
db: AsyncSession, access_token: str
) -> OAuthToken | None:
"""根据访问令牌获取令牌记录""" """根据访问令牌获取令牌记录"""
statement = select(OAuthToken).where( statement = select(OAuthToken).where(
OAuthToken.access_token == access_token, OAuthToken.access_token == access_token,
@@ -261,9 +243,7 @@ async def get_token_by_access_token(
return (await db.exec(statement)).first() return (await db.exec(statement)).first()
async def get_token_by_refresh_token( async def get_token_by_refresh_token(db: AsyncSession, refresh_token: str) -> OAuthToken | None:
db: AsyncSession, refresh_token: str
) -> OAuthToken | None:
"""根据刷新令牌获取令牌记录""" """根据刷新令牌获取令牌记录"""
statement = select(OAuthToken).where( statement = select(OAuthToken).where(
OAuthToken.refresh_token == refresh_token, OAuthToken.refresh_token == refresh_token,

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
import math import math
@@ -67,11 +68,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
if settings.suspicious_score_check: if settings.suspicious_score_check:
beatmap_banned = ( beatmap_banned = (
await session.exec( await session.exec(select(exists()).where(col(BannedBeatmaps.beatmap_id) == score.beatmap_id))
select(exists()).where(
col(BannedBeatmaps.beatmap_id) == score.beatmap_id
)
)
).first() ).first()
if beatmap_banned: if beatmap_banned:
return 0 return 0
@@ -82,12 +79,9 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
logger.warning(f"Beatmap {score.beatmap_id} is suspicious, banned") logger.warning(f"Beatmap {score.beatmap_id} is suspicious, banned")
return 0 return 0
except Exception: except Exception:
logger.exception( logger.exception(f"Error checking if beatmap {score.beatmap_id} is suspicious")
f"Error checking if beatmap {score.beatmap_id} is suspicious"
)
# 使用线程池执行计算密集型操作以避免阻塞事件循环 # 使用线程池执行计算密集型操作以避免阻塞事件循环
import asyncio
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@@ -118,9 +112,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
pp = attrs.pp pp = attrs.pp
# mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp # mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp
if settings.suspicious_score_check and ( if settings.suspicious_score_check and ((attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 2300):
(attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 2300
):
logger.warning( logger.warning(
f"User {score.user_id} played {score.beatmap_id} " f"User {score.user_id} played {score.beatmap_id} "
f"(star={attrs.difficulty.stars}) with {pp=} " f"(star={attrs.difficulty.stars}) with {pp=} "
@@ -131,9 +123,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
return pp return pp
async def pre_fetch_and_calculate_pp( async def pre_fetch_and_calculate_pp(score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher) -> float:
score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher
) -> float:
""" """
优化版PP计算预先获取beatmap文件并使用缓存 优化版PP计算预先获取beatmap文件并使用缓存
""" """
@@ -144,9 +134,7 @@ async def pre_fetch_and_calculate_pp(
# 快速检查是否被封禁 # 快速检查是否被封禁
if settings.suspicious_score_check: if settings.suspicious_score_check:
beatmap_banned = ( beatmap_banned = (
await session.exec( await session.exec(select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id))
select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id)
)
).first() ).first()
if beatmap_banned: if beatmap_banned:
return 0 return 0
@@ -202,9 +190,7 @@ async def batch_calculate_pp(
banned_beatmaps = set() banned_beatmaps = set()
if settings.suspicious_score_check: if settings.suspicious_score_check:
banned_results = await session.exec( banned_results = await session.exec(
select(BannedBeatmaps.beatmap_id).where( select(BannedBeatmaps.beatmap_id).where(col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids))
col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids)
)
) )
banned_beatmaps = set(banned_results.all()) banned_beatmaps = set(banned_results.all())
@@ -380,9 +366,7 @@ def calculate_score_to_level(total_score: int) -> float:
level = 0.0 level = 0.0
while remaining_score > 0: while remaining_score > 0:
next_level_requirement = to_next_level[ next_level_requirement = to_next_level[min(len(to_next_level) - 1, round(level))]
min(len(to_next_level) - 1, round(level))
]
level += min(1, remaining_score / next_level_requirement) level += min(1, remaining_score / next_level_requirement)
remaining_score -= next_level_requirement remaining_score -= next_level_requirement
@@ -417,9 +401,7 @@ class Threshold(int, Enum):
NOTE_POSX_THRESHOLD = 512 # x: [-512,512] NOTE_POSX_THRESHOLD = 512 # x: [-512,512]
NOTE_POSY_THRESHOLD = 384 # y: [-384,384] NOTE_POSY_THRESHOLD = 384 # y: [-384,384]
POS_ERROR_THRESHOLD = ( POS_ERROR_THRESHOLD = 1280 * 50 # 超过这么多个物件(包括滑条控制点)的位置有问题就毙掉
1280 * 50
) # 超过这么多个物件(包括滑条控制点)的位置有问题就毙掉
SLIDER_REPEAT_THRESHOLD = 5000 SLIDER_REPEAT_THRESHOLD = 5000
@@ -469,10 +451,7 @@ def is_2b(hit_objects: list[HitObject]) -> bool:
def is_suspicious_beatmap(content: str) -> bool: def is_suspicious_beatmap(content: str) -> bool:
osufile = OsuFile(content=content.encode("utf-8")).parse_file() osufile = OsuFile(content=content.encode("utf-8")).parse_file()
if ( if osufile.hit_objects[-1].start_time - osufile.hit_objects[0].start_time > 24 * 60 * 60 * 1000:
osufile.hit_objects[-1].start_time - osufile.hit_objects[0].start_time
> 24 * 60 * 60 * 1000
):
return True return True
if osufile.mode == int(GameMode.TAIKO): if osufile.mode == int(GameMode.TAIKO):
if len(osufile.hit_objects) > Threshold.TAIKO_THRESHOLD: if len(osufile.hit_objects) > Threshold.TAIKO_THRESHOLD:

View File

@@ -124,14 +124,10 @@ class Settings(BaseSettings):
smtp_password: str = "" smtp_password: str = ""
from_email: str = "noreply@example.com" from_email: str = "noreply@example.com"
from_name: str = "osu! server" from_name: str = "osu! server"
# 邮件验证功能开关 # 邮件验证功能开关
enable_email_verification: bool = Field( enable_email_verification: bool = Field(default=True, description="是否启用邮件验证功能")
default=True, description="是否启用邮件验证功能" enable_email_sending: bool = Field(default=False, description="是否真实发送邮件False时仅模拟发送")
)
enable_email_sending: bool = Field(
default=False, description="是否真实发送邮件False时仅模拟发送"
)
# Sentry 配置 # Sentry 配置
sentry_dsn: HttpUrl | None = None sentry_dsn: HttpUrl | None = None
@@ -143,12 +139,8 @@ class Settings(BaseSettings):
geoip_update_hour: int = 2 # 每周更新的小时数0-23 geoip_update_hour: int = 2 # 每周更新的小时数0-23
# 游戏设置 # 游戏设置
enable_rx: bool = Field( enable_rx: bool = Field(default=False, validation_alias=AliasChoices("enable_rx", "enable_osu_rx"))
default=False, validation_alias=AliasChoices("enable_rx", "enable_osu_rx") enable_ap: bool = Field(default=False, validation_alias=AliasChoices("enable_ap", "enable_osu_ap"))
)
enable_ap: bool = Field(
default=False, validation_alias=AliasChoices("enable_ap", "enable_osu_ap")
)
enable_all_mods_pp: bool = False enable_all_mods_pp: bool = False
enable_supporter_for_all_users: bool = False enable_supporter_for_all_users: bool = False
enable_all_beatmap_leaderboard: bool = False enable_all_beatmap_leaderboard: bool = False
@@ -189,9 +181,7 @@ class Settings(BaseSettings):
# 存储设置 # 存储设置
storage_service: StorageServiceType = StorageServiceType.LOCAL storage_service: StorageServiceType = StorageServiceType.LOCAL
storage_settings: ( storage_settings: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings = LocalStorageSettings()
LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings
) = LocalStorageSettings()
@field_validator("fetcher_scopes", mode="before") @field_validator("fetcher_scopes", mode="before")
def validate_fetcher_scopes(cls, v: Any) -> list[str]: def validate_fetcher_scopes(cls, v: Any) -> list[str]:
@@ -207,22 +197,13 @@ class Settings(BaseSettings):
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings: ) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2: if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2:
if not isinstance(v, CloudflareR2Settings): if not isinstance(v, CloudflareR2Settings):
raise ValueError( raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
"When storage_service is 'r2', "
"storage_settings must be CloudflareR2Settings"
)
elif info.data.get("storage_service") == StorageServiceType.LOCAL: elif info.data.get("storage_service") == StorageServiceType.LOCAL:
if not isinstance(v, LocalStorageSettings): if not isinstance(v, LocalStorageSettings):
raise ValueError( raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
"When storage_service is 'local', "
"storage_settings must be LocalStorageSettings"
)
elif info.data.get("storage_service") == StorageServiceType.AWS_S3: elif info.data.get("storage_service") == StorageServiceType.AWS_S3:
if not isinstance(v, AWSS3StorageSettings): if not isinstance(v, AWSS3StorageSettings):
raise ValueError( raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
"When storage_service is 's3', "
"storage_settings must be AWSS3StorageSettings"
)
return v return v

View File

@@ -28,18 +28,14 @@ if TYPE_CHECKING:
class UserAchievementBase(SQLModel, UTCBaseModel): class UserAchievementBase(SQLModel, UTCBaseModel):
achievement_id: int achievement_id: int
achieved_at: datetime = Field( achieved_at: datetime = Field(default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)))
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
)
class UserAchievement(UserAchievementBase, table=True): class UserAchievement(UserAchievementBase, table=True):
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType] __tablename__: str = "lazer_user_achievements"
id: int | None = Field(default=None, primary_key=True, index=True) id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True)
sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True
)
user: "User" = Relationship(back_populates="achievement") user: "User" = Relationship(back_populates="achievement")
@@ -56,11 +52,7 @@ async def process_achievements(session: AsyncSession, redis: Redis, score_id: in
if not score: if not score:
return return
achieved = ( achieved = (
await session.exec( await session.exec(select(UserAchievement.achievement_id).where(UserAchievement.user_id == score.user_id))
select(UserAchievement.achievement_id).where(
UserAchievement.user_id == score.user_id
)
)
).all() ).all()
not_achieved = {k: v for k, v in MEDALS.items() if k.id not in achieved} not_achieved = {k: v for k, v in MEDALS.items() if k.id not in achieved}
result: list[Achievement] = [] result: list[Achievement] = []
@@ -78,9 +70,7 @@ async def process_achievements(session: AsyncSession, redis: Redis, score_id: in
) )
await redis.publish( await redis.publish(
"chat:notification", "chat:notification",
UserAchievementUnlock.init( UserAchievementUnlock.init(r, score.user_id, score.gamemode).model_dump_json(),
r, score.user_id, score.gamemode
).model_dump_json(),
) )
event = Event( event = Event(
created_at=now, created_at=now,

View File

@@ -20,42 +20,34 @@ if TYPE_CHECKING:
class OAuthToken(UTCBaseModel, SQLModel, table=True): class OAuthToken(UTCBaseModel, SQLModel, table=True):
__tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType] __tablename__: str = "oauth_tokens"
id: int | None = Field(default=None, primary_key=True, index=True) id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
client_id: int = Field(index=True) client_id: int = Field(index=True)
access_token: str = Field(max_length=500, unique=True) access_token: str = Field(max_length=500, unique=True)
refresh_token: str = Field(max_length=500, unique=True) refresh_token: str = Field(max_length=500, unique=True)
token_type: str = Field(default="Bearer", max_length=20) token_type: str = Field(default="Bearer", max_length=20)
scope: str = Field(default="*", max_length=100) scope: str = Field(default="*", max_length=100)
expires_at: datetime = Field(sa_column=Column(DateTime)) expires_at: datetime = Field(sa_column=Column(DateTime))
created_at: datetime = Field( created_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship() user: "User" = Relationship()
class OAuthClient(SQLModel, table=True): class OAuthClient(SQLModel, table=True):
__tablename__ = "oauth_clients" # pyright: ignore[reportAssignmentType] __tablename__: str = "oauth_clients"
name: str = Field(max_length=100, index=True) name: str = Field(max_length=100, index=True)
description: str = Field(sa_column=Column(Text), default="") description: str = Field(sa_column=Column(Text), default="")
client_id: int | None = Field(default=None, primary_key=True, index=True) client_id: int | None = Field(default=None, primary_key=True, index=True)
client_secret: str = Field(default_factory=secrets.token_hex, index=True) client_secret: str = Field(default_factory=secrets.token_hex, index=True)
redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON)) redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON))
owner_id: int = Field( owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
class V1APIKeys(SQLModel, table=True): class V1APIKeys(SQLModel, table=True):
__tablename__ = "v1_api_keys" # pyright: ignore[reportAssignmentType] __tablename__: str = "v1_api_keys"
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
name: str = Field(max_length=100, index=True) name: str = Field(max_length=100, index=True)
key: str = Field(default_factory=secrets.token_hex, index=True) key: str = Field(default_factory=secrets.token_hex, index=True)
owner_id: int = Field( owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)

View File

@@ -60,17 +60,13 @@ class BeatmapBase(SQLModel):
class Beatmap(BeatmapBase, table=True): class Beatmap(BeatmapBase, table=True):
__tablename__ = "beatmaps" # pyright: ignore[reportAssignmentType] __tablename__: str = "beatmaps"
id: int = Field(primary_key=True, index=True) id: int = Field(primary_key=True, index=True)
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
beatmap_status: BeatmapRankStatus = Field(index=True) beatmap_status: BeatmapRankStatus = Field(index=True)
# optional # optional
beatmapset: Beatmapset = Relationship( beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"} failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
)
failtimes: FailTime | None = Relationship(
back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"}
)
@classmethod @classmethod
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap": async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
@@ -84,21 +80,15 @@ class Beatmap(BeatmapBase, table=True):
"beatmap_status": BeatmapRankStatus(resp.ranked), "beatmap_status": BeatmapRankStatus(resp.ranked),
} }
) )
if not ( if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
await session.exec(select(exists()).where(Beatmap.id == resp.id))
).first():
session.add(beatmap) session.add(beatmap)
await session.commit() await session.commit()
beatmap = ( beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).first()
await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
).first()
assert beatmap is not None, "Beatmap should not be None after commit" assert beatmap is not None, "Beatmap should not be None after commit"
return beatmap return beatmap
@classmethod @classmethod
async def from_resp_batch( async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0
) -> list["Beatmap"]:
beatmaps = [] beatmaps = []
for resp in inp: for resp in inp:
if resp.id == from_: if resp.id == from_:
@@ -113,9 +103,7 @@ class Beatmap(BeatmapBase, table=True):
"beatmap_status": BeatmapRankStatus(resp.ranked), "beatmap_status": BeatmapRankStatus(resp.ranked),
} }
) )
if not ( if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
await session.exec(select(exists()).where(Beatmap.id == resp.id))
).first():
session.add(beatmap) session.add(beatmap)
beatmaps.append(beatmap) beatmaps.append(beatmap)
await session.commit() await session.commit()
@@ -130,17 +118,11 @@ class Beatmap(BeatmapBase, table=True):
md5: str | None = None, md5: str | None = None,
) -> "Beatmap": ) -> "Beatmap":
beatmap = ( beatmap = (
await session.exec( await session.exec(select(Beatmap).where(Beatmap.id == bid if bid is not None else Beatmap.checksum == md5))
select(Beatmap).where(
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
)
)
).first() ).first()
if not beatmap: if not beatmap:
resp = await fetcher.get_beatmap(bid, md5) resp = await fetcher.get_beatmap(bid, md5)
r = await session.exec( r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)
)
if not r.first(): if not r.first():
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id) set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
await Beatmapset.from_resp(session, set_resp, from_=resp.id) await Beatmapset.from_resp(session, set_resp, from_=resp.id)
@@ -178,10 +160,7 @@ class BeatmapResp(BeatmapBase):
if query_mode is not None and beatmap.mode != query_mode: if query_mode is not None and beatmap.mode != query_mode:
beatmap_["convert"] = True beatmap_["convert"] = True
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard() beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
if ( if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
settings.enable_all_beatmap_leaderboard
and not beatmap_status.has_leaderboard()
):
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower() beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
else: else:
@@ -189,9 +168,7 @@ class BeatmapResp(BeatmapBase):
beatmap_["ranked"] = beatmap_status.value beatmap_["ranked"] = beatmap_status.value
beatmap_["mode_int"] = int(beatmap.mode) beatmap_["mode_int"] = int(beatmap.mode)
if not from_set: if not from_set:
beatmap_["beatmapset"] = await BeatmapsetResp.from_db( beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
beatmap.beatmapset, session=session, user=user
)
if beatmap.failtimes is not None: if beatmap.failtimes is not None:
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes) beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
else: else:
@@ -218,7 +195,7 @@ class BeatmapResp(BeatmapBase):
class BannedBeatmaps(SQLModel, table=True): class BannedBeatmaps(SQLModel, table=True):
__tablename__ = "banned_beatmaps" # pyright: ignore[reportAssignmentType] __tablename__: str = "banned_beatmaps"
id: int | None = Field(primary_key=True, index=True, default=None) id: int | None = Field(primary_key=True, index=True, default=None)
beatmap_id: int = Field(index=True) beatmap_id: int = Field(index=True)
@@ -230,15 +207,10 @@ async def calculate_beatmap_attributes(
redis: Redis, redis: Redis,
fetcher: "Fetcher", fetcher: "Fetcher",
): ):
key = ( key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
f"beatmap:{beatmap_id}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
)
if await redis.exists(key): if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] return BeatmapAttributes.model_validate_json(await redis.get(key))
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
attr = await asyncio.get_event_loop().run_in_executor( attr = await asyncio.get_event_loop().run_in_executor(None, calculate_beatmap_attribute, resp, ruleset, mods_)
None, calculate_beatmap_attribute, resp, ruleset, mods_
)
await redis.set(key, attr.model_dump_json()) await redis.set(key, attr.model_dump_json())
return attr return attr

View File

@@ -23,15 +23,13 @@ if TYPE_CHECKING:
class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True): class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True):
__tablename__ = "beatmap_playcounts" # pyright: ignore[reportAssignmentType] __tablename__: str = "beatmap_playcounts"
id: int | None = Field( id: int | None = Field(
default=None, default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True), sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
) )
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
playcount: int = Field(default=0) playcount: int = Field(default=0)
@@ -59,9 +57,7 @@ class BeatmapPlaycountsResp(BaseModel):
) )
async def process_beatmap_playcount( async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None:
session: AsyncSession, user_id: int, beatmap_id: int
) -> None:
existing_playcount = ( existing_playcount = (
await session.exec( await session.exec(
select(BeatmapPlaycounts).where( select(BeatmapPlaycounts).where(
@@ -89,7 +85,5 @@ async def process_beatmap_playcount(
} }
session.add(playcount_event) session.add(playcount_event)
else: else:
new_playcount = BeatmapPlaycounts( new_playcount = BeatmapPlaycounts(user_id=user_id, beatmap_id=beatmap_id, playcount=1)
user_id=user_id, beatmap_id=beatmap_id, playcount=1
)
session.add(new_playcount) session.add(new_playcount)

View File

@@ -86,9 +86,7 @@ class BeatmapsetBase(SQLModel):
# optional # optional
# converts: list[Beatmap] = Relationship(back_populates="beatmapset") # converts: list[Beatmap] = Relationship(back_populates="beatmapset")
current_nominations: list[BeatmapNomination] | None = Field( current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON))
None, sa_column=Column(JSON)
)
description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON)) description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON))
# TODO: discussions: list[BeatmapsetDiscussion] = None # TODO: discussions: list[BeatmapsetDiscussion] = None
# TODO: current_user_attributes: Optional[CurrentUserAttributes] = None # TODO: current_user_attributes: Optional[CurrentUserAttributes] = None
@@ -105,22 +103,18 @@ class BeatmapsetBase(SQLModel):
can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean)) can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
discussion_locked: bool = Field(default=False, sa_column=Column(Boolean)) discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
last_updated: datetime = Field(sa_column=Column(DateTime, index=True)) last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
ranked_date: datetime | None = Field( ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True))
default=None, sa_column=Column(DateTime, index=True)
)
storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True)) storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
submitted_date: datetime = Field(sa_column=Column(DateTime, index=True)) submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
tags: str = Field(default="", sa_column=Column(Text)) tags: str = Field(default="", sa_column=Column(Text))
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType] __tablename__: str = "beatmapsets"
id: int | None = Field(default=None, primary_key=True, index=True) id: int = Field(default=None, primary_key=True, index=True)
# Beatmapset # Beatmapset
beatmap_status: BeatmapRankStatus = Field( beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True)
default=BeatmapRankStatus.GRAVEYARD, index=True
)
# optional # optional
beatmaps: list["Beatmap"] = Relationship(back_populates="beatmapset") beatmaps: list["Beatmap"] = Relationship(back_populates="beatmapset")
@@ -137,9 +131,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset") favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
@classmethod @classmethod
async def from_resp( async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0
) -> "Beatmapset":
from .beatmap import Beatmap from .beatmap import Beatmap
d = resp.model_dump() d = resp.model_dump()
@@ -167,18 +159,14 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
"download_disabled": resp.availability.download_disabled or False, "download_disabled": resp.availability.download_disabled or False,
} }
) )
if not ( if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
await session.exec(select(exists()).where(Beatmapset.id == resp.id))
).first():
session.add(beatmapset) session.add(beatmapset)
await session.commit() await session.commit()
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_) await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
return beatmapset return beatmapset
@classmethod @classmethod
async def get_or_fetch( async def get_or_fetch(cls, session: AsyncSession, fetcher: "Fetcher", sid: int) -> "Beatmapset":
cls, session: AsyncSession, fetcher: "Fetcher", sid: int
) -> "Beatmapset":
beatmapset = await session.get(Beatmapset, sid) beatmapset = await session.get(Beatmapset, sid)
if not beatmapset: if not beatmapset:
resp = await fetcher.get_beatmapset(sid) resp = await fetcher.get_beatmapset(sid)
@@ -227,13 +215,9 @@ class BeatmapsetResp(BeatmapsetBase):
@model_validator(mode="after") @model_validator(mode="after")
def fix_genre_language(self) -> Self: def fix_genre_language(self) -> Self:
if self.genre is None: if self.genre is None:
self.genre = BeatmapTranslationText( self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id)
name=Genre(self.genre_id).name, id=self.genre_id
)
if self.language is None: if self.language is None:
self.language = BeatmapTranslationText( self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id)
name=Language(self.language_id).name, id=self.language_id
)
return self return self
@classmethod @classmethod
@@ -252,9 +236,7 @@ class BeatmapsetResp(BeatmapsetBase):
await BeatmapResp.from_db(beatmap, from_set=True, session=session) await BeatmapResp.from_db(beatmap, from_set=True, session=session)
for beatmap in await beatmapset.awaitable_attrs.beatmaps for beatmap in await beatmapset.awaitable_attrs.beatmaps
], ],
"hype": BeatmapHype( "hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required),
current=beatmapset.hype_current, required=beatmapset.hype_required
),
"availability": BeatmapAvailability( "availability": BeatmapAvailability(
more_information=beatmapset.availability_info, more_information=beatmapset.availability_info,
download_disabled=beatmapset.download_disabled, download_disabled=beatmapset.download_disabled,
@@ -282,10 +264,7 @@ class BeatmapsetResp(BeatmapsetBase):
update["ratings"] = [] update["ratings"] = []
beatmap_status = beatmapset.beatmap_status beatmap_status = beatmapset.beatmap_status
if ( if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
settings.enable_all_beatmap_leaderboard
and not beatmap_status.has_leaderboard()
):
update["status"] = BeatmapRankStatus.APPROVED.name.lower() update["status"] = BeatmapRankStatus.APPROVED.name.lower()
update["ranked"] = BeatmapRankStatus.APPROVED.value update["ranked"] = BeatmapRankStatus.APPROVED.value
else: else:
@@ -295,9 +274,7 @@ class BeatmapsetResp(BeatmapsetBase):
if session and user: if session and user:
existing_favourite = ( existing_favourite = (
await session.exec( await session.exec(
select(FavouriteBeatmapset).where( select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
FavouriteBeatmapset.beatmapset_id == beatmapset.id
)
) )
).first() ).first()
update["has_favourited"] = existing_favourite is not None update["has_favourited"] = existing_favourite is not None

View File

@@ -20,13 +20,9 @@ if TYPE_CHECKING:
class BestScore(SQLModel, table=True): class BestScore(SQLModel, table=True):
__tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType] __tablename__: str = "total_score_best_scores"
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True) gamemode: GameMode = Field(index=True)
total_score: int = Field(default=0, sa_column=Column(BigInteger)) total_score: int = Field(default=0, sa_column=Column(BigInteger))

View File

@@ -51,30 +51,22 @@ class ChatChannelBase(SQLModel):
class ChatChannel(ChatChannelBase, table=True): class ChatChannel(ChatChannelBase, table=True):
__tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType] __tablename__: str = "chat_channels"
channel_id: int | None = Field(primary_key=True, index=True, default=None) channel_id: int = Field(primary_key=True, index=True, default=None)
@classmethod @classmethod
async def get( async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
cls, channel: str | int, session: AsyncSession
) -> "ChatChannel | None":
if isinstance(channel, int) or channel.isdigit(): if isinstance(channel, int) or channel.isdigit():
# 使用查询而不是 get() 来确保对象完全加载 # 使用查询而不是 get() 来确保对象完全加载
result = await session.exec( result = await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
channel_ = result.first() channel_ = result.first()
if channel_ is not None: if channel_ is not None:
return channel_ return channel_
result = await session.exec( result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
select(ChatChannel).where(ChatChannel.name == channel)
)
return result.first() return result.first()
@classmethod @classmethod
async def get_pm_channel( async def get_pm_channel(cls, user1: int, user2: int, session: AsyncSession) -> "ChatChannel | None":
cls, user1: int, user2: int, session: AsyncSession
) -> "ChatChannel | None":
channel = await cls.get(f"pm_{user1}_{user2}", session) channel = await cls.get(f"pm_{user1}_{user2}", session)
if channel is None: if channel is None:
channel = await cls.get(f"pm_{user2}_{user1}", session) channel = await cls.get(f"pm_{user2}_{user1}", session)
@@ -153,18 +145,13 @@ class ChatChannelResp(ChatChannelBase):
.limit(10) .limit(10)
) )
).all() ).all()
c.recent_messages = [ c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages]
await ChatMessageResp.from_db(msg, session, user) for msg in messages
]
c.recent_messages.reverse() c.recent_messages.reverse()
if c.type == ChannelType.PM and users and len(users) == 2: if c.type == ChannelType.PM and users and len(users) == 2:
target_user_id = next(u for u in users if u != user.id) target_user_id = next(u for u in users if u != user.id)
target_name = await session.exec( target_name = await session.exec(select(User.username).where(User.id == target_user_id))
select(User.username).where(User.id == target_user_id)
)
c.name = target_name.one() c.name = target_name.one()
assert user.id
c.users = [target_user_id, user.id] c.users = [target_user_id, user.id]
return c return c
@@ -181,19 +168,15 @@ class MessageType(str, Enum):
class ChatMessageBase(UTCBaseModel, SQLModel): class ChatMessageBase(UTCBaseModel, SQLModel):
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id") channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
content: str = Field(sa_column=Column(VARCHAR(1000))) content: str = Field(sa_column=Column(VARCHAR(1000)))
message_id: int | None = Field(index=True, primary_key=True, default=None) message_id: int = Field(index=True, primary_key=True, default=None)
sender_id: int = Field( sender_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) timestamp: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC))
)
timestamp: datetime = Field(
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
)
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True) type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
uuid: str | None = Field(default=None) uuid: str | None = Field(default=None)
class ChatMessage(ChatMessageBase, table=True): class ChatMessage(ChatMessageBase, table=True):
__tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType] __tablename__: str = "chat_messages"
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
channel: ChatChannel = Relationship() channel: ChatChannel = Relationship()
@@ -211,9 +194,7 @@ class ChatMessageResp(ChatMessageBase):
if user: if user:
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES) m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
else: else:
m.sender = await UserResp.from_db( m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES)
db_message.user, session, RANKING_INCLUDES
)
return m return m
@@ -221,17 +202,13 @@ class ChatMessageResp(ChatMessageBase):
class SilenceUser(UTCBaseModel, SQLModel, table=True): class SilenceUser(UTCBaseModel, SQLModel, table=True):
__tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType] __tablename__: str = "chat_silence_users"
id: int | None = Field(primary_key=True, default=None, index=True) id: int = Field(primary_key=True, default=None, index=True)
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True) channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True)
until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None) until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None)
reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True)) reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True))
banned_at: datetime = Field( banned_at: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC))
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
)
class UserSilenceResp(SQLModel): class UserSilenceResp(SQLModel):
@@ -240,7 +217,6 @@ class UserSilenceResp(SQLModel):
@classmethod @classmethod
def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp": def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp":
assert db_silence.id is not None
return cls( return cls(
id=db_silence.id, id=db_silence.id,
user_id=db_silence.user_id, user_id=db_silence.user_id,

View File

@@ -21,28 +21,24 @@ class CountBase(SQLModel):
class MonthlyPlaycounts(CountBase, table=True): class MonthlyPlaycounts(CountBase, table=True):
__tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType] __tablename__: str = "monthly_playcounts"
id: int | None = Field( id: int | None = Field(
default=None, default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True), sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
) )
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user: "User" = Relationship(back_populates="monthly_playcounts") user: "User" = Relationship(back_populates="monthly_playcounts")
class ReplayWatchedCount(CountBase, table=True): class ReplayWatchedCount(CountBase, table=True):
__tablename__ = "replays_watched_counts" # pyright: ignore[reportAssignmentType] __tablename__: str = "replays_watched_counts"
id: int | None = Field( id: int | None = Field(
default=None, default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True), sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
) )
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
user: "User" = Relationship(back_populates="replays_watched_counts") user: "User" = Relationship(back_populates="replays_watched_counts")

View File

@@ -24,9 +24,7 @@ class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
daily_streak_best: int = Field(default=0) daily_streak_best: int = Field(default=0)
daily_streak_current: int = Field(default=0) daily_streak_current: int = Field(default=0)
last_update: datetime | None = Field(default=None, sa_column=Column(DateTime)) last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
last_weekly_streak: datetime | None = Field( last_weekly_streak: datetime | None = Field(default=None, sa_column=Column(DateTime))
default=None, sa_column=Column(DateTime)
)
playcount: int = Field(default=0) playcount: int = Field(default=0)
top_10p_placements: int = Field(default=0) top_10p_placements: int = Field(default=0)
top_50p_placements: int = Field(default=0) top_50p_placements: int = Field(default=0)
@@ -35,7 +33,7 @@ class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
class DailyChallengeStats(DailyChallengeStatsBase, table=True): class DailyChallengeStats(DailyChallengeStatsBase, table=True):
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType] __tablename__: str = "daily_challenge_stats"
user_id: int | None = Field( user_id: int | None = Field(
default=None, default=None,
@@ -61,9 +59,7 @@ class DailyChallengeStatsResp(DailyChallengeStatsBase):
return cls.model_validate(obj) return cls.model_validate(obj)
async def process_daily_challenge_score( async def process_daily_challenge_score(session: AsyncSession, user_id: int, room_id: int):
session: AsyncSession, user_id: int, room_id: int
):
from .playlist_best_score import PlaylistBestScore from .playlist_best_score import PlaylistBestScore
score = ( score = (

View File

@@ -4,16 +4,17 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, UTC from datetime import UTC, datetime
from sqlmodel import SQLModel, Field
from sqlalchemy import Column, BigInteger, ForeignKey from sqlalchemy import BigInteger, Column, ForeignKey
from sqlmodel import Field, SQLModel
class EmailVerification(SQLModel, table=True): class EmailVerification(SQLModel, table=True):
"""邮件验证记录""" """邮件验证记录"""
__tablename__: str = "email_verifications" __tablename__: str = "email_verifications"
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True)) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
email: str = Field(index=True) email: str = Field(index=True)
@@ -28,9 +29,9 @@ class EmailVerification(SQLModel, table=True):
class LoginSession(SQLModel, table=True): class LoginSession(SQLModel, table=True):
"""登录会话记录""" """登录会话记录"""
__tablename__: str = "login_sessions" __tablename__: str = "login_sessions"
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True)) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
session_token: str = Field(unique=True, index=True) # 会话令牌 session_token: str = Field(unique=True, index=True) # 会话令牌

View File

@@ -36,17 +36,13 @@ class EventType(str, Enum):
class EventBase(SQLModel): class EventBase(SQLModel):
id: int = Field(default=None, primary_key=True) id: int = Field(default=None, primary_key=True)
created_at: datetime = Field( created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), default=datetime.now(UTC)))
sa_column=Column(DateTime(timezone=True), default=datetime.now(UTC))
)
type: EventType type: EventType
event_payload: dict = Field( event_payload: dict = Field(exclude=True, default_factory=dict, sa_column=Column(JSON))
exclude=True, default_factory=dict, sa_column=Column(JSON)
)
class Event(EventBase, table=True): class Event(EventBase, table=True):
__tablename__ = "user_events" # pyright: ignore[reportAssignmentType] __tablename__: str = "user_events"
user_id: int | None = Field( user_id: int | None = Field(
default=None, default=None,
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True), sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True),

View File

@@ -16,8 +16,8 @@ FAILTIME_STRUCT = Struct("<100i")
class FailTime(SQLModel, table=True): class FailTime(SQLModel, table=True):
__tablename__ = "failtime" # pyright: ignore[reportAssignmentType] __tablename__: str = "failtime"
beatmap_id: int = Field(primary_key=True, index=True, foreign_key="beatmaps.id") beatmap_id: int = Field(primary_key=True, foreign_key="beatmaps.id")
exit: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False)) exit: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False))
fail: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False)) fail: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False))
@@ -41,12 +41,8 @@ class FailTime(SQLModel, table=True):
class FailTimeResp(BaseModel): class FailTimeResp(BaseModel):
exit: list[int] = Field( exit: list[int] = Field(default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)))
default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)) fail: list[int] = Field(default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)))
)
fail: list[int] = Field(
default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400))
)
@classmethod @classmethod
def from_db(cls, failtime: FailTime) -> "FailTimeResp": def from_db(cls, failtime: FailTime) -> "FailTimeResp":

View File

@@ -16,7 +16,7 @@ from sqlmodel import (
class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True): class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True):
__tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType] __tablename__: str = "favourite_beatmapset"
id: int | None = Field( id: int | None = Field(
default=None, default=None,
sa_column=Column(BigInteger, autoincrement=True, primary_key=True), sa_column=Column(BigInteger, autoincrement=True, primary_key=True),

View File

@@ -75,9 +75,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_active: bool = True is_active: bool = True
is_bot: bool = False is_bot: bool = False
is_supporter: bool = False is_supporter: bool = False
last_visit: datetime | None = Field( last_visit: datetime | None = Field(default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)))
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
)
pm_friends_only: bool = False pm_friends_only: bool = False
profile_colour: str | None = None profile_colour: str | None = None
username: str = Field(max_length=32, unique=True, index=True) username: str = Field(max_length=32, unique=True, index=True)
@@ -90,9 +88,7 @@ class UserBase(UTCBaseModel, SQLModel):
is_restricted: bool = False is_restricted: bool = False
# blocks # blocks
cover: UserProfileCover = Field( cover: UserProfileCover = Field(
default=UserProfileCover( default=UserProfileCover(url="https://assets.ppy.sh/user-profile-covers/default.jpeg"),
url="https://assets.ppy.sh/user-profile-covers/default.jpeg"
),
sa_column=Column(JSON), sa_column=Column(JSON),
) )
beatmap_playcounts_count: int = 0 beatmap_playcounts_count: int = 0
@@ -150,9 +146,9 @@ class UserBase(UTCBaseModel, SQLModel):
class User(AsyncAttrs, UserBase, table=True): class User(AsyncAttrs, UserBase, table=True):
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType] __tablename__: str = "lazer_users"
id: int | None = Field( id: int = Field(
default=None, default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True), sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
) )
@@ -160,16 +156,10 @@ class User(AsyncAttrs, UserBase, table=True):
statistics: list[UserStatistics] = Relationship() statistics: list[UserStatistics] = Relationship()
achievement: list[UserAchievement] = Relationship(back_populates="user") achievement: list[UserAchievement] = Relationship(back_populates="user")
team_membership: TeamMember | None = Relationship(back_populates="user") team_membership: TeamMember | None = Relationship(back_populates="user")
daily_challenge_stats: DailyChallengeStats | None = Relationship( daily_challenge_stats: DailyChallengeStats | None = Relationship(back_populates="user")
back_populates="user"
)
monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user") monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user")
replays_watched_counts: list[ReplayWatchedCount] = Relationship( replays_watched_counts: list[ReplayWatchedCount] = Relationship(back_populates="user")
back_populates="user" favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(back_populates="user")
)
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(
back_populates="user"
)
rank_history: list[RankHistory] = Relationship( rank_history: list[RankHistory] = Relationship(
back_populates="user", back_populates="user",
) )
@@ -178,16 +168,10 @@ class User(AsyncAttrs, UserBase, table=True):
email: str = Field(max_length=254, unique=True, index=True, exclude=True) email: str = Field(max_length=254, unique=True, index=True, exclude=True)
priv: int = Field(default=1, exclude=True) priv: int = Field(default=1, exclude=True)
pw_bcrypt: str = Field(max_length=60, exclude=True) pw_bcrypt: str = Field(max_length=60, exclude=True)
silence_end_at: datetime | None = Field( silence_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True)
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True donor_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True)
)
donor_end_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
)
async def is_user_can_pm( async def is_user_can_pm(self, from_user: "User", session: AsyncSession) -> tuple[bool, str]:
self, from_user: "User", session: AsyncSession
) -> tuple[bool, str]:
from .relationship import Relationship, RelationshipType from .relationship import Relationship, RelationshipType
from_relationship = ( from_relationship = (
@@ -200,13 +184,10 @@ class User(AsyncAttrs, UserBase, table=True):
).first() ).first()
if from_relationship and from_relationship.type == RelationshipType.BLOCK: if from_relationship and from_relationship.type == RelationshipType.BLOCK:
return False, "You have blocked the target user." return False, "You have blocked the target user."
if from_user.pm_friends_only and ( if from_user.pm_friends_only and (not from_relationship or from_relationship.type != RelationshipType.FOLLOW):
not from_relationship or from_relationship.type != RelationshipType.FOLLOW
):
return ( return (
False, False,
"You have disabled non-friend communications " "You have disabled non-friend communications and target user is not your friend.",
"and target user is not your friend.",
) )
relationship = ( relationship = (
@@ -219,9 +200,7 @@ class User(AsyncAttrs, UserBase, table=True):
).first() ).first()
if relationship and relationship.type == RelationshipType.BLOCK: if relationship and relationship.type == RelationshipType.BLOCK:
return False, "Target user has blocked you." return False, "Target user has blocked you."
if self.pm_friends_only and ( if self.pm_friends_only and (not relationship or relationship.type != RelationshipType.FOLLOW):
not relationship or relationship.type != RelationshipType.FOLLOW
):
return False, "Target user has disabled non-friend communications" return False, "Target user has disabled non-friend communications"
return True, "" return True, ""
@@ -288,9 +267,7 @@ class UserResp(UserBase):
u = cls.model_validate(obj.model_dump()) u = cls.model_validate(obj.model_dump())
u.id = obj.id u.id = obj.id
u.default_group = "bot" if u.is_bot else "default" u.default_group = "bot" if u.is_bot else "default"
u.country = Country( u.country = Country(code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown"))
code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown")
)
u.follower_count = ( u.follower_count = (
await session.exec( await session.exec(
select(func.count()) select(func.count())
@@ -314,9 +291,7 @@ class UserResp(UserBase):
redis = get_redis() redis = get_redis()
u.is_online = await redis.exists(f"metadata:online:{obj.id}") u.is_online = await redis.exists(f"metadata:online:{obj.id}")
u.cover_url = ( u.cover_url = (
obj.cover.get( obj.cover.get("url", "https://assets.ppy.sh/user-profile-covers/default.jpeg")
"url", "https://assets.ppy.sh/user-profile-covers/default.jpeg"
)
if obj.cover if obj.cover
else "https://assets.ppy.sh/user-profile-covers/default.jpeg" else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
) )
@@ -335,22 +310,15 @@ class UserResp(UserBase):
] ]
if "team" in include: if "team" in include:
if await obj.awaitable_attrs.team_membership: if team_membership := await obj.awaitable_attrs.team_membership:
assert obj.team_membership u.team = team_membership.team
u.team = obj.team_membership.team
if "account_history" in include: if "account_history" in include:
u.account_history = [ u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
UserAccountHistoryResp.from_db(ah)
for ah in await obj.awaitable_attrs.account_history
]
if "daily_challenge_user_stats": if "daily_challenge_user_stats":
if await obj.awaitable_attrs.daily_challenge_stats: if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats:
assert obj.daily_challenge_stats u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
obj.daily_challenge_stats
)
if "statistics" in include: if "statistics" in include:
current_stattistics = None current_stattistics = None
@@ -359,59 +327,40 @@ class UserResp(UserBase):
current_stattistics = i current_stattistics = i
break break
u.statistics = ( u.statistics = (
await UserStatisticsResp.from_db( await UserStatisticsResp.from_db(current_stattistics, session, obj.country_code)
current_stattistics, session, obj.country_code
)
if current_stattistics if current_stattistics
else None else None
) )
if "statistics_rulesets" in include: if "statistics_rulesets" in include:
u.statistics_rulesets = { u.statistics_rulesets = {
i.mode.value: await UserStatisticsResp.from_db( i.mode.value: await UserStatisticsResp.from_db(i, session, obj.country_code)
i, session, obj.country_code
)
for i in await obj.awaitable_attrs.statistics for i in await obj.awaitable_attrs.statistics
} }
if "monthly_playcounts" in include: if "monthly_playcounts" in include:
u.monthly_playcounts = [ u.monthly_playcounts = [CountResp.from_db(pc) for pc in await obj.awaitable_attrs.monthly_playcounts]
CountResp.from_db(pc)
for pc in await obj.awaitable_attrs.monthly_playcounts
]
if len(u.monthly_playcounts) == 1: if len(u.monthly_playcounts) == 1:
d = u.monthly_playcounts[0].start_date d = u.monthly_playcounts[0].start_date
u.monthly_playcounts.insert( u.monthly_playcounts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
0, CountResp(start_date=d - timedelta(days=20), count=0)
)
if "replays_watched_counts" in include: if "replays_watched_counts" in include:
u.replay_watched_counts = [ u.replay_watched_counts = [
CountResp.from_db(rwc) CountResp.from_db(rwc) for rwc in await obj.awaitable_attrs.replays_watched_counts
for rwc in await obj.awaitable_attrs.replays_watched_counts
] ]
if len(u.replay_watched_counts) == 1: if len(u.replay_watched_counts) == 1:
d = u.replay_watched_counts[0].start_date d = u.replay_watched_counts[0].start_date
u.replay_watched_counts.insert( u.replay_watched_counts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
0, CountResp(start_date=d - timedelta(days=20), count=0)
)
if "achievements" in include: if "achievements" in include:
u.user_achievements = [ u.user_achievements = [UserAchievementResp.from_db(ua) for ua in await obj.awaitable_attrs.achievement]
UserAchievementResp.from_db(ua)
for ua in await obj.awaitable_attrs.achievement
]
if "rank_history" in include: if "rank_history" in include:
rank_history = await RankHistoryResp.from_db(session, obj.id, ruleset) rank_history = await RankHistoryResp.from_db(session, obj.id, ruleset)
if len(rank_history.data) != 0: if len(rank_history.data) != 0:
u.rank_history = rank_history u.rank_history = rank_history
rank_top = ( rank_top = (
await session.exec( await session.exec(select(RankTop).where(RankTop.user_id == obj.id, RankTop.mode == ruleset))
select(RankTop).where(
RankTop.user_id == obj.id, RankTop.mode == ruleset
)
)
).first() ).first()
if rank_top: if rank_top:
u.rank_highest = ( u.rank_highest = (
@@ -425,9 +374,7 @@ class UserResp(UserBase):
u.favourite_beatmapset_count = ( u.favourite_beatmapset_count = (
await session.exec( await session.exec(
select(func.count()) select(func.count()).select_from(FavouriteBeatmapset).where(FavouriteBeatmapset.user_id == obj.id)
.select_from(FavouriteBeatmapset)
.where(FavouriteBeatmapset.user_id == obj.id)
) )
).one() ).one()
u.scores_pinned_count = ( u.scores_pinned_count = (
@@ -478,17 +425,19 @@ class UserResp(UserBase):
# 检查会话验证状态 # 检查会话验证状态
# 如果邮件验证功能被禁用,则始终设置 session_verified 为 true # 如果邮件验证功能被禁用,则始终设置 session_verified 为 true
from app.config import settings from app.config import settings
if not settings.enable_email_verification: if not settings.enable_email_verification:
u.session_verified = True u.session_verified = True
else: else:
# 如果用户有未验证的登录会话,则设置 session_verified 为 false # 如果用户有未验证的登录会话,则设置 session_verified 为 false
from .email_verification import LoginSession from .email_verification import LoginSession
unverified_session = ( unverified_session = (
await session.exec( await session.exec(
select(LoginSession).where( select(LoginSession).where(
LoginSession.user_id == obj.id, LoginSession.user_id == obj.id,
LoginSession.is_verified == False, col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > datetime.now(UTC) LoginSession.expires_at > datetime.now(UTC),
) )
) )
).first() ).first()

View File

@@ -30,8 +30,8 @@ class MultiplayerEventBase(SQLModel, UTCBaseModel):
class MultiplayerEvent(MultiplayerEventBase, table=True): class MultiplayerEvent(MultiplayerEventBase, table=True):
__tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType] __tablename__: str = "multiplayer_events"
id: int | None = Field( id: int = Field(
default=None, default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True), sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
) )

View File

@@ -17,7 +17,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
class Notification(SQLModel, table=True): class Notification(SQLModel, table=True):
__tablename__ = "notifications" # pyright: ignore[reportAssignmentType] __tablename__: str = "notifications"
id: int = Field(primary_key=True, index=True, default=None) id: int = Field(primary_key=True, index=True, default=None)
name: NotificationName = Field(index=True) name: NotificationName = Field(index=True)
@@ -30,7 +30,7 @@ class Notification(SQLModel, table=True):
class UserNotification(SQLModel, table=True): class UserNotification(SQLModel, table=True):
__tablename__ = "user_notifications" # pyright: ignore[reportAssignmentType] __tablename__: str = "user_notifications"
id: int = Field( id: int = Field(
sa_column=Column( sa_column=Column(
BigInteger, BigInteger,
@@ -40,9 +40,7 @@ class UserNotification(SQLModel, table=True):
default=None, default=None,
) )
notification_id: int = Field(index=True, foreign_key="notifications.id") notification_id: int = Field(index=True, foreign_key="notifications.id")
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
is_read: bool = Field(index=True) is_read: bool = Field(index=True)
notification: Notification = Relationship(sa_relationship_kwargs={"lazy": "joined"}) notification: Notification = Relationship(sa_relationship_kwargs={"lazy": "joined"})

View File

@@ -4,16 +4,17 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, UTC from datetime import UTC, datetime
from sqlmodel import SQLModel, Field
from sqlalchemy import Column, BigInteger, ForeignKey from sqlalchemy import BigInteger, Column, ForeignKey
from sqlmodel import Field, SQLModel
class PasswordReset(SQLModel, table=True): class PasswordReset(SQLModel, table=True):
"""密码重置记录""" """密码重置记录"""
__tablename__: str = "password_resets" __tablename__: str = "password_resets"
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True)) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
email: str = Field(index=True) email: str = Field(index=True)

View File

@@ -21,16 +21,14 @@ class ItemAttemptsCountBase(SQLModel):
room_id: int = Field(foreign_key="rooms.id", index=True) room_id: int = Field(foreign_key="rooms.id", index=True)
attempts: int = Field(default=0) attempts: int = Field(default=0)
completed: int = Field(default=0) completed: int = Field(default=0)
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
accuracy: float = 0.0 accuracy: float = 0.0
pp: float = 0 pp: float = 0
total_score: int = 0 total_score: int = 0
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True): class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
__tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType] __tablename__: str = "item_attempts_count"
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
user: User = Relationship() user: User = Relationship()
@@ -63,9 +61,7 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
self.pp = sum(score.score.pp for score in playlist_scores) self.pp = sum(score.score.pp for score in playlist_scores)
self.completed = len([score for score in playlist_scores if score.score.passed]) self.completed = len([score for score in playlist_scores if score.score.passed])
self.accuracy = ( self.accuracy = (
sum(score.score.accuracy for score in playlist_scores) / self.completed sum(score.score.accuracy for score in playlist_scores) / self.completed if self.completed > 0 else 0.0
if self.completed > 0
else 0.0
) )
await session.commit() await session.commit()
await session.refresh(self) await session.refresh(self)

View File

@@ -21,14 +21,10 @@ if TYPE_CHECKING:
class PlaylistBestScore(SQLModel, table=True): class PlaylistBestScore(SQLModel, table=True):
__tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType] __tablename__: str = "playlist_best_scores"
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
room_id: int = Field(foreign_key="rooms.id", index=True) room_id: int = Field(foreign_key="rooms.id", index=True)
playlist_id: int = Field(index=True) playlist_id: int = Field(index=True)
total_score: int = Field(default=0, sa_column=Column(BigInteger)) total_score: int = Field(default=0, sa_column=Column(BigInteger))

View File

@@ -50,7 +50,7 @@ class PlaylistBase(SQLModel, UTCBaseModel):
class Playlist(PlaylistBase, table=True): class Playlist(PlaylistBase, table=True):
__tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType] __tablename__: str = "room_playlists"
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True) db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
room_id: int = Field(foreign_key="rooms.id", exclude=True) room_id: int = Field(foreign_key="rooms.id", exclude=True)
@@ -63,16 +63,12 @@ class Playlist(PlaylistBase, table=True):
@classmethod @classmethod
async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int: async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int:
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where( stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(cls.room_id == room_id)
cls.room_id == room_id
)
result = await session.exec(stmt) result = await session.exec(stmt)
return result.one() return result.one()
@classmethod @classmethod
async def from_hub( async def from_hub(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
) -> "Playlist":
next_id = await cls.get_next_id_for_room(room_id, session=session) next_id = await cls.get_next_id_for_room(room_id, session=session)
return cls( return cls(
id=next_id, id=next_id,
@@ -90,9 +86,7 @@ class Playlist(PlaylistBase, table=True):
@classmethod @classmethod
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession): async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
db_playlist = await session.exec( db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id))
select(cls).where(cls.id == playlist.id, cls.room_id == room_id)
)
db_playlist = db_playlist.first() db_playlist = db_playlist.first()
if db_playlist is None: if db_playlist is None:
raise ValueError("Playlist item not found") raise ValueError("Playlist item not found")
@@ -108,9 +102,7 @@ class Playlist(PlaylistBase, table=True):
await session.commit() await session.commit()
@classmethod @classmethod
async def add_to_db( async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
):
db_playlist = await cls.from_hub(playlist, room_id, session) db_playlist = await cls.from_hub(playlist, room_id, session)
session.add(db_playlist) session.add(db_playlist)
await session.commit() await session.commit()
@@ -119,9 +111,7 @@ class Playlist(PlaylistBase, table=True):
@classmethod @classmethod
async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession): async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession):
db_playlist = await session.exec( db_playlist = await session.exec(select(cls).where(cls.id == item_id, cls.room_id == room_id))
select(cls).where(cls.id == item_id, cls.room_id == room_id)
)
db_playlist = db_playlist.first() db_playlist = db_playlist.first()
if db_playlist is None: if db_playlist is None:
raise ValueError("Playlist item not found") raise ValueError("Playlist item not found")
@@ -133,9 +123,7 @@ class PlaylistResp(PlaylistBase):
beatmap: BeatmapResp | None = None beatmap: BeatmapResp | None = None
@classmethod @classmethod
async def from_db( async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp":
cls, playlist: Playlist, include: list[str] = []
) -> "PlaylistResp":
data = playlist.model_dump() data = playlist.model_dump()
if "beatmap" in include: if "beatmap" in include:
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap) data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)

View File

@@ -20,13 +20,9 @@ if TYPE_CHECKING:
class PPBestScore(SQLModel, table=True): class PPBestScore(SQLModel, table=True):
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType] __tablename__: str = "best_scores"
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
)
score_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
)
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
gamemode: GameMode = Field(index=True) gamemode: GameMode = Field(index=True)
pp: float = Field( pp: float = Field(

View File

@@ -26,12 +26,10 @@ if TYPE_CHECKING:
class RankHistory(SQLModel, table=True): class RankHistory(SQLModel, table=True):
__tablename__ = "rank_history" # pyright: ignore[reportAssignmentType] __tablename__: str = "rank_history"
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True)) id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
mode: GameMode mode: GameMode
rank: int rank: int
date: dt = Field( date: dt = Field(
@@ -43,12 +41,10 @@ class RankHistory(SQLModel, table=True):
class RankTop(SQLModel, table=True): class RankTop(SQLModel, table=True):
__tablename__ = "rank_top" # pyright: ignore[reportAssignmentType] __tablename__: str = "rank_top"
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True)) id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
mode: GameMode mode: GameMode
rank: int rank: int
date: dt = Field( date: dt = Field(
@@ -62,9 +58,7 @@ class RankHistoryResp(BaseModel):
data: list[int] data: list[int]
@classmethod @classmethod
async def from_db( async def from_db(cls, session: AsyncSession, user_id: int, mode: GameMode) -> "RankHistoryResp":
cls, session: AsyncSession, user_id: int, mode: GameMode
) -> "RankHistoryResp":
results = ( results = (
await session.exec( await session.exec(
select(RankHistory) select(RankHistory)

View File

@@ -21,7 +21,7 @@ class RelationshipType(str, Enum):
class Relationship(SQLModel, table=True): class Relationship(SQLModel, table=True):
__tablename__ = "relationship" # pyright: ignore[reportAssignmentType] __tablename__: str = "relationship"
id: int | None = Field( id: int | None = Field(
default=None, default=None,
sa_column=Column(BigInteger, autoincrement=True, primary_key=True), sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
@@ -59,9 +59,7 @@ class RelationshipResp(BaseModel):
type: RelationshipType type: RelationshipType
@classmethod @classmethod
async def from_db( async def from_db(cls, session: AsyncSession, relationship: Relationship) -> "RelationshipResp":
cls, session: AsyncSession, relationship: Relationship
) -> "RelationshipResp":
target_relationship = ( target_relationship = (
await session.exec( await session.exec(
select(Relationship).where( select(Relationship).where(

View File

@@ -58,11 +58,9 @@ class RoomBase(SQLModel, UTCBaseModel):
class Room(AsyncAttrs, RoomBase, table=True): class Room(AsyncAttrs, RoomBase, table=True):
__tablename__ = "rooms" # pyright: ignore[reportAssignmentType] __tablename__: str = "rooms"
id: int = Field(default=None, primary_key=True, index=True) id: int = Field(default=None, primary_key=True, index=True)
host_id: int = Field( host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
host: User = Relationship() host: User = Relationship()
playlist: list[Playlist] = Relationship( playlist: list[Playlist] = Relationship(
@@ -109,12 +107,8 @@ class RoomResp(RoomBase):
if not playlist.expired: if not playlist.expired:
stats.count_active += 1 stats.count_active += 1
rulesets.add(playlist.ruleset_id) rulesets.add(playlist.ruleset_id)
difficulty_range.min = min( difficulty_range.min = min(difficulty_range.min, playlist.beatmap.difficulty_rating)
difficulty_range.min, playlist.beatmap.difficulty_rating difficulty_range.max = max(difficulty_range.max, playlist.beatmap.difficulty_rating)
)
difficulty_range.max = max(
difficulty_range.max, playlist.beatmap.difficulty_rating
)
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"])) resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
stats.ruleset_ids = list(rulesets) stats.ruleset_ids = list(rulesets)
resp.playlist_item_stats = stats resp.playlist_item_stats = stats
@@ -137,13 +131,9 @@ class RoomResp(RoomBase):
include=["statistics"], include=["statistics"],
) )
) )
resp.host = await UserResp.from_db( resp.host = await UserResp.from_db(await room.awaitable_attrs.host, session, include=["statistics"])
await room.awaitable_attrs.host, session, include=["statistics"]
)
if "current_user_score" in include and user: if "current_user_score" in include and user:
resp.current_user_score = await PlaylistAggregateScore.from_db( resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
room.id, user.id, session
)
return resp return resp
@classmethod @classmethod

View File

@@ -18,22 +18,16 @@ if TYPE_CHECKING:
class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True): class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True):
__tablename__ = "room_participated_users" # pyright: ignore[reportAssignmentType] __tablename__: str = "room_participated_users"
id: int | None = Field( id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True))
default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True)
)
room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), nullable=False)) room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), nullable=False))
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False)
)
joined_at: datetime = Field( joined_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), nullable=False), sa_column=Column(DateTime(timezone=True), nullable=False),
default=datetime.now(UTC), default=datetime.now(UTC),
) )
left_at: datetime | None = Field( left_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True), default=None)
sa_column=Column(DateTime(timezone=True), nullable=True), default=None
)
room: "Room" = Relationship() room: "Room" = Relationship()
user: "User" = Relationship() user: "User" = Relationship()

View File

@@ -47,9 +47,9 @@ from .score_token import ScoreToken
from pydantic import field_serializer, field_validator from pydantic import field_serializer, field_validator
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime, TextClause
from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import aliased from sqlalchemy.orm import Mapped, aliased
from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import ( from sqlmodel import (
JSON, JSON,
@@ -76,9 +76,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
accuracy: float accuracy: float
map_md5: str = Field(max_length=32, index=True) map_md5: str = Field(max_length=32, index=True)
build_id: int | None = Field(default=None) build_id: int | None = Field(default=None)
classic_total_score: int | None = Field( classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score
default=0, sa_column=Column(BigInteger)
) # solo_score
ended_at: datetime = Field(sa_column=Column(DateTime)) ended_at: datetime = Field(sa_column=Column(DateTime))
has_replay: bool = Field(sa_column=Column(Boolean)) has_replay: bool = Field(sa_column=Column(Boolean))
max_combo: int max_combo: int
@@ -91,14 +89,10 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
room_id: int | None = Field(default=None) # multiplayer room_id: int | None = Field(default=None) # multiplayer
started_at: datetime = Field(sa_column=Column(DateTime)) started_at: datetime = Field(sa_column=Column(DateTime))
total_score: int = Field(default=0, sa_column=Column(BigInteger)) total_score: int = Field(default=0, sa_column=Column(BigInteger))
total_score_without_mods: int = Field( total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
default=0, sa_column=Column(BigInteger), exclude=True
)
type: str type: str
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id") beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
maximum_statistics: ScoreStatistics = Field( maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict)
sa_column=Column(JSON), default_factory=dict
)
@field_validator("maximum_statistics", mode="before") @field_validator("maximum_statistics", mode="before")
@classmethod @classmethod
@@ -147,10 +141,8 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
class Score(ScoreBase, table=True): class Score(ScoreBase, table=True):
__tablename__ = "scores" # pyright: ignore[reportAssignmentType] __tablename__: str = "scores"
id: int | None = Field( id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
)
user_id: int = Field( user_id: int = Field(
default=None, default=None,
sa_column=Column( sa_column=Column(
@@ -193,8 +185,8 @@ class Score(ScoreBase, table=True):
return str(v) return str(v)
# optional # optional
beatmap: Beatmap = Relationship() beatmap: Mapped[Beatmap] = Relationship()
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) user: Mapped[User] = Relationship(sa_relationship_kwargs={"lazy": "joined"})
@property @property
def is_perfect_combo(self) -> bool: def is_perfect_combo(self) -> bool:
@@ -205,11 +197,7 @@ class Score(ScoreBase, table=True):
*where_clauses: ColumnExpressionArgument[bool] | bool, *where_clauses: ColumnExpressionArgument[bool] | bool,
) -> SelectOfScalar["Score"]: ) -> SelectOfScalar["Score"]:
rownum = ( rownum = (
func.row_number() func.row_number().over(partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()).label("rn")
.over(
partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()
)
.label("rn")
) )
subq = select(Score, rownum).where(*where_clauses).subquery() subq = select(Score, rownum).where(*where_clauses).subquery()
best = aliased(Score, subq, adapt_on_names=True) best = aliased(Score, subq, adapt_on_names=True)
@@ -296,12 +284,9 @@ class ScoreResp(ScoreBase):
await session.refresh(score) await session.refresh(score)
s = cls.model_validate(score.model_dump()) s = cls.model_validate(score.model_dump())
assert score.id
await score.awaitable_attrs.beatmap await score.awaitable_attrs.beatmap
s.beatmap = await BeatmapResp.from_db(score.beatmap) s.beatmap = await BeatmapResp.from_db(score.beatmap)
s.beatmapset = await BeatmapsetResp.from_db( s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user)
score.beatmap.beatmapset, session=session, user=score.user
)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
s.legacy_perfect = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo
s.ruleset_id = int(score.gamemode) s.ruleset_id = int(score.gamemode)
@@ -371,11 +356,7 @@ class ScoreAround(SQLModel):
async def get_best_id(session: AsyncSession, score_id: int) -> None: async def get_best_id(session: AsyncSession, score_id: int) -> None:
rownum = ( rownum = (
func.row_number() func.row_number().over(partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()).label("rn")
.over(
partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()
)
.label("rn")
) )
subq = select(PPBestScore, rownum).subquery() subq = select(PPBestScore, rownum).subquery()
stmt = select(subq.c.rn).where(subq.c.score_id == score_id) stmt = select(subq.c.rn).where(subq.c.score_id == score_id)
@@ -389,8 +370,8 @@ async def _score_where(
mode: GameMode, mode: GameMode,
mods: list[str] | None = None, mods: list[str] | None = None,
user: User | None = None, user: User | None = None,
) -> list[ColumnElement[bool]] | None: ) -> list[ColumnElement[bool] | TextClause] | None:
wheres = [ wheres: list[ColumnElement[bool] | TextClause] = [
col(BestScore.beatmap_id) == beatmap, col(BestScore.beatmap_id) == beatmap,
col(BestScore.gamemode) == mode, col(BestScore.gamemode) == mode,
] ]
@@ -410,9 +391,7 @@ async def _score_where(
return None return None
elif type == LeaderboardType.COUNTRY: elif type == LeaderboardType.COUNTRY:
if user and user.is_supporter: if user and user.is_supporter:
wheres.append( wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
col(BestScore.user).has(col(User.country_code) == user.country_code)
)
else: else:
return None return None
elif type == LeaderboardType.TEAM: elif type == LeaderboardType.TEAM:
@@ -420,18 +399,14 @@ async def _score_where(
team_membership = await user.awaitable_attrs.team_membership team_membership = await user.awaitable_attrs.team_membership
if team_membership: if team_membership:
team_id = team_membership.team_id team_id = team_membership.team_id
wheres.append( wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
col(BestScore.user).has(
col(User.team_membership).has(TeamMember.team_id == team_id)
)
)
if mods: if mods:
if user and user.is_supporter: if user and user.is_supporter:
wheres.append( wheres.append(
text( text(
"JSON_CONTAINS(total_score_best_scores.mods, :w)" "JSON_CONTAINS(total_score_best_scores.mods, :w)"
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)" " AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
).params(w=json.dumps(mods)) # pyright: ignore[reportArgumentType] ).params(w=json.dumps(mods))
) )
else: else:
return None return None
@@ -654,18 +629,14 @@ def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]:
+ (score.nsmall_tick_hit or 0) + (score.nsmall_tick_hit or 0)
) )
total_obj = 0 total_obj = 0
for statistics, count in ( for statistics, count in score.maximum_statistics.items() if score.maximum_statistics else {}:
score.maximum_statistics.items() if score.maximum_statistics else {}
):
if not isinstance(statistics, HitResult): if not isinstance(statistics, HitResult):
statistics = HitResult(statistics) statistics = HitResult(statistics)
if statistics.is_scorable(): if statistics.is_scorable():
total_obj += count total_obj += count
return total_length, score.passed or ( return total_length, score.passed or (
total_length > 8 total_length > 8 and score.total_score >= 5000 and total_obj_hited >= min(0.1 * total_obj, 20)
and score.total_score >= 5000
and total_obj_hited >= min(0.1 * total_obj, 20)
) )
@@ -678,12 +649,8 @@ async def process_user(
ranked: bool = False, ranked: bool = False,
has_leaderboard: bool = False, has_leaderboard: bool = False,
): ):
assert user.id
assert score.id
mod_for_save = mod_to_save(score.mods) mod_for_save = mod_to_save(score.mods)
previous_score_best = await get_user_best_score_in_beatmap( previous_score_best = await get_user_best_score_in_beatmap(session, score.beatmap_id, user.id, score.gamemode)
session, score.beatmap_id, user.id, score.gamemode
)
previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap( previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap(
session, score.beatmap_id, user.id, mod_for_save, score.gamemode session, score.beatmap_id, user.id, mod_for_save, score.gamemode
) )
@@ -698,9 +665,7 @@ async def process_user(
) )
).first() ).first()
if mouthly_playcount is None: if mouthly_playcount is None:
mouthly_playcount = MonthlyPlaycounts( mouthly_playcount = MonthlyPlaycounts(user_id=user.id, year=date.today().year, month=date.today().month)
user_id=user.id, year=date.today().year, month=date.today().month
)
add_to_db = True add_to_db = True
statistics = None statistics = None
for i in await user.awaitable_attrs.statistics: for i in await user.awaitable_attrs.statistics:
@@ -708,17 +673,11 @@ async def process_user(
statistics = i statistics = i
break break
if statistics is None: if statistics is None:
raise ValueError( raise ValueError(f"User {user.id} does not have statistics for mode {score.gamemode.value}")
f"User {user.id} does not have statistics for mode {score.gamemode.value}"
)
# pc, pt, tth, tts # pc, pt, tth, tts
statistics.total_score += score.total_score statistics.total_score += score.total_score
difference = ( difference = score.total_score - previous_score_best.total_score if previous_score_best else score.total_score
score.total_score - previous_score_best.total_score
if previous_score_best
else score.total_score
)
if difference > 0 and score.passed and ranked: if difference > 0 and score.passed and ranked:
match score.rank: match score.rank:
case Rank.X: case Rank.X:
@@ -746,11 +705,8 @@ async def process_user(
statistics.ranked_score += difference statistics.ranked_score += difference
statistics.level_current = calculate_score_to_level(statistics.total_score) statistics.level_current = calculate_score_to_level(statistics.total_score)
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)
new_score_position = await get_score_position_by_user( new_score_position = await get_score_position_by_user(session, score.beatmap_id, user, score.gamemode)
session, score.beatmap_id, user, score.gamemode
)
total_users = await session.exec(select(func.count()).select_from(User)) total_users = await session.exec(select(func.count()).select_from(User))
assert total_users is not None
score_range = min(50, math.ceil(float(total_users.one()) * 0.01)) score_range = min(50, math.ceil(float(total_users.one()) * 0.01))
if new_score_position <= score_range and new_score_position > 0: if new_score_position <= score_range and new_score_position > 0:
# Get the scores that might be displaced # Get the scores that might be displaced
@@ -774,11 +730,7 @@ async def process_user(
) )
# If this score was previously in top positions but now pushed out # If this score was previously in top positions but now pushed out
if ( if i < score_range and displaced_position > score_range and displaced_position is not None:
i < score_range
and displaced_position > score_range
and displaced_position is not None
):
# Create rank lost event for the displaced user # Create rank lost event for the displaced user
rank_lost_event = Event( rank_lost_event = Event(
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
@@ -814,10 +766,7 @@ async def process_user(
) )
# 情况3: 有最佳分数记录和该mod组合的记录且是同一个记录更新得分更高的情况 # 情况3: 有最佳分数记录和该mod组合的记录且是同一个记录更新得分更高的情况
elif ( elif previous_score_best.score_id == previous_score_best_mod.score_id and difference > 0:
previous_score_best.score_id == previous_score_best_mod.score_id
and difference > 0
):
previous_score_best.total_score = score.total_score previous_score_best.total_score = score.total_score
previous_score_best.rank = score.rank previous_score_best.rank = score.rank
previous_score_best.score_id = score.id previous_score_best.score_id = score.id
@@ -847,9 +796,7 @@ async def process_user(
statistics.count_300 += score.n300 + score.ngeki statistics.count_300 += score.n300 + score.ngeki
statistics.count_50 += score.n50 statistics.count_50 += score.n50
statistics.count_miss += score.nmiss statistics.count_miss += score.nmiss
statistics.total_hits += ( statistics.total_hits += score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
)
if score.passed and ranked: if score.passed and ranked:
with session.no_autoflush: with session.no_autoflush:
@@ -885,7 +832,6 @@ async def process_score(
item_id: int | None = None, item_id: int | None = None,
room_id: int | None = None, room_id: int | None = None,
) -> Score: ) -> Score:
assert user.id
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods) can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
gamemode = GameMode.from_int(info.ruleset_id).to_special_mode(info.mods) gamemode = GameMode.from_int(info.ruleset_id).to_special_mode(info.mods)
score = Score( score = Score(
@@ -922,20 +868,15 @@ async def process_score(
if can_get_pp: if can_get_pp:
from app.calculator import pre_fetch_and_calculate_pp from app.calculator import pre_fetch_and_calculate_pp
pp = await pre_fetch_and_calculate_pp( pp = await pre_fetch_and_calculate_pp(score, beatmap_id, session, redis, fetcher)
score, beatmap_id, session, redis, fetcher
)
score.pp = pp score.pp = pp
session.add(score) session.add(score)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
await session.refresh(score) await session.refresh(score)
if can_get_pp and score.pp != 0: if can_get_pp and score.pp != 0:
previous_pp_best = await get_user_best_pp_in_beatmap( previous_pp_best = await get_user_best_pp_in_beatmap(session, beatmap_id, user_id, score.gamemode)
session, beatmap_id, user_id, score.gamemode
)
if previous_pp_best is None or score.pp > previous_pp_best.pp: if previous_pp_best is None or score.pp > previous_pp_best.pp:
assert score.id
best_score = PPBestScore( best_score = PPBestScore(
user_id=user_id, user_id=user_id,
score_id=score.id, score_id=score.id,

View File

@@ -7,6 +7,7 @@ from .beatmap import Beatmap
from .lazer_user import User from .lazer_user import User
from sqlalchemy import Column, DateTime, Index from sqlalchemy import Column, DateTime, Index
from sqlalchemy.orm import Mapped
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
@@ -14,16 +15,12 @@ class ScoreTokenBase(SQLModel, UTCBaseModel):
score_id: int | None = Field(sa_column=Column(BigInteger), default=None) score_id: int | None = Field(sa_column=Column(BigInteger), default=None)
ruleset_id: GameMode ruleset_id: GameMode
playlist_item_id: int | None = Field(default=None) # playlist playlist_item_id: int | None = Field(default=None) # playlist
created_at: datetime = Field( created_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
default_factory=datetime.utcnow, sa_column=Column(DateTime) updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
class ScoreToken(ScoreTokenBase, table=True): class ScoreToken(ScoreTokenBase, table=True):
__tablename__ = "score_tokens" # pyright: ignore[reportAssignmentType] __tablename__: str = "score_tokens"
__table_args__ = (Index("idx_user_playlist", "user_id", "playlist_item_id"),) __table_args__ = (Index("idx_user_playlist", "user_id", "playlist_item_id"),)
id: int | None = Field( id: int | None = Field(
@@ -37,8 +34,8 @@ class ScoreToken(ScoreTokenBase, table=True):
) )
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
beatmap_id: int = Field(foreign_key="beatmaps.id") beatmap_id: int = Field(foreign_key="beatmaps.id")
user: User = Relationship() user: Mapped[User] = Relationship()
beatmap: Beatmap = Relationship() beatmap: Mapped[Beatmap] = Relationship()
class ScoreTokenResp(ScoreTokenBase): class ScoreTokenResp(ScoreTokenBase):

View File

@@ -58,7 +58,7 @@ class UserStatisticsBase(SQLModel):
class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True): class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType] __tablename__: str = "lazer_user_statistics"
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
user_id: int = Field( user_id: int = Field(
default=None, default=None,
@@ -123,9 +123,7 @@ class UserStatisticsResp(UserStatisticsBase):
if "user" in include: if "user" in include:
from .lazer_user import RANKING_INCLUDES, UserResp from .lazer_user import RANKING_INCLUDES, UserResp
user = await UserResp.from_db( user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES
)
s.user = user s.user = user
user_country = user.country_code user_country = user.country_code
@@ -149,9 +147,7 @@ class UserStatisticsResp(UserStatisticsBase):
return s return s
async def get_rank( async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
session: AsyncSession, statistics: UserStatistics, country: str | None = None
) -> int | None:
from .lazer_user import User from .lazer_user import User
query = select( query = select(
@@ -168,9 +164,7 @@ async def get_rank(
subq = query.subquery() subq = query.subquery()
result = await session.exec( result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id))
select(subq.c.rank).where(subq.c.user_id == statistics.user_id)
)
rank = result.first() rank = result.first()
if rank is None: if rank is None:

View File

@@ -11,9 +11,9 @@ if TYPE_CHECKING:
class Team(SQLModel, UTCBaseModel, table=True): class Team(SQLModel, UTCBaseModel, table=True):
__tablename__ = "teams" # pyright: ignore[reportAssignmentType] __tablename__: str = "teams"
id: int | None = Field(default=None, primary_key=True, index=True) id: int = Field(default=None, primary_key=True, index=True)
name: str = Field(max_length=100) name: str = Field(max_length=100)
short_name: str = Field(max_length=10) short_name: str = Field(max_length=10)
flag_url: str | None = Field(default=None) flag_url: str | None = Field(default=None)
@@ -26,34 +26,22 @@ class Team(SQLModel, UTCBaseModel, table=True):
class TeamMember(SQLModel, UTCBaseModel, table=True): class TeamMember(SQLModel, UTCBaseModel, table=True):
__tablename__ = "team_members" # pyright: ignore[reportAssignmentType] __tablename__: str = "team_members"
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)
)
team_id: int = Field(foreign_key="teams.id") team_id: int = Field(foreign_key="teams.id")
joined_at: datetime = Field( joined_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: "User" = Relationship( user: "User" = Relationship(back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"})
back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"} team: "Team" = Relationship(back_populates="members", sa_relationship_kwargs={"lazy": "joined"})
)
team: "Team" = Relationship(
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
)
class TeamRequest(SQLModel, UTCBaseModel, table=True): class TeamRequest(SQLModel, UTCBaseModel, table=True):
__tablename__ = "team_requests" # pyright: ignore[reportAssignmentType] __tablename__: str = "team_requests"
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)
)
team_id: int = Field(foreign_key="teams.id", primary_key=True) team_id: int = Field(foreign_key="teams.id", primary_key=True)
requested_at: datetime = Field( requested_at: datetime = Field(default=datetime.now(UTC), sa_column=Column(DateTime))
default=datetime.now(UTC), sa_column=Column(DateTime)
)
user: "User" = Relationship(sa_relationship_kwargs={"lazy": "joined"}) user: "User" = Relationship(sa_relationship_kwargs={"lazy": "joined"})
team: "Team" = Relationship(sa_relationship_kwargs={"lazy": "joined"}) team: "Team" = Relationship(sa_relationship_kwargs={"lazy": "joined"})

View File

@@ -22,7 +22,7 @@ class UserAccountHistoryBase(SQLModel, UTCBaseModel):
class UserAccountHistory(UserAccountHistoryBase, table=True): class UserAccountHistory(UserAccountHistoryBase, table=True):
__tablename__ = "user_account_history" # pyright: ignore[reportAssignmentType] __tablename__: str = "user_account_history"
id: int | None = Field( id: int | None = Field(
sa_column=Column( sa_column=Column(
@@ -32,9 +32,7 @@ class UserAccountHistory(UserAccountHistoryBase, table=True):
primary_key=True, primary_key=True,
) )
) )
user_id: int = Field( user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
class UserAccountHistoryResp(UserAccountHistoryBase): class UserAccountHistoryResp(UserAccountHistoryBase):

View File

@@ -10,27 +10,17 @@ from sqlmodel import Field, SQLModel
class UserLoginLog(SQLModel, table=True): class UserLoginLog(SQLModel, table=True):
"""User login log table""" """User login log table"""
__tablename__ = "user_login_log" # pyright: ignore[reportAssignmentType] __tablename__: str = "user_login_log"
id: int | None = Field(default=None, primary_key=True, description="Record ID") id: int | None = Field(default=None, primary_key=True, description="Record ID")
user_id: int = Field(index=True, description="User ID") user_id: int = Field(index=True, description="User ID")
ip_address: str = Field( ip_address: str = Field(max_length=45, index=True, description="IP address (supports IPv4 and IPv6)")
max_length=45, index=True, description="IP address (supports IPv4 and IPv6)" user_agent: str | None = Field(default=None, max_length=500, description="User agent information")
) login_time: datetime = Field(default_factory=datetime.utcnow, description="Login time")
user_agent: str | None = Field(
default=None, max_length=500, description="User agent information"
)
login_time: datetime = Field(
default_factory=datetime.utcnow, description="Login time"
)
# GeoIP information # GeoIP information
country_code: str | None = Field( country_code: str | None = Field(default=None, max_length=2, description="Country code")
default=None, max_length=2, description="Country code" country_name: str | None = Field(default=None, max_length=100, description="Country name")
)
country_name: str | None = Field(
default=None, max_length=100, description="Country name"
)
city_name: str | None = Field(default=None, max_length=100, description="City name") city_name: str | None = Field(default=None, max_length=100, description="City name")
latitude: str | None = Field(default=None, max_length=20, description="Latitude") latitude: str | None = Field(default=None, max_length=20, description="Latitude")
longitude: str | None = Field(default=None, max_length=20, description="Longitude") longitude: str | None = Field(default=None, max_length=20, description="Longitude")
@@ -38,22 +28,14 @@ class UserLoginLog(SQLModel, table=True):
# ASN information # ASN information
asn: int | None = Field(default=None, description="Autonomous System Number") asn: int | None = Field(default=None, description="Autonomous System Number")
organization: str | None = Field( organization: str | None = Field(default=None, max_length=200, description="Organization name")
default=None, max_length=200, description="Organization name"
)
# Login status # Login status
login_success: bool = Field( login_success: bool = Field(default=True, description="Whether the login was successful")
default=True, description="Whether the login was successful" login_method: str = Field(max_length=50, description="Login method (password/oauth/etc.)")
)
login_method: str = Field(
max_length=50, description="Login method (password/oauth/etc.)"
)
# Additional information # Additional information
notes: str | None = Field( notes: str | None = Field(default=None, max_length=500, description="Additional notes")
default=None, max_length=500, description="Additional notes"
)
class Config: class Config:
from_attributes = True from_attributes = True

View File

@@ -40,15 +40,11 @@ engine = create_async_engine(
redis_client = redis.from_url(settings.redis_url, decode_responses=True) redis_client = redis.from_url(settings.redis_url, decode_responses=True)
# Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行 # Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行
redis_message_client = sync_redis.from_url( redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1)
settings.redis_url, decode_responses=True, db=1
)
# 数据库依赖 # 数据库依赖
db_session_context: ContextVar[AsyncSession | None] = ContextVar( db_session_context: ContextVar[AsyncSession | None] = ContextVar("db_session_context", default=None)
"db_session_context", default=None
)
async def get_db(): async def get_db():

View File

@@ -25,7 +25,5 @@ async def get_fetcher() -> Fetcher:
if refresh_token: if refresh_token:
fetcher.refresh_token = str(refresh_token) fetcher.refresh_token = str(refresh_token)
if not fetcher.access_token or not fetcher.refresh_token: if not fetcher.access_token or not fetcher.refresh_token:
logger.opt(colors=True).info( logger.opt(colors=True).info(f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>")
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
)
return fetcher return fetcher

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC from datetime import UTC
from typing import cast
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -16,7 +17,7 @@ def get_scheduler() -> AsyncIOScheduler:
global scheduler global scheduler
if scheduler is None: if scheduler is None:
init_scheduler() init_scheduler()
return scheduler # pyright: ignore[reportReturnType] return cast(AsyncIOScheduler, scheduler)
def start_scheduler(): def start_scheduler():

View File

@@ -70,9 +70,7 @@ async def v1_authorize(
if not api_key: if not api_key:
raise HTTPException(status_code=401, detail="Missing API key") raise HTTPException(status_code=401, detail="Missing API key")
api_key_record = ( api_key_record = (await db.exec(select(V1APIKeys).where(V1APIKeys.key == api_key))).first()
await db.exec(select(V1APIKeys).where(V1APIKeys.key == api_key))
).first()
if not api_key_record: if not api_key_record:
raise HTTPException(status_code=401, detail="Invalid API key") raise HTTPException(status_code=401, detail="Invalid API key")
@@ -98,9 +96,7 @@ async def get_current_user(
security_scopes: SecurityScopes, security_scopes: SecurityScopes,
token_pw: Annotated[str | None, Depends(oauth2_password)] = None, token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
token_code: Annotated[str | None, Depends(oauth2_code)] = None, token_code: Annotated[str | None, Depends(oauth2_code)] = None,
token_client_credentials: Annotated[ token_client_credentials: Annotated[str | None, Depends(oauth2_client_credentials)] = None,
str | None, Depends(oauth2_client_credentials)
] = None,
) -> User: ) -> User:
"""获取当前认证用户""" """获取当前认证用户"""
token = token_pw or token_code or token_client_credentials token = token_pw or token_code or token_client_credentials
@@ -119,9 +115,7 @@ async def get_current_user(
if not is_client: if not is_client:
for scope in security_scopes.scopes: for scope in security_scopes.scopes:
if scope not in token_record.scope.split(","): if scope not in token_record.scope.split(","):
raise HTTPException( raise HTTPException(status_code=403, detail=f"Insufficient scope: {scope}")
status_code=403, detail=f"Insufficient scope: {scope}"
)
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
if not user: if not user:

View File

@@ -121,14 +121,10 @@ class BaseFetcher:
except Exception as e: except Exception as e:
last_error = e last_error = e
if attempt < max_retries: if attempt < max_retries:
logger.warning( logger.warning(f"Request failed (attempt {attempt + 1}/{max_retries + 1}): {e}, retrying...")
f"Request failed (attempt {attempt + 1}/{max_retries + 1}): {e}, retrying..."
)
continue continue
else: else:
logger.error( logger.error(f"Request failed after {max_retries + 1} attempts: {e}")
f"Request failed after {max_retries + 1} attempts: {e}"
)
break break
# 如果所有重试都失败了 # 如果所有重试都失败了
@@ -196,13 +192,9 @@ class BaseFetcher:
f"fetcher:refresh_token:{self.client_id}", f"fetcher:refresh_token:{self.client_id}",
self.refresh_token, self.refresh_token,
) )
logger.info( logger.info(f"Successfully refreshed access token for client {self.client_id}")
f"Successfully refreshed access token for client {self.client_id}"
)
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to refresh access token for client {self.client_id}: {e}")
f"Failed to refresh access token for client {self.client_id}: {e}"
)
# 清除无效的 token要求重新授权 # 清除无效的 token要求重新授权
self.access_token = "" self.access_token = ""
self.refresh_token = "" self.refresh_token = ""
@@ -210,9 +202,7 @@ class BaseFetcher:
redis = get_redis() redis = get_redis()
await redis.delete(f"fetcher:access_token:{self.client_id}") await redis.delete(f"fetcher:access_token:{self.client_id}")
await redis.delete(f"fetcher:refresh_token:{self.client_id}") await redis.delete(f"fetcher:refresh_token:{self.client_id}")
logger.warning( logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}")
f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}"
)
raise raise
async def _trigger_reauthorization(self) -> None: async def _trigger_reauthorization(self) -> None:
@@ -237,8 +227,7 @@ class BaseFetcher:
await redis.delete(f"fetcher:refresh_token:{self.client_id}") await redis.delete(f"fetcher:refresh_token:{self.client_id}")
logger.warning( logger.warning(
f"All tokens cleared for client {self.client_id}. " f"All tokens cleared for client {self.client_id}. Please re-authorize using: {self.authorize_url}"
f"Please re-authorize using: {self.authorize_url}"
) )
def reset_auth_retry_count(self) -> None: def reset_auth_retry_count(self) -> None:

View File

@@ -7,18 +7,14 @@ from ._base import BaseFetcher
class BeatmapFetcher(BaseFetcher): class BeatmapFetcher(BaseFetcher):
async def get_beatmap( async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapResp:
self, beatmap_id: int | None = None, beatmap_checksum: str | None = None
) -> BeatmapResp:
if beatmap_id: if beatmap_id:
params = {"id": beatmap_id} params = {"id": beatmap_id}
elif beatmap_checksum: elif beatmap_checksum:
params = {"checksum": beatmap_checksum} params = {"checksum": beatmap_checksum}
else: else:
raise ValueError("Either beatmap_id or beatmap_checksum must be provided.") raise ValueError("Either beatmap_id or beatmap_checksum must be provided.")
logger.opt(colors=True).debug( logger.opt(colors=True).debug(f"<blue>[BeatmapFetcher]</blue> get_beatmap: <y>{params}</y>")
f"<blue>[BeatmapFetcher]</blue> get_beatmap: <y>{params}</y>"
)
return BeatmapResp.model_validate( return BeatmapResp.model_validate(
await self.request_api( await self.request_api(

View File

@@ -18,9 +18,7 @@ class BeatmapRawFetcher(BaseFetcher):
async def get_beatmap_raw(self, beatmap_id: int) -> str: async def get_beatmap_raw(self, beatmap_id: int) -> str:
for url in urls: for url in urls:
req_url = url.format(beatmap_id=beatmap_id) req_url = url.format(beatmap_id=beatmap_id)
logger.opt(colors=True).debug( logger.opt(colors=True).debug(f"<blue>[BeatmapRawFetcher]</blue> get_beatmap_raw: <y>{req_url}</y>")
f"<blue>[BeatmapRawFetcher]</blue> get_beatmap_raw: <y>{req_url}</y>"
)
resp = await self._request(req_url) resp = await self._request(req_url)
if resp.status_code >= 400: if resp.status_code >= 400:
continue continue
@@ -34,9 +32,7 @@ class BeatmapRawFetcher(BaseFetcher):
) )
return response return response
async def get_or_fetch_beatmap_raw( async def get_or_fetch_beatmap_raw(self, redis: redis.Redis, beatmap_id: int) -> str:
self, redis: redis.Redis, beatmap_id: int
) -> str:
from app.config import settings from app.config import settings
cache_key = f"beatmap:{beatmap_id}:raw" cache_key = f"beatmap:{beatmap_id}:raw"
@@ -48,7 +44,7 @@ class BeatmapRawFetcher(BaseFetcher):
if content: if content:
# 延长缓存时间 # 延长缓存时间
await redis.expire(cache_key, cache_expire) await redis.expire(cache_key, cache_expire)
return content # pyright: ignore[reportReturnType] return content
# 获取并缓存 # 获取并缓存
raw = await self.get_beatmap_raw(beatmap_id) raw = await self.get_beatmap_raw(beatmap_id)

View File

@@ -10,6 +10,7 @@ from app.helpers.rate_limiter import osu_api_rate_limiter
from app.log import logger from app.log import logger
from app.models.beatmap import SearchQueryModel from app.models.beatmap import SearchQueryModel
from app.models.model import Cursor from app.models.model import Cursor
from app.utils import bg_tasks
from ._base import BaseFetcher from ._base import BaseFetcher
@@ -81,9 +82,7 @@ class BeatmapsetFetcher(BaseFetcher):
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":")) cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
cache_hash = hashlib.md5(cache_json.encode()).hexdigest() cache_hash = hashlib.md5(cache_json.encode()).hexdigest()
logger.opt(colors=True).debug( logger.opt(colors=True).debug(f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}")
f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}"
)
return f"beatmapset:search:{cache_hash}" return f"beatmapset:search:{cache_hash}"
@@ -103,22 +102,16 @@ class BeatmapsetFetcher(BaseFetcher):
return {} return {}
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp: async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
logger.opt(colors=True).debug( logger.opt(colors=True).debug(f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>")
f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>"
)
return BeatmapsetResp.model_validate( return BeatmapsetResp.model_validate(
await self.request_api( await self.request_api(f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}")
f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}"
)
) )
async def search_beatmapset( async def search_beatmapset(
self, query: SearchQueryModel, cursor: Cursor, redis_client: redis.Redis self, query: SearchQueryModel, cursor: Cursor, redis_client: redis.Redis
) -> SearchBeatmapsetsResp: ) -> SearchBeatmapsetsResp:
logger.opt(colors=True).debug( logger.opt(colors=True).debug(f"<blue>[BeatmapsetFetcher]</blue> search_beatmapset: <y>{query}</y>")
f"<blue>[BeatmapsetFetcher]</blue> search_beatmapset: <y>{query}</y>"
)
# 生成缓存键 # 生成缓存键
cache_key = self._generate_cache_key(query, cursor) cache_key = self._generate_cache_key(query, cursor)
@@ -126,9 +119,7 @@ class BeatmapsetFetcher(BaseFetcher):
# 尝试从缓存获取结果 # 尝试从缓存获取结果
cached_result = await redis_client.get(cache_key) cached_result = await redis_client.get(cache_key)
if cached_result: if cached_result:
logger.opt(colors=True).debug( logger.opt(colors=True).debug(f"<green>[BeatmapsetFetcher]</green> Cache hit for key: <y>{cache_key}</y>")
f"<green>[BeatmapsetFetcher]</green> Cache hit for key: <y>{cache_key}</y>"
)
try: try:
cached_data = json.loads(cached_result) cached_data = json.loads(cached_result)
return SearchBeatmapsetsResp.model_validate(cached_data) return SearchBeatmapsetsResp.model_validate(cached_data)
@@ -138,13 +129,9 @@ class BeatmapsetFetcher(BaseFetcher):
) )
# 缓存未命中,从 API 获取数据 # 缓存未命中,从 API 获取数据
logger.opt(colors=True).debug( logger.opt(colors=True).debug("<blue>[BeatmapsetFetcher]</blue> Cache miss, fetching from API")
"<blue>[BeatmapsetFetcher]</blue> Cache miss, fetching from API"
)
params = query.model_dump( params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
exclude_none=True, exclude_unset=True, exclude_defaults=True
)
if query.cursor_string: if query.cursor_string:
params["cursor_string"] = query.cursor_string params["cursor_string"] = query.cursor_string
@@ -164,39 +151,26 @@ class BeatmapsetFetcher(BaseFetcher):
# 将结果缓存 15 分钟 # 将结果缓存 15 分钟
cache_ttl = 15 * 60 # 15 分钟 cache_ttl = 15 * 60 # 15 分钟
await redis_client.set( await redis_client.set(cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl)
cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl
)
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
f"<green>[BeatmapsetFetcher]</green> Cached result for key: " f"<green>[BeatmapsetFetcher]</green> Cached result for key: <y>{cache_key}</y> (TTL: {cache_ttl}s)"
f"<y>{cache_key}</y> (TTL: {cache_ttl}s)"
) )
resp = SearchBeatmapsetsResp.model_validate(api_response) resp = SearchBeatmapsetsResp.model_validate(api_response)
# 智能预取只在用户明确搜索时才预取避免过多API请求 # 智能预取只在用户明确搜索时才预取避免过多API请求
# 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取 # 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取
if api_response.get("cursor") and ( if api_response.get("cursor") and (query.q or query.s != "leaderboard" or cursor):
query.q or query.s != "leaderboard" or cursor
):
# 在后台预取下1页减少预取量 # 在后台预取下1页减少预取量
import asyncio import asyncio
# 不立即创建任务,而是延迟一段时间再预取 # 不立即创建任务,而是延迟一段时间再预取
async def delayed_prefetch(): async def delayed_prefetch():
await asyncio.sleep(3.0) # 延迟3秒 await asyncio.sleep(3.0) # 延迟3秒
await self.prefetch_next_pages( await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
query, api_response["cursor"], redis_client, pages=1
)
# 创建延迟预取任务 bg_tasks.add_task(delayed_prefetch)
task = asyncio.create_task(delayed_prefetch())
# 添加到后台任务集合避免被垃圾回收
if not hasattr(self, "_background_tasks"):
self._background_tasks = set()
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return resp return resp
@@ -218,18 +192,14 @@ class BeatmapsetFetcher(BaseFetcher):
# 使用当前 cursor 请求下一页 # 使用当前 cursor 请求下一页
next_query = query.model_copy() next_query = query.model_copy()
logger.opt(colors=True).debug( logger.opt(colors=True).debug(f"<cyan>[BeatmapsetFetcher]</cyan> Prefetching page {page + 1}")
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetching page {page + 1}"
)
# 生成下一页的缓存键 # 生成下一页的缓存键
next_cache_key = self._generate_cache_key(next_query, cursor) next_cache_key = self._generate_cache_key(next_query, cursor)
# 检查是否已经缓存 # 检查是否已经缓存
if await redis_client.exists(next_cache_key): if await redis_client.exists(next_cache_key):
logger.opt(colors=True).debug( logger.opt(colors=True).debug(f"<cyan>[BeatmapsetFetcher]</cyan> Page {page + 1} already cached")
f"<cyan>[BeatmapsetFetcher]</cyan> Page {page + 1} already cached"
)
# 尝试从缓存获取cursor继续预取 # 尝试从缓存获取cursor继续预取
cached_data = await redis_client.get(next_cache_key) cached_data = await redis_client.get(next_cache_key)
if cached_data: if cached_data:
@@ -247,9 +217,7 @@ class BeatmapsetFetcher(BaseFetcher):
await asyncio.sleep(1.5) # 1.5秒延迟 await asyncio.sleep(1.5) # 1.5秒延迟
# 请求下一页数据 # 请求下一页数据
params = next_query.model_dump( params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
exclude_none=True, exclude_unset=True, exclude_defaults=True
)
for k, v in cursor.items(): for k, v in cursor.items():
params[f"cursor[{k}]"] = v params[f"cursor[{k}]"] = v
@@ -277,22 +245,18 @@ class BeatmapsetFetcher(BaseFetcher):
) )
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetched page {page + 1} " f"<cyan>[BeatmapsetFetcher]</cyan> Prefetched page {page + 1} (TTL: {prefetch_ttl}s)"
f"(TTL: {prefetch_ttl}s)"
) )
except Exception as e: except Exception as e:
logger.opt(colors=True).warning( logger.opt(colors=True).warning(f"<yellow>[BeatmapsetFetcher]</yellow> Prefetch failed: {e}")
f"<yellow>[BeatmapsetFetcher]</yellow> Prefetch failed: {e}"
)
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None: async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
"""预热主页缓存""" """预热主页缓存"""
homepage_queries = self._get_homepage_queries() homepage_queries = self._get_homepage_queries()
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"<magenta>[BeatmapsetFetcher]</magenta> Starting homepage cache warmup " f"<magenta>[BeatmapsetFetcher]</magenta> Starting homepage cache warmup ({len(homepage_queries)} queries)"
f"({len(homepage_queries)} queries)"
) )
for i, (query, cursor) in enumerate(homepage_queries): for i, (query, cursor) in enumerate(homepage_queries):
@@ -306,15 +270,12 @@ class BeatmapsetFetcher(BaseFetcher):
# 检查是否已经缓存 # 检查是否已经缓存
if await redis_client.exists(cache_key): if await redis_client.exists(cache_key):
logger.opt(colors=True).debug( logger.opt(colors=True).debug(
f"<magenta>[BeatmapsetFetcher]</magenta> " f"<magenta>[BeatmapsetFetcher]</magenta> Query {query.sort} already cached"
f"Query {query.sort} already cached"
) )
continue continue
# 请求并缓存 # 请求并缓存
params = query.model_dump( params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
exclude_none=True, exclude_unset=True, exclude_defaults=True
)
api_response = await self.request_api( api_response = await self.request_api(
"https://osu.ppy.sh/api/v2/beatmapsets/search", "https://osu.ppy.sh/api/v2/beatmapsets/search",
@@ -334,17 +295,13 @@ class BeatmapsetFetcher(BaseFetcher):
) )
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"<magenta>[BeatmapsetFetcher]</magenta> " f"<magenta>[BeatmapsetFetcher]</magenta> Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
) )
if api_response.get("cursor"): if api_response.get("cursor"):
await self.prefetch_next_pages( await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
query, api_response["cursor"], redis_client, pages=2
)
except Exception as e: except Exception as e:
logger.opt(colors=True).error( logger.opt(colors=True).error(
f"<red>[BeatmapsetFetcher]</red> " f"<red>[BeatmapsetFetcher]</red> Failed to warmup cache for {query.sort}: {e}"
f"Failed to warmup cache for {query.sort}: {e}"
) )

View File

@@ -55,14 +55,9 @@ class GeoIPHelper:
- 临时目录退出后自动清理 - 临时目录退出后自动清理
""" """
if not self.license_key: if not self.license_key:
raise ValueError( raise ValueError("缺少 MaxMind License Key请传入或设置环境变量 MAXMIND_LICENSE_KEY")
"缺少 MaxMind License Key请传入或设置环境变量 MAXMIND_LICENSE_KEY"
)
url = ( url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
f"{BASE_URL}?edition_id={edition_id}&"
f"license_key={self.license_key}&suffix=tar.gz"
)
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client: with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
with client.stream("GET", url) as resp: with client.stream("GET", url) as resp:

View File

@@ -48,8 +48,7 @@ class RateLimiter:
if wait_time > 0: if wait_time > 0:
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"<yellow>[RateLimiter]</yellow> Rate limit reached, " f"<yellow>[RateLimiter]</yellow> Rate limit reached, waiting {wait_time:.2f}s"
f"waiting {wait_time:.2f}s"
) )
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
current_time = time.time() current_time = time.time()
@@ -107,11 +106,7 @@ class RateLimiter:
"max_requests_per_minute": self.max_requests_per_minute, "max_requests_per_minute": self.max_requests_per_minute,
"burst_requests": len(self.burst_times), "burst_requests": len(self.burst_times),
"burst_limit": self.burst_limit, "burst_limit": self.burst_limit,
"next_reset_in_seconds": ( "next_reset_in_seconds": (60.0 - (current_time - self.request_times[0]) if self.request_times else 0.0),
60.0 - (current_time - self.request_times[0])
if self.request_times
else 0.0
),
} }

View File

@@ -46,14 +46,10 @@ class InterceptHandler(logging.Handler):
color = True color = True
else: else:
color = False color = False
logger.opt(depth=depth, exception=record.exc_info, colors=color).log( logger.opt(depth=depth, exception=record.exc_info, colors=color).log(level, message)
level, message
)
def _format_uvicorn_error_log(self, message: str) -> str: def _format_uvicorn_error_log(self, message: str) -> str:
websocket_pattern = ( websocket_pattern = r'(\d+\.\d+\.\d+\.\d+:\d+)\s*-\s*"WebSocket\s+([^"]+)"\s+([\w\[\]]+)'
r'(\d+\.\d+\.\d+\.\d+:\d+)\s*-\s*"WebSocket\s+([^"]+)"\s+([\w\[\]]+)'
)
websocket_match = re.search(websocket_pattern, message) websocket_match = re.search(websocket_pattern, message)
if websocket_match: if websocket_match:
@@ -64,14 +60,8 @@ class InterceptHandler(logging.Handler):
"[accepted]": "<green>[accepted]</green>", "[accepted]": "<green>[accepted]</green>",
"403": "<red>403 [rejected]</red>", "403": "<red>403 [rejected]</red>",
} }
colored_status = status_colors.get( colored_status = status_colors.get(status.lower(), f"<white>{status}</white>")
status.lower(), f"<white>{status}</white>" return f'{colored_ip} - "<bold><magenta>WebSocket</magenta> {path}</bold>" {colored_status}'
)
return (
f'{colored_ip} - "<bold><magenta>WebSocket</magenta> '
f'{path}</bold>" '
f"{colored_status}"
)
else: else:
return message return message
@@ -121,9 +111,7 @@ logger.remove()
logger.add( logger.add(
stdout, stdout,
colorize=True, colorize=True,
format=( format=("<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"),
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"
),
level=settings.log_level, level=settings.log_level,
diagnose=settings.debug, diagnose=settings.debug,
) )

View File

@@ -19,17 +19,11 @@ class Achievement(NamedTuple):
@property @property
def url(self) -> str: def url(self) -> str:
return ( return self.medal_url or f"https://assets.ppy.sh/medals/client/{self.assets_id}.png"
self.medal_url
or f"https://assets.ppy.sh/medals/client/{self.assets_id}.png"
)
@property @property
def url2x(self) -> str: def url2x(self) -> str:
return ( return self.medal_url2x or f"https://assets.ppy.sh/medals/client/{self.assets_id}@2x.png"
self.medal_url2x
or f"https://assets.ppy.sh/medals/client/{self.assets_id}@2x.png"
)
MedalProcessor = Callable[[AsyncSession, "Score", "Beatmap"], Awaitable[bool]] MedalProcessor = Callable[[AsyncSession, "Score", "Beatmap"], Awaitable[bool]]

View File

@@ -11,7 +11,8 @@ class APIMe(UserResp):
""" """
/me 端点的响应模型 /me 端点的响应模型
对应 osu! 的 APIMe 类型,继承 APIUser(UserResp) 并包含 session_verified 字段 对应 osu! 的 APIMe 类型,继承 APIUser(UserResp) 并包含 session_verified 字段
session_verified 字段已经在 UserResp 中定义,这里不需要重复定义 session_verified 字段已经在 UserResp 中定义,这里不需要重复定义
""" """
pass pass

View File

@@ -95,11 +95,7 @@ class SearchQueryModel(BaseModel):
q: str = Field("", description="搜索关键词") q: str = Field("", description="搜索关键词")
c: Annotated[ c: Annotated[
list[ list[Literal["recommended", "converts", "follows", "spotlights", "featured_artists"]],
Literal[
"recommended", "converts", "follows", "spotlights", "featured_artists"
]
],
BeforeValidator(_parse_list), BeforeValidator(_parse_list),
PlainSerializer(lambda x: ".".join(x)), PlainSerializer(lambda x: ".".join(x)),
] = Field( ] = Field(
@@ -188,12 +184,10 @@ class SearchQueryModel(BaseModel):
list[Literal["video", "storyboard"]], list[Literal["video", "storyboard"]],
BeforeValidator(_parse_list), BeforeValidator(_parse_list),
PlainSerializer(lambda x: ".".join(x)), PlainSerializer(lambda x: ".".join(x)),
] = Field( ] = Field(default_factory=list, description=("其他video 有视频 / storyboard 有故事板"))
default_factory=list, description=("其他video 有视频 / storyboard 有故事板") r: Annotated[list[Rank], BeforeValidator(_parse_list), PlainSerializer(lambda x: ".".join(x))] = Field(
default_factory=list, description="成绩"
) )
r: Annotated[
list[Rank], BeforeValidator(_parse_list), PlainSerializer(lambda x: ".".join(x))
] = Field(default_factory=list, description="成绩")
played: bool = Field( played: bool = Field(
default=False, default=False,
description="玩过", description="玩过",

View File

@@ -9,12 +9,13 @@ from pydantic import BaseModel
class ExtendedTokenResponse(BaseModel): class ExtendedTokenResponse(BaseModel):
"""扩展的令牌响应,支持二次验证状态""" """扩展的令牌响应,支持二次验证状态"""
access_token: str | None = None access_token: str | None = None
token_type: str = "Bearer" token_type: str = "Bearer"
expires_in: int | None = None expires_in: int | None = None
refresh_token: str | None = None refresh_token: str | None = None
scope: str | None = None scope: str | None = None
# 二次验证相关字段 # 二次验证相关字段
requires_second_factor: bool = False requires_second_factor: bool = False
verification_message: str | None = None verification_message: str | None = None
@@ -23,6 +24,7 @@ class ExtendedTokenResponse(BaseModel):
class SessionState(BaseModel): class SessionState(BaseModel):
"""会话状态""" """会话状态"""
user_id: int user_id: int
username: str username: str
email: str email: str

View File

@@ -145,9 +145,7 @@ class MultiplayerPlaylistItemStats(BaseModel):
class MultiplayerRoomStats(BaseModel): class MultiplayerRoomStats(BaseModel):
room_id: int room_id: int
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field( playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(default_factory=dict)
default_factory=dict
)
class MultiplayerRoomScoreSetEvent(BaseModel): class MultiplayerRoomScoreSetEvent(BaseModel):

View File

@@ -174,11 +174,7 @@ def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool:
return True return True
ranked_mods = RANKED_MODS[ruleset_id] ranked_mods = RANKED_MODS[ruleset_id]
for mod in mods: for mod in mods:
if ( if app_settings.enable_rx and mod["acronym"] == "RX" and ruleset_id in {0, 1, 2}:
app_settings.enable_rx
and mod["acronym"] == "RX"
and ruleset_id in {0, 1, 2}
):
continue continue
if app_settings.enable_ap and mod["acronym"] == "AP" and ruleset_id == 0: if app_settings.enable_ap and mod["acronym"] == "AP" and ruleset_id == 0:
continue continue
@@ -251,10 +247,7 @@ def get_available_mods(ruleset_id: int, required_mods: list[APIMod]) -> list[API
if mod_acronym in incompatible_mods: if mod_acronym in incompatible_mods:
continue continue
if any( if any(required_acronym in mod_data["IncompatibleMods"] for required_acronym in required_mod_acronyms):
required_acronym in mod_data["IncompatibleMods"]
for required_acronym in required_mod_acronyms
):
continue continue
if mod_data.get("UserPlayable", False): if mod_data.get("UserPlayable", False):

View File

@@ -121,32 +121,21 @@ class PlaylistItem(BaseModel):
star_rating: float star_rating: float
freestyle: bool freestyle: bool
def _validate_mod_for_ruleset( def _validate_mod_for_ruleset(self, mod: APIMod, ruleset_key: int, context: str = "mod") -> None:
self, mod: APIMod, ruleset_key: int, context: str = "mod"
) -> None:
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
# Check if mod is valid for ruleset # Check if mod is valid for ruleset
if ( if typed_ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[typed_ruleset_key]:
typed_ruleset_key not in API_MODS raise InvokeException(f"{context} {mod['acronym']} is invalid for this ruleset")
or mod["acronym"] not in API_MODS[typed_ruleset_key]
):
raise InvokeException(
f"{context} {mod['acronym']} is invalid for this ruleset"
)
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]] mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
# Check if mod is unplayable in multiplayer # Check if mod is unplayable in multiplayer
if mod_settings.get("UserPlayable", True) is False: if mod_settings.get("UserPlayable", True) is False:
raise InvokeException( raise InvokeException(f"{context} {mod['acronym']} is not playable by users")
f"{context} {mod['acronym']} is not playable by users"
)
if mod_settings.get("ValidForMultiplayer", True) is False: if mod_settings.get("ValidForMultiplayer", True) is False:
raise InvokeException( raise InvokeException(f"{context} {mod['acronym']} is not valid for multiplayer")
f"{context} {mod['acronym']} is not valid for multiplayer"
)
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None: def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
from typing import Literal, cast from typing import Literal, cast
@@ -159,10 +148,7 @@ class PlaylistItem(BaseModel):
incompatible = set(mod1_settings.get("IncompatibleMods", [])) incompatible = set(mod1_settings.get("IncompatibleMods", []))
for mod2 in mods[i + 1 :]: for mod2 in mods[i + 1 :]:
if mod2["acronym"] in incompatible: if mod2["acronym"] in incompatible:
raise InvokeException( raise InvokeException(f"Mods {mod1['acronym']} and {mod2['acronym']} are incompatible")
f"Mods {mod1['acronym']} and "
f"{mod2['acronym']} are incompatible"
)
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None: def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
from typing import Literal, cast from typing import Literal, cast
@@ -178,10 +164,7 @@ class PlaylistItem(BaseModel):
conflicting_allowed = allowed_acronyms & incompatible conflicting_allowed = allowed_acronyms & incompatible
if conflicting_allowed: if conflicting_allowed:
conflict_list = ", ".join(conflicting_allowed) conflict_list = ", ".join(conflicting_allowed)
raise InvokeException( raise InvokeException(f"Required mod {req_acronym} conflicts with allowed mods: {conflict_list}")
f"Required mod {req_acronym} conflicts with "
f"allowed mods: {conflict_list}"
)
def validate_playlist_item_mods(self) -> None: def validate_playlist_item_mods(self) -> None:
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id) ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
@@ -219,10 +202,7 @@ class PlaylistItem(BaseModel):
# Check if mods are valid for the ruleset # Check if mods are valid for the ruleset
for mod in proposed_mods: for mod in proposed_mods:
if ( if ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[ruleset_key]:
ruleset_key not in API_MODS
or mod["acronym"] not in API_MODS[ruleset_key]
):
all_proposed_valid = False all_proposed_valid = False
continue continue
valid_mods.append(mod) valid_mods.append(mod)
@@ -252,9 +232,7 @@ class PlaylistItem(BaseModel):
# Check compatibility with required mods # Check compatibility with required mods
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods} required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
all_mod_acronyms = { all_mod_acronyms = {mod["acronym"] for mod in final_valid_mods} | required_mod_acronyms
mod["acronym"] for mod in final_valid_mods
} | required_mod_acronyms
# Check for incompatibility between required and user mods # Check for incompatibility between required and user mods
filtered_valid_mods = [] filtered_valid_mods = []
@@ -288,9 +266,7 @@ class PlaylistItem(BaseModel):
class _MultiplayerCountdown(SignalRUnionMessage): class _MultiplayerCountdown(SignalRUnionMessage):
id: int = 0 id: int = 0
time_remaining: timedelta time_remaining: timedelta
is_exclusive: Annotated[ is_exclusive: Annotated[bool, Field(default=True), SignalRMeta(member_ignore=True)] = True
bool, Field(default=True), SignalRMeta(member_ignore=True)
] = True
class MatchStartCountdown(_MultiplayerCountdown): class MatchStartCountdown(_MultiplayerCountdown):
@@ -305,17 +281,13 @@ class ServerShuttingDownCountdown(_MultiplayerCountdown):
union_type: ClassVar[Literal[2]] = 2 union_type: ClassVar[Literal[2]] = 2
MultiplayerCountdown = ( MultiplayerCountdown = MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
)
class MultiplayerRoomUser(BaseModel): class MultiplayerRoomUser(BaseModel):
user_id: int user_id: int
state: MultiplayerUserState = MultiplayerUserState.IDLE state: MultiplayerUserState = MultiplayerUserState.IDLE
availability: BeatmapAvailability = BeatmapAvailability( availability: BeatmapAvailability = BeatmapAvailability(state=DownloadState.UNKNOWN, download_progress=None)
state=DownloadState.UNKNOWN, download_progress=None
)
mods: list[APIMod] = Field(default_factory=list) mods: list[APIMod] = Field(default_factory=list)
match_state: MatchUserState | None = None match_state: MatchUserState | None = None
ruleset_id: int | None = None # freestyle ruleset_id: int | None = None # freestyle
@@ -358,9 +330,7 @@ class MultiplayerRoom(BaseModel):
expired=item.expired, expired=item.expired,
playlist_order=item.playlist_order, playlist_order=item.playlist_order,
played_at=item.played_at, played_at=item.played_at,
star_rating=item.beatmap.difficulty_rating star_rating=item.beatmap.difficulty_rating if item.beatmap is not None else 0.0,
if item.beatmap is not None
else 0.0,
freestyle=item.freestyle, freestyle=item.freestyle,
) )
) )
@@ -425,9 +395,7 @@ class MultiplayerQueue:
user_item_groups[item.owner_id] = [] user_item_groups[item.owner_id] = []
user_item_groups[item.owner_id].append(item) user_item_groups[item.owner_id].append(item)
max_items = max( max_items = max((len(items) for items in user_item_groups.values()), default=0)
(len(items) for items in user_item_groups.values()), default=0
)
for i in range(max_items): for i in range(max_items):
current_set = [] current_set = []
@@ -436,20 +404,13 @@ class MultiplayerQueue:
current_set.append(items[i]) current_set.append(items[i])
if is_first_set: if is_first_set:
current_set.sort( current_set.sort(key=lambda item: (item.playlist_order, item.id))
key=lambda item: (item.playlist_order, item.id)
)
ordered_active_items.extend(current_set) ordered_active_items.extend(current_set)
first_set_order_by_user_id = { first_set_order_by_user_id = {
item.owner_id: idx item.owner_id: idx for idx, item in enumerate(ordered_active_items)
for idx, item in enumerate(ordered_active_items)
} }
else: else:
current_set.sort( current_set.sort(key=lambda item: first_set_order_by_user_id.get(item.owner_id, 0))
key=lambda item: first_set_order_by_user_id.get(
item.owner_id, 0
)
)
ordered_active_items.extend(current_set) ordered_active_items.extend(current_set)
is_first_set = False is_first_set = False
@@ -464,9 +425,7 @@ class MultiplayerQueue:
continue continue
item.playlist_order = idx item.playlist_order = idx
await Playlist.update(item, self.room.room_id, session) await Playlist.update(item, self.room.room_id, session)
await self.hub.playlist_changed( await self.hub.playlist_changed(self.server_room, item, beatmap_changed=False)
self.server_room, item, beatmap_changed=False
)
async def update_current_item(self): async def update_current_item(self):
upcoming_items = self.upcoming_items upcoming_items = self.upcoming_items
@@ -494,16 +453,7 @@ class MultiplayerQueue:
raise InvokeException("You are not the host") raise InvokeException("You are not the host")
limit = HOST_LIMIT if is_host else PER_USER_LIMIT limit = HOST_LIMIT if is_host else PER_USER_LIMIT
if ( if len([True for u in self.room.playlist if u.owner_id == user.user_id and not u.expired]) >= limit:
len(
[
True
for u in self.room.playlist
if u.owner_id == user.user_id and not u.expired
]
)
>= limit
):
raise InvokeException(f"You can only have {limit} items in the queue") raise InvokeException(f"You can only have {limit} items in the queue")
if item.freestyle and len(item.allowed_mods) > 0: if item.freestyle and len(item.allowed_mods) > 0:
@@ -512,9 +462,7 @@ class MultiplayerQueue:
async with with_db() as session: async with with_db() as session:
fetcher = await get_fetcher() fetcher = await get_fetcher()
async with session: async with session:
beatmap = await Beatmap.get_or_fetch( beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
session, fetcher, bid=item.beatmap_id
)
if beatmap is None: if beatmap is None:
raise InvokeException("Beatmap not found") raise InvokeException("Beatmap not found")
if item.beatmap_checksum != beatmap.checksum: if item.beatmap_checksum != beatmap.checksum:
@@ -538,29 +486,19 @@ class MultiplayerQueue:
async with with_db() as session: async with with_db() as session:
fetcher = await get_fetcher() fetcher = await get_fetcher()
async with session: async with session:
beatmap = await Beatmap.get_or_fetch( beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
session, fetcher, bid=item.beatmap_id
)
if item.beatmap_checksum != beatmap.checksum: if item.beatmap_checksum != beatmap.checksum:
raise InvokeException("Checksum mismatch") raise InvokeException("Checksum mismatch")
existing_item = next( existing_item = next((i for i in self.room.playlist if i.id == item.id), None)
(i for i in self.room.playlist if i.id == item.id), None
)
if existing_item is None: if existing_item is None:
raise InvokeException( raise InvokeException("Attempted to change an item that doesn't exist")
"Attempted to change an item that doesn't exist"
)
if existing_item.owner_id != user.user_id and self.room.host != user: if existing_item.owner_id != user.user_id and self.room.host != user:
raise InvokeException( raise InvokeException("Attempted to change an item which is not owned by the user")
"Attempted to change an item which is not owned by the user"
)
if existing_item.expired: if existing_item.expired:
raise InvokeException( raise InvokeException("Attempted to change an item which has already been played")
"Attempted to change an item which has already been played"
)
item.validate_playlist_item_mods() item.validate_playlist_item_mods()
item.owner_id = user.user_id item.owner_id = user.user_id
@@ -578,8 +516,7 @@ class MultiplayerQueue:
await self.hub.playlist_changed( await self.hub.playlist_changed(
self.server_room, self.server_room,
item, item,
beatmap_changed=item.beatmap_checksum beatmap_changed=item.beatmap_checksum != existing_item.beatmap_checksum,
!= existing_item.beatmap_checksum,
) )
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser): async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
@@ -600,14 +537,10 @@ class MultiplayerQueue:
raise InvokeException("The only item in the room cannot be removed") raise InvokeException("The only item in the room cannot be removed")
if item.owner_id != user.user_id and self.room.host != user: if item.owner_id != user.user_id and self.room.host != user:
raise InvokeException( raise InvokeException("Attempted to remove an item which is not owned by the user")
"Attempted to remove an item which is not owned by the user"
)
if item.expired: if item.expired:
raise InvokeException( raise InvokeException("Attempted to remove an item which has already been played")
"Attempted to remove an item which has already been played"
)
async with with_db() as session: async with with_db() as session:
await Playlist.delete_item(item.id, self.room.room_id, session) await Playlist.delete_item(item.id, self.room.room_id, session)
@@ -668,9 +601,7 @@ class CountdownInfo:
def __init__(self, countdown: MultiplayerCountdown): def __init__(self, countdown: MultiplayerCountdown):
self.countdown = countdown self.countdown = countdown
self.duration = ( self.duration = (
countdown.time_remaining countdown.time_remaining if countdown.time_remaining > timedelta(seconds=0) else timedelta(seconds=0)
if countdown.time_remaining > timedelta(seconds=0)
else timedelta(seconds=0)
) )
@@ -704,9 +635,7 @@ class MatchTypeHandler(ABC):
async def handle_join(self, user: MultiplayerRoomUser): ... async def handle_join(self, user: MultiplayerRoomUser): ...
@abstractmethod @abstractmethod
async def handle_request( async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
self, user: MultiplayerRoomUser, request: MatchRequest
): ...
@abstractmethod @abstractmethod
async def handle_leave(self, user: MultiplayerRoomUser): ... async def handle_leave(self, user: MultiplayerRoomUser): ...
@@ -723,9 +652,7 @@ class HeadToHeadHandler(MatchTypeHandler):
await self.hub.change_user_match_state(self.room, user) await self.hub.change_user_match_state(self.room, user)
@override @override
async def handle_request( async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
self, user: MultiplayerRoomUser, request: MatchRequest
): ...
@override @override
async def handle_leave(self, user: MultiplayerRoomUser): ... async def handle_leave(self, user: MultiplayerRoomUser): ...
@@ -762,9 +689,7 @@ class TeamVersusHandler(MatchTypeHandler):
team_counts = defaultdict(int) team_counts = defaultdict(int)
for user in self.room.room.users: for user in self.room.room.users:
if user.match_state is not None and isinstance( if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
user.match_state, TeamVersusUserState
):
team_counts[user.match_state.team_id] += 1 team_counts[user.match_state.team_id] += 1
if team_counts: if team_counts:
@@ -798,9 +723,7 @@ class TeamVersusHandler(MatchTypeHandler):
def get_details(self) -> MatchStartedEventDetail: def get_details(self) -> MatchStartedEventDetail:
teams: dict[int, Literal["blue", "red"]] = {} teams: dict[int, Literal["blue", "red"]] = {}
for user in self.room.room.users: for user in self.room.room.users:
if user.match_state is not None and isinstance( if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
user.match_state, TeamVersusUserState
):
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red" teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
detail = MatchStartedEventDetail(room_type="team_versus", team=teams) detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
return detail return detail
@@ -843,9 +766,7 @@ class ServerMultiplayerRoom:
self._tracked_countdown = {} self._tracked_countdown = {}
async def set_handler(self): async def set_handler(self):
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type]( self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](self)
self
)
for i in self.room.users: for i in self.room.users:
await self.match_type_handler.handle_join(i) await self.match_type_handler.handle_join(i)
@@ -871,9 +792,7 @@ class ServerMultiplayerRoom:
info = CountdownInfo(countdown) info = CountdownInfo(countdown)
self.room.active_countdowns.append(info.countdown) self.room.active_countdowns.append(info.countdown)
self._tracked_countdown[countdown.id] = info self._tracked_countdown[countdown.id] = info
await self.hub.send_match_event( await self.hub.send_match_event(self, CountdownStartedEvent(countdown=info.countdown))
self, CountdownStartedEvent(countdown=info.countdown)
)
info.task = asyncio.create_task(_countdown_task(self)) info.task = asyncio.create_task(_countdown_task(self))
async def stop_countdown(self, countdown: MultiplayerCountdown): async def stop_countdown(self, countdown: MultiplayerCountdown):

View File

@@ -53,7 +53,7 @@ class NotificationName(str, Enum):
NotificationName.BEATMAP_OWNER_CHANGE: "beatmap_owner_change", NotificationName.BEATMAP_OWNER_CHANGE: "beatmap_owner_change",
NotificationName.BEATMAPSET_DISCUSSION_LOCK: "beatmapset_discussion", NotificationName.BEATMAPSET_DISCUSSION_LOCK: "beatmapset_discussion",
NotificationName.BEATMAPSET_DISCUSSION_POST_NEW: "beatmapset_discussion", NotificationName.BEATMAPSET_DISCUSSION_POST_NEW: "beatmapset_discussion",
NotificationName.BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM: "beatmapset_problem", # noqa: E501 NotificationName.BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM: "beatmapset_problem",
NotificationName.BEATMAPSET_DISCUSSION_REVIEW_NEW: "beatmapset_discussion", NotificationName.BEATMAPSET_DISCUSSION_REVIEW_NEW: "beatmapset_discussion",
NotificationName.BEATMAPSET_DISCUSSION_UNLOCK: "beatmapset_discussion", NotificationName.BEATMAPSET_DISCUSSION_UNLOCK: "beatmapset_discussion",
NotificationName.BEATMAPSET_DISQUALIFY: "beatmapset_state", NotificationName.BEATMAPSET_DISQUALIFY: "beatmapset_state",
@@ -164,17 +164,11 @@ class ChannelMessageTeam(ChannelMessageBase):
from app.database import TeamMember from app.database import TeamMember
user_team_id = ( user_team_id = (
await session.exec( await session.exec(select(TeamMember.team_id).where(TeamMember.user_id == self._user.id))
select(TeamMember.team_id).where(TeamMember.user_id == self._user.id)
)
).first() ).first()
if not user_team_id: if not user_team_id:
return [] return []
user_ids = ( user_ids = (await session.exec(select(TeamMember.user_id).where(TeamMember.team_id == user_team_id))).all()
await session.exec(
select(TeamMember.user_id).where(TeamMember.team_id == user_team_id)
)
).all()
return list(user_ids) return list(user_ids)

View File

@@ -197,9 +197,7 @@ class SoloScoreSubmissionInfo(BaseModel):
# check incompatible mods # check incompatible mods
for mod in mods: for mod in mods:
if mod["acronym"] in incompatible_mods: if mod["acronym"] in incompatible_mods:
raise ValueError( raise ValueError(f"Mod {mod['acronym']} is incompatible with other mods")
f"Mod {mod['acronym']} is incompatible with other mods"
)
setting_mods = API_MODS[info.data["ruleset_id"]].get(mod["acronym"]) setting_mods = API_MODS[info.data["ruleset_id"]].get(mod["acronym"])
if not setting_mods: if not setting_mods:
raise ValueError(f"Invalid mod: {mod['acronym']}") raise ValueError(f"Invalid mod: {mod['acronym']}")

View File

@@ -22,9 +22,7 @@ class SignalRUnionMessage(BaseModel):
class Transport(BaseModel): class Transport(BaseModel):
transport: str transport: str
transfer_formats: list[str] = Field( transfer_formats: list[str] = Field(default_factory=lambda: ["Binary", "Text"], alias="transferFormats")
default_factory=lambda: ["Binary", "Text"], alias="transferFormats"
)
class NegotiateResponse(BaseModel): class NegotiateResponse(BaseModel):

View File

@@ -89,9 +89,7 @@ class LegacyReplayFrame(BaseModel):
mouse_y: float | None = None mouse_y: float | None = None
button_state: int button_state: int
header: Annotated[ header: Annotated[FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)]
FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)
]
class FrameDataBundle(BaseModel): class FrameDataBundle(BaseModel):

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
import re import re
from typing import Literal, Union from typing import Literal
from app.auth import ( from app.auth import (
authenticate_user, authenticate_user,
@@ -22,19 +22,19 @@ from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import get_client_ip, get_geoip_helper from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.helpers.geoip_helper import GeoIPHelper from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger from app.log import logger
from app.models.extended_auth import ExtendedTokenResponse
from app.models.oauth import ( from app.models.oauth import (
OAuthErrorResponse, OAuthErrorResponse,
RegistrationRequestErrors, RegistrationRequestErrors,
TokenResponse, TokenResponse,
UserRegistrationErrors, UserRegistrationErrors,
) )
from app.models.extended_auth import ExtendedTokenResponse
from app.models.score import GameMode from app.models.score import GameMode
from app.service.login_log_service import LoginLogService
from app.service.email_verification_service import ( from app.service.email_verification_service import (
EmailVerificationService, EmailVerificationService,
LoginSessionService LoginSessionService,
) )
from app.service.login_log_service import LoginLogService
from app.service.password_reset_service import password_reset_service from app.service.password_reset_service import password_reset_service
from fastapi import APIRouter, Depends, Form, Request from fastapi import APIRouter, Depends, Form, Request
@@ -44,13 +44,9 @@ from sqlalchemy import text
from sqlmodel import select from sqlmodel import select
def create_oauth_error_response( def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400):
error: str, description: str, hint: str, status_code: int = 400
):
"""创建标准的 OAuth 错误响应""" """创建标准的 OAuth 错误响应"""
error_data = OAuthErrorResponse( error_data = OAuthErrorResponse(error=error, error_description=description, hint=hint, message=description)
error=error, error_description=description, hint=hint, message=description
)
return JSONResponse(status_code=status_code, content=error_data.model_dump()) return JSONResponse(status_code=status_code, content=error_data.model_dump())
@@ -123,9 +119,7 @@ async def register_user(
) )
) )
return JSONResponse( return JSONResponse(status_code=422, content={"form_error": errors.model_dump()})
status_code=422, content={"form_error": errors.model_dump()}
)
try: try:
# 获取客户端 IP 并查询地理位置 # 获取客户端 IP 并查询地理位置
@@ -137,10 +131,7 @@ async def register_user(
geo_info = geoip.lookup(client_ip) geo_info = geoip.lookup(client_ip)
if geo_info and geo_info.get("country_iso"): if geo_info and geo_info.get("country_iso"):
country_code = geo_info["country_iso"] country_code = geo_info["country_iso"]
logger.info( logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
f"User {user_username} registering from "
f"{client_ip}, country: {country_code}"
)
else: else:
logger.warning(f"Could not determine country for IP {client_ip}") logger.warning(f"Could not determine country for IP {client_ip}")
except Exception as e: except Exception as e:
@@ -148,7 +139,7 @@ async def register_user(
# 创建新用户 # 创建新用户
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy # 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
result = await db.execute( # pyright: ignore[reportDeprecated] result = await db.execute(
text( text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES " "SELECT AUTO_INCREMENT FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'" "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
@@ -173,7 +164,6 @@ async def register_user(
db.add(new_user) db.add(new_user)
await db.commit() await db.commit()
await db.refresh(new_user) await db.refresh(new_user)
assert new_user.id is not None, "New user ID should not be None"
for i in [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]: for i in [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]:
statistics = UserStatistics(mode=i, user_id=new_user.id) statistics = UserStatistics(mode=i, user_id=new_user.id)
db.add(statistics) db.add(statistics)
@@ -193,36 +183,30 @@ async def register_user(
logger.exception(f"Registration error for user {user_username}") logger.exception(f"Registration error for user {user_username}")
# 返回通用错误 # 返回通用错误
errors = RegistrationRequestErrors( errors = RegistrationRequestErrors(message="An error occurred while creating your account. Please try again.")
message="An error occurred while creating your account. Please try again."
)
return JSONResponse( return JSONResponse(status_code=500, content={"form_error": errors.model_dump()})
status_code=500, content={"form_error": errors.model_dump()}
)
@router.post( @router.post(
"/oauth/token", "/oauth/token",
response_model=Union[TokenResponse, ExtendedTokenResponse], response_model=TokenResponse | ExtendedTokenResponse,
name="获取访问令牌", name="获取访问令牌",
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。", description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
) )
async def oauth_token( async def oauth_token(
db: Database, db: Database,
request: Request, request: Request,
grant_type: Literal[ grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
"authorization_code", "refresh_token", "password", "client_credentials" ..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
] = Form(..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"), ),
client_id: int = Form(..., description="客户端 ID"), client_id: int = Form(..., description="客户端 ID"),
client_secret: str = Form(..., description="客户端密钥"), client_secret: str = Form(..., description="客户端密钥"),
code: str | None = Form(None, description="授权码(仅授权码模式需要)"), code: str | None = Form(None, description="授权码(仅授权码模式需要)"),
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*'"), scope: str = Form("*", description="权限范围(空格分隔,默认为 '*'"),
username: str | None = Form(None, description="用户名(仅密码模式需要)"), username: str | None = Form(None, description="用户名(仅密码模式需要)"),
password: str | None = Form(None, description="密码(仅密码模式需要)"), password: str | None = Form(None, description="密码(仅密码模式需要)"),
refresh_token: str | None = Form( refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
None, description="刷新令牌(仅刷新令牌模式需要)"
),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
geoip: GeoIPHelper = Depends(get_geoip_helper), geoip: GeoIPHelper = Depends(get_geoip_helper),
): ):
@@ -303,37 +287,33 @@ async def oauth_token(
await db.refresh(user) await db.refresh(user)
# 获取用户信息和客户端信息 # 获取用户信息和客户端信息
user_id = getattr(user, "id") user_id = user.id
assert user_id is not None, "User ID should not be None after authentication"
from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request) ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "") user_agent = request.headers.get("User-Agent", "")
# 获取国家代码 # 获取国家代码
geo_info = geoip.lookup(ip_address) geo_info = geoip.lookup(ip_address)
country_code = geo_info.get("country_iso", "XX") country_code = geo_info.get("country_iso", "XX")
# 检查是否为新位置登录 # 检查是否为新位置登录
is_new_location = await LoginSessionService.check_new_location( is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
db, user_id, ip_address, country_code
)
# 创建登录会话记录 # 创建登录会话记录
login_session = await LoginSessionService.create_session( login_session = await LoginSessionService.create_session( # noqa: F841
db, redis, user_id, ip_address, user_agent, country_code, is_new_location db, redis, user_id, ip_address, user_agent, country_code, is_new_location
) )
# 如果是新位置登录,需要邮件验证 # 如果是新位置登录,需要邮件验证
if is_new_location and settings.enable_email_verification: if is_new_location and settings.enable_email_verification:
# 刷新用户对象以确保属性已加载 # 刷新用户对象以确保属性已加载
await db.refresh(user) await db.refresh(user)
# 发送邮件验证码 # 发送邮件验证码
verification_sent = await EmailVerificationService.send_verification_email( verification_sent = await EmailVerificationService.send_verification_email(
db, redis, user_id, user.username, user.email, ip_address, user_agent db, redis, user_id, user.username, user.email, ip_address, user_agent
) )
# 记录需要二次验证的登录尝试 # 记录需要二次验证的登录尝试
await LoginLogService.record_login( await LoginLogService.record_login(
db=db, db=db,
@@ -343,14 +323,16 @@ async def oauth_token(
login_method="password_pending_verification", login_method="password_pending_verification",
notes=f"新位置登录,需要邮件验证 - IP: {ip_address}, 国家: {country_code}", notes=f"新位置登录,需要邮件验证 - IP: {ip_address}, 国家: {country_code}",
) )
if not verification_sent: if not verification_sent:
# 邮件发送失败,记录错误 # 邮件发送失败,记录错误
logger.error(f"[Auth] Failed to send email verification code for user {user_id}") logger.error(f"[Auth] Failed to send email verification code for user {user_id}")
elif is_new_location and not settings.enable_email_verification: elif is_new_location and not settings.enable_email_verification:
# 新位置登录但邮件验证功能被禁用,直接标记会话为已验证 # 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
await LoginSessionService.mark_session_verified(db, user_id) await LoginSessionService.mark_session_verified(db, user_id)
logger.debug(f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}") logger.debug(
f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}"
)
else: else:
# 不是新位置登录,正常登录 # 不是新位置登录,正常登录
await LoginLogService.record_login( await LoginLogService.record_login(
@@ -361,20 +343,17 @@ async def oauth_token(
login_method="password", login_method="password",
notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}", notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}",
) )
# 无论是否新位置登录都返回正常的token # 无论是否新位置登录都返回正常的token
# session_verified状态通过/me接口的session_verified字段来体现 # session_verified状态通过/me接口的session_verified字段来体现
# 生成令牌 # 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
# 获取用户ID避免触发延迟加载 # 获取用户ID避免触发延迟加载
access_token = create_access_token( access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
data={"sub": str(user_id)}, expires_delta=access_token_expires
)
refresh_token_str = generate_refresh_token() refresh_token_str = generate_refresh_token()
# 存储令牌 # 存储令牌
assert user_id
await store_token( await store_token(
db, db,
user_id, user_id,
@@ -423,9 +402,7 @@ async def oauth_token(
# 生成新的访问令牌 # 生成新的访问令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token( access_token = create_access_token(data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires)
data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires
)
new_refresh_token = generate_refresh_token() new_refresh_token = generate_refresh_token()
# 更新令牌 # 更新令牌
@@ -489,17 +466,11 @@ async def oauth_token(
# 生成令牌 # 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
# 重新查询只获取ID避免触发延迟加载 user_id = user.id
id_result = await db.exec(select(User.id).where(User.username == username)) access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
user_id = id_result.first()
access_token = create_access_token(
data={"sub": str(user_id)}, expires_delta=access_token_expires
)
refresh_token_str = generate_refresh_token() refresh_token_str = generate_refresh_token()
# 存储令牌 # 存储令牌
assert user_id
await store_token( await store_token(
db, db,
user_id, user_id,
@@ -539,9 +510,7 @@ async def oauth_token(
# 生成令牌 # 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token( access_token = create_access_token(data={"sub": "3"}, expires_delta=access_token_expires)
data={"sub": "3"}, expires_delta=access_token_expires
)
refresh_token_str = generate_refresh_token() refresh_token_str = generate_refresh_token()
# 存储令牌 # 存储令牌
@@ -567,7 +536,7 @@ async def oauth_token(
@router.post( @router.post(
"/password-reset/request", "/password-reset/request",
name="请求密码重置", name="请求密码重置",
description="通过邮箱请求密码重置验证码" description="通过邮箱请求密码重置验证码",
) )
async def request_password_reset( async def request_password_reset(
request: Request, request: Request,
@@ -578,42 +547,26 @@ async def request_password_reset(
请求密码重置 请求密码重置
""" """
from app.dependencies.geoip import get_client_ip from app.dependencies.geoip import get_client_ip
# 获取客户端信息 # 获取客户端信息
ip_address = get_client_ip(request) ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "") user_agent = request.headers.get("User-Agent", "")
# 请求密码重置 # 请求密码重置
success, message = await password_reset_service.request_password_reset( success, message = await password_reset_service.request_password_reset(
email=email.lower().strip(), email=email.lower().strip(),
ip_address=ip_address, ip_address=ip_address,
user_agent=user_agent, user_agent=user_agent,
redis=redis redis=redis,
) )
if success: if success:
return JSONResponse( return JSONResponse(status_code=200, content={"success": True, "message": message})
status_code=200,
content={
"success": True,
"message": message
}
)
else: else:
return JSONResponse( return JSONResponse(status_code=400, content={"success": False, "error": message})
status_code=400,
content={
"success": False,
"error": message
}
)
@router.post( @router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
"/password-reset/reset",
name="重置密码",
description="使用验证码重置密码"
)
async def reset_password( async def reset_password(
request: Request, request: Request,
email: str = Form(..., description="邮箱地址"), email: str = Form(..., description="邮箱地址"),
@@ -625,32 +578,20 @@ async def reset_password(
重置密码 重置密码
""" """
from app.dependencies.geoip import get_client_ip from app.dependencies.geoip import get_client_ip
# 获取客户端信息 # 获取客户端信息
ip_address = get_client_ip(request) ip_address = get_client_ip(request)
# 重置密码 # 重置密码
success, message = await password_reset_service.reset_password( success, message = await password_reset_service.reset_password(
email=email.lower().strip(), email=email.lower().strip(),
reset_code=reset_code.strip(), reset_code=reset_code.strip(),
new_password=new_password, new_password=new_password,
ip_address=ip_address, ip_address=ip_address,
redis=redis redis=redis,
) )
if success: if success:
return JSONResponse( return JSONResponse(status_code=200, content={"success": True, "message": message})
status_code=200,
content={
"success": True,
"message": message
}
)
else: else:
return JSONResponse( return JSONResponse(status_code=400, content={"success": False, "error": message})
status_code=400,
content={
"success": False,
"error": message
}
)

View File

@@ -43,9 +43,9 @@ async def get_notifications(
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
if settings.server_url is not None: if settings.server_url is not None:
notification_endpoint = f"{settings.server_url}notification-server".replace( notification_endpoint = f"{settings.server_url}notification-server".replace("http://", "ws://").replace(
"http://", "ws://" "https://", "wss://"
).replace("https://", "wss://") )
else: else:
notification_endpoint = "/notification-server" notification_endpoint = "/notification-server"
query = select(UserNotification).where( query = select(UserNotification).where(
@@ -96,21 +96,15 @@ async def _get_notifications(
query = base_query.where(UserNotification.notification_id == identity.id) query = base_query.where(UserNotification.notification_id == identity.id)
if identity.object_id is not None: if identity.object_id is not None:
query = base_query.where( query = base_query.where(
col(UserNotification.notification).has( col(UserNotification.notification).has(col(Notification.object_id) == identity.object_id)
col(Notification.object_id) == identity.object_id
)
) )
if identity.object_type is not None: if identity.object_type is not None:
query = base_query.where( query = base_query.where(
col(UserNotification.notification).has( col(UserNotification.notification).has(col(Notification.object_type) == identity.object_type)
col(Notification.object_type) == identity.object_type
)
) )
if identity.category is not None: if identity.category is not None:
query = base_query.where( query = base_query.where(
col(UserNotification.notification).has( col(UserNotification.notification).has(col(Notification.category) == identity.category)
col(Notification.category) == identity.category
)
) )
result.update({n.notification_id: n for n in await session.exec(query)}) result.update({n.notification_id: n for n in await session.exec(query)})
return list(result.values()) return list(result.values())
@@ -134,7 +128,6 @@ async def mark_notifications_as_read(
for user_notification in user_notifications: for user_notification in user_notifications:
user_notification.is_read = True user_notification.is_read = True
assert current_user.id
await server.send_event( await server.send_event(
current_user.id, current_user.id,
ChatEvent( ChatEvent(

View File

@@ -91,9 +91,7 @@ class Bot:
if reply: if reply:
await self._send_reply(user, channel, reply, session) await self._send_reply(user, channel, reply, session)
async def _send_message( async def _send_message(self, channel: ChatChannel, content: str, session: AsyncSession) -> None:
self, channel: ChatChannel, content: str, session: AsyncSession
) -> None:
bot = await session.get(User, self.bot_user_id) bot = await session.get(User, self.bot_user_id)
if bot is None: if bot is None:
return return
@@ -101,7 +99,6 @@ class Bot:
if channel_id is None: if channel_id is None:
return return
assert bot.id is not None
msg = ChatMessage( msg = ChatMessage(
channel_id=channel_id, channel_id=channel_id,
content=content, content=content,
@@ -115,9 +112,7 @@ class Bot:
resp = await ChatMessageResp.from_db(msg, session, bot) resp = await ChatMessageResp.from_db(msg, session, bot)
await server.send_message_to_channel(resp) await server.send_message_to_channel(resp)
async def _ensure_pm_channel( async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
self, user: User, session: AsyncSession
) -> ChatChannel | None:
user_id = user.id user_id = user.id
if user_id is None: if user_id is None:
return None return None
@@ -160,9 +155,7 @@ bot = Bot()
@bot.command("help") @bot.command("help")
async def _help( async def _help(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
) -> str:
cmds = sorted(bot._handlers.keys()) cmds = sorted(bot._handlers.keys())
if args: if args:
target = args[0].lower() target = args[0].lower()
@@ -175,9 +168,7 @@ async def _help(
@bot.command("roll") @bot.command("roll")
def _roll( def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
) -> str:
if len(args) > 0 and args[0].isdigit(): if len(args) > 0 and args[0].isdigit():
r = random.randint(1, int(args[0])) r = random.randint(1, int(args[0]))
else: else:
@@ -186,13 +177,9 @@ def _roll(
@bot.command("stats") @bot.command("stats")
async def _stats( async def _stats(user: User, args: list[str], session: AsyncSession, channel: ChatChannel) -> str:
user: User, args: list[str], session: AsyncSession, channel: ChatChannel
) -> str:
if len(args) >= 1: if len(args) >= 1:
target_user = ( target_user = (await session.exec(select(User).where(User.username == args[0]))).first()
await session.exec(select(User).where(User.username == args[0]))
).first()
if not target_user: if not target_user:
return f"User '{args[0]}' not found." return f"User '{args[0]}' not found."
else: else:
@@ -202,14 +189,8 @@ async def _stats(
if len(args) >= 2: if len(args) >= 2:
gamemode = GameMode.parse(args[1].upper()) gamemode = GameMode.parse(args[1].upper())
if gamemode is None: if gamemode is None:
subquery = ( subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
select(func.max(Score.id)) last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
.where(Score.user_id == target_user.id)
.scalar_subquery()
)
last_score = (
await session.exec(select(Score).where(Score.id == subquery))
).first()
if last_score is not None: if last_score is not None:
gamemode = last_score.gamemode gamemode = last_score.gamemode
else: else:
@@ -295,9 +276,7 @@ async def _mp_host(
return "Usage: !mp host <username>" return "Usage: !mp host <username>"
username = args[0] username = args[0]
user_id = ( user_id = (await session.exec(select(User.id).where(User.username == username))).first()
await session.exec(select(User.id).where(User.username == username))
).first()
if not user_id: if not user_id:
return f"User '{username}' not found." return f"User '{username}' not found."
@@ -362,24 +341,18 @@ async def _mp_team(
if team is None: if team is None:
return "Invalid team colour. Use 'red' or 'blue'." return "Invalid team colour. Use 'red' or 'blue'."
user_id = ( user_id = (await session.exec(select(User.id).where(User.username == username))).first()
await session.exec(select(User.id).where(User.username == username))
).first()
if not user_id: if not user_id:
return f"User '{username}' not found." return f"User '{username}' not found."
user_client = MultiplayerHubs.get_client_by_id(str(user_id)) user_client = MultiplayerHubs.get_client_by_id(str(user_id))
if not user_client: if not user_client:
return f"User '{username}' is not in the room." return f"User '{username}' is not in the room."
if ( assert room.room.host
user_client.user_id != signalr_client.user_id if user_client.user_id != signalr_client.user_id and room.room.host.user_id != signalr_client.user_id:
and room.room.host.user_id != signalr_client.user_id
):
return "You are not allowed to change other users' teams." return "You are not allowed to change other users' teams."
try: try:
await MultiplayerHubs.SendMatchRequest( await MultiplayerHubs.SendMatchRequest(user_client, ChangeTeamRequest(team_id=team))
user_client, ChangeTeamRequest(team_id=team)
)
return "" return ""
except InvokeException as e: except InvokeException as e:
return e.message return e.message
@@ -414,9 +387,7 @@ async def _mp_kick(
return "Usage: !mp kick <username>" return "Usage: !mp kick <username>"
username = args[0] username = args[0]
user_id = ( user_id = (await session.exec(select(User.id).where(User.username == username))).first()
await session.exec(select(User.id).where(User.username == username))
).first()
if not user_id: if not user_id:
return f"User '{username}' not found." return f"User '{username}' not found."
@@ -456,10 +427,7 @@ async def _mp_map(
try: try:
beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id) beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id)
if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode: if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode:
return ( return f"Cannot convert to {playmode.value}. Original mode is {beatmap.mode.value}."
f"Cannot convert to {playmode.value}. "
f"Original mode is {beatmap.mode.value}."
)
except HTTPError: except HTTPError:
return "Beatmap not found" return "Beatmap not found"
@@ -530,9 +498,7 @@ async def _mp_mods(
if freestyle: if freestyle:
item.allowed_mods = [] item.allowed_mods = []
elif freemod: elif freemod:
item.allowed_mods = get_available_mods( item.allowed_mods = get_available_mods(current_item.ruleset_id, required_mods)
current_item.ruleset_id, required_mods
)
else: else:
item.allowed_mods = allowed_mods item.allowed_mods = allowed_mods
item.required_mods = required_mods item.required_mods = required_mods
@@ -601,14 +567,9 @@ async def _score(
include_fail: bool = False, include_fail: bool = False,
gamemode: GameMode | None = None, gamemode: GameMode | None = None,
) -> str: ) -> str:
q = ( q = select(Score).where(Score.user_id == user_id).order_by(col(Score.id).desc()).options(joinedload(Score.beatmap))
select(Score)
.where(Score.user_id == user_id)
.order_by(col(Score.id).desc())
.options(joinedload(Score.beatmap))
)
if not include_fail: if not include_fail:
q = q.where(Score.passed.is_(True)) q = q.where(col(Score.passed).is_(True))
if gamemode is not None: if gamemode is not None:
q = q.where(Score.gamemode == gamemode) q = q.where(Score.gamemode == gamemode)
@@ -619,17 +580,13 @@ async def _score(
result = f"""{score.beatmap.beatmapset.title} [{score.beatmap.version}] ({score.gamemode.name.lower()}) result = f"""{score.beatmap.beatmapset.title} [{score.beatmap.version}] ({score.gamemode.name.lower()})
Played at {score.started_at} Played at {score.started_at}
{score.pp:.2f}pp {score.accuracy:.2%} {",".join(mod_to_save(score.mods))} {score.rank.name.upper()} {score.pp:.2f}pp {score.accuracy:.2%} {",".join(mod_to_save(score.mods))} {score.rank.name.upper()}
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}""" # noqa: E501 Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}"""
if score.gamemode == GameMode.MANIA: if score.gamemode == GameMode.MANIA:
keys = next( keys = next((mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None)
(mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None
)
if keys is None: if keys is None:
keys = f"{int(score.beatmap.cs)}K" keys = f"{int(score.beatmap.cs)}K"
p_d_g = f"{score.ngeki / score.n300:.2f}:1" if score.n300 > 0 else "inf:1" p_d_g = f"{score.ngeki / score.n300:.2f}:1" if score.n300 > 0 else "inf:1"
result += ( result += f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
)
return result return result

View File

@@ -38,27 +38,18 @@ class UpdateResponse(BaseModel):
) )
async def get_update( async def get_update(
session: Database, session: Database,
history_since: int | None = Query( history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
None, description="获取自此禁言 ID 之后的禁言记录"
),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"), since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
includes: list[str] = Query( includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"),
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
),
current_user: User = Security(get_current_user, scopes=["chat.read"]), current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
resp = UpdateResponse() resp = UpdateResponse()
if "presence" in includes: if "presence" in includes:
assert current_user.id
channel_ids = server.get_user_joined_channel(current_user.id) channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids: for channel_id in channel_ids:
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
if db_channel: if db_channel:
# 提取必要的属性避免惰性加载 # 提取必要的属性避免惰性加载
channel_type = db_channel.type channel_type = db_channel.type
@@ -69,34 +60,20 @@ async def get_update(
session, session,
current_user, current_user,
redis, redis,
server.channels.get(channel_id, []) server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
if channel_type != ChannelType.PUBLIC
else None,
) )
) )
if "silences" in includes: if "silences" in includes:
if history_since: if history_since:
silences = ( silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
await session.exec( resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
select(SilenceUser).where(col(SilenceUser.id) > history_since)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
elif since: elif since:
msg = await session.get(ChatMessage, since) msg = await session.get(ChatMessage, since)
if msg: if msg:
silences = ( silences = (
await session.exec( await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
select(SilenceUser).where(
col(SilenceUser.banned_at) > msg.timestamp
)
)
).all() ).all()
resp.silences.extend( resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
[UserSilenceResp.from_db(silence) for silence in silences]
)
return resp return resp
@@ -115,15 +92,9 @@ async def join_channel(
): ):
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
if channel.isdigit(): if channel.isdigit():
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else: else:
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
@@ -145,15 +116,9 @@ async def leave_channel(
): ):
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
if channel.isdigit(): if channel.isdigit():
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else: else:
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
@@ -173,27 +138,20 @@ async def get_channel_list(
current_user: User = Security(get_current_user, scopes=["chat.read"]), current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
channels = ( channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
await session.exec(
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
)
).all()
results = [] results = []
for channel in channels: for channel in channels:
# 提取必要的属性避免惰性加载 # 提取必要的属性避免惰性加载
channel_id = channel.channel_id channel_id = channel.channel_id
channel_type = channel.type channel_type = channel.type
assert channel_id is not None
results.append( results.append(
await ChatChannelResp.from_db( await ChatChannelResp.from_db(
channel, channel,
session, session,
current_user, current_user,
redis, redis,
server.channels.get(channel_id, []) server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
if channel_type != ChannelType.PUBLIC
else None,
) )
) )
return results return results
@@ -219,15 +177,9 @@ async def get_channel(
): ):
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
if channel.isdigit(): if channel.isdigit():
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else: else:
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
@@ -237,8 +189,6 @@ async def get_channel(
channel_type = db_channel.type channel_type = db_channel.type
channel_name = db_channel.name channel_name = db_channel.name
assert channel_id is not None
users = [] users = []
if channel_type == ChannelType.PM: if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:] user_ids = channel_name.split("_")[1:]
@@ -259,9 +209,7 @@ async def get_channel(
session, session,
current_user, current_user,
redis, redis,
server.channels.get(channel_id, []) server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
if channel_type != ChannelType.PUBLIC
else None,
) )
) )
@@ -284,9 +232,7 @@ class CreateChannelReq(BaseModel):
raise ValueError("target_id must be set for PM channels") raise ValueError("target_id must be set for PM channels")
else: else:
if self.target_ids is None or self.channel is None or self.message is None: if self.target_ids is None or self.channel is None or self.message is None:
raise ValueError( raise ValueError("target_ids, channel, and message must be set for ANNOUNCE channels")
"target_ids, channel, and message must be set for ANNOUNCE channels"
)
return self return self
@@ -312,24 +258,20 @@ async def create_channel(
raise HTTPException(status_code=403, detail=block) raise HTTPException(status_code=403, detail=block)
channel = await ChatChannel.get_pm_channel( channel = await ChatChannel.get_pm_channel(
current_user.id, # pyright: ignore[reportArgumentType] current_user.id,
req.target_id, # pyright: ignore[reportArgumentType] req.target_id, # pyright: ignore[reportArgumentType]
session, session,
) )
channel_name = f"pm_{current_user.id}_{req.target_id}" channel_name = f"pm_{current_user.id}_{req.target_id}"
else: else:
channel_name = req.channel.name if req.channel else "Unnamed Channel" channel_name = req.channel.name if req.channel else "Unnamed Channel"
result = await session.exec( result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
select(ChatChannel).where(ChatChannel.name == channel_name)
)
channel = result.first() channel = result.first()
if channel is None: if channel is None:
channel = ChatChannel( channel = ChatChannel(
name=channel_name, name=channel_name,
description=req.channel.description description=req.channel.description if req.channel else "Private message channel",
if req.channel
else "Private message channel",
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE, type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
) )
session.add(channel) session.add(channel)
@@ -340,16 +282,13 @@ async def create_channel(
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable] await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable] await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
else: else:
target_users = await session.exec( target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
select(User).where(col(User.id).in_(req.target_ids or []))
)
await server.batch_join_channel([*target_users, current_user], channel, session) await server.batch_join_channel([*target_users, current_user], channel, session)
await server.join_channel(current_user, channel, session) await server.join_channel(current_user, channel, session)
# 提取必要的属性避免惰性加载 # 提取必要的属性避免惰性加载
channel_id = channel.channel_id channel_id = channel.channel_id
assert channel_id
return await ChatChannelResp.from_db( return await ChatChannelResp.from_db(
channel, channel,

View File

@@ -41,33 +41,19 @@ class KeepAliveResp(BaseModel):
) )
async def keep_alive( async def keep_alive(
session: Database, session: Database,
history_since: int | None = Query( history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
None, description="获取自此禁言 ID 之后的禁言记录"
),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"), since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]), current_user: User = Security(get_current_user, scopes=["chat.read"]),
): ):
resp = KeepAliveResp() resp = KeepAliveResp()
if history_since: if history_since:
silences = ( silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
await session.exec(
select(SilenceUser).where(col(SilenceUser.id) > history_since)
)
).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences]) resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
elif since: elif since:
msg = await session.get(ChatMessage, since) msg = await session.get(ChatMessage, since)
if msg: if msg:
silences = ( silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))).all()
await session.exec( resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
select(SilenceUser).where(
col(SilenceUser.banned_at) > msg.timestamp
)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
return resp return resp
@@ -93,15 +79,9 @@ async def send_message(
): ):
# 使用明确的查询来获取 channel避免延迟加载 # 使用明确的查询来获取 channel避免延迟加载
if channel.isdigit(): if channel.isdigit():
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else: else:
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
@@ -111,9 +91,6 @@ async def send_message(
channel_type = db_channel.type channel_type = db_channel.type
channel_name = db_channel.name channel_name = db_channel.name
assert channel_id is not None
assert current_user.id
# 使用 Redis 消息系统发送消息 - 立即返回 # 使用 Redis 消息系统发送消息 - 立即返回
resp = await redis_message_system.send_message( resp = await redis_message_system.send_message(
channel_id=channel_id, channel_id=channel_id,
@@ -125,9 +102,7 @@ async def send_message(
# 立即广播消息给所有客户端 # 立即广播消息给所有客户端
is_bot_command = req.message.startswith("!") is_bot_command = req.message.startswith("!")
await server.send_message_to_channel( await server.send_message_to_channel(resp, is_bot_command and channel_type == ChannelType.PUBLIC)
resp, is_bot_command and channel_type == ChannelType.PUBLIC
)
# 处理机器人命令 # 处理机器人命令
if is_bot_command: if is_bot_command:
@@ -147,14 +122,10 @@ async def send_message(
if channel_type == ChannelType.PM: if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:] user_ids = channel_name.split("_")[1:]
await server.new_private_notification( await server.new_private_notification(
ChannelMessage.init( ChannelMessage.init(temp_msg, current_user, [int(u) for u in user_ids], channel_type)
temp_msg, current_user, [int(u) for u in user_ids], channel_type
)
) )
elif channel_type == ChannelType.TEAM: elif channel_type == ChannelType.TEAM:
await server.new_private_notification( await server.new_private_notification(ChannelMessageTeam.init(temp_msg, current_user))
ChannelMessageTeam.init(temp_msg, current_user)
)
return resp return resp
@@ -176,22 +147,15 @@ async def get_message(
): ):
# 使用明确的查询获取 channel避免延迟加载 # 使用明确的查询获取 channel避免延迟加载
if channel.isdigit(): if channel.isdigit():
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else: else:
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
# 提取必要的属性避免惰性加载 # 提取必要的属性避免惰性加载
channel_id = db_channel.channel_id channel_id = db_channel.channel_id
assert channel_id is not None
# 使用 Redis 消息系统获取消息 # 使用 Redis 消息系统获取消息
try: try:
@@ -230,23 +194,15 @@ async def mark_as_read(
): ):
# 使用明确的查询获取 channel避免延迟加载 # 使用明确的查询获取 channel避免延迟加载
if channel.isdigit(): if channel.isdigit():
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
else: else:
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
if db_channel is None: if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found") raise HTTPException(status_code=404, detail="Channel not found")
# 立即提取需要的属性 # 立即提取需要的属性
channel_id = db_channel.channel_id channel_id = db_channel.channel_id
assert channel_id
assert current_user.id
await server.mark_as_read(channel_id, current_user.id, message) await server.mark_as_read(channel_id, current_user.id, message)
@@ -283,7 +239,6 @@ async def create_new_pm(
if not is_can_pm: if not is_can_pm:
raise HTTPException(status_code=403, detail=block) raise HTTPException(status_code=403, detail=block)
assert user_id
channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session) channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session)
if channel is None: if channel is None:
channel = ChatChannel( channel = ChatChannel(
@@ -297,7 +252,6 @@ async def create_new_pm(
await session.refresh(target) await session.refresh(target)
await session.refresh(current_user) await session.refresh(current_user)
assert channel.channel_id
await server.batch_join_channel([target, current_user], channel, session) await server.batch_join_channel([target, current_user], channel, session)
channel_resp = await ChatChannelResp.from_db( channel_resp = await ChatChannelResp.from_db(
channel, session, current_user, redis, server.channels[channel.channel_id] channel, session, current_user, redis, server.channels[channel.channel_id]

View File

@@ -17,6 +17,7 @@ from app.log import logger
from app.models.chat import ChatEvent from app.models.chat import ChatEvent
from app.models.notification import NotificationDetail from app.models.notification import NotificationDetail
from app.service.subscribers.chat import ChatSubscriber from app.service.subscribers.chat import ChatSubscriber
from app.utils import bg_tasks
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes from fastapi.security import SecurityScopes
@@ -37,20 +38,11 @@ class ChatServer:
self.ChatSubscriber.chat_server = self self.ChatSubscriber.chat_server = self
self._subscribed = False self._subscribed = False
def _add_task(self, task):
task = asyncio.create_task(task)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
def connect(self, user_id: int, client: WebSocket): def connect(self, user_id: int, client: WebSocket):
self.connect_client[user_id] = client self.connect_client[user_id] = client
def get_user_joined_channel(self, user_id: int) -> list[int]: def get_user_joined_channel(self, user_id: int) -> list[int]:
return [ return [channel_id for channel_id, users in self.channels.items() if user_id in users]
channel_id
for channel_id, users in self.channels.items()
if user_id in users
]
async def disconnect(self, user: User, session: AsyncSession): async def disconnect(self, user: User, session: AsyncSession):
user_id = user.id user_id = user.id
@@ -61,9 +53,7 @@ class ChatServer:
channel.remove(user_id) channel.remove(user_id)
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
db_channel = ( db_channel = (
await session.exec( await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first() ).first()
if db_channel: if db_channel:
await self.leave_channel(user, db_channel, session) await self.leave_channel(user, db_channel, session)
@@ -93,11 +83,10 @@ class ChatServer:
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int): async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id) await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
async def send_message_to_channel( async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
self, message: ChatMessageResp, is_bot_command: bool = False
):
logger.info( logger.info(
f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}" f"Sending message to channel {message.channel_id}, message_id: "
f"{message.message_id}, is_bot_command: {is_bot_command}"
) )
event = ChatEvent( event = ChatEvent(
@@ -106,62 +95,44 @@ class ChatServer:
) )
if is_bot_command: if is_bot_command:
logger.info(f"Sending bot command to user {message.sender_id}") logger.info(f"Sending bot command to user {message.sender_id}")
self._add_task(self.send_event(message.sender_id, event)) bg_tasks.add_task(self.send_event, message.sender_id, event)
else: else:
# 总是广播消息无论是临时ID还是真实ID # 总是广播消息无论是临时ID还是真实ID
logger.info( logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
f"Broadcasting message to all users in channel {message.channel_id}" bg_tasks.add_task(
) self.broadcast,
self._add_task( message.channel_id,
self.broadcast( event,
message.channel_id,
event,
)
) )
# 只有真实消息 ID正数且非零才进行标记已读和设置最后消息 # 只有真实消息 ID正数且非零才进行标记已读和设置最后消息
# Redis 消息系统生成的ID都是正数所以这里应该都能正常处理 # Redis 消息系统生成的ID都是正数所以这里应该都能正常处理
if message.message_id and message.message_id > 0: if message.message_id and message.message_id > 0:
await self.mark_as_read( await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
message.channel_id, message.sender_id, message.message_id await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
) logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
await self.redis.set(
f"chat:{message.channel_id}:last_msg", message.message_id
)
logger.info(
f"Updated last message ID for channel {message.channel_id} to {message.message_id}"
)
else: else:
logger.debug( logger.debug(f"Skipping last message update for message ID: {message.message_id}")
f"Skipping last message update for message ID: {message.message_id}"
)
async def batch_join_channel( async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
self, users: list[User], channel: ChatChannel, session: AsyncSession
):
channel_id = channel.channel_id channel_id = channel.channel_id
assert channel_id is not None
not_joined = [] not_joined = []
if channel_id not in self.channels: if channel_id not in self.channels:
self.channels[channel_id] = [] self.channels[channel_id] = []
for user in users: for user in users:
assert user.id is not None
if user.id not in self.channels[channel_id]: if user.id not in self.channels[channel_id]:
self.channels[channel_id].append(user.id) self.channels[channel_id].append(user.id)
not_joined.append(user) not_joined.append(user)
for user in not_joined: for user in not_joined:
assert user.id is not None
channel_resp = await ChatChannelResp.from_db( channel_resp = await ChatChannelResp.from_db(
channel, channel,
session, session,
user, user,
self.redis, self.redis,
self.channels[channel_id] self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
if channel.type != ChannelType.PUBLIC
else None,
) )
await self.send_event( await self.send_event(
user.id, user.id,
@@ -171,13 +142,9 @@ class ChatServer:
), ),
) )
async def join_channel( async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
self, user: User, channel: ChatChannel, session: AsyncSession
) -> ChatChannelResp:
user_id = user.id user_id = user.id
channel_id = channel.channel_id channel_id = channel.channel_id
assert channel_id is not None
assert user_id is not None
if channel_id not in self.channels: if channel_id not in self.channels:
self.channels[channel_id] = [] self.channels[channel_id] = []
@@ -202,13 +169,9 @@ class ChatServer:
return channel_resp return channel_resp
async def leave_channel( async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
self, user: User, channel: ChatChannel, session: AsyncSession
) -> None:
user_id = user.id user_id = user.id
channel_id = channel.channel_id channel_id = channel.channel_id
assert channel_id is not None
assert user_id is not None
if channel_id in self.channels and user_id in self.channels[channel_id]: if channel_id in self.channels and user_id in self.channels[channel_id]:
self.channels[channel_id].remove(user_id) self.channels[channel_id].remove(user_id)
@@ -221,9 +184,7 @@ class ChatServer:
session, session,
user, user,
self.redis, self.redis,
self.channels.get(channel_id) self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
if channel.type != ChannelType.PUBLIC
else None,
) )
await self.send_event( await self.send_event(
user_id, user_id,
@@ -236,11 +197,7 @@ class ChatServer:
async def join_room_channel(self, channel_id: int, user_id: int): async def join_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session: async with with_db() as session:
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
if db_channel is None: if db_channel is None:
return return
@@ -253,11 +210,7 @@ class ChatServer:
async def leave_room_channel(self, channel_id: int, user_id: int): async def leave_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session: async with with_db() as session:
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
if db_channel is None: if db_channel is None:
return return
@@ -270,13 +223,7 @@ class ChatServer:
async def new_private_notification(self, detail: NotificationDetail): async def new_private_notification(self, detail: NotificationDetail):
async with with_db() as session: async with with_db() as session:
id = await insert_notification(session, detail) id = await insert_notification(session, detail)
users = ( users = (await session.exec(select(UserNotification).where(UserNotification.notification_id == id))).all()
await session.exec(
select(UserNotification).where(
UserNotification.notification_id == id
)
)
).all()
for user_notification in users: for user_notification in users:
data = user_notification.notification.model_dump() data = user_notification.notification.model_dump()
data["is_read"] = user_notification.is_read data["is_read"] = user_notification.is_read
@@ -308,9 +255,7 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
await ws.close(code=1000) await ws.close(code=1000)
break break
except WebSocketDisconnect as e: except WebSocketDisconnect as e:
logger.info( logger.info(f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}")
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
)
except RuntimeError as e: except RuntimeError as e:
if "disconnect message" in str(e): if "disconnect message" in str(e):
logger.info(f"[NotificationServer] Client {user_id} closed the connection.") logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
@@ -332,11 +277,7 @@ async def chat_websocket(
async for session in factory(): async for session in factory():
token = authorization[7:] token = authorization[7:]
if ( if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None:
user := await get_current_user(
session, SecurityScopes(scopes=["chat.read"]), token_pw=token
)
) is None:
await websocket.close(code=1008) await websocket.close(code=1008)
return return
@@ -346,12 +287,9 @@ async def chat_websocket(
await websocket.close(code=1008) await websocket.close(code=1008)
return return
user_id = user.id user_id = user.id
assert user_id
server.connect(user_id, websocket) server.connect(user_id, websocket)
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
db_channel = ( db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))
).first()
if db_channel is not None: if db_channel is not None:
await server.join_channel(user, db_channel, session) await server.join_channel(user, db_channel, session)

View File

@@ -2,22 +2,20 @@
密码重置管理接口 密码重置管理接口
""" """
from fastapi import APIRouter, Depends, HTTPException from __future__ import annotations
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
from app.service.password_reset_service import password_reset_service
from app.log import logger from app.log import logger
from app.service.password_reset_service import password_reset_service
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
router = APIRouter(prefix="/admin/password-reset", tags=["密码重置管理"]) router = APIRouter(prefix="/admin/password-reset", tags=["密码重置管理"])
@router.get( @router.get("/status/{email}", name="查询重置状态", description="查询指定邮箱的密码重置状态")
"/status/{email}",
name="查询重置状态",
description="查询指定邮箱的密码重置状态"
)
async def get_password_reset_status( async def get_password_reset_status(
email: str, email: str,
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
@@ -25,28 +23,16 @@ async def get_password_reset_status(
"""查询密码重置状态""" """查询密码重置状态"""
try: try:
info = await password_reset_service.get_reset_code_info(email, redis) info = await password_reset_service.get_reset_code_info(email, redis)
return JSONResponse( return JSONResponse(status_code=200, content={"success": True, "data": info})
status_code=200,
content={
"success": True,
"data": info
}
)
except Exception as e: except Exception as e:
logger.error(f"[Admin] Failed to get password reset status for {email}: {e}") logger.error(f"[Admin] Failed to get password reset status for {email}: {e}")
return JSONResponse( return JSONResponse(status_code=500, content={"success": False, "error": "获取状态失败"})
status_code=500,
content={
"success": False,
"error": "获取状态失败"
}
)
@router.delete( @router.delete(
"/cleanup/{email}", "/cleanup/{email}",
name="清理重置数据", name="清理重置数据",
description="强制清理指定邮箱的密码重置数据" description="强制清理指定邮箱的密码重置数据",
) )
async def force_cleanup_reset( async def force_cleanup_reset(
email: str, email: str,
@@ -55,38 +41,23 @@ async def force_cleanup_reset(
"""强制清理密码重置数据""" """强制清理密码重置数据"""
try: try:
success = await password_reset_service.force_cleanup_user_reset(email, redis) success = await password_reset_service.force_cleanup_user_reset(email, redis)
if success: if success:
return JSONResponse( return JSONResponse(
status_code=200, status_code=200,
content={ content={"success": True, "message": f"已清理邮箱 {email} 的重置数据"},
"success": True,
"message": f"已清理邮箱 {email} 的重置数据"
}
) )
else: else:
return JSONResponse( return JSONResponse(status_code=500, content={"success": False, "error": "清理失败"})
status_code=500,
content={
"success": False,
"error": "清理失败"
}
)
except Exception as e: except Exception as e:
logger.error(f"[Admin] Failed to cleanup password reset for {email}: {e}") logger.error(f"[Admin] Failed to cleanup password reset for {email}: {e}")
return JSONResponse( return JSONResponse(status_code=500, content={"success": False, "error": "清理操作失败"})
status_code=500,
content={
"success": False,
"error": "清理操作失败"
}
)
@router.post( @router.post(
"/cleanup/expired", "/cleanup/expired",
name="清理过期验证码", name="清理过期验证码",
description="清理所有过期的密码重置验证码" description="清理所有过期的密码重置验证码",
) )
async def cleanup_expired_codes( async def cleanup_expired_codes(
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
@@ -99,25 +70,15 @@ async def cleanup_expired_codes(
content={ content={
"success": True, "success": True,
"message": f"已清理 {count} 个过期的验证码", "message": f"已清理 {count} 个过期的验证码",
"cleaned_count": count "cleaned_count": count,
} },
) )
except Exception as e: except Exception as e:
logger.error(f"[Admin] Failed to cleanup expired codes: {e}") logger.error(f"[Admin] Failed to cleanup expired codes: {e}")
return JSONResponse( return JSONResponse(status_code=500, content={"success": False, "error": "清理操作失败"})
status_code=500,
content={
"success": False,
"error": "清理操作失败"
}
)
@router.get( @router.get("/stats", name="重置统计", description="获取密码重置的统计信息")
"/stats",
name="重置统计",
description="获取密码重置的统计信息"
)
async def get_reset_statistics( async def get_reset_statistics(
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
@@ -126,53 +87,42 @@ async def get_reset_statistics(
# 获取所有重置相关的键 # 获取所有重置相关的键
reset_keys = await redis.keys("password_reset:code:*") reset_keys = await redis.keys("password_reset:code:*")
rate_limit_keys = await redis.keys("password_reset:rate_limit:*") rate_limit_keys = await redis.keys("password_reset:rate_limit:*")
active_resets = 0 active_resets = 0
used_resets = 0 used_resets = 0
active_rate_limits = 0 active_rate_limits = 0
# 统计活跃重置 # 统计活跃重置
for key in reset_keys: for key in reset_keys:
data_str = await redis.get(key) data_str = await redis.get(key)
if data_str: if data_str:
try: try:
import json import json
data = json.loads(data_str) data = json.loads(data_str)
if data.get("used", False): if data.get("used", False):
used_resets += 1 used_resets += 1
else: else:
active_resets += 1 active_resets += 1
except: except Exception:
pass pass
# 统计频率限制 # 统计频率限制
for key in rate_limit_keys: for key in rate_limit_keys:
ttl = await redis.ttl(key) ttl = await redis.ttl(key)
if ttl > 0: if ttl > 0:
active_rate_limits += 1 active_rate_limits += 1
stats = { stats = {
"total_reset_codes": len(reset_keys), "total_reset_codes": len(reset_keys),
"active_resets": active_resets, "active_resets": active_resets,
"used_resets": used_resets, "used_resets": used_resets,
"active_rate_limits": active_rate_limits, "active_rate_limits": active_rate_limits,
"total_rate_limit_keys": len(rate_limit_keys) "total_rate_limit_keys": len(rate_limit_keys),
} }
return JSONResponse( return JSONResponse(status_code=200, content={"success": True, "data": stats})
status_code=200,
content={
"success": True,
"data": stats
}
)
except Exception as e: except Exception as e:
logger.error(f"[Admin] Failed to get reset statistics: {e}") logger.error(f"[Admin] Failed to get reset statistics: {e}")
return JSONResponse( return JSONResponse(status_code=500, content={"success": False, "error": "获取统计信息失败"})
status_code=500,
content={
"success": False,
"error": "获取统计信息失败"
}
)

View File

@@ -26,7 +26,7 @@ async def create_oauth_app(
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"), redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
result = await session.execute( # pyright: ignore[reportDeprecated] result = await session.execute(
text( text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES " "SELECT AUTO_INCREMENT FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'oauth_clients'" "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'oauth_clients'"
@@ -84,9 +84,7 @@ async def get_user_oauth_apps(
session: Database, session: Database,
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
oauth_apps = await session.exec( oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id))
select(OAuthClient).where(OAuthClient.owner_id == current_user.id)
)
return [ return [
{ {
"name": app.name, "name": app.name,
@@ -113,13 +111,9 @@ async def delete_oauth_app(
if not oauth_client: if not oauth_client:
raise HTTPException(status_code=404, detail="OAuth app not found") raise HTTPException(status_code=404, detail="OAuth app not found")
if oauth_client.owner_id != current_user.id: if oauth_client.owner_id != current_user.id:
raise HTTPException( raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
status_code=403, detail="Forbidden: Not the owner of this app"
)
tokens = await session.exec( tokens = await session.exec(select(OAuthToken).where(OAuthToken.client_id == client_id))
select(OAuthToken).where(OAuthToken.client_id == client_id)
)
for token in tokens: for token in tokens:
await session.delete(token) await session.delete(token)
@@ -144,9 +138,7 @@ async def update_oauth_app(
if not oauth_client: if not oauth_client:
raise HTTPException(status_code=404, detail="OAuth app not found") raise HTTPException(status_code=404, detail="OAuth app not found")
if oauth_client.owner_id != current_user.id: if oauth_client.owner_id != current_user.id:
raise HTTPException( raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
status_code=403, detail="Forbidden: Not the owner of this app"
)
oauth_client.name = name oauth_client.name = name
oauth_client.description = description oauth_client.description = description
@@ -176,14 +168,10 @@ async def refresh_secret(
if not oauth_client: if not oauth_client:
raise HTTPException(status_code=404, detail="OAuth app not found") raise HTTPException(status_code=404, detail="OAuth app not found")
if oauth_client.owner_id != current_user.id: if oauth_client.owner_id != current_user.id:
raise HTTPException( raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
status_code=403, detail="Forbidden: Not the owner of this app"
)
oauth_client.client_secret = secrets.token_hex() oauth_client.client_secret = secrets.token_hex()
tokens = await session.exec( tokens = await session.exec(select(OAuthToken).where(OAuthToken.client_id == client_id))
select(OAuthToken).where(OAuthToken.client_id == client_id)
)
for token in tokens: for token in tokens:
await session.delete(token) await session.delete(token)
@@ -215,9 +203,7 @@ async def generate_oauth_code(
raise HTTPException(status_code=404, detail="OAuth app not found") raise HTTPException(status_code=404, detail="OAuth app not found")
if redirect_uri not in client.redirect_uris: if redirect_uri not in client.redirect_uris:
raise HTTPException( raise HTTPException(status_code=403, detail="Redirect URI not allowed for this client")
status_code=403, detail="Redirect URI not allowed for this client"
)
code = secrets.token_urlsafe(80) code = secrets.token_urlsafe(80)
await redis.hset( # pyright: ignore[reportGeneralTypeIssues] await redis.hset( # pyright: ignore[reportGeneralTypeIssues]

View File

@@ -50,12 +50,8 @@ async def check_user_relationship(
) )
).first() ).first()
is_followed = bool( is_followed = bool(target_relationship and target_relationship.type == RelationshipType.FOLLOW)
target_relationship and target_relationship.type == RelationshipType.FOLLOW is_following = bool(my_relationship and my_relationship.type == RelationshipType.FOLLOW)
)
is_following = bool(
my_relationship and my_relationship.type == RelationshipType.FOLLOW
)
return CheckResponse( return CheckResponse(
is_followed=is_followed, is_followed=is_followed,

View File

@@ -40,16 +40,13 @@ async def create_team(
支持的图片格式: PNG、JPEG、GIF 支持的图片格式: PNG、JPEG、GIF
""" """
user_id = current_user.id user_id = current_user.id
assert user_id
if (await current_user.awaitable_attrs.team_membership) is not None: if (await current_user.awaitable_attrs.team_membership) is not None:
raise HTTPException(status_code=403, detail="You are already in a team") raise HTTPException(status_code=403, detail="You are already in a team")
is_existed = (await session.exec(select(exists()).where(Team.name == name))).first() is_existed = (await session.exec(select(exists()).where(Team.name == name))).first()
if is_existed: if is_existed:
raise HTTPException(status_code=409, detail="Name already exists") raise HTTPException(status_code=409, detail="Name already exists")
is_existed = ( is_existed = (await session.exec(select(exists()).where(Team.short_name == short_name))).first()
await session.exec(select(exists()).where(Team.short_name == short_name))
).first()
if is_existed: if is_existed:
raise HTTPException(status_code=409, detail="Short name already exists") raise HTTPException(status_code=409, detail="Short name already exists")
@@ -101,7 +98,6 @@ async def update_team(
""" """
team = await session.get(Team, team_id) team = await session.get(Team, team_id)
user_id = current_user.id user_id = current_user.id
assert user_id
if not team: if not team:
raise HTTPException(status_code=404, detail="Team not found") raise HTTPException(status_code=404, detail="Team not found")
if team.leader_id != user_id: if team.leader_id != user_id:
@@ -110,9 +106,7 @@ async def update_team(
is_existed = (await session.exec(select(exists()).where(Team.name == name))).first() is_existed = (await session.exec(select(exists()).where(Team.name == name))).first()
if is_existed: if is_existed:
raise HTTPException(status_code=409, detail="Name already exists") raise HTTPException(status_code=409, detail="Name already exists")
is_existed = ( is_existed = (await session.exec(select(exists()).where(Team.short_name == short_name))).first()
await session.exec(select(exists()).where(Team.short_name == short_name))
).first()
if is_existed: if is_existed:
raise HTTPException(status_code=409, detail="Short name already exists") raise HTTPException(status_code=409, detail="Short name already exists")
@@ -132,20 +126,12 @@ async def update_team(
team.cover_url = await storage.get_file_url(storage_path) team.cover_url = await storage.get_file_url(storage_path)
if leader_id is not None: if leader_id is not None:
if not ( if not (await session.exec(select(exists()).where(User.id == leader_id))).first():
await session.exec(select(exists()).where(User.id == leader_id))
).first():
raise HTTPException(status_code=404, detail="Leader not found") raise HTTPException(status_code=404, detail="Leader not found")
if not ( if not (
await session.exec( await session.exec(select(TeamMember).where(TeamMember.user_id == leader_id, TeamMember.team_id == team.id))
select(TeamMember).where(
TeamMember.user_id == leader_id, TeamMember.team_id == team.id
)
)
).first(): ).first():
raise HTTPException( raise HTTPException(status_code=404, detail="Leader is not a member of the team")
status_code=404, detail="Leader is not a member of the team"
)
team.leader_id = leader_id team.leader_id = leader_id
await session.commit() await session.commit()
@@ -166,9 +152,7 @@ async def delete_team(
if team.leader_id != current_user.id: if team.leader_id != current_user.id:
raise HTTPException(status_code=403, detail="You are not the team leader") raise HTTPException(status_code=403, detail="You are not the team leader")
team_members = await session.exec( team_members = await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))
select(TeamMember).where(TeamMember.team_id == team_id)
)
for member in team_members: for member in team_members:
await session.delete(member) await session.delete(member)
@@ -186,15 +170,10 @@ async def get_team(
session: Database, session: Database,
team_id: int = Path(..., description="战队 ID"), team_id: int = Path(..., description="战队 ID"),
): ):
members = ( members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all()
await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))
).all()
return TeamQueryResp( return TeamQueryResp(
team=members[0].team, team=members[0].team,
members=[ members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members],
await UserResp.from_db(m.user, session, include=BASE_INCLUDES)
for m in members
],
) )
@@ -213,15 +192,11 @@ async def request_join_team(
if ( if (
await session.exec( await session.exec(
select(exists()).where( select(exists()).where(TeamRequest.team_id == team_id, TeamRequest.user_id == current_user.id)
TeamRequest.team_id == team_id, TeamRequest.user_id == current_user.id
)
) )
).first(): ).first():
raise HTTPException(status_code=409, detail="Join request already exists") raise HTTPException(status_code=409, detail="Join request already exists")
team_request = TeamRequest( team_request = TeamRequest(user_id=current_user.id, team_id=team_id, requested_at=datetime.now(UTC))
user_id=current_user.id, team_id=team_id, requested_at=datetime.now(UTC)
)
session.add(team_request) session.add(team_request)
await session.commit() await session.commit()
await session.refresh(team_request) await session.refresh(team_request)
@@ -229,9 +204,7 @@ async def request_join_team(
@router.post("/team/{team_id}/{user_id}/request", name="接受加入请求", status_code=204) @router.post("/team/{team_id}/{user_id}/request", name="接受加入请求", status_code=204)
@router.delete( @router.delete("/team/{team_id}/{user_id}/request", name="拒绝加入请求", status_code=204)
"/team/{team_id}/{user_id}/request", name="拒绝加入请求", status_code=204
)
async def handle_request( async def handle_request(
req: Request, req: Request,
session: Database, session: Database,
@@ -247,11 +220,7 @@ async def handle_request(
raise HTTPException(status_code=403, detail="You are not the team leader") raise HTTPException(status_code=403, detail="You are not the team leader")
team_request = ( team_request = (
await session.exec( await session.exec(select(TeamRequest).where(TeamRequest.team_id == team_id, TeamRequest.user_id == user_id))
select(TeamRequest).where(
TeamRequest.team_id == team_id, TeamRequest.user_id == user_id
)
)
).first() ).first()
if not team_request: if not team_request:
raise HTTPException(status_code=404, detail="Join request not found") raise HTTPException(status_code=404, detail="Join request not found")
@@ -261,16 +230,10 @@ async def handle_request(
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
if req.method == "POST": if req.method == "POST":
if ( if (await session.exec(select(exists()).where(TeamMember.user_id == user_id))).first():
await session.exec(select(exists()).where(TeamMember.user_id == user_id)) raise HTTPException(status_code=409, detail="User is already a member of the team")
).first():
raise HTTPException(
status_code=409, detail="User is already a member of the team"
)
session.add( session.add(TeamMember(user_id=user_id, team_id=team_id, joined_at=datetime.now(UTC)))
TeamMember(user_id=user_id, team_id=team_id, joined_at=datetime.now(UTC))
)
await server.new_private_notification(TeamApplicationAccept.init(team_request)) await server.new_private_notification(TeamApplicationAccept.init(team_request))
else: else:
@@ -294,19 +257,13 @@ async def kick_member(
raise HTTPException(status_code=403, detail="You are not the team leader") raise HTTPException(status_code=403, detail="You are not the team leader")
team_member = ( team_member = (
await session.exec( await session.exec(select(TeamMember).where(TeamMember.team_id == team_id, TeamMember.user_id == user_id))
select(TeamMember).where(
TeamMember.team_id == team_id, TeamMember.user_id == user_id
)
)
).first() ).first()
if not team_member: if not team_member:
raise HTTPException(status_code=404, detail="User is not a member of the team") raise HTTPException(status_code=404, detail="User is not a member of the team")
if team.leader_id == current_user.id: if team.leader_id == current_user.id:
raise HTTPException( raise HTTPException(status_code=403, detail="You cannot leave because you are the team leader")
status_code=403, detail="You cannot leave because you are the team leader"
)
await session.delete(team_member) await session.delete(team_member)
await session.commit() await session.commit()

View File

@@ -35,10 +35,7 @@ async def user_rename(
返回: 返回:
- 成功: None - 成功: None
""" """
assert current_user is not None samename_user = (await session.exec(select(User).where(User.username == new_name))).first()
samename_user = (
await session.exec(select(User).where(User.username == new_name))
).first()
if samename_user: if samename_user:
raise HTTPException(409, "Username Exisits") raise HTTPException(409, "Username Exisits")
errors = validate_username(new_name) errors = validate_username(new_name)

View File

@@ -106,9 +106,7 @@ class V1Beatmap(AllStrModel):
await session.exec( await session.exec(
select(func.count()) select(func.count())
.select_from(FavouriteBeatmapset) .select_from(FavouriteBeatmapset)
.where( .where(FavouriteBeatmapset.beatmapset_id == db_beatmap.beatmapset.id)
FavouriteBeatmapset.beatmapset_id == db_beatmap.beatmapset.id
)
) )
).one(), ).one(),
rating=0, # TODO rating=0, # TODO
@@ -154,12 +152,8 @@ async def get_beatmaps(
beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"), beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"),
beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"), beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"),
user: str | None = Query(None, alias="u", description="谱师"), user: str | None = Query(None, alias="u", description="谱师"),
type: Literal["string", "id"] | None = Query( type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
None, description="用户类型string 用户名称 / id 用户 ID" ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0, le=3), # TODO
),
ruleset_id: int | None = Query(
None, alias="m", description="Ruleset ID", ge=0, le=3
), # TODO
convert: bool = Query(False, alias="a", description="转谱"), # TODO convert: bool = Query(False, alias="a", description="转谱"), # TODO
checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"), checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"),
limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"), limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"),
@@ -181,11 +175,7 @@ async def get_beatmaps(
else: else:
beatmaps = beatmapset.beatmaps beatmaps = beatmapset.beatmaps
elif user is not None: elif user is not None:
where = ( where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
Beatmapset.user_id == user
if type == "id" or user.isdigit()
else Beatmapset.creator == user
)
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all() beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()
for beatmapset in beatmapsets: for beatmapset in beatmapsets:
if len(beatmaps) >= limit: if len(beatmaps) >= limit:
@@ -193,11 +183,7 @@ async def get_beatmaps(
beatmaps.extend(beatmapset.beatmaps) beatmaps.extend(beatmapset.beatmaps)
elif since is not None: elif since is not None:
beatmapsets = ( beatmapsets = (
await session.exec( await session.exec(select(Beatmapset).where(col(Beatmapset.ranked_date) > since).limit(limit))
select(Beatmapset)
.where(col(Beatmapset.ranked_date) > since)
.limit(limit)
)
).all() ).all()
for beatmapset in beatmapsets: for beatmapset in beatmapsets:
if len(beatmaps) >= limit: if len(beatmaps) >= limit:
@@ -214,11 +200,7 @@ async def get_beatmaps(
redis, redis,
fetcher, fetcher,
) )
results.append( results.append(await V1Beatmap.from_db(session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty))
await V1Beatmap.from_db(
session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty
)
)
continue continue
except Exception: except Exception:
... ...

View File

@@ -41,9 +41,7 @@ async def download_replay(
ge=0, ge=0,
), ),
score_id: int | None = Query(None, alias="s", description="成绩 ID"), score_id: int | None = Query(None, alias="s", description="成绩 ID"),
type: Literal["string", "id"] | None = Query( type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
None, description="用户类型string 用户名称 / id 用户 ID"
),
mods: int = Query(0, description="成绩的 MOD"), mods: int = Query(0, description="成绩的 MOD"),
storage_service: StorageService = Depends(get_storage_service), storage_service: StorageService = Depends(get_storage_service),
): ):
@@ -58,13 +56,9 @@ async def download_replay(
await session.exec( await session.exec(
select(Score).where( select(Score).where(
Score.beatmap_id == beatmap, Score.beatmap_id == beatmap,
Score.user_id == user Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
if type == "id" or user.isdigit()
else col(Score.user).has(username=user),
Score.mods == mods_, Score.mods == mods_,
Score.gamemode == GameMode.from_int_extra(ruleset_id) Score.gamemode == GameMode.from_int_extra(ruleset_id) if ruleset_id is not None else True,
if ruleset_id is not None
else True,
) )
) )
).first() ).first()
@@ -73,10 +67,7 @@ async def download_replay(
except KeyError: except KeyError:
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail="Invalid request")
filepath = ( filepath = f"replays/{score_record.id}_{score_record.beatmap_id}_{score_record.user_id}_lazer_replay.osr"
f"replays/{score_record.id}_{score_record.beatmap_id}"
f"_{score_record.user_id}_lazer_replay.osr"
)
if not await storage_service.is_exists(filepath): if not await storage_service.is_exists(filepath):
raise HTTPException(status_code=404, detail="Replay file not found") raise HTTPException(status_code=404, detail="Replay file not found")
@@ -100,6 +91,4 @@ async def download_replay(
await session.commit() await session.commit()
data = await storage_service.read_file(filepath) data = await storage_service.read_file(filepath)
return ReplayModel( return ReplayModel(content=base64.b64encode(data).decode("utf-8"), encoding="base64")
content=base64.b64encode(data).decode("utf-8"), encoding="base64"
)

View File

@@ -8,9 +8,7 @@ from app.dependencies.user import v1_authorize
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from pydantic import BaseModel, field_serializer from pydantic import BaseModel, field_serializer
router = APIRouter( router = APIRouter(prefix="/api/v1", dependencies=[Depends(v1_authorize)], tags=["V1 API"])
prefix="/api/v1", dependencies=[Depends(v1_authorize)], tags=["V1 API"]
)
class AllStrModel(BaseModel): class AllStrModel(BaseModel):

View File

@@ -70,9 +70,7 @@ async def get_user_best(
session: Database, session: Database,
user: str = Query(..., alias="u", description="用户"), user: str = Query(..., alias="u", description="用户"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query( type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
None, description="用户类型string 用户名称 / id 用户 ID"
),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
): ):
try: try:
@@ -80,9 +78,7 @@ async def get_user_best(
await session.exec( await session.exec(
select(Score) select(Score)
.where( .where(
Score.user_id == user Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
if type == "id" or user.isdigit()
else col(Score.user).has(username=user),
Score.gamemode == GameMode.from_int_extra(ruleset_id), Score.gamemode == GameMode.from_int_extra(ruleset_id),
exists().where(col(PPBestScore.score_id) == Score.id), exists().where(col(PPBestScore.score_id) == Score.id),
) )
@@ -106,9 +102,7 @@ async def get_user_recent(
session: Database, session: Database,
user: str = Query(..., alias="u", description="用户"), user: str = Query(..., alias="u", description="用户"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query( type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
None, description="用户类型string 用户名称 / id 用户 ID"
),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
): ):
try: try:
@@ -116,9 +110,7 @@ async def get_user_recent(
await session.exec( await session.exec(
select(Score) select(Score)
.where( .where(
Score.user_id == user Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
if type == "id" or user.isdigit()
else col(Score.user).has(username=user),
Score.gamemode == GameMode.from_int_extra(ruleset_id), Score.gamemode == GameMode.from_int_extra(ruleset_id),
Score.ended_at > datetime.now(UTC) - timedelta(hours=24), Score.ended_at > datetime.now(UTC) - timedelta(hours=24),
) )
@@ -143,9 +135,7 @@ async def get_scores(
user: str | None = Query(None, alias="u", description="用户"), user: str | None = Query(None, alias="u", description="用户"),
beatmap_id: int = Query(alias="b", description="谱面 ID"), beatmap_id: int = Query(alias="b", description="谱面 ID"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query( type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
None, description="用户类型string 用户名称 / id 用户 ID"
),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
mods: int = Query(0, description="成绩的 MOD"), mods: int = Query(0, description="成绩的 MOD"),
): ):
@@ -157,9 +147,7 @@ async def get_scores(
.where( .where(
Score.gamemode == GameMode.from_int_extra(ruleset_id), Score.gamemode == GameMode.from_int_extra(ruleset_id),
Score.beatmap_id == beatmap_id, Score.beatmap_id == beatmap_id,
Score.user_id == user Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
if type == "id" or user.isdigit()
else col(Score.user).has(username=user),
) )
.options(joinedload(Score.beatmap)) .options(joinedload(Score.beatmap))
.order_by(col(Score.classic_total_score).desc()) .order_by(col(Score.classic_total_score).desc())

View File

@@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from datetime import datetime from datetime import datetime
from typing import Literal from typing import Literal
@@ -13,7 +12,7 @@ from app.service.user_cache_service import get_user_cache_service
from .router import AllStrModel, router from .router import AllStrModel, router
from fastapi import HTTPException, Query from fastapi import BackgroundTasks, HTTPException, Query
from sqlmodel import select from sqlmodel import select
@@ -49,9 +48,7 @@ class V1User(AllStrModel):
return f"v1_user:{user_id}" return f"v1_user:{user_id}"
@classmethod @classmethod
async def from_db( async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User":
cls, session: Database, db_user: User, ruleset: GameMode | None = None
) -> "V1User":
# 确保 user_id 不为 None # 确保 user_id 不为 None
if db_user.id is None: if db_user.id is None:
raise ValueError("User ID cannot be None") raise ValueError("User ID cannot be None")
@@ -63,9 +60,7 @@ class V1User(AllStrModel):
current_statistics = i current_statistics = i
break break
if current_statistics: if current_statistics:
statistics = await UserStatisticsResp.from_db( statistics = await UserStatisticsResp.from_db(current_statistics, session, db_user.country_code)
current_statistics, session, db_user.country_code
)
else: else:
statistics = None statistics = None
return cls( return cls(
@@ -78,9 +73,7 @@ class V1User(AllStrModel):
playcount=statistics.play_count if statistics else 0, playcount=statistics.play_count if statistics else 0,
ranked_score=statistics.ranked_score if statistics else 0, ranked_score=statistics.ranked_score if statistics else 0,
total_score=statistics.total_score if statistics else 0, total_score=statistics.total_score if statistics else 0,
pp_rank=statistics.global_rank pp_rank=statistics.global_rank if statistics and statistics.global_rank else 0,
if statistics and statistics.global_rank
else 0,
level=current_statistics.level_current if current_statistics else 0, level=current_statistics.level_current if current_statistics else 0,
pp_raw=statistics.pp if statistics else 0.0, pp_raw=statistics.pp if statistics else 0.0,
accuracy=statistics.hit_accuracy if statistics else 0, accuracy=statistics.hit_accuracy if statistics else 0,
@@ -91,9 +84,7 @@ class V1User(AllStrModel):
count_rank_a=current_statistics.grade_a if current_statistics else 0, count_rank_a=current_statistics.grade_a if current_statistics else 0,
country=db_user.country_code, country=db_user.country_code,
total_seconds_played=statistics.play_time if statistics else 0, total_seconds_played=statistics.play_time if statistics else 0,
pp_country_rank=statistics.country_rank pp_country_rank=statistics.country_rank if statistics and statistics.country_rank else 0,
if statistics and statistics.country_rank
else 0,
events=[], # TODO events=[], # TODO
) )
@@ -106,14 +97,11 @@ class V1User(AllStrModel):
) )
async def get_user( async def get_user(
session: Database, session: Database,
background_tasks: BackgroundTasks,
user: str = Query(..., alias="u", description="用户"), user: str = Query(..., alias="u", description="用户"),
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0), ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query( type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
None, description="用户类型string 用户名称 / id 用户 ID" event_days: int = Query(default=1, ge=1, le=31, description="从现在起所有事件的最大天数"),
),
event_days: int = Query(
default=1, ge=1, le=31, description="从现在起所有事件的最大天数"
),
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
@@ -131,9 +119,7 @@ async def get_user(
if is_id_query: if is_id_query:
try: try:
user_id_for_cache = int(user) user_id_for_cache = int(user)
cached_v1_user = await cache_service.get_v1_user_from_cache( cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset)
user_id_for_cache, ruleset
)
if cached_v1_user: if cached_v1_user:
return [V1User(**cached_v1_user)] return [V1User(**cached_v1_user)]
except (ValueError, TypeError): except (ValueError, TypeError):
@@ -158,9 +144,7 @@ async def get_user(
# 异步缓存结果如果有用户ID # 异步缓存结果如果有用户ID
if db_user.id is not None: if db_user.id is not None:
user_data = v1_user.model_dump() user_data = v1_user.model_dump()
asyncio.create_task( background_tasks.add_task(cache_service.cache_v1_user, user_data, db_user.id, ruleset)
cache_service.cache_v1_user(user_data, db_user.id, ruleset)
)
return [v1_user] return [v1_user]

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401 from . import ( # noqa: F401
beatmap, beatmap,
beatmapset, beatmapset,
me, me,

View File

@@ -40,18 +40,13 @@ class BatchGetResp(BaseModel):
tags=["谱面"], tags=["谱面"],
name="查询单个谱面", name="查询单个谱面",
response_model=BeatmapResp, response_model=BeatmapResp,
description=( description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
"根据谱面 ID / MD5 / 文件名 查询单个谱面。"
"至少提供 id / checksum / filename 之一。"
),
) )
async def lookup_beatmap( async def lookup_beatmap(
db: Database, db: Database,
id: int | None = Query(default=None, alias="id", description="谱面 ID"), id: int | None = Query(default=None, alias="id", description="谱面 ID"),
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"), md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
filename: str | None = Query( filename: str | None = Query(default=None, alias="filename", description="谱面文件名"),
default=None, alias="filename", description="谱面文件名"
),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
@@ -96,43 +91,23 @@ async def get_beatmap(
tags=["谱面"], tags=["谱面"],
name="批量获取谱面", name="批量获取谱面",
response_model=BatchGetResp, response_model=BatchGetResp,
description=( description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
"批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。"
"为空时按最近更新时间返回。"
),
) )
async def batch_get_beatmaps( async def batch_get_beatmaps(
db: Database, db: Database,
beatmap_ids: list[int] = Query( beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
if not beatmap_ids: if not beatmap_ids:
beatmaps = ( beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
await db.exec(
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
)
).all()
else: else:
beatmaps = list( beatmaps = list((await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50))).all())
( not_found_beatmaps = [bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]]
await db.exec(
select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50)
)
).all()
)
not_found_beatmaps = [
bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]
]
beatmaps.extend( beatmaps.extend(
beatmap beatmap
for beatmap in await asyncio.gather( for beatmap in await asyncio.gather(
*[ *[Beatmap.get_or_fetch(db, fetcher, bid=bid) for bid in not_found_beatmaps],
Beatmap.get_or_fetch(db, fetcher, bid=bid)
for bid in not_found_beatmaps
],
return_exceptions=True, return_exceptions=True,
) )
if isinstance(beatmap, Beatmap) if isinstance(beatmap, Beatmap)
@@ -140,12 +115,7 @@ async def batch_get_beatmaps(
for beatmap in beatmaps: for beatmap in beatmaps:
await db.refresh(beatmap) await db.refresh(beatmap)
return BatchGetResp( return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
beatmaps=[
await BeatmapResp.from_db(bm, session=db, user=current_user)
for bm in beatmaps
]
)
@router.post( @router.post(
@@ -163,12 +133,8 @@ async def get_beatmap_attributes(
default_factory=list, default_factory=list,
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称", description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
), ),
ruleset: GameMode | None = Query( ruleset: GameMode | None = Query(default=None, description="指定 ruleset为空则使用谱面自身模式"),
default=None, description="指定 ruleset;为空则使用谱面自身模式" ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3),
),
ruleset_id: int | None = Query(
default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3
),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
@@ -187,16 +153,11 @@ async def get_beatmap_attributes(
if ruleset is None: if ruleset is None:
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id) beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
ruleset = beatmap_db.mode ruleset = beatmap_db.mode
key = ( key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
f"beatmap:{beatmap_id}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
)
if await redis.exists(key): if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try: try:
return await calculate_beatmap_attributes( return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher)
beatmap_id, ruleset, mods_, redis, fetcher
)
except HTTPStatusError: except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found") raise HTTPException(status_code=404, detail="Beatmap not found")
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue] except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]

View File

@@ -35,9 +35,7 @@ from sqlmodel import exists, select
async def _save_to_db(sets: SearchBeatmapsetsResp): async def _save_to_db(sets: SearchBeatmapsetsResp):
async with with_db() as session: async with with_db() as session:
for s in sets.beatmapsets: for s in sets.beatmapsets:
if not ( if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first():
await session.exec(select(exists()).where(Beatmapset.id == s.id))
).first():
await Beatmapset.from_resp(session, s) await Beatmapset.from_resp(session, s)
@@ -117,9 +115,7 @@ async def lookup_beatmapset(
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id) beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
resp = await BeatmapsetResp.from_db( resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
beatmap.beatmapset, session=db, user=current_user
)
return resp return resp
@@ -138,9 +134,7 @@ async def get_beatmapset(
): ):
try: try:
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id) beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
return await BeatmapsetResp.from_db( return await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user)
beatmapset, session=db, include=["recent_favourites"], user=current_user
)
except HTTPError: except HTTPError:
raise HTTPException(status_code=404, detail="Beatmapset not found") raise HTTPException(status_code=404, detail="Beatmapset not found")
@@ -165,9 +159,7 @@ async def download_beatmapset(
country_code = geo_info.get("country_iso", "") country_code = geo_info.get("country_iso", "")
# 优先使用IP地理位置判断如果获取失败则回退到用户账户的国家代码 # 优先使用IP地理位置判断如果获取失败则回退到用户账户的国家代码
is_china = country_code == "CN" or ( is_china = country_code == "CN" or (not country_code and current_user.country_code == "CN")
not country_code and current_user.country_code == "CN"
)
try: try:
# 使用负载均衡服务获取下载URL # 使用负载均衡服务获取下载URL
@@ -179,13 +171,10 @@ async def download_beatmapset(
# 如果负载均衡服务失败,回退到原有逻辑 # 如果负载均衡服务失败,回退到原有逻辑
if is_china: if is_china:
return RedirectResponse( return RedirectResponse(
f"https://dl.sayobot.cn/beatmaps/download/" f"https://dl.sayobot.cn/beatmaps/download/{'novideo' if no_video else 'full'}/{beatmapset_id}"
f"{'novideo' if no_video else 'full'}/{beatmapset_id}"
) )
else: else:
return RedirectResponse( return RedirectResponse(f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}")
f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}"
)
@router.post( @router.post(
@@ -197,12 +186,9 @@ async def download_beatmapset(
async def favourite_beatmapset( async def favourite_beatmapset(
db: Database, db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"), beatmapset_id: int = Path(..., description="谱面集 ID"),
action: Literal["favourite", "unfavourite"] = Form( action: Literal["favourite", "unfavourite"] = Form(description="操作类型favourite 收藏 / unfavourite 取消收藏"),
description="操作类型favourite 收藏 / unfavourite 取消收藏"
),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
assert current_user.id is not None
existing_favourite = ( existing_favourite = (
await db.exec( await db.exec(
select(FavouriteBeatmapset).where( select(FavouriteBeatmapset).where(
@@ -212,15 +198,11 @@ async def favourite_beatmapset(
) )
).first() ).first()
if (action == "favourite" and existing_favourite) or ( if (action == "favourite" and existing_favourite) or (action == "unfavourite" and not existing_favourite):
action == "unfavourite" and not existing_favourite
):
return return
if action == "favourite": if action == "favourite":
favourite = FavouriteBeatmapset( favourite = FavouriteBeatmapset(user_id=current_user.id, beatmapset_id=beatmapset_id)
user_id=current_user.id, beatmapset_id=beatmapset_id
)
db.add(favourite) db.add(favourite)
else: else:
await db.delete(existing_favourite) await db.delete(existing_favourite)

View File

@@ -4,8 +4,8 @@ from app.database import User
from app.database.lazer_user import ALL_INCLUDED from app.database.lazer_user import ALL_INCLUDED
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.dependencies.database import Database from app.dependencies.database import Database
from app.models.score import GameMode
from app.models.api_me import APIMe from app.models.api_me import APIMe
from app.models.score import GameMode
from .router import router from .router import router

View File

@@ -33,6 +33,4 @@ class BackgroundsResp(BaseModel):
description="获取当前季节背景图列表。", description="获取当前季节背景图列表。",
) )
async def get_seasonal_backgrounds(): async def get_seasonal_backgrounds():
return BackgroundsResp( return BackgroundsResp(backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds])
backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds]
)

View File

@@ -12,7 +12,7 @@ from app.service.ranking_cache_service import get_ranking_cache_service
from .router import router from .router import router
from fastapi import Path, Query, Security from fastapi import BackgroundTasks, Path, Query, Security
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import col, select from sqlmodel import col, select
@@ -38,6 +38,7 @@ class CountryResponse(BaseModel):
) )
async def get_country_ranking( async def get_country_ranking(
session: Database, session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"), ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"), # TODO page: int = Query(1, ge=1, description="页码"), # TODO
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
@@ -51,9 +52,7 @@ async def get_country_ranking(
if cached_data: if cached_data:
# 从缓存返回数据 # 从缓存返回数据
return CountryResponse( return CountryResponse(ranking=[CountryStatistics.model_validate(item) for item in cached_data])
ranking=[CountryStatistics.model_validate(item) for item in cached_data]
)
# 缓存未命中,从数据库查询 # 缓存未命中,从数据库查询
response = CountryResponse(ranking=[]) response = CountryResponse(ranking=[])
@@ -105,14 +104,15 @@ async def get_country_ranking(
# 异步缓存数据(不等待完成) # 异步缓存数据(不等待完成)
cache_data = [item.model_dump() for item in current_page_data] cache_data = [item.model_dump() for item in current_page_data]
cache_task = cache_service.cache_country_ranking(
ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60
)
# 创建后台任务来缓存数据 # 创建后台任务来缓存数据
import asyncio background_tasks.add_task(
cache_service.cache_country_ranking,
asyncio.create_task(cache_task) ruleset,
cache_data,
page,
ttl=settings.ranking_cache_expire_minutes * 60,
)
# 返回当前页的结果 # 返回当前页的结果
response.ranking = current_page_data response.ranking = current_page_data
@@ -132,10 +132,9 @@ class TopUsersResponse(BaseModel):
) )
async def get_user_ranking( async def get_user_ranking(
session: Database, session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"), ruleset: GameMode = Path(..., description="指定 ruleset"),
type: Literal["performance", "score"] = Path( type: Literal["performance", "score"] = Path(..., description="排名类型performance 表现分 / score 计分成绩总分"),
..., description="排名类型performance 表现分 / score 计分成绩总分"
),
country: str | None = Query(None, description="国家代码"), country: str | None = Query(None, description="国家代码"),
page: int = Query(1, ge=1, description="页码"), page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
@@ -149,9 +148,7 @@ async def get_user_ranking(
if cached_data: if cached_data:
# 从缓存返回数据 # 从缓存返回数据
return TopUsersResponse( return TopUsersResponse(ranking=[UserStatisticsResp.model_validate(item) for item in cached_data])
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]
)
# 缓存未命中,从数据库查询 # 缓存未命中,从数据库查询
wheres = [ wheres = [
@@ -169,25 +166,22 @@ async def get_user_ranking(
wheres.append(col(UserStatistics.user).has(country_code=country.upper())) wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
statistics_list = await session.exec( statistics_list = await session.exec(
select(UserStatistics) select(UserStatistics).where(*wheres).order_by(order_by).limit(50).offset(50 * (page - 1))
.where(*wheres)
.order_by(order_by)
.limit(50)
.offset(50 * (page - 1))
) )
# 转换为响应格式 # 转换为响应格式
ranking_data = [] ranking_data = []
for statistics in statistics_list: for statistics in statistics_list:
user_stats_resp = await UserStatisticsResp.from_db( user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
statistics, session, None, include
)
ranking_data.append(user_stats_resp) ranking_data.append(user_stats_resp)
# 异步缓存数据(不等待完成) # 异步缓存数据(不等待完成)
# 使用配置文件中的TTL设置 # 使用配置文件中的TTL设置
cache_data = [item.model_dump() for item in ranking_data] cache_data = [item.model_dump() for item in ranking_data]
cache_task = cache_service.cache_ranking( # 创建后台任务来缓存数据
background_tasks.add_task(
cache_service.cache_ranking,
ruleset, ruleset,
type, type,
cache_data, cache_data,
@@ -196,139 +190,134 @@ async def get_user_ranking(
ttl=settings.ranking_cache_expire_minutes * 60, ttl=settings.ranking_cache_expire_minutes * 60,
) )
# 创建后台任务来缓存数据
import asyncio
asyncio.create_task(cache_task)
resp = TopUsersResponse(ranking=ranking_data) resp = TopUsersResponse(ranking=ranking_data)
return resp return resp
""" @router.post( # @router.post(
"/rankings/cache/refresh", # "/rankings/cache/refresh",
name="刷新排行榜缓存", # name="刷新排行榜缓存",
description="手动刷新排行榜缓存(管理员功能)", # description="手动刷新排行榜缓存(管理员功能)",
tags=["排行榜", "管理"], # tags=["排行榜", "管理"],
) # )
async def refresh_ranking_cache( # async def refresh_ranking_cache(
session: Database, # session: Database,
ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"), # ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"), # type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"), # country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"), # include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 # current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
): # ):
redis = get_redis() # redis = get_redis()
cache_service = get_ranking_cache_service(redis) # cache_service = get_ranking_cache_service(redis)
if ruleset and type: # if ruleset and type:
# 刷新特定的用户排行榜 # # 刷新特定的用户排行榜
await cache_service.refresh_ranking_cache(session, ruleset, type, country) # await cache_service.refresh_ranking_cache(session, ruleset, type, country)
message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "") # message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
# 如果请求刷新地区排行榜 # # 如果请求刷新地区排行榜
if include_country_ranking and not country: # 地区排行榜不依赖于国家参数 # if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
await cache_service.refresh_country_ranking_cache(session, ruleset) # await cache_service.refresh_country_ranking_cache(session, ruleset)
message += f" and country ranking for {ruleset}" # message += f" and country ranking for {ruleset}"
return {"message": message} # return {"message": message}
elif ruleset: # elif ruleset:
# 刷新特定游戏模式的所有排行榜 # # 刷新特定游戏模式的所有排行榜
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"] # ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
for ranking_type in ranking_types: # for ranking_type in ranking_types:
await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country) # await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
if include_country_ranking: # if include_country_ranking:
await cache_service.refresh_country_ranking_cache(session, ruleset) # await cache_service.refresh_country_ranking_cache(session, ruleset)
return {"message": f"Refreshed all ranking caches for {ruleset}"} # return {"message": f"Refreshed all ranking caches for {ruleset}"}
else: # else:
# 刷新所有排行榜 # # 刷新所有排行榜
await cache_service.refresh_all_rankings(session) # await cache_service.refresh_all_rankings(session)
return {"message": "Refreshed all ranking caches"} # return {"message": "Refreshed all ranking caches"}
@router.post( # @router.post(
"/rankings/{ruleset}/country/cache/refresh", # "/rankings/{ruleset}/country/cache/refresh",
name="刷新地区排行榜缓存", # name="刷新地区排行榜缓存",
description="手动刷新地区排行榜缓存(管理员功能)", # description="手动刷新地区排行榜缓存(管理员功能)",
tags=["排行榜", "管理"], # tags=["排行榜", "管理"],
) # )
async def refresh_country_ranking_cache( # async def refresh_country_ranking_cache(
session: Database, # session: Database,
ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"), # ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 # current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
): # ):
redis = get_redis() # redis = get_redis()
cache_service = get_ranking_cache_service(redis) # cache_service = get_ranking_cache_service(redis)
await cache_service.refresh_country_ranking_cache(session, ruleset) # await cache_service.refresh_country_ranking_cache(session, ruleset)
return {"message": f"Refreshed country ranking cache for {ruleset}"} # return {"message": f"Refreshed country ranking cache for {ruleset}"}
@router.delete( # @router.delete(
"/rankings/cache", # "/rankings/cache",
name="清除排行榜缓存", # name="清除排行榜缓存",
description="清除排行榜缓存(管理员功能)", # description="清除排行榜缓存(管理员功能)",
tags=["排行榜", "管理"], # tags=["排行榜", "管理"],
) # )
async def clear_ranking_cache( # async def clear_ranking_cache(
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"), # ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"), # type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"), # country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"), # include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 # current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
): # ):
redis = get_redis() # redis = get_redis()
cache_service = get_ranking_cache_service(redis) # cache_service = get_ranking_cache_service(redis)
await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking) # await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
if ruleset and type: # if ruleset and type:
message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "") # message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
if include_country_ranking: # if include_country_ranking:
message += " and country ranking" # message += " and country ranking"
return {"message": message} # return {"message": message}
else: # else:
message = "Cleared all ranking caches" # message = "Cleared all ranking caches"
if include_country_ranking: # if include_country_ranking:
message += " including country rankings" # message += " including country rankings"
return {"message": message} # return {"message": message}
@router.delete( # @router.delete(
"/rankings/{ruleset}/country/cache", # "/rankings/{ruleset}/country/cache",
name="清除地区排行榜缓存", # name="清除地区排行榜缓存",
description="清除地区排行榜缓存(管理员功能)", # description="清除地区排行榜缓存(管理员功能)",
tags=["排行榜", "管理"], # tags=["排行榜", "管理"],
) # )
async def clear_country_ranking_cache( # async def clear_country_ranking_cache(
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"), # ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 # current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
): # ):
redis = get_redis() # redis = get_redis()
cache_service = get_ranking_cache_service(redis) # cache_service = get_ranking_cache_service(redis)
await cache_service.invalidate_country_cache(ruleset) # await cache_service.invalidate_country_cache(ruleset)
if ruleset: # if ruleset:
return {"message": f"Cleared country ranking cache for {ruleset}"} # return {"message": f"Cleared country ranking cache for {ruleset}"}
else: # else:
return {"message": "Cleared all country ranking caches"} # return {"message": "Cleared all country ranking caches"}
@router.get( # @router.get(
"/rankings/cache/stats", # "/rankings/cache/stats",
name="获取排行榜缓存统计", # name="获取排行榜缓存统计",
description="获取排行榜缓存统计信息(管理员功能)", # description="获取排行榜缓存统计信息(管理员功能)",
tags=["排行榜", "管理"], # tags=["排行榜", "管理"],
) # )
async def get_ranking_cache_stats( # async def get_ranking_cache_stats(
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 # current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
): # ):
redis = get_redis() # redis = get_redis()
cache_service = get_ranking_cache_service(redis) # cache_service = get_ranking_cache_service(redis)
stats = await cache_service.get_cache_stats() # stats = await cache_service.get_cache_stats()
return stats """ # return stats

View File

@@ -30,11 +30,7 @@ async def get_relationship(
request: Request, request: Request,
current_user: User = Security(get_current_user, scopes=["friends.read"]), current_user: User = Security(get_current_user, scopes=["friends.read"]),
): ):
relationship_type = ( relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
RelationshipType.FOLLOW
if request.url.path.endswith("/friends")
else RelationshipType.BLOCK
)
relationships = await db.exec( relationships = await db.exec(
select(Relationship).where( select(Relationship).where(
Relationship.user_id == current_user.id, Relationship.user_id == current_user.id,
@@ -71,12 +67,7 @@ async def add_relationship(
target: int = Query(description="目标用户 ID"), target: int = Query(description="目标用户 ID"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
assert current_user.id is not None relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
relationship_type = (
RelationshipType.FOLLOW
if request.url.path.endswith("/friends")
else RelationshipType.BLOCK
)
if target == current_user.id: if target == current_user.id:
raise HTTPException(422, "Cannot add relationship to yourself") raise HTTPException(422, "Cannot add relationship to yourself")
relationship = ( relationship = (
@@ -120,11 +111,8 @@ async def add_relationship(
Relationship.target_id == target, Relationship.target_id == target,
) )
) )
).first() ).one()
assert relationship, "Relationship should exist after commit" return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship))
return AddFriendResp(
user_relation=await RelationshipResp.from_db(db, relationship)
)
@router.delete( @router.delete(
@@ -145,11 +133,7 @@ async def delete_relationship(
target: int = Path(..., description="目标用户 ID"), target: int = Path(..., description="目标用户 ID"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
relationship_type = ( relationship_type = RelationshipType.BLOCK if "/blocks/" in request.url.path else RelationshipType.FOLLOW
RelationshipType.BLOCK
if "/blocks/" in request.url.path
else RelationshipType.FOLLOW
)
relationship = ( relationship = (
await db.exec( await db.exec(
select(Relationship).where( select(Relationship).where(

View File

@@ -39,17 +39,11 @@ async def get_all_rooms(
db: Database, db: Database,
mode: Literal["open", "ended", "participated", "owned", None] = Query( mode: Literal["open", "ended", "participated", "owned", None] = Query(
default="open", default="open",
description=( description=("房间模式open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
"房间模式open 当前开放 / ended 已经结束 / "
"participated 参与过 / owned 自己创建的房间"
),
), ),
category: RoomCategory = Query( category: RoomCategory = Query(
RoomCategory.NORMAL, RoomCategory.NORMAL,
description=( description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
"房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
" / DAILY_CHALLENGE 每日挑战"
),
), ),
status: RoomStatus | None = Query(None, description="房间状态(可选)"), status: RoomStatus | None = Query(None, description="房间状态(可选)"),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
@@ -60,10 +54,7 @@ async def get_all_rooms(
if status is not None: if status is not None:
where_clauses.append(col(Room.status) == status) where_clauses.append(col(Room.status) == status)
if mode == "open": if mode == "open":
where_clauses.append( where_clauses.append((col(Room.ends_at).is_(None)) | (col(Room.ends_at) > now.replace(tzinfo=UTC)))
(col(Room.ends_at).is_(None))
| (col(Room.ends_at) > now.replace(tzinfo=UTC))
)
if category == RoomCategory.REALTIME: if category == RoomCategory.REALTIME:
where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys())) where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
if mode == "participated": if mode == "participated":
@@ -76,10 +67,7 @@ async def get_all_rooms(
if mode == "owned": if mode == "owned":
where_clauses.append(col(Room.host_id) == current_user.id) where_clauses.append(col(Room.host_id) == current_user.id)
if mode == "ended": if mode == "ended":
where_clauses.append( where_clauses.append((col(Room.ends_at).is_not(None)) & (col(Room.ends_at) < now.replace(tzinfo=UTC)))
(col(Room.ends_at).is_not(None))
& (col(Room.ends_at) < now.replace(tzinfo=UTC))
)
db_rooms = ( db_rooms = (
( (
@@ -97,11 +85,7 @@ async def get_all_rooms(
resp = await RoomResp.from_db(room, db) resp = await RoomResp.from_db(room, db)
if category == RoomCategory.REALTIME: if category == RoomCategory.REALTIME:
mp_room = MultiplayerHubs.rooms.get(room.id) mp_room = MultiplayerHubs.rooms.get(room.id)
resp.has_password = ( resp.has_password = bool(mp_room.room.settings.password.strip()) if mp_room is not None else False
bool(mp_room.room.settings.password.strip())
if mp_room is not None
else False
)
resp.category = RoomCategory.NORMAL resp.category = RoomCategory.NORMAL
resp_list.append(resp) resp_list.append(resp)
@@ -115,9 +99,7 @@ class APICreatedRoom(RoomResp):
error: str = "" error: str = ""
async def _participate_room( async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
):
participated_user = ( participated_user = (
await session.exec( await session.exec(
select(RoomParticipatedUser).where( select(RoomParticipatedUser).where(
@@ -154,7 +136,6 @@ async def create_room(
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
assert current_user.id is not None
user_id = current_user.id user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id) db_room = await create_playlist_room_from_api(db, room, user_id)
await _participate_room(db_room.id, user_id, db_room, db, redis) await _participate_room(db_room.id, user_id, db_room, db, redis)
@@ -177,10 +158,7 @@ async def get_room(
room_id: int = Path(..., description="房间 ID"), room_id: int = Path(..., description="房间 ID"),
category: str = Query( category: str = Query(
default="", default="",
description=( description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
"房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
" / DAILY_CHALLENGE 每日挑战 (可选)"
),
), ),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
@@ -188,9 +166,7 @@ async def get_room(
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None: if db_room is None:
raise HTTPException(404, "Room not found") raise HTTPException(404, "Room not found")
resp = await RoomResp.from_db( resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user)
db_room, include=["current_user_score"], session=db, user=current_user
)
return resp return resp
@@ -400,7 +376,6 @@ async def get_room_events(
for score in scores: for score in scores:
user_ids.add(score.user_id) user_ids.add(score.user_id)
beatmap_ids.add(score.beatmap_id) beatmap_ids.add(score.beatmap_id)
assert event.id is not None
first_event_id = min(first_event_id, event.id) first_event_id = min(first_event_id, event.id)
last_event_id = max(last_event_id, event.id) last_event_id = max(last_event_id, event.id)
@@ -416,16 +391,12 @@ async def get_room_events(
users = await db.exec(select(User).where(col(User.id).in_(user_ids))) users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
user_resps = [await UserResp.from_db(user, db) for user in users] user_resps = [await UserResp.from_db(user, db) for user in users]
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids))) beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
beatmap_resps = [ beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps]
await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
]
beatmapset_resps = {} beatmapset_resps = {}
for beatmap_resp in beatmap_resps: for beatmap_resp in beatmap_resps:
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
playlist_items_resps = [ playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()]
await PlaylistResp.from_db(item) for item in playlist_items.values()
]
return RoomEvents( return RoomEvents(
beatmaps=beatmap_resps, beatmaps=beatmap_resps,

View File

@@ -104,11 +104,7 @@ async def submit_score(
if not info.passed: if not info.passed:
info.rank = Rank.F info.rank = Rank.F
score_token = ( score_token = (
await db.exec( await db.exec(select(ScoreToken).options(joinedload(ScoreToken.beatmap)).where(ScoreToken.id == token))
select(ScoreToken)
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
.where(ScoreToken.id == token)
)
).first() ).first()
if not score_token or score_token.user_id != user_id: if not score_token or score_token.user_id != user_id:
raise HTTPException(status_code=404, detail="Score token not found") raise HTTPException(status_code=404, detail="Score token not found")
@@ -138,10 +134,7 @@ async def submit_score(
except HTTPError: except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found") raise HTTPException(status_code=404, detail="Beatmap not found")
has_pp = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp has_pp = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp
has_leaderboard = ( has_leaderboard = db_beatmap.beatmap_status.has_leaderboard() | settings.enable_all_beatmap_leaderboard
db_beatmap.beatmap_status.has_leaderboard()
| settings.enable_all_beatmap_leaderboard
)
beatmap_length = db_beatmap.total_length beatmap_length = db_beatmap.total_length
score = await process_score( score = await process_score(
current_user, current_user,
@@ -167,21 +160,11 @@ async def submit_score(
has_pp, has_pp,
has_leaderboard, has_leaderboard,
) )
score = ( score = (await db.exec(select(Score).options(joinedload(Score.user)).where(Score.id == score_id))).one()
await db.exec(
select(Score)
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
.where(Score.id == score_id)
)
).first()
assert score is not None
resp = await ScoreResp.from_db(db, score) resp = await ScoreResp.from_db(db, score)
total_users = (await db.exec(select(func.count()).select_from(User))).first() total_users = (await db.exec(select(func.count()).select_from(User))).one()
assert total_users is not None if resp.rank_global is not None and resp.rank_global <= min(math.ceil(float(total_users) * 0.01), 50):
if resp.rank_global is not None and resp.rank_global <= min(
math.ceil(float(total_users) * 0.01), 50
):
rank_event = Event( rank_event = Event(
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
type=EventType.RANK, type=EventType.RANK,
@@ -207,9 +190,7 @@ async def submit_score(
score_gamemode = score.gamemode score_gamemode = score.gamemode
if user_id is not None: if user_id is not None:
background_task.add_task( background_task.add_task(_refresh_user_cache_background, redis, user_id, score_gamemode)
_refresh_user_cache_background, redis, user_id, score_gamemode
)
background_task.add_task(process_user_achievement, resp.id) background_task.add_task(process_user_achievement, resp.id)
return resp return resp
@@ -225,9 +206,7 @@ async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameM
# 创建独立的数据库会话 # 创建独立的数据库会话
session = AsyncSession(engine) session = AsyncSession(engine)
try: try:
await user_cache_service.refresh_user_cache_on_score_submit( await user_cache_service.refresh_user_cache_on_score_submit(session, user_id, mode)
session, user_id, mode
)
finally: finally:
await session.close() await session.close()
except Exception as e: except Exception as e:
@@ -280,22 +259,16 @@ async def get_beatmap_scores(
beatmap_id: int = Path(description="谱面 ID"), beatmap_id: int = Path(description="谱面 ID"),
mode: GameMode = Query(description="指定 auleset"), mode: GameMode = Query(description="指定 auleset"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
mods: list[str] = Query( mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"),
default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"
),
type: LeaderboardType = Query( type: LeaderboardType = Query(
LeaderboardType.GLOBAL, LeaderboardType.GLOBAL,
description=( description=("排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
"排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"
),
), ),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"), limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
): ):
if legacy_only: if legacy_only:
raise HTTPException( raise HTTPException(status_code=404, detail="this server only contains lazer scores")
status_code=404, detail="this server only contains lazer scores"
)
all_scores, user_score, count = await get_leaderboard( all_scores, user_score, count = await get_leaderboard(
db, db,
@@ -310,9 +283,7 @@ async def get_beatmap_scores(
user_score_resp = await ScoreResp.from_db(db, user_score) if user_score else None user_score_resp = await ScoreResp.from_db(db, user_score) if user_score else None
resp = BeatmapScores( resp = BeatmapScores(
scores=[await ScoreResp.from_db(db, score) for score in all_scores], scores=[await ScoreResp.from_db(db, score) for score in all_scores],
user_score=BeatmapUserScore( user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0)
score=user_score_resp, position=user_score_resp.rank_global or 0
)
if user_score_resp if user_score_resp
else None, else None,
score_count=count, score_count=count,
@@ -342,9 +313,7 @@ async def get_user_beatmap_score(
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
): ):
if legacy_only: if legacy_only:
raise HTTPException( raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
status_code=404, detail="This server only contains non-legacy scores"
)
user_score = ( user_score = (
await db.exec( await db.exec(
select(Score) select(Score)
@@ -386,9 +355,7 @@ async def get_user_all_beatmap_scores(
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
): ):
if legacy_only: if legacy_only:
raise HTTPException( raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
status_code=404, detail="This server only contains non-legacy scores"
)
all_user_scores = ( all_user_scores = (
await db.exec( await db.exec(
select(Score) select(Score)
@@ -420,7 +387,6 @@ async def create_solo_score(
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"), ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
assert current_user.id is not None
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
@@ -454,10 +420,7 @@ async def submit_solo_score(
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher), fetcher=Depends(get_fetcher),
): ):
assert current_user.id is not None return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher)
return await submit_score(
background_task, info, beatmap_id, token, current_user, db, redis, fetcher
)
@router.post( @router.post(
@@ -478,7 +441,6 @@ async def create_playlist_score(
version_hash: str = Form("", description="谱面版本哈希"), version_hash: str = Form("", description="谱面版本哈希"),
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
): ):
assert current_user.id is not None
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
@@ -488,26 +450,16 @@ async def create_playlist_score(
db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None
if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC): if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC):
raise HTTPException(status_code=400, detail="Room has ended") raise HTTPException(status_code=400, detail="Room has ended")
item = ( item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
)
).first()
if not item: if not item:
raise HTTPException(status_code=404, detail="Playlist not found") raise HTTPException(status_code=404, detail="Playlist not found")
# validate # validate
if not item.freestyle: if not item.freestyle:
if item.ruleset_id != ruleset_id: if item.ruleset_id != ruleset_id:
raise HTTPException( raise HTTPException(status_code=400, detail="Ruleset mismatch in playlist item")
status_code=400, detail="Ruleset mismatch in playlist item"
)
if item.beatmap_id != beatmap_id: if item.beatmap_id != beatmap_id:
raise HTTPException( raise HTTPException(status_code=400, detail="Beatmap ID mismatch in playlist item")
status_code=400, detail="Beatmap ID mismatch in playlist item"
)
agg = await session.exec( agg = await session.exec(
select(ItemAttemptsCount).where( select(ItemAttemptsCount).where(
ItemAttemptsCount.room_id == room_id, ItemAttemptsCount.room_id == room_id,
@@ -523,9 +475,7 @@ async def create_playlist_score(
if item.expired: if item.expired:
raise HTTPException(status_code=400, detail="Playlist item has expired") raise HTTPException(status_code=400, detail="Playlist item has expired")
if item.played_at: if item.played_at:
raise HTTPException( raise HTTPException(status_code=400, detail="Playlist item has already been played")
status_code=400, detail="Playlist item has already been played"
)
# 这里应该不用验证mod了吧。。。 # 这里应该不用验证mod了吧。。。
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id) background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
score_token = ScoreToken( score_token = ScoreToken(
@@ -557,18 +507,10 @@ async def submit_playlist_score(
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher), fetcher: Fetcher = Depends(get_fetcher),
): ):
assert current_user.id is not None
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
item = ( item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
)
).first()
if not item: if not item:
raise HTTPException(status_code=404, detail="Playlist item not found") raise HTTPException(status_code=404, detail="Playlist item not found")
room = await session.get(Room, room_id) room = await session.get(Room, room_id)
@@ -621,9 +563,7 @@ async def index_playlist_scores(
room_id: int, room_id: int,
playlist_id: int, playlist_id: int,
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"), limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
cursor: int = Query( cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"),
2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"
),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
): ):
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
@@ -693,9 +633,6 @@ async def show_playlist_score(
current_user: User = Security(get_client_user), current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis), redis: Redis = Depends(get_redis),
): ):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
room = await session.get(Room, room_id) room = await session.get(Room, room_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -715,9 +652,7 @@ async def show_playlist_score(
) )
) )
).first() ).first()
if completed_players := await redis.get( if completed_players := await redis.get(f"multiplayer:{room_id}:gameplay:players"):
f"multiplayer:{room_id}:gameplay:players"
):
completed = completed_players == "0" completed = completed_players == "0"
if score_record and completed: if score_record and completed:
break break
@@ -784,9 +719,7 @@ async def get_user_playlist_score(
raise HTTPException(status_code=404, detail="Score not found") raise HTTPException(status_code=404, detail="Score not found")
resp = await ScoreResp.from_db(session, score_record.score) resp = await ScoreResp.from_db(session, score_record.score)
resp.position = await get_position( resp.position = await get_position(room_id, playlist_id, score_record.score_id, session)
room_id, playlist_id, score_record.score_id, session
)
return resp return resp
@@ -850,11 +783,7 @@ async def unpin_score(
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
score_record = ( score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
await db.exec(
select(Score).where(Score.id == score_id, Score.user_id == user_id)
)
).first()
if not score_record: if not score_record:
raise HTTPException(status_code=404, detail="Score not found") raise HTTPException(status_code=404, detail="Score not found")
@@ -878,10 +807,7 @@ async def unpin_score(
"/score-pins/{score_id}/reorder", "/score-pins/{score_id}/reorder",
status_code=204, status_code=204,
name="调整置顶成绩顺序", name="调整置顶成绩顺序",
description=( description=("**客户端专属**\n调整已置顶成绩的展示顺序。仅提供 after_score_id 或 before_score_id 之一。"),
"**客户端专属**\n调整已置顶成绩的展示顺序。"
"仅提供 after_score_id 或 before_score_id 之一。"
),
tags=["成绩"], tags=["成绩"],
) )
async def reorder_score_pin( async def reorder_score_pin(
@@ -894,11 +820,7 @@ async def reorder_score_pin(
# 立即获取用户ID避免懒加载问题 # 立即获取用户ID避免懒加载问题
user_id = current_user.id user_id = current_user.id
score_record = ( score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
await db.exec(
select(Score).where(Score.id == score_id, Score.user_id == user_id)
)
).first()
if not score_record: if not score_record:
raise HTTPException(status_code=404, detail="Score not found") raise HTTPException(status_code=404, detail="Score not found")
@@ -908,8 +830,7 @@ async def reorder_score_pin(
if (after_score_id is None) == (before_score_id is None): if (after_score_id is None) == (before_score_id is None):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="Either after_score_id or before_score_id " detail="Either after_score_id or before_score_id must be provided (but not both)",
"must be provided (but not both)",
) )
all_pinned_scores = ( all_pinned_scores = (
@@ -927,9 +848,7 @@ async def reorder_score_pin(
target_order = None target_order = None
reference_score_id = after_score_id or before_score_id reference_score_id = after_score_id or before_score_id
reference_score = next( reference_score = next((s for s in all_pinned_scores if s.id == reference_score_id), None)
(s for s in all_pinned_scores if s.id == reference_score_id), None
)
if not reference_score: if not reference_score:
detail = "After score not found" if after_score_id else "Before score not found" detail = "After score not found" if after_score_id else "Before score not found"
raise HTTPException(status_code=404, detail=detail) raise HTTPException(status_code=404, detail=detail)
@@ -951,9 +870,7 @@ async def reorder_score_pin(
if current_order < s.pinned_order <= target_order and s.id != score_id: if current_order < s.pinned_order <= target_order and s.id != score_id:
updates.append((s.id, s.pinned_order - 1)) updates.append((s.id, s.pinned_order - 1))
if after_score_id: if after_score_id:
final_target = ( final_target = target_order - 1 if target_order > current_order else target_order
target_order - 1 if target_order > current_order else target_order
)
else: else:
final_target = target_order final_target = target_order
else: else:
@@ -964,9 +881,7 @@ async def reorder_score_pin(
for score_id, new_order in updates: for score_id, new_order in updates:
await db.exec(select(Score).where(Score.id == score_id)) await db.exec(select(Score).where(Score.id == score_id))
score_to_update = ( score_to_update = (await db.exec(select(Score).where(Score.id == score_id))).first()
await db.exec(select(Score).where(Score.id == score_id))
).first()
if score_to_update: if score_to_update:
score_to_update.pinned_order = new_order score_to_update.pinned_order = new_order

View File

@@ -4,34 +4,29 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, UTC
from typing import Annotated from typing import Annotated
from app.auth import authenticate_user
from app.config import settings
from app.database import User from app.database import User
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import GeoIPHelper, get_geoip_helper from app.dependencies.geoip import GeoIPHelper, get_geoip_helper
from app.database.email_verification import EmailVerification, LoginSession
from app.service.email_verification_service import ( from app.service.email_verification_service import (
EmailVerificationService, EmailVerificationService,
LoginSessionService LoginSessionService,
) )
from app.service.login_log_service import LoginLogService from app.service.login_log_service import LoginLogService
from app.models.extended_auth import ExtendedTokenResponse
from fastapi import Form, Depends, Request, HTTPException, status, Security
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlmodel import select
from .router import router from .router import router
from fastapi import Depends, Form, HTTPException, Request, Security, status
from fastapi.responses import Response
from pydantic import BaseModel
from redis.asyncio import Redis
class SessionReissueResponse(BaseModel): class SessionReissueResponse(BaseModel):
"""重新发送验证码响应""" """重新发送验证码响应"""
success: bool success: bool
message: str message: str
@@ -40,39 +35,35 @@ class SessionReissueResponse(BaseModel):
"/session/verify", "/session/verify",
name="验证会话", name="验证会话",
description="验证邮件验证码并完成会话认证", description="验证邮件验证码并完成会话认证",
status_code=204 status_code=204,
) )
async def verify_session( async def verify_session(
request: Request, request: Request,
db: Database, db: Database,
redis: Annotated[Redis, Depends(get_redis)], redis: Annotated[Redis, Depends(get_redis)],
verification_key: str = Form(..., description="8位邮件验证码"), verification_key: str = Form(..., description="8位邮件验证码"),
current_user: User = Security(get_current_user) current_user: User = Security(get_current_user),
) -> Response: ) -> Response:
""" """
验证邮件验证码并完成会话认证 验证邮件验证码并完成会话认证
对应 osu! 的 session/verify 接口 对应 osu! 的 session/verify 接口
成功时返回 204 No Content失败时返回 401 Unauthorized 成功时返回 204 No Content失败时返回 401 Unauthorized
""" """
try: try:
from app.dependencies.geoip import get_client_ip from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown") ip_address = get_client_ip(request) # noqa: F841
user_agent = request.headers.get("User-Agent", "Unknown") # noqa: F841
# 从当前认证用户获取信息 # 从当前认证用户获取信息
user_id = current_user.id user_id = current_user.id
if not user_id: if not user_id:
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未认证")
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户未认证"
)
# 验证邮件验证码 # 验证邮件验证码
success, message = await EmailVerificationService.verify_code( success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_key)
db, redis, user_id, verification_key
)
if success: if success:
# 记录成功的邮件验证 # 记录成功的邮件验证
await LoginLogService.record_login( await LoginLogService.record_login(
@@ -81,9 +72,9 @@ async def verify_session(
request=request, request=request,
login_method="email_verification", login_method="email_verification",
login_success=True, login_success=True,
notes=f"邮件验证成功" notes="邮件验证成功",
) )
# 返回 204 No Content 表示验证成功 # 返回 204 No Content 表示验证成功
return Response(status_code=status.HTTP_204_NO_CONTENT) return Response(status_code=status.HTTP_204_NO_CONTENT)
else: else:
@@ -93,83 +84,69 @@ async def verify_session(
request=request, request=request,
attempted_username=current_user.username, attempted_username=current_user.username,
login_method="email_verification", login_method="email_verification",
notes=f"邮件验证失败: {message}" notes=f"邮件验证失败: {message}",
) )
# 返回 401 Unauthorized 表示验证失败 # 返回 401 Unauthorized 表示验证失败
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=message)
status_code=status.HTTP_401_UNAUTHORIZED,
detail=message
)
except ValueError: except ValueError:
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的用户会话")
status_code=status.HTTP_401_UNAUTHORIZED, except Exception:
detail="无效的用户会话" raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="验证过程中发生错误")
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="验证过程中发生错误"
)
@router.post( @router.post(
"/session/verify/reissue", "/session/verify/reissue",
name="重新发送验证码", name="重新发送验证码",
description="重新发送邮件验证码", description="重新发送邮件验证码",
response_model=SessionReissueResponse response_model=SessionReissueResponse,
) )
async def reissue_verification_code( async def reissue_verification_code(
request: Request, request: Request,
db: Database, db: Database,
redis: Annotated[Redis, Depends(get_redis)], redis: Annotated[Redis, Depends(get_redis)],
current_user: User = Security(get_current_user) current_user: User = Security(get_current_user),
) -> SessionReissueResponse: ) -> SessionReissueResponse:
""" """
重新发送邮件验证码 重新发送邮件验证码
对应 osu! 的 session/verify/reissue 接口 对应 osu! 的 session/verify/reissue 接口
""" """
try: try:
from app.dependencies.geoip import get_client_ip from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request) ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown") user_agent = request.headers.get("User-Agent", "Unknown")
# 从当前认证用户获取信息 # 从当前认证用户获取信息
user_id = current_user.id user_id = current_user.id
if not user_id: if not user_id:
return SessionReissueResponse( return SessionReissueResponse(success=False, message="用户未认证")
success=False,
message="用户未认证"
)
# 重新发送验证码 # 重新发送验证码
success, message = await EmailVerificationService.resend_verification_code( success, message = await EmailVerificationService.resend_verification_code(
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent db,
redis,
user_id,
current_user.username,
current_user.email,
ip_address,
user_agent,
) )
return SessionReissueResponse( return SessionReissueResponse(success=success, message=message)
success=success,
message=message
)
except ValueError: except ValueError:
return SessionReissueResponse( return SessionReissueResponse(success=False, message="无效的用户会话")
success=False, except Exception:
message="无效的用户会话" return SessionReissueResponse(success=False, message="重新发送过程中发生错误")
)
except Exception as e:
return SessionReissueResponse(
success=False,
message="重新发送过程中发生错误"
)
@router.post( @router.post(
"/session/check-new-location", "/session/check-new-location",
name="检查新位置登录", name="检查新位置登录",
description="检查登录是否来自新位置(内部接口)" description="检查登录是否来自新位置(内部接口)",
) )
async def check_new_location( async def check_new_location(
request: Request, request: Request,
@@ -183,22 +160,21 @@ async def check_new_location(
""" """
try: try:
from app.dependencies.geoip import get_client_ip from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request) ip_address = get_client_ip(request)
geo_info = geoip.lookup(ip_address) geo_info = geoip.lookup(ip_address)
country_code = geo_info.get("country_iso", "XX") country_code = geo_info.get("country_iso", "XX")
is_new_location = await LoginSessionService.check_new_location( is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
db, user_id, ip_address, country_code
)
return { return {
"is_new_location": is_new_location, "is_new_location": is_new_location,
"ip_address": ip_address, "ip_address": ip_address,
"country_code": country_code "country_code": country_code,
} }
except Exception as e: except Exception as e:
return { return {
"is_new_location": True, # 出错时默认为新位置 "is_new_location": True, # 出错时默认为新位置
"error": str(e) "error": str(e),
} }

View File

@@ -1,73 +1,80 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from datetime import datetime, timedelta
import json
from typing import Any
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import json
from app.dependencies.database import get_redis, get_redis_message from app.dependencies.database import get_redis, get_redis_message
from app.log import logger from app.log import logger
from app.utils import bg_tasks
from .router import router from .router import router
from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
# Redis key constants # Redis key constants
REDIS_ONLINE_USERS_KEY = "server:online_users" REDIS_ONLINE_USERS_KEY = "server:online_users"
REDIS_PLAYING_USERS_KEY = "server:playing_users" REDIS_PLAYING_USERS_KEY = "server:playing_users"
REDIS_REGISTERED_USERS_KEY = "server:registered_users" REDIS_REGISTERED_USERS_KEY = "server:registered_users"
REDIS_ONLINE_HISTORY_KEY = "server:online_history" REDIS_ONLINE_HISTORY_KEY = "server:online_history"
# 线程池用于同步Redis操作 # 线程池用于同步Redis操作
_executor = ThreadPoolExecutor(max_workers=2) _executor = ThreadPoolExecutor(max_workers=2)
async def _redis_exec(func, *args, **kwargs): async def _redis_exec(func, *args, **kwargs):
"""在线程池中执行同步Redis操作""" """在线程池中执行同步Redis操作"""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor(_executor, func, *args, **kwargs) return await loop.run_in_executor(_executor, func, *args, **kwargs)
class ServerStats(BaseModel): class ServerStats(BaseModel):
"""服务器统计信息响应模型""" """服务器统计信息响应模型"""
registered_users: int registered_users: int
online_users: int online_users: int
playing_users: int playing_users: int
timestamp: datetime timestamp: datetime
class OnlineHistoryPoint(BaseModel): class OnlineHistoryPoint(BaseModel):
"""在线历史数据点""" """在线历史数据点"""
timestamp: datetime timestamp: datetime
online_count: int online_count: int
playing_count: int playing_count: int
class OnlineHistoryResponse(BaseModel): class OnlineHistoryResponse(BaseModel):
"""24小时在线历史响应模型""" """24小时在线历史响应模型"""
history: list[OnlineHistoryPoint] history: list[OnlineHistoryPoint]
current_stats: ServerStats current_stats: ServerStats
@router.get("/stats", response_model=ServerStats, tags=["统计"]) @router.get("/stats", response_model=ServerStats, tags=["统计"])
async def get_server_stats() -> ServerStats: async def get_server_stats() -> ServerStats:
""" """
获取服务器实时统计信息 获取服务器实时统计信息
返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息 返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息
""" """
redis = get_redis() redis = get_redis()
try: try:
# 并行获取所有统计数据 # 并行获取所有统计数据
registered_count, online_count, playing_count = await asyncio.gather( registered_count, online_count, playing_count = await asyncio.gather(
_get_registered_users_count(redis), _get_registered_users_count(redis),
_get_online_users_count(redis), _get_online_users_count(redis),
_get_playing_users_count(redis) _get_playing_users_count(redis),
) )
return ServerStats( return ServerStats(
registered_users=registered_count, registered_users=registered_count,
online_users=online_count, online_users=online_count,
playing_users=playing_count, playing_users=playing_count,
timestamp=datetime.utcnow() timestamp=datetime.utcnow(),
) )
except Exception as e: except Exception as e:
logger.error(f"Error getting server stats: {e}") logger.error(f"Error getting server stats: {e}")
@@ -76,14 +83,15 @@ async def get_server_stats() -> ServerStats:
registered_users=0, registered_users=0,
online_users=0, online_users=0,
playing_users=0, playing_users=0,
timestamp=datetime.utcnow() timestamp=datetime.utcnow(),
) )
@router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"]) @router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"])
async def get_online_history() -> OnlineHistoryResponse: async def get_online_history() -> OnlineHistoryResponse:
""" """
获取最近24小时在线统计历史 获取最近24小时在线统计历史
返回过去24小时内每小时的在线用户数和游玩用户数统计 返回过去24小时内每小时的在线用户数和游玩用户数统计
包含当前实时数据作为最新数据点 包含当前实时数据作为最新数据点
""" """
@@ -92,80 +100,80 @@ async def get_online_history() -> OnlineHistoryResponse:
redis_sync = get_redis_message() redis_sync = get_redis_message()
history_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1) history_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
history_points = [] history_points = []
# 处理历史数据 # 处理历史数据
for data in history_data: for data in history_data:
try: try:
point_data = json.loads(data) point_data = json.loads(data)
# 只保留基本字段 # 只保留基本字段
history_points.append(OnlineHistoryPoint( history_points.append(
timestamp=datetime.fromisoformat(point_data["timestamp"]), OnlineHistoryPoint(
online_count=point_data["online_count"], timestamp=datetime.fromisoformat(point_data["timestamp"]),
playing_count=point_data["playing_count"] online_count=point_data["online_count"],
)) playing_count=point_data["playing_count"],
)
)
except (json.JSONDecodeError, KeyError, ValueError) as e: except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f"Invalid history data point: {data}, error: {e}") logger.warning(f"Invalid history data point: {data}, error: {e}")
continue continue
# 获取当前实时统计信息 # 获取当前实时统计信息
current_stats = await get_server_stats() current_stats = await get_server_stats()
# 如果历史数据为空或者最新数据超过15分钟添加当前数据点 # 如果历史数据为空或者最新数据超过15分钟添加当前数据点
if not history_points or ( if not history_points or (
history_points and history_points
(current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds() > 15 * 60 and (current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds()
> 15 * 60
): ):
history_points.append(OnlineHistoryPoint( history_points.append(
timestamp=current_stats.timestamp, OnlineHistoryPoint(
online_count=current_stats.online_users, timestamp=current_stats.timestamp,
playing_count=current_stats.playing_users online_count=current_stats.online_users,
)) playing_count=current_stats.playing_users,
)
)
# 按时间排序(最新的在前) # 按时间排序(最新的在前)
history_points.sort(key=lambda x: x.timestamp, reverse=True) history_points.sort(key=lambda x: x.timestamp, reverse=True)
# 限制到最多48个数据点24小时 # 限制到最多48个数据点24小时
history_points = history_points[:48] history_points = history_points[:48]
return OnlineHistoryResponse( return OnlineHistoryResponse(history=history_points, current_stats=current_stats)
history=history_points,
current_stats=current_stats
)
except Exception as e: except Exception as e:
logger.error(f"Error getting online history: {e}") logger.error(f"Error getting online history: {e}")
# 返回空历史和当前状态 # 返回空历史和当前状态
current_stats = await get_server_stats() current_stats = await get_server_stats()
return OnlineHistoryResponse( return OnlineHistoryResponse(history=[], current_stats=current_stats)
history=[],
current_stats=current_stats
)
@router.get("/stats/debug", tags=["统计"]) @router.get("/stats/debug", tags=["统计"])
async def get_stats_debug_info(): async def get_stats_debug_info():
""" """
获取统计系统调试信息 获取统计系统调试信息
用于调试时间对齐和区间统计问题 用于调试时间对齐和区间统计问题
""" """
try: try:
from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager
current_time = datetime.utcnow() current_time = datetime.utcnow()
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
interval_stats = await EnhancedIntervalStatsManager.get_current_interval_stats() interval_stats = await EnhancedIntervalStatsManager.get_current_interval_stats()
# 获取Redis中的实际数据 # 获取Redis中的实际数据
redis_sync = get_redis_message() redis_sync = get_redis_message()
online_key = f"server:interval_online_users:{current_interval.interval_key}" online_key = f"server:interval_online_users:{current_interval.interval_key}"
playing_key = f"server:interval_playing_users:{current_interval.interval_key}" playing_key = f"server:interval_playing_users:{current_interval.interval_key}"
online_users_raw = await _redis_exec(redis_sync.smembers, online_key) online_users_raw = await _redis_exec(redis_sync.smembers, online_key)
playing_users_raw = await _redis_exec(redis_sync.smembers, playing_key) playing_users_raw = await _redis_exec(redis_sync.smembers, playing_key)
online_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in online_users_raw] online_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in online_users_raw]
playing_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in playing_users_raw] playing_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in playing_users_raw]
return { return {
"current_time": current_time.isoformat(), "current_time": current_time.isoformat(),
"current_interval": { "current_interval": {
@@ -175,28 +183,29 @@ async def get_stats_debug_info():
"is_current": current_interval.is_current(), "is_current": current_interval.is_current(),
"minutes_remaining": int((current_interval.end_time - current_time).total_seconds() / 60), "minutes_remaining": int((current_interval.end_time - current_time).total_seconds() / 60),
"seconds_remaining": int((current_interval.end_time - current_time).total_seconds()), "seconds_remaining": int((current_interval.end_time - current_time).total_seconds()),
"progress_percentage": round((1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100, 1) "progress_percentage": round(
(1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100,
1,
),
}, },
"interval_statistics": interval_stats.to_dict() if interval_stats else None, "interval_statistics": interval_stats.to_dict() if interval_stats else None,
"redis_data": { "redis_data": {
"online_users": online_users, "online_users": online_users,
"playing_users": playing_users, "playing_users": playing_users,
"online_count": len(online_users), "online_count": len(online_users),
"playing_count": len(playing_users) "playing_count": len(playing_users),
}, },
"system_status": { "system_status": {
"stats_system": "enhanced_interval_stats", "stats_system": "enhanced_interval_stats",
"data_alignment": "30_minute_boundaries", "data_alignment": "30_minute_boundaries",
"real_time_updates": True, "real_time_updates": True,
"auto_24h_fill": True "auto_24h_fill": True,
} },
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting debug info: {e}") logger.error(f"Error getting debug info: {e}")
return { return {"error": "Failed to retrieve debug information", "message": str(e)}
"error": "Failed to retrieve debug information",
"message": str(e)
}
async def _get_registered_users_count(redis) -> int: async def _get_registered_users_count(redis) -> int:
"""获取注册用户总数(从缓存)""" """获取注册用户总数(从缓存)"""
@@ -207,6 +216,7 @@ async def _get_registered_users_count(redis) -> int:
logger.error(f"Error getting registered users count: {e}") logger.error(f"Error getting registered users count: {e}")
return 0 return 0
async def _get_online_users_count(redis) -> int: async def _get_online_users_count(redis) -> int:
"""获取当前在线用户数""" """获取当前在线用户数"""
try: try:
@@ -216,6 +226,7 @@ async def _get_online_users_count(redis) -> int:
logger.error(f"Error getting online users count: {e}") logger.error(f"Error getting online users count: {e}")
return 0 return 0
async def _get_playing_users_count(redis) -> int: async def _get_playing_users_count(redis) -> int:
"""获取当前游玩用户数""" """获取当前游玩用户数"""
try: try:
@@ -225,27 +236,28 @@ async def _get_playing_users_count(redis) -> int:
logger.error(f"Error getting playing users count: {e}") logger.error(f"Error getting playing users count: {e}")
return 0 return 0
# 统计更新功能 # 统计更新功能
async def update_registered_users_count() -> None: async def update_registered_users_count() -> None:
"""更新注册用户数缓存""" """更新注册用户数缓存"""
from app.dependencies.database import with_db
from app.database import User
from app.const import BANCHOBOT_ID from app.const import BANCHOBOT_ID
from sqlmodel import select, func from app.database import User
from app.dependencies.database import with_db
from sqlmodel import func, select
redis = get_redis() redis = get_redis()
try: try:
async with with_db() as db: async with with_db() as db:
# 排除机器人用户BANCHOBOT_ID # 排除机器人用户BANCHOBOT_ID
result = await db.exec( result = await db.exec(select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID))
select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID)
)
count = result.first() count = result.first()
await redis.set(REDIS_REGISTERED_USERS_KEY, count or 0, ex=300) # 5分钟过期 await redis.set(REDIS_REGISTERED_USERS_KEY, count or 0, ex=300) # 5分钟过期
logger.debug(f"Updated registered users count: {count}") logger.debug(f"Updated registered users count: {count}")
except Exception as e: except Exception as e:
logger.error(f"Error updating registered users count: {e}") logger.error(f"Error updating registered users count: {e}")
async def add_online_user(user_id: int) -> None: async def add_online_user(user_id: int) -> None:
"""添加在线用户""" """添加在线用户"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
@@ -257,14 +269,20 @@ async def add_online_user(user_id: int) -> None:
if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期 if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期
await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期 await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期
logger.debug(f"Added online user {user_id}") logger.debug(f"Added online user {user_id}")
# 立即更新当前区间统计 # 立即更新当前区间统计
from app.service.enhanced_interval_stats import update_user_activity_in_interval from app.service.enhanced_interval_stats import update_user_activity_in_interval
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=False))
bg_tasks.add_task(
update_user_activity_in_interval,
user_id,
is_playing=False,
)
except Exception as e: except Exception as e:
logger.error(f"Error adding online user {user_id}: {e}") logger.error(f"Error adding online user {user_id}: {e}")
async def remove_online_user(user_id: int) -> None: async def remove_online_user(user_id: int) -> None:
"""移除在线用户""" """移除在线用户"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
@@ -274,6 +292,7 @@ async def remove_online_user(user_id: int) -> None:
except Exception as e: except Exception as e:
logger.error(f"Error removing online user {user_id}: {e}") logger.error(f"Error removing online user {user_id}: {e}")
async def add_playing_user(user_id: int) -> None: async def add_playing_user(user_id: int) -> None:
"""添加游玩用户""" """添加游玩用户"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
@@ -285,14 +304,16 @@ async def add_playing_user(user_id: int) -> None:
if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期 if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期
await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期 await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期
logger.debug(f"Added playing user {user_id}") logger.debug(f"Added playing user {user_id}")
# 立即更新当前区间统计 # 立即更新当前区间统计
from app.service.enhanced_interval_stats import update_user_activity_in_interval from app.service.enhanced_interval_stats import update_user_activity_in_interval
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=True))
bg_tasks.add_task(update_user_activity_in_interval, user_id, is_playing=True)
except Exception as e: except Exception as e:
logger.error(f"Error adding playing user {user_id}: {e}") logger.error(f"Error adding playing user {user_id}: {e}")
async def remove_playing_user(user_id: int) -> None: async def remove_playing_user(user_id: int) -> None:
"""移除游玩用户""" """移除游玩用户"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
@@ -301,6 +322,7 @@ async def remove_playing_user(user_id: int) -> None:
except Exception as e: except Exception as e:
logger.error(f"Error removing playing user {user_id}: {e}") logger.error(f"Error removing playing user {user_id}: {e}")
async def record_hourly_stats() -> None: async def record_hourly_stats() -> None:
"""记录统计数据 - 简化版本主要作为fallback使用""" """记录统计数据 - 简化版本主要作为fallback使用"""
redis_sync = get_redis_message() redis_sync = get_redis_message()
@@ -308,24 +330,27 @@ async def record_hourly_stats() -> None:
try: try:
# 先确保Redis连接正常 # 先确保Redis连接正常
await redis_async.ping() await redis_async.ping()
online_count = await _get_online_users_count(redis_async) online_count = await _get_online_users_count(redis_async)
playing_count = await _get_playing_users_count(redis_async) playing_count = await _get_playing_users_count(redis_async)
current_time = datetime.utcnow() current_time = datetime.utcnow()
history_point = { history_point = {
"timestamp": current_time.isoformat(), "timestamp": current_time.isoformat(),
"online_count": online_count, "online_count": online_count,
"playing_count": playing_count "playing_count": playing_count,
} }
# 添加到历史记录 # 添加到历史记录
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)) await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
# 只保留48个数据点24小时每30分钟一个点 # 只保留48个数据点24小时每30分钟一个点
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47) await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
# 设置过期时间为26小时确保有足够缓冲 # 设置过期时间为26小时确保有足够缓冲
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600) await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
logger.info(f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}") logger.info(
f"Recorded fallback stats: online={online_count}, playing={playing_count} "
f"at {current_time.strftime('%H:%M:%S')}"
)
except Exception as e: except Exception as e:
logger.error(f"Error recording fallback stats: {e}") logger.error(f"Error recording fallback stats: {e}")

View File

@@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from typing import Literal from typing import Literal
@@ -26,7 +25,7 @@ from app.service.user_cache_service import get_user_cache_service
from .router import router from .router import router
from fastapi import HTTPException, Path, Query, Security from fastapi import BackgroundTasks, HTTPException, Path, Query, Security
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import exists, false, select from sqlmodel import exists, false, select
from sqlmodel.sql.expression import col from sqlmodel.sql.expression import col
@@ -47,13 +46,10 @@ class BatchUserResponse(BaseModel):
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False) @router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
async def get_users( async def get_users(
session: Database, session: Database,
user_ids: list[int] = Query( background_task: BackgroundTasks,
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表" user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"),
),
# current_user: User = Security(get_current_user, scopes=["public"]), # current_user: User = Security(get_current_user, scopes=["public"]),
include_variant_statistics: bool = Query( include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use
default=False, description="是否包含各模式的统计信息"
), # TODO: future use
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
@@ -72,11 +68,7 @@ async def get_users(
# 查询未缓存的用户 # 查询未缓存的用户
if uncached_user_ids: if uncached_user_ids:
searched_users = ( searched_users = (await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))).all()
await session.exec(
select(User).where(col(User.id).in_(uncached_user_ids))
)
).all()
# 将查询到的用户添加到缓存并返回 # 将查询到的用户添加到缓存并返回
for searched_user in searched_users: for searched_user in searched_users:
@@ -88,7 +80,7 @@ async def get_users(
) )
cached_users.append(user_resp) cached_users.append(user_resp)
# 异步缓存,不阻塞响应 # 异步缓存,不阻塞响应
asyncio.create_task(cache_service.cache_user(user_resp)) background_task.add_task(cache_service.cache_user, user_resp)
return BatchUserResponse(users=cached_users) return BatchUserResponse(users=cached_users)
else: else:
@@ -103,7 +95,7 @@ async def get_users(
) )
users.append(user_resp) users.append(user_resp)
# 异步缓存 # 异步缓存
asyncio.create_task(cache_service.cache_user(user_resp)) background_task.add_task(cache_service.cache_user, user_resp)
return BatchUserResponse(users=users) return BatchUserResponse(users=users)
@@ -117,6 +109,7 @@ async def get_users(
) )
async def get_user_info_ruleset( async def get_user_info_ruleset(
session: Database, session: Database,
background_task: BackgroundTasks,
user_id: str = Path(description="用户 ID 或用户名"), user_id: str = Path(description="用户 ID 或用户名"),
ruleset: GameMode | None = Path(description="指定 ruleset"), ruleset: GameMode | None = Path(description="指定 ruleset"),
# current_user: User = Security(get_current_user, scopes=["public"]), # current_user: User = Security(get_current_user, scopes=["public"]),
@@ -134,9 +127,7 @@ async def get_user_info_ruleset(
searched_user = ( searched_user = (
await session.exec( await session.exec(
select(User).where( select(User).where(
User.id == int(user_id) User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
if user_id.isdigit()
else User.username == user_id.removeprefix("@")
) )
) )
).first() ).first()
@@ -151,7 +142,7 @@ async def get_user_info_ruleset(
) )
# 异步缓存结果 # 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp, ruleset)) background_task.add_task(cache_service.cache_user, user_resp, ruleset)
return user_resp return user_resp
@@ -165,6 +156,7 @@ async def get_user_info_ruleset(
tags=["用户"], tags=["用户"],
) )
async def get_user_info( async def get_user_info(
background_task: BackgroundTasks,
session: Database, session: Database,
user_id: str = Path(description="用户 ID 或用户名"), user_id: str = Path(description="用户 ID 或用户名"),
# current_user: User = Security(get_current_user, scopes=["public"]), # current_user: User = Security(get_current_user, scopes=["public"]),
@@ -182,9 +174,7 @@ async def get_user_info(
searched_user = ( searched_user = (
await session.exec( await session.exec(
select(User).where( select(User).where(
User.id == int(user_id) User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
if user_id.isdigit()
else User.username == user_id.removeprefix("@")
) )
) )
).first() ).first()
@@ -198,7 +188,7 @@ async def get_user_info(
) )
# 异步缓存结果 # 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp)) background_task.add_task(cache_service.cache_user, user_resp)
return user_resp return user_resp
@@ -212,6 +202,7 @@ async def get_user_info(
) )
async def get_user_beatmapsets( async def get_user_beatmapsets(
session: Database, session: Database,
background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"), user_id: int = Path(description="用户 ID"),
type: BeatmapsetType = Path(description="谱面集类型"), type: BeatmapsetType = Path(description="谱面集类型"),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
@@ -222,9 +213,7 @@ async def get_user_beatmapsets(
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
# 先尝试从缓存获取 # 先尝试从缓存获取
cached_result = await cache_service.get_user_beatmapsets_from_cache( cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
user_id, type.value, limit, offset
)
if cached_result is not None: if cached_result is not None:
# 根据类型恢复对象 # 根据类型恢复对象
if type == BeatmapsetType.MOST_PLAYED: if type == BeatmapsetType.MOST_PLAYED:
@@ -253,10 +242,7 @@ async def get_user_beatmapsets(
raise HTTPException(404, detail="User not found") raise HTTPException(404, detail="User not found")
favourites = await user.awaitable_attrs.favourite_beatmapsets favourites = await user.awaitable_attrs.favourite_beatmapsets
resp = [ resp = [
await BeatmapsetResp.from_db( await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
favourite.beatmapset, session=session, user=user
)
for favourite in favourites
] ]
elif type == BeatmapsetType.MOST_PLAYED: elif type == BeatmapsetType.MOST_PLAYED:
@@ -267,25 +253,18 @@ async def get_user_beatmapsets(
.limit(limit) .limit(limit)
.offset(offset) .offset(offset)
) )
resp = [ resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
await BeatmapPlaycountsResp.from_db(most_played_beatmap)
for most_played_beatmap in most_played
]
else: else:
raise HTTPException(400, detail="Invalid beatmapset type") raise HTTPException(400, detail="Invalid beatmapset type")
# 异步缓存结果 # 异步缓存结果
async def cache_beatmapsets(): async def cache_beatmapsets():
try: try:
await cache_service.cache_user_beatmapsets( await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
user_id, type.value, resp, limit, offset
)
except Exception as e: except Exception as e:
logger.error( logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}")
f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
)
asyncio.create_task(cache_beatmapsets()) background_task.add_task(cache_beatmapsets)
return resp return resp
@@ -299,18 +278,14 @@ async def get_user_beatmapsets(
) )
async def get_user_scores( async def get_user_scores(
session: Database, session: Database,
background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"), user_id: int = Path(description="用户 ID"),
type: Literal["best", "recent", "firsts", "pinned"] = Path( type: Literal["best", "recent", "firsts", "pinned"] = Path(
description=( description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")
"成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩"
" / firsts 第一名成绩 / pinned 置顶成绩"
)
), ),
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"), legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"),
include_fails: bool = Query(False, description="是否包含失败的成绩"), include_fails: bool = Query(False, description="是否包含失败的成绩"),
mode: GameMode | None = Query( mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"),
None, description="指定 ruleset (可选,默认为用户主模式)"
),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"), offset: int = Query(0, ge=0, description="偏移量"),
current_user: User = Security(get_current_user, scopes=["public"]), current_user: User = Security(get_current_user, scopes=["public"]),
@@ -320,9 +295,7 @@ async def get_user_scores(
# 先尝试从缓存获取对于recent类型使用较短的缓存时间 # 先尝试从缓存获取对于recent类型使用较短的缓存时间
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
cached_scores = await cache_service.get_user_scores_from_cache( cached_scores = await cache_service.get_user_scores_from_cache(user_id, type, mode, limit, offset)
user_id, type, mode, limit, offset
)
if cached_scores is not None: if cached_scores is not None:
return cached_scores return cached_scores
@@ -332,9 +305,7 @@ async def get_user_scores(
gamemode = mode or db_user.playmode gamemode = mode or db_user.playmode
order_by = None order_by = None
where_clause = (col(Score.user_id) == db_user.id) & ( where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
col(Score.gamemode) == gamemode
)
if not include_fails: if not include_fails:
where_clause &= col(Score.passed).is_(True) where_clause &= col(Score.passed).is_(True)
if type == "pinned": if type == "pinned":
@@ -351,13 +322,7 @@ async def get_user_scores(
where_clause &= false() where_clause &= false()
scores = ( scores = (
await session.exec( await session.exec(select(Score).where(where_clause).order_by(order_by).limit(limit).offset(offset))
select(Score)
.where(where_clause)
.order_by(order_by)
.limit(limit)
.offset(offset)
)
).all() ).all()
if not scores: if not scores:
return [] return []
@@ -371,18 +336,14 @@ async def get_user_scores(
] ]
# 异步缓存结果 # 异步缓存结果
asyncio.create_task( background_task.add_task(
cache_service.cache_user_scores( cache_service.cache_user_scores, user_id, type, score_responses, mode, limit, offset, cache_expire
user_id, type, score_responses, mode, limit, offset, cache_expire
)
) )
return score_responses return score_responses
@router.get( @router.get("/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp])
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
)
async def get_user_events( async def get_user_events(
session: Database, session: Database,
user: int, user: int,

View File

@@ -59,9 +59,7 @@ class CacheScheduler:
# 从配置文件获取间隔设置 # 从配置文件获取间隔设置
check_interval = 5 * 60 # 5分钟检查间隔 check_interval = 5 * 60 # 5分钟检查间隔
beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔 beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔
ranking_cache_interval = ( ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取
settings.ranking_cache_refresh_interval_minutes * 60
) # 从配置读取
user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔 user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔
user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔 user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔

View File

@@ -5,9 +5,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from datetime import datetime
from app.config import settings
from app.dependencies.database import engine from app.dependencies.database import engine
from app.log import logger from app.log import logger
from app.service.database_cleanup_service import DatabaseCleanupService from app.service.database_cleanup_service import DatabaseCleanupService
@@ -51,16 +49,16 @@ class DatabaseCleanupScheduler:
try: try:
# 每小时运行一次清理 # 每小时运行一次清理
await asyncio.sleep(3600) # 3600秒 = 1小时 await asyncio.sleep(3600) # 3600秒 = 1小时
if not self.running: if not self.running:
break break
await self._run_cleanup() await self._run_cleanup()
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
logger.error(f"Database cleanup scheduler error: {str(e)}") logger.error(f"Database cleanup scheduler error: {e!s}")
# 发生错误后等待5分钟再继续 # 发生错误后等待5分钟再继续
await asyncio.sleep(300) await asyncio.sleep(300)
@@ -69,20 +67,20 @@ class DatabaseCleanupScheduler:
try: try:
async with AsyncSession(engine) as db: async with AsyncSession(engine) as db:
logger.debug("Starting scheduled database cleanup...") logger.debug("Starting scheduled database cleanup...")
# 清理过期的验证码 # 清理过期的验证码
expired_codes = await DatabaseCleanupService.cleanup_expired_verification_codes(db) expired_codes = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
# 清理过期的登录会话 # 清理过期的登录会话
expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db) expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
# 只在有清理记录时输出总结 # 只在有清理记录时输出总结
total_cleaned = expired_codes + expired_sessions total_cleaned = expired_codes + expired_sessions
if total_cleaned > 0: if total_cleaned > 0:
logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}") logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}")
except Exception as e: except Exception as e:
logger.error(f"Error during scheduled database cleanup: {str(e)}") logger.error(f"Error during scheduled database cleanup: {e!s}")
async def run_manual_cleanup(self): async def run_manual_cleanup(self):
"""手动运行完整清理""" """手动运行完整清理"""
@@ -95,7 +93,7 @@ class DatabaseCleanupScheduler:
logger.debug(f"Manual cleanup completed, total records cleaned: {total}") logger.debug(f"Manual cleanup completed, total records cleaned: {total}")
return results return results
except Exception as e: except Exception as e:
logger.error(f"Error during manual database cleanup: {str(e)}") logger.error(f"Error during manual database cleanup: {e!s}")
return {} return {}

View File

@@ -63,10 +63,7 @@ class BeatmapCacheService:
if preload_tasks: if preload_tasks:
results = await asyncio.gather(*preload_tasks, return_exceptions=True) results = await asyncio.gather(*preload_tasks, return_exceptions=True)
success_count = sum(1 for r in results if r is True) success_count = sum(1 for r in results if r is True)
logger.info( logger.info(f"Preloaded {success_count}/{len(preload_tasks)} beatmaps successfully")
f"Preloaded {success_count}/{len(preload_tasks)} "
f"beatmaps successfully"
)
except Exception as e: except Exception as e:
logger.error(f"Error during beatmap preloading: {e}") logger.error(f"Error during beatmap preloading: {e}")
@@ -119,9 +116,7 @@ class BeatmapCacheService:
return { return {
"cached_beatmaps": len(keys), "cached_beatmaps": len(keys),
"estimated_total_size_mb": ( "estimated_total_size_mb": (round(total_size / 1024 / 1024, 2) if total_size > 0 else 0),
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
),
"preloading": self._preloading, "preloading": self._preloading,
} }
except Exception as e: except Exception as e:
@@ -155,9 +150,7 @@ def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheS
return _cache_service return _cache_service
async def schedule_preload_task( async def schedule_preload_task(session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
session: AsyncSession, redis: Redis, fetcher: "Fetcher"
):
""" """
定时预加载任务 定时预加载任务
""" """

View File

@@ -192,22 +192,16 @@ class BeatmapDownloadService:
healthy_endpoints.sort(key=lambda x: x.priority) healthy_endpoints.sort(key=lambda x: x.priority)
return healthy_endpoints return healthy_endpoints
def get_download_url( def get_download_url(self, beatmapset_id: int, no_video: bool, is_china: bool) -> str:
self, beatmapset_id: int, no_video: bool, is_china: bool
) -> str:
"""获取下载URL带负载均衡和故障转移""" """获取下载URL带负载均衡和故障转移"""
healthy_endpoints = self.get_healthy_endpoints(is_china) healthy_endpoints = self.get_healthy_endpoints(is_china)
if not healthy_endpoints: if not healthy_endpoints:
# 如果没有健康的端点,记录错误并回退到所有端点中优先级最高的 # 如果没有健康的端点,记录错误并回退到所有端点中优先级最高的
logger.error(f"No healthy endpoints available for is_china={is_china}") logger.error(f"No healthy endpoints available for is_china={is_china}")
endpoints = ( endpoints = self.china_endpoints if is_china else self.international_endpoints
self.china_endpoints if is_china else self.international_endpoints
)
if not endpoints: if not endpoints:
raise HTTPException( raise HTTPException(status_code=503, detail="No download endpoints available")
status_code=503, detail="No download endpoints available"
)
endpoint = min(endpoints, key=lambda x: x.priority) endpoint = min(endpoints, key=lambda x: x.priority)
else: else:
# 使用第一个健康的端点(已按优先级排序) # 使用第一个健康的端点(已按优先级排序)
@@ -218,9 +212,7 @@ class BeatmapDownloadService:
video_type = "novideo" if no_video else "full" video_type = "novideo" if no_video else "full"
return endpoint.url_template.format(type=video_type, sid=beatmapset_id) return endpoint.url_template.format(type=video_type, sid=beatmapset_id)
elif endpoint.name == "Nerinyan": elif endpoint.name == "Nerinyan":
return endpoint.url_template.format( return endpoint.url_template.format(sid=beatmapset_id, no_video="true" if no_video else "false")
sid=beatmapset_id, no_video="true" if no_video else "false"
)
elif endpoint.name == "OsuDirect": elif endpoint.name == "OsuDirect":
# osu.direct 似乎没有no_video参数直接使用基础URL # osu.direct 似乎没有no_video参数直接使用基础URL
return endpoint.url_template.format(sid=beatmapset_id) return endpoint.url_template.format(sid=beatmapset_id)
@@ -239,9 +231,7 @@ class BeatmapDownloadService:
for name, status in self.endpoint_status.items(): for name, status in self.endpoint_status.items():
status_info["endpoints"][name] = { status_info["endpoints"][name] = {
"healthy": status.is_healthy, "healthy": status.is_healthy,
"last_check": status.last_check.isoformat() "last_check": status.last_check.isoformat() if status.last_check else None,
if status.last_check
else None,
"consecutive_failures": status.consecutive_failures, "consecutive_failures": status.consecutive_failures,
"last_error": status.last_error, "last_error": status.last_error,
"priority": status.endpoint.priority, "priority": status.endpoint.priority,

View File

@@ -11,9 +11,7 @@ from app.models.score import GameMode
from sqlmodel import col, exists, select, update from sqlmodel import col, exists, select, update
@get_scheduler().scheduled_job( @get_scheduler().scheduled_job("cron", hour=0, minute=0, second=0, id="calculate_user_rank")
"cron", hour=0, minute=0, second=0, id="calculate_user_rank"
)
async def calculate_user_rank(is_today: bool = False): async def calculate_user_rank(is_today: bool = False):
today = datetime.now(UTC).date() today = datetime.now(UTC).date()
target_date = today if is_today else today - timedelta(days=1) target_date = today if is_today else today - timedelta(days=1)

View File

@@ -11,9 +11,7 @@ from sqlmodel import exists, select
async def create_banchobot(): async def create_banchobot():
async with with_db() as session: async with with_db() as session:
is_exist = ( is_exist = (await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))).first()
await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))
).first()
if not is_exist: if not is_exist:
banchobot = User( banchobot = User(
username="BanchoBot", username="BanchoBot",

View File

@@ -82,8 +82,7 @@ async def daily_challenge_job():
if beatmap is None or ruleset_id is None: if beatmap is None or ruleset_id is None:
logger.warning( logger.warning(
f"[DailyChallenge] Missing required data for daily challenge {now}." f"[DailyChallenge] Missing required data for daily challenge {now}. Will try again in 5 minutes."
" Will try again in 5 minutes."
) )
get_scheduler().add_job( get_scheduler().add_job(
daily_challenge_job, daily_challenge_job,
@@ -104,9 +103,7 @@ async def daily_challenge_job():
else: else:
allowed_mods_list = get_available_mods(ruleset_id_int, required_mods_list) allowed_mods_list = get_available_mods(ruleset_id_int, required_mods_list)
next_day = (now + timedelta(days=1)).replace( next_day = (now + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
hour=0, minute=0, second=0, microsecond=0
)
room = await create_daily_challenge_room( room = await create_daily_challenge_room(
beatmap=beatmap_int, beatmap=beatmap_int,
ruleset_id=ruleset_id_int, ruleset_id=ruleset_id_int,
@@ -114,24 +111,13 @@ async def daily_challenge_job():
allowed_mods=allowed_mods_list, allowed_mods=allowed_mods_list,
duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60), duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60),
) )
await MetadataHubs.broadcast_call( await MetadataHubs.broadcast_call("DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id))
"DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id) logger.success(f"[DailyChallenge] Added today's daily challenge: {beatmap=}, {ruleset_id=}, {required_mods=}")
)
logger.success(
"[DailyChallenge] Added today's daily challenge: "
f"{beatmap=}, {ruleset_id=}, {required_mods=}"
)
return return
except (ValueError, json.JSONDecodeError) as e: except (ValueError, json.JSONDecodeError) as e:
logger.warning( logger.warning(f"[DailyChallenge] Error processing daily challenge data: {e} Will try again in 5 minutes.")
f"[DailyChallenge] Error processing daily challenge data: {e}"
" Will try again in 5 minutes."
)
except Exception as e: except Exception as e:
logger.exception( logger.exception(f"[DailyChallenge] Unexpected error in daily challenge job: {e} Will try again in 5 minutes.")
f"[DailyChallenge] Unexpected error in daily challenge job: {e}"
" Will try again in 5 minutes."
)
get_scheduler().add_job( get_scheduler().add_job(
daily_challenge_job, daily_challenge_job,
"date", "date",
@@ -139,9 +125,7 @@ async def daily_challenge_job():
) )
@get_scheduler().scheduled_job( @get_scheduler().scheduled_job("cron", hour=0, minute=1, second=0, id="daily_challenge_last_top")
"cron", hour=0, minute=1, second=0, id="daily_challenge_last_top"
)
async def process_daily_challenge_top(): async def process_daily_challenge_top():
async with with_db() as session: async with with_db() as session:
now = datetime.now(UTC) now = datetime.now(UTC)
@@ -182,11 +166,7 @@ async def process_daily_challenge_top():
await session.commit() await session.commit()
del s del s
user_ids = ( user_ids = (await session.exec(select(User.id).where(col(User.id).not_in(participated_users)))).all()
await session.exec(
select(User.id).where(col(User.id).not_in(participated_users))
)
).all()
for id in user_ids: for id in user_ids:
stats = await session.get(DailyChallengeStats, id) stats = await session.get(DailyChallengeStats, id)
if stats is None: # not execute if stats is None: # not execute

View File

@@ -4,14 +4,13 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, UTC, timedelta from datetime import UTC, datetime, timedelta
from app.database.email_verification import EmailVerification, LoginSession from app.database.email_verification import EmailVerification, LoginSession
from app.log import logger from app.log import logger
from sqlmodel import select from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy import and_
class DatabaseCleanupService: class DatabaseCleanupService:
@@ -21,211 +20,207 @@ class DatabaseCleanupService:
async def cleanup_expired_verification_codes(db: AsyncSession) -> int: async def cleanup_expired_verification_codes(db: AsyncSession) -> int:
""" """
清理过期的邮件验证码 清理过期的邮件验证码
Args: Args:
db: 数据库会话 db: 数据库会话
Returns: Returns:
int: 清理的记录数 int: 清理的记录数
""" """
try: try:
# 查找过期的验证码记录 # 查找过期的验证码记录
current_time = datetime.now(UTC) current_time = datetime.now(UTC)
stmt = select(EmailVerification).where( stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
EmailVerification.expires_at < current_time
)
result = await db.exec(stmt) result = await db.exec(stmt)
expired_codes = result.all() expired_codes = result.all()
# 删除过期的记录 # 删除过期的记录
deleted_count = 0 deleted_count = 0
for code in expired_codes: for code in expired_codes:
await db.delete(code) await db.delete(code)
deleted_count += 1 deleted_count += 1
await db.commit() await db.commit()
if deleted_count > 0: if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired email verification codes") logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired email verification codes")
return deleted_count return deleted_count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {str(e)}") logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {e!s}")
return 0 return 0
@staticmethod @staticmethod
async def cleanup_expired_login_sessions(db: AsyncSession) -> int: async def cleanup_expired_login_sessions(db: AsyncSession) -> int:
""" """
清理过期的登录会话 清理过期的登录会话
Args: Args:
db: 数据库会话 db: 数据库会话
Returns: Returns:
int: 清理的记录数 int: 清理的记录数
""" """
try: try:
# 查找过期的登录会话记录 # 查找过期的登录会话记录
current_time = datetime.now(UTC) current_time = datetime.now(UTC)
stmt = select(LoginSession).where( stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
LoginSession.expires_at < current_time
)
result = await db.exec(stmt) result = await db.exec(stmt)
expired_sessions = result.all() expired_sessions = result.all()
# 删除过期的记录 # 删除过期的记录
deleted_count = 0 deleted_count = 0
for session in expired_sessions: for session in expired_sessions:
await db.delete(session) await db.delete(session)
deleted_count += 1 deleted_count += 1
await db.commit() await db.commit()
if deleted_count > 0: if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired login sessions") logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired login sessions")
return deleted_count return deleted_count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {str(e)}") logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {e!s}")
return 0 return 0
@staticmethod @staticmethod
async def cleanup_old_used_verification_codes(db: AsyncSession, days_old: int = 7) -> int: async def cleanup_old_used_verification_codes(db: AsyncSession, days_old: int = 7) -> int:
""" """
清理旧的已使用验证码记录 清理旧的已使用验证码记录
Args: Args:
db: 数据库会话 db: 数据库会话
days_old: 清理多少天前的已使用记录默认7天 days_old: 清理多少天前的已使用记录默认7天
Returns: Returns:
int: 清理的记录数 int: 清理的记录数
""" """
try: try:
# 查找指定天数前的已使用验证码记录 # 查找指定天数前的已使用验证码记录
cutoff_time = datetime.now(UTC) - timedelta(days=days_old) cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
stmt = select(EmailVerification).where( stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
EmailVerification.is_used == True
)
result = await db.exec(stmt) result = await db.exec(stmt)
all_used_codes = result.all() all_used_codes = result.all()
# 筛选出过期的记录 # 筛选出过期的记录
old_used_codes = [ old_used_codes = [code for code in all_used_codes if code.used_at and code.used_at < cutoff_time]
code for code in all_used_codes
if code.used_at and code.used_at < cutoff_time
]
# 删除旧的已使用记录 # 删除旧的已使用记录
deleted_count = 0 deleted_count = 0
for code in old_used_codes: for code in old_used_codes:
await db.delete(code) await db.delete(code)
deleted_count += 1 deleted_count += 1
await db.commit() await db.commit()
if deleted_count > 0: if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days") logger.debug(
f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days"
)
return deleted_count return deleted_count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {str(e)}") logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}")
return 0 return 0
@staticmethod @staticmethod
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int: async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
""" """
清理旧的已验证会话记录 清理旧的已验证会话记录
Args: Args:
db: 数据库会话 db: 数据库会话
days_old: 清理多少天前的已验证记录默认30天 days_old: 清理多少天前的已验证记录默认30天
Returns: Returns:
int: 清理的记录数 int: 清理的记录数
""" """
try: try:
# 查找指定天数前的已验证会话记录 # 查找指定天数前的已验证会话记录
cutoff_time = datetime.now(UTC) - timedelta(days=days_old) cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
stmt = select(LoginSession).where( stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
LoginSession.is_verified == True
)
result = await db.exec(stmt) result = await db.exec(stmt)
all_verified_sessions = result.all() all_verified_sessions = result.all()
# 筛选出过期的记录 # 筛选出过期的记录
old_verified_sessions = [ old_verified_sessions = [
session for session in all_verified_sessions session
for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_time if session.verified_at and session.verified_at < cutoff_time
] ]
# 删除旧的已验证记录 # 删除旧的已验证记录
deleted_count = 0 deleted_count = 0
for session in old_verified_sessions: for session in old_verified_sessions:
await db.delete(session) await db.delete(session)
deleted_count += 1 deleted_count += 1
await db.commit() await db.commit()
if deleted_count > 0: if deleted_count > 0:
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days") logger.debug(
f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days"
)
return deleted_count return deleted_count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {str(e)}") logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {e!s}")
return 0 return 0
@staticmethod @staticmethod
async def run_full_cleanup(db: AsyncSession) -> dict[str, int]: async def run_full_cleanup(db: AsyncSession) -> dict[str, int]:
""" """
运行完整的清理流程 运行完整的清理流程
Args: Args:
db: 数据库会话 db: 数据库会话
Returns: Returns:
dict: 各项清理的结果统计 dict: 各项清理的结果统计
""" """
results = {} results = {}
# 清理过期的验证码 # 清理过期的验证码
results["expired_verification_codes"] = await DatabaseCleanupService.cleanup_expired_verification_codes(db) results["expired_verification_codes"] = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
# 清理过期的登录会话 # 清理过期的登录会话
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db) results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
# 清理7天前的已使用验证码 # 清理7天前的已使用验证码
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7) results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
# 清理30天前的已验证会话 # 清理30天前的已验证会话
results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30) results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
total_cleaned = sum(results.values()) total_cleaned = sum(results.values())
if total_cleaned > 0: if total_cleaned > 0:
logger.debug(f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}") logger.debug(
f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}"
)
return results return results
@staticmethod @staticmethod
async def get_cleanup_statistics(db: AsyncSession) -> dict[str, int]: async def get_cleanup_statistics(db: AsyncSession) -> dict[str, int]:
""" """
获取清理统计信息 获取清理统计信息
Args: Args:
db: 数据库会话 db: 数据库会话
Returns: Returns:
dict: 统计信息 dict: 统计信息
""" """
@@ -233,57 +228,54 @@ class DatabaseCleanupService:
current_time = datetime.now(UTC) current_time = datetime.now(UTC)
cutoff_7_days = current_time - timedelta(days=7) cutoff_7_days = current_time - timedelta(days=7)
cutoff_30_days = current_time - timedelta(days=30) cutoff_30_days = current_time - timedelta(days=30)
# 统计过期的验证码数量 # 统计过期的验证码数量
expired_codes_stmt = select(EmailVerification).where( expired_codes_stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
EmailVerification.expires_at < current_time
)
expired_codes_result = await db.exec(expired_codes_stmt) expired_codes_result = await db.exec(expired_codes_stmt)
expired_codes_count = len(expired_codes_result.all()) expired_codes_count = len(expired_codes_result.all())
# 统计过期的登录会话数量 # 统计过期的登录会话数量
expired_sessions_stmt = select(LoginSession).where( expired_sessions_stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
LoginSession.expires_at < current_time
)
expired_sessions_result = await db.exec(expired_sessions_stmt) expired_sessions_result = await db.exec(expired_sessions_stmt)
expired_sessions_count = len(expired_sessions_result.all()) expired_sessions_count = len(expired_sessions_result.all())
# 统计7天前的已使用验证码数量 # 统计7天前的已使用验证码数量
old_used_codes_stmt = select(EmailVerification).where( old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
EmailVerification.is_used == True
)
old_used_codes_result = await db.exec(old_used_codes_stmt) old_used_codes_result = await db.exec(old_used_codes_stmt)
all_used_codes = old_used_codes_result.all() all_used_codes = old_used_codes_result.all()
old_used_codes_count = len([ old_used_codes_count = len(
code for code in all_used_codes [code for code in all_used_codes if code.used_at and code.used_at < cutoff_7_days]
if code.used_at and code.used_at < cutoff_7_days
])
# 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(
LoginSession.is_verified == True
) )
# 统计30天前的已验证会话数量
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt) old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
all_verified_sessions = old_verified_sessions_result.all() all_verified_sessions = old_verified_sessions_result.all()
old_verified_sessions_count = len([ old_verified_sessions_count = len(
session for session in all_verified_sessions [
if session.verified_at and session.verified_at < cutoff_30_days session
]) for session in all_verified_sessions
if session.verified_at and session.verified_at < cutoff_30_days
]
)
return { return {
"expired_verification_codes": expired_codes_count, "expired_verification_codes": expired_codes_count,
"expired_login_sessions": expired_sessions_count, "expired_login_sessions": expired_sessions_count,
"old_used_verification_codes": old_used_codes_count, "old_used_verification_codes": old_used_codes_count,
"old_verified_sessions": old_verified_sessions_count, "old_verified_sessions": old_verified_sessions_count,
"total_cleanable": expired_codes_count + expired_sessions_count + old_used_codes_count + old_verified_sessions_count "total_cleanable": expired_codes_count
+ expired_sessions_count
+ old_used_codes_count
+ old_verified_sessions_count,
} }
except Exception as e: except Exception as e:
logger.error(f"[Cleanup Service] Error getting cleanup statistics: {str(e)}") logger.error(f"[Cleanup Service] Error getting cleanup statistics: {e!s}")
return { return {
"expired_verification_codes": 0, "expired_verification_codes": 0,
"expired_login_sessions": 0, "expired_login_sessions": 0,
"old_used_verification_codes": 0, "old_used_verification_codes": 0,
"old_verified_sessions": 0, "old_verified_sessions": 0,
"total_cleanable": 0 "total_cleanable": 0,
} }

View File

@@ -8,17 +8,18 @@ from __future__ import annotations
import asyncio import asyncio
import concurrent.futures import concurrent.futures
from datetime import datetime from datetime import datetime
import json
import uuid
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from typing import Dict, Any, Optional from email.mime.text import MIMEText
import redis as sync_redis # 添加同步Redis导入 import json
import smtplib
from typing import Any
import uuid
from app.config import settings from app.config import settings
from app.dependencies.database import redis_message_client # 使用同步Redis客户端
from app.log import logger from app.log import logger
from app.utils import bg_tasks # 添加同步Redis导入
import redis as sync_redis
class EmailQueue: class EmailQueue:
@@ -30,14 +31,14 @@ class EmailQueue:
self._processing = False self._processing = False
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
self._retry_limit = 3 # 重试次数限制 self._retry_limit = 3 # 重试次数限制
# 邮件配置 # 邮件配置
self.smtp_server = getattr(settings, 'smtp_server', 'localhost') self.smtp_server = getattr(settings, "smtp_server", "localhost")
self.smtp_port = getattr(settings, 'smtp_port', 587) self.smtp_port = getattr(settings, "smtp_port", 587)
self.smtp_username = getattr(settings, 'smtp_username', '') self.smtp_username = getattr(settings, "smtp_username", "")
self.smtp_password = getattr(settings, 'smtp_password', '') self.smtp_password = getattr(settings, "smtp_password", "")
self.from_email = getattr(settings, 'from_email', 'noreply@example.com') self.from_email = getattr(settings, "from_email", "noreply@example.com")
self.from_name = getattr(settings, 'from_name', 'osu! server') self.from_name = getattr(settings, "from_name", "osu! server")
async def _run_in_executor(self, func, *args): async def _run_in_executor(self, func, *args):
"""在线程池中运行同步操作""" """在线程池中运行同步操作"""
@@ -48,7 +49,7 @@ class EmailQueue:
"""启动邮件处理任务""" """启动邮件处理任务"""
if not self._processing: if not self._processing:
self._processing = True self._processing = True
asyncio.create_task(self._process_email_queue()) bg_tasks.add_task(self._process_email_queue)
logger.info("Email queue processing started") logger.info("Email queue processing started")
async def stop_processing(self): async def stop_processing(self):
@@ -56,27 +57,29 @@ class EmailQueue:
self._processing = False self._processing = False
logger.info("Email queue processing stopped") logger.info("Email queue processing stopped")
async def enqueue_email(self, async def enqueue_email(
to_email: str, self,
subject: str, to_email: str,
content: str, subject: str,
html_content: Optional[str] = None, content: str,
metadata: Optional[Dict[str, Any]] = None) -> str: html_content: str | None = None,
metadata: dict[str, Any] | None = None,
) -> str:
""" """
将邮件加入队列等待发送 将邮件加入队列等待发送
Args: Args:
to_email: 收件人邮箱地址 to_email: 收件人邮箱地址
subject: 邮件主题 subject: 邮件主题
content: 邮件纯文本内容 content: 邮件纯文本内容
html_content: 邮件HTML内容如果有 html_content: 邮件HTML内容如果有
metadata: 额外元数据如密码重置ID等 metadata: 额外元数据如密码重置ID等
Returns: Returns:
邮件任务ID 邮件任务ID
""" """
email_id = str(uuid.uuid4()) email_id = str(uuid.uuid4())
email_data = { email_data = {
"id": email_id, "id": email_id,
"to_email": to_email, "to_email": to_email,
@@ -86,125 +89,117 @@ class EmailQueue:
"metadata": json.dumps(metadata) if metadata else "{}", "metadata": json.dumps(metadata) if metadata else "{}",
"created_at": datetime.now().isoformat(), "created_at": datetime.now().isoformat(),
"status": "pending", # pending, sending, sent, failed "status": "pending", # pending, sending, sent, failed
"retry_count": "0" "retry_count": "0",
} }
# 将邮件数据存入Redis # 将邮件数据存入Redis
await self._run_in_executor( await self._run_in_executor(lambda: self.redis.hset(f"email:{email_id}", mapping=email_data))
lambda: self.redis.hset(f"email:{email_id}", mapping=email_data)
)
# 设置24小时过期防止数据堆积 # 设置24小时过期防止数据堆积
await self._run_in_executor( await self._run_in_executor(self.redis.expire, f"email:{email_id}", 86400)
self.redis.expire, f"email:{email_id}", 86400
)
# 加入发送队列 # 加入发送队列
await self._run_in_executor( await self._run_in_executor(self.redis.lpush, "email_queue", email_id)
self.redis.lpush, "email_queue", email_id
)
logger.info(f"Email enqueued with id: {email_id} to {to_email}") logger.info(f"Email enqueued with id: {email_id} to {to_email}")
return email_id return email_id
async def get_email_status(self, email_id: str) -> Dict[str, Any]: async def get_email_status(self, email_id: str) -> dict[str, Any]:
""" """
获取邮件发送状态 获取邮件发送状态
Args: Args:
email_id: 邮件任务ID email_id: 邮件任务ID
Returns: Returns:
邮件任务状态信息 邮件任务状态信息
""" """
email_data = await self._run_in_executor( email_data = await self._run_in_executor(self.redis.hgetall, f"email:{email_id}")
self.redis.hgetall, f"email:{email_id}"
)
# 解码Redis返回的字节数据 # 解码Redis返回的字节数据
if email_data: if email_data:
return { return {
k.decode("utf-8") if isinstance(k, bytes) else k: k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") if isinstance(v, bytes) else v
v.decode("utf-8") if isinstance(v, bytes) else v
for k, v in email_data.items() for k, v in email_data.items()
} }
return {"status": "not_found"} return {"status": "not_found"}
async def _process_email_queue(self): async def _process_email_queue(self):
"""处理邮件队列""" """处理邮件队列"""
logger.info("Starting email queue processor") logger.info("Starting email queue processor")
while self._processing: while self._processing:
try: try:
# 从队列获取邮件ID # 从队列获取邮件ID
def brpop_operation(): def brpop_operation():
return self.redis.brpop(["email_queue"], timeout=5) return self.redis.brpop(["email_queue"], timeout=5)
result = await self._run_in_executor(brpop_operation) result = await self._run_in_executor(brpop_operation)
if not result: if not result:
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
# 解包返回结果(列表名和值) # 解包返回结果(列表名和值)
queue_name, email_id = result queue_name, email_id = result
if isinstance(email_id, bytes): if isinstance(email_id, bytes):
email_id = email_id.decode("utf-8") email_id = email_id.decode("utf-8")
# 获取邮件数据 # 获取邮件数据
email_data = await self.get_email_status(email_id) email_data = await self.get_email_status(email_id)
if email_data.get("status") == "not_found": if email_data.get("status") == "not_found":
logger.warning(f"Email data not found for id: {email_id}") logger.warning(f"Email data not found for id: {email_id}")
continue continue
# 更新状态为发送中 # 更新状态为发送中
await self._run_in_executor( await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "sending")
self.redis.hset, f"email:{email_id}", "status", "sending"
)
# 尝试发送邮件 # 尝试发送邮件
success = await self._send_email(email_data) success = await self._send_email(email_data)
if success: if success:
# 更新状态为已发送 # 更新状态为已发送
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "sent")
await self._run_in_executor( await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "status", "sent" self.redis.hset,
) f"email:{email_id}",
await self._run_in_executor( "sent_at",
self.redis.hset, f"email:{email_id}", "sent_at", datetime.now().isoformat() datetime.now().isoformat(),
) )
logger.info(f"Email {email_id} sent successfully to {email_data.get('to_email')}") logger.info(f"Email {email_id} sent successfully to {email_data.get('to_email')}")
else: else:
# 计算重试次数 # 计算重试次数
retry_count = int(email_data.get("retry_count", "0")) + 1 retry_count = int(email_data.get("retry_count", "0")) + 1
if retry_count <= self._retry_limit: if retry_count <= self._retry_limit:
# 重新入队,稍后重试 # 重新入队,稍后重试
await self._run_in_executor( await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "retry_count", str(retry_count) self.redis.hset,
f"email:{email_id}",
"retry_count",
str(retry_count),
) )
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "pending")
await self._run_in_executor( await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "status", "pending" self.redis.hset,
f"email:{email_id}",
"last_retry",
datetime.now().isoformat(),
) )
await self._run_in_executor(
self.redis.hset, f"email:{email_id}", "last_retry", datetime.now().isoformat()
)
# 延迟重试(使用指数退避) # 延迟重试(使用指数退避)
delay = 60 * (2 ** (retry_count - 1)) # 1分钟2分钟4分钟... delay = 60 * (2 ** (retry_count - 1)) # 1分钟2分钟4分钟...
# 创建延迟任务 # 创建延迟任务
asyncio.create_task(self._delayed_retry(email_id, delay)) bg_tasks.add_task(self._delayed_retry, email_id, delay)
logger.warning(f"Email {email_id} will be retried in {delay} seconds (attempt {retry_count})") logger.warning(f"Email {email_id} will be retried in {delay} seconds (attempt {retry_count})")
else: else:
# 超过重试次数,标记为失败 # 超过重试次数,标记为失败
await self._run_in_executor( await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "failed")
self.redis.hset, f"email:{email_id}", "status", "failed"
)
logger.error(f"Email {email_id} failed after {retry_count} attempts") logger.error(f"Email {email_id} failed after {retry_count} attempts")
except Exception as e: except Exception as e:
logger.error(f"Error processing email queue: {e}") logger.error(f"Error processing email queue: {e}")
await asyncio.sleep(5) # 出错后等待5秒 await asyncio.sleep(5) # 出错后等待5秒
@@ -212,53 +207,51 @@ class EmailQueue:
async def _delayed_retry(self, email_id: str, delay: int): async def _delayed_retry(self, email_id: str, delay: int):
"""延迟重试发送邮件""" """延迟重试发送邮件"""
await asyncio.sleep(delay) await asyncio.sleep(delay)
await self._run_in_executor( await self._run_in_executor(self.redis.lpush, "email_queue", email_id)
self.redis.lpush, "email_queue", email_id
)
logger.info(f"Re-queued email {email_id} for retry after {delay} seconds") logger.info(f"Re-queued email {email_id} for retry after {delay} seconds")
async def _send_email(self, email_data: Dict[str, Any]) -> bool: async def _send_email(self, email_data: dict[str, Any]) -> bool:
""" """
实际发送邮件 实际发送邮件
Args: Args:
email_data: 邮件数据 email_data: 邮件数据
Returns: Returns:
是否发送成功 是否发送成功
""" """
try: try:
# 如果邮件发送功能被禁用,则只记录日志 # 如果邮件发送功能被禁用,则只记录日志
if not getattr(settings, 'enable_email_sending', True): if not getattr(settings, "enable_email_sending", True):
logger.info(f"[Mock Email] Would send to {email_data.get('to_email')}: {email_data.get('subject')}") logger.info(f"[Mock Email] Would send to {email_data.get('to_email')}: {email_data.get('subject')}")
return True return True
# 创建邮件 # 创建邮件
msg = MIMEMultipart('alternative') msg = MIMEMultipart("alternative")
msg['From'] = f"{self.from_name} <{self.from_email}>" msg["From"] = f"{self.from_name} <{self.from_email}>"
msg['To'] = email_data.get('to_email', '') msg["To"] = email_data.get("to_email", "")
msg['Subject'] = email_data.get('subject', '') msg["Subject"] = email_data.get("subject", "")
# 添加纯文本内容 # 添加纯文本内容
content = email_data.get('content', '') content = email_data.get("content", "")
if content: if content:
msg.attach(MIMEText(content, 'plain', 'utf-8')) msg.attach(MIMEText(content, "plain", "utf-8"))
# 添加HTML内容如果有 # 添加HTML内容如果有
html_content = email_data.get('html_content', '') html_content = email_data.get("html_content", "")
if html_content: if html_content:
msg.attach(MIMEText(html_content, 'html', 'utf-8')) msg.attach(MIMEText(html_content, "html", "utf-8"))
# 发送邮件 # 发送邮件
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server: with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
if self.smtp_username and self.smtp_password: if self.smtp_username and self.smtp_password:
server.starttls() server.starttls()
server.login(self.smtp_username, self.smtp_password) server.login(self.smtp_username, self.smtp_password)
server.send_message(msg) server.send_message(msg)
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to send email: {e}") logger.error(f"Failed to send email: {e}")
return False return False
@@ -267,10 +260,12 @@ class EmailQueue:
# 全局邮件队列实例 # 全局邮件队列实例
email_queue = EmailQueue() email_queue = EmailQueue()
# 在应用启动时调用 # 在应用启动时调用
async def start_email_processor(): async def start_email_processor():
await email_queue.start_processing() await email_queue.start_processing()
# 在应用关闭时调用 # 在应用关闭时调用
async def stop_email_processor(): async def stop_email_processor():
await email_queue.stop_processing() await email_queue.stop_processing()

View File

@@ -4,13 +4,11 @@
from __future__ import annotations from __future__ import annotations
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
import secrets import secrets
import smtplib
import string import string
from datetime import datetime, UTC, timedelta
from typing import Optional
from app.config import settings from app.config import settings
from app.log import logger from app.log import logger
@@ -18,28 +16,28 @@ from app.log import logger
class EmailService: class EmailService:
"""邮件发送服务""" """邮件发送服务"""
def __init__(self): def __init__(self):
self.smtp_server = getattr(settings, 'smtp_server', 'localhost') self.smtp_server = getattr(settings, "smtp_server", "localhost")
self.smtp_port = getattr(settings, 'smtp_port', 587) self.smtp_port = getattr(settings, "smtp_port", 587)
self.smtp_username = getattr(settings, 'smtp_username', '') self.smtp_username = getattr(settings, "smtp_username", "")
self.smtp_password = getattr(settings, 'smtp_password', '') self.smtp_password = getattr(settings, "smtp_password", "")
self.from_email = getattr(settings, 'from_email', 'noreply@example.com') self.from_email = getattr(settings, "from_email", "noreply@example.com")
self.from_name = getattr(settings, 'from_name', 'osu! server') self.from_name = getattr(settings, "from_name", "osu! server")
def generate_verification_code(self) -> str: def generate_verification_code(self) -> str:
"""生成8位验证码""" """生成8位验证码"""
# 只使用数字,避免混淆 # 只使用数字,避免混淆
return ''.join(secrets.choice(string.digits) for _ in range(8)) return "".join(secrets.choice(string.digits) for _ in range(8))
async def send_verification_email(self, email: str, code: str, username: str) -> bool: async def send_verification_email(self, email: str, code: str, username: str) -> bool:
"""发送验证邮件""" """发送验证邮件"""
try: try:
msg = MIMEMultipart() msg = MIMEMultipart()
msg['From'] = f"{self.from_name} <{self.from_email}>" msg["From"] = f"{self.from_name} <{self.from_email}>"
msg['To'] = email msg["To"] = email
msg['Subject'] = "邮箱验证 - Email Verification" msg["Subject"] = "邮箱验证 - Email Verification"
# HTML 邮件内容 # HTML 邮件内容
html_content = f""" html_content = f"""
<!DOCTYPE html> <!DOCTYPE html>
@@ -101,15 +99,15 @@ class EmailService:
<h1>osu! 邮箱验证</h1> <h1>osu! 邮箱验证</h1>
<p>Email Verification</p> <p>Email Verification</p>
</div> </div>
<div class="content"> <div class="content">
<h2>你好 {username}</h2> <h2>你好 {username}</h2>
<p>感谢你注册我们的 osu! 服务器。为了完成账户验证,请输入以下验证码:</p> <p>感谢你注册我们的 osu! 服务器。为了完成账户验证,请输入以下验证码:</p>
<div class="code">{code}</div> <div class="code">{code}</div>
<p>这个验证码将在 <strong>10 分钟后过期</strong>。</p> <p>这个验证码将在 <strong>10 分钟后过期</strong>。</p>
<div class="warning"> <div class="warning">
<strong>注意:</strong> <strong>注意:</strong>
<ul> <ul>
@@ -118,19 +116,19 @@ class EmailService:
<li>验证码只能使用一次</li> <li>验证码只能使用一次</li>
</ul> </ul>
</div> </div>
<p>如果你有任何问题,请联系我们的支持团队。</p> <p>如果你有任何问题,请联系我们的支持团队。</p>
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;"> <hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
<h3>Hello {username}!</h3> <h3>Hello {username}!</h3>
<p>Thank you for registering on our osu! server. To complete your account verification, please enter the following verification code:</p> <p>Thank you for registering on our osu! server. To complete your account verification, please enter the following verification code:</p>
<p>This verification code will expire in <strong>10 minutes</strong>.</p> <p>This verification code will expire in <strong>10 minutes</strong>.</p>
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p> <p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
</div> </div>
<div class="footer"> <div class="footer">
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p> <p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
<p>This email was sent automatically, please do not reply.</p> <p>This email was sent automatically, please do not reply.</p>
@@ -138,26 +136,26 @@ class EmailService:
</div> </div>
</body> </body>
</html> </html>
""" """ # noqa: E501
msg.attach(MIMEText(html_content, 'html', 'utf-8')) msg.attach(MIMEText(html_content, "html", "utf-8"))
# 发送邮件 # 发送邮件
if not settings.enable_email_sending: if not settings.enable_email_sending:
# 邮件发送功能禁用时只记录日志,不实际发送 # 邮件发送功能禁用时只记录日志,不实际发送
logger.info(f"[Email Verification] Mock sending verification code to {email}: {code}") logger.info(f"[Email Verification] Mock sending verification code to {email}: {code}")
return True return True
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server: with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
if self.smtp_username and self.smtp_password: if self.smtp_username and self.smtp_password:
server.starttls() server.starttls()
server.login(self.smtp_username, self.smtp_password) server.login(self.smtp_username, self.smtp_password)
server.send_message(msg) server.send_message(msg)
logger.info(f"[Email Verification] Successfully sent verification code to {email}") logger.info(f"[Email Verification] Successfully sent verification code to {email}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"[Email Verification] Failed to send email: {e}") logger.error(f"[Email Verification] Failed to send email: {e}")
return False return False

View File

@@ -4,40 +4,38 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime, timedelta
import secrets import secrets
import string import string
from datetime import datetime, UTC, timedelta
from typing import Optional
from app.database.email_verification import EmailVerification, LoginSession
from app.service.email_service import email_service
from app.service.email_queue import email_queue # 导入邮件队列
from app.log import logger
from app.config import settings from app.config import settings
from app.database.email_verification import EmailVerification, LoginSession
from app.log import logger
from app.service.email_queue import email_queue # 导入邮件队列
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import select
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
class EmailVerificationService: class EmailVerificationService:
"""邮件验证服务""" """邮件验证服务"""
@staticmethod @staticmethod
def generate_verification_code() -> str: def generate_verification_code() -> str:
"""生成8位验证码""" """生成8位验证码"""
return ''.join(secrets.choice(string.digits) for _ in range(8)) return "".join(secrets.choice(string.digits) for _ in range(8))
@staticmethod @staticmethod
async def send_verification_email_via_queue(email: str, code: str, username: str, user_id: int) -> bool: async def send_verification_email_via_queue(email: str, code: str, username: str, user_id: int) -> bool:
"""使用邮件队列发送验证邮件 """使用邮件队列发送验证邮件
Args: Args:
email: 接收验证码的邮箱地址 email: 接收验证码的邮箱地址
code: 验证码 code: 验证码
username: 用户名 username: 用户名
user_id: 用户ID user_id: 用户ID
Returns: Returns:
是否成功将邮件加入队列 是否成功将邮件加入队列
""" """
@@ -103,15 +101,15 @@ class EmailVerificationService:
<h1>osu! 邮箱验证</h1> <h1>osu! 邮箱验证</h1>
<p>Email Verification</p> <p>Email Verification</p>
</div> </div>
<div class="content"> <div class="content">
<h2>你好 {username}</h2> <h2>你好 {username}</h2>
<p>请使用以下验证码验证您的账户:</p> <p>请使用以下验证码验证您的账户:</p>
<div class="code">{code}</div> <div class="code">{code}</div>
<p>验证码将在 <strong>10 分钟内有效</strong>。</p> <p>验证码将在 <strong>10 分钟内有效</strong>。</p>
<div class="warning"> <div class="warning">
<p><strong>重要提示:</strong></p> <p><strong>重要提示:</strong></p>
<ul> <ul>
@@ -120,17 +118,17 @@ class EmailVerificationService:
<li>为了账户安全,请勿在其他网站使用相同的密码</li> <li>为了账户安全,请勿在其他网站使用相同的密码</li>
</ul> </ul>
</div> </div>
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;"> <hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
<h3>Hello {username}!</h3> <h3>Hello {username}!</h3>
<p>Please use the following verification code to verify your account:</p> <p>Please use the following verification code to verify your account:</p>
<p>This verification code will be valid for <strong>10 minutes</strong>.</p> <p>This verification code will be valid for <strong>10 minutes</strong>.</p>
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p> <p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
</div> </div>
<div class="footer"> <div class="footer">
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p> <p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
<p>This email was sent automatically, please do not reply.</p> <p>This email was sent automatically, please do not reply.</p>
@@ -138,8 +136,8 @@ class EmailVerificationService:
</div> </div>
</body> </body>
</html> </html>
""" """ # noqa: E501
# 纯文本备用内容 # 纯文本备用内容
plain_content = f""" plain_content = f"""
你好 {username} 你好 {username}
@@ -162,34 +160,30 @@ This verification code will be valid for 10 minutes.
© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。 © 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。
This email was sent automatically, please do not reply. This email was sent automatically, please do not reply.
""" """
# 将邮件加入队列 # 将邮件加入队列
subject = "邮箱验证 - Email Verification" subject = "邮箱验证 - Email Verification"
metadata = { metadata = {"type": "email_verification", "user_id": user_id, "code": code}
"type": "email_verification",
"user_id": user_id,
"code": code
}
await email_queue.enqueue_email( await email_queue.enqueue_email(
to_email=email, to_email=email,
subject=subject, subject=subject,
content=plain_content, content=plain_content,
html_content=html_content, html_content=html_content,
metadata=metadata metadata=metadata,
) )
return True return True
except Exception as e: except Exception as e:
logger.error(f"[Email Verification] Failed to enqueue email: {e}") logger.error(f"[Email Verification] Failed to enqueue email: {e}")
return False return False
@staticmethod @staticmethod
def generate_session_token() -> str: def generate_session_token() -> str:
"""生成会话令牌""" """生成会话令牌"""
return secrets.token_urlsafe(32) return secrets.token_urlsafe(32)
@staticmethod @staticmethod
async def create_verification_record( async def create_verification_record(
db: AsyncSession, db: AsyncSession,
@@ -197,27 +191,27 @@ This email was sent automatically, please do not reply.
user_id: int, user_id: int,
email: str, email: str,
ip_address: str | None = None, ip_address: str | None = None,
user_agent: str | None = None user_agent: str | None = None,
) -> tuple[EmailVerification, str]: ) -> tuple[EmailVerification, str]:
"""创建邮件验证记录""" """创建邮件验证记录"""
# 检查是否有未过期的验证码 # 检查是否有未过期的验证码
existing_result = await db.exec( existing_result = await db.exec(
select(EmailVerification).where( select(EmailVerification).where(
EmailVerification.user_id == user_id, EmailVerification.user_id == user_id,
EmailVerification.is_used == False, col(EmailVerification.is_used).is_(False),
EmailVerification.expires_at > datetime.now(UTC) EmailVerification.expires_at > datetime.now(UTC),
) )
) )
existing = existing_result.first() existing = existing_result.first()
if existing: if existing:
# 如果有未过期的验证码,直接返回 # 如果有未过期的验证码,直接返回
return existing, existing.verification_code return existing, existing.verification_code
# 生成新的验证码 # 生成新的验证码
code = EmailVerificationService.generate_verification_code() code = EmailVerificationService.generate_verification_code()
# 创建验证记录 # 创建验证记录
verification = EmailVerification( verification = EmailVerification(
user_id=user_id, user_id=user_id,
@@ -225,23 +219,23 @@ This email was sent automatically, please do not reply.
verification_code=code, verification_code=code,
expires_at=datetime.now(UTC) + timedelta(minutes=10), # 10分钟过期 expires_at=datetime.now(UTC) + timedelta(minutes=10), # 10分钟过期
ip_address=ip_address, ip_address=ip_address,
user_agent=user_agent user_agent=user_agent,
) )
db.add(verification) db.add(verification)
await db.commit() await db.commit()
await db.refresh(verification) await db.refresh(verification)
# 存储到 Redis用于快速验证 # 存储到 Redis用于快速验证
await redis.setex( await redis.setex(
f"email_verification:{user_id}:{code}", f"email_verification:{user_id}:{code}",
600, # 10分钟过期 600, # 10分钟过期
str(verification.id) if verification.id else "0" str(verification.id) if verification.id else "0",
) )
logger.info(f"[Email Verification] Created verification code for user {user_id}: {code}") logger.info(f"[Email Verification] Created verification code for user {user_id}: {code}")
return verification, code return verification, code
@staticmethod @staticmethod
async def send_verification_email( async def send_verification_email(
db: AsyncSession, db: AsyncSession,
@@ -250,7 +244,7 @@ This email was sent automatically, please do not reply.
username: str, username: str,
email: str, email: str,
ip_address: str | None = None, ip_address: str | None = None,
user_agent: str | None = None user_agent: str | None = None,
) -> bool: ) -> bool:
"""发送验证邮件""" """发送验证邮件"""
try: try:
@@ -258,33 +252,38 @@ This email was sent automatically, please do not reply.
if not settings.enable_email_verification: if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}") logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}")
return True # 返回成功,但不执行验证流程 return True # 返回成功,但不执行验证流程
# 创建验证记录 # 创建验证记录
verification, code = await EmailVerificationService.create_verification_record( (
verification,
code,
) = await EmailVerificationService.create_verification_record(
db, redis, user_id, email, ip_address, user_agent db, redis, user_id, email, ip_address, user_agent
) )
# 使用邮件队列发送验证邮件 # 使用邮件队列发送验证邮件
success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id) success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id)
if success: if success:
logger.info(f"[Email Verification] Successfully enqueued verification email to {email} (user: {username})") logger.info(
f"[Email Verification] Successfully enqueued verification email to {email} (user: {username})"
)
return True return True
else: else:
logger.error(f"[Email Verification] Failed to enqueue verification email: {email} (user: {username})") logger.error(f"[Email Verification] Failed to enqueue verification email: {email} (user: {username})")
return False return False
except Exception as e: except Exception as e:
logger.error(f"[Email Verification] Exception during sending verification email: {e}") logger.error(f"[Email Verification] Exception during sending verification email: {e}")
return False return False
@staticmethod @staticmethod
async def verify_code( async def verify_code(
db: AsyncSession, db: AsyncSession,
redis: Redis, redis: Redis,
user_id: int, user_id: int,
code: str, code: str,
ip_address: str | None = None ip_address: str | None = None,
) -> tuple[bool, str]: ) -> tuple[bool, str]:
"""验证验证码""" """验证验证码"""
try: try:
@@ -294,46 +293,46 @@ This email was sent automatically, please do not reply.
# 仍然标记登录会话为已验证 # 仍然标记登录会话为已验证
await LoginSessionService.mark_session_verified(db, user_id) await LoginSessionService.mark_session_verified(db, user_id)
return True, "验证成功(邮件验证功能已禁用)" return True, "验证成功(邮件验证功能已禁用)"
# 先从 Redis 检查 # 先从 Redis 检查
verification_id = await redis.get(f"email_verification:{user_id}:{code}") verification_id = await redis.get(f"email_verification:{user_id}:{code}")
if not verification_id: if not verification_id:
return False, "验证码无效或已过期" return False, "验证码无效或已过期"
# 从数据库获取验证记录 # 从数据库获取验证记录
result = await db.exec( result = await db.exec(
select(EmailVerification).where( select(EmailVerification).where(
EmailVerification.id == int(verification_id), EmailVerification.id == int(verification_id),
EmailVerification.user_id == user_id, EmailVerification.user_id == user_id,
EmailVerification.verification_code == code, EmailVerification.verification_code == code,
EmailVerification.is_used == False, col(EmailVerification.is_used).is_(False),
EmailVerification.expires_at > datetime.now(UTC) EmailVerification.expires_at > datetime.now(UTC),
) )
) )
verification = result.first() verification = result.first()
if not verification: if not verification:
return False, "验证码无效或已过期" return False, "验证码无效或已过期"
# 标记为已使用 # 标记为已使用
verification.is_used = True verification.is_used = True
verification.used_at = datetime.now(UTC) verification.used_at = datetime.now(UTC)
# 同时更新对应的登录会话状态 # 同时更新对应的登录会话状态
await LoginSessionService.mark_session_verified(db, user_id) await LoginSessionService.mark_session_verified(db, user_id)
await db.commit() await db.commit()
# 删除 Redis 记录 # 删除 Redis 记录
await redis.delete(f"email_verification:{user_id}:{code}") await redis.delete(f"email_verification:{user_id}:{code}")
logger.info(f"[Email Verification] User {user_id} verification code verified successfully") logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
return True, "验证成功" return True, "验证成功"
except Exception as e: except Exception as e:
logger.error(f"[Email Verification] Exception during verification code validation: {e}") logger.error(f"[Email Verification] Exception during verification code validation: {e}")
return False, "验证过程中发生错误" return False, "验证过程中发生错误"
@staticmethod @staticmethod
async def resend_verification_code( async def resend_verification_code(
db: AsyncSession, db: AsyncSession,
@@ -342,7 +341,7 @@ This email was sent automatically, please do not reply.
username: str, username: str,
email: str, email: str,
ip_address: str | None = None, ip_address: str | None = None,
user_agent: str | None = None user_agent: str | None = None,
) -> tuple[bool, str]: ) -> tuple[bool, str]:
"""重新发送验证码""" """重新发送验证码"""
try: try:
@@ -350,25 +349,25 @@ This email was sent automatically, please do not reply.
if not settings.enable_email_verification: if not settings.enable_email_verification:
logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}") logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}")
return True, "验证码已发送(邮件验证功能已禁用)" return True, "验证码已发送(邮件验证功能已禁用)"
# 检查重发频率限制60秒内只能发送一次 # 检查重发频率限制60秒内只能发送一次
rate_limit_key = f"email_verification_rate_limit:{user_id}" rate_limit_key = f"email_verification_rate_limit:{user_id}"
if await redis.get(rate_limit_key): if await redis.get(rate_limit_key):
return False, "请等待60秒后再重新发送" return False, "请等待60秒后再重新发送"
# 设置频率限制 # 设置频率限制
await redis.setex(rate_limit_key, 60, "1") await redis.setex(rate_limit_key, 60, "1")
# 生成新的验证码 # 生成新的验证码
success = await EmailVerificationService.send_verification_email( success = await EmailVerificationService.send_verification_email(
db, redis, user_id, username, email, ip_address, user_agent db, redis, user_id, username, email, ip_address, user_agent
) )
if success: if success:
return True, "验证码已重新发送" return True, "验证码已重新发送"
else: else:
return False, "重新发送失败,请稍后再试" return False, "重新发送失败,请稍后再试"
except Exception as e: except Exception as e:
logger.error(f"[Email Verification] Exception during resending verification code: {e}") logger.error(f"[Email Verification] Exception during resending verification code: {e}")
return False, "重新发送过程中发生错误" return False, "重新发送过程中发生错误"
@@ -376,7 +375,7 @@ This email was sent automatically, please do not reply.
class LoginSessionService: class LoginSessionService:
"""登录会话服务""" """登录会话服务"""
@staticmethod @staticmethod
async def create_session( async def create_session(
db: AsyncSession, db: AsyncSession,
@@ -385,47 +384,40 @@ class LoginSessionService:
ip_address: str, ip_address: str,
user_agent: str | None = None, user_agent: str | None = None,
country_code: str | None = None, country_code: str | None = None,
is_new_location: bool = False is_new_location: bool = False,
) -> LoginSession: ) -> LoginSession:
"""创建登录会话""" """创建登录会话"""
from app.utils import simplify_user_agent
session_token = EmailVerificationService.generate_session_token() session_token = EmailVerificationService.generate_session_token()
# 简化 User-Agent 字符串
simplified_user_agent = simplify_user_agent(user_agent, max_length=250)
session = LoginSession( session = LoginSession(
user_id=user_id, user_id=user_id,
session_token=session_token, session_token=session_token,
ip_address=ip_address, ip_address=ip_address,
user_agent=simplified_user_agent, user_agent=None,
country_code=country_code, country_code=country_code,
is_new_location=is_new_location, is_new_location=is_new_location,
expires_at=datetime.now(UTC) + timedelta(hours=24), # 24小时过期 expires_at=datetime.now(UTC) + timedelta(hours=24), # 24小时过期
is_verified=not is_new_location # 新位置需要验证 is_verified=not is_new_location, # 新位置需要验证
) )
db.add(session) db.add(session)
await db.commit() await db.commit()
await db.refresh(session) await db.refresh(session)
# 存储到 Redis # 存储到 Redis
await redis.setex( await redis.setex(
f"login_session:{session_token}", f"login_session:{session_token}",
86400, # 24小时 86400, # 24小时
user_id user_id,
) )
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})") logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
return session return session
@staticmethod @staticmethod
async def verify_session( async def verify_session(
db: AsyncSession, db: AsyncSession, redis: Redis, session_token: str, verification_code: str
redis: Redis,
session_token: str,
verification_code: str
) -> tuple[bool, str]: ) -> tuple[bool, str]:
"""验证会话(通过邮件验证码)""" """验证会话(通过邮件验证码)"""
try: try:
@@ -433,98 +425,89 @@ class LoginSessionService:
user_id = await redis.get(f"login_session:{session_token}") user_id = await redis.get(f"login_session:{session_token}")
if not user_id: if not user_id:
return False, "会话无效或已过期" return False, "会话无效或已过期"
user_id = int(user_id) user_id = int(user_id)
# 验证邮件验证码 # 验证邮件验证码
success, message = await EmailVerificationService.verify_code( success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_code)
db, redis, user_id, verification_code
)
if not success: if not success:
return False, message return False, message
# 更新会话状态 # 更新会话状态
result = await db.exec( result = await db.exec(
select(LoginSession).where( select(LoginSession).where(
LoginSession.session_token == session_token, LoginSession.session_token == session_token,
LoginSession.user_id == user_id, LoginSession.user_id == user_id,
LoginSession.is_verified == False col(LoginSession.is_verified).is_(False),
) )
) )
session = result.first() session = result.first()
if session: if session:
session.is_verified = True session.is_verified = True
session.verified_at = datetime.now(UTC) session.verified_at = datetime.now(UTC)
await db.commit() await db.commit()
logger.info(f"[Login Session] User {user_id} session verification successful") logger.info(f"[Login Session] User {user_id} session verification successful")
return True, "会话验证成功" return True, "会话验证成功"
except Exception as e: except Exception as e:
logger.error(f"[Login Session] Exception during session verification: {e}") logger.error(f"[Login Session] Exception during session verification: {e}")
return False, "验证过程中发生错误" return False, "验证过程中发生错误"
@staticmethod @staticmethod
async def check_new_location( async def check_new_location(
db: AsyncSession, db: AsyncSession, user_id: int, ip_address: str, country_code: str | None = None
user_id: int,
ip_address: str,
country_code: str | None = None
) -> bool: ) -> bool:
"""检查是否为新位置登录""" """检查是否为新位置登录"""
try: try:
# 查看过去30天内是否有相同IP或相同国家的登录记录 # 查看过去30天内是否有相同IP或相同国家的登录记录
thirty_days_ago = datetime.now(UTC) - timedelta(days=30) thirty_days_ago = datetime.now(UTC) - timedelta(days=30)
result = await db.exec( result = await db.exec(
select(LoginSession).where( select(LoginSession).where(
LoginSession.user_id == user_id, LoginSession.user_id == user_id,
LoginSession.created_at > thirty_days_ago, LoginSession.created_at > thirty_days_ago,
(LoginSession.ip_address == ip_address) | (LoginSession.ip_address == ip_address) | (LoginSession.country_code == country_code),
(LoginSession.country_code == country_code)
) )
) )
existing_sessions = result.all() existing_sessions = result.all()
# 如果有历史记录,则不是新位置 # 如果有历史记录,则不是新位置
return len(existing_sessions) == 0 return len(existing_sessions) == 0
except Exception as e: except Exception as e:
logger.error(f"[Login Session] Exception during new location check: {e}") logger.error(f"[Login Session] Exception during new location check: {e}")
# 出错时默认为新位置(更安全) # 出错时默认为新位置(更安全)
return True return True
@staticmethod @staticmethod
async def mark_session_verified( async def mark_session_verified(db: AsyncSession, user_id: int) -> bool:
db: AsyncSession,
user_id: int
) -> bool:
"""标记用户的未验证会话为已验证""" """标记用户的未验证会话为已验证"""
try: try:
# 查找用户所有未验证且未过期的会话 # 查找用户所有未验证且未过期的会话
result = await db.exec( result = await db.exec(
select(LoginSession).where( select(LoginSession).where(
LoginSession.user_id == user_id, LoginSession.user_id == user_id,
LoginSession.is_verified == False, col(LoginSession.is_verified).is_(False),
LoginSession.expires_at > datetime.now(UTC) LoginSession.expires_at > datetime.now(UTC),
) )
) )
sessions = result.all() sessions = result.all()
# 标记所有会话为已验证 # 标记所有会话为已验证
for session in sessions: for session in sessions:
session.is_verified = True session.is_verified = True
session.verified_at = datetime.now(UTC) session.verified_at = datetime.now(UTC)
if sessions: if sessions:
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}") logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
return len(sessions) > 0 return len(sessions) > 0
except Exception as e: except Exception as e:
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}") logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
return False return False

View File

@@ -117,14 +117,10 @@ class EnhancedIntervalStatsManager:
@staticmethod @staticmethod
async def get_current_interval_info() -> IntervalInfo: async def get_current_interval_info() -> IntervalInfo:
"""获取当前区间信息""" """获取当前区间信息"""
start_time, end_time = ( start_time, end_time = EnhancedIntervalStatsManager.get_current_interval_boundaries()
EnhancedIntervalStatsManager.get_current_interval_boundaries()
)
interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time) interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time)
return IntervalInfo( return IntervalInfo(start_time=start_time, end_time=end_time, interval_key=interval_key)
start_time=start_time, end_time=end_time, interval_key=interval_key
)
@staticmethod @staticmethod
async def initialize_current_interval() -> None: async def initialize_current_interval() -> None:
@@ -133,9 +129,7 @@ class EnhancedIntervalStatsManager:
redis_async = get_redis() redis_async = get_redis()
try: try:
current_interval = ( current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
await EnhancedIntervalStatsManager.get_current_interval_info()
)
# 存储当前区间信息 # 存储当前区间信息
await _redis_exec( await _redis_exec(
@@ -147,9 +141,7 @@ class EnhancedIntervalStatsManager:
# 初始化区间用户集合(如果不存在) # 初始化区间用户集合(如果不存在)
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}" online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
playing_key = ( playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
)
# 设置过期时间为35分钟 # 设置过期时间为35分钟
await redis_async.expire(online_key, 35 * 60) await redis_async.expire(online_key, 35 * 60)
@@ -179,7 +171,8 @@ class EnhancedIntervalStatsManager:
await EnhancedIntervalStatsManager._ensure_24h_history_exists() await EnhancedIntervalStatsManager._ensure_24h_history_exists()
logger.info( logger.info(
f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')} - {current_interval.end_time.strftime('%H:%M')}" f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')}"
f" - {current_interval.end_time.strftime('%H:%M')}"
) )
except Exception as e: except Exception as e:
@@ -193,42 +186,32 @@ class EnhancedIntervalStatsManager:
try: try:
# 检查现有历史数据数量 # 检查现有历史数据数量
history_length = await _redis_exec( history_length = await _redis_exec(redis_sync.llen, REDIS_ONLINE_HISTORY_KEY)
redis_sync.llen, REDIS_ONLINE_HISTORY_KEY
)
if history_length < 48: # 少于48个数据点24小时*2 if history_length < 48: # 少于48个数据点24小时*2
logger.info( logger.info(f"History has only {history_length} points, filling with zeros for 24h")
f"History has only {history_length} points, filling with zeros for 24h"
)
# 计算需要填充的数据点数量 # 计算需要填充的数据点数量
needed_points = 48 - history_length needed_points = 48 - history_length
# 从当前时间往前推创建缺失的时间点都填充为0 # 从当前时间往前推创建缺失的时间点都填充为0
current_time = datetime.utcnow() current_time = datetime.utcnow() # noqa: F841
current_interval_start, _ = ( current_interval_start, _ = EnhancedIntervalStatsManager.get_current_interval_boundaries()
EnhancedIntervalStatsManager.get_current_interval_boundaries()
)
# 从当前区间开始往前推创建历史数据点确保时间对齐到30分钟边界 # 从当前区间开始往前推创建历史数据点确保时间对齐到30分钟边界
fill_points = [] fill_points = []
for i in range(needed_points): for i in range(needed_points):
# 每次往前推30分钟确保时间对齐 # 每次往前推30分钟确保时间对齐
point_time = current_interval_start - timedelta( point_time = current_interval_start - timedelta(minutes=30 * (i + 1))
minutes=30 * (i + 1)
)
# 确保时间对齐到30分钟边界 # 确保时间对齐到30分钟边界
aligned_minute = (point_time.minute // 30) * 30 aligned_minute = (point_time.minute // 30) * 30
point_time = point_time.replace( point_time = point_time.replace(minute=aligned_minute, second=0, microsecond=0)
minute=aligned_minute, second=0, microsecond=0
)
history_point = { history_point = {
"timestamp": point_time.isoformat(), "timestamp": point_time.isoformat(),
"online_count": 0, "online_count": 0,
"playing_count": 0 "playing_count": 0,
} }
fill_points.append(json.dumps(history_point)) fill_points.append(json.dumps(history_point))
@@ -238,9 +221,7 @@ class EnhancedIntervalStatsManager:
temp_key = f"{REDIS_ONLINE_HISTORY_KEY}_temp" temp_key = f"{REDIS_ONLINE_HISTORY_KEY}_temp"
if history_length > 0: if history_length > 0:
# 复制现有数据到临时key # 复制现有数据到临时key
existing_data = await _redis_exec( existing_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1
)
if existing_data: if existing_data:
for data in existing_data: for data in existing_data:
await _redis_exec(redis_sync.rpush, temp_key, data) await _redis_exec(redis_sync.rpush, temp_key, data)
@@ -250,19 +231,13 @@ class EnhancedIntervalStatsManager:
# 先添加填充数据(最旧的) # 先添加填充数据(最旧的)
for point in reversed(fill_points): # 反向添加,最旧的在最后 for point in reversed(fill_points): # 反向添加,最旧的在最后
await _redis_exec( await _redis_exec(redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point)
redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point
)
# 再添加原有数据(较新的) # 再添加原有数据(较新的)
if history_length > 0: if history_length > 0:
existing_data = await _redis_exec( existing_data = await _redis_exec(redis_sync.lrange, temp_key, 0, -1)
redis_sync.lrange, temp_key, 0, -1
)
for data in existing_data: for data in existing_data:
await _redis_exec( await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data)
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data
)
# 清理临时key # 清理临时key
await redis_async.delete(temp_key) await redis_async.delete(temp_key)
@@ -273,9 +248,7 @@ class EnhancedIntervalStatsManager:
# 设置过期时间 # 设置过期时间
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600) await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
logger.info( logger.info(f"Filled {len(fill_points)} historical data points with zeros")
f"Filled {len(fill_points)} historical data points with zeros"
)
except Exception as e: except Exception as e:
logger.error(f"Error ensuring 24h history exists: {e}") logger.error(f"Error ensuring 24h history exists: {e}")
@@ -287,9 +260,7 @@ class EnhancedIntervalStatsManager:
redis_async = get_redis() redis_async = get_redis()
try: try:
current_interval = ( current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
await EnhancedIntervalStatsManager.get_current_interval_info()
)
# 添加到区间在线用户集合 # 添加到区间在线用户集合
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}" online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
@@ -298,9 +269,7 @@ class EnhancedIntervalStatsManager:
# 如果用户在游玩,也添加到游玩用户集合 # 如果用户在游玩,也添加到游玩用户集合
if is_playing: if is_playing:
playing_key = ( playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
)
await _redis_exec(redis_sync.sadd, playing_key, str(user_id)) await _redis_exec(redis_sync.sadd, playing_key, str(user_id))
await redis_async.expire(playing_key, 35 * 60) await redis_async.expire(playing_key, 35 * 60)
@@ -308,7 +277,8 @@ class EnhancedIntervalStatsManager:
await EnhancedIntervalStatsManager._update_interval_stats() await EnhancedIntervalStatsManager._update_interval_stats()
logger.debug( logger.debug(
f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}-{current_interval.end_time.strftime('%H:%M')}" f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}"
f"-{current_interval.end_time.strftime('%H:%M')}"
) )
except Exception as e: except Exception as e:
@@ -321,15 +291,11 @@ class EnhancedIntervalStatsManager:
redis_async = get_redis() redis_async = get_redis()
try: try:
current_interval = ( current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
await EnhancedIntervalStatsManager.get_current_interval_info()
)
# 获取区间内独特用户数 # 获取区间内独特用户数
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}" online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
playing_key = ( playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
)
unique_online = await _redis_exec(redis_sync.scard, online_key) unique_online = await _redis_exec(redis_sync.scard, online_key)
unique_playing = await _redis_exec(redis_sync.scard, playing_key) unique_playing = await _redis_exec(redis_sync.scard, playing_key)
@@ -339,16 +305,12 @@ class EnhancedIntervalStatsManager:
current_playing = await _get_playing_users_count(redis_async) current_playing = await _get_playing_users_count(redis_async)
# 获取现有统计数据 # 获取现有统计数据
existing_data = await _redis_exec( existing_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
redis_sync.get, current_interval.interval_key
)
if existing_data: if existing_data:
stats = IntervalStats.from_dict(json.loads(existing_data)) stats = IntervalStats.from_dict(json.loads(existing_data))
# 更新峰值 # 更新峰值
stats.peak_online_count = max(stats.peak_online_count, current_online) stats.peak_online_count = max(stats.peak_online_count, current_online)
stats.peak_playing_count = max( stats.peak_playing_count = max(stats.peak_playing_count, current_playing)
stats.peak_playing_count, current_playing
)
stats.total_samples += 1 stats.total_samples += 1
else: else:
# 创建新的统计记录 # 创建新的统计记录
@@ -377,7 +339,8 @@ class EnhancedIntervalStatsManager:
await redis_async.expire(current_interval.interval_key, 35 * 60) await redis_async.expire(current_interval.interval_key, 35 * 60)
logger.debug( logger.debug(
f"Updated interval stats: online={unique_online}, playing={unique_playing}, peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}" f"Updated interval stats: online={unique_online}, playing={unique_playing}, "
f"peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}"
) )
except Exception as e: except Exception as e:
@@ -395,21 +358,21 @@ class EnhancedIntervalStatsManager:
# 上一个区间开始时间是当前区间开始时间减去30分钟 # 上一个区间开始时间是当前区间开始时间减去30分钟
previous_start = current_start - timedelta(minutes=30) previous_start = current_start - timedelta(minutes=30)
previous_end = current_start # 上一个区间的结束时间就是当前区间的开始时间 previous_end = current_start # 上一个区间的结束时间就是当前区间的开始时间
interval_key = EnhancedIntervalStatsManager.generate_interval_key(previous_start) interval_key = EnhancedIntervalStatsManager.generate_interval_key(previous_start)
previous_interval = IntervalInfo( previous_interval = IntervalInfo(
start_time=previous_start, start_time=previous_start,
end_time=previous_end, end_time=previous_end,
interval_key=interval_key interval_key=interval_key,
) )
# 获取最终统计数据 # 获取最终统计数据
stats_data = await _redis_exec( stats_data = await _redis_exec(redis_sync.get, previous_interval.interval_key)
redis_sync.get, previous_interval.interval_key
)
if not stats_data: if not stats_data:
logger.warning(f"No interval stats found to finalize for {previous_interval.start_time.strftime('%H:%M')}") logger.warning(
f"No interval stats found to finalize for {previous_interval.start_time.strftime('%H:%M')}"
)
return None return None
stats = IntervalStats.from_dict(json.loads(stats_data)) stats = IntervalStats.from_dict(json.loads(stats_data))
@@ -418,13 +381,11 @@ class EnhancedIntervalStatsManager:
history_point = { history_point = {
"timestamp": previous_interval.start_time.isoformat(), "timestamp": previous_interval.start_time.isoformat(),
"online_count": stats.unique_online_users, "online_count": stats.unique_online_users,
"playing_count": stats.unique_playing_users "playing_count": stats.unique_playing_users,
} }
# 添加到历史记录 # 添加到历史记录
await _redis_exec( await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)
)
# 只保留48个数据点24小时每30分钟一个点 # 只保留48个数据点24小时每30分钟一个点
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47) await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
# 设置过期时间为26小时确保有足够缓冲 # 设置过期时间为26小时确保有足够缓冲
@@ -452,12 +413,8 @@ class EnhancedIntervalStatsManager:
redis_sync = get_redis_message() redis_sync = get_redis_message()
try: try:
current_interval = ( current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
await EnhancedIntervalStatsManager.get_current_interval_info() stats_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
)
stats_data = await _redis_exec(
redis_sync.get, current_interval.interval_key
)
if stats_data: if stats_data:
return IntervalStats.from_dict(json.loads(stats_data)) return IntervalStats.from_dict(json.loads(stats_data))
@@ -506,8 +463,6 @@ class EnhancedIntervalStatsManager:
# 便捷函数,用于替换现有的统计更新函数 # 便捷函数,用于替换现有的统计更新函数
async def update_user_activity_in_interval( async def update_user_activity_in_interval(user_id: int, is_playing: bool = False) -> None:
user_id: int, is_playing: bool = False
) -> None:
"""用户活动时更新区间统计(在登录、开始游玩等时调用)""" """用户活动时更新区间统计(在登录、开始游玩等时调用)"""
await EnhancedIntervalStatsManager.add_user_to_interval(user_id, is_playing) await EnhancedIntervalStatsManager.add_user_to_interval(user_id, is_playing)

Some files were not shown because too many files have changed in this diff Show More