From 598fcc8b38b4ffd651e38f0e97116d1f4865ffcd Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Fri, 22 Aug 2025 08:21:52 +0000 Subject: [PATCH] refactor(project): make pyright & ruff happy --- app/achievements/hush_hush.py | 58 +-- app/achievements/mods.py | 2 +- app/achievements/osu_combo.py | 6 +- app/achievements/skill.py | 4 +- app/auth.py | 46 +-- app/calculator.py | 41 +-- app/config.py | 37 +- app/database/achievement.py | 20 +- app/database/auth.py | 22 +- app/database/beatmap.py | 58 +-- app/database/beatmap_playcounts.py | 14 +- app/database/beatmapset.py | 49 +-- app/database/best_score.py | 10 +- app/database/chat.py | 58 +-- app/database/counts.py | 12 +- app/database/daily_challenge.py | 10 +- app/database/email_verification.py | 15 +- app/database/events.py | 10 +- app/database/failtime.py | 12 +- app/database/favourite_beatmapset.py | 2 +- app/database/lazer_user.py | 117 ++---- app/database/multiplayer_event.py | 4 +- app/database/notification.py | 8 +- app/database/password_reset.py | 11 +- app/database/playlist_attempts.py | 10 +- app/database/playlist_best_score.py | 10 +- app/database/playlists.py | 26 +- app/database/pp_best_score.py | 10 +- app/database/rank_history.py | 16 +- app/database/relationship.py | 6 +- app/database/room.py | 22 +- app/database/room_participated_user.py | 14 +- app/database/score.py | 117 ++---- app/database/score_token.py | 15 +- app/database/statistics.py | 14 +- app/database/team.py | 32 +- app/database/user_account_history.py | 6 +- app/database/user_login_log.py | 38 +- app/dependencies/database.py | 8 +- app/dependencies/fetcher.py | 4 +- app/dependencies/scheduler.py | 3 +- app/dependencies/user.py | 12 +- app/fetcher/_base.py | 23 +- app/fetcher/beatmap.py | 8 +- app/fetcher/beatmap_raw.py | 10 +- app/fetcher/beatmapset.py | 91 ++--- app/helpers/geoip_helper.py | 9 +- app/helpers/rate_limiter.py | 9 +- app/log.py | 22 +- app/models/achievement.py | 10 +- app/models/api_me.py | 3 +- app/models/beatmap.py | 14 +- app/models/extended_auth.py | 4 +- app/models/metadata_hub.py | 4 +- app/models/mods.py | 11 +- app/models/multiplayer_hub.py | 151 ++------ app/models/notification.py | 12 +- app/models/score.py | 4 +- app/models/signalr.py | 4 +- app/models/spectator_hub.py | 4 +- app/router/auth.py | 159 +++----- app/router/notification/__init__.py | 19 +- app/router/notification/banchobot.py | 85 ++--- app/router/notification/channel.py | 105 ++---- app/router/notification/message.py | 72 +--- app/router/notification/server.py | 116 ++---- app/router/password_reset_admin.py | 110 ++---- app/router/private/oauth.py | 30 +- app/router/private/relationship.py | 8 +- app/router/private/team.py | 77 +--- app/router/private/username.py | 5 +- app/router/v1/beatmap.py | 30 +- app/router/v1/replay.py | 21 +- app/router/v1/router.py | 4 +- app/router/v1/score.py | 24 +- app/router/v1/user.py | 36 +- app/router/v2/__init__.py | 2 +- app/router/v2/beatmap.py | 65 +--- app/router/v2/beatmapset.py | 36 +- app/router/v2/me.py | 2 +- app/router/v2/misc.py | 4 +- app/router/v2/ranking.py | 287 +++++++-------- app/router/v2/relationship.py | 26 +- app/router/v2/room.py | 49 +-- app/router/v2/score.py | 143 ++----- app/router/v2/session_verify.py | 136 +++---- app/router/v2/stats.py | 169 +++++---- app/router/v2/user.py | 97 ++--- app/scheduler/cache_scheduler.py | 4 +- app/scheduler/database_cleanup_scheduler.py | 22 +- app/service/beatmap_cache_service.py | 13 +- app/service/beatmap_download_service.py | 20 +- app/service/calculate_all_user_rank.py | 4 +- app/service/create_banchobot.py | 4 +- app/service/daily_challenge.py | 36 +- app/service/database_cleanup_service.py | 192 +++++----- app/service/email_queue.py | 193 +++++----- app/service/email_service.py | 70 ++-- app/service/email_verification_service.py | 239 ++++++------ app/service/enhanced_interval_stats.py | 125 ++----- app/service/load_achievements.py | 8 +- app/service/login_log_service.py | 18 +- app/service/message_queue.py | 71 +--- app/service/message_queue_processor.py | 84 ++--- app/service/online_status_maintenance.py | 22 +- app/service/online_status_manager.py | 55 +-- app/service/optimized_message.py | 9 +- app/service/password_reset_service.py | 184 +++++---- app/service/ranking_cache_service.py | 38 +- app/service/recalculate.py | 48 +-- app/service/redis_message_system.py | 126 ++----- app/service/room.py | 16 +- app/service/session_manager.py | 35 +- app/service/stats_cleanup.py | 25 +- app/service/stats_scheduler.py | 22 +- app/service/subscribers/base.py | 9 +- app/service/subscribers/score_processed.py | 7 +- app/service/user_cache_service.py | 84 +---- app/signalr/hub/hub.py | 43 +-- app/signalr/hub/metadata.py | 66 ++-- app/signalr/hub/multiplayer.py | 348 +++++------------- app/signalr/hub/spectator.py | 151 +++----- app/signalr/packet.py | 74 +--- app/signalr/router.py | 8 +- app/signalr/store.py | 4 +- app/signalr/utils.py | 4 +- app/utils.py | 105 ++++-- main.py | 21 +- migrations/env.py | 4 +- ...6348cdfd2_add_email_verification_tables.py | 3 +- .../versions/198227d190b8_user_add_events.py | 4 +- .../19cdc9ce4dcb_gamemode_add_osurx_osupp.py | 28 +- ...d04d3f4dc_fix_user_login_log_table_name.py | 52 +-- .../3eef4794ded1_add_user_login_log_table.py | 48 +-- ...6c43d8601_notification_add_notification.py | 20 +- .../59c9a0827de0_beatmap_add_indexes.py | 40 +- ..._increase_the_length_limit_of_the_user_.py | 23 +- ...5e7dc8d5905_team_add_team_request_table.py | 8 +- .../7e9d5e012d37_auth_add_v1_keys_table.py | 4 +- ...d764a5_statistics_remove_level_progress.py | 4 +- ...1a2188e691_score_add_rx_for_taiko_catch.py | 44 +-- ...laylist_best_scores_remove_foreign_key_.py | 4 +- .../9f6b27e8ea51_add_table_banned_beatmaps.py | 4 +- ...a8669ba11e96_auth_support_custom_client.py | 12 +- ...13f905_count_add_replays_watched_counts.py | 12 +- .../b6a304d96a2d_user_support_rank.py | 28 +- ...f0a5674_beatmap_make_max_combo_nullable.py | 8 +- .../d103d442dc24_add_password_reset_table.py | 344 ++++++++++------- .../versions/dd33d89aa2c2_chat_add_chat.py | 20 +- .../df9f725a077c_room_add_channel_id.py | 16 +- ...49e18ca_achievement_remove_primary_key_.py | 24 +- ...onvert_event_event_payload_from_str_to_.py | 8 +- migrations/versions/fdb3822a30ba_init.py | 192 +++------- pyproject.toml | 5 +- test_spectator_buffer.py | 84 ----- tools/add_daily_challenge.py | 8 +- uv.lock | 15 + 157 files changed, 2382 insertions(+), 4590 deletions(-) delete mode 100644 test_spectator_buffer.py diff --git a/app/achievements/hush_hush.py b/app/achievements/hush_hush.py index df97c93..a887693 100644 --- a/app/achievements/hush_hush.py +++ b/app/achievements/hush_hush.py @@ -65,9 +65,7 @@ async def to_the_core( # using either of the mods specified: DT, NC if not score.passed: return False - if ( - "Nightcore" not in beatmap.beatmapset.title - ) and "Nightcore" not in beatmap.beatmapset.artist: + if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist: return False mods_ = mod_to_save(score.mods) if "DT" not in mods_ or "NC" not in mods_: @@ -118,9 +116,7 @@ async def reckless_adandon( fetcher = await get_fetcher() redis = get_redis() mods_ = score.mods.copy() - attribute = await calculate_beatmap_attributes( - beatmap.id, score.gamemode, mods_, redis, fetcher - ) + attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher) if attribute.star_rating < 3: return False return True @@ -186,9 +182,7 @@ async def slow_and_steady( fetcher = await get_fetcher() redis = get_redis() mods_ = score.mods.copy() - attribute = await calculate_beatmap_attributes( - beatmap.id, score.gamemode, mods_, redis, fetcher - ) + attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher) return attribute.star_rating >= 3 @@ -218,9 +212,7 @@ async def sognare( mods_ = mod_to_save(score.mods) if "HT" not in mods_: return False - return ( - beatmap.beatmapset.artist == "LeaF" and beatmap.beatmapset.title == "Evanescent" - ) + return beatmap.beatmapset.artist == "LeaF" and beatmap.beatmapset.title == "Evanescent" async def realtor_extraordinaire( @@ -234,10 +226,7 @@ async def realtor_extraordinaire( mods_ = mod_to_save(score.mods) if not ("DT" in mods_ or "NC" in mods_) or "HR" not in mods_: return False - return ( - beatmap.beatmapset.artist == "cYsmix" - and beatmap.beatmapset.title == "House With Legs" - ) + return beatmap.beatmapset.artist == "cYsmix" and beatmap.beatmapset.title == "House With Legs" async def impeccable( @@ -255,9 +244,7 @@ async def impeccable( fetcher = await get_fetcher() redis = get_redis() mods_ = score.mods.copy() - attribute = await calculate_beatmap_attributes( - beatmap.id, score.gamemode, mods_, redis, fetcher - ) + attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher) return attribute.star_rating >= 4 @@ -274,18 +261,14 @@ async def aeon( mods_ = mod_to_save(score.mods) if "FL" not in mods_ or "HD" not in mods_ or "HT" not in mods_: return False - if not beatmap.beatmapset.ranked_date or beatmap.beatmapset.ranked_date > datetime( - 2012, 1, 1 - ): + if not beatmap.beatmapset.ranked_date or beatmap.beatmapset.ranked_date > datetime(2012, 1, 1): return False if beatmap.total_length < 180: return False fetcher = await get_fetcher() redis = get_redis() mods_ = score.mods.copy() - attribute = await calculate_beatmap_attributes( - beatmap.id, score.gamemode, mods_, redis, fetcher - ) + attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher) return attribute.star_rating >= 4 @@ -297,10 +280,7 @@ async def quick_maths( # Get exactly 34 misses on any difficulty of Function Phantom - Variable. if score.nmiss != 34: return False - return ( - beatmap.beatmapset.artist == "Function Phantom" - and beatmap.beatmapset.title == "Variable" - ) + return beatmap.beatmapset.artist == "Function Phantom" and beatmap.beatmapset.title == "Variable" async def kaleidoscope( @@ -328,8 +308,7 @@ async def valediction( return ( score.passed and beatmap.beatmapset.artist == "a_hisa" - and beatmap.beatmapset.title - == "Alexithymia | Lupinus | Tokei no Heya to Seishin Sekai" + and beatmap.beatmapset.title == "Alexithymia | Lupinus | Tokei no Heya to Seishin Sekai" and score.accuracy >= 0.9 ) @@ -342,9 +321,7 @@ async def right_on_time( # Submit a score on Kola Kid - timer on the first minute of any hour if not score.passed: return False - if not ( - beatmap.beatmapset.artist == "Kola Kid" and beatmap.beatmapset.title == "timer" - ): + if not (beatmap.beatmapset.artist == "Kola Kid" and beatmap.beatmapset.title == "timer"): return False return score.ended_at.minute == 0 @@ -361,9 +338,7 @@ async def not_again( return False if score.accuracy < 0.99: return False - return ( - beatmap.beatmapset.artist == "ARForest" and beatmap.beatmapset.title == "Regret" - ) + return beatmap.beatmapset.artist == "ARForest" and beatmap.beatmapset.title == "Regret" async def deliberation( @@ -377,18 +352,13 @@ async def deliberation( mods_ = mod_to_save(score.mods) if "HT" not in mods_: return False - if ( - not beatmap.beatmap_status.has_pp() - and beatmap.beatmap_status != BeatmapRankStatus.LOVED - ): + if not beatmap.beatmap_status.has_pp() and beatmap.beatmap_status != BeatmapRankStatus.LOVED: return False fetcher = await get_fetcher() redis = get_redis() mods_copy = score.mods.copy() - attribute = await calculate_beatmap_attributes( - beatmap.id, score.gamemode, mods_copy, redis, fetcher - ) + attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_copy, redis, fetcher) return attribute.star_rating >= 6 diff --git a/app/achievements/mods.py b/app/achievements/mods.py index a0d2b25..d157a15 100644 --- a/app/achievements/mods.py +++ b/app/achievements/mods.py @@ -72,7 +72,7 @@ MEDALS: Medals = { Achievement( id=93, name="Sweet Rave Party", - desc="Founded in the fine tradition of changing things that were just fine as they were.", # noqa: E501 + desc="Founded in the fine tradition of changing things that were just fine as they were.", assets_id="all-intro-nightcore", ): partial(process_mod, "NC"), Achievement( diff --git a/app/achievements/osu_combo.py b/app/achievements/osu_combo.py index cb67276..25abbbd 100644 --- a/app/achievements/osu_combo.py +++ b/app/achievements/osu_combo.py @@ -16,11 +16,7 @@ async def process_combo( score: Score, beatmap: Beatmap, ) -> bool: - if ( - not score.passed - or not beatmap.beatmap_status.has_pp() - or score.gamemode != GameMode.OSU - ): + if not score.passed or not beatmap.beatmap_status.has_pp() or score.gamemode != GameMode.OSU: return False if combo < 1: return False diff --git a/app/achievements/skill.py b/app/achievements/skill.py index bc1f8e6..66123cc 100644 --- a/app/achievements/skill.py +++ b/app/achievements/skill.py @@ -44,9 +44,7 @@ async def process_skill( redis = get_redis() mods_ = score.mods.copy() mods_.sort(key=lambda x: x["acronym"]) - attribute = await calculate_beatmap_attributes( - beatmap.id, score.gamemode, mods_, redis, fetcher - ) + attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher) if attribute.star_rating < star or attribute.star_rating >= star + 1: return False if type == "fc" and not score.is_perfect_combo: diff --git a/app/auth.py b/app/auth.py index ef1b247..7f7c884 100644 --- a/app/auth.py +++ b/app/auth.py @@ -43,9 +43,7 @@ def validate_username(username: str) -> list[str]: # 检查用户名格式(只允许字母、数字、下划线、连字符) if not re.match(r"^[a-zA-Z0-9_-]+$", username): - errors.append( - "Username can only contain letters, numbers, underscores, and hyphens" - ) + errors.append("Username can only contain letters, numbers, underscores, and hyphens") # 检查是否以数字开头 if username[0].isdigit(): @@ -104,9 +102,7 @@ def get_password_hash(password: str) -> str: return pw_bcrypt.decode() -async def authenticate_user_legacy( - db: AsyncSession, name: str, password: str -) -> User | None: +async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) -> User | None: """ 验证用户身份 - 使用类似 from_login 的逻辑 """ @@ -145,9 +141,7 @@ async def authenticate_user_legacy( return None -async def authenticate_user( - db: AsyncSession, username: str, password: str -) -> User | None: +async def authenticate_user(db: AsyncSession, username: str, password: str) -> User | None: """验证用户身份""" return await authenticate_user_legacy(db, username, password) @@ -158,14 +152,10 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s if expires_delta: expire = datetime.now(UTC) + expires_delta else: - expire = datetime.now(UTC) + timedelta( - minutes=settings.access_token_expire_minutes - ) + expire = datetime.now(UTC) + timedelta(minutes=settings.access_token_expire_minutes) to_encode.update({"exp": expire, "random": secrets.token_hex(16)}) - encoded_jwt = jwt.encode( - to_encode, settings.secret_key, algorithm=settings.algorithm - ) + encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm) return encoded_jwt @@ -178,20 +168,20 @@ def generate_refresh_token() -> str: async def invalidate_user_tokens(db: AsyncSession, user_id: int) -> int: """使指定用户的所有令牌失效 - + 返回删除的令牌数量 """ # 使用 select 先获取所有令牌 stmt = select(OAuthToken).where(OAuthToken.user_id == user_id) result = await db.exec(stmt) tokens = result.all() - + # 逐个删除令牌 count = 0 for token in tokens: await db.delete(token) count += 1 - + # 提交更改 await db.commit() return count @@ -200,9 +190,7 @@ async def invalidate_user_tokens(db: AsyncSession, user_id: int) -> int: def verify_token(token: str) -> dict | None: """验证访问令牌""" try: - payload = jwt.decode( - token, settings.secret_key, algorithms=[settings.algorithm] - ) + payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm]) return payload except JWTError: return None @@ -221,17 +209,13 @@ async def store_token( expires_at = datetime.utcnow() + timedelta(seconds=expires_in) # 删除用户的旧令牌 - statement = select(OAuthToken).where( - OAuthToken.user_id == user_id, OAuthToken.client_id == client_id - ) + statement = select(OAuthToken).where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id) old_tokens = (await db.exec(statement)).all() for token in old_tokens: await db.delete(token) # 检查是否有重复的 access_token - duplicate_token = ( - await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token)) - ).first() + duplicate_token = (await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))).first() if duplicate_token: await db.delete(duplicate_token) @@ -250,9 +234,7 @@ async def store_token( return token_record -async def get_token_by_access_token( - db: AsyncSession, access_token: str -) -> OAuthToken | None: +async def get_token_by_access_token(db: AsyncSession, access_token: str) -> OAuthToken | None: """根据访问令牌获取令牌记录""" statement = select(OAuthToken).where( OAuthToken.access_token == access_token, @@ -261,9 +243,7 @@ async def get_token_by_access_token( return (await db.exec(statement)).first() -async def get_token_by_refresh_token( - db: AsyncSession, refresh_token: str -) -> OAuthToken | None: +async def get_token_by_refresh_token(db: AsyncSession, refresh_token: str) -> OAuthToken | None: """根据刷新令牌获取令牌记录""" statement = select(OAuthToken).where( OAuthToken.refresh_token == refresh_token, diff --git a/app/calculator.py b/app/calculator.py index 344bafe..cc76c0e 100644 --- a/app/calculator.py +++ b/app/calculator.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from copy import deepcopy from enum import Enum import math @@ -67,11 +68,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f if settings.suspicious_score_check: beatmap_banned = ( - await session.exec( - select(exists()).where( - col(BannedBeatmaps.beatmap_id) == score.beatmap_id - ) - ) + await session.exec(select(exists()).where(col(BannedBeatmaps.beatmap_id) == score.beatmap_id)) ).first() if beatmap_banned: return 0 @@ -82,12 +79,9 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f logger.warning(f"Beatmap {score.beatmap_id} is suspicious, banned") return 0 except Exception: - logger.exception( - f"Error checking if beatmap {score.beatmap_id} is suspicious" - ) + logger.exception(f"Error checking if beatmap {score.beatmap_id} is suspicious") # 使用线程池执行计算密集型操作以避免阻塞事件循环 - import asyncio loop = asyncio.get_event_loop() @@ -118,9 +112,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f pp = attrs.pp # mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp - if settings.suspicious_score_check and ( - (attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 2300 - ): + if settings.suspicious_score_check and ((attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 2300): logger.warning( f"User {score.user_id} played {score.beatmap_id} " f"(star={attrs.difficulty.stars}) with {pp=} " @@ -131,9 +123,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f return pp -async def pre_fetch_and_calculate_pp( - score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher -) -> float: +async def pre_fetch_and_calculate_pp(score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher) -> float: """ 优化版PP计算:预先获取beatmap文件并使用缓存 """ @@ -144,9 +134,7 @@ async def pre_fetch_and_calculate_pp( # 快速检查是否被封禁 if settings.suspicious_score_check: beatmap_banned = ( - await session.exec( - select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id) - ) + await session.exec(select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id)) ).first() if beatmap_banned: return 0 @@ -202,9 +190,7 @@ async def batch_calculate_pp( banned_beatmaps = set() if settings.suspicious_score_check: banned_results = await session.exec( - select(BannedBeatmaps.beatmap_id).where( - col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids) - ) + select(BannedBeatmaps.beatmap_id).where(col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids)) ) banned_beatmaps = set(banned_results.all()) @@ -380,9 +366,7 @@ def calculate_score_to_level(total_score: int) -> float: level = 0.0 while remaining_score > 0: - next_level_requirement = to_next_level[ - min(len(to_next_level) - 1, round(level)) - ] + next_level_requirement = to_next_level[min(len(to_next_level) - 1, round(level))] level += min(1, remaining_score / next_level_requirement) remaining_score -= next_level_requirement @@ -417,9 +401,7 @@ class Threshold(int, Enum): NOTE_POSX_THRESHOLD = 512 # x: [-512,512] NOTE_POSY_THRESHOLD = 384 # y: [-384,384] - POS_ERROR_THRESHOLD = ( - 1280 * 50 - ) # 超过这么多个物件(包括滑条控制点)的位置有问题就毙掉 + POS_ERROR_THRESHOLD = 1280 * 50 # 超过这么多个物件(包括滑条控制点)的位置有问题就毙掉 SLIDER_REPEAT_THRESHOLD = 5000 @@ -469,10 +451,7 @@ def is_2b(hit_objects: list[HitObject]) -> bool: def is_suspicious_beatmap(content: str) -> bool: osufile = OsuFile(content=content.encode("utf-8")).parse_file() - if ( - osufile.hit_objects[-1].start_time - osufile.hit_objects[0].start_time - > 24 * 60 * 60 * 1000 - ): + if osufile.hit_objects[-1].start_time - osufile.hit_objects[0].start_time > 24 * 60 * 60 * 1000: return True if osufile.mode == int(GameMode.TAIKO): if len(osufile.hit_objects) > Threshold.TAIKO_THRESHOLD: diff --git a/app/config.py b/app/config.py index 42140b7..ed29100 100644 --- a/app/config.py +++ b/app/config.py @@ -124,14 +124,10 @@ class Settings(BaseSettings): smtp_password: str = "" from_email: str = "noreply@example.com" from_name: str = "osu! server" - + # 邮件验证功能开关 - enable_email_verification: bool = Field( - default=True, description="是否启用邮件验证功能" - ) - enable_email_sending: bool = Field( - default=False, description="是否真实发送邮件(False时仅模拟发送)" - ) + enable_email_verification: bool = Field(default=True, description="是否启用邮件验证功能") + enable_email_sending: bool = Field(default=False, description="是否真实发送邮件(False时仅模拟发送)") # Sentry 配置 sentry_dsn: HttpUrl | None = None @@ -143,12 +139,8 @@ class Settings(BaseSettings): geoip_update_hour: int = 2 # 每周更新的小时数(0-23) # 游戏设置 - enable_rx: bool = Field( - default=False, validation_alias=AliasChoices("enable_rx", "enable_osu_rx") - ) - enable_ap: bool = Field( - default=False, validation_alias=AliasChoices("enable_ap", "enable_osu_ap") - ) + enable_rx: bool = Field(default=False, validation_alias=AliasChoices("enable_rx", "enable_osu_rx")) + enable_ap: bool = Field(default=False, validation_alias=AliasChoices("enable_ap", "enable_osu_ap")) enable_all_mods_pp: bool = False enable_supporter_for_all_users: bool = False enable_all_beatmap_leaderboard: bool = False @@ -189,9 +181,7 @@ class Settings(BaseSettings): # 存储设置 storage_service: StorageServiceType = StorageServiceType.LOCAL - storage_settings: ( - LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings - ) = LocalStorageSettings() + storage_settings: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings = LocalStorageSettings() @field_validator("fetcher_scopes", mode="before") def validate_fetcher_scopes(cls, v: Any) -> list[str]: @@ -207,22 +197,13 @@ class Settings(BaseSettings): ) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings: if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2: if not isinstance(v, CloudflareR2Settings): - raise ValueError( - "When storage_service is 'r2', " - "storage_settings must be CloudflareR2Settings" - ) + raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings") elif info.data.get("storage_service") == StorageServiceType.LOCAL: if not isinstance(v, LocalStorageSettings): - raise ValueError( - "When storage_service is 'local', " - "storage_settings must be LocalStorageSettings" - ) + raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings") elif info.data.get("storage_service") == StorageServiceType.AWS_S3: if not isinstance(v, AWSS3StorageSettings): - raise ValueError( - "When storage_service is 's3', " - "storage_settings must be AWSS3StorageSettings" - ) + raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings") return v diff --git a/app/database/achievement.py b/app/database/achievement.py index 2782c7c..d26f29b 100644 --- a/app/database/achievement.py +++ b/app/database/achievement.py @@ -28,18 +28,14 @@ if TYPE_CHECKING: class UserAchievementBase(SQLModel, UTCBaseModel): achievement_id: int - achieved_at: datetime = Field( - default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)) - ) + achieved_at: datetime = Field(default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))) class UserAchievement(UserAchievementBase, table=True): - __tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType] + __tablename__: str = "lazer_user_achievements" id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True) user: "User" = Relationship(back_populates="achievement") @@ -56,11 +52,7 @@ async def process_achievements(session: AsyncSession, redis: Redis, score_id: in if not score: return achieved = ( - await session.exec( - select(UserAchievement.achievement_id).where( - UserAchievement.user_id == score.user_id - ) - ) + await session.exec(select(UserAchievement.achievement_id).where(UserAchievement.user_id == score.user_id)) ).all() not_achieved = {k: v for k, v in MEDALS.items() if k.id not in achieved} result: list[Achievement] = [] @@ -78,9 +70,7 @@ async def process_achievements(session: AsyncSession, redis: Redis, score_id: in ) await redis.publish( "chat:notification", - UserAchievementUnlock.init( - r, score.user_id, score.gamemode - ).model_dump_json(), + UserAchievementUnlock.init(r, score.user_id, score.gamemode).model_dump_json(), ) event = Event( created_at=now, diff --git a/app/database/auth.py b/app/database/auth.py index 632be23..74bee24 100644 --- a/app/database/auth.py +++ b/app/database/auth.py @@ -20,42 +20,34 @@ if TYPE_CHECKING: class OAuthToken(UTCBaseModel, SQLModel, table=True): - __tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType] + __tablename__: str = "oauth_tokens" id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) client_id: int = Field(index=True) access_token: str = Field(max_length=500, unique=True) refresh_token: str = Field(max_length=500, unique=True) token_type: str = Field(default="Bearer", max_length=20) scope: str = Field(default="*", max_length=100) expires_at: datetime = Field(sa_column=Column(DateTime)) - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) + created_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime)) user: "User" = Relationship() class OAuthClient(SQLModel, table=True): - __tablename__ = "oauth_clients" # pyright: ignore[reportAssignmentType] + __tablename__: str = "oauth_clients" name: str = Field(max_length=100, index=True) description: str = Field(sa_column=Column(Text), default="") client_id: int | None = Field(default=None, primary_key=True, index=True) client_secret: str = Field(default_factory=secrets.token_hex, index=True) redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON)) - owner_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) class V1APIKeys(SQLModel, table=True): - __tablename__ = "v1_api_keys" # pyright: ignore[reportAssignmentType] + __tablename__: str = "v1_api_keys" id: int | None = Field(default=None, primary_key=True) name: str = Field(max_length=100, index=True) key: str = Field(default_factory=secrets.token_hex, index=True) - owner_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 278d722..1f83aa2 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -60,17 +60,13 @@ class BeatmapBase(SQLModel): class Beatmap(BeatmapBase, table=True): - __tablename__ = "beatmaps" # pyright: ignore[reportAssignmentType] + __tablename__: str = "beatmaps" id: int = Field(primary_key=True, index=True) beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True) beatmap_status: BeatmapRankStatus = Field(index=True) # optional - beatmapset: Beatmapset = Relationship( - back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"} - ) - failtimes: FailTime | None = Relationship( - back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"} - ) + beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}) + failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"}) @classmethod async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap": @@ -84,21 +80,15 @@ class Beatmap(BeatmapBase, table=True): "beatmap_status": BeatmapRankStatus(resp.ranked), } ) - if not ( - await session.exec(select(exists()).where(Beatmap.id == resp.id)) - ).first(): + if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first(): session.add(beatmap) await session.commit() - beatmap = ( - await session.exec(select(Beatmap).where(Beatmap.id == resp.id)) - ).first() + beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).first() assert beatmap is not None, "Beatmap should not be None after commit" return beatmap @classmethod - async def from_resp_batch( - cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0 - ) -> list["Beatmap"]: + async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]: beatmaps = [] for resp in inp: if resp.id == from_: @@ -113,9 +103,7 @@ class Beatmap(BeatmapBase, table=True): "beatmap_status": BeatmapRankStatus(resp.ranked), } ) - if not ( - await session.exec(select(exists()).where(Beatmap.id == resp.id)) - ).first(): + if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first(): session.add(beatmap) beatmaps.append(beatmap) await session.commit() @@ -130,17 +118,11 @@ class Beatmap(BeatmapBase, table=True): md5: str | None = None, ) -> "Beatmap": beatmap = ( - await session.exec( - select(Beatmap).where( - Beatmap.id == bid if bid is not None else Beatmap.checksum == md5 - ) - ) + await session.exec(select(Beatmap).where(Beatmap.id == bid if bid is not None else Beatmap.checksum == md5)) ).first() if not beatmap: resp = await fetcher.get_beatmap(bid, md5) - r = await session.exec( - select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id) - ) + r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)) if not r.first(): set_resp = await fetcher.get_beatmapset(resp.beatmapset_id) await Beatmapset.from_resp(session, set_resp, from_=resp.id) @@ -178,10 +160,7 @@ class BeatmapResp(BeatmapBase): if query_mode is not None and beatmap.mode != query_mode: beatmap_["convert"] = True beatmap_["is_scoreable"] = beatmap_status.has_leaderboard() - if ( - settings.enable_all_beatmap_leaderboard - and not beatmap_status.has_leaderboard() - ): + if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard(): beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower() else: @@ -189,9 +168,7 @@ class BeatmapResp(BeatmapBase): beatmap_["ranked"] = beatmap_status.value beatmap_["mode_int"] = int(beatmap.mode) if not from_set: - beatmap_["beatmapset"] = await BeatmapsetResp.from_db( - beatmap.beatmapset, session=session, user=user - ) + beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user) if beatmap.failtimes is not None: beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes) else: @@ -218,7 +195,7 @@ class BeatmapResp(BeatmapBase): class BannedBeatmaps(SQLModel, table=True): - __tablename__ = "banned_beatmaps" # pyright: ignore[reportAssignmentType] + __tablename__: str = "banned_beatmaps" id: int | None = Field(primary_key=True, index=True, default=None) beatmap_id: int = Field(index=True) @@ -230,15 +207,10 @@ async def calculate_beatmap_attributes( redis: Redis, fetcher: "Fetcher", ): - key = ( - f"beatmap:{beatmap_id}:{ruleset}:" - f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" - ) + key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" if await redis.exists(key): - return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] + return BeatmapAttributes.model_validate_json(await redis.get(key)) resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) - attr = await asyncio.get_event_loop().run_in_executor( - None, calculate_beatmap_attribute, resp, ruleset, mods_ - ) + attr = await asyncio.get_event_loop().run_in_executor(None, calculate_beatmap_attribute, resp, ruleset, mods_) await redis.set(key, attr.model_dump_json()) return attr diff --git a/app/database/beatmap_playcounts.py b/app/database/beatmap_playcounts.py index a6eed31..0917d5a 100644 --- a/app/database/beatmap_playcounts.py +++ b/app/database/beatmap_playcounts.py @@ -23,15 +23,13 @@ if TYPE_CHECKING: class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True): - __tablename__ = "beatmap_playcounts" # pyright: ignore[reportAssignmentType] + __tablename__: str = "beatmap_playcounts" id: int | None = Field( default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True), ) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) playcount: int = Field(default=0) @@ -59,9 +57,7 @@ class BeatmapPlaycountsResp(BaseModel): ) -async def process_beatmap_playcount( - session: AsyncSession, user_id: int, beatmap_id: int -) -> None: +async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None: existing_playcount = ( await session.exec( select(BeatmapPlaycounts).where( @@ -89,7 +85,5 @@ async def process_beatmap_playcount( } session.add(playcount_event) else: - new_playcount = BeatmapPlaycounts( - user_id=user_id, beatmap_id=beatmap_id, playcount=1 - ) + new_playcount = BeatmapPlaycounts(user_id=user_id, beatmap_id=beatmap_id, playcount=1) session.add(new_playcount) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 48b0408..32549f8 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -86,9 +86,7 @@ class BeatmapsetBase(SQLModel): # optional # converts: list[Beatmap] = Relationship(back_populates="beatmapset") - current_nominations: list[BeatmapNomination] | None = Field( - None, sa_column=Column(JSON) - ) + current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON)) description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON)) # TODO: discussions: list[BeatmapsetDiscussion] = None # TODO: current_user_attributes: Optional[CurrentUserAttributes] = None @@ -105,22 +103,18 @@ class BeatmapsetBase(SQLModel): can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean)) discussion_locked: bool = Field(default=False, sa_column=Column(Boolean)) last_updated: datetime = Field(sa_column=Column(DateTime, index=True)) - ranked_date: datetime | None = Field( - default=None, sa_column=Column(DateTime, index=True) - ) + ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True)) storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True)) submitted_date: datetime = Field(sa_column=Column(DateTime, index=True)) tags: str = Field(default="", sa_column=Column(Text)) class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): - __tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType] + __tablename__: str = "beatmapsets" - id: int | None = Field(default=None, primary_key=True, index=True) + id: int = Field(default=None, primary_key=True, index=True) # Beatmapset - beatmap_status: BeatmapRankStatus = Field( - default=BeatmapRankStatus.GRAVEYARD, index=True - ) + beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True) # optional beatmaps: list["Beatmap"] = Relationship(back_populates="beatmapset") @@ -137,9 +131,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset") @classmethod - async def from_resp( - cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0 - ) -> "Beatmapset": + async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset": from .beatmap import Beatmap d = resp.model_dump() @@ -167,18 +159,14 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): "download_disabled": resp.availability.download_disabled or False, } ) - if not ( - await session.exec(select(exists()).where(Beatmapset.id == resp.id)) - ).first(): + if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first(): session.add(beatmapset) await session.commit() await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_) return beatmapset @classmethod - async def get_or_fetch( - cls, session: AsyncSession, fetcher: "Fetcher", sid: int - ) -> "Beatmapset": + async def get_or_fetch(cls, session: AsyncSession, fetcher: "Fetcher", sid: int) -> "Beatmapset": beatmapset = await session.get(Beatmapset, sid) if not beatmapset: resp = await fetcher.get_beatmapset(sid) @@ -227,13 +215,9 @@ class BeatmapsetResp(BeatmapsetBase): @model_validator(mode="after") def fix_genre_language(self) -> Self: if self.genre is None: - self.genre = BeatmapTranslationText( - name=Genre(self.genre_id).name, id=self.genre_id - ) + self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id) if self.language is None: - self.language = BeatmapTranslationText( - name=Language(self.language_id).name, id=self.language_id - ) + self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id) return self @classmethod @@ -252,9 +236,7 @@ class BeatmapsetResp(BeatmapsetBase): await BeatmapResp.from_db(beatmap, from_set=True, session=session) for beatmap in await beatmapset.awaitable_attrs.beatmaps ], - "hype": BeatmapHype( - current=beatmapset.hype_current, required=beatmapset.hype_required - ), + "hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required), "availability": BeatmapAvailability( more_information=beatmapset.availability_info, download_disabled=beatmapset.download_disabled, @@ -282,10 +264,7 @@ class BeatmapsetResp(BeatmapsetBase): update["ratings"] = [] beatmap_status = beatmapset.beatmap_status - if ( - settings.enable_all_beatmap_leaderboard - and not beatmap_status.has_leaderboard() - ): + if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard(): update["status"] = BeatmapRankStatus.APPROVED.name.lower() update["ranked"] = BeatmapRankStatus.APPROVED.value else: @@ -295,9 +274,7 @@ class BeatmapsetResp(BeatmapsetBase): if session and user: existing_favourite = ( await session.exec( - select(FavouriteBeatmapset).where( - FavouriteBeatmapset.beatmapset_id == beatmapset.id - ) + select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id) ) ).first() update["has_favourited"] = existing_favourite is not None diff --git a/app/database/best_score.py b/app/database/best_score.py index 8688d5b..8179453 100644 --- a/app/database/best_score.py +++ b/app/database/best_score.py @@ -20,13 +20,9 @@ if TYPE_CHECKING: class BestScore(SQLModel, table=True): - __tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType] - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) - score_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) - ) + __tablename__: str = "total_score_best_scores" + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) + score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) gamemode: GameMode = Field(index=True) total_score: int = Field(default=0, sa_column=Column(BigInteger)) diff --git a/app/database/chat.py b/app/database/chat.py index 8ac1f82..a1a66d8 100644 --- a/app/database/chat.py +++ b/app/database/chat.py @@ -51,30 +51,22 @@ class ChatChannelBase(SQLModel): class ChatChannel(ChatChannelBase, table=True): - __tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType] - channel_id: int | None = Field(primary_key=True, index=True, default=None) + __tablename__: str = "chat_channels" + channel_id: int = Field(primary_key=True, index=True, default=None) @classmethod - async def get( - cls, channel: str | int, session: AsyncSession - ) -> "ChatChannel | None": + async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None": if isinstance(channel, int) or channel.isdigit(): # 使用查询而不是 get() 来确保对象完全加载 - result = await session.exec( - select(ChatChannel).where(ChatChannel.channel_id == int(channel)) - ) + result = await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel))) channel_ = result.first() if channel_ is not None: return channel_ - result = await session.exec( - select(ChatChannel).where(ChatChannel.name == channel) - ) + result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) return result.first() @classmethod - async def get_pm_channel( - cls, user1: int, user2: int, session: AsyncSession - ) -> "ChatChannel | None": + async def get_pm_channel(cls, user1: int, user2: int, session: AsyncSession) -> "ChatChannel | None": channel = await cls.get(f"pm_{user1}_{user2}", session) if channel is None: channel = await cls.get(f"pm_{user2}_{user1}", session) @@ -153,18 +145,13 @@ class ChatChannelResp(ChatChannelBase): .limit(10) ) ).all() - c.recent_messages = [ - await ChatMessageResp.from_db(msg, session, user) for msg in messages - ] + c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages] c.recent_messages.reverse() if c.type == ChannelType.PM and users and len(users) == 2: target_user_id = next(u for u in users if u != user.id) - target_name = await session.exec( - select(User.username).where(User.id == target_user_id) - ) + target_name = await session.exec(select(User.username).where(User.id == target_user_id)) c.name = target_name.one() - assert user.id c.users = [target_user_id, user.id] return c @@ -181,19 +168,15 @@ class MessageType(str, Enum): class ChatMessageBase(UTCBaseModel, SQLModel): channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id") content: str = Field(sa_column=Column(VARCHAR(1000))) - message_id: int | None = Field(index=True, primary_key=True, default=None) - sender_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) - timestamp: datetime = Field( - sa_column=Column(DateTime, index=True), default=datetime.now(UTC) - ) + message_id: int = Field(index=True, primary_key=True, default=None) + sender_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) + timestamp: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC)) type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True) uuid: str | None = Field(default=None) class ChatMessage(ChatMessageBase, table=True): - __tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType] + __tablename__: str = "chat_messages" user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) channel: ChatChannel = Relationship() @@ -211,9 +194,7 @@ class ChatMessageResp(ChatMessageBase): if user: m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES) else: - m.sender = await UserResp.from_db( - db_message.user, session, RANKING_INCLUDES - ) + m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES) return m @@ -221,17 +202,13 @@ class ChatMessageResp(ChatMessageBase): class SilenceUser(UTCBaseModel, SQLModel, table=True): - __tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType] - id: int | None = Field(primary_key=True, default=None, index=True) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + __tablename__: str = "chat_silence_users" + id: int = Field(primary_key=True, default=None, index=True) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True) until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None) reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True)) - banned_at: datetime = Field( - sa_column=Column(DateTime, index=True), default=datetime.now(UTC) - ) + banned_at: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC)) class UserSilenceResp(SQLModel): @@ -240,7 +217,6 @@ class UserSilenceResp(SQLModel): @classmethod def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp": - assert db_silence.id is not None return cls( id=db_silence.id, user_id=db_silence.user_id, diff --git a/app/database/counts.py b/app/database/counts.py index c999471..57e2b46 100644 --- a/app/database/counts.py +++ b/app/database/counts.py @@ -21,28 +21,24 @@ class CountBase(SQLModel): class MonthlyPlaycounts(CountBase, table=True): - __tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType] + __tablename__: str = "monthly_playcounts" id: int | None = Field( default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True), ) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) user: "User" = Relationship(back_populates="monthly_playcounts") class ReplayWatchedCount(CountBase, table=True): - __tablename__ = "replays_watched_counts" # pyright: ignore[reportAssignmentType] + __tablename__: str = "replays_watched_counts" id: int | None = Field( default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True), ) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) user: "User" = Relationship(back_populates="replays_watched_counts") diff --git a/app/database/daily_challenge.py b/app/database/daily_challenge.py index 476f326..071c606 100644 --- a/app/database/daily_challenge.py +++ b/app/database/daily_challenge.py @@ -24,9 +24,7 @@ class DailyChallengeStatsBase(SQLModel, UTCBaseModel): daily_streak_best: int = Field(default=0) daily_streak_current: int = Field(default=0) last_update: datetime | None = Field(default=None, sa_column=Column(DateTime)) - last_weekly_streak: datetime | None = Field( - default=None, sa_column=Column(DateTime) - ) + last_weekly_streak: datetime | None = Field(default=None, sa_column=Column(DateTime)) playcount: int = Field(default=0) top_10p_placements: int = Field(default=0) top_50p_placements: int = Field(default=0) @@ -35,7 +33,7 @@ class DailyChallengeStatsBase(SQLModel, UTCBaseModel): class DailyChallengeStats(DailyChallengeStatsBase, table=True): - __tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType] + __tablename__: str = "daily_challenge_stats" user_id: int | None = Field( default=None, @@ -61,9 +59,7 @@ class DailyChallengeStatsResp(DailyChallengeStatsBase): return cls.model_validate(obj) -async def process_daily_challenge_score( - session: AsyncSession, user_id: int, room_id: int -): +async def process_daily_challenge_score(session: AsyncSession, user_id: int, room_id: int): from .playlist_best_score import PlaylistBestScore score = ( diff --git a/app/database/email_verification.py b/app/database/email_verification.py index 46a55b9..07d57a5 100644 --- a/app/database/email_verification.py +++ b/app/database/email_verification.py @@ -4,16 +4,17 @@ from __future__ import annotations -from datetime import datetime, UTC -from sqlmodel import SQLModel, Field -from sqlalchemy import Column, BigInteger, ForeignKey +from datetime import UTC, datetime + +from sqlalchemy import BigInteger, Column, ForeignKey +from sqlmodel import Field, SQLModel class EmailVerification(SQLModel, table=True): """邮件验证记录""" - + __tablename__: str = "email_verifications" - + id: int | None = Field(default=None, primary_key=True) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True)) email: str = Field(index=True) @@ -28,9 +29,9 @@ class EmailVerification(SQLModel, table=True): class LoginSession(SQLModel, table=True): """登录会话记录""" - + __tablename__: str = "login_sessions" - + id: int | None = Field(default=None, primary_key=True) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True)) session_token: str = Field(unique=True, index=True) # 会话令牌 diff --git a/app/database/events.py b/app/database/events.py index 19fe515..980396c 100644 --- a/app/database/events.py +++ b/app/database/events.py @@ -36,17 +36,13 @@ class EventType(str, Enum): class EventBase(SQLModel): id: int = Field(default=None, primary_key=True) - created_at: datetime = Field( - sa_column=Column(DateTime(timezone=True), default=datetime.now(UTC)) - ) + created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), default=datetime.now(UTC))) type: EventType - event_payload: dict = Field( - exclude=True, default_factory=dict, sa_column=Column(JSON) - ) + event_payload: dict = Field(exclude=True, default_factory=dict, sa_column=Column(JSON)) class Event(EventBase, table=True): - __tablename__ = "user_events" # pyright: ignore[reportAssignmentType] + __tablename__: str = "user_events" user_id: int | None = Field( default=None, sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True), diff --git a/app/database/failtime.py b/app/database/failtime.py index ae0f66f..5e08a73 100644 --- a/app/database/failtime.py +++ b/app/database/failtime.py @@ -16,8 +16,8 @@ FAILTIME_STRUCT = Struct("<100i") class FailTime(SQLModel, table=True): - __tablename__ = "failtime" # pyright: ignore[reportAssignmentType] - beatmap_id: int = Field(primary_key=True, index=True, foreign_key="beatmaps.id") + __tablename__: str = "failtime" + beatmap_id: int = Field(primary_key=True, foreign_key="beatmaps.id") exit: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False)) fail: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False)) @@ -41,12 +41,8 @@ class FailTime(SQLModel, table=True): class FailTimeResp(BaseModel): - exit: list[int] = Field( - default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)) - ) - fail: list[int] = Field( - default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)) - ) + exit: list[int] = Field(default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400))) + fail: list[int] = Field(default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400))) @classmethod def from_db(cls, failtime: FailTime) -> "FailTimeResp": diff --git a/app/database/favourite_beatmapset.py b/app/database/favourite_beatmapset.py index 51bd578..c59521c 100644 --- a/app/database/favourite_beatmapset.py +++ b/app/database/favourite_beatmapset.py @@ -16,7 +16,7 @@ from sqlmodel import ( class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True): - __tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType] + __tablename__: str = "favourite_beatmapset" id: int | None = Field( default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True), diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index e215dab..8132137 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -75,9 +75,7 @@ class UserBase(UTCBaseModel, SQLModel): is_active: bool = True is_bot: bool = False is_supporter: bool = False - last_visit: datetime | None = Field( - default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)) - ) + last_visit: datetime | None = Field(default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))) pm_friends_only: bool = False profile_colour: str | None = None username: str = Field(max_length=32, unique=True, index=True) @@ -90,9 +88,7 @@ class UserBase(UTCBaseModel, SQLModel): is_restricted: bool = False # blocks cover: UserProfileCover = Field( - default=UserProfileCover( - url="https://assets.ppy.sh/user-profile-covers/default.jpeg" - ), + default=UserProfileCover(url="https://assets.ppy.sh/user-profile-covers/default.jpeg"), sa_column=Column(JSON), ) beatmap_playcounts_count: int = 0 @@ -150,9 +146,9 @@ class UserBase(UTCBaseModel, SQLModel): class User(AsyncAttrs, UserBase, table=True): - __tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType] + __tablename__: str = "lazer_users" - id: int | None = Field( + id: int = Field( default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True), ) @@ -160,16 +156,10 @@ class User(AsyncAttrs, UserBase, table=True): statistics: list[UserStatistics] = Relationship() achievement: list[UserAchievement] = Relationship(back_populates="user") team_membership: TeamMember | None = Relationship(back_populates="user") - daily_challenge_stats: DailyChallengeStats | None = Relationship( - back_populates="user" - ) + daily_challenge_stats: DailyChallengeStats | None = Relationship(back_populates="user") monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user") - replays_watched_counts: list[ReplayWatchedCount] = Relationship( - back_populates="user" - ) - favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship( - back_populates="user" - ) + replays_watched_counts: list[ReplayWatchedCount] = Relationship(back_populates="user") + favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(back_populates="user") rank_history: list[RankHistory] = Relationship( back_populates="user", ) @@ -178,16 +168,10 @@ class User(AsyncAttrs, UserBase, table=True): email: str = Field(max_length=254, unique=True, index=True, exclude=True) priv: int = Field(default=1, exclude=True) pw_bcrypt: str = Field(max_length=60, exclude=True) - silence_end_at: datetime | None = Field( - default=None, sa_column=Column(DateTime(timezone=True)), exclude=True - ) - donor_end_at: datetime | None = Field( - default=None, sa_column=Column(DateTime(timezone=True)), exclude=True - ) + silence_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True) + donor_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True) - async def is_user_can_pm( - self, from_user: "User", session: AsyncSession - ) -> tuple[bool, str]: + async def is_user_can_pm(self, from_user: "User", session: AsyncSession) -> tuple[bool, str]: from .relationship import Relationship, RelationshipType from_relationship = ( @@ -200,13 +184,10 @@ class User(AsyncAttrs, UserBase, table=True): ).first() if from_relationship and from_relationship.type == RelationshipType.BLOCK: return False, "You have blocked the target user." - if from_user.pm_friends_only and ( - not from_relationship or from_relationship.type != RelationshipType.FOLLOW - ): + if from_user.pm_friends_only and (not from_relationship or from_relationship.type != RelationshipType.FOLLOW): return ( False, - "You have disabled non-friend communications " - "and target user is not your friend.", + "You have disabled non-friend communications and target user is not your friend.", ) relationship = ( @@ -219,9 +200,7 @@ class User(AsyncAttrs, UserBase, table=True): ).first() if relationship and relationship.type == RelationshipType.BLOCK: return False, "Target user has blocked you." - if self.pm_friends_only and ( - not relationship or relationship.type != RelationshipType.FOLLOW - ): + if self.pm_friends_only and (not relationship or relationship.type != RelationshipType.FOLLOW): return False, "Target user has disabled non-friend communications" return True, "" @@ -288,9 +267,7 @@ class UserResp(UserBase): u = cls.model_validate(obj.model_dump()) u.id = obj.id u.default_group = "bot" if u.is_bot else "default" - u.country = Country( - code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown") - ) + u.country = Country(code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown")) u.follower_count = ( await session.exec( select(func.count()) @@ -314,9 +291,7 @@ class UserResp(UserBase): redis = get_redis() u.is_online = await redis.exists(f"metadata:online:{obj.id}") u.cover_url = ( - obj.cover.get( - "url", "https://assets.ppy.sh/user-profile-covers/default.jpeg" - ) + obj.cover.get("url", "https://assets.ppy.sh/user-profile-covers/default.jpeg") if obj.cover else "https://assets.ppy.sh/user-profile-covers/default.jpeg" ) @@ -335,22 +310,15 @@ class UserResp(UserBase): ] if "team" in include: - if await obj.awaitable_attrs.team_membership: - assert obj.team_membership - u.team = obj.team_membership.team + if team_membership := await obj.awaitable_attrs.team_membership: + u.team = team_membership.team if "account_history" in include: - u.account_history = [ - UserAccountHistoryResp.from_db(ah) - for ah in await obj.awaitable_attrs.account_history - ] + u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history] if "daily_challenge_user_stats": - if await obj.awaitable_attrs.daily_challenge_stats: - assert obj.daily_challenge_stats - u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db( - obj.daily_challenge_stats - ) + if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats: + u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats) if "statistics" in include: current_stattistics = None @@ -359,59 +327,40 @@ class UserResp(UserBase): current_stattistics = i break u.statistics = ( - await UserStatisticsResp.from_db( - current_stattistics, session, obj.country_code - ) + await UserStatisticsResp.from_db(current_stattistics, session, obj.country_code) if current_stattistics else None ) if "statistics_rulesets" in include: u.statistics_rulesets = { - i.mode.value: await UserStatisticsResp.from_db( - i, session, obj.country_code - ) + i.mode.value: await UserStatisticsResp.from_db(i, session, obj.country_code) for i in await obj.awaitable_attrs.statistics } if "monthly_playcounts" in include: - u.monthly_playcounts = [ - CountResp.from_db(pc) - for pc in await obj.awaitable_attrs.monthly_playcounts - ] + u.monthly_playcounts = [CountResp.from_db(pc) for pc in await obj.awaitable_attrs.monthly_playcounts] if len(u.monthly_playcounts) == 1: d = u.monthly_playcounts[0].start_date - u.monthly_playcounts.insert( - 0, CountResp(start_date=d - timedelta(days=20), count=0) - ) + u.monthly_playcounts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0)) if "replays_watched_counts" in include: u.replay_watched_counts = [ - CountResp.from_db(rwc) - for rwc in await obj.awaitable_attrs.replays_watched_counts + CountResp.from_db(rwc) for rwc in await obj.awaitable_attrs.replays_watched_counts ] if len(u.replay_watched_counts) == 1: d = u.replay_watched_counts[0].start_date - u.replay_watched_counts.insert( - 0, CountResp(start_date=d - timedelta(days=20), count=0) - ) + u.replay_watched_counts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0)) if "achievements" in include: - u.user_achievements = [ - UserAchievementResp.from_db(ua) - for ua in await obj.awaitable_attrs.achievement - ] + u.user_achievements = [UserAchievementResp.from_db(ua) for ua in await obj.awaitable_attrs.achievement] if "rank_history" in include: rank_history = await RankHistoryResp.from_db(session, obj.id, ruleset) if len(rank_history.data) != 0: u.rank_history = rank_history rank_top = ( - await session.exec( - select(RankTop).where( - RankTop.user_id == obj.id, RankTop.mode == ruleset - ) - ) + await session.exec(select(RankTop).where(RankTop.user_id == obj.id, RankTop.mode == ruleset)) ).first() if rank_top: u.rank_highest = ( @@ -425,9 +374,7 @@ class UserResp(UserBase): u.favourite_beatmapset_count = ( await session.exec( - select(func.count()) - .select_from(FavouriteBeatmapset) - .where(FavouriteBeatmapset.user_id == obj.id) + select(func.count()).select_from(FavouriteBeatmapset).where(FavouriteBeatmapset.user_id == obj.id) ) ).one() u.scores_pinned_count = ( @@ -478,17 +425,19 @@ class UserResp(UserBase): # 检查会话验证状态 # 如果邮件验证功能被禁用,则始终设置 session_verified 为 true from app.config import settings + if not settings.enable_email_verification: u.session_verified = True else: # 如果用户有未验证的登录会话,则设置 session_verified 为 false from .email_verification import LoginSession + unverified_session = ( await session.exec( select(LoginSession).where( LoginSession.user_id == obj.id, - LoginSession.is_verified == False, - LoginSession.expires_at > datetime.now(UTC) + col(LoginSession.is_verified).is_(False), + LoginSession.expires_at > datetime.now(UTC), ) ) ).first() diff --git a/app/database/multiplayer_event.py b/app/database/multiplayer_event.py index 904fbe4..ce3f024 100644 --- a/app/database/multiplayer_event.py +++ b/app/database/multiplayer_event.py @@ -30,8 +30,8 @@ class MultiplayerEventBase(SQLModel, UTCBaseModel): class MultiplayerEvent(MultiplayerEventBase, table=True): - __tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType] - id: int | None = Field( + __tablename__: str = "multiplayer_events" + id: int = Field( default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True), ) diff --git a/app/database/notification.py b/app/database/notification.py index a0f568b..6cb44fb 100644 --- a/app/database/notification.py +++ b/app/database/notification.py @@ -17,7 +17,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession class Notification(SQLModel, table=True): - __tablename__ = "notifications" # pyright: ignore[reportAssignmentType] + __tablename__: str = "notifications" id: int = Field(primary_key=True, index=True, default=None) name: NotificationName = Field(index=True) @@ -30,7 +30,7 @@ class Notification(SQLModel, table=True): class UserNotification(SQLModel, table=True): - __tablename__ = "user_notifications" # pyright: ignore[reportAssignmentType] + __tablename__: str = "user_notifications" id: int = Field( sa_column=Column( BigInteger, @@ -40,9 +40,7 @@ class UserNotification(SQLModel, table=True): default=None, ) notification_id: int = Field(index=True, foreign_key="notifications.id") - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) is_read: bool = Field(index=True) notification: Notification = Relationship(sa_relationship_kwargs={"lazy": "joined"}) diff --git a/app/database/password_reset.py b/app/database/password_reset.py index 38cc962..51651ce 100644 --- a/app/database/password_reset.py +++ b/app/database/password_reset.py @@ -4,16 +4,17 @@ from __future__ import annotations -from datetime import datetime, UTC -from sqlmodel import SQLModel, Field -from sqlalchemy import Column, BigInteger, ForeignKey +from datetime import UTC, datetime + +from sqlalchemy import BigInteger, Column, ForeignKey +from sqlmodel import Field, SQLModel class PasswordReset(SQLModel, table=True): """密码重置记录""" - + __tablename__: str = "password_resets" - + id: int | None = Field(default=None, primary_key=True) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True)) email: str = Field(index=True) diff --git a/app/database/playlist_attempts.py b/app/database/playlist_attempts.py index f8a8f0c..4b7530d 100644 --- a/app/database/playlist_attempts.py +++ b/app/database/playlist_attempts.py @@ -21,16 +21,14 @@ class ItemAttemptsCountBase(SQLModel): room_id: int = Field(foreign_key="rooms.id", index=True) attempts: int = Field(default=0) completed: int = Field(default=0) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) accuracy: float = 0.0 pp: float = 0 total_score: int = 0 class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True): - __tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType] + __tablename__: str = "item_attempts_count" id: int | None = Field(default=None, primary_key=True) user: User = Relationship() @@ -63,9 +61,7 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True): self.pp = sum(score.score.pp for score in playlist_scores) self.completed = len([score for score in playlist_scores if score.score.passed]) self.accuracy = ( - sum(score.score.accuracy for score in playlist_scores) / self.completed - if self.completed > 0 - else 0.0 + sum(score.score.accuracy for score in playlist_scores) / self.completed if self.completed > 0 else 0.0 ) await session.commit() await session.refresh(self) diff --git a/app/database/playlist_best_score.py b/app/database/playlist_best_score.py index 72f83ec..411797d 100644 --- a/app/database/playlist_best_score.py +++ b/app/database/playlist_best_score.py @@ -21,14 +21,10 @@ if TYPE_CHECKING: class PlaylistBestScore(SQLModel, table=True): - __tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType] + __tablename__: str = "playlist_best_scores" - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) - score_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) + score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)) room_id: int = Field(foreign_key="rooms.id", index=True) playlist_id: int = Field(index=True) total_score: int = Field(default=0, sa_column=Column(BigInteger)) diff --git a/app/database/playlists.py b/app/database/playlists.py index c177432..79e5c0d 100644 --- a/app/database/playlists.py +++ b/app/database/playlists.py @@ -50,7 +50,7 @@ class PlaylistBase(SQLModel, UTCBaseModel): class Playlist(PlaylistBase, table=True): - __tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType] + __tablename__: str = "room_playlists" db_id: int = Field(default=None, primary_key=True, index=True, exclude=True) room_id: int = Field(foreign_key="rooms.id", exclude=True) @@ -63,16 +63,12 @@ class Playlist(PlaylistBase, table=True): @classmethod async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int: - stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where( - cls.room_id == room_id - ) + stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(cls.room_id == room_id) result = await session.exec(stmt) return result.one() @classmethod - async def from_hub( - cls, playlist: PlaylistItem, room_id: int, session: AsyncSession - ) -> "Playlist": + async def from_hub(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist": next_id = await cls.get_next_id_for_room(room_id, session=session) return cls( id=next_id, @@ -90,9 +86,7 @@ class Playlist(PlaylistBase, table=True): @classmethod async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession): - db_playlist = await session.exec( - select(cls).where(cls.id == playlist.id, cls.room_id == room_id) - ) + db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id)) db_playlist = db_playlist.first() if db_playlist is None: raise ValueError("Playlist item not found") @@ -108,9 +102,7 @@ class Playlist(PlaylistBase, table=True): await session.commit() @classmethod - async def add_to_db( - cls, playlist: PlaylistItem, room_id: int, session: AsyncSession - ): + async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession): db_playlist = await cls.from_hub(playlist, room_id, session) session.add(db_playlist) await session.commit() @@ -119,9 +111,7 @@ class Playlist(PlaylistBase, table=True): @classmethod async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession): - db_playlist = await session.exec( - select(cls).where(cls.id == item_id, cls.room_id == room_id) - ) + db_playlist = await session.exec(select(cls).where(cls.id == item_id, cls.room_id == room_id)) db_playlist = db_playlist.first() if db_playlist is None: raise ValueError("Playlist item not found") @@ -133,9 +123,7 @@ class PlaylistResp(PlaylistBase): beatmap: BeatmapResp | None = None @classmethod - async def from_db( - cls, playlist: Playlist, include: list[str] = [] - ) -> "PlaylistResp": + async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp": data = playlist.model_dump() if "beatmap" in include: data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap) diff --git a/app/database/pp_best_score.py b/app/database/pp_best_score.py index ffc74d3..2141630 100644 --- a/app/database/pp_best_score.py +++ b/app/database/pp_best_score.py @@ -20,13 +20,9 @@ if TYPE_CHECKING: class PPBestScore(SQLModel, table=True): - __tablename__ = "best_scores" # pyright: ignore[reportAssignmentType] - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) - score_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True) - ) + __tablename__: str = "best_scores" + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) + score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)) beatmap_id: int = Field(foreign_key="beatmaps.id", index=True) gamemode: GameMode = Field(index=True) pp: float = Field( diff --git a/app/database/rank_history.py b/app/database/rank_history.py index 403dbca..81caf22 100644 --- a/app/database/rank_history.py +++ b/app/database/rank_history.py @@ -26,12 +26,10 @@ if TYPE_CHECKING: class RankHistory(SQLModel, table=True): - __tablename__ = "rank_history" # pyright: ignore[reportAssignmentType] + __tablename__: str = "rank_history" id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True)) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) mode: GameMode rank: int date: dt = Field( @@ -43,12 +41,10 @@ class RankHistory(SQLModel, table=True): class RankTop(SQLModel, table=True): - __tablename__ = "rank_top" # pyright: ignore[reportAssignmentType] + __tablename__: str = "rank_top" id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True)) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) mode: GameMode rank: int date: dt = Field( @@ -62,9 +58,7 @@ class RankHistoryResp(BaseModel): data: list[int] @classmethod - async def from_db( - cls, session: AsyncSession, user_id: int, mode: GameMode - ) -> "RankHistoryResp": + async def from_db(cls, session: AsyncSession, user_id: int, mode: GameMode) -> "RankHistoryResp": results = ( await session.exec( select(RankHistory) diff --git a/app/database/relationship.py b/app/database/relationship.py index b941c28..a6d109e 100644 --- a/app/database/relationship.py +++ b/app/database/relationship.py @@ -21,7 +21,7 @@ class RelationshipType(str, Enum): class Relationship(SQLModel, table=True): - __tablename__ = "relationship" # pyright: ignore[reportAssignmentType] + __tablename__: str = "relationship" id: int | None = Field( default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True), @@ -59,9 +59,7 @@ class RelationshipResp(BaseModel): type: RelationshipType @classmethod - async def from_db( - cls, session: AsyncSession, relationship: Relationship - ) -> "RelationshipResp": + async def from_db(cls, session: AsyncSession, relationship: Relationship) -> "RelationshipResp": target_relationship = ( await session.exec( select(Relationship).where( diff --git a/app/database/room.py b/app/database/room.py index 54497db..491b55b 100644 --- a/app/database/room.py +++ b/app/database/room.py @@ -58,11 +58,9 @@ class RoomBase(SQLModel, UTCBaseModel): class Room(AsyncAttrs, RoomBase, table=True): - __tablename__ = "rooms" # pyright: ignore[reportAssignmentType] + __tablename__: str = "rooms" id: int = Field(default=None, primary_key=True, index=True) - host_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) host: User = Relationship() playlist: list[Playlist] = Relationship( @@ -109,12 +107,8 @@ class RoomResp(RoomBase): if not playlist.expired: stats.count_active += 1 rulesets.add(playlist.ruleset_id) - difficulty_range.min = min( - difficulty_range.min, playlist.beatmap.difficulty_rating - ) - difficulty_range.max = max( - difficulty_range.max, playlist.beatmap.difficulty_rating - ) + difficulty_range.min = min(difficulty_range.min, playlist.beatmap.difficulty_rating) + difficulty_range.max = max(difficulty_range.max, playlist.beatmap.difficulty_rating) resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"])) stats.ruleset_ids = list(rulesets) resp.playlist_item_stats = stats @@ -137,13 +131,9 @@ class RoomResp(RoomBase): include=["statistics"], ) ) - resp.host = await UserResp.from_db( - await room.awaitable_attrs.host, session, include=["statistics"] - ) + resp.host = await UserResp.from_db(await room.awaitable_attrs.host, session, include=["statistics"]) if "current_user_score" in include and user: - resp.current_user_score = await PlaylistAggregateScore.from_db( - room.id, user.id, session - ) + resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session) return resp @classmethod diff --git a/app/database/room_participated_user.py b/app/database/room_participated_user.py index 18b0aeb..3779516 100644 --- a/app/database/room_participated_user.py +++ b/app/database/room_participated_user.py @@ -18,22 +18,16 @@ if TYPE_CHECKING: class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True): - __tablename__ = "room_participated_users" # pyright: ignore[reportAssignmentType] + __tablename__: str = "room_participated_users" - id: int | None = Field( - default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True) - ) + id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True)) room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), nullable=False)) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False)) joined_at: datetime = Field( sa_column=Column(DateTime(timezone=True), nullable=False), default=datetime.now(UTC), ) - left_at: datetime | None = Field( - sa_column=Column(DateTime(timezone=True), nullable=True), default=None - ) + left_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True), default=None) room: "Room" = Relationship() user: "User" = Relationship() diff --git a/app/database/score.py b/app/database/score.py index da9cbc4..66e1cc2 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -47,9 +47,9 @@ from .score_token import ScoreToken from pydantic import field_serializer, field_validator from redis.asyncio import Redis -from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime +from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime, TextClause from sqlalchemy.ext.asyncio import AsyncAttrs -from sqlalchemy.orm import aliased +from sqlalchemy.orm import Mapped, aliased from sqlalchemy.sql.elements import ColumnElement from sqlmodel import ( JSON, @@ -76,9 +76,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): accuracy: float map_md5: str = Field(max_length=32, index=True) build_id: int | None = Field(default=None) - classic_total_score: int | None = Field( - default=0, sa_column=Column(BigInteger) - ) # solo_score + classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score ended_at: datetime = Field(sa_column=Column(DateTime)) has_replay: bool = Field(sa_column=Column(Boolean)) max_combo: int @@ -91,14 +89,10 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): room_id: int | None = Field(default=None) # multiplayer started_at: datetime = Field(sa_column=Column(DateTime)) total_score: int = Field(default=0, sa_column=Column(BigInteger)) - total_score_without_mods: int = Field( - default=0, sa_column=Column(BigInteger), exclude=True - ) + total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True) type: str beatmap_id: int = Field(index=True, foreign_key="beatmaps.id") - maximum_statistics: ScoreStatistics = Field( - sa_column=Column(JSON), default_factory=dict - ) + maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict) @field_validator("maximum_statistics", mode="before") @classmethod @@ -147,10 +141,8 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel): class Score(ScoreBase, table=True): - __tablename__ = "scores" # pyright: ignore[reportAssignmentType] - id: int | None = Field( - default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True) - ) + __tablename__: str = "scores" + id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)) user_id: int = Field( default=None, sa_column=Column( @@ -193,8 +185,8 @@ class Score(ScoreBase, table=True): return str(v) # optional - beatmap: Beatmap = Relationship() - user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"}) + beatmap: Mapped[Beatmap] = Relationship() + user: Mapped[User] = Relationship(sa_relationship_kwargs={"lazy": "joined"}) @property def is_perfect_combo(self) -> bool: @@ -205,11 +197,7 @@ class Score(ScoreBase, table=True): *where_clauses: ColumnExpressionArgument[bool] | bool, ) -> SelectOfScalar["Score"]: rownum = ( - func.row_number() - .over( - partition_by=col(Score.user_id), order_by=col(Score.total_score).desc() - ) - .label("rn") + func.row_number().over(partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()).label("rn") ) subq = select(Score, rownum).where(*where_clauses).subquery() best = aliased(Score, subq, adapt_on_names=True) @@ -296,12 +284,9 @@ class ScoreResp(ScoreBase): await session.refresh(score) s = cls.model_validate(score.model_dump()) - assert score.id await score.awaitable_attrs.beatmap s.beatmap = await BeatmapResp.from_db(score.beatmap) - s.beatmapset = await BeatmapsetResp.from_db( - score.beatmap.beatmapset, session=session, user=score.user - ) + s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user) s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo s.ruleset_id = int(score.gamemode) @@ -371,11 +356,7 @@ class ScoreAround(SQLModel): async def get_best_id(session: AsyncSession, score_id: int) -> None: rownum = ( - func.row_number() - .over( - partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc() - ) - .label("rn") + func.row_number().over(partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()).label("rn") ) subq = select(PPBestScore, rownum).subquery() stmt = select(subq.c.rn).where(subq.c.score_id == score_id) @@ -389,8 +370,8 @@ async def _score_where( mode: GameMode, mods: list[str] | None = None, user: User | None = None, -) -> list[ColumnElement[bool]] | None: - wheres = [ +) -> list[ColumnElement[bool] | TextClause] | None: + wheres: list[ColumnElement[bool] | TextClause] = [ col(BestScore.beatmap_id) == beatmap, col(BestScore.gamemode) == mode, ] @@ -410,9 +391,7 @@ async def _score_where( return None elif type == LeaderboardType.COUNTRY: if user and user.is_supporter: - wheres.append( - col(BestScore.user).has(col(User.country_code) == user.country_code) - ) + wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code)) else: return None elif type == LeaderboardType.TEAM: @@ -420,18 +399,14 @@ async def _score_where( team_membership = await user.awaitable_attrs.team_membership if team_membership: team_id = team_membership.team_id - wheres.append( - col(BestScore.user).has( - col(User.team_membership).has(TeamMember.team_id == team_id) - ) - ) + wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id))) if mods: if user and user.is_supporter: wheres.append( text( "JSON_CONTAINS(total_score_best_scores.mods, :w)" " AND JSON_CONTAINS(:w, total_score_best_scores.mods)" - ).params(w=json.dumps(mods)) # pyright: ignore[reportArgumentType] + ).params(w=json.dumps(mods)) ) else: return None @@ -654,18 +629,14 @@ def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]: + (score.nsmall_tick_hit or 0) ) total_obj = 0 - for statistics, count in ( - score.maximum_statistics.items() if score.maximum_statistics else {} - ): + for statistics, count in score.maximum_statistics.items() if score.maximum_statistics else {}: if not isinstance(statistics, HitResult): statistics = HitResult(statistics) if statistics.is_scorable(): total_obj += count return total_length, score.passed or ( - total_length > 8 - and score.total_score >= 5000 - and total_obj_hited >= min(0.1 * total_obj, 20) + total_length > 8 and score.total_score >= 5000 and total_obj_hited >= min(0.1 * total_obj, 20) ) @@ -678,12 +649,8 @@ async def process_user( ranked: bool = False, has_leaderboard: bool = False, ): - assert user.id - assert score.id mod_for_save = mod_to_save(score.mods) - previous_score_best = await get_user_best_score_in_beatmap( - session, score.beatmap_id, user.id, score.gamemode - ) + previous_score_best = await get_user_best_score_in_beatmap(session, score.beatmap_id, user.id, score.gamemode) previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap( session, score.beatmap_id, user.id, mod_for_save, score.gamemode ) @@ -698,9 +665,7 @@ async def process_user( ) ).first() if mouthly_playcount is None: - mouthly_playcount = MonthlyPlaycounts( - user_id=user.id, year=date.today().year, month=date.today().month - ) + mouthly_playcount = MonthlyPlaycounts(user_id=user.id, year=date.today().year, month=date.today().month) add_to_db = True statistics = None for i in await user.awaitable_attrs.statistics: @@ -708,17 +673,11 @@ async def process_user( statistics = i break if statistics is None: - raise ValueError( - f"User {user.id} does not have statistics for mode {score.gamemode.value}" - ) + raise ValueError(f"User {user.id} does not have statistics for mode {score.gamemode.value}") # pc, pt, tth, tts statistics.total_score += score.total_score - difference = ( - score.total_score - previous_score_best.total_score - if previous_score_best - else score.total_score - ) + difference = score.total_score - previous_score_best.total_score if previous_score_best else score.total_score if difference > 0 and score.passed and ranked: match score.rank: case Rank.X: @@ -746,11 +705,8 @@ async def process_user( statistics.ranked_score += difference statistics.level_current = calculate_score_to_level(statistics.total_score) statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) - new_score_position = await get_score_position_by_user( - session, score.beatmap_id, user, score.gamemode - ) + new_score_position = await get_score_position_by_user(session, score.beatmap_id, user, score.gamemode) total_users = await session.exec(select(func.count()).select_from(User)) - assert total_users is not None score_range = min(50, math.ceil(float(total_users.one()) * 0.01)) if new_score_position <= score_range and new_score_position > 0: # Get the scores that might be displaced @@ -774,11 +730,7 @@ async def process_user( ) # If this score was previously in top positions but now pushed out - if ( - i < score_range - and displaced_position > score_range - and displaced_position is not None - ): + if i < score_range and displaced_position > score_range and displaced_position is not None: # Create rank lost event for the displaced user rank_lost_event = Event( created_at=datetime.now(UTC), @@ -814,10 +766,7 @@ async def process_user( ) # 情况3: 有最佳分数记录和该mod组合的记录,且是同一个记录,更新得分更高的情况 - elif ( - previous_score_best.score_id == previous_score_best_mod.score_id - and difference > 0 - ): + elif previous_score_best.score_id == previous_score_best_mod.score_id and difference > 0: previous_score_best.total_score = score.total_score previous_score_best.rank = score.rank previous_score_best.score_id = score.id @@ -847,9 +796,7 @@ async def process_user( statistics.count_300 += score.n300 + score.ngeki statistics.count_50 += score.n50 statistics.count_miss += score.nmiss - statistics.total_hits += ( - score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu - ) + statistics.total_hits += score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu if score.passed and ranked: with session.no_autoflush: @@ -885,7 +832,6 @@ async def process_score( item_id: int | None = None, room_id: int | None = None, ) -> Score: - assert user.id can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods) gamemode = GameMode.from_int(info.ruleset_id).to_special_mode(info.mods) score = Score( @@ -922,20 +868,15 @@ async def process_score( if can_get_pp: from app.calculator import pre_fetch_and_calculate_pp - pp = await pre_fetch_and_calculate_pp( - score, beatmap_id, session, redis, fetcher - ) + pp = await pre_fetch_and_calculate_pp(score, beatmap_id, session, redis, fetcher) score.pp = pp session.add(score) user_id = user.id await session.commit() await session.refresh(score) if can_get_pp and score.pp != 0: - previous_pp_best = await get_user_best_pp_in_beatmap( - session, beatmap_id, user_id, score.gamemode - ) + previous_pp_best = await get_user_best_pp_in_beatmap(session, beatmap_id, user_id, score.gamemode) if previous_pp_best is None or score.pp > previous_pp_best.pp: - assert score.id best_score = PPBestScore( user_id=user_id, score_id=score.id, diff --git a/app/database/score_token.py b/app/database/score_token.py index 4467b8b..ac6c9a9 100644 --- a/app/database/score_token.py +++ b/app/database/score_token.py @@ -7,6 +7,7 @@ from .beatmap import Beatmap from .lazer_user import User from sqlalchemy import Column, DateTime, Index +from sqlalchemy.orm import Mapped from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel @@ -14,16 +15,12 @@ class ScoreTokenBase(SQLModel, UTCBaseModel): score_id: int | None = Field(sa_column=Column(BigInteger), default=None) ruleset_id: GameMode playlist_item_id: int | None = Field(default=None) # playlist - created_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) - updated_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) + created_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime)) + updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime)) class ScoreToken(ScoreTokenBase, table=True): - __tablename__ = "score_tokens" # pyright: ignore[reportAssignmentType] + __tablename__: str = "score_tokens" __table_args__ = (Index("idx_user_playlist", "user_id", "playlist_item_id"),) id: int | None = Field( @@ -37,8 +34,8 @@ class ScoreToken(ScoreTokenBase, table=True): ) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"))) beatmap_id: int = Field(foreign_key="beatmaps.id") - user: User = Relationship() - beatmap: Beatmap = Relationship() + user: Mapped[User] = Relationship() + beatmap: Mapped[Beatmap] = Relationship() class ScoreTokenResp(ScoreTokenBase): diff --git a/app/database/statistics.py b/app/database/statistics.py index 01226cb..0422138 100644 --- a/app/database/statistics.py +++ b/app/database/statistics.py @@ -58,7 +58,7 @@ class UserStatisticsBase(SQLModel): class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True): - __tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType] + __tablename__: str = "lazer_user_statistics" id: int | None = Field(default=None, primary_key=True) user_id: int = Field( default=None, @@ -123,9 +123,7 @@ class UserStatisticsResp(UserStatisticsBase): if "user" in include: from .lazer_user import RANKING_INCLUDES, UserResp - user = await UserResp.from_db( - await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES - ) + user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES) s.user = user user_country = user.country_code @@ -149,9 +147,7 @@ class UserStatisticsResp(UserStatisticsBase): return s -async def get_rank( - session: AsyncSession, statistics: UserStatistics, country: str | None = None -) -> int | None: +async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None: from .lazer_user import User query = select( @@ -168,9 +164,7 @@ async def get_rank( subq = query.subquery() - result = await session.exec( - select(subq.c.rank).where(subq.c.user_id == statistics.user_id) - ) + result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id)) rank = result.first() if rank is None: diff --git a/app/database/team.py b/app/database/team.py index e61b4eb..e60b548 100644 --- a/app/database/team.py +++ b/app/database/team.py @@ -11,9 +11,9 @@ if TYPE_CHECKING: class Team(SQLModel, UTCBaseModel, table=True): - __tablename__ = "teams" # pyright: ignore[reportAssignmentType] + __tablename__: str = "teams" - id: int | None = Field(default=None, primary_key=True, index=True) + id: int = Field(default=None, primary_key=True, index=True) name: str = Field(max_length=100) short_name: str = Field(max_length=10) flag_url: str | None = Field(default=None) @@ -26,34 +26,22 @@ class Team(SQLModel, UTCBaseModel, table=True): class TeamMember(SQLModel, UTCBaseModel, table=True): - __tablename__ = "team_members" # pyright: ignore[reportAssignmentType] + __tablename__: str = "team_members" - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)) team_id: int = Field(foreign_key="teams.id") - joined_at: datetime = Field( - default_factory=datetime.utcnow, sa_column=Column(DateTime) - ) + joined_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime)) - user: "User" = Relationship( - back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"} - ) - team: "Team" = Relationship( - back_populates="members", sa_relationship_kwargs={"lazy": "joined"} - ) + user: "User" = Relationship(back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}) + team: "Team" = Relationship(back_populates="members", sa_relationship_kwargs={"lazy": "joined"}) class TeamRequest(SQLModel, UTCBaseModel, table=True): - __tablename__ = "team_requests" # pyright: ignore[reportAssignmentType] + __tablename__: str = "team_requests" - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)) team_id: int = Field(foreign_key="teams.id", primary_key=True) - requested_at: datetime = Field( - default=datetime.now(UTC), sa_column=Column(DateTime) - ) + requested_at: datetime = Field(default=datetime.now(UTC), sa_column=Column(DateTime)) user: "User" = Relationship(sa_relationship_kwargs={"lazy": "joined"}) team: "Team" = Relationship(sa_relationship_kwargs={"lazy": "joined"}) diff --git a/app/database/user_account_history.py b/app/database/user_account_history.py index 217c8eb..e09c209 100644 --- a/app/database/user_account_history.py +++ b/app/database/user_account_history.py @@ -22,7 +22,7 @@ class UserAccountHistoryBase(SQLModel, UTCBaseModel): class UserAccountHistory(UserAccountHistoryBase, table=True): - __tablename__ = "user_account_history" # pyright: ignore[reportAssignmentType] + __tablename__: str = "user_account_history" id: int | None = Field( sa_column=Column( @@ -32,9 +32,7 @@ class UserAccountHistory(UserAccountHistoryBase, table=True): primary_key=True, ) ) - user_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) - ) + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) class UserAccountHistoryResp(UserAccountHistoryBase): diff --git a/app/database/user_login_log.py b/app/database/user_login_log.py index 09e4c49..7ebe2c9 100644 --- a/app/database/user_login_log.py +++ b/app/database/user_login_log.py @@ -10,27 +10,17 @@ from sqlmodel import Field, SQLModel class UserLoginLog(SQLModel, table=True): """User login log table""" - __tablename__ = "user_login_log" # pyright: ignore[reportAssignmentType] + __tablename__: str = "user_login_log" id: int | None = Field(default=None, primary_key=True, description="Record ID") user_id: int = Field(index=True, description="User ID") - ip_address: str = Field( - max_length=45, index=True, description="IP address (supports IPv4 and IPv6)" - ) - user_agent: str | None = Field( - default=None, max_length=500, description="User agent information" - ) - login_time: datetime = Field( - default_factory=datetime.utcnow, description="Login time" - ) + ip_address: str = Field(max_length=45, index=True, description="IP address (supports IPv4 and IPv6)") + user_agent: str | None = Field(default=None, max_length=500, description="User agent information") + login_time: datetime = Field(default_factory=datetime.utcnow, description="Login time") # GeoIP information - country_code: str | None = Field( - default=None, max_length=2, description="Country code" - ) - country_name: str | None = Field( - default=None, max_length=100, description="Country name" - ) + country_code: str | None = Field(default=None, max_length=2, description="Country code") + country_name: str | None = Field(default=None, max_length=100, description="Country name") city_name: str | None = Field(default=None, max_length=100, description="City name") latitude: str | None = Field(default=None, max_length=20, description="Latitude") longitude: str | None = Field(default=None, max_length=20, description="Longitude") @@ -38,22 +28,14 @@ class UserLoginLog(SQLModel, table=True): # ASN information asn: int | None = Field(default=None, description="Autonomous System Number") - organization: str | None = Field( - default=None, max_length=200, description="Organization name" - ) + organization: str | None = Field(default=None, max_length=200, description="Organization name") # Login status - login_success: bool = Field( - default=True, description="Whether the login was successful" - ) - login_method: str = Field( - max_length=50, description="Login method (password/oauth/etc.)" - ) + login_success: bool = Field(default=True, description="Whether the login was successful") + login_method: str = Field(max_length=50, description="Login method (password/oauth/etc.)") # Additional information - notes: str | None = Field( - default=None, max_length=500, description="Additional notes" - ) + notes: str | None = Field(default=None, max_length=500, description="Additional notes") class Config: from_attributes = True diff --git a/app/dependencies/database.py b/app/dependencies/database.py index 77648d1..0eae622 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -40,15 +40,11 @@ engine = create_async_engine( redis_client = redis.from_url(settings.redis_url, decode_responses=True) # Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行 -redis_message_client = sync_redis.from_url( - settings.redis_url, decode_responses=True, db=1 -) +redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1) # 数据库依赖 -db_session_context: ContextVar[AsyncSession | None] = ContextVar( - "db_session_context", default=None -) +db_session_context: ContextVar[AsyncSession | None] = ContextVar("db_session_context", default=None) async def get_db(): diff --git a/app/dependencies/fetcher.py b/app/dependencies/fetcher.py index 51964f0..b4db26c 100644 --- a/app/dependencies/fetcher.py +++ b/app/dependencies/fetcher.py @@ -25,7 +25,5 @@ async def get_fetcher() -> Fetcher: if refresh_token: fetcher.refresh_token = str(refresh_token) if not fetcher.access_token or not fetcher.refresh_token: - logger.opt(colors=True).info( - f"Login to initialize fetcher: {fetcher.authorize_url}" - ) + logger.opt(colors=True).info(f"Login to initialize fetcher: {fetcher.authorize_url}") return fetcher diff --git a/app/dependencies/scheduler.py b/app/dependencies/scheduler.py index e24c6fa..2bcee52 100644 --- a/app/dependencies/scheduler.py +++ b/app/dependencies/scheduler.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import UTC +from typing import cast from apscheduler.schedulers.asyncio import AsyncIOScheduler @@ -16,7 +17,7 @@ def get_scheduler() -> AsyncIOScheduler: global scheduler if scheduler is None: init_scheduler() - return scheduler # pyright: ignore[reportReturnType] + return cast(AsyncIOScheduler, scheduler) def start_scheduler(): diff --git a/app/dependencies/user.py b/app/dependencies/user.py index d3787dc..69f3edd 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -70,9 +70,7 @@ async def v1_authorize( if not api_key: raise HTTPException(status_code=401, detail="Missing API key") - api_key_record = ( - await db.exec(select(V1APIKeys).where(V1APIKeys.key == api_key)) - ).first() + api_key_record = (await db.exec(select(V1APIKeys).where(V1APIKeys.key == api_key))).first() if not api_key_record: raise HTTPException(status_code=401, detail="Invalid API key") @@ -98,9 +96,7 @@ async def get_current_user( security_scopes: SecurityScopes, token_pw: Annotated[str | None, Depends(oauth2_password)] = None, token_code: Annotated[str | None, Depends(oauth2_code)] = None, - token_client_credentials: Annotated[ - str | None, Depends(oauth2_client_credentials) - ] = None, + token_client_credentials: Annotated[str | None, Depends(oauth2_client_credentials)] = None, ) -> User: """获取当前认证用户""" token = token_pw or token_code or token_client_credentials @@ -119,9 +115,7 @@ async def get_current_user( if not is_client: for scope in security_scopes.scopes: if scope not in token_record.scope.split(","): - raise HTTPException( - status_code=403, detail=f"Insufficient scope: {scope}" - ) + raise HTTPException(status_code=403, detail=f"Insufficient scope: {scope}") user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() if not user: diff --git a/app/fetcher/_base.py b/app/fetcher/_base.py index 97d16db..edf9041 100644 --- a/app/fetcher/_base.py +++ b/app/fetcher/_base.py @@ -121,14 +121,10 @@ class BaseFetcher: except Exception as e: last_error = e if attempt < max_retries: - logger.warning( - f"Request failed (attempt {attempt + 1}/{max_retries + 1}): {e}, retrying..." - ) + logger.warning(f"Request failed (attempt {attempt + 1}/{max_retries + 1}): {e}, retrying...") continue else: - logger.error( - f"Request failed after {max_retries + 1} attempts: {e}" - ) + logger.error(f"Request failed after {max_retries + 1} attempts: {e}") break # 如果所有重试都失败了 @@ -196,13 +192,9 @@ class BaseFetcher: f"fetcher:refresh_token:{self.client_id}", self.refresh_token, ) - logger.info( - f"Successfully refreshed access token for client {self.client_id}" - ) + logger.info(f"Successfully refreshed access token for client {self.client_id}") except Exception as e: - logger.error( - f"Failed to refresh access token for client {self.client_id}: {e}" - ) + logger.error(f"Failed to refresh access token for client {self.client_id}: {e}") # 清除无效的 token,要求重新授权 self.access_token = "" self.refresh_token = "" @@ -210,9 +202,7 @@ class BaseFetcher: redis = get_redis() await redis.delete(f"fetcher:access_token:{self.client_id}") await redis.delete(f"fetcher:refresh_token:{self.client_id}") - logger.warning( - f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}" - ) + logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}") raise async def _trigger_reauthorization(self) -> None: @@ -237,8 +227,7 @@ class BaseFetcher: await redis.delete(f"fetcher:refresh_token:{self.client_id}") logger.warning( - f"All tokens cleared for client {self.client_id}. " - f"Please re-authorize using: {self.authorize_url}" + f"All tokens cleared for client {self.client_id}. Please re-authorize using: {self.authorize_url}" ) def reset_auth_retry_count(self) -> None: diff --git a/app/fetcher/beatmap.py b/app/fetcher/beatmap.py index cf68fbe..fa49cf4 100644 --- a/app/fetcher/beatmap.py +++ b/app/fetcher/beatmap.py @@ -7,18 +7,14 @@ from ._base import BaseFetcher class BeatmapFetcher(BaseFetcher): - async def get_beatmap( - self, beatmap_id: int | None = None, beatmap_checksum: str | None = None - ) -> BeatmapResp: + async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapResp: if beatmap_id: params = {"id": beatmap_id} elif beatmap_checksum: params = {"checksum": beatmap_checksum} else: raise ValueError("Either beatmap_id or beatmap_checksum must be provided.") - logger.opt(colors=True).debug( - f"[BeatmapFetcher] get_beatmap: {params}" - ) + logger.opt(colors=True).debug(f"[BeatmapFetcher] get_beatmap: {params}") return BeatmapResp.model_validate( await self.request_api( diff --git a/app/fetcher/beatmap_raw.py b/app/fetcher/beatmap_raw.py index 25e9152..ccb19b8 100644 --- a/app/fetcher/beatmap_raw.py +++ b/app/fetcher/beatmap_raw.py @@ -18,9 +18,7 @@ class BeatmapRawFetcher(BaseFetcher): async def get_beatmap_raw(self, beatmap_id: int) -> str: for url in urls: req_url = url.format(beatmap_id=beatmap_id) - logger.opt(colors=True).debug( - f"[BeatmapRawFetcher] get_beatmap_raw: {req_url}" - ) + logger.opt(colors=True).debug(f"[BeatmapRawFetcher] get_beatmap_raw: {req_url}") resp = await self._request(req_url) if resp.status_code >= 400: continue @@ -34,9 +32,7 @@ class BeatmapRawFetcher(BaseFetcher): ) return response - async def get_or_fetch_beatmap_raw( - self, redis: redis.Redis, beatmap_id: int - ) -> str: + async def get_or_fetch_beatmap_raw(self, redis: redis.Redis, beatmap_id: int) -> str: from app.config import settings cache_key = f"beatmap:{beatmap_id}:raw" @@ -48,7 +44,7 @@ class BeatmapRawFetcher(BaseFetcher): if content: # 延长缓存时间 await redis.expire(cache_key, cache_expire) - return content # pyright: ignore[reportReturnType] + return content # 获取并缓存 raw = await self.get_beatmap_raw(beatmap_id) diff --git a/app/fetcher/beatmapset.py b/app/fetcher/beatmapset.py index 9f3c025..641e1e6 100644 --- a/app/fetcher/beatmapset.py +++ b/app/fetcher/beatmapset.py @@ -10,6 +10,7 @@ from app.helpers.rate_limiter import osu_api_rate_limiter from app.log import logger from app.models.beatmap import SearchQueryModel from app.models.model import Cursor +from app.utils import bg_tasks from ._base import BaseFetcher @@ -81,9 +82,7 @@ class BeatmapsetFetcher(BaseFetcher): cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":")) cache_hash = hashlib.md5(cache_json.encode()).hexdigest() - logger.opt(colors=True).debug( - f"[CacheKey] Query: {cache_data}, Hash: {cache_hash}" - ) + logger.opt(colors=True).debug(f"[CacheKey] Query: {cache_data}, Hash: {cache_hash}") return f"beatmapset:search:{cache_hash}" @@ -103,22 +102,16 @@ class BeatmapsetFetcher(BaseFetcher): return {} async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp: - logger.opt(colors=True).debug( - f"[BeatmapsetFetcher] get_beatmapset: {beatmap_set_id}" - ) + logger.opt(colors=True).debug(f"[BeatmapsetFetcher] get_beatmapset: {beatmap_set_id}") return BeatmapsetResp.model_validate( - await self.request_api( - f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}" - ) + await self.request_api(f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}") ) async def search_beatmapset( self, query: SearchQueryModel, cursor: Cursor, redis_client: redis.Redis ) -> SearchBeatmapsetsResp: - logger.opt(colors=True).debug( - f"[BeatmapsetFetcher] search_beatmapset: {query}" - ) + logger.opt(colors=True).debug(f"[BeatmapsetFetcher] search_beatmapset: {query}") # 生成缓存键 cache_key = self._generate_cache_key(query, cursor) @@ -126,9 +119,7 @@ class BeatmapsetFetcher(BaseFetcher): # 尝试从缓存获取结果 cached_result = await redis_client.get(cache_key) if cached_result: - logger.opt(colors=True).debug( - f"[BeatmapsetFetcher] Cache hit for key: {cache_key}" - ) + logger.opt(colors=True).debug(f"[BeatmapsetFetcher] Cache hit for key: {cache_key}") try: cached_data = json.loads(cached_result) return SearchBeatmapsetsResp.model_validate(cached_data) @@ -138,13 +129,9 @@ class BeatmapsetFetcher(BaseFetcher): ) # 缓存未命中,从 API 获取数据 - logger.opt(colors=True).debug( - "[BeatmapsetFetcher] Cache miss, fetching from API" - ) + logger.opt(colors=True).debug("[BeatmapsetFetcher] Cache miss, fetching from API") - params = query.model_dump( - exclude_none=True, exclude_unset=True, exclude_defaults=True - ) + params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True) if query.cursor_string: params["cursor_string"] = query.cursor_string @@ -164,39 +151,26 @@ class BeatmapsetFetcher(BaseFetcher): # 将结果缓存 15 分钟 cache_ttl = 15 * 60 # 15 分钟 - await redis_client.set( - cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl - ) + await redis_client.set(cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl) logger.opt(colors=True).debug( - f"[BeatmapsetFetcher] Cached result for key: " - f"{cache_key} (TTL: {cache_ttl}s)" + f"[BeatmapsetFetcher] Cached result for key: {cache_key} (TTL: {cache_ttl}s)" ) resp = SearchBeatmapsetsResp.model_validate(api_response) # 智能预取:只在用户明确搜索时才预取,避免过多API请求 # 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取 - if api_response.get("cursor") and ( - query.q or query.s != "leaderboard" or cursor - ): + if api_response.get("cursor") and (query.q or query.s != "leaderboard" or cursor): # 在后台预取下1页(减少预取量) import asyncio # 不立即创建任务,而是延迟一段时间再预取 async def delayed_prefetch(): await asyncio.sleep(3.0) # 延迟3秒 - await self.prefetch_next_pages( - query, api_response["cursor"], redis_client, pages=1 - ) + await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1) - # 创建延迟预取任务 - task = asyncio.create_task(delayed_prefetch()) - # 添加到后台任务集合避免被垃圾回收 - if not hasattr(self, "_background_tasks"): - self._background_tasks = set() - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) + bg_tasks.add_task(delayed_prefetch) return resp @@ -218,18 +192,14 @@ class BeatmapsetFetcher(BaseFetcher): # 使用当前 cursor 请求下一页 next_query = query.model_copy() - logger.opt(colors=True).debug( - f"[BeatmapsetFetcher] Prefetching page {page + 1}" - ) + logger.opt(colors=True).debug(f"[BeatmapsetFetcher] Prefetching page {page + 1}") # 生成下一页的缓存键 next_cache_key = self._generate_cache_key(next_query, cursor) # 检查是否已经缓存 if await redis_client.exists(next_cache_key): - logger.opt(colors=True).debug( - f"[BeatmapsetFetcher] Page {page + 1} already cached" - ) + logger.opt(colors=True).debug(f"[BeatmapsetFetcher] Page {page + 1} already cached") # 尝试从缓存获取cursor继续预取 cached_data = await redis_client.get(next_cache_key) if cached_data: @@ -247,9 +217,7 @@ class BeatmapsetFetcher(BaseFetcher): await asyncio.sleep(1.5) # 1.5秒延迟 # 请求下一页数据 - params = next_query.model_dump( - exclude_none=True, exclude_unset=True, exclude_defaults=True - ) + params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True) for k, v in cursor.items(): params[f"cursor[{k}]"] = v @@ -277,22 +245,18 @@ class BeatmapsetFetcher(BaseFetcher): ) logger.opt(colors=True).debug( - f"[BeatmapsetFetcher] Prefetched page {page + 1} " - f"(TTL: {prefetch_ttl}s)" + f"[BeatmapsetFetcher] Prefetched page {page + 1} (TTL: {prefetch_ttl}s)" ) except Exception as e: - logger.opt(colors=True).warning( - f"[BeatmapsetFetcher] Prefetch failed: {e}" - ) + logger.opt(colors=True).warning(f"[BeatmapsetFetcher] Prefetch failed: {e}") async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None: """预热主页缓存""" homepage_queries = self._get_homepage_queries() logger.opt(colors=True).info( - f"[BeatmapsetFetcher] Starting homepage cache warmup " - f"({len(homepage_queries)} queries)" + f"[BeatmapsetFetcher] Starting homepage cache warmup ({len(homepage_queries)} queries)" ) for i, (query, cursor) in enumerate(homepage_queries): @@ -306,15 +270,12 @@ class BeatmapsetFetcher(BaseFetcher): # 检查是否已经缓存 if await redis_client.exists(cache_key): logger.opt(colors=True).debug( - f"[BeatmapsetFetcher] " - f"Query {query.sort} already cached" + f"[BeatmapsetFetcher] Query {query.sort} already cached" ) continue # 请求并缓存 - params = query.model_dump( - exclude_none=True, exclude_unset=True, exclude_defaults=True - ) + params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True) api_response = await self.request_api( "https://osu.ppy.sh/api/v2/beatmapsets/search", @@ -334,17 +295,13 @@ class BeatmapsetFetcher(BaseFetcher): ) logger.opt(colors=True).info( - f"[BeatmapsetFetcher] " - f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)" + f"[BeatmapsetFetcher] Warmed up cache for {query.sort} (TTL: {cache_ttl}s)" ) if api_response.get("cursor"): - await self.prefetch_next_pages( - query, api_response["cursor"], redis_client, pages=2 - ) + await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2) except Exception as e: logger.opt(colors=True).error( - f"[BeatmapsetFetcher] " - f"Failed to warmup cache for {query.sort}: {e}" + f"[BeatmapsetFetcher] Failed to warmup cache for {query.sort}: {e}" ) diff --git a/app/helpers/geoip_helper.py b/app/helpers/geoip_helper.py index b30b39d..3d22927 100644 --- a/app/helpers/geoip_helper.py +++ b/app/helpers/geoip_helper.py @@ -55,14 +55,9 @@ class GeoIPHelper: - 临时目录退出后自动清理 """ if not self.license_key: - raise ValueError( - "缺少 MaxMind License Key,请传入或设置环境变量 MAXMIND_LICENSE_KEY" - ) + raise ValueError("缺少 MaxMind License Key,请传入或设置环境变量 MAXMIND_LICENSE_KEY") - url = ( - f"{BASE_URL}?edition_id={edition_id}&" - f"license_key={self.license_key}&suffix=tar.gz" - ) + url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz" with httpx.Client(follow_redirects=True, timeout=self.timeout) as client: with client.stream("GET", url) as resp: diff --git a/app/helpers/rate_limiter.py b/app/helpers/rate_limiter.py index 40efc9b..0002e5c 100644 --- a/app/helpers/rate_limiter.py +++ b/app/helpers/rate_limiter.py @@ -48,8 +48,7 @@ class RateLimiter: if wait_time > 0: logger.opt(colors=True).info( - f"[RateLimiter] Rate limit reached, " - f"waiting {wait_time:.2f}s" + f"[RateLimiter] Rate limit reached, waiting {wait_time:.2f}s" ) await asyncio.sleep(wait_time) current_time = time.time() @@ -107,11 +106,7 @@ class RateLimiter: "max_requests_per_minute": self.max_requests_per_minute, "burst_requests": len(self.burst_times), "burst_limit": self.burst_limit, - "next_reset_in_seconds": ( - 60.0 - (current_time - self.request_times[0]) - if self.request_times - else 0.0 - ), + "next_reset_in_seconds": (60.0 - (current_time - self.request_times[0]) if self.request_times else 0.0), } diff --git a/app/log.py b/app/log.py index 4b29f17..9186f49 100644 --- a/app/log.py +++ b/app/log.py @@ -46,14 +46,10 @@ class InterceptHandler(logging.Handler): color = True else: color = False - logger.opt(depth=depth, exception=record.exc_info, colors=color).log( - level, message - ) + logger.opt(depth=depth, exception=record.exc_info, colors=color).log(level, message) def _format_uvicorn_error_log(self, message: str) -> str: - websocket_pattern = ( - r'(\d+\.\d+\.\d+\.\d+:\d+)\s*-\s*"WebSocket\s+([^"]+)"\s+([\w\[\]]+)' - ) + websocket_pattern = r'(\d+\.\d+\.\d+\.\d+:\d+)\s*-\s*"WebSocket\s+([^"]+)"\s+([\w\[\]]+)' websocket_match = re.search(websocket_pattern, message) if websocket_match: @@ -64,14 +60,8 @@ class InterceptHandler(logging.Handler): "[accepted]": "[accepted]", "403": "403 [rejected]", } - colored_status = status_colors.get( - status.lower(), f"{status}" - ) - return ( - f'{colored_ip} - "WebSocket ' - f'{path}" ' - f"{colored_status}" - ) + colored_status = status_colors.get(status.lower(), f"{status}") + return f'{colored_ip} - "WebSocket {path}" {colored_status}' else: return message @@ -121,9 +111,7 @@ logger.remove() logger.add( stdout, colorize=True, - format=( - "{time:YYYY-MM-DD HH:mm:ss} [{level}] | {message}" - ), + format=("{time:YYYY-MM-DD HH:mm:ss} [{level}] | {message}"), level=settings.log_level, diagnose=settings.debug, ) diff --git a/app/models/achievement.py b/app/models/achievement.py index 6f14dd9..7b37155 100644 --- a/app/models/achievement.py +++ b/app/models/achievement.py @@ -19,17 +19,11 @@ class Achievement(NamedTuple): @property def url(self) -> str: - return ( - self.medal_url - or f"https://assets.ppy.sh/medals/client/{self.assets_id}.png" - ) + return self.medal_url or f"https://assets.ppy.sh/medals/client/{self.assets_id}.png" @property def url2x(self) -> str: - return ( - self.medal_url2x - or f"https://assets.ppy.sh/medals/client/{self.assets_id}@2x.png" - ) + return self.medal_url2x or f"https://assets.ppy.sh/medals/client/{self.assets_id}@2x.png" MedalProcessor = Callable[[AsyncSession, "Score", "Beatmap"], Awaitable[bool]] diff --git a/app/models/api_me.py b/app/models/api_me.py index 8e632e5..dab1256 100644 --- a/app/models/api_me.py +++ b/app/models/api_me.py @@ -11,7 +11,8 @@ class APIMe(UserResp): """ /me 端点的响应模型 对应 osu! 的 APIMe 类型,继承 APIUser(UserResp) 并包含 session_verified 字段 - + session_verified 字段已经在 UserResp 中定义,这里不需要重复定义 """ + pass diff --git a/app/models/beatmap.py b/app/models/beatmap.py index 69304b6..2e58857 100644 --- a/app/models/beatmap.py +++ b/app/models/beatmap.py @@ -95,11 +95,7 @@ class SearchQueryModel(BaseModel): q: str = Field("", description="搜索关键词") c: Annotated[ - list[ - Literal[ - "recommended", "converts", "follows", "spotlights", "featured_artists" - ] - ], + list[Literal["recommended", "converts", "follows", "spotlights", "featured_artists"]], BeforeValidator(_parse_list), PlainSerializer(lambda x: ".".join(x)), ] = Field( @@ -188,12 +184,10 @@ class SearchQueryModel(BaseModel): list[Literal["video", "storyboard"]], BeforeValidator(_parse_list), PlainSerializer(lambda x: ".".join(x)), - ] = Field( - default_factory=list, description=("其他:video 有视频 / storyboard 有故事板") + ] = Field(default_factory=list, description=("其他:video 有视频 / storyboard 有故事板")) + r: Annotated[list[Rank], BeforeValidator(_parse_list), PlainSerializer(lambda x: ".".join(x))] = Field( + default_factory=list, description="成绩" ) - r: Annotated[ - list[Rank], BeforeValidator(_parse_list), PlainSerializer(lambda x: ".".join(x)) - ] = Field(default_factory=list, description="成绩") played: bool = Field( default=False, description="玩过", diff --git a/app/models/extended_auth.py b/app/models/extended_auth.py index e98ea7d..b3fc831 100644 --- a/app/models/extended_auth.py +++ b/app/models/extended_auth.py @@ -9,12 +9,13 @@ from pydantic import BaseModel class ExtendedTokenResponse(BaseModel): """扩展的令牌响应,支持二次验证状态""" + access_token: str | None = None token_type: str = "Bearer" expires_in: int | None = None refresh_token: str | None = None scope: str | None = None - + # 二次验证相关字段 requires_second_factor: bool = False verification_message: str | None = None @@ -23,6 +24,7 @@ class ExtendedTokenResponse(BaseModel): class SessionState(BaseModel): """会话状态""" + user_id: int username: str email: str diff --git a/app/models/metadata_hub.py b/app/models/metadata_hub.py index 8bf237d..7188235 100644 --- a/app/models/metadata_hub.py +++ b/app/models/metadata_hub.py @@ -145,9 +145,7 @@ class MultiplayerPlaylistItemStats(BaseModel): class MultiplayerRoomStats(BaseModel): room_id: int - playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field( - default_factory=dict - ) + playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(default_factory=dict) class MultiplayerRoomScoreSetEvent(BaseModel): diff --git a/app/models/mods.py b/app/models/mods.py index f407500..f168fae 100644 --- a/app/models/mods.py +++ b/app/models/mods.py @@ -174,11 +174,7 @@ def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool: return True ranked_mods = RANKED_MODS[ruleset_id] for mod in mods: - if ( - app_settings.enable_rx - and mod["acronym"] == "RX" - and ruleset_id in {0, 1, 2} - ): + if app_settings.enable_rx and mod["acronym"] == "RX" and ruleset_id in {0, 1, 2}: continue if app_settings.enable_ap and mod["acronym"] == "AP" and ruleset_id == 0: continue @@ -251,10 +247,7 @@ def get_available_mods(ruleset_id: int, required_mods: list[APIMod]) -> list[API if mod_acronym in incompatible_mods: continue - if any( - required_acronym in mod_data["IncompatibleMods"] - for required_acronym in required_mod_acronyms - ): + if any(required_acronym in mod_data["IncompatibleMods"] for required_acronym in required_mod_acronyms): continue if mod_data.get("UserPlayable", False): diff --git a/app/models/multiplayer_hub.py b/app/models/multiplayer_hub.py index adc3f00..13294b8 100644 --- a/app/models/multiplayer_hub.py +++ b/app/models/multiplayer_hub.py @@ -121,32 +121,21 @@ class PlaylistItem(BaseModel): star_rating: float freestyle: bool - def _validate_mod_for_ruleset( - self, mod: APIMod, ruleset_key: int, context: str = "mod" - ) -> None: + def _validate_mod_for_ruleset(self, mod: APIMod, ruleset_key: int, context: str = "mod") -> None: typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key) # Check if mod is valid for ruleset - if ( - typed_ruleset_key not in API_MODS - or mod["acronym"] not in API_MODS[typed_ruleset_key] - ): - raise InvokeException( - f"{context} {mod['acronym']} is invalid for this ruleset" - ) + if typed_ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[typed_ruleset_key]: + raise InvokeException(f"{context} {mod['acronym']} is invalid for this ruleset") mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]] # Check if mod is unplayable in multiplayer if mod_settings.get("UserPlayable", True) is False: - raise InvokeException( - f"{context} {mod['acronym']} is not playable by users" - ) + raise InvokeException(f"{context} {mod['acronym']} is not playable by users") if mod_settings.get("ValidForMultiplayer", True) is False: - raise InvokeException( - f"{context} {mod['acronym']} is not valid for multiplayer" - ) + raise InvokeException(f"{context} {mod['acronym']} is not valid for multiplayer") def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None: from typing import Literal, cast @@ -159,10 +148,7 @@ class PlaylistItem(BaseModel): incompatible = set(mod1_settings.get("IncompatibleMods", [])) for mod2 in mods[i + 1 :]: if mod2["acronym"] in incompatible: - raise InvokeException( - f"Mods {mod1['acronym']} and " - f"{mod2['acronym']} are incompatible" - ) + raise InvokeException(f"Mods {mod1['acronym']} and {mod2['acronym']} are incompatible") def _check_required_allowed_compatibility(self, ruleset_key: int) -> None: from typing import Literal, cast @@ -178,10 +164,7 @@ class PlaylistItem(BaseModel): conflicting_allowed = allowed_acronyms & incompatible if conflicting_allowed: conflict_list = ", ".join(conflicting_allowed) - raise InvokeException( - f"Required mod {req_acronym} conflicts with " - f"allowed mods: {conflict_list}" - ) + raise InvokeException(f"Required mod {req_acronym} conflicts with allowed mods: {conflict_list}") def validate_playlist_item_mods(self) -> None: ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id) @@ -219,10 +202,7 @@ class PlaylistItem(BaseModel): # Check if mods are valid for the ruleset for mod in proposed_mods: - if ( - ruleset_key not in API_MODS - or mod["acronym"] not in API_MODS[ruleset_key] - ): + if ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[ruleset_key]: all_proposed_valid = False continue valid_mods.append(mod) @@ -252,9 +232,7 @@ class PlaylistItem(BaseModel): # Check compatibility with required mods required_mod_acronyms = {mod["acronym"] for mod in self.required_mods} - all_mod_acronyms = { - mod["acronym"] for mod in final_valid_mods - } | required_mod_acronyms + all_mod_acronyms = {mod["acronym"] for mod in final_valid_mods} | required_mod_acronyms # Check for incompatibility between required and user mods filtered_valid_mods = [] @@ -288,9 +266,7 @@ class PlaylistItem(BaseModel): class _MultiplayerCountdown(SignalRUnionMessage): id: int = 0 time_remaining: timedelta - is_exclusive: Annotated[ - bool, Field(default=True), SignalRMeta(member_ignore=True) - ] = True + is_exclusive: Annotated[bool, Field(default=True), SignalRMeta(member_ignore=True)] = True class MatchStartCountdown(_MultiplayerCountdown): @@ -305,17 +281,13 @@ class ServerShuttingDownCountdown(_MultiplayerCountdown): union_type: ClassVar[Literal[2]] = 2 -MultiplayerCountdown = ( - MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown -) +MultiplayerCountdown = MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown class MultiplayerRoomUser(BaseModel): user_id: int state: MultiplayerUserState = MultiplayerUserState.IDLE - availability: BeatmapAvailability = BeatmapAvailability( - state=DownloadState.UNKNOWN, download_progress=None - ) + availability: BeatmapAvailability = BeatmapAvailability(state=DownloadState.UNKNOWN, download_progress=None) mods: list[APIMod] = Field(default_factory=list) match_state: MatchUserState | None = None ruleset_id: int | None = None # freestyle @@ -358,9 +330,7 @@ class MultiplayerRoom(BaseModel): expired=item.expired, playlist_order=item.playlist_order, played_at=item.played_at, - star_rating=item.beatmap.difficulty_rating - if item.beatmap is not None - else 0.0, + star_rating=item.beatmap.difficulty_rating if item.beatmap is not None else 0.0, freestyle=item.freestyle, ) ) @@ -425,9 +395,7 @@ class MultiplayerQueue: user_item_groups[item.owner_id] = [] user_item_groups[item.owner_id].append(item) - max_items = max( - (len(items) for items in user_item_groups.values()), default=0 - ) + max_items = max((len(items) for items in user_item_groups.values()), default=0) for i in range(max_items): current_set = [] @@ -436,20 +404,13 @@ class MultiplayerQueue: current_set.append(items[i]) if is_first_set: - current_set.sort( - key=lambda item: (item.playlist_order, item.id) - ) + current_set.sort(key=lambda item: (item.playlist_order, item.id)) ordered_active_items.extend(current_set) first_set_order_by_user_id = { - item.owner_id: idx - for idx, item in enumerate(ordered_active_items) + item.owner_id: idx for idx, item in enumerate(ordered_active_items) } else: - current_set.sort( - key=lambda item: first_set_order_by_user_id.get( - item.owner_id, 0 - ) - ) + current_set.sort(key=lambda item: first_set_order_by_user_id.get(item.owner_id, 0)) ordered_active_items.extend(current_set) is_first_set = False @@ -464,9 +425,7 @@ class MultiplayerQueue: continue item.playlist_order = idx await Playlist.update(item, self.room.room_id, session) - await self.hub.playlist_changed( - self.server_room, item, beatmap_changed=False - ) + await self.hub.playlist_changed(self.server_room, item, beatmap_changed=False) async def update_current_item(self): upcoming_items = self.upcoming_items @@ -494,16 +453,7 @@ class MultiplayerQueue: raise InvokeException("You are not the host") limit = HOST_LIMIT if is_host else PER_USER_LIMIT - if ( - len( - [ - True - for u in self.room.playlist - if u.owner_id == user.user_id and not u.expired - ] - ) - >= limit - ): + if len([True for u in self.room.playlist if u.owner_id == user.user_id and not u.expired]) >= limit: raise InvokeException(f"You can only have {limit} items in the queue") if item.freestyle and len(item.allowed_mods) > 0: @@ -512,9 +462,7 @@ class MultiplayerQueue: async with with_db() as session: fetcher = await get_fetcher() async with session: - beatmap = await Beatmap.get_or_fetch( - session, fetcher, bid=item.beatmap_id - ) + beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id) if beatmap is None: raise InvokeException("Beatmap not found") if item.beatmap_checksum != beatmap.checksum: @@ -538,29 +486,19 @@ class MultiplayerQueue: async with with_db() as session: fetcher = await get_fetcher() async with session: - beatmap = await Beatmap.get_or_fetch( - session, fetcher, bid=item.beatmap_id - ) + beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id) if item.beatmap_checksum != beatmap.checksum: raise InvokeException("Checksum mismatch") - existing_item = next( - (i for i in self.room.playlist if i.id == item.id), None - ) + existing_item = next((i for i in self.room.playlist if i.id == item.id), None) if existing_item is None: - raise InvokeException( - "Attempted to change an item that doesn't exist" - ) + raise InvokeException("Attempted to change an item that doesn't exist") if existing_item.owner_id != user.user_id and self.room.host != user: - raise InvokeException( - "Attempted to change an item which is not owned by the user" - ) + raise InvokeException("Attempted to change an item which is not owned by the user") if existing_item.expired: - raise InvokeException( - "Attempted to change an item which has already been played" - ) + raise InvokeException("Attempted to change an item which has already been played") item.validate_playlist_item_mods() item.owner_id = user.user_id @@ -578,8 +516,7 @@ class MultiplayerQueue: await self.hub.playlist_changed( self.server_room, item, - beatmap_changed=item.beatmap_checksum - != existing_item.beatmap_checksum, + beatmap_changed=item.beatmap_checksum != existing_item.beatmap_checksum, ) async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser): @@ -600,14 +537,10 @@ class MultiplayerQueue: raise InvokeException("The only item in the room cannot be removed") if item.owner_id != user.user_id and self.room.host != user: - raise InvokeException( - "Attempted to remove an item which is not owned by the user" - ) + raise InvokeException("Attempted to remove an item which is not owned by the user") if item.expired: - raise InvokeException( - "Attempted to remove an item which has already been played" - ) + raise InvokeException("Attempted to remove an item which has already been played") async with with_db() as session: await Playlist.delete_item(item.id, self.room.room_id, session) @@ -668,9 +601,7 @@ class CountdownInfo: def __init__(self, countdown: MultiplayerCountdown): self.countdown = countdown self.duration = ( - countdown.time_remaining - if countdown.time_remaining > timedelta(seconds=0) - else timedelta(seconds=0) + countdown.time_remaining if countdown.time_remaining > timedelta(seconds=0) else timedelta(seconds=0) ) @@ -704,9 +635,7 @@ class MatchTypeHandler(ABC): async def handle_join(self, user: MultiplayerRoomUser): ... @abstractmethod - async def handle_request( - self, user: MultiplayerRoomUser, request: MatchRequest - ): ... + async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ... @abstractmethod async def handle_leave(self, user: MultiplayerRoomUser): ... @@ -723,9 +652,7 @@ class HeadToHeadHandler(MatchTypeHandler): await self.hub.change_user_match_state(self.room, user) @override - async def handle_request( - self, user: MultiplayerRoomUser, request: MatchRequest - ): ... + async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ... @override async def handle_leave(self, user: MultiplayerRoomUser): ... @@ -762,9 +689,7 @@ class TeamVersusHandler(MatchTypeHandler): team_counts = defaultdict(int) for user in self.room.room.users: - if user.match_state is not None and isinstance( - user.match_state, TeamVersusUserState - ): + if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState): team_counts[user.match_state.team_id] += 1 if team_counts: @@ -798,9 +723,7 @@ class TeamVersusHandler(MatchTypeHandler): def get_details(self) -> MatchStartedEventDetail: teams: dict[int, Literal["blue", "red"]] = {} for user in self.room.room.users: - if user.match_state is not None and isinstance( - user.match_state, TeamVersusUserState - ): + if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState): teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red" detail = MatchStartedEventDetail(room_type="team_versus", team=teams) return detail @@ -843,9 +766,7 @@ class ServerMultiplayerRoom: self._tracked_countdown = {} async def set_handler(self): - self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type]( - self - ) + self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](self) for i in self.room.users: await self.match_type_handler.handle_join(i) @@ -871,9 +792,7 @@ class ServerMultiplayerRoom: info = CountdownInfo(countdown) self.room.active_countdowns.append(info.countdown) self._tracked_countdown[countdown.id] = info - await self.hub.send_match_event( - self, CountdownStartedEvent(countdown=info.countdown) - ) + await self.hub.send_match_event(self, CountdownStartedEvent(countdown=info.countdown)) info.task = asyncio.create_task(_countdown_task(self)) async def stop_countdown(self, countdown: MultiplayerCountdown): diff --git a/app/models/notification.py b/app/models/notification.py index 199c4fb..dbfcaa3 100644 --- a/app/models/notification.py +++ b/app/models/notification.py @@ -53,7 +53,7 @@ class NotificationName(str, Enum): NotificationName.BEATMAP_OWNER_CHANGE: "beatmap_owner_change", NotificationName.BEATMAPSET_DISCUSSION_LOCK: "beatmapset_discussion", NotificationName.BEATMAPSET_DISCUSSION_POST_NEW: "beatmapset_discussion", - NotificationName.BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM: "beatmapset_problem", # noqa: E501 + NotificationName.BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM: "beatmapset_problem", NotificationName.BEATMAPSET_DISCUSSION_REVIEW_NEW: "beatmapset_discussion", NotificationName.BEATMAPSET_DISCUSSION_UNLOCK: "beatmapset_discussion", NotificationName.BEATMAPSET_DISQUALIFY: "beatmapset_state", @@ -164,17 +164,11 @@ class ChannelMessageTeam(ChannelMessageBase): from app.database import TeamMember user_team_id = ( - await session.exec( - select(TeamMember.team_id).where(TeamMember.user_id == self._user.id) - ) + await session.exec(select(TeamMember.team_id).where(TeamMember.user_id == self._user.id)) ).first() if not user_team_id: return [] - user_ids = ( - await session.exec( - select(TeamMember.user_id).where(TeamMember.team_id == user_team_id) - ) - ).all() + user_ids = (await session.exec(select(TeamMember.user_id).where(TeamMember.team_id == user_team_id))).all() return list(user_ids) diff --git a/app/models/score.py b/app/models/score.py index 4e42baf..a6adb53 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -197,9 +197,7 @@ class SoloScoreSubmissionInfo(BaseModel): # check incompatible mods for mod in mods: if mod["acronym"] in incompatible_mods: - raise ValueError( - f"Mod {mod['acronym']} is incompatible with other mods" - ) + raise ValueError(f"Mod {mod['acronym']} is incompatible with other mods") setting_mods = API_MODS[info.data["ruleset_id"]].get(mod["acronym"]) if not setting_mods: raise ValueError(f"Invalid mod: {mod['acronym']}") diff --git a/app/models/signalr.py b/app/models/signalr.py index ffbaf6b..8a60b26 100644 --- a/app/models/signalr.py +++ b/app/models/signalr.py @@ -22,9 +22,7 @@ class SignalRUnionMessage(BaseModel): class Transport(BaseModel): transport: str - transfer_formats: list[str] = Field( - default_factory=lambda: ["Binary", "Text"], alias="transferFormats" - ) + transfer_formats: list[str] = Field(default_factory=lambda: ["Binary", "Text"], alias="transferFormats") class NegotiateResponse(BaseModel): diff --git a/app/models/spectator_hub.py b/app/models/spectator_hub.py index 9f35932..8a5eb71 100644 --- a/app/models/spectator_hub.py +++ b/app/models/spectator_hub.py @@ -89,9 +89,7 @@ class LegacyReplayFrame(BaseModel): mouse_y: float | None = None button_state: int - header: Annotated[ - FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True) - ] + header: Annotated[FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)] class FrameDataBundle(BaseModel): diff --git a/app/router/auth.py b/app/router/auth.py index db8cfe1..c92ddd7 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -2,7 +2,7 @@ from __future__ import annotations from datetime import UTC, datetime, timedelta import re -from typing import Literal, Union +from typing import Literal from app.auth import ( authenticate_user, @@ -22,19 +22,19 @@ from app.dependencies.database import Database, get_redis from app.dependencies.geoip import get_client_ip, get_geoip_helper from app.helpers.geoip_helper import GeoIPHelper from app.log import logger +from app.models.extended_auth import ExtendedTokenResponse from app.models.oauth import ( OAuthErrorResponse, RegistrationRequestErrors, TokenResponse, UserRegistrationErrors, ) -from app.models.extended_auth import ExtendedTokenResponse from app.models.score import GameMode -from app.service.login_log_service import LoginLogService from app.service.email_verification_service import ( EmailVerificationService, - LoginSessionService + LoginSessionService, ) +from app.service.login_log_service import LoginLogService from app.service.password_reset_service import password_reset_service from fastapi import APIRouter, Depends, Form, Request @@ -44,13 +44,9 @@ from sqlalchemy import text from sqlmodel import select -def create_oauth_error_response( - error: str, description: str, hint: str, status_code: int = 400 -): +def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400): """创建标准的 OAuth 错误响应""" - error_data = OAuthErrorResponse( - error=error, error_description=description, hint=hint, message=description - ) + error_data = OAuthErrorResponse(error=error, error_description=description, hint=hint, message=description) return JSONResponse(status_code=status_code, content=error_data.model_dump()) @@ -123,9 +119,7 @@ async def register_user( ) ) - return JSONResponse( - status_code=422, content={"form_error": errors.model_dump()} - ) + return JSONResponse(status_code=422, content={"form_error": errors.model_dump()}) try: # 获取客户端 IP 并查询地理位置 @@ -137,10 +131,7 @@ async def register_user( geo_info = geoip.lookup(client_ip) if geo_info and geo_info.get("country_iso"): country_code = geo_info["country_iso"] - logger.info( - f"User {user_username} registering from " - f"{client_ip}, country: {country_code}" - ) + logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}") else: logger.warning(f"Could not determine country for IP {client_ip}") except Exception as e: @@ -148,7 +139,7 @@ async def register_user( # 创建新用户 # 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy) - result = await db.execute( # pyright: ignore[reportDeprecated] + result = await db.execute( text( "SELECT AUTO_INCREMENT FROM information_schema.TABLES " "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'" @@ -173,7 +164,6 @@ async def register_user( db.add(new_user) await db.commit() await db.refresh(new_user) - assert new_user.id is not None, "New user ID should not be None" for i in [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]: statistics = UserStatistics(mode=i, user_id=new_user.id) db.add(statistics) @@ -193,36 +183,30 @@ async def register_user( logger.exception(f"Registration error for user {user_username}") # 返回通用错误 - errors = RegistrationRequestErrors( - message="An error occurred while creating your account. Please try again." - ) + errors = RegistrationRequestErrors(message="An error occurred while creating your account. Please try again.") - return JSONResponse( - status_code=500, content={"form_error": errors.model_dump()} - ) + return JSONResponse(status_code=500, content={"form_error": errors.model_dump()}) @router.post( "/oauth/token", - response_model=Union[TokenResponse, ExtendedTokenResponse], + response_model=TokenResponse | ExtendedTokenResponse, name="获取访问令牌", description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。", ) async def oauth_token( db: Database, request: Request, - grant_type: Literal[ - "authorization_code", "refresh_token", "password", "client_credentials" - ] = Form(..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"), + grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form( + ..., description="授权类型:密码/刷新令牌/授权码/客户端凭证" + ), client_id: int = Form(..., description="客户端 ID"), client_secret: str = Form(..., description="客户端密钥"), code: str | None = Form(None, description="授权码(仅授权码模式需要)"), scope: str = Form("*", description="权限范围(空格分隔,默认为 '*')"), username: str | None = Form(None, description="用户名(仅密码模式需要)"), password: str | None = Form(None, description="密码(仅密码模式需要)"), - refresh_token: str | None = Form( - None, description="刷新令牌(仅刷新令牌模式需要)" - ), + refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"), redis: Redis = Depends(get_redis), geoip: GeoIPHelper = Depends(get_geoip_helper), ): @@ -303,37 +287,33 @@ async def oauth_token( await db.refresh(user) # 获取用户信息和客户端信息 - user_id = getattr(user, "id") - assert user_id is not None, "User ID should not be None after authentication" - - from app.dependencies.geoip import get_client_ip + user_id = user.id + ip_address = get_client_ip(request) user_agent = request.headers.get("User-Agent", "") - + # 获取国家代码 geo_info = geoip.lookup(ip_address) country_code = geo_info.get("country_iso", "XX") - + # 检查是否为新位置登录 - is_new_location = await LoginSessionService.check_new_location( - db, user_id, ip_address, country_code - ) - + is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code) + # 创建登录会话记录 - login_session = await LoginSessionService.create_session( + login_session = await LoginSessionService.create_session( # noqa: F841 db, redis, user_id, ip_address, user_agent, country_code, is_new_location ) - + # 如果是新位置登录,需要邮件验证 if is_new_location and settings.enable_email_verification: # 刷新用户对象以确保属性已加载 await db.refresh(user) - + # 发送邮件验证码 verification_sent = await EmailVerificationService.send_verification_email( db, redis, user_id, user.username, user.email, ip_address, user_agent ) - + # 记录需要二次验证的登录尝试 await LoginLogService.record_login( db=db, @@ -343,14 +323,16 @@ async def oauth_token( login_method="password_pending_verification", notes=f"新位置登录,需要邮件验证 - IP: {ip_address}, 国家: {country_code}", ) - + if not verification_sent: # 邮件发送失败,记录错误 logger.error(f"[Auth] Failed to send email verification code for user {user_id}") elif is_new_location and not settings.enable_email_verification: # 新位置登录但邮件验证功能被禁用,直接标记会话为已验证 await LoginSessionService.mark_session_verified(db, user_id) - logger.debug(f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}") + logger.debug( + f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}" + ) else: # 不是新位置登录,正常登录 await LoginLogService.record_login( @@ -361,20 +343,17 @@ async def oauth_token( login_method="password", notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}", ) - + # 无论是否新位置登录,都返回正常的token # session_verified状态通过/me接口的session_verified字段来体现 # 生成令牌 access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) # 获取用户ID,避免触发延迟加载 - access_token = create_access_token( - data={"sub": str(user_id)}, expires_delta=access_token_expires - ) + access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires) refresh_token_str = generate_refresh_token() # 存储令牌 - assert user_id await store_token( db, user_id, @@ -423,9 +402,7 @@ async def oauth_token( # 生成新的访问令牌 access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) - access_token = create_access_token( - data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires - ) + access_token = create_access_token(data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires) new_refresh_token = generate_refresh_token() # 更新令牌 @@ -489,17 +466,11 @@ async def oauth_token( # 生成令牌 access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) - # 重新查询只获取ID,避免触发延迟加载 - id_result = await db.exec(select(User.id).where(User.username == username)) - user_id = id_result.first() - - access_token = create_access_token( - data={"sub": str(user_id)}, expires_delta=access_token_expires - ) + user_id = user.id + access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires) refresh_token_str = generate_refresh_token() # 存储令牌 - assert user_id await store_token( db, user_id, @@ -539,9 +510,7 @@ async def oauth_token( # 生成令牌 access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) - access_token = create_access_token( - data={"sub": "3"}, expires_delta=access_token_expires - ) + access_token = create_access_token(data={"sub": "3"}, expires_delta=access_token_expires) refresh_token_str = generate_refresh_token() # 存储令牌 @@ -567,7 +536,7 @@ async def oauth_token( @router.post( "/password-reset/request", name="请求密码重置", - description="通过邮箱请求密码重置验证码" + description="通过邮箱请求密码重置验证码", ) async def request_password_reset( request: Request, @@ -578,42 +547,26 @@ async def request_password_reset( 请求密码重置 """ from app.dependencies.geoip import get_client_ip - + # 获取客户端信息 ip_address = get_client_ip(request) user_agent = request.headers.get("User-Agent", "") - + # 请求密码重置 success, message = await password_reset_service.request_password_reset( email=email.lower().strip(), ip_address=ip_address, user_agent=user_agent, - redis=redis + redis=redis, ) - + if success: - return JSONResponse( - status_code=200, - content={ - "success": True, - "message": message - } - ) + return JSONResponse(status_code=200, content={"success": True, "message": message}) else: - return JSONResponse( - status_code=400, - content={ - "success": False, - "error": message - } - ) + return JSONResponse(status_code=400, content={"success": False, "error": message}) -@router.post( - "/password-reset/reset", - name="重置密码", - description="使用验证码重置密码" -) +@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码") async def reset_password( request: Request, email: str = Form(..., description="邮箱地址"), @@ -625,32 +578,20 @@ async def reset_password( 重置密码 """ from app.dependencies.geoip import get_client_ip - + # 获取客户端信息 ip_address = get_client_ip(request) - + # 重置密码 success, message = await password_reset_service.reset_password( email=email.lower().strip(), reset_code=reset_code.strip(), new_password=new_password, ip_address=ip_address, - redis=redis + redis=redis, ) - + if success: - return JSONResponse( - status_code=200, - content={ - "success": True, - "message": message - } - ) + return JSONResponse(status_code=200, content={"success": True, "message": message}) else: - return JSONResponse( - status_code=400, - content={ - "success": False, - "error": message - } - ) + return JSONResponse(status_code=400, content={"success": False, "error": message}) diff --git a/app/router/notification/__init__.py b/app/router/notification/__init__.py index 206abd8..fef679f 100644 --- a/app/router/notification/__init__.py +++ b/app/router/notification/__init__.py @@ -43,9 +43,9 @@ async def get_notifications( current_user: User = Security(get_client_user), ): if settings.server_url is not None: - notification_endpoint = f"{settings.server_url}notification-server".replace( - "http://", "ws://" - ).replace("https://", "wss://") + notification_endpoint = f"{settings.server_url}notification-server".replace("http://", "ws://").replace( + "https://", "wss://" + ) else: notification_endpoint = "/notification-server" query = select(UserNotification).where( @@ -96,21 +96,15 @@ async def _get_notifications( query = base_query.where(UserNotification.notification_id == identity.id) if identity.object_id is not None: query = base_query.where( - col(UserNotification.notification).has( - col(Notification.object_id) == identity.object_id - ) + col(UserNotification.notification).has(col(Notification.object_id) == identity.object_id) ) if identity.object_type is not None: query = base_query.where( - col(UserNotification.notification).has( - col(Notification.object_type) == identity.object_type - ) + col(UserNotification.notification).has(col(Notification.object_type) == identity.object_type) ) if identity.category is not None: query = base_query.where( - col(UserNotification.notification).has( - col(Notification.category) == identity.category - ) + col(UserNotification.notification).has(col(Notification.category) == identity.category) ) result.update({n.notification_id: n for n in await session.exec(query)}) return list(result.values()) @@ -134,7 +128,6 @@ async def mark_notifications_as_read( for user_notification in user_notifications: user_notification.is_read = True - assert current_user.id await server.send_event( current_user.id, ChatEvent( diff --git a/app/router/notification/banchobot.py b/app/router/notification/banchobot.py index 7aefccc..4463255 100644 --- a/app/router/notification/banchobot.py +++ b/app/router/notification/banchobot.py @@ -91,9 +91,7 @@ class Bot: if reply: await self._send_reply(user, channel, reply, session) - async def _send_message( - self, channel: ChatChannel, content: str, session: AsyncSession - ) -> None: + async def _send_message(self, channel: ChatChannel, content: str, session: AsyncSession) -> None: bot = await session.get(User, self.bot_user_id) if bot is None: return @@ -101,7 +99,6 @@ class Bot: if channel_id is None: return - assert bot.id is not None msg = ChatMessage( channel_id=channel_id, content=content, @@ -115,9 +112,7 @@ class Bot: resp = await ChatMessageResp.from_db(msg, session, bot) await server.send_message_to_channel(resp) - async def _ensure_pm_channel( - self, user: User, session: AsyncSession - ) -> ChatChannel | None: + async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None: user_id = user.id if user_id is None: return None @@ -160,9 +155,7 @@ bot = Bot() @bot.command("help") -async def _help( - user: User, args: list[str], _session: AsyncSession, channel: ChatChannel -) -> str: +async def _help(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str: cmds = sorted(bot._handlers.keys()) if args: target = args[0].lower() @@ -175,9 +168,7 @@ async def _help( @bot.command("roll") -def _roll( - user: User, args: list[str], _session: AsyncSession, channel: ChatChannel -) -> str: +def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str: if len(args) > 0 and args[0].isdigit(): r = random.randint(1, int(args[0])) else: @@ -186,13 +177,9 @@ def _roll( @bot.command("stats") -async def _stats( - user: User, args: list[str], session: AsyncSession, channel: ChatChannel -) -> str: +async def _stats(user: User, args: list[str], session: AsyncSession, channel: ChatChannel) -> str: if len(args) >= 1: - target_user = ( - await session.exec(select(User).where(User.username == args[0])) - ).first() + target_user = (await session.exec(select(User).where(User.username == args[0]))).first() if not target_user: return f"User '{args[0]}' not found." else: @@ -202,14 +189,8 @@ async def _stats( if len(args) >= 2: gamemode = GameMode.parse(args[1].upper()) if gamemode is None: - subquery = ( - select(func.max(Score.id)) - .where(Score.user_id == target_user.id) - .scalar_subquery() - ) - last_score = ( - await session.exec(select(Score).where(Score.id == subquery)) - ).first() + subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery() + last_score = (await session.exec(select(Score).where(Score.id == subquery))).first() if last_score is not None: gamemode = last_score.gamemode else: @@ -295,9 +276,7 @@ async def _mp_host( return "Usage: !mp host " username = args[0] - user_id = ( - await session.exec(select(User.id).where(User.username == username)) - ).first() + user_id = (await session.exec(select(User.id).where(User.username == username))).first() if not user_id: return f"User '{username}' not found." @@ -362,24 +341,18 @@ async def _mp_team( if team is None: return "Invalid team colour. Use 'red' or 'blue'." - user_id = ( - await session.exec(select(User.id).where(User.username == username)) - ).first() + user_id = (await session.exec(select(User.id).where(User.username == username))).first() if not user_id: return f"User '{username}' not found." user_client = MultiplayerHubs.get_client_by_id(str(user_id)) if not user_client: return f"User '{username}' is not in the room." - if ( - user_client.user_id != signalr_client.user_id - and room.room.host.user_id != signalr_client.user_id - ): + assert room.room.host + if user_client.user_id != signalr_client.user_id and room.room.host.user_id != signalr_client.user_id: return "You are not allowed to change other users' teams." try: - await MultiplayerHubs.SendMatchRequest( - user_client, ChangeTeamRequest(team_id=team) - ) + await MultiplayerHubs.SendMatchRequest(user_client, ChangeTeamRequest(team_id=team)) return "" except InvokeException as e: return e.message @@ -414,9 +387,7 @@ async def _mp_kick( return "Usage: !mp kick " username = args[0] - user_id = ( - await session.exec(select(User.id).where(User.username == username)) - ).first() + user_id = (await session.exec(select(User.id).where(User.username == username))).first() if not user_id: return f"User '{username}' not found." @@ -456,10 +427,7 @@ async def _mp_map( try: beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id) if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode: - return ( - f"Cannot convert to {playmode.value}. " - f"Original mode is {beatmap.mode.value}." - ) + return f"Cannot convert to {playmode.value}. Original mode is {beatmap.mode.value}." except HTTPError: return "Beatmap not found" @@ -530,9 +498,7 @@ async def _mp_mods( if freestyle: item.allowed_mods = [] elif freemod: - item.allowed_mods = get_available_mods( - current_item.ruleset_id, required_mods - ) + item.allowed_mods = get_available_mods(current_item.ruleset_id, required_mods) else: item.allowed_mods = allowed_mods item.required_mods = required_mods @@ -601,14 +567,9 @@ async def _score( include_fail: bool = False, gamemode: GameMode | None = None, ) -> str: - q = ( - select(Score) - .where(Score.user_id == user_id) - .order_by(col(Score.id).desc()) - .options(joinedload(Score.beatmap)) - ) + q = select(Score).where(Score.user_id == user_id).order_by(col(Score.id).desc()).options(joinedload(Score.beatmap)) if not include_fail: - q = q.where(Score.passed.is_(True)) + q = q.where(col(Score.passed).is_(True)) if gamemode is not None: q = q.where(Score.gamemode == gamemode) @@ -619,17 +580,13 @@ async def _score( result = f"""{score.beatmap.beatmapset.title} [{score.beatmap.version}] ({score.gamemode.name.lower()}) Played at {score.started_at} {score.pp:.2f}pp {score.accuracy:.2%} {",".join(mod_to_save(score.mods))} {score.rank.name.upper()} -Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}""" # noqa: E501 +Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}""" if score.gamemode == GameMode.MANIA: - keys = next( - (mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None - ) + keys = next((mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None) if keys is None: keys = f"{int(score.beatmap.cs)}K" p_d_g = f"{score.ngeki / score.n300:.2f}:1" if score.n300 > 0 else "inf:1" - result += ( - f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}" - ) + result += f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}" return result diff --git a/app/router/notification/channel.py b/app/router/notification/channel.py index f2018dc..9f76d7e 100644 --- a/app/router/notification/channel.py +++ b/app/router/notification/channel.py @@ -38,27 +38,18 @@ class UpdateResponse(BaseModel): ) async def get_update( session: Database, - history_since: int | None = Query( - None, description="获取自此禁言 ID 之后的禁言记录" - ), + history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"), since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"), - includes: list[str] = Query( - ["presence", "silences"], alias="includes[]", description="要包含的更新类型" - ), + includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"), current_user: User = Security(get_current_user, scopes=["chat.read"]), redis: Redis = Depends(get_redis), ): resp = UpdateResponse() if "presence" in includes: - assert current_user.id channel_ids = server.get_user_joined_channel(current_user.id) for channel_id in channel_ids: # 使用明确的查询避免延迟加载 - db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.channel_id == channel_id) - ) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first() if db_channel: # 提取必要的属性避免惰性加载 channel_type = db_channel.type @@ -69,34 +60,20 @@ async def get_update( session, current_user, redis, - server.channels.get(channel_id, []) - if channel_type != ChannelType.PUBLIC - else None, + server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None, ) ) if "silences" in includes: if history_since: - silences = ( - await session.exec( - select(SilenceUser).where(col(SilenceUser.id) > history_since) - ) - ).all() - resp.silences.extend( - [UserSilenceResp.from_db(silence) for silence in silences] - ) + silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all() + resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences]) elif since: msg = await session.get(ChatMessage, since) if msg: silences = ( - await session.exec( - select(SilenceUser).where( - col(SilenceUser.banned_at) > msg.timestamp - ) - ) + await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp)) ).all() - resp.silences.extend( - [UserSilenceResp.from_db(silence) for silence in silences] - ) + resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences]) return resp @@ -115,15 +92,9 @@ async def join_channel( ): # 使用明确的查询避免延迟加载 if channel.isdigit(): - db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.channel_id == int(channel)) - ) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first() else: - db_channel = ( - await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first() if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") @@ -145,15 +116,9 @@ async def leave_channel( ): # 使用明确的查询避免延迟加载 if channel.isdigit(): - db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.channel_id == int(channel)) - ) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first() else: - db_channel = ( - await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first() if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") @@ -173,27 +138,20 @@ async def get_channel_list( current_user: User = Security(get_current_user, scopes=["chat.read"]), redis: Redis = Depends(get_redis), ): - channels = ( - await session.exec( - select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC) - ) - ).all() + channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all() results = [] for channel in channels: # 提取必要的属性避免惰性加载 channel_id = channel.channel_id channel_type = channel.type - assert channel_id is not None results.append( await ChatChannelResp.from_db( channel, session, current_user, redis, - server.channels.get(channel_id, []) - if channel_type != ChannelType.PUBLIC - else None, + server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None, ) ) return results @@ -219,15 +177,9 @@ async def get_channel( ): # 使用明确的查询避免延迟加载 if channel.isdigit(): - db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.channel_id == int(channel)) - ) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first() else: - db_channel = ( - await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first() if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") @@ -237,8 +189,6 @@ async def get_channel( channel_type = db_channel.type channel_name = db_channel.name - assert channel_id is not None - users = [] if channel_type == ChannelType.PM: user_ids = channel_name.split("_")[1:] @@ -259,9 +209,7 @@ async def get_channel( session, current_user, redis, - server.channels.get(channel_id, []) - if channel_type != ChannelType.PUBLIC - else None, + server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None, ) ) @@ -284,9 +232,7 @@ class CreateChannelReq(BaseModel): raise ValueError("target_id must be set for PM channels") else: if self.target_ids is None or self.channel is None or self.message is None: - raise ValueError( - "target_ids, channel, and message must be set for ANNOUNCE channels" - ) + raise ValueError("target_ids, channel, and message must be set for ANNOUNCE channels") return self @@ -312,24 +258,20 @@ async def create_channel( raise HTTPException(status_code=403, detail=block) channel = await ChatChannel.get_pm_channel( - current_user.id, # pyright: ignore[reportArgumentType] + current_user.id, req.target_id, # pyright: ignore[reportArgumentType] session, ) channel_name = f"pm_{current_user.id}_{req.target_id}" else: channel_name = req.channel.name if req.channel else "Unnamed Channel" - result = await session.exec( - select(ChatChannel).where(ChatChannel.name == channel_name) - ) + result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name)) channel = result.first() if channel is None: channel = ChatChannel( name=channel_name, - description=req.channel.description - if req.channel - else "Private message channel", + description=req.channel.description if req.channel else "Private message channel", type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE, ) session.add(channel) @@ -340,16 +282,13 @@ async def create_channel( await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable] await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable] else: - target_users = await session.exec( - select(User).where(col(User.id).in_(req.target_ids or [])) - ) + target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or []))) await server.batch_join_channel([*target_users, current_user], channel, session) await server.join_channel(current_user, channel, session) # 提取必要的属性避免惰性加载 channel_id = channel.channel_id - assert channel_id return await ChatChannelResp.from_db( channel, diff --git a/app/router/notification/message.py b/app/router/notification/message.py index 390db6b..a323918 100644 --- a/app/router/notification/message.py +++ b/app/router/notification/message.py @@ -41,33 +41,19 @@ class KeepAliveResp(BaseModel): ) async def keep_alive( session: Database, - history_since: int | None = Query( - None, description="获取自此禁言 ID 之后的禁言记录" - ), + history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"), since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"), current_user: User = Security(get_current_user, scopes=["chat.read"]), ): resp = KeepAliveResp() if history_since: - silences = ( - await session.exec( - select(SilenceUser).where(col(SilenceUser.id) > history_since) - ) - ).all() + silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all() resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences]) elif since: msg = await session.get(ChatMessage, since) if msg: - silences = ( - await session.exec( - select(SilenceUser).where( - col(SilenceUser.banned_at) > msg.timestamp - ) - ) - ).all() - resp.silences.extend( - [UserSilenceResp.from_db(silence) for silence in silences] - ) + silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))).all() + resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences]) return resp @@ -93,15 +79,9 @@ async def send_message( ): # 使用明确的查询来获取 channel,避免延迟加载 if channel.isdigit(): - db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.channel_id == int(channel)) - ) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first() else: - db_channel = ( - await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first() if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") @@ -111,9 +91,6 @@ async def send_message( channel_type = db_channel.type channel_name = db_channel.name - assert channel_id is not None - assert current_user.id - # 使用 Redis 消息系统发送消息 - 立即返回 resp = await redis_message_system.send_message( channel_id=channel_id, @@ -125,9 +102,7 @@ async def send_message( # 立即广播消息给所有客户端 is_bot_command = req.message.startswith("!") - await server.send_message_to_channel( - resp, is_bot_command and channel_type == ChannelType.PUBLIC - ) + await server.send_message_to_channel(resp, is_bot_command and channel_type == ChannelType.PUBLIC) # 处理机器人命令 if is_bot_command: @@ -147,14 +122,10 @@ async def send_message( if channel_type == ChannelType.PM: user_ids = channel_name.split("_")[1:] await server.new_private_notification( - ChannelMessage.init( - temp_msg, current_user, [int(u) for u in user_ids], channel_type - ) + ChannelMessage.init(temp_msg, current_user, [int(u) for u in user_ids], channel_type) ) elif channel_type == ChannelType.TEAM: - await server.new_private_notification( - ChannelMessageTeam.init(temp_msg, current_user) - ) + await server.new_private_notification(ChannelMessageTeam.init(temp_msg, current_user)) return resp @@ -176,22 +147,15 @@ async def get_message( ): # 使用明确的查询获取 channel,避免延迟加载 if channel.isdigit(): - db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.channel_id == int(channel)) - ) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first() else: - db_channel = ( - await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first() if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") # 提取必要的属性避免惰性加载 channel_id = db_channel.channel_id - assert channel_id is not None # 使用 Redis 消息系统获取消息 try: @@ -230,23 +194,15 @@ async def mark_as_read( ): # 使用明确的查询获取 channel,避免延迟加载 if channel.isdigit(): - db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.channel_id == int(channel)) - ) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first() else: - db_channel = ( - await session.exec(select(ChatChannel).where(ChatChannel.name == channel)) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first() if db_channel is None: raise HTTPException(status_code=404, detail="Channel not found") # 立即提取需要的属性 channel_id = db_channel.channel_id - assert channel_id - assert current_user.id await server.mark_as_read(channel_id, current_user.id, message) @@ -283,7 +239,6 @@ async def create_new_pm( if not is_can_pm: raise HTTPException(status_code=403, detail=block) - assert user_id channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session) if channel is None: channel = ChatChannel( @@ -297,7 +252,6 @@ async def create_new_pm( await session.refresh(target) await session.refresh(current_user) - assert channel.channel_id await server.batch_join_channel([target, current_user], channel, session) channel_resp = await ChatChannelResp.from_db( channel, session, current_user, redis, server.channels[channel.channel_id] diff --git a/app/router/notification/server.py b/app/router/notification/server.py index 71a5106..d4810a0 100644 --- a/app/router/notification/server.py +++ b/app/router/notification/server.py @@ -17,6 +17,7 @@ from app.log import logger from app.models.chat import ChatEvent from app.models.notification import NotificationDetail from app.service.subscribers.chat import ChatSubscriber +from app.utils import bg_tasks from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect from fastapi.security import SecurityScopes @@ -37,20 +38,11 @@ class ChatServer: self.ChatSubscriber.chat_server = self self._subscribed = False - def _add_task(self, task): - task = asyncio.create_task(task) - self.tasks.add(task) - task.add_done_callback(self.tasks.discard) - def connect(self, user_id: int, client: WebSocket): self.connect_client[user_id] = client def get_user_joined_channel(self, user_id: int) -> list[int]: - return [ - channel_id - for channel_id, users in self.channels.items() - if user_id in users - ] + return [channel_id for channel_id, users in self.channels.items() if user_id in users] async def disconnect(self, user: User, session: AsyncSession): user_id = user.id @@ -61,9 +53,7 @@ class ChatServer: channel.remove(user_id) # 使用明确的查询避免延迟加载 db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.channel_id == channel_id) - ) + await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id)) ).first() if db_channel: await self.leave_channel(user, db_channel, session) @@ -93,11 +83,10 @@ class ChatServer: async def mark_as_read(self, channel_id: int, user_id: int, message_id: int): await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id) - async def send_message_to_channel( - self, message: ChatMessageResp, is_bot_command: bool = False - ): + async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False): logger.info( - f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}" + f"Sending message to channel {message.channel_id}, message_id: " + f"{message.message_id}, is_bot_command: {is_bot_command}" ) event = ChatEvent( @@ -106,62 +95,44 @@ class ChatServer: ) if is_bot_command: logger.info(f"Sending bot command to user {message.sender_id}") - self._add_task(self.send_event(message.sender_id, event)) + bg_tasks.add_task(self.send_event, message.sender_id, event) else: # 总是广播消息,无论是临时ID还是真实ID - logger.info( - f"Broadcasting message to all users in channel {message.channel_id}" - ) - self._add_task( - self.broadcast( - message.channel_id, - event, - ) + logger.info(f"Broadcasting message to all users in channel {message.channel_id}") + bg_tasks.add_task( + self.broadcast, + message.channel_id, + event, ) # 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息 # Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理 if message.message_id and message.message_id > 0: - await self.mark_as_read( - message.channel_id, message.sender_id, message.message_id - ) - await self.redis.set( - f"chat:{message.channel_id}:last_msg", message.message_id - ) - logger.info( - f"Updated last message ID for channel {message.channel_id} to {message.message_id}" - ) + await self.mark_as_read(message.channel_id, message.sender_id, message.message_id) + await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id) + logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}") else: - logger.debug( - f"Skipping last message update for message ID: {message.message_id}" - ) + logger.debug(f"Skipping last message update for message ID: {message.message_id}") - async def batch_join_channel( - self, users: list[User], channel: ChatChannel, session: AsyncSession - ): + async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession): channel_id = channel.channel_id - assert channel_id is not None not_joined = [] if channel_id not in self.channels: self.channels[channel_id] = [] for user in users: - assert user.id is not None if user.id not in self.channels[channel_id]: self.channels[channel_id].append(user.id) not_joined.append(user) for user in not_joined: - assert user.id is not None channel_resp = await ChatChannelResp.from_db( channel, session, user, self.redis, - self.channels[channel_id] - if channel.type != ChannelType.PUBLIC - else None, + self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None, ) await self.send_event( user.id, @@ -171,13 +142,9 @@ class ChatServer: ), ) - async def join_channel( - self, user: User, channel: ChatChannel, session: AsyncSession - ) -> ChatChannelResp: + async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp: user_id = user.id channel_id = channel.channel_id - assert channel_id is not None - assert user_id is not None if channel_id not in self.channels: self.channels[channel_id] = [] @@ -202,13 +169,9 @@ class ChatServer: return channel_resp - async def leave_channel( - self, user: User, channel: ChatChannel, session: AsyncSession - ) -> None: + async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None: user_id = user.id channel_id = channel.channel_id - assert channel_id is not None - assert user_id is not None if channel_id in self.channels and user_id in self.channels[channel_id]: self.channels[channel_id].remove(user_id) @@ -221,9 +184,7 @@ class ChatServer: session, user, self.redis, - self.channels.get(channel_id) - if channel.type != ChannelType.PUBLIC - else None, + self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None, ) await self.send_event( user_id, @@ -236,11 +197,7 @@ class ChatServer: async def join_room_channel(self, channel_id: int, user_id: int): async with with_db() as session: # 使用明确的查询避免延迟加载 - db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.channel_id == channel_id) - ) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first() if db_channel is None: return @@ -253,11 +210,7 @@ class ChatServer: async def leave_room_channel(self, channel_id: int, user_id: int): async with with_db() as session: # 使用明确的查询避免延迟加载 - db_channel = ( - await session.exec( - select(ChatChannel).where(ChatChannel.channel_id == channel_id) - ) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first() if db_channel is None: return @@ -270,13 +223,7 @@ class ChatServer: async def new_private_notification(self, detail: NotificationDetail): async with with_db() as session: id = await insert_notification(session, detail) - users = ( - await session.exec( - select(UserNotification).where( - UserNotification.notification_id == id - ) - ) - ).all() + users = (await session.exec(select(UserNotification).where(UserNotification.notification_id == id))).all() for user_notification in users: data = user_notification.notification.model_dump() data["is_read"] = user_notification.is_read @@ -308,9 +255,7 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory): await ws.close(code=1000) break except WebSocketDisconnect as e: - logger.info( - f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}" - ) + logger.info(f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}") except RuntimeError as e: if "disconnect message" in str(e): logger.info(f"[NotificationServer] Client {user_id} closed the connection.") @@ -332,11 +277,7 @@ async def chat_websocket( async for session in factory(): token = authorization[7:] - if ( - user := await get_current_user( - session, SecurityScopes(scopes=["chat.read"]), token_pw=token - ) - ) is None: + if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None: await websocket.close(code=1008) return @@ -346,12 +287,9 @@ async def chat_websocket( await websocket.close(code=1008) return user_id = user.id - assert user_id server.connect(user_id, websocket) # 使用明确的查询避免延迟加载 - db_channel = ( - await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1)) - ).first() + db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first() if db_channel is not None: await server.join_channel(user, db_channel, session) diff --git a/app/router/password_reset_admin.py b/app/router/password_reset_admin.py index 8abd81c..226fb44 100644 --- a/app/router/password_reset_admin.py +++ b/app/router/password_reset_admin.py @@ -2,22 +2,20 @@ 密码重置管理接口 """ -from fastapi import APIRouter, Depends, HTTPException -from fastapi.responses import JSONResponse -from redis.asyncio import Redis +from __future__ import annotations from app.dependencies.database import get_redis -from app.service.password_reset_service import password_reset_service from app.log import logger +from app.service.password_reset_service import password_reset_service + +from fastapi import APIRouter, Depends +from fastapi.responses import JSONResponse +from redis.asyncio import Redis router = APIRouter(prefix="/admin/password-reset", tags=["密码重置管理"]) -@router.get( - "/status/{email}", - name="查询重置状态", - description="查询指定邮箱的密码重置状态" -) +@router.get("/status/{email}", name="查询重置状态", description="查询指定邮箱的密码重置状态") async def get_password_reset_status( email: str, redis: Redis = Depends(get_redis), @@ -25,28 +23,16 @@ async def get_password_reset_status( """查询密码重置状态""" try: info = await password_reset_service.get_reset_code_info(email, redis) - return JSONResponse( - status_code=200, - content={ - "success": True, - "data": info - } - ) + return JSONResponse(status_code=200, content={"success": True, "data": info}) except Exception as e: logger.error(f"[Admin] Failed to get password reset status for {email}: {e}") - return JSONResponse( - status_code=500, - content={ - "success": False, - "error": "获取状态失败" - } - ) + return JSONResponse(status_code=500, content={"success": False, "error": "获取状态失败"}) @router.delete( "/cleanup/{email}", name="清理重置数据", - description="强制清理指定邮箱的密码重置数据" + description="强制清理指定邮箱的密码重置数据", ) async def force_cleanup_reset( email: str, @@ -55,38 +41,23 @@ async def force_cleanup_reset( """强制清理密码重置数据""" try: success = await password_reset_service.force_cleanup_user_reset(email, redis) - + if success: return JSONResponse( status_code=200, - content={ - "success": True, - "message": f"已清理邮箱 {email} 的重置数据" - } + content={"success": True, "message": f"已清理邮箱 {email} 的重置数据"}, ) else: - return JSONResponse( - status_code=500, - content={ - "success": False, - "error": "清理失败" - } - ) + return JSONResponse(status_code=500, content={"success": False, "error": "清理失败"}) except Exception as e: logger.error(f"[Admin] Failed to cleanup password reset for {email}: {e}") - return JSONResponse( - status_code=500, - content={ - "success": False, - "error": "清理操作失败" - } - ) + return JSONResponse(status_code=500, content={"success": False, "error": "清理操作失败"}) @router.post( "/cleanup/expired", name="清理过期验证码", - description="清理所有过期的密码重置验证码" + description="清理所有过期的密码重置验证码", ) async def cleanup_expired_codes( redis: Redis = Depends(get_redis), @@ -99,25 +70,15 @@ async def cleanup_expired_codes( content={ "success": True, "message": f"已清理 {count} 个过期的验证码", - "cleaned_count": count - } + "cleaned_count": count, + }, ) except Exception as e: logger.error(f"[Admin] Failed to cleanup expired codes: {e}") - return JSONResponse( - status_code=500, - content={ - "success": False, - "error": "清理操作失败" - } - ) + return JSONResponse(status_code=500, content={"success": False, "error": "清理操作失败"}) -@router.get( - "/stats", - name="重置统计", - description="获取密码重置的统计信息" -) +@router.get("/stats", name="重置统计", description="获取密码重置的统计信息") async def get_reset_statistics( redis: Redis = Depends(get_redis), ): @@ -126,53 +87,42 @@ async def get_reset_statistics( # 获取所有重置相关的键 reset_keys = await redis.keys("password_reset:code:*") rate_limit_keys = await redis.keys("password_reset:rate_limit:*") - + active_resets = 0 used_resets = 0 active_rate_limits = 0 - + # 统计活跃重置 for key in reset_keys: data_str = await redis.get(key) if data_str: try: import json + data = json.loads(data_str) if data.get("used", False): used_resets += 1 else: active_resets += 1 - except: + except Exception: pass - + # 统计频率限制 for key in rate_limit_keys: ttl = await redis.ttl(key) if ttl > 0: active_rate_limits += 1 - + stats = { "total_reset_codes": len(reset_keys), "active_resets": active_resets, "used_resets": used_resets, "active_rate_limits": active_rate_limits, - "total_rate_limit_keys": len(rate_limit_keys) + "total_rate_limit_keys": len(rate_limit_keys), } - - return JSONResponse( - status_code=200, - content={ - "success": True, - "data": stats - } - ) - + + return JSONResponse(status_code=200, content={"success": True, "data": stats}) + except Exception as e: logger.error(f"[Admin] Failed to get reset statistics: {e}") - return JSONResponse( - status_code=500, - content={ - "success": False, - "error": "获取统计信息失败" - } - ) + return JSONResponse(status_code=500, content={"success": False, "error": "获取统计信息失败"}) diff --git a/app/router/private/oauth.py b/app/router/private/oauth.py index c18d8bd..576f154 100644 --- a/app/router/private/oauth.py +++ b/app/router/private/oauth.py @@ -26,7 +26,7 @@ async def create_oauth_app( redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"), current_user: User = Security(get_client_user), ): - result = await session.execute( # pyright: ignore[reportDeprecated] + result = await session.execute( text( "SELECT AUTO_INCREMENT FROM information_schema.TABLES " "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'oauth_clients'" @@ -84,9 +84,7 @@ async def get_user_oauth_apps( session: Database, current_user: User = Security(get_client_user), ): - oauth_apps = await session.exec( - select(OAuthClient).where(OAuthClient.owner_id == current_user.id) - ) + oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id)) return [ { "name": app.name, @@ -113,13 +111,9 @@ async def delete_oauth_app( if not oauth_client: raise HTTPException(status_code=404, detail="OAuth app not found") if oauth_client.owner_id != current_user.id: - raise HTTPException( - status_code=403, detail="Forbidden: Not the owner of this app" - ) + raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app") - tokens = await session.exec( - select(OAuthToken).where(OAuthToken.client_id == client_id) - ) + tokens = await session.exec(select(OAuthToken).where(OAuthToken.client_id == client_id)) for token in tokens: await session.delete(token) @@ -144,9 +138,7 @@ async def update_oauth_app( if not oauth_client: raise HTTPException(status_code=404, detail="OAuth app not found") if oauth_client.owner_id != current_user.id: - raise HTTPException( - status_code=403, detail="Forbidden: Not the owner of this app" - ) + raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app") oauth_client.name = name oauth_client.description = description @@ -176,14 +168,10 @@ async def refresh_secret( if not oauth_client: raise HTTPException(status_code=404, detail="OAuth app not found") if oauth_client.owner_id != current_user.id: - raise HTTPException( - status_code=403, detail="Forbidden: Not the owner of this app" - ) + raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app") oauth_client.client_secret = secrets.token_hex() - tokens = await session.exec( - select(OAuthToken).where(OAuthToken.client_id == client_id) - ) + tokens = await session.exec(select(OAuthToken).where(OAuthToken.client_id == client_id)) for token in tokens: await session.delete(token) @@ -215,9 +203,7 @@ async def generate_oauth_code( raise HTTPException(status_code=404, detail="OAuth app not found") if redirect_uri not in client.redirect_uris: - raise HTTPException( - status_code=403, detail="Redirect URI not allowed for this client" - ) + raise HTTPException(status_code=403, detail="Redirect URI not allowed for this client") code = secrets.token_urlsafe(80) await redis.hset( # pyright: ignore[reportGeneralTypeIssues] diff --git a/app/router/private/relationship.py b/app/router/private/relationship.py index 1a8af44..1dde7af 100644 --- a/app/router/private/relationship.py +++ b/app/router/private/relationship.py @@ -50,12 +50,8 @@ async def check_user_relationship( ) ).first() - is_followed = bool( - target_relationship and target_relationship.type == RelationshipType.FOLLOW - ) - is_following = bool( - my_relationship and my_relationship.type == RelationshipType.FOLLOW - ) + is_followed = bool(target_relationship and target_relationship.type == RelationshipType.FOLLOW) + is_following = bool(my_relationship and my_relationship.type == RelationshipType.FOLLOW) return CheckResponse( is_followed=is_followed, diff --git a/app/router/private/team.py b/app/router/private/team.py index 3f18645..32584e6 100644 --- a/app/router/private/team.py +++ b/app/router/private/team.py @@ -40,16 +40,13 @@ async def create_team( 支持的图片格式: PNG、JPEG、GIF """ user_id = current_user.id - assert user_id if (await current_user.awaitable_attrs.team_membership) is not None: raise HTTPException(status_code=403, detail="You are already in a team") is_existed = (await session.exec(select(exists()).where(Team.name == name))).first() if is_existed: raise HTTPException(status_code=409, detail="Name already exists") - is_existed = ( - await session.exec(select(exists()).where(Team.short_name == short_name)) - ).first() + is_existed = (await session.exec(select(exists()).where(Team.short_name == short_name))).first() if is_existed: raise HTTPException(status_code=409, detail="Short name already exists") @@ -101,7 +98,6 @@ async def update_team( """ team = await session.get(Team, team_id) user_id = current_user.id - assert user_id if not team: raise HTTPException(status_code=404, detail="Team not found") if team.leader_id != user_id: @@ -110,9 +106,7 @@ async def update_team( is_existed = (await session.exec(select(exists()).where(Team.name == name))).first() if is_existed: raise HTTPException(status_code=409, detail="Name already exists") - is_existed = ( - await session.exec(select(exists()).where(Team.short_name == short_name)) - ).first() + is_existed = (await session.exec(select(exists()).where(Team.short_name == short_name))).first() if is_existed: raise HTTPException(status_code=409, detail="Short name already exists") @@ -132,20 +126,12 @@ async def update_team( team.cover_url = await storage.get_file_url(storage_path) if leader_id is not None: - if not ( - await session.exec(select(exists()).where(User.id == leader_id)) - ).first(): + if not (await session.exec(select(exists()).where(User.id == leader_id))).first(): raise HTTPException(status_code=404, detail="Leader not found") if not ( - await session.exec( - select(TeamMember).where( - TeamMember.user_id == leader_id, TeamMember.team_id == team.id - ) - ) + await session.exec(select(TeamMember).where(TeamMember.user_id == leader_id, TeamMember.team_id == team.id)) ).first(): - raise HTTPException( - status_code=404, detail="Leader is not a member of the team" - ) + raise HTTPException(status_code=404, detail="Leader is not a member of the team") team.leader_id = leader_id await session.commit() @@ -166,9 +152,7 @@ async def delete_team( if team.leader_id != current_user.id: raise HTTPException(status_code=403, detail="You are not the team leader") - team_members = await session.exec( - select(TeamMember).where(TeamMember.team_id == team_id) - ) + team_members = await session.exec(select(TeamMember).where(TeamMember.team_id == team_id)) for member in team_members: await session.delete(member) @@ -186,15 +170,10 @@ async def get_team( session: Database, team_id: int = Path(..., description="战队 ID"), ): - members = ( - await session.exec(select(TeamMember).where(TeamMember.team_id == team_id)) - ).all() + members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all() return TeamQueryResp( team=members[0].team, - members=[ - await UserResp.from_db(m.user, session, include=BASE_INCLUDES) - for m in members - ], + members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members], ) @@ -213,15 +192,11 @@ async def request_join_team( if ( await session.exec( - select(exists()).where( - TeamRequest.team_id == team_id, TeamRequest.user_id == current_user.id - ) + select(exists()).where(TeamRequest.team_id == team_id, TeamRequest.user_id == current_user.id) ) ).first(): raise HTTPException(status_code=409, detail="Join request already exists") - team_request = TeamRequest( - user_id=current_user.id, team_id=team_id, requested_at=datetime.now(UTC) - ) + team_request = TeamRequest(user_id=current_user.id, team_id=team_id, requested_at=datetime.now(UTC)) session.add(team_request) await session.commit() await session.refresh(team_request) @@ -229,9 +204,7 @@ async def request_join_team( @router.post("/team/{team_id}/{user_id}/request", name="接受加入请求", status_code=204) -@router.delete( - "/team/{team_id}/{user_id}/request", name="拒绝加入请求", status_code=204 -) +@router.delete("/team/{team_id}/{user_id}/request", name="拒绝加入请求", status_code=204) async def handle_request( req: Request, session: Database, @@ -247,11 +220,7 @@ async def handle_request( raise HTTPException(status_code=403, detail="You are not the team leader") team_request = ( - await session.exec( - select(TeamRequest).where( - TeamRequest.team_id == team_id, TeamRequest.user_id == user_id - ) - ) + await session.exec(select(TeamRequest).where(TeamRequest.team_id == team_id, TeamRequest.user_id == user_id)) ).first() if not team_request: raise HTTPException(status_code=404, detail="Join request not found") @@ -261,16 +230,10 @@ async def handle_request( raise HTTPException(status_code=404, detail="User not found") if req.method == "POST": - if ( - await session.exec(select(exists()).where(TeamMember.user_id == user_id)) - ).first(): - raise HTTPException( - status_code=409, detail="User is already a member of the team" - ) + if (await session.exec(select(exists()).where(TeamMember.user_id == user_id))).first(): + raise HTTPException(status_code=409, detail="User is already a member of the team") - session.add( - TeamMember(user_id=user_id, team_id=team_id, joined_at=datetime.now(UTC)) - ) + session.add(TeamMember(user_id=user_id, team_id=team_id, joined_at=datetime.now(UTC))) await server.new_private_notification(TeamApplicationAccept.init(team_request)) else: @@ -294,19 +257,13 @@ async def kick_member( raise HTTPException(status_code=403, detail="You are not the team leader") team_member = ( - await session.exec( - select(TeamMember).where( - TeamMember.team_id == team_id, TeamMember.user_id == user_id - ) - ) + await session.exec(select(TeamMember).where(TeamMember.team_id == team_id, TeamMember.user_id == user_id)) ).first() if not team_member: raise HTTPException(status_code=404, detail="User is not a member of the team") if team.leader_id == current_user.id: - raise HTTPException( - status_code=403, detail="You cannot leave because you are the team leader" - ) + raise HTTPException(status_code=403, detail="You cannot leave because you are the team leader") await session.delete(team_member) await session.commit() diff --git a/app/router/private/username.py b/app/router/private/username.py index 0e66bd8..0d94efe 100644 --- a/app/router/private/username.py +++ b/app/router/private/username.py @@ -35,10 +35,7 @@ async def user_rename( 返回: - 成功: None """ - assert current_user is not None - samename_user = ( - await session.exec(select(User).where(User.username == new_name)) - ).first() + samename_user = (await session.exec(select(User).where(User.username == new_name))).first() if samename_user: raise HTTPException(409, "Username Exisits") errors = validate_username(new_name) diff --git a/app/router/v1/beatmap.py b/app/router/v1/beatmap.py index 82209dc..3301fd2 100644 --- a/app/router/v1/beatmap.py +++ b/app/router/v1/beatmap.py @@ -106,9 +106,7 @@ class V1Beatmap(AllStrModel): await session.exec( select(func.count()) .select_from(FavouriteBeatmapset) - .where( - FavouriteBeatmapset.beatmapset_id == db_beatmap.beatmapset.id - ) + .where(FavouriteBeatmapset.beatmapset_id == db_beatmap.beatmapset.id) ) ).one(), rating=0, # TODO @@ -154,12 +152,8 @@ async def get_beatmaps( beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"), beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"), user: str | None = Query(None, alias="u", description="谱师"), - type: Literal["string", "id"] | None = Query( - None, description="用户类型:string 用户名称 / id 用户 ID" - ), - ruleset_id: int | None = Query( - None, alias="m", description="Ruleset ID", ge=0, le=3 - ), # TODO + type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"), + ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0, le=3), # TODO convert: bool = Query(False, alias="a", description="转谱"), # TODO checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"), limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"), @@ -181,11 +175,7 @@ async def get_beatmaps( else: beatmaps = beatmapset.beatmaps elif user is not None: - where = ( - Beatmapset.user_id == user - if type == "id" or user.isdigit() - else Beatmapset.creator == user - ) + where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user beatmapsets = (await session.exec(select(Beatmapset).where(where))).all() for beatmapset in beatmapsets: if len(beatmaps) >= limit: @@ -193,11 +183,7 @@ async def get_beatmaps( beatmaps.extend(beatmapset.beatmaps) elif since is not None: beatmapsets = ( - await session.exec( - select(Beatmapset) - .where(col(Beatmapset.ranked_date) > since) - .limit(limit) - ) + await session.exec(select(Beatmapset).where(col(Beatmapset.ranked_date) > since).limit(limit)) ).all() for beatmapset in beatmapsets: if len(beatmaps) >= limit: @@ -214,11 +200,7 @@ async def get_beatmaps( redis, fetcher, ) - results.append( - await V1Beatmap.from_db( - session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty - ) - ) + results.append(await V1Beatmap.from_db(session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty)) continue except Exception: ... diff --git a/app/router/v1/replay.py b/app/router/v1/replay.py index f1cc97d..884dabf 100644 --- a/app/router/v1/replay.py +++ b/app/router/v1/replay.py @@ -41,9 +41,7 @@ async def download_replay( ge=0, ), score_id: int | None = Query(None, alias="s", description="成绩 ID"), - type: Literal["string", "id"] | None = Query( - None, description="用户类型:string 用户名称 / id 用户 ID" - ), + type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"), mods: int = Query(0, description="成绩的 MOD"), storage_service: StorageService = Depends(get_storage_service), ): @@ -58,13 +56,9 @@ async def download_replay( await session.exec( select(Score).where( Score.beatmap_id == beatmap, - Score.user_id == user - if type == "id" or user.isdigit() - else col(Score.user).has(username=user), + Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user), Score.mods == mods_, - Score.gamemode == GameMode.from_int_extra(ruleset_id) - if ruleset_id is not None - else True, + Score.gamemode == GameMode.from_int_extra(ruleset_id) if ruleset_id is not None else True, ) ) ).first() @@ -73,10 +67,7 @@ async def download_replay( except KeyError: raise HTTPException(status_code=400, detail="Invalid request") - filepath = ( - f"replays/{score_record.id}_{score_record.beatmap_id}" - f"_{score_record.user_id}_lazer_replay.osr" - ) + filepath = f"replays/{score_record.id}_{score_record.beatmap_id}_{score_record.user_id}_lazer_replay.osr" if not await storage_service.is_exists(filepath): raise HTTPException(status_code=404, detail="Replay file not found") @@ -100,6 +91,4 @@ async def download_replay( await session.commit() data = await storage_service.read_file(filepath) - return ReplayModel( - content=base64.b64encode(data).decode("utf-8"), encoding="base64" - ) + return ReplayModel(content=base64.b64encode(data).decode("utf-8"), encoding="base64") diff --git a/app/router/v1/router.py b/app/router/v1/router.py index 268cd1f..ecd9006 100644 --- a/app/router/v1/router.py +++ b/app/router/v1/router.py @@ -8,9 +8,7 @@ from app.dependencies.user import v1_authorize from fastapi import APIRouter, Depends from pydantic import BaseModel, field_serializer -router = APIRouter( - prefix="/api/v1", dependencies=[Depends(v1_authorize)], tags=["V1 API"] -) +router = APIRouter(prefix="/api/v1", dependencies=[Depends(v1_authorize)], tags=["V1 API"]) class AllStrModel(BaseModel): diff --git a/app/router/v1/score.py b/app/router/v1/score.py index d382522..fdedd59 100644 --- a/app/router/v1/score.py +++ b/app/router/v1/score.py @@ -70,9 +70,7 @@ async def get_user_best( session: Database, user: str = Query(..., alias="u", description="用户"), ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), - type: Literal["string", "id"] | None = Query( - None, description="用户类型:string 用户名称 / id 用户 ID" - ), + type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"), limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), ): try: @@ -80,9 +78,7 @@ async def get_user_best( await session.exec( select(Score) .where( - Score.user_id == user - if type == "id" or user.isdigit() - else col(Score.user).has(username=user), + Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user), Score.gamemode == GameMode.from_int_extra(ruleset_id), exists().where(col(PPBestScore.score_id) == Score.id), ) @@ -106,9 +102,7 @@ async def get_user_recent( session: Database, user: str = Query(..., alias="u", description="用户"), ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), - type: Literal["string", "id"] | None = Query( - None, description="用户类型:string 用户名称 / id 用户 ID" - ), + type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"), limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), ): try: @@ -116,9 +110,7 @@ async def get_user_recent( await session.exec( select(Score) .where( - Score.user_id == user - if type == "id" or user.isdigit() - else col(Score.user).has(username=user), + Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user), Score.gamemode == GameMode.from_int_extra(ruleset_id), Score.ended_at > datetime.now(UTC) - timedelta(hours=24), ) @@ -143,9 +135,7 @@ async def get_scores( user: str | None = Query(None, alias="u", description="用户"), beatmap_id: int = Query(alias="b", description="谱面 ID"), ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0), - type: Literal["string", "id"] | None = Query( - None, description="用户类型:string 用户名称 / id 用户 ID" - ), + type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"), limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"), mods: int = Query(0, description="成绩的 MOD"), ): @@ -157,9 +147,7 @@ async def get_scores( .where( Score.gamemode == GameMode.from_int_extra(ruleset_id), Score.beatmap_id == beatmap_id, - Score.user_id == user - if type == "id" or user.isdigit() - else col(Score.user).has(username=user), + Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user), ) .options(joinedload(Score.beatmap)) .order_by(col(Score.classic_total_score).desc()) diff --git a/app/router/v1/user.py b/app/router/v1/user.py index bb57c3a..84bb9b0 100644 --- a/app/router/v1/user.py +++ b/app/router/v1/user.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio from datetime import datetime from typing import Literal @@ -13,7 +12,7 @@ from app.service.user_cache_service import get_user_cache_service from .router import AllStrModel, router -from fastapi import HTTPException, Query +from fastapi import BackgroundTasks, HTTPException, Query from sqlmodel import select @@ -49,9 +48,7 @@ class V1User(AllStrModel): return f"v1_user:{user_id}" @classmethod - async def from_db( - cls, session: Database, db_user: User, ruleset: GameMode | None = None - ) -> "V1User": + async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User": # 确保 user_id 不为 None if db_user.id is None: raise ValueError("User ID cannot be None") @@ -63,9 +60,7 @@ class V1User(AllStrModel): current_statistics = i break if current_statistics: - statistics = await UserStatisticsResp.from_db( - current_statistics, session, db_user.country_code - ) + statistics = await UserStatisticsResp.from_db(current_statistics, session, db_user.country_code) else: statistics = None return cls( @@ -78,9 +73,7 @@ class V1User(AllStrModel): playcount=statistics.play_count if statistics else 0, ranked_score=statistics.ranked_score if statistics else 0, total_score=statistics.total_score if statistics else 0, - pp_rank=statistics.global_rank - if statistics and statistics.global_rank - else 0, + pp_rank=statistics.global_rank if statistics and statistics.global_rank else 0, level=current_statistics.level_current if current_statistics else 0, pp_raw=statistics.pp if statistics else 0.0, accuracy=statistics.hit_accuracy if statistics else 0, @@ -91,9 +84,7 @@ class V1User(AllStrModel): count_rank_a=current_statistics.grade_a if current_statistics else 0, country=db_user.country_code, total_seconds_played=statistics.play_time if statistics else 0, - pp_country_rank=statistics.country_rank - if statistics and statistics.country_rank - else 0, + pp_country_rank=statistics.country_rank if statistics and statistics.country_rank else 0, events=[], # TODO ) @@ -106,14 +97,11 @@ class V1User(AllStrModel): ) async def get_user( session: Database, + background_tasks: BackgroundTasks, user: str = Query(..., alias="u", description="用户"), ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0), - type: Literal["string", "id"] | None = Query( - None, description="用户类型:string 用户名称 / id 用户 ID" - ), - event_days: int = Query( - default=1, ge=1, le=31, description="从现在起所有事件的最大天数" - ), + type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"), + event_days: int = Query(default=1, ge=1, le=31, description="从现在起所有事件的最大天数"), ): redis = get_redis() cache_service = get_user_cache_service(redis) @@ -131,9 +119,7 @@ async def get_user( if is_id_query: try: user_id_for_cache = int(user) - cached_v1_user = await cache_service.get_v1_user_from_cache( - user_id_for_cache, ruleset - ) + cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset) if cached_v1_user: return [V1User(**cached_v1_user)] except (ValueError, TypeError): @@ -158,9 +144,7 @@ async def get_user( # 异步缓存结果(如果有用户ID) if db_user.id is not None: user_data = v1_user.model_dump() - asyncio.create_task( - cache_service.cache_v1_user(user_data, db_user.id, ruleset) - ) + background_tasks.add_task(cache_service.cache_v1_user, user_data, db_user.id, ruleset) return [v1_user] diff --git a/app/router/v2/__init__.py b/app/router/v2/__init__.py index 8e59b4e..968b030 100644 --- a/app/router/v2/__init__.py +++ b/app/router/v2/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401 +from . import ( # noqa: F401 beatmap, beatmapset, me, diff --git a/app/router/v2/beatmap.py b/app/router/v2/beatmap.py index e5a2775..506f9f3 100644 --- a/app/router/v2/beatmap.py +++ b/app/router/v2/beatmap.py @@ -40,18 +40,13 @@ class BatchGetResp(BaseModel): tags=["谱面"], name="查询单个谱面", response_model=BeatmapResp, - description=( - "根据谱面 ID / MD5 / 文件名 查询单个谱面。" - "至少提供 id / checksum / filename 之一。" - ), + description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"), ) async def lookup_beatmap( db: Database, id: int | None = Query(default=None, alias="id", description="谱面 ID"), md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"), - filename: str | None = Query( - default=None, alias="filename", description="谱面文件名" - ), + filename: str | None = Query(default=None, alias="filename", description="谱面文件名"), current_user: User = Security(get_current_user, scopes=["public"]), fetcher: Fetcher = Depends(get_fetcher), ): @@ -96,43 +91,23 @@ async def get_beatmap( tags=["谱面"], name="批量获取谱面", response_model=BatchGetResp, - description=( - "批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。" - "为空时按最近更新时间返回。" - ), + description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"), ) async def batch_get_beatmaps( db: Database, - beatmap_ids: list[int] = Query( - alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)" - ), + beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"), current_user: User = Security(get_current_user, scopes=["public"]), fetcher: Fetcher = Depends(get_fetcher), ): if not beatmap_ids: - beatmaps = ( - await db.exec( - select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50) - ) - ).all() + beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all() else: - beatmaps = list( - ( - await db.exec( - select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50) - ) - ).all() - ) - not_found_beatmaps = [ - bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps] - ] + beatmaps = list((await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50))).all()) + not_found_beatmaps = [bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]] beatmaps.extend( beatmap for beatmap in await asyncio.gather( - *[ - Beatmap.get_or_fetch(db, fetcher, bid=bid) - for bid in not_found_beatmaps - ], + *[Beatmap.get_or_fetch(db, fetcher, bid=bid) for bid in not_found_beatmaps], return_exceptions=True, ) if isinstance(beatmap, Beatmap) @@ -140,12 +115,7 @@ async def batch_get_beatmaps( for beatmap in beatmaps: await db.refresh(beatmap) - return BatchGetResp( - beatmaps=[ - await BeatmapResp.from_db(bm, session=db, user=current_user) - for bm in beatmaps - ] - ) + return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps]) @router.post( @@ -163,12 +133,8 @@ async def get_beatmap_attributes( default_factory=list, description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称", ), - ruleset: GameMode | None = Query( - default=None, description="指定 ruleset;为空则使用谱面自身模式" - ), - ruleset_id: int | None = Query( - default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3 - ), + ruleset: GameMode | None = Query(default=None, description="指定 ruleset;为空则使用谱面自身模式"), + ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3), redis: Redis = Depends(get_redis), fetcher: Fetcher = Depends(get_fetcher), ): @@ -187,16 +153,11 @@ async def get_beatmap_attributes( if ruleset is None: beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id) ruleset = beatmap_db.mode - key = ( - f"beatmap:{beatmap_id}:{ruleset}:" - f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" - ) + key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes" if await redis.exists(key): return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] try: - return await calculate_beatmap_attributes( - beatmap_id, ruleset, mods_, redis, fetcher - ) + return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher) except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmap not found") except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue] diff --git a/app/router/v2/beatmapset.py b/app/router/v2/beatmapset.py index 5d4dadc..c0f50df 100644 --- a/app/router/v2/beatmapset.py +++ b/app/router/v2/beatmapset.py @@ -35,9 +35,7 @@ from sqlmodel import exists, select async def _save_to_db(sets: SearchBeatmapsetsResp): async with with_db() as session: for s in sets.beatmapsets: - if not ( - await session.exec(select(exists()).where(Beatmapset.id == s.id)) - ).first(): + if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first(): await Beatmapset.from_resp(session, s) @@ -117,9 +115,7 @@ async def lookup_beatmapset( fetcher: Fetcher = Depends(get_fetcher), ): beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id) - resp = await BeatmapsetResp.from_db( - beatmap.beatmapset, session=db, user=current_user - ) + resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user) return resp @@ -138,9 +134,7 @@ async def get_beatmapset( ): try: beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id) - return await BeatmapsetResp.from_db( - beatmapset, session=db, include=["recent_favourites"], user=current_user - ) + return await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user) except HTTPError: raise HTTPException(status_code=404, detail="Beatmapset not found") @@ -165,9 +159,7 @@ async def download_beatmapset( country_code = geo_info.get("country_iso", "") # 优先使用IP地理位置判断,如果获取失败则回退到用户账户的国家代码 - is_china = country_code == "CN" or ( - not country_code and current_user.country_code == "CN" - ) + is_china = country_code == "CN" or (not country_code and current_user.country_code == "CN") try: # 使用负载均衡服务获取下载URL @@ -179,13 +171,10 @@ async def download_beatmapset( # 如果负载均衡服务失败,回退到原有逻辑 if is_china: return RedirectResponse( - f"https://dl.sayobot.cn/beatmaps/download/" - f"{'novideo' if no_video else 'full'}/{beatmapset_id}" + f"https://dl.sayobot.cn/beatmaps/download/{'novideo' if no_video else 'full'}/{beatmapset_id}" ) else: - return RedirectResponse( - f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}" - ) + return RedirectResponse(f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}") @router.post( @@ -197,12 +186,9 @@ async def download_beatmapset( async def favourite_beatmapset( db: Database, beatmapset_id: int = Path(..., description="谱面集 ID"), - action: Literal["favourite", "unfavourite"] = Form( - description="操作类型:favourite 收藏 / unfavourite 取消收藏" - ), + action: Literal["favourite", "unfavourite"] = Form(description="操作类型:favourite 收藏 / unfavourite 取消收藏"), current_user: User = Security(get_client_user), ): - assert current_user.id is not None existing_favourite = ( await db.exec( select(FavouriteBeatmapset).where( @@ -212,15 +198,11 @@ async def favourite_beatmapset( ) ).first() - if (action == "favourite" and existing_favourite) or ( - action == "unfavourite" and not existing_favourite - ): + if (action == "favourite" and existing_favourite) or (action == "unfavourite" and not existing_favourite): return if action == "favourite": - favourite = FavouriteBeatmapset( - user_id=current_user.id, beatmapset_id=beatmapset_id - ) + favourite = FavouriteBeatmapset(user_id=current_user.id, beatmapset_id=beatmapset_id) db.add(favourite) else: await db.delete(existing_favourite) diff --git a/app/router/v2/me.py b/app/router/v2/me.py index e7b80ae..16d0c40 100644 --- a/app/router/v2/me.py +++ b/app/router/v2/me.py @@ -4,8 +4,8 @@ from app.database import User from app.database.lazer_user import ALL_INCLUDED from app.dependencies import get_current_user from app.dependencies.database import Database -from app.models.score import GameMode from app.models.api_me import APIMe +from app.models.score import GameMode from .router import router diff --git a/app/router/v2/misc.py b/app/router/v2/misc.py index e0e58db..bd67695 100644 --- a/app/router/v2/misc.py +++ b/app/router/v2/misc.py @@ -33,6 +33,4 @@ class BackgroundsResp(BaseModel): description="获取当前季节背景图列表。", ) async def get_seasonal_backgrounds(): - return BackgroundsResp( - backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds] - ) + return BackgroundsResp(backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds]) diff --git a/app/router/v2/ranking.py b/app/router/v2/ranking.py index dcb6d8f..fe884c9 100644 --- a/app/router/v2/ranking.py +++ b/app/router/v2/ranking.py @@ -12,7 +12,7 @@ from app.service.ranking_cache_service import get_ranking_cache_service from .router import router -from fastapi import Path, Query, Security +from fastapi import BackgroundTasks, Path, Query, Security from pydantic import BaseModel from sqlmodel import col, select @@ -38,6 +38,7 @@ class CountryResponse(BaseModel): ) async def get_country_ranking( session: Database, + background_tasks: BackgroundTasks, ruleset: GameMode = Path(..., description="指定 ruleset"), page: int = Query(1, ge=1, description="页码"), # TODO current_user: User = Security(get_current_user, scopes=["public"]), @@ -51,9 +52,7 @@ async def get_country_ranking( if cached_data: # 从缓存返回数据 - return CountryResponse( - ranking=[CountryStatistics.model_validate(item) for item in cached_data] - ) + return CountryResponse(ranking=[CountryStatistics.model_validate(item) for item in cached_data]) # 缓存未命中,从数据库查询 response = CountryResponse(ranking=[]) @@ -105,14 +104,15 @@ async def get_country_ranking( # 异步缓存数据(不等待完成) cache_data = [item.model_dump() for item in current_page_data] - cache_task = cache_service.cache_country_ranking( - ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60 - ) # 创建后台任务来缓存数据 - import asyncio - - asyncio.create_task(cache_task) + background_tasks.add_task( + cache_service.cache_country_ranking, + ruleset, + cache_data, + page, + ttl=settings.ranking_cache_expire_minutes * 60, + ) # 返回当前页的结果 response.ranking = current_page_data @@ -132,10 +132,9 @@ class TopUsersResponse(BaseModel): ) async def get_user_ranking( session: Database, + background_tasks: BackgroundTasks, ruleset: GameMode = Path(..., description="指定 ruleset"), - type: Literal["performance", "score"] = Path( - ..., description="排名类型:performance 表现分 / score 计分成绩总分" - ), + type: Literal["performance", "score"] = Path(..., description="排名类型:performance 表现分 / score 计分成绩总分"), country: str | None = Query(None, description="国家代码"), page: int = Query(1, ge=1, description="页码"), current_user: User = Security(get_current_user, scopes=["public"]), @@ -149,9 +148,7 @@ async def get_user_ranking( if cached_data: # 从缓存返回数据 - return TopUsersResponse( - ranking=[UserStatisticsResp.model_validate(item) for item in cached_data] - ) + return TopUsersResponse(ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]) # 缓存未命中,从数据库查询 wheres = [ @@ -169,25 +166,22 @@ async def get_user_ranking( wheres.append(col(UserStatistics.user).has(country_code=country.upper())) statistics_list = await session.exec( - select(UserStatistics) - .where(*wheres) - .order_by(order_by) - .limit(50) - .offset(50 * (page - 1)) + select(UserStatistics).where(*wheres).order_by(order_by).limit(50).offset(50 * (page - 1)) ) # 转换为响应格式 ranking_data = [] for statistics in statistics_list: - user_stats_resp = await UserStatisticsResp.from_db( - statistics, session, None, include - ) + user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include) ranking_data.append(user_stats_resp) # 异步缓存数据(不等待完成) # 使用配置文件中的TTL设置 cache_data = [item.model_dump() for item in ranking_data] - cache_task = cache_service.cache_ranking( + # 创建后台任务来缓存数据 + + background_tasks.add_task( + cache_service.cache_ranking, ruleset, type, cache_data, @@ -196,139 +190,134 @@ async def get_user_ranking( ttl=settings.ranking_cache_expire_minutes * 60, ) - # 创建后台任务来缓存数据 - import asyncio - - asyncio.create_task(cache_task) - resp = TopUsersResponse(ranking=ranking_data) return resp -""" @router.post( - "/rankings/cache/refresh", - name="刷新排行榜缓存", - description="手动刷新排行榜缓存(管理员功能)", - tags=["排行榜", "管理"], -) -async def refresh_ranking_cache( - session: Database, - ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"), - type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"), - country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"), - include_country_ranking: bool = Query(True, description="是否包含地区排行榜"), - current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 -): - redis = get_redis() - cache_service = get_ranking_cache_service(redis) - - if ruleset and type: - # 刷新特定的用户排行榜 - await cache_service.refresh_ranking_cache(session, ruleset, type, country) - message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "") - - # 如果请求刷新地区排行榜 - if include_country_ranking and not country: # 地区排行榜不依赖于国家参数 - await cache_service.refresh_country_ranking_cache(session, ruleset) - message += f" and country ranking for {ruleset}" - - return {"message": message} - elif ruleset: - # 刷新特定游戏模式的所有排行榜 - ranking_types: list[Literal["performance", "score"]] = ["performance", "score"] - for ranking_type in ranking_types: - await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country) - - if include_country_ranking: - await cache_service.refresh_country_ranking_cache(session, ruleset) - - return {"message": f"Refreshed all ranking caches for {ruleset}"} - else: - # 刷新所有排行榜 - await cache_service.refresh_all_rankings(session) - return {"message": "Refreshed all ranking caches"} +# @router.post( +# "/rankings/cache/refresh", +# name="刷新排行榜缓存", +# description="手动刷新排行榜缓存(管理员功能)", +# tags=["排行榜", "管理"], +# ) +# async def refresh_ranking_cache( +# session: Database, +# ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"), +# type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"), +# country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"), +# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"), +# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 +# ): +# redis = get_redis() +# cache_service = get_ranking_cache_service(redis) + +# if ruleset and type: +# # 刷新特定的用户排行榜 +# await cache_service.refresh_ranking_cache(session, ruleset, type, country) +# message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "") + +# # 如果请求刷新地区排行榜 +# if include_country_ranking and not country: # 地区排行榜不依赖于国家参数 +# await cache_service.refresh_country_ranking_cache(session, ruleset) +# message += f" and country ranking for {ruleset}" + +# return {"message": message} +# elif ruleset: +# # 刷新特定游戏模式的所有排行榜 +# ranking_types: list[Literal["performance", "score"]] = ["performance", "score"] +# for ranking_type in ranking_types: +# await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country) + +# if include_country_ranking: +# await cache_service.refresh_country_ranking_cache(session, ruleset) + +# return {"message": f"Refreshed all ranking caches for {ruleset}"} +# else: +# # 刷新所有排行榜 +# await cache_service.refresh_all_rankings(session) +# return {"message": "Refreshed all ranking caches"} -@router.post( - "/rankings/{ruleset}/country/cache/refresh", - name="刷新地区排行榜缓存", - description="手动刷新地区排行榜缓存(管理员功能)", - tags=["排行榜", "管理"], -) -async def refresh_country_ranking_cache( - session: Database, - ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"), - current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 -): - redis = get_redis() - cache_service = get_ranking_cache_service(redis) - - await cache_service.refresh_country_ranking_cache(session, ruleset) - return {"message": f"Refreshed country ranking cache for {ruleset}"} +# @router.post( +# "/rankings/{ruleset}/country/cache/refresh", +# name="刷新地区排行榜缓存", +# description="手动刷新地区排行榜缓存(管理员功能)", +# tags=["排行榜", "管理"], +# ) +# async def refresh_country_ranking_cache( +# session: Database, +# ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"), +# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 +# ): +# redis = get_redis() +# cache_service = get_ranking_cache_service(redis) + +# await cache_service.refresh_country_ranking_cache(session, ruleset) +# return {"message": f"Refreshed country ranking cache for {ruleset}"} -@router.delete( - "/rankings/cache", - name="清除排行榜缓存", - description="清除排行榜缓存(管理员功能)", - tags=["排行榜", "管理"], -) -async def clear_ranking_cache( - ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"), - type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"), - country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"), - include_country_ranking: bool = Query(True, description="是否包含地区排行榜"), - current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 -): - redis = get_redis() - cache_service = get_ranking_cache_service(redis) - - await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking) - - if ruleset and type: - message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "") - if include_country_ranking: - message += " and country ranking" - return {"message": message} - else: - message = "Cleared all ranking caches" - if include_country_ranking: - message += " including country rankings" - return {"message": message} +# @router.delete( +# "/rankings/cache", +# name="清除排行榜缓存", +# description="清除排行榜缓存(管理员功能)", +# tags=["排行榜", "管理"], +# ) +# async def clear_ranking_cache( +# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"), +# type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"), +# country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"), +# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"), +# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 +# ): +# redis = get_redis() +# cache_service = get_ranking_cache_service(redis) + +# await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking) + +# if ruleset and type: +# message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "") +# if include_country_ranking: +# message += " and country ranking" +# return {"message": message} +# else: +# message = "Cleared all ranking caches" +# if include_country_ranking: +# message += " including country rankings" +# return {"message": message} -@router.delete( - "/rankings/{ruleset}/country/cache", - name="清除地区排行榜缓存", - description="清除地区排行榜缓存(管理员功能)", - tags=["排行榜", "管理"], -) -async def clear_country_ranking_cache( - ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"), - current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 -): - redis = get_redis() - cache_service = get_ranking_cache_service(redis) - - await cache_service.invalidate_country_cache(ruleset) - - if ruleset: - return {"message": f"Cleared country ranking cache for {ruleset}"} - else: - return {"message": "Cleared all country ranking caches"} +# @router.delete( +# "/rankings/{ruleset}/country/cache", +# name="清除地区排行榜缓存", +# description="清除地区排行榜缓存(管理员功能)", +# tags=["排行榜", "管理"], +# ) +# async def clear_country_ranking_cache( +# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"), +# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 +# ): +# redis = get_redis() +# cache_service = get_ranking_cache_service(redis) + +# await cache_service.invalidate_country_cache(ruleset) + +# if ruleset: +# return {"message": f"Cleared country ranking cache for {ruleset}"} +# else: +# return {"message": "Cleared all country ranking caches"} -@router.get( - "/rankings/cache/stats", - name="获取排行榜缓存统计", - description="获取排行榜缓存统计信息(管理员功能)", - tags=["排行榜", "管理"], -) -async def get_ranking_cache_stats( - current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 -): - redis = get_redis() - cache_service = get_ranking_cache_service(redis) - - stats = await cache_service.get_cache_stats() - return stats """ +# @router.get( +# "/rankings/cache/stats", +# name="获取排行榜缓存统计", +# description="获取排行榜缓存统计信息(管理员功能)", +# tags=["排行榜", "管理"], +# ) +# async def get_ranking_cache_stats( +# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限 +# ): +# redis = get_redis() +# cache_service = get_ranking_cache_service(redis) + +# stats = await cache_service.get_cache_stats() +# return stats diff --git a/app/router/v2/relationship.py b/app/router/v2/relationship.py index accbde8..a941f8a 100644 --- a/app/router/v2/relationship.py +++ b/app/router/v2/relationship.py @@ -30,11 +30,7 @@ async def get_relationship( request: Request, current_user: User = Security(get_current_user, scopes=["friends.read"]), ): - relationship_type = ( - RelationshipType.FOLLOW - if request.url.path.endswith("/friends") - else RelationshipType.BLOCK - ) + relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK relationships = await db.exec( select(Relationship).where( Relationship.user_id == current_user.id, @@ -71,12 +67,7 @@ async def add_relationship( target: int = Query(description="目标用户 ID"), current_user: User = Security(get_client_user), ): - assert current_user.id is not None - relationship_type = ( - RelationshipType.FOLLOW - if request.url.path.endswith("/friends") - else RelationshipType.BLOCK - ) + relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK if target == current_user.id: raise HTTPException(422, "Cannot add relationship to yourself") relationship = ( @@ -120,11 +111,8 @@ async def add_relationship( Relationship.target_id == target, ) ) - ).first() - assert relationship, "Relationship should exist after commit" - return AddFriendResp( - user_relation=await RelationshipResp.from_db(db, relationship) - ) + ).one() + return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship)) @router.delete( @@ -145,11 +133,7 @@ async def delete_relationship( target: int = Path(..., description="目标用户 ID"), current_user: User = Security(get_client_user), ): - relationship_type = ( - RelationshipType.BLOCK - if "/blocks/" in request.url.path - else RelationshipType.FOLLOW - ) + relationship_type = RelationshipType.BLOCK if "/blocks/" in request.url.path else RelationshipType.FOLLOW relationship = ( await db.exec( select(Relationship).where( diff --git a/app/router/v2/room.py b/app/router/v2/room.py index 1f4fd79..03341d0 100644 --- a/app/router/v2/room.py +++ b/app/router/v2/room.py @@ -39,17 +39,11 @@ async def get_all_rooms( db: Database, mode: Literal["open", "ended", "participated", "owned", None] = Query( default="open", - description=( - "房间模式:open 当前开放 / ended 已经结束 / " - "participated 参与过 / owned 自己创建的房间" - ), + description=("房间模式:open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"), ), category: RoomCategory = Query( RoomCategory.NORMAL, - description=( - "房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间" - " / DAILY_CHALLENGE 每日挑战" - ), + description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"), ), status: RoomStatus | None = Query(None, description="房间状态(可选)"), current_user: User = Security(get_current_user, scopes=["public"]), @@ -60,10 +54,7 @@ async def get_all_rooms( if status is not None: where_clauses.append(col(Room.status) == status) if mode == "open": - where_clauses.append( - (col(Room.ends_at).is_(None)) - | (col(Room.ends_at) > now.replace(tzinfo=UTC)) - ) + where_clauses.append((col(Room.ends_at).is_(None)) | (col(Room.ends_at) > now.replace(tzinfo=UTC))) if category == RoomCategory.REALTIME: where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys())) if mode == "participated": @@ -76,10 +67,7 @@ async def get_all_rooms( if mode == "owned": where_clauses.append(col(Room.host_id) == current_user.id) if mode == "ended": - where_clauses.append( - (col(Room.ends_at).is_not(None)) - & (col(Room.ends_at) < now.replace(tzinfo=UTC)) - ) + where_clauses.append((col(Room.ends_at).is_not(None)) & (col(Room.ends_at) < now.replace(tzinfo=UTC))) db_rooms = ( ( @@ -97,11 +85,7 @@ async def get_all_rooms( resp = await RoomResp.from_db(room, db) if category == RoomCategory.REALTIME: mp_room = MultiplayerHubs.rooms.get(room.id) - resp.has_password = ( - bool(mp_room.room.settings.password.strip()) - if mp_room is not None - else False - ) + resp.has_password = bool(mp_room.room.settings.password.strip()) if mp_room is not None else False resp.category = RoomCategory.NORMAL resp_list.append(resp) @@ -115,9 +99,7 @@ class APICreatedRoom(RoomResp): error: str = "" -async def _participate_room( - room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis -): +async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis): participated_user = ( await session.exec( select(RoomParticipatedUser).where( @@ -154,7 +136,6 @@ async def create_room( current_user: User = Security(get_client_user), redis: Redis = Depends(get_redis), ): - assert current_user.id is not None user_id = current_user.id db_room = await create_playlist_room_from_api(db, room, user_id) await _participate_room(db_room.id, user_id, db_room, db, redis) @@ -177,10 +158,7 @@ async def get_room( room_id: int = Path(..., description="房间 ID"), category: str = Query( default="", - description=( - "房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间" - " / DAILY_CHALLENGE 每日挑战 (可选)" - ), + description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"), ), current_user: User = Security(get_client_user), redis: Redis = Depends(get_redis), @@ -188,9 +166,7 @@ async def get_room( db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() if db_room is None: raise HTTPException(404, "Room not found") - resp = await RoomResp.from_db( - db_room, include=["current_user_score"], session=db, user=current_user - ) + resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user) return resp @@ -400,7 +376,6 @@ async def get_room_events( for score in scores: user_ids.add(score.user_id) beatmap_ids.add(score.beatmap_id) - assert event.id is not None first_event_id = min(first_event_id, event.id) last_event_id = max(last_event_id, event.id) @@ -416,16 +391,12 @@ async def get_room_events( users = await db.exec(select(User).where(col(User.id).in_(user_ids))) user_resps = [await UserResp.from_db(user, db) for user in users] beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids))) - beatmap_resps = [ - await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps - ] + beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps] beatmapset_resps = {} for beatmap_resp in beatmap_resps: beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset - playlist_items_resps = [ - await PlaylistResp.from_db(item) for item in playlist_items.values() - ] + playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()] return RoomEvents( beatmaps=beatmap_resps, diff --git a/app/router/v2/score.py b/app/router/v2/score.py index 8838d8f..9007bb6 100644 --- a/app/router/v2/score.py +++ b/app/router/v2/score.py @@ -104,11 +104,7 @@ async def submit_score( if not info.passed: info.rank = Rank.F score_token = ( - await db.exec( - select(ScoreToken) - .options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType] - .where(ScoreToken.id == token) - ) + await db.exec(select(ScoreToken).options(joinedload(ScoreToken.beatmap)).where(ScoreToken.id == token)) ).first() if not score_token or score_token.user_id != user_id: raise HTTPException(status_code=404, detail="Score token not found") @@ -138,10 +134,7 @@ async def submit_score( except HTTPError: raise HTTPException(status_code=404, detail="Beatmap not found") has_pp = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp - has_leaderboard = ( - db_beatmap.beatmap_status.has_leaderboard() - | settings.enable_all_beatmap_leaderboard - ) + has_leaderboard = db_beatmap.beatmap_status.has_leaderboard() | settings.enable_all_beatmap_leaderboard beatmap_length = db_beatmap.total_length score = await process_score( current_user, @@ -167,21 +160,11 @@ async def submit_score( has_pp, has_leaderboard, ) - score = ( - await db.exec( - select(Score) - .options(joinedload(Score.user)) # pyright: ignore[reportArgumentType] - .where(Score.id == score_id) - ) - ).first() - assert score is not None + score = (await db.exec(select(Score).options(joinedload(Score.user)).where(Score.id == score_id))).one() resp = await ScoreResp.from_db(db, score) - total_users = (await db.exec(select(func.count()).select_from(User))).first() - assert total_users is not None - if resp.rank_global is not None and resp.rank_global <= min( - math.ceil(float(total_users) * 0.01), 50 - ): + total_users = (await db.exec(select(func.count()).select_from(User))).one() + if resp.rank_global is not None and resp.rank_global <= min(math.ceil(float(total_users) * 0.01), 50): rank_event = Event( created_at=datetime.now(UTC), type=EventType.RANK, @@ -207,9 +190,7 @@ async def submit_score( score_gamemode = score.gamemode if user_id is not None: - background_task.add_task( - _refresh_user_cache_background, redis, user_id, score_gamemode - ) + background_task.add_task(_refresh_user_cache_background, redis, user_id, score_gamemode) background_task.add_task(process_user_achievement, resp.id) return resp @@ -225,9 +206,7 @@ async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameM # 创建独立的数据库会话 session = AsyncSession(engine) try: - await user_cache_service.refresh_user_cache_on_score_submit( - session, user_id, mode - ) + await user_cache_service.refresh_user_cache_on_score_submit(session, user_id, mode) finally: await session.close() except Exception as e: @@ -280,22 +259,16 @@ async def get_beatmap_scores( beatmap_id: int = Path(description="谱面 ID"), mode: GameMode = Query(description="指定 auleset"), legacy_only: bool = Query(None, description="是否只查询 Stable 分数"), - mods: list[str] = Query( - default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)" - ), + mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"), type: LeaderboardType = Query( LeaderboardType.GLOBAL, - description=( - "排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队" - ), + description=("排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"), ), current_user: User = Security(get_current_user, scopes=["public"]), limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"), ): if legacy_only: - raise HTTPException( - status_code=404, detail="this server only contains lazer scores" - ) + raise HTTPException(status_code=404, detail="this server only contains lazer scores") all_scores, user_score, count = await get_leaderboard( db, @@ -310,9 +283,7 @@ async def get_beatmap_scores( user_score_resp = await ScoreResp.from_db(db, user_score) if user_score else None resp = BeatmapScores( scores=[await ScoreResp.from_db(db, score) for score in all_scores], - user_score=BeatmapUserScore( - score=user_score_resp, position=user_score_resp.rank_global or 0 - ) + user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0) if user_score_resp else None, score_count=count, @@ -342,9 +313,7 @@ async def get_user_beatmap_score( current_user: User = Security(get_current_user, scopes=["public"]), ): if legacy_only: - raise HTTPException( - status_code=404, detail="This server only contains non-legacy scores" - ) + raise HTTPException(status_code=404, detail="This server only contains non-legacy scores") user_score = ( await db.exec( select(Score) @@ -386,9 +355,7 @@ async def get_user_all_beatmap_scores( current_user: User = Security(get_current_user, scopes=["public"]), ): if legacy_only: - raise HTTPException( - status_code=404, detail="This server only contains non-legacy scores" - ) + raise HTTPException(status_code=404, detail="This server only contains non-legacy scores") all_user_scores = ( await db.exec( select(Score) @@ -420,7 +387,6 @@ async def create_solo_score( ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"), current_user: User = Security(get_client_user), ): - assert current_user.id is not None # 立即获取用户ID,避免懒加载问题 user_id = current_user.id @@ -454,10 +420,7 @@ async def submit_solo_score( redis: Redis = Depends(get_redis), fetcher=Depends(get_fetcher), ): - assert current_user.id is not None - return await submit_score( - background_task, info, beatmap_id, token, current_user, db, redis, fetcher - ) + return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher) @router.post( @@ -478,7 +441,6 @@ async def create_playlist_score( version_hash: str = Form("", description="谱面版本哈希"), current_user: User = Security(get_client_user), ): - assert current_user.id is not None # 立即获取用户ID,避免懒加载问题 user_id = current_user.id @@ -488,26 +450,16 @@ async def create_playlist_score( db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC): raise HTTPException(status_code=400, detail="Room has ended") - item = ( - await session.exec( - select(Playlist).where( - Playlist.id == playlist_id, Playlist.room_id == room_id - ) - ) - ).first() + item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first() if not item: raise HTTPException(status_code=404, detail="Playlist not found") # validate if not item.freestyle: if item.ruleset_id != ruleset_id: - raise HTTPException( - status_code=400, detail="Ruleset mismatch in playlist item" - ) + raise HTTPException(status_code=400, detail="Ruleset mismatch in playlist item") if item.beatmap_id != beatmap_id: - raise HTTPException( - status_code=400, detail="Beatmap ID mismatch in playlist item" - ) + raise HTTPException(status_code=400, detail="Beatmap ID mismatch in playlist item") agg = await session.exec( select(ItemAttemptsCount).where( ItemAttemptsCount.room_id == room_id, @@ -523,9 +475,7 @@ async def create_playlist_score( if item.expired: raise HTTPException(status_code=400, detail="Playlist item has expired") if item.played_at: - raise HTTPException( - status_code=400, detail="Playlist item has already been played" - ) + raise HTTPException(status_code=400, detail="Playlist item has already been played") # 这里应该不用验证mod了吧。。。 background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id) score_token = ScoreToken( @@ -557,18 +507,10 @@ async def submit_playlist_score( redis: Redis = Depends(get_redis), fetcher: Fetcher = Depends(get_fetcher), ): - assert current_user.id is not None - # 立即获取用户ID,避免懒加载问题 user_id = current_user.id - item = ( - await session.exec( - select(Playlist).where( - Playlist.id == playlist_id, Playlist.room_id == room_id - ) - ) - ).first() + item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first() if not item: raise HTTPException(status_code=404, detail="Playlist item not found") room = await session.get(Room, room_id) @@ -621,9 +563,7 @@ async def index_playlist_scores( room_id: int, playlist_id: int, limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"), - cursor: int = Query( - 2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)" - ), + cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"), current_user: User = Security(get_current_user, scopes=["public"]), ): # 立即获取用户ID,避免懒加载问题 @@ -693,9 +633,6 @@ async def show_playlist_score( current_user: User = Security(get_client_user), redis: Redis = Depends(get_redis), ): - # 立即获取用户ID,避免懒加载问题 - user_id = current_user.id - room = await session.get(Room, room_id) if not room: raise HTTPException(status_code=404, detail="Room not found") @@ -715,9 +652,7 @@ async def show_playlist_score( ) ) ).first() - if completed_players := await redis.get( - f"multiplayer:{room_id}:gameplay:players" - ): + if completed_players := await redis.get(f"multiplayer:{room_id}:gameplay:players"): completed = completed_players == "0" if score_record and completed: break @@ -784,9 +719,7 @@ async def get_user_playlist_score( raise HTTPException(status_code=404, detail="Score not found") resp = await ScoreResp.from_db(session, score_record.score) - resp.position = await get_position( - room_id, playlist_id, score_record.score_id, session - ) + resp.position = await get_position(room_id, playlist_id, score_record.score_id, session) return resp @@ -850,11 +783,7 @@ async def unpin_score( # 立即获取用户ID,避免懒加载问题 user_id = current_user.id - score_record = ( - await db.exec( - select(Score).where(Score.id == score_id, Score.user_id == user_id) - ) - ).first() + score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first() if not score_record: raise HTTPException(status_code=404, detail="Score not found") @@ -878,10 +807,7 @@ async def unpin_score( "/score-pins/{score_id}/reorder", status_code=204, name="调整置顶成绩顺序", - description=( - "**客户端专属**\n调整已置顶成绩的展示顺序。" - "仅提供 after_score_id 或 before_score_id 之一。" - ), + description=("**客户端专属**\n调整已置顶成绩的展示顺序。仅提供 after_score_id 或 before_score_id 之一。"), tags=["成绩"], ) async def reorder_score_pin( @@ -894,11 +820,7 @@ async def reorder_score_pin( # 立即获取用户ID,避免懒加载问题 user_id = current_user.id - score_record = ( - await db.exec( - select(Score).where(Score.id == score_id, Score.user_id == user_id) - ) - ).first() + score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first() if not score_record: raise HTTPException(status_code=404, detail="Score not found") @@ -908,8 +830,7 @@ async def reorder_score_pin( if (after_score_id is None) == (before_score_id is None): raise HTTPException( status_code=400, - detail="Either after_score_id or before_score_id " - "must be provided (but not both)", + detail="Either after_score_id or before_score_id must be provided (but not both)", ) all_pinned_scores = ( @@ -927,9 +848,7 @@ async def reorder_score_pin( target_order = None reference_score_id = after_score_id or before_score_id - reference_score = next( - (s for s in all_pinned_scores if s.id == reference_score_id), None - ) + reference_score = next((s for s in all_pinned_scores if s.id == reference_score_id), None) if not reference_score: detail = "After score not found" if after_score_id else "Before score not found" raise HTTPException(status_code=404, detail=detail) @@ -951,9 +870,7 @@ async def reorder_score_pin( if current_order < s.pinned_order <= target_order and s.id != score_id: updates.append((s.id, s.pinned_order - 1)) if after_score_id: - final_target = ( - target_order - 1 if target_order > current_order else target_order - ) + final_target = target_order - 1 if target_order > current_order else target_order else: final_target = target_order else: @@ -964,9 +881,7 @@ async def reorder_score_pin( for score_id, new_order in updates: await db.exec(select(Score).where(Score.id == score_id)) - score_to_update = ( - await db.exec(select(Score).where(Score.id == score_id)) - ).first() + score_to_update = (await db.exec(select(Score).where(Score.id == score_id))).first() if score_to_update: score_to_update.pinned_order = new_order diff --git a/app/router/v2/session_verify.py b/app/router/v2/session_verify.py index 5203785..d8596ee 100644 --- a/app/router/v2/session_verify.py +++ b/app/router/v2/session_verify.py @@ -4,34 +4,29 @@ from __future__ import annotations -from datetime import datetime, UTC from typing import Annotated -from app.auth import authenticate_user -from app.config import settings from app.database import User from app.dependencies import get_current_user from app.dependencies.database import Database, get_redis from app.dependencies.geoip import GeoIPHelper, get_geoip_helper -from app.database.email_verification import EmailVerification, LoginSession from app.service.email_verification_service import ( - EmailVerificationService, - LoginSessionService + EmailVerificationService, + LoginSessionService, ) from app.service.login_log_service import LoginLogService -from app.models.extended_auth import ExtendedTokenResponse - -from fastapi import Form, Depends, Request, HTTPException, status, Security -from fastapi.responses import JSONResponse, Response -from pydantic import BaseModel -from redis.asyncio import Redis -from sqlmodel import select from .router import router +from fastapi import Depends, Form, HTTPException, Request, Security, status +from fastapi.responses import Response +from pydantic import BaseModel +from redis.asyncio import Redis + class SessionReissueResponse(BaseModel): """重新发送验证码响应""" + success: bool message: str @@ -40,39 +35,35 @@ class SessionReissueResponse(BaseModel): "/session/verify", name="验证会话", description="验证邮件验证码并完成会话认证", - status_code=204 + status_code=204, ) async def verify_session( request: Request, db: Database, redis: Annotated[Redis, Depends(get_redis)], verification_key: str = Form(..., description="8位邮件验证码"), - current_user: User = Security(get_current_user) + current_user: User = Security(get_current_user), ) -> Response: """ 验证邮件验证码并完成会话认证 - + 对应 osu! 的 session/verify 接口 成功时返回 204 No Content,失败时返回 401 Unauthorized """ try: from app.dependencies.geoip import get_client_ip - ip_address = get_client_ip(request) - user_agent = request.headers.get("User-Agent", "Unknown") - + + ip_address = get_client_ip(request) # noqa: F841 + user_agent = request.headers.get("User-Agent", "Unknown") # noqa: F841 + # 从当前认证用户获取信息 user_id = current_user.id if not user_id: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="用户未认证" - ) - + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未认证") + # 验证邮件验证码 - success, message = await EmailVerificationService.verify_code( - db, redis, user_id, verification_key - ) - + success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_key) + if success: # 记录成功的邮件验证 await LoginLogService.record_login( @@ -81,9 +72,9 @@ async def verify_session( request=request, login_method="email_verification", login_success=True, - notes=f"邮件验证成功" + notes="邮件验证成功", ) - + # 返回 204 No Content 表示验证成功 return Response(status_code=status.HTTP_204_NO_CONTENT) else: @@ -93,83 +84,69 @@ async def verify_session( request=request, attempted_username=current_user.username, login_method="email_verification", - notes=f"邮件验证失败: {message}" + notes=f"邮件验证失败: {message}", ) - + # 返回 401 Unauthorized 表示验证失败 - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=message - ) - + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=message) + except ValueError: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="无效的用户会话" - ) - except Exception as e: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="验证过程中发生错误" - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的用户会话") + except Exception: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="验证过程中发生错误") @router.post( "/session/verify/reissue", name="重新发送验证码", description="重新发送邮件验证码", - response_model=SessionReissueResponse + response_model=SessionReissueResponse, ) async def reissue_verification_code( request: Request, db: Database, redis: Annotated[Redis, Depends(get_redis)], - current_user: User = Security(get_current_user) + current_user: User = Security(get_current_user), ) -> SessionReissueResponse: """ 重新发送邮件验证码 - + 对应 osu! 的 session/verify/reissue 接口 """ try: from app.dependencies.geoip import get_client_ip + ip_address = get_client_ip(request) user_agent = request.headers.get("User-Agent", "Unknown") - + # 从当前认证用户获取信息 user_id = current_user.id if not user_id: - return SessionReissueResponse( - success=False, - message="用户未认证" - ) - + return SessionReissueResponse(success=False, message="用户未认证") + # 重新发送验证码 success, message = await EmailVerificationService.resend_verification_code( - db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent + db, + redis, + user_id, + current_user.username, + current_user.email, + ip_address, + user_agent, ) - - return SessionReissueResponse( - success=success, - message=message - ) - + + return SessionReissueResponse(success=success, message=message) + except ValueError: - return SessionReissueResponse( - success=False, - message="无效的用户会话" - ) - except Exception as e: - return SessionReissueResponse( - success=False, - message="重新发送过程中发生错误" - ) + return SessionReissueResponse(success=False, message="无效的用户会话") + except Exception: + return SessionReissueResponse(success=False, message="重新发送过程中发生错误") @router.post( "/session/check-new-location", name="检查新位置登录", - description="检查登录是否来自新位置(内部接口)" + description="检查登录是否来自新位置(内部接口)", ) async def check_new_location( request: Request, @@ -183,22 +160,21 @@ async def check_new_location( """ try: from app.dependencies.geoip import get_client_ip + ip_address = get_client_ip(request) geo_info = geoip.lookup(ip_address) country_code = geo_info.get("country_iso", "XX") - - is_new_location = await LoginSessionService.check_new_location( - db, user_id, ip_address, country_code - ) - + + is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code) + return { "is_new_location": is_new_location, "ip_address": ip_address, - "country_code": country_code + "country_code": country_code, } - + except Exception as e: return { "is_new_location": True, # 出错时默认为新位置 - "error": str(e) + "error": str(e), } diff --git a/app/router/v2/stats.py b/app/router/v2/stats.py index c0c83dd..c4744fa 100644 --- a/app/router/v2/stats.py +++ b/app/router/v2/stats.py @@ -1,73 +1,80 @@ from __future__ import annotations import asyncio -from datetime import datetime, timedelta -import json -from typing import Any from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +import json from app.dependencies.database import get_redis, get_redis_message from app.log import logger +from app.utils import bg_tasks from .router import router -from fastapi import APIRouter from pydantic import BaseModel # Redis key constants REDIS_ONLINE_USERS_KEY = "server:online_users" -REDIS_PLAYING_USERS_KEY = "server:playing_users" +REDIS_PLAYING_USERS_KEY = "server:playing_users" REDIS_REGISTERED_USERS_KEY = "server:registered_users" REDIS_ONLINE_HISTORY_KEY = "server:online_history" # 线程池用于同步Redis操作 _executor = ThreadPoolExecutor(max_workers=2) + async def _redis_exec(func, *args, **kwargs): """在线程池中执行同步Redis操作""" loop = asyncio.get_event_loop() return await loop.run_in_executor(_executor, func, *args, **kwargs) + class ServerStats(BaseModel): """服务器统计信息响应模型""" + registered_users: int online_users: int playing_users: int timestamp: datetime + class OnlineHistoryPoint(BaseModel): """在线历史数据点""" + timestamp: datetime online_count: int playing_count: int + class OnlineHistoryResponse(BaseModel): """24小时在线历史响应模型""" + history: list[OnlineHistoryPoint] current_stats: ServerStats + @router.get("/stats", response_model=ServerStats, tags=["统计"]) async def get_server_stats() -> ServerStats: """ 获取服务器实时统计信息 - + 返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息 """ redis = get_redis() - + try: # 并行获取所有统计数据 registered_count, online_count, playing_count = await asyncio.gather( _get_registered_users_count(redis), _get_online_users_count(redis), - _get_playing_users_count(redis) + _get_playing_users_count(redis), ) - + return ServerStats( registered_users=registered_count, online_users=online_count, playing_users=playing_count, - timestamp=datetime.utcnow() + timestamp=datetime.utcnow(), ) except Exception as e: logger.error(f"Error getting server stats: {e}") @@ -76,14 +83,15 @@ async def get_server_stats() -> ServerStats: registered_users=0, online_users=0, playing_users=0, - timestamp=datetime.utcnow() + timestamp=datetime.utcnow(), ) + @router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"]) async def get_online_history() -> OnlineHistoryResponse: """ 获取最近24小时在线统计历史 - + 返回过去24小时内每小时的在线用户数和游玩用户数统计, 包含当前实时数据作为最新数据点 """ @@ -92,80 +100,80 @@ async def get_online_history() -> OnlineHistoryResponse: redis_sync = get_redis_message() history_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1) history_points = [] - + # 处理历史数据 for data in history_data: try: point_data = json.loads(data) # 只保留基本字段 - history_points.append(OnlineHistoryPoint( - timestamp=datetime.fromisoformat(point_data["timestamp"]), - online_count=point_data["online_count"], - playing_count=point_data["playing_count"] - )) + history_points.append( + OnlineHistoryPoint( + timestamp=datetime.fromisoformat(point_data["timestamp"]), + online_count=point_data["online_count"], + playing_count=point_data["playing_count"], + ) + ) except (json.JSONDecodeError, KeyError, ValueError) as e: logger.warning(f"Invalid history data point: {data}, error: {e}") continue - + # 获取当前实时统计信息 current_stats = await get_server_stats() - + # 如果历史数据为空或者最新数据超过15分钟,添加当前数据点 if not history_points or ( - history_points and - (current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds() > 15 * 60 + history_points + and (current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds() + > 15 * 60 ): - history_points.append(OnlineHistoryPoint( - timestamp=current_stats.timestamp, - online_count=current_stats.online_users, - playing_count=current_stats.playing_users - )) - + history_points.append( + OnlineHistoryPoint( + timestamp=current_stats.timestamp, + online_count=current_stats.online_users, + playing_count=current_stats.playing_users, + ) + ) + # 按时间排序(最新的在前) history_points.sort(key=lambda x: x.timestamp, reverse=True) - + # 限制到最多48个数据点(24小时) history_points = history_points[:48] - - return OnlineHistoryResponse( - history=history_points, - current_stats=current_stats - ) + + return OnlineHistoryResponse(history=history_points, current_stats=current_stats) except Exception as e: logger.error(f"Error getting online history: {e}") # 返回空历史和当前状态 current_stats = await get_server_stats() - return OnlineHistoryResponse( - history=[], - current_stats=current_stats - ) + return OnlineHistoryResponse(history=[], current_stats=current_stats) + @router.get("/stats/debug", tags=["统计"]) async def get_stats_debug_info(): """ 获取统计系统调试信息 - + 用于调试时间对齐和区间统计问题 """ try: from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager - + current_time = datetime.utcnow() current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() interval_stats = await EnhancedIntervalStatsManager.get_current_interval_stats() - + # 获取Redis中的实际数据 redis_sync = get_redis_message() - + online_key = f"server:interval_online_users:{current_interval.interval_key}" playing_key = f"server:interval_playing_users:{current_interval.interval_key}" - + online_users_raw = await _redis_exec(redis_sync.smembers, online_key) playing_users_raw = await _redis_exec(redis_sync.smembers, playing_key) - + online_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in online_users_raw] playing_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in playing_users_raw] - + return { "current_time": current_time.isoformat(), "current_interval": { @@ -175,28 +183,29 @@ async def get_stats_debug_info(): "is_current": current_interval.is_current(), "minutes_remaining": int((current_interval.end_time - current_time).total_seconds() / 60), "seconds_remaining": int((current_interval.end_time - current_time).total_seconds()), - "progress_percentage": round((1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100, 1) + "progress_percentage": round( + (1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100, + 1, + ), }, "interval_statistics": interval_stats.to_dict() if interval_stats else None, "redis_data": { "online_users": online_users, "playing_users": playing_users, "online_count": len(online_users), - "playing_count": len(playing_users) + "playing_count": len(playing_users), }, "system_status": { "stats_system": "enhanced_interval_stats", "data_alignment": "30_minute_boundaries", "real_time_updates": True, - "auto_24h_fill": True - } + "auto_24h_fill": True, + }, } except Exception as e: logger.error(f"Error getting debug info: {e}") - return { - "error": "Failed to retrieve debug information", - "message": str(e) - } + return {"error": "Failed to retrieve debug information", "message": str(e)} + async def _get_registered_users_count(redis) -> int: """获取注册用户总数(从缓存)""" @@ -207,6 +216,7 @@ async def _get_registered_users_count(redis) -> int: logger.error(f"Error getting registered users count: {e}") return 0 + async def _get_online_users_count(redis) -> int: """获取当前在线用户数""" try: @@ -216,6 +226,7 @@ async def _get_online_users_count(redis) -> int: logger.error(f"Error getting online users count: {e}") return 0 + async def _get_playing_users_count(redis) -> int: """获取当前游玩用户数""" try: @@ -225,27 +236,28 @@ async def _get_playing_users_count(redis) -> int: logger.error(f"Error getting playing users count: {e}") return 0 + # 统计更新功能 async def update_registered_users_count() -> None: """更新注册用户数缓存""" - from app.dependencies.database import with_db - from app.database import User from app.const import BANCHOBOT_ID - from sqlmodel import select, func - + from app.database import User + from app.dependencies.database import with_db + + from sqlmodel import func, select + redis = get_redis() try: async with with_db() as db: # 排除机器人用户(BANCHOBOT_ID) - result = await db.exec( - select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID) - ) + result = await db.exec(select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID)) count = result.first() await redis.set(REDIS_REGISTERED_USERS_KEY, count or 0, ex=300) # 5分钟过期 logger.debug(f"Updated registered users count: {count}") except Exception as e: logger.error(f"Error updating registered users count: {e}") + async def add_online_user(user_id: int) -> None: """添加在线用户""" redis_sync = get_redis_message() @@ -257,14 +269,20 @@ async def add_online_user(user_id: int) -> None: if ttl <= 0: # -1表示永不过期,-2表示不存在,0表示已过期 await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期 logger.debug(f"Added online user {user_id}") - + # 立即更新当前区间统计 from app.service.enhanced_interval_stats import update_user_activity_in_interval - asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=False)) - + + bg_tasks.add_task( + update_user_activity_in_interval, + user_id, + is_playing=False, + ) + except Exception as e: logger.error(f"Error adding online user {user_id}: {e}") + async def remove_online_user(user_id: int) -> None: """移除在线用户""" redis_sync = get_redis_message() @@ -274,6 +292,7 @@ async def remove_online_user(user_id: int) -> None: except Exception as e: logger.error(f"Error removing online user {user_id}: {e}") + async def add_playing_user(user_id: int) -> None: """添加游玩用户""" redis_sync = get_redis_message() @@ -285,14 +304,16 @@ async def add_playing_user(user_id: int) -> None: if ttl <= 0: # -1表示永不过期,-2表示不存在,0表示已过期 await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期 logger.debug(f"Added playing user {user_id}") - + # 立即更新当前区间统计 from app.service.enhanced_interval_stats import update_user_activity_in_interval - asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=True)) - + + bg_tasks.add_task(update_user_activity_in_interval, user_id, is_playing=True) + except Exception as e: logger.error(f"Error adding playing user {user_id}: {e}") + async def remove_playing_user(user_id: int) -> None: """移除游玩用户""" redis_sync = get_redis_message() @@ -301,6 +322,7 @@ async def remove_playing_user(user_id: int) -> None: except Exception as e: logger.error(f"Error removing playing user {user_id}: {e}") + async def record_hourly_stats() -> None: """记录统计数据 - 简化版本,主要作为fallback使用""" redis_sync = get_redis_message() @@ -308,24 +330,27 @@ async def record_hourly_stats() -> None: try: # 先确保Redis连接正常 await redis_async.ping() - + online_count = await _get_online_users_count(redis_async) playing_count = await _get_playing_users_count(redis_async) - + current_time = datetime.utcnow() history_point = { "timestamp": current_time.isoformat(), "online_count": online_count, - "playing_count": playing_count + "playing_count": playing_count, } - + # 添加到历史记录 await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)) # 只保留48个数据点(24小时,每30分钟一个点) await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47) # 设置过期时间为26小时,确保有足够缓冲 await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600) - - logger.info(f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}") + + logger.info( + f"Recorded fallback stats: online={online_count}, playing={playing_count} " + f"at {current_time.strftime('%H:%M:%S')}" + ) except Exception as e: logger.error(f"Error recording fallback stats: {e}") diff --git a/app/router/v2/user.py b/app/router/v2/user.py index 9e67de6..501bce0 100644 --- a/app/router/v2/user.py +++ b/app/router/v2/user.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio from datetime import UTC, datetime, timedelta from typing import Literal @@ -26,7 +25,7 @@ from app.service.user_cache_service import get_user_cache_service from .router import router -from fastapi import HTTPException, Path, Query, Security +from fastapi import BackgroundTasks, HTTPException, Path, Query, Security from pydantic import BaseModel from sqlmodel import exists, false, select from sqlmodel.sql.expression import col @@ -47,13 +46,10 @@ class BatchUserResponse(BaseModel): @router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False) async def get_users( session: Database, - user_ids: list[int] = Query( - default_factory=list, alias="ids[]", description="要查询的用户 ID 列表" - ), + background_task: BackgroundTasks, + user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"), # current_user: User = Security(get_current_user, scopes=["public"]), - include_variant_statistics: bool = Query( - default=False, description="是否包含各模式的统计信息" - ), # TODO: future use + include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use ): redis = get_redis() cache_service = get_user_cache_service(redis) @@ -72,11 +68,7 @@ async def get_users( # 查询未缓存的用户 if uncached_user_ids: - searched_users = ( - await session.exec( - select(User).where(col(User.id).in_(uncached_user_ids)) - ) - ).all() + searched_users = (await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))).all() # 将查询到的用户添加到缓存并返回 for searched_user in searched_users: @@ -88,7 +80,7 @@ async def get_users( ) cached_users.append(user_resp) # 异步缓存,不阻塞响应 - asyncio.create_task(cache_service.cache_user(user_resp)) + background_task.add_task(cache_service.cache_user, user_resp) return BatchUserResponse(users=cached_users) else: @@ -103,7 +95,7 @@ async def get_users( ) users.append(user_resp) # 异步缓存 - asyncio.create_task(cache_service.cache_user(user_resp)) + background_task.add_task(cache_service.cache_user, user_resp) return BatchUserResponse(users=users) @@ -117,6 +109,7 @@ async def get_users( ) async def get_user_info_ruleset( session: Database, + background_task: BackgroundTasks, user_id: str = Path(description="用户 ID 或用户名"), ruleset: GameMode | None = Path(description="指定 ruleset"), # current_user: User = Security(get_current_user, scopes=["public"]), @@ -134,9 +127,7 @@ async def get_user_info_ruleset( searched_user = ( await session.exec( select(User).where( - User.id == int(user_id) - if user_id.isdigit() - else User.username == user_id.removeprefix("@") + User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@") ) ) ).first() @@ -151,7 +142,7 @@ async def get_user_info_ruleset( ) # 异步缓存结果 - asyncio.create_task(cache_service.cache_user(user_resp, ruleset)) + background_task.add_task(cache_service.cache_user, user_resp, ruleset) return user_resp @@ -165,6 +156,7 @@ async def get_user_info_ruleset( tags=["用户"], ) async def get_user_info( + background_task: BackgroundTasks, session: Database, user_id: str = Path(description="用户 ID 或用户名"), # current_user: User = Security(get_current_user, scopes=["public"]), @@ -182,9 +174,7 @@ async def get_user_info( searched_user = ( await session.exec( select(User).where( - User.id == int(user_id) - if user_id.isdigit() - else User.username == user_id.removeprefix("@") + User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@") ) ) ).first() @@ -198,7 +188,7 @@ async def get_user_info( ) # 异步缓存结果 - asyncio.create_task(cache_service.cache_user(user_resp)) + background_task.add_task(cache_service.cache_user, user_resp) return user_resp @@ -212,6 +202,7 @@ async def get_user_info( ) async def get_user_beatmapsets( session: Database, + background_task: BackgroundTasks, user_id: int = Path(description="用户 ID"), type: BeatmapsetType = Path(description="谱面集类型"), current_user: User = Security(get_current_user, scopes=["public"]), @@ -222,9 +213,7 @@ async def get_user_beatmapsets( cache_service = get_user_cache_service(redis) # 先尝试从缓存获取 - cached_result = await cache_service.get_user_beatmapsets_from_cache( - user_id, type.value, limit, offset - ) + cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset) if cached_result is not None: # 根据类型恢复对象 if type == BeatmapsetType.MOST_PLAYED: @@ -253,10 +242,7 @@ async def get_user_beatmapsets( raise HTTPException(404, detail="User not found") favourites = await user.awaitable_attrs.favourite_beatmapsets resp = [ - await BeatmapsetResp.from_db( - favourite.beatmapset, session=session, user=user - ) - for favourite in favourites + await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites ] elif type == BeatmapsetType.MOST_PLAYED: @@ -267,25 +253,18 @@ async def get_user_beatmapsets( .limit(limit) .offset(offset) ) - resp = [ - await BeatmapPlaycountsResp.from_db(most_played_beatmap) - for most_played_beatmap in most_played - ] + resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played] else: raise HTTPException(400, detail="Invalid beatmapset type") # 异步缓存结果 async def cache_beatmapsets(): try: - await cache_service.cache_user_beatmapsets( - user_id, type.value, resp, limit, offset - ) + await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset) except Exception as e: - logger.error( - f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}" - ) + logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}") - asyncio.create_task(cache_beatmapsets()) + background_task.add_task(cache_beatmapsets) return resp @@ -299,18 +278,14 @@ async def get_user_beatmapsets( ) async def get_user_scores( session: Database, + background_task: BackgroundTasks, user_id: int = Path(description="用户 ID"), type: Literal["best", "recent", "firsts", "pinned"] = Path( - description=( - "成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩" - " / firsts 第一名成绩 / pinned 置顶成绩" - ) + description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩") ), legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"), include_fails: bool = Query(False, description="是否包含失败的成绩"), - mode: GameMode | None = Query( - None, description="指定 ruleset (可选,默认为用户主模式)" - ), + mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"), limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"), offset: int = Query(0, ge=0, description="偏移量"), current_user: User = Security(get_current_user, scopes=["public"]), @@ -320,9 +295,7 @@ async def get_user_scores( # 先尝试从缓存获取(对于recent类型使用较短的缓存时间) cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds - cached_scores = await cache_service.get_user_scores_from_cache( - user_id, type, mode, limit, offset - ) + cached_scores = await cache_service.get_user_scores_from_cache(user_id, type, mode, limit, offset) if cached_scores is not None: return cached_scores @@ -332,9 +305,7 @@ async def get_user_scores( gamemode = mode or db_user.playmode order_by = None - where_clause = (col(Score.user_id) == db_user.id) & ( - col(Score.gamemode) == gamemode - ) + where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode) if not include_fails: where_clause &= col(Score.passed).is_(True) if type == "pinned": @@ -351,13 +322,7 @@ async def get_user_scores( where_clause &= false() scores = ( - await session.exec( - select(Score) - .where(where_clause) - .order_by(order_by) - .limit(limit) - .offset(offset) - ) + await session.exec(select(Score).where(where_clause).order_by(order_by).limit(limit).offset(offset)) ).all() if not scores: return [] @@ -371,18 +336,14 @@ async def get_user_scores( ] # 异步缓存结果 - asyncio.create_task( - cache_service.cache_user_scores( - user_id, type, score_responses, mode, limit, offset, cache_expire - ) + background_task.add_task( + cache_service.cache_user_scores, user_id, type, score_responses, mode, limit, offset, cache_expire ) return score_responses -@router.get( - "/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp] -) +@router.get("/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]) async def get_user_events( session: Database, user: int, diff --git a/app/scheduler/cache_scheduler.py b/app/scheduler/cache_scheduler.py index 8edecfb..9a36ddb 100644 --- a/app/scheduler/cache_scheduler.py +++ b/app/scheduler/cache_scheduler.py @@ -59,9 +59,7 @@ class CacheScheduler: # 从配置文件获取间隔设置 check_interval = 5 * 60 # 5分钟检查间隔 beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔 - ranking_cache_interval = ( - settings.ranking_cache_refresh_interval_minutes * 60 - ) # 从配置读取 + ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取 user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔 user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔 diff --git a/app/scheduler/database_cleanup_scheduler.py b/app/scheduler/database_cleanup_scheduler.py index 4dfb21a..1a21bcf 100644 --- a/app/scheduler/database_cleanup_scheduler.py +++ b/app/scheduler/database_cleanup_scheduler.py @@ -5,9 +5,7 @@ from __future__ import annotations import asyncio -from datetime import datetime -from app.config import settings from app.dependencies.database import engine from app.log import logger from app.service.database_cleanup_service import DatabaseCleanupService @@ -51,16 +49,16 @@ class DatabaseCleanupScheduler: try: # 每小时运行一次清理 await asyncio.sleep(3600) # 3600秒 = 1小时 - + if not self.running: break - + await self._run_cleanup() - + except asyncio.CancelledError: break except Exception as e: - logger.error(f"Database cleanup scheduler error: {str(e)}") + logger.error(f"Database cleanup scheduler error: {e!s}") # 发生错误后等待5分钟再继续 await asyncio.sleep(300) @@ -69,20 +67,20 @@ class DatabaseCleanupScheduler: try: async with AsyncSession(engine) as db: logger.debug("Starting scheduled database cleanup...") - + # 清理过期的验证码 expired_codes = await DatabaseCleanupService.cleanup_expired_verification_codes(db) - + # 清理过期的登录会话 expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db) - + # 只在有清理记录时输出总结 total_cleaned = expired_codes + expired_sessions if total_cleaned > 0: logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}") - + except Exception as e: - logger.error(f"Error during scheduled database cleanup: {str(e)}") + logger.error(f"Error during scheduled database cleanup: {e!s}") async def run_manual_cleanup(self): """手动运行完整清理""" @@ -95,7 +93,7 @@ class DatabaseCleanupScheduler: logger.debug(f"Manual cleanup completed, total records cleaned: {total}") return results except Exception as e: - logger.error(f"Error during manual database cleanup: {str(e)}") + logger.error(f"Error during manual database cleanup: {e!s}") return {} diff --git a/app/service/beatmap_cache_service.py b/app/service/beatmap_cache_service.py index 2f76426..5351fd1 100644 --- a/app/service/beatmap_cache_service.py +++ b/app/service/beatmap_cache_service.py @@ -63,10 +63,7 @@ class BeatmapCacheService: if preload_tasks: results = await asyncio.gather(*preload_tasks, return_exceptions=True) success_count = sum(1 for r in results if r is True) - logger.info( - f"Preloaded {success_count}/{len(preload_tasks)} " - f"beatmaps successfully" - ) + logger.info(f"Preloaded {success_count}/{len(preload_tasks)} beatmaps successfully") except Exception as e: logger.error(f"Error during beatmap preloading: {e}") @@ -119,9 +116,7 @@ class BeatmapCacheService: return { "cached_beatmaps": len(keys), - "estimated_total_size_mb": ( - round(total_size / 1024 / 1024, 2) if total_size > 0 else 0 - ), + "estimated_total_size_mb": (round(total_size / 1024 / 1024, 2) if total_size > 0 else 0), "preloading": self._preloading, } except Exception as e: @@ -155,9 +150,7 @@ def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheS return _cache_service -async def schedule_preload_task( - session: AsyncSession, redis: Redis, fetcher: "Fetcher" -): +async def schedule_preload_task(session: AsyncSession, redis: Redis, fetcher: "Fetcher"): """ 定时预加载任务 """ diff --git a/app/service/beatmap_download_service.py b/app/service/beatmap_download_service.py index d90856a..6484fb5 100644 --- a/app/service/beatmap_download_service.py +++ b/app/service/beatmap_download_service.py @@ -192,22 +192,16 @@ class BeatmapDownloadService: healthy_endpoints.sort(key=lambda x: x.priority) return healthy_endpoints - def get_download_url( - self, beatmapset_id: int, no_video: bool, is_china: bool - ) -> str: + def get_download_url(self, beatmapset_id: int, no_video: bool, is_china: bool) -> str: """获取下载URL,带负载均衡和故障转移""" healthy_endpoints = self.get_healthy_endpoints(is_china) if not healthy_endpoints: # 如果没有健康的端点,记录错误并回退到所有端点中优先级最高的 logger.error(f"No healthy endpoints available for is_china={is_china}") - endpoints = ( - self.china_endpoints if is_china else self.international_endpoints - ) + endpoints = self.china_endpoints if is_china else self.international_endpoints if not endpoints: - raise HTTPException( - status_code=503, detail="No download endpoints available" - ) + raise HTTPException(status_code=503, detail="No download endpoints available") endpoint = min(endpoints, key=lambda x: x.priority) else: # 使用第一个健康的端点(已按优先级排序) @@ -218,9 +212,7 @@ class BeatmapDownloadService: video_type = "novideo" if no_video else "full" return endpoint.url_template.format(type=video_type, sid=beatmapset_id) elif endpoint.name == "Nerinyan": - return endpoint.url_template.format( - sid=beatmapset_id, no_video="true" if no_video else "false" - ) + return endpoint.url_template.format(sid=beatmapset_id, no_video="true" if no_video else "false") elif endpoint.name == "OsuDirect": # osu.direct 似乎没有no_video参数,直接使用基础URL return endpoint.url_template.format(sid=beatmapset_id) @@ -239,9 +231,7 @@ class BeatmapDownloadService: for name, status in self.endpoint_status.items(): status_info["endpoints"][name] = { "healthy": status.is_healthy, - "last_check": status.last_check.isoformat() - if status.last_check - else None, + "last_check": status.last_check.isoformat() if status.last_check else None, "consecutive_failures": status.consecutive_failures, "last_error": status.last_error, "priority": status.endpoint.priority, diff --git a/app/service/calculate_all_user_rank.py b/app/service/calculate_all_user_rank.py index bc4e074..3d089c5 100644 --- a/app/service/calculate_all_user_rank.py +++ b/app/service/calculate_all_user_rank.py @@ -11,9 +11,7 @@ from app.models.score import GameMode from sqlmodel import col, exists, select, update -@get_scheduler().scheduled_job( - "cron", hour=0, minute=0, second=0, id="calculate_user_rank" -) +@get_scheduler().scheduled_job("cron", hour=0, minute=0, second=0, id="calculate_user_rank") async def calculate_user_rank(is_today: bool = False): today = datetime.now(UTC).date() target_date = today if is_today else today - timedelta(days=1) diff --git a/app/service/create_banchobot.py b/app/service/create_banchobot.py index dceec18..0393396 100644 --- a/app/service/create_banchobot.py +++ b/app/service/create_banchobot.py @@ -11,9 +11,7 @@ from sqlmodel import exists, select async def create_banchobot(): async with with_db() as session: - is_exist = ( - await session.exec(select(exists()).where(User.id == BANCHOBOT_ID)) - ).first() + is_exist = (await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))).first() if not is_exist: banchobot = User( username="BanchoBot", diff --git a/app/service/daily_challenge.py b/app/service/daily_challenge.py index 7b10874..10117b3 100644 --- a/app/service/daily_challenge.py +++ b/app/service/daily_challenge.py @@ -82,8 +82,7 @@ async def daily_challenge_job(): if beatmap is None or ruleset_id is None: logger.warning( - f"[DailyChallenge] Missing required data for daily challenge {now}." - " Will try again in 5 minutes." + f"[DailyChallenge] Missing required data for daily challenge {now}. Will try again in 5 minutes." ) get_scheduler().add_job( daily_challenge_job, @@ -104,9 +103,7 @@ async def daily_challenge_job(): else: allowed_mods_list = get_available_mods(ruleset_id_int, required_mods_list) - next_day = (now + timedelta(days=1)).replace( - hour=0, minute=0, second=0, microsecond=0 - ) + next_day = (now + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0) room = await create_daily_challenge_room( beatmap=beatmap_int, ruleset_id=ruleset_id_int, @@ -114,24 +111,13 @@ async def daily_challenge_job(): allowed_mods=allowed_mods_list, duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60), ) - await MetadataHubs.broadcast_call( - "DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id) - ) - logger.success( - "[DailyChallenge] Added today's daily challenge: " - f"{beatmap=}, {ruleset_id=}, {required_mods=}" - ) + await MetadataHubs.broadcast_call("DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id)) + logger.success(f"[DailyChallenge] Added today's daily challenge: {beatmap=}, {ruleset_id=}, {required_mods=}") return except (ValueError, json.JSONDecodeError) as e: - logger.warning( - f"[DailyChallenge] Error processing daily challenge data: {e}" - " Will try again in 5 minutes." - ) + logger.warning(f"[DailyChallenge] Error processing daily challenge data: {e} Will try again in 5 minutes.") except Exception as e: - logger.exception( - f"[DailyChallenge] Unexpected error in daily challenge job: {e}" - " Will try again in 5 minutes." - ) + logger.exception(f"[DailyChallenge] Unexpected error in daily challenge job: {e} Will try again in 5 minutes.") get_scheduler().add_job( daily_challenge_job, "date", @@ -139,9 +125,7 @@ async def daily_challenge_job(): ) -@get_scheduler().scheduled_job( - "cron", hour=0, minute=1, second=0, id="daily_challenge_last_top" -) +@get_scheduler().scheduled_job("cron", hour=0, minute=1, second=0, id="daily_challenge_last_top") async def process_daily_challenge_top(): async with with_db() as session: now = datetime.now(UTC) @@ -182,11 +166,7 @@ async def process_daily_challenge_top(): await session.commit() del s - user_ids = ( - await session.exec( - select(User.id).where(col(User.id).not_in(participated_users)) - ) - ).all() + user_ids = (await session.exec(select(User.id).where(col(User.id).not_in(participated_users)))).all() for id in user_ids: stats = await session.get(DailyChallengeStats, id) if stats is None: # not execute diff --git a/app/service/database_cleanup_service.py b/app/service/database_cleanup_service.py index a4558d8..a024c9a 100644 --- a/app/service/database_cleanup_service.py +++ b/app/service/database_cleanup_service.py @@ -4,14 +4,13 @@ from __future__ import annotations -from datetime import datetime, UTC, timedelta +from datetime import UTC, datetime, timedelta from app.database.email_verification import EmailVerification, LoginSession from app.log import logger -from sqlmodel import select +from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession -from sqlalchemy import and_ class DatabaseCleanupService: @@ -21,211 +20,207 @@ class DatabaseCleanupService: async def cleanup_expired_verification_codes(db: AsyncSession) -> int: """ 清理过期的邮件验证码 - + Args: db: 数据库会话 - + Returns: int: 清理的记录数 """ try: # 查找过期的验证码记录 current_time = datetime.now(UTC) - - stmt = select(EmailVerification).where( - EmailVerification.expires_at < current_time - ) + + stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time) result = await db.exec(stmt) expired_codes = result.all() - + # 删除过期的记录 deleted_count = 0 for code in expired_codes: await db.delete(code) deleted_count += 1 - + await db.commit() - + if deleted_count > 0: logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired email verification codes") - + return deleted_count - + except Exception as e: await db.rollback() - logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {str(e)}") + logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {e!s}") return 0 @staticmethod async def cleanup_expired_login_sessions(db: AsyncSession) -> int: """ 清理过期的登录会话 - + Args: db: 数据库会话 - + Returns: int: 清理的记录数 """ try: # 查找过期的登录会话记录 current_time = datetime.now(UTC) - - stmt = select(LoginSession).where( - LoginSession.expires_at < current_time - ) + + stmt = select(LoginSession).where(LoginSession.expires_at < current_time) result = await db.exec(stmt) expired_sessions = result.all() - + # 删除过期的记录 deleted_count = 0 for session in expired_sessions: await db.delete(session) deleted_count += 1 - + await db.commit() - + if deleted_count > 0: logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired login sessions") - + return deleted_count - + except Exception as e: await db.rollback() - logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {str(e)}") + logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {e!s}") return 0 @staticmethod async def cleanup_old_used_verification_codes(db: AsyncSession, days_old: int = 7) -> int: """ 清理旧的已使用验证码记录 - + Args: db: 数据库会话 days_old: 清理多少天前的已使用记录,默认7天 - + Returns: int: 清理的记录数 """ try: # 查找指定天数前的已使用验证码记录 cutoff_time = datetime.now(UTC) - timedelta(days=days_old) - - stmt = select(EmailVerification).where( - EmailVerification.is_used == True - ) + + stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True)) result = await db.exec(stmt) all_used_codes = result.all() - + # 筛选出过期的记录 - old_used_codes = [ - code for code in all_used_codes - if code.used_at and code.used_at < cutoff_time - ] - + old_used_codes = [code for code in all_used_codes if code.used_at and code.used_at < cutoff_time] + # 删除旧的已使用记录 deleted_count = 0 for code in old_used_codes: await db.delete(code) deleted_count += 1 - + await db.commit() - + if deleted_count > 0: - logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days") - + logger.debug( + f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days" + ) + return deleted_count - + except Exception as e: await db.rollback() - logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {str(e)}") + logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}") return 0 @staticmethod async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int: """ 清理旧的已验证会话记录 - + Args: db: 数据库会话 days_old: 清理多少天前的已验证记录,默认30天 - + Returns: int: 清理的记录数 """ try: # 查找指定天数前的已验证会话记录 cutoff_time = datetime.now(UTC) - timedelta(days=days_old) - - stmt = select(LoginSession).where( - LoginSession.is_verified == True - ) + + stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True)) result = await db.exec(stmt) all_verified_sessions = result.all() - + # 筛选出过期的记录 old_verified_sessions = [ - session for session in all_verified_sessions + session + for session in all_verified_sessions if session.verified_at and session.verified_at < cutoff_time ] - + # 删除旧的已验证记录 deleted_count = 0 for session in old_verified_sessions: await db.delete(session) deleted_count += 1 - + await db.commit() - + if deleted_count > 0: - logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days") - + logger.debug( + f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days" + ) + return deleted_count - + except Exception as e: await db.rollback() - logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {str(e)}") + logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {e!s}") return 0 @staticmethod async def run_full_cleanup(db: AsyncSession) -> dict[str, int]: """ 运行完整的清理流程 - + Args: db: 数据库会话 - + Returns: dict: 各项清理的结果统计 """ results = {} - + # 清理过期的验证码 results["expired_verification_codes"] = await DatabaseCleanupService.cleanup_expired_verification_codes(db) - + # 清理过期的登录会话 results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db) - + # 清理7天前的已使用验证码 results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7) - + # 清理30天前的已验证会话 results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30) - + total_cleaned = sum(results.values()) if total_cleaned > 0: - logger.debug(f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}") - + logger.debug( + f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}" + ) + return results @staticmethod async def get_cleanup_statistics(db: AsyncSession) -> dict[str, int]: """ 获取清理统计信息 - + Args: db: 数据库会话 - + Returns: dict: 统计信息 """ @@ -233,57 +228,54 @@ class DatabaseCleanupService: current_time = datetime.now(UTC) cutoff_7_days = current_time - timedelta(days=7) cutoff_30_days = current_time - timedelta(days=30) - + # 统计过期的验证码数量 - expired_codes_stmt = select(EmailVerification).where( - EmailVerification.expires_at < current_time - ) + expired_codes_stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time) expired_codes_result = await db.exec(expired_codes_stmt) expired_codes_count = len(expired_codes_result.all()) - + # 统计过期的登录会话数量 - expired_sessions_stmt = select(LoginSession).where( - LoginSession.expires_at < current_time - ) + expired_sessions_stmt = select(LoginSession).where(LoginSession.expires_at < current_time) expired_sessions_result = await db.exec(expired_sessions_stmt) expired_sessions_count = len(expired_sessions_result.all()) - + # 统计7天前的已使用验证码数量 - old_used_codes_stmt = select(EmailVerification).where( - EmailVerification.is_used == True - ) + old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True)) old_used_codes_result = await db.exec(old_used_codes_stmt) all_used_codes = old_used_codes_result.all() - old_used_codes_count = len([ - code for code in all_used_codes - if code.used_at and code.used_at < cutoff_7_days - ]) - - # 统计30天前的已验证会话数量 - old_verified_sessions_stmt = select(LoginSession).where( - LoginSession.is_verified == True + old_used_codes_count = len( + [code for code in all_used_codes if code.used_at and code.used_at < cutoff_7_days] ) + + # 统计30天前的已验证会话数量 + old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True)) old_verified_sessions_result = await db.exec(old_verified_sessions_stmt) all_verified_sessions = old_verified_sessions_result.all() - old_verified_sessions_count = len([ - session for session in all_verified_sessions - if session.verified_at and session.verified_at < cutoff_30_days - ]) - + old_verified_sessions_count = len( + [ + session + for session in all_verified_sessions + if session.verified_at and session.verified_at < cutoff_30_days + ] + ) + return { "expired_verification_codes": expired_codes_count, "expired_login_sessions": expired_sessions_count, "old_used_verification_codes": old_used_codes_count, "old_verified_sessions": old_verified_sessions_count, - "total_cleanable": expired_codes_count + expired_sessions_count + old_used_codes_count + old_verified_sessions_count + "total_cleanable": expired_codes_count + + expired_sessions_count + + old_used_codes_count + + old_verified_sessions_count, } - + except Exception as e: - logger.error(f"[Cleanup Service] Error getting cleanup statistics: {str(e)}") + logger.error(f"[Cleanup Service] Error getting cleanup statistics: {e!s}") return { "expired_verification_codes": 0, "expired_login_sessions": 0, "old_used_verification_codes": 0, "old_verified_sessions": 0, - "total_cleanable": 0 + "total_cleanable": 0, } diff --git a/app/service/email_queue.py b/app/service/email_queue.py index a583836..7335c8c 100644 --- a/app/service/email_queue.py +++ b/app/service/email_queue.py @@ -8,17 +8,18 @@ from __future__ import annotations import asyncio import concurrent.futures from datetime import datetime -import json -import uuid -import smtplib -from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart -from typing import Dict, Any, Optional -import redis as sync_redis # 添加同步Redis导入 +from email.mime.text import MIMEText +import json +import smtplib +from typing import Any +import uuid from app.config import settings -from app.dependencies.database import redis_message_client # 使用同步Redis客户端 from app.log import logger +from app.utils import bg_tasks # 添加同步Redis导入 + +import redis as sync_redis class EmailQueue: @@ -30,14 +31,14 @@ class EmailQueue: self._processing = False self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) self._retry_limit = 3 # 重试次数限制 - + # 邮件配置 - self.smtp_server = getattr(settings, 'smtp_server', 'localhost') - self.smtp_port = getattr(settings, 'smtp_port', 587) - self.smtp_username = getattr(settings, 'smtp_username', '') - self.smtp_password = getattr(settings, 'smtp_password', '') - self.from_email = getattr(settings, 'from_email', 'noreply@example.com') - self.from_name = getattr(settings, 'from_name', 'osu! server') + self.smtp_server = getattr(settings, "smtp_server", "localhost") + self.smtp_port = getattr(settings, "smtp_port", 587) + self.smtp_username = getattr(settings, "smtp_username", "") + self.smtp_password = getattr(settings, "smtp_password", "") + self.from_email = getattr(settings, "from_email", "noreply@example.com") + self.from_name = getattr(settings, "from_name", "osu! server") async def _run_in_executor(self, func, *args): """在线程池中运行同步操作""" @@ -48,7 +49,7 @@ class EmailQueue: """启动邮件处理任务""" if not self._processing: self._processing = True - asyncio.create_task(self._process_email_queue()) + bg_tasks.add_task(self._process_email_queue) logger.info("Email queue processing started") async def stop_processing(self): @@ -56,27 +57,29 @@ class EmailQueue: self._processing = False logger.info("Email queue processing stopped") - async def enqueue_email(self, - to_email: str, - subject: str, - content: str, - html_content: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None) -> str: + async def enqueue_email( + self, + to_email: str, + subject: str, + content: str, + html_content: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> str: """ 将邮件加入队列等待发送 - + Args: to_email: 收件人邮箱地址 subject: 邮件主题 content: 邮件纯文本内容 html_content: 邮件HTML内容(如果有) metadata: 额外元数据(如密码重置ID等) - + Returns: 邮件任务ID """ email_id = str(uuid.uuid4()) - + email_data = { "id": email_id, "to_email": to_email, @@ -86,125 +89,117 @@ class EmailQueue: "metadata": json.dumps(metadata) if metadata else "{}", "created_at": datetime.now().isoformat(), "status": "pending", # pending, sending, sent, failed - "retry_count": "0" + "retry_count": "0", } - + # 将邮件数据存入Redis - await self._run_in_executor( - lambda: self.redis.hset(f"email:{email_id}", mapping=email_data) - ) - + await self._run_in_executor(lambda: self.redis.hset(f"email:{email_id}", mapping=email_data)) + # 设置24小时过期(防止数据堆积) - await self._run_in_executor( - self.redis.expire, f"email:{email_id}", 86400 - ) - + await self._run_in_executor(self.redis.expire, f"email:{email_id}", 86400) + # 加入发送队列 - await self._run_in_executor( - self.redis.lpush, "email_queue", email_id - ) - + await self._run_in_executor(self.redis.lpush, "email_queue", email_id) + logger.info(f"Email enqueued with id: {email_id} to {to_email}") return email_id - async def get_email_status(self, email_id: str) -> Dict[str, Any]: + async def get_email_status(self, email_id: str) -> dict[str, Any]: """ 获取邮件发送状态 - + Args: email_id: 邮件任务ID - + Returns: 邮件任务状态信息 """ - email_data = await self._run_in_executor( - self.redis.hgetall, f"email:{email_id}" - ) - + email_data = await self._run_in_executor(self.redis.hgetall, f"email:{email_id}") + # 解码Redis返回的字节数据 if email_data: return { - k.decode("utf-8") if isinstance(k, bytes) else k: - v.decode("utf-8") if isinstance(v, bytes) else v + k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") if isinstance(v, bytes) else v for k, v in email_data.items() } - + return {"status": "not_found"} async def _process_email_queue(self): """处理邮件队列""" logger.info("Starting email queue processor") - + while self._processing: try: # 从队列获取邮件ID def brpop_operation(): return self.redis.brpop(["email_queue"], timeout=5) - + result = await self._run_in_executor(brpop_operation) - + if not result: await asyncio.sleep(1) continue - + # 解包返回结果(列表名和值) queue_name, email_id = result if isinstance(email_id, bytes): email_id = email_id.decode("utf-8") - + # 获取邮件数据 email_data = await self.get_email_status(email_id) if email_data.get("status") == "not_found": logger.warning(f"Email data not found for id: {email_id}") continue - + # 更新状态为发送中 - await self._run_in_executor( - self.redis.hset, f"email:{email_id}", "status", "sending" - ) - + await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "sending") + # 尝试发送邮件 success = await self._send_email(email_data) - + if success: # 更新状态为已发送 + await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "sent") await self._run_in_executor( - self.redis.hset, f"email:{email_id}", "status", "sent" - ) - await self._run_in_executor( - self.redis.hset, f"email:{email_id}", "sent_at", datetime.now().isoformat() + self.redis.hset, + f"email:{email_id}", + "sent_at", + datetime.now().isoformat(), ) logger.info(f"Email {email_id} sent successfully to {email_data.get('to_email')}") else: # 计算重试次数 retry_count = int(email_data.get("retry_count", "0")) + 1 - + if retry_count <= self._retry_limit: # 重新入队,稍后重试 await self._run_in_executor( - self.redis.hset, f"email:{email_id}", "retry_count", str(retry_count) + self.redis.hset, + f"email:{email_id}", + "retry_count", + str(retry_count), ) + await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "pending") await self._run_in_executor( - self.redis.hset, f"email:{email_id}", "status", "pending" + self.redis.hset, + f"email:{email_id}", + "last_retry", + datetime.now().isoformat(), ) - await self._run_in_executor( - self.redis.hset, f"email:{email_id}", "last_retry", datetime.now().isoformat() - ) - + # 延迟重试(使用指数退避) delay = 60 * (2 ** (retry_count - 1)) # 1分钟,2分钟,4分钟... - + # 创建延迟任务 - asyncio.create_task(self._delayed_retry(email_id, delay)) - + bg_tasks.add_task(self._delayed_retry, email_id, delay) + logger.warning(f"Email {email_id} will be retried in {delay} seconds (attempt {retry_count})") else: # 超过重试次数,标记为失败 - await self._run_in_executor( - self.redis.hset, f"email:{email_id}", "status", "failed" - ) + await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "failed") logger.error(f"Email {email_id} failed after {retry_count} attempts") - + except Exception as e: logger.error(f"Error processing email queue: {e}") await asyncio.sleep(5) # 出错后等待5秒 @@ -212,53 +207,51 @@ class EmailQueue: async def _delayed_retry(self, email_id: str, delay: int): """延迟重试发送邮件""" await asyncio.sleep(delay) - await self._run_in_executor( - self.redis.lpush, "email_queue", email_id - ) + await self._run_in_executor(self.redis.lpush, "email_queue", email_id) logger.info(f"Re-queued email {email_id} for retry after {delay} seconds") - async def _send_email(self, email_data: Dict[str, Any]) -> bool: + async def _send_email(self, email_data: dict[str, Any]) -> bool: """ 实际发送邮件 - + Args: email_data: 邮件数据 - + Returns: 是否发送成功 """ try: # 如果邮件发送功能被禁用,则只记录日志 - if not getattr(settings, 'enable_email_sending', True): + if not getattr(settings, "enable_email_sending", True): logger.info(f"[Mock Email] Would send to {email_data.get('to_email')}: {email_data.get('subject')}") return True - + # 创建邮件 - msg = MIMEMultipart('alternative') - msg['From'] = f"{self.from_name} <{self.from_email}>" - msg['To'] = email_data.get('to_email', '') - msg['Subject'] = email_data.get('subject', '') - + msg = MIMEMultipart("alternative") + msg["From"] = f"{self.from_name} <{self.from_email}>" + msg["To"] = email_data.get("to_email", "") + msg["Subject"] = email_data.get("subject", "") + # 添加纯文本内容 - content = email_data.get('content', '') + content = email_data.get("content", "") if content: - msg.attach(MIMEText(content, 'plain', 'utf-8')) - + msg.attach(MIMEText(content, "plain", "utf-8")) + # 添加HTML内容(如果有) - html_content = email_data.get('html_content', '') + html_content = email_data.get("html_content", "") if html_content: - msg.attach(MIMEText(html_content, 'html', 'utf-8')) - + msg.attach(MIMEText(html_content, "html", "utf-8")) + # 发送邮件 with smtplib.SMTP(self.smtp_server, self.smtp_port) as server: if self.smtp_username and self.smtp_password: server.starttls() server.login(self.smtp_username, self.smtp_password) - + server.send_message(msg) - + return True - + except Exception as e: logger.error(f"Failed to send email: {e}") return False @@ -267,10 +260,12 @@ class EmailQueue: # 全局邮件队列实例 email_queue = EmailQueue() + # 在应用启动时调用 async def start_email_processor(): await email_queue.start_processing() + # 在应用关闭时调用 async def stop_email_processor(): await email_queue.stop_processing() diff --git a/app/service/email_service.py b/app/service/email_service.py index 0562c4d..12cb70c 100644 --- a/app/service/email_service.py +++ b/app/service/email_service.py @@ -4,13 +4,11 @@ from __future__ import annotations -import smtplib -from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText import secrets +import smtplib import string -from datetime import datetime, UTC, timedelta -from typing import Optional from app.config import settings from app.log import logger @@ -18,28 +16,28 @@ from app.log import logger class EmailService: """邮件发送服务""" - + def __init__(self): - self.smtp_server = getattr(settings, 'smtp_server', 'localhost') - self.smtp_port = getattr(settings, 'smtp_port', 587) - self.smtp_username = getattr(settings, 'smtp_username', '') - self.smtp_password = getattr(settings, 'smtp_password', '') - self.from_email = getattr(settings, 'from_email', 'noreply@example.com') - self.from_name = getattr(settings, 'from_name', 'osu! server') - + self.smtp_server = getattr(settings, "smtp_server", "localhost") + self.smtp_port = getattr(settings, "smtp_port", 587) + self.smtp_username = getattr(settings, "smtp_username", "") + self.smtp_password = getattr(settings, "smtp_password", "") + self.from_email = getattr(settings, "from_email", "noreply@example.com") + self.from_name = getattr(settings, "from_name", "osu! server") + def generate_verification_code(self) -> str: """生成8位验证码""" # 只使用数字,避免混淆 - return ''.join(secrets.choice(string.digits) for _ in range(8)) - + return "".join(secrets.choice(string.digits) for _ in range(8)) + async def send_verification_email(self, email: str, code: str, username: str) -> bool: """发送验证邮件""" try: msg = MIMEMultipart() - msg['From'] = f"{self.from_name} <{self.from_email}>" - msg['To'] = email - msg['Subject'] = "邮箱验证 - Email Verification" - + msg["From"] = f"{self.from_name} <{self.from_email}>" + msg["To"] = email + msg["Subject"] = "邮箱验证 - Email Verification" + # HTML 邮件内容 html_content = f""" @@ -101,15 +99,15 @@ class EmailService:

osu! 邮箱验证

Email Verification

- +

你好 {username}!

感谢你注册我们的 osu! 服务器。为了完成账户验证,请输入以下验证码:

- +
{code}
- +

这个验证码将在 10 分钟后过期

- +
注意:
    @@ -118,19 +116,19 @@ class EmailService:
  • 验证码只能使用一次
- +

如果你有任何问题,请联系我们的支持团队。

- +
- +

Hello {username}!

Thank you for registering on our osu! server. To complete your account verification, please enter the following verification code:

- +

This verification code will expire in 10 minutes.

- +

Important: Do not share this verification code with anyone. If you did not request this code, please ignore this email.

- + - """ - - msg.attach(MIMEText(html_content, 'html', 'utf-8')) - + """ # noqa: E501 + + msg.attach(MIMEText(html_content, "html", "utf-8")) + # 发送邮件 if not settings.enable_email_sending: # 邮件发送功能禁用时只记录日志,不实际发送 logger.info(f"[Email Verification] Mock sending verification code to {email}: {code}") return True - + with smtplib.SMTP(self.smtp_server, self.smtp_port) as server: if self.smtp_username and self.smtp_password: server.starttls() server.login(self.smtp_username, self.smtp_password) - + server.send_message(msg) - + logger.info(f"[Email Verification] Successfully sent verification code to {email}") return True - + except Exception as e: logger.error(f"[Email Verification] Failed to send email: {e}") return False diff --git a/app/service/email_verification_service.py b/app/service/email_verification_service.py index f37cffe..ef32fcb 100644 --- a/app/service/email_verification_service.py +++ b/app/service/email_verification_service.py @@ -4,40 +4,38 @@ from __future__ import annotations +from datetime import UTC, datetime, timedelta import secrets import string -from datetime import datetime, UTC, timedelta -from typing import Optional -from app.database.email_verification import EmailVerification, LoginSession -from app.service.email_service import email_service -from app.service.email_queue import email_queue # 导入邮件队列 -from app.log import logger from app.config import settings +from app.database.email_verification import EmailVerification, LoginSession +from app.log import logger +from app.service.email_queue import email_queue # 导入邮件队列 -from sqlmodel.ext.asyncio.session import AsyncSession -from sqlmodel import select from redis.asyncio import Redis +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession class EmailVerificationService: """邮件验证服务""" - + @staticmethod def generate_verification_code() -> str: """生成8位验证码""" - return ''.join(secrets.choice(string.digits) for _ in range(8)) - + return "".join(secrets.choice(string.digits) for _ in range(8)) + @staticmethod async def send_verification_email_via_queue(email: str, code: str, username: str, user_id: int) -> bool: """使用邮件队列发送验证邮件 - + Args: email: 接收验证码的邮箱地址 code: 验证码 username: 用户名 user_id: 用户ID - + Returns: 是否成功将邮件加入队列 """ @@ -103,15 +101,15 @@ class EmailVerificationService:

osu! 邮箱验证

Email Verification

- +

你好 {username}!

请使用以下验证码验证您的账户:

- +
{code}
- +

验证码将在 10 分钟内有效

- +

重要提示:

    @@ -120,17 +118,17 @@ class EmailVerificationService:
  • 为了账户安全,请勿在其他网站使用相同的密码
- +
- +

Hello {username}!

Please use the following verification code to verify your account:

- +

This verification code will be valid for 10 minutes.

- +

Important: Do not share this verification code with anyone. If you did not request this code, please ignore this email.

- + - """ - + """ # noqa: E501 + # 纯文本备用内容 plain_content = f""" 你好 {username}! @@ -162,34 +160,30 @@ This verification code will be valid for 10 minutes. © 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。 This email was sent automatically, please do not reply. """ - + # 将邮件加入队列 subject = "邮箱验证 - Email Verification" - metadata = { - "type": "email_verification", - "user_id": user_id, - "code": code - } - + metadata = {"type": "email_verification", "user_id": user_id, "code": code} + await email_queue.enqueue_email( to_email=email, subject=subject, content=plain_content, html_content=html_content, - metadata=metadata + metadata=metadata, ) - + return True - + except Exception as e: logger.error(f"[Email Verification] Failed to enqueue email: {e}") return False - + @staticmethod def generate_session_token() -> str: """生成会话令牌""" return secrets.token_urlsafe(32) - + @staticmethod async def create_verification_record( db: AsyncSession, @@ -197,27 +191,27 @@ This email was sent automatically, please do not reply. user_id: int, email: str, ip_address: str | None = None, - user_agent: str | None = None + user_agent: str | None = None, ) -> tuple[EmailVerification, str]: """创建邮件验证记录""" - + # 检查是否有未过期的验证码 existing_result = await db.exec( select(EmailVerification).where( EmailVerification.user_id == user_id, - EmailVerification.is_used == False, - EmailVerification.expires_at > datetime.now(UTC) + col(EmailVerification.is_used).is_(False), + EmailVerification.expires_at > datetime.now(UTC), ) ) existing = existing_result.first() - + if existing: # 如果有未过期的验证码,直接返回 return existing, existing.verification_code - + # 生成新的验证码 code = EmailVerificationService.generate_verification_code() - + # 创建验证记录 verification = EmailVerification( user_id=user_id, @@ -225,23 +219,23 @@ This email was sent automatically, please do not reply. verification_code=code, expires_at=datetime.now(UTC) + timedelta(minutes=10), # 10分钟过期 ip_address=ip_address, - user_agent=user_agent + user_agent=user_agent, ) - + db.add(verification) await db.commit() await db.refresh(verification) - + # 存储到 Redis(用于快速验证) await redis.setex( f"email_verification:{user_id}:{code}", 600, # 10分钟过期 - str(verification.id) if verification.id else "0" + str(verification.id) if verification.id else "0", ) - + logger.info(f"[Email Verification] Created verification code for user {user_id}: {code}") return verification, code - + @staticmethod async def send_verification_email( db: AsyncSession, @@ -250,7 +244,7 @@ This email was sent automatically, please do not reply. username: str, email: str, ip_address: str | None = None, - user_agent: str | None = None + user_agent: str | None = None, ) -> bool: """发送验证邮件""" try: @@ -258,33 +252,38 @@ This email was sent automatically, please do not reply. if not settings.enable_email_verification: logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}") return True # 返回成功,但不执行验证流程 - + # 创建验证记录 - verification, code = await EmailVerificationService.create_verification_record( + ( + verification, + code, + ) = await EmailVerificationService.create_verification_record( db, redis, user_id, email, ip_address, user_agent ) - + # 使用邮件队列发送验证邮件 success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id) - + if success: - logger.info(f"[Email Verification] Successfully enqueued verification email to {email} (user: {username})") + logger.info( + f"[Email Verification] Successfully enqueued verification email to {email} (user: {username})" + ) return True else: logger.error(f"[Email Verification] Failed to enqueue verification email: {email} (user: {username})") return False - + except Exception as e: logger.error(f"[Email Verification] Exception during sending verification email: {e}") return False - + @staticmethod async def verify_code( db: AsyncSession, redis: Redis, user_id: int, code: str, - ip_address: str | None = None + ip_address: str | None = None, ) -> tuple[bool, str]: """验证验证码""" try: @@ -294,46 +293,46 @@ This email was sent automatically, please do not reply. # 仍然标记登录会话为已验证 await LoginSessionService.mark_session_verified(db, user_id) return True, "验证成功(邮件验证功能已禁用)" - + # 先从 Redis 检查 verification_id = await redis.get(f"email_verification:{user_id}:{code}") if not verification_id: return False, "验证码无效或已过期" - + # 从数据库获取验证记录 result = await db.exec( select(EmailVerification).where( EmailVerification.id == int(verification_id), EmailVerification.user_id == user_id, EmailVerification.verification_code == code, - EmailVerification.is_used == False, - EmailVerification.expires_at > datetime.now(UTC) + col(EmailVerification.is_used).is_(False), + EmailVerification.expires_at > datetime.now(UTC), ) ) - + verification = result.first() if not verification: return False, "验证码无效或已过期" - + # 标记为已使用 verification.is_used = True verification.used_at = datetime.now(UTC) - + # 同时更新对应的登录会话状态 await LoginSessionService.mark_session_verified(db, user_id) - + await db.commit() - + # 删除 Redis 记录 await redis.delete(f"email_verification:{user_id}:{code}") - + logger.info(f"[Email Verification] User {user_id} verification code verified successfully") return True, "验证成功" - + except Exception as e: logger.error(f"[Email Verification] Exception during verification code validation: {e}") return False, "验证过程中发生错误" - + @staticmethod async def resend_verification_code( db: AsyncSession, @@ -342,7 +341,7 @@ This email was sent automatically, please do not reply. username: str, email: str, ip_address: str | None = None, - user_agent: str | None = None + user_agent: str | None = None, ) -> tuple[bool, str]: """重新发送验证码""" try: @@ -350,25 +349,25 @@ This email was sent automatically, please do not reply. if not settings.enable_email_verification: logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}") return True, "验证码已发送(邮件验证功能已禁用)" - + # 检查重发频率限制(60秒内只能发送一次) rate_limit_key = f"email_verification_rate_limit:{user_id}" if await redis.get(rate_limit_key): return False, "请等待60秒后再重新发送" - + # 设置频率限制 await redis.setex(rate_limit_key, 60, "1") - + # 生成新的验证码 success = await EmailVerificationService.send_verification_email( db, redis, user_id, username, email, ip_address, user_agent ) - + if success: return True, "验证码已重新发送" else: return False, "重新发送失败,请稍后再试" - + except Exception as e: logger.error(f"[Email Verification] Exception during resending verification code: {e}") return False, "重新发送过程中发生错误" @@ -376,7 +375,7 @@ This email was sent automatically, please do not reply. class LoginSessionService: """登录会话服务""" - + @staticmethod async def create_session( db: AsyncSession, @@ -385,47 +384,40 @@ class LoginSessionService: ip_address: str, user_agent: str | None = None, country_code: str | None = None, - is_new_location: bool = False + is_new_location: bool = False, ) -> LoginSession: """创建登录会话""" - from app.utils import simplify_user_agent - + session_token = EmailVerificationService.generate_session_token() - - # 简化 User-Agent 字符串 - simplified_user_agent = simplify_user_agent(user_agent, max_length=250) - + session = LoginSession( user_id=user_id, session_token=session_token, ip_address=ip_address, - user_agent=simplified_user_agent, + user_agent=None, country_code=country_code, is_new_location=is_new_location, expires_at=datetime.now(UTC) + timedelta(hours=24), # 24小时过期 - is_verified=not is_new_location # 新位置需要验证 + is_verified=not is_new_location, # 新位置需要验证 ) - + db.add(session) await db.commit() await db.refresh(session) - + # 存储到 Redis await redis.setex( f"login_session:{session_token}", 86400, # 24小时 - user_id + user_id, ) - + logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})") return session - + @staticmethod async def verify_session( - db: AsyncSession, - redis: Redis, - session_token: str, - verification_code: str + db: AsyncSession, redis: Redis, session_token: str, verification_code: str ) -> tuple[bool, str]: """验证会话(通过邮件验证码)""" try: @@ -433,98 +425,89 @@ class LoginSessionService: user_id = await redis.get(f"login_session:{session_token}") if not user_id: return False, "会话无效或已过期" - + user_id = int(user_id) - + # 验证邮件验证码 - success, message = await EmailVerificationService.verify_code( - db, redis, user_id, verification_code - ) - + success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_code) + if not success: return False, message - + # 更新会话状态 result = await db.exec( select(LoginSession).where( LoginSession.session_token == session_token, LoginSession.user_id == user_id, - LoginSession.is_verified == False + col(LoginSession.is_verified).is_(False), ) ) - + session = result.first() if session: session.is_verified = True session.verified_at = datetime.now(UTC) await db.commit() - + logger.info(f"[Login Session] User {user_id} session verification successful") return True, "会话验证成功" - + except Exception as e: logger.error(f"[Login Session] Exception during session verification: {e}") return False, "验证过程中发生错误" - + @staticmethod async def check_new_location( - db: AsyncSession, - user_id: int, - ip_address: str, - country_code: str | None = None + db: AsyncSession, user_id: int, ip_address: str, country_code: str | None = None ) -> bool: """检查是否为新位置登录""" try: # 查看过去30天内是否有相同IP或相同国家的登录记录 thirty_days_ago = datetime.now(UTC) - timedelta(days=30) - + result = await db.exec( select(LoginSession).where( LoginSession.user_id == user_id, LoginSession.created_at > thirty_days_ago, - (LoginSession.ip_address == ip_address) | - (LoginSession.country_code == country_code) + (LoginSession.ip_address == ip_address) | (LoginSession.country_code == country_code), ) ) - + existing_sessions = result.all() - + # 如果有历史记录,则不是新位置 return len(existing_sessions) == 0 - + except Exception as e: logger.error(f"[Login Session] Exception during new location check: {e}") # 出错时默认为新位置(更安全) return True @staticmethod - async def mark_session_verified( - db: AsyncSession, - user_id: int - ) -> bool: + async def mark_session_verified(db: AsyncSession, user_id: int) -> bool: """标记用户的未验证会话为已验证""" try: # 查找用户所有未验证且未过期的会话 result = await db.exec( select(LoginSession).where( LoginSession.user_id == user_id, - LoginSession.is_verified == False, - LoginSession.expires_at > datetime.now(UTC) + col(LoginSession.is_verified).is_(False), + LoginSession.expires_at > datetime.now(UTC), ) ) - + sessions = result.all() - + # 标记所有会话为已验证 for session in sessions: session.is_verified = True session.verified_at = datetime.now(UTC) - + if sessions: logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}") - + return len(sessions) > 0 - + except Exception as e: logger.error(f"[Login Session] Exception during marking sessions as verified: {e}") return False diff --git a/app/service/enhanced_interval_stats.py b/app/service/enhanced_interval_stats.py index 6ea6f17..3da85d7 100644 --- a/app/service/enhanced_interval_stats.py +++ b/app/service/enhanced_interval_stats.py @@ -117,14 +117,10 @@ class EnhancedIntervalStatsManager: @staticmethod async def get_current_interval_info() -> IntervalInfo: """获取当前区间信息""" - start_time, end_time = ( - EnhancedIntervalStatsManager.get_current_interval_boundaries() - ) + start_time, end_time = EnhancedIntervalStatsManager.get_current_interval_boundaries() interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time) - return IntervalInfo( - start_time=start_time, end_time=end_time, interval_key=interval_key - ) + return IntervalInfo(start_time=start_time, end_time=end_time, interval_key=interval_key) @staticmethod async def initialize_current_interval() -> None: @@ -133,9 +129,7 @@ class EnhancedIntervalStatsManager: redis_async = get_redis() try: - current_interval = ( - await EnhancedIntervalStatsManager.get_current_interval_info() - ) + current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() # 存储当前区间信息 await _redis_exec( @@ -147,9 +141,7 @@ class EnhancedIntervalStatsManager: # 初始化区间用户集合(如果不存在) online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}" - playing_key = ( - f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" - ) + playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" # 设置过期时间为35分钟 await redis_async.expire(online_key, 35 * 60) @@ -179,7 +171,8 @@ class EnhancedIntervalStatsManager: await EnhancedIntervalStatsManager._ensure_24h_history_exists() logger.info( - f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')} - {current_interval.end_time.strftime('%H:%M')}" + f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')}" + f" - {current_interval.end_time.strftime('%H:%M')}" ) except Exception as e: @@ -193,42 +186,32 @@ class EnhancedIntervalStatsManager: try: # 检查现有历史数据数量 - history_length = await _redis_exec( - redis_sync.llen, REDIS_ONLINE_HISTORY_KEY - ) + history_length = await _redis_exec(redis_sync.llen, REDIS_ONLINE_HISTORY_KEY) if history_length < 48: # 少于48个数据点(24小时*2) - logger.info( - f"History has only {history_length} points, filling with zeros for 24h" - ) + logger.info(f"History has only {history_length} points, filling with zeros for 24h") # 计算需要填充的数据点数量 needed_points = 48 - history_length # 从当前时间往前推,创建缺失的时间点(都填充为0) - current_time = datetime.utcnow() - current_interval_start, _ = ( - EnhancedIntervalStatsManager.get_current_interval_boundaries() - ) + current_time = datetime.utcnow() # noqa: F841 + current_interval_start, _ = EnhancedIntervalStatsManager.get_current_interval_boundaries() # 从当前区间开始往前推,创建历史数据点(确保时间对齐到30分钟边界) fill_points = [] for i in range(needed_points): # 每次往前推30分钟,确保时间对齐 - point_time = current_interval_start - timedelta( - minutes=30 * (i + 1) - ) + point_time = current_interval_start - timedelta(minutes=30 * (i + 1)) # 确保时间对齐到30分钟边界 aligned_minute = (point_time.minute // 30) * 30 - point_time = point_time.replace( - minute=aligned_minute, second=0, microsecond=0 - ) + point_time = point_time.replace(minute=aligned_minute, second=0, microsecond=0) history_point = { "timestamp": point_time.isoformat(), "online_count": 0, - "playing_count": 0 + "playing_count": 0, } fill_points.append(json.dumps(history_point)) @@ -238,9 +221,7 @@ class EnhancedIntervalStatsManager: temp_key = f"{REDIS_ONLINE_HISTORY_KEY}_temp" if history_length > 0: # 复制现有数据到临时key - existing_data = await _redis_exec( - redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1 - ) + existing_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1) if existing_data: for data in existing_data: await _redis_exec(redis_sync.rpush, temp_key, data) @@ -250,19 +231,13 @@ class EnhancedIntervalStatsManager: # 先添加填充数据(最旧的) for point in reversed(fill_points): # 反向添加,最旧的在最后 - await _redis_exec( - redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point - ) + await _redis_exec(redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point) # 再添加原有数据(较新的) if history_length > 0: - existing_data = await _redis_exec( - redis_sync.lrange, temp_key, 0, -1 - ) + existing_data = await _redis_exec(redis_sync.lrange, temp_key, 0, -1) for data in existing_data: - await _redis_exec( - redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data - ) + await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data) # 清理临时key await redis_async.delete(temp_key) @@ -273,9 +248,7 @@ class EnhancedIntervalStatsManager: # 设置过期时间 await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600) - logger.info( - f"Filled {len(fill_points)} historical data points with zeros" - ) + logger.info(f"Filled {len(fill_points)} historical data points with zeros") except Exception as e: logger.error(f"Error ensuring 24h history exists: {e}") @@ -287,9 +260,7 @@ class EnhancedIntervalStatsManager: redis_async = get_redis() try: - current_interval = ( - await EnhancedIntervalStatsManager.get_current_interval_info() - ) + current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() # 添加到区间在线用户集合 online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}" @@ -298,9 +269,7 @@ class EnhancedIntervalStatsManager: # 如果用户在游玩,也添加到游玩用户集合 if is_playing: - playing_key = ( - f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" - ) + playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" await _redis_exec(redis_sync.sadd, playing_key, str(user_id)) await redis_async.expire(playing_key, 35 * 60) @@ -308,7 +277,8 @@ class EnhancedIntervalStatsManager: await EnhancedIntervalStatsManager._update_interval_stats() logger.debug( - f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}-{current_interval.end_time.strftime('%H:%M')}" + f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}" + f"-{current_interval.end_time.strftime('%H:%M')}" ) except Exception as e: @@ -321,15 +291,11 @@ class EnhancedIntervalStatsManager: redis_async = get_redis() try: - current_interval = ( - await EnhancedIntervalStatsManager.get_current_interval_info() - ) + current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() # 获取区间内独特用户数 online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}" - playing_key = ( - f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" - ) + playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}" unique_online = await _redis_exec(redis_sync.scard, online_key) unique_playing = await _redis_exec(redis_sync.scard, playing_key) @@ -339,16 +305,12 @@ class EnhancedIntervalStatsManager: current_playing = await _get_playing_users_count(redis_async) # 获取现有统计数据 - existing_data = await _redis_exec( - redis_sync.get, current_interval.interval_key - ) + existing_data = await _redis_exec(redis_sync.get, current_interval.interval_key) if existing_data: stats = IntervalStats.from_dict(json.loads(existing_data)) # 更新峰值 stats.peak_online_count = max(stats.peak_online_count, current_online) - stats.peak_playing_count = max( - stats.peak_playing_count, current_playing - ) + stats.peak_playing_count = max(stats.peak_playing_count, current_playing) stats.total_samples += 1 else: # 创建新的统计记录 @@ -377,7 +339,8 @@ class EnhancedIntervalStatsManager: await redis_async.expire(current_interval.interval_key, 35 * 60) logger.debug( - f"Updated interval stats: online={unique_online}, playing={unique_playing}, peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}" + f"Updated interval stats: online={unique_online}, playing={unique_playing}, " + f"peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}" ) except Exception as e: @@ -395,21 +358,21 @@ class EnhancedIntervalStatsManager: # 上一个区间开始时间是当前区间开始时间减去30分钟 previous_start = current_start - timedelta(minutes=30) previous_end = current_start # 上一个区间的结束时间就是当前区间的开始时间 - + interval_key = EnhancedIntervalStatsManager.generate_interval_key(previous_start) - + previous_interval = IntervalInfo( start_time=previous_start, end_time=previous_end, - interval_key=interval_key + interval_key=interval_key, ) # 获取最终统计数据 - stats_data = await _redis_exec( - redis_sync.get, previous_interval.interval_key - ) + stats_data = await _redis_exec(redis_sync.get, previous_interval.interval_key) if not stats_data: - logger.warning(f"No interval stats found to finalize for {previous_interval.start_time.strftime('%H:%M')}") + logger.warning( + f"No interval stats found to finalize for {previous_interval.start_time.strftime('%H:%M')}" + ) return None stats = IntervalStats.from_dict(json.loads(stats_data)) @@ -418,13 +381,11 @@ class EnhancedIntervalStatsManager: history_point = { "timestamp": previous_interval.start_time.isoformat(), "online_count": stats.unique_online_users, - "playing_count": stats.unique_playing_users + "playing_count": stats.unique_playing_users, } # 添加到历史记录 - await _redis_exec( - redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point) - ) + await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)) # 只保留48个数据点(24小时,每30分钟一个点) await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47) # 设置过期时间为26小时,确保有足够缓冲 @@ -452,12 +413,8 @@ class EnhancedIntervalStatsManager: redis_sync = get_redis_message() try: - current_interval = ( - await EnhancedIntervalStatsManager.get_current_interval_info() - ) - stats_data = await _redis_exec( - redis_sync.get, current_interval.interval_key - ) + current_interval = await EnhancedIntervalStatsManager.get_current_interval_info() + stats_data = await _redis_exec(redis_sync.get, current_interval.interval_key) if stats_data: return IntervalStats.from_dict(json.loads(stats_data)) @@ -506,8 +463,6 @@ class EnhancedIntervalStatsManager: # 便捷函数,用于替换现有的统计更新函数 -async def update_user_activity_in_interval( - user_id: int, is_playing: bool = False -) -> None: +async def update_user_activity_in_interval(user_id: int, is_playing: bool = False) -> None: """用户活动时更新区间统计(在登录、开始游玩等时调用)""" await EnhancedIntervalStatsManager.add_user_to_interval(user_id, is_playing) diff --git a/app/service/load_achievements.py b/app/service/load_achievements.py index 6d20503..e491274 100644 --- a/app/service/load_achievements.py +++ b/app/service/load_achievements.py @@ -11,12 +11,8 @@ def load_achievements() -> Medals: for module in ACHIEVEMENTS_DIR.iterdir(): if module.is_file() and module.suffix == ".py": module_name = module.stem - module_achievements = importlib.import_module( - f"app.achievements.{module_name}" - ) + module_achievements = importlib.import_module(f"app.achievements.{module_name}") medals = getattr(module_achievements, "MEDALS", {}) MEDALS.update(medals) - logger.success( - f"Successfully loaded {len(medals)} achievements from {module_name}.py" - ) + logger.success(f"Successfully loaded {len(medals)} achievements from {module_name}.py") return MEDALS diff --git a/app/service/login_log_service.py b/app/service/login_log_service.py index 8f2298d..b8b6766 100644 --- a/app/service/login_log_service.py +++ b/app/service/login_log_service.py @@ -47,6 +47,7 @@ class LoginLogService: # 获取并简化User-Agent from app.utils import simplify_user_agent + raw_user_agent = request.headers.get("User-Agent", "") user_agent = simplify_user_agent(raw_user_agent, max_length=500) @@ -67,9 +68,7 @@ class LoginLogService: # 在后台线程中运行GeoIP查询(避免阻塞) loop = asyncio.get_event_loop() - geo_info = await loop.run_in_executor( - None, lambda: geoip.lookup(ip_address) - ) + geo_info = await loop.run_in_executor(None, lambda: geoip.lookup(ip_address)) if geo_info: login_log.country_code = geo_info.get("country_iso", "") @@ -89,10 +88,7 @@ class LoginLogService: login_log.organization = geo_info.get("organization", "") - logger.debug( - f"GeoIP lookup for {ip_address}: " - f"{geo_info.get('country_name', 'Unknown')}" - ) + logger.debug(f"GeoIP lookup for {ip_address}: {geo_info.get('country_name', 'Unknown')}") else: logger.warning(f"GeoIP lookup failed for {ip_address}") @@ -104,9 +100,7 @@ class LoginLogService: await db.commit() await db.refresh(login_log) - logger.info( - f"Login recorded for user {user_id} from {ip_address} ({login_method})" - ) + logger.info(f"Login recorded for user {user_id} from {ip_address} ({login_method})") return login_log @staticmethod @@ -137,9 +131,7 @@ class LoginLogService: request=request, login_success=False, login_method=login_method, - notes=f"Failed login attempt: {attempted_username}" - if attempted_username - else "Failed login attempt", + notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt", ) diff --git a/app/service/message_queue.py b/app/service/message_queue.py index 0665796..4f1c0f1 100644 --- a/app/service/message_queue.py +++ b/app/service/message_queue.py @@ -13,6 +13,7 @@ import uuid from app.database.chat import ChatMessage, MessageType from app.dependencies.database import get_redis, with_db from app.log import logger +from app.utils import bg_tasks class MessageQueue: @@ -34,7 +35,7 @@ class MessageQueue: """启动消息处理任务""" if not self._processing: self._processing = True - asyncio.create_task(self._process_message_queue()) + bg_tasks.add_task(self._process_message_queue) logger.info("Message queue processing started") async def stop_processing(self): @@ -59,12 +60,8 @@ class MessageQueue: message_data["status"] = "pending" # pending, processing, completed, failed # 将消息存储到 Redis - await self._run_in_executor( - lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data) - ) - await self._run_in_executor( - self.redis.expire, f"msg:{temp_uuid}", 3600 - ) # 1小时过期 + await self._run_in_executor(lambda: self.redis.hset(f"msg:{temp_uuid}", mapping=message_data)) + await self._run_in_executor(self.redis.expire, f"msg:{temp_uuid}", 3600) # 1小时过期 # 加入处理队列 await self._run_in_executor(self.redis.lpush, "message_queue", temp_uuid) @@ -74,17 +71,13 @@ class MessageQueue: async def get_message_status(self, temp_uuid: str) -> dict | None: """获取消息状态""" - message_data = await self._run_in_executor( - self.redis.hgetall, f"msg:{temp_uuid}" - ) + message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}") if not message_data: return None return message_data - async def get_cached_messages( - self, channel_id: int, limit: int = 50, since: int = 0 - ) -> list[dict]: + async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: """ 从 Redis 获取缓存的消息 @@ -97,15 +90,11 @@ class MessageQueue: 消息列表 """ # 从 Redis 获取频道最近的消息 UUID 列表 - message_uuids = await self._run_in_executor( - self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1 - ) + message_uuids = await self._run_in_executor(self.redis.lrange, f"channel:{channel_id}:messages", 0, limit - 1) messages = [] for uuid_str in message_uuids: - message_data = await self._run_in_executor( - self.redis.hgetall, f"msg:{uuid_str}" - ) + message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{uuid_str}") if message_data: # 检查是否满足 since 条件 if since > 0 and "message_id" in message_data: @@ -116,22 +105,14 @@ class MessageQueue: return messages[::-1] # 返回时间顺序 - async def cache_channel_message( - self, channel_id: int, temp_uuid: str, max_cache: int = 100 - ): + async def cache_channel_message(self, channel_id: int, temp_uuid: str, max_cache: int = 100): """将消息 UUID 缓存到频道消息列表""" # 添加到频道消息列表开头 - await self._run_in_executor( - self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid - ) + await self._run_in_executor(self.redis.lpush, f"channel:{channel_id}:messages", temp_uuid) # 限制缓存大小 - await self._run_in_executor( - self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1 - ) + await self._run_in_executor(self.redis.ltrim, f"channel:{channel_id}:messages", 0, max_cache - 1) # 设置过期时间(24小时) - await self._run_in_executor( - self.redis.expire, f"channel:{channel_id}:messages", 86400 - ) + await self._run_in_executor(self.redis.expire, f"channel:{channel_id}:messages", 86400) async def _process_message_queue(self): """异步处理消息队列,批量写入数据库""" @@ -140,9 +121,7 @@ class MessageQueue: # 批量获取消息 message_uuids = [] for _ in range(self._batch_size): - result = await self._run_in_executor( - lambda: self.redis.brpop(["message_queue"], timeout=1) - ) + result = await self._run_in_executor(lambda: self.redis.brpop(["message_queue"], timeout=1)) if result: message_uuids.append(result[1]) else: @@ -166,16 +145,12 @@ class MessageQueue: for temp_uuid in message_uuids: try: # 获取消息数据 - message_data = await self._run_in_executor( - self.redis.hgetall, f"msg:{temp_uuid}" - ) + message_data = await self._run_in_executor(self.redis.hgetall, f"msg:{temp_uuid}") if not message_data: continue # 更新状态为处理中 - await self._run_in_executor( - self.redis.hset, f"msg:{temp_uuid}", "status", "processing" - ) + await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "processing") # 创建数据库消息对象 msg = ChatMessage( @@ -190,9 +165,7 @@ class MessageQueue: except Exception as e: logger.error(f"Error preparing message {temp_uuid}: {e}") - await self._run_in_executor( - self.redis.hset, f"msg:{temp_uuid}", "status", "failed" - ) + await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed") if messages_to_insert: try: @@ -211,16 +184,12 @@ class MessageQueue: mapping={ "status": "completed", "message_id": str(msg.message_id), - "created_at": msg.timestamp.isoformat() - if msg.timestamp - else "", + "created_at": msg.timestamp.isoformat() if msg.timestamp else "", }, ) ) - logger.info( - f"Message {temp_uuid} persisted to DB with ID {msg.message_id}" - ) + logger.info(f"Message {temp_uuid} persisted to DB with ID {msg.message_id}") except Exception as e: logger.error(f"Error inserting messages to database: {e}") @@ -228,9 +197,7 @@ class MessageQueue: # 标记所有消息为失败 for _, temp_uuid in messages_to_insert: - await self._run_in_executor( - self.redis.hset, f"msg:{temp_uuid}", "status", "failed" - ) + await self._run_in_executor(self.redis.hset, f"msg:{temp_uuid}", "status", "failed") # 全局消息队列实例 diff --git a/app/service/message_queue_processor.py b/app/service/message_queue_processor.py index 3ecce55..41ef9bf 100644 --- a/app/service/message_queue_processor.py +++ b/app/service/message_queue_processor.py @@ -33,36 +33,22 @@ class MessageQueueProcessor: """将消息缓存到 Redis""" try: # 存储消息数据 - await self._redis_exec( - self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data - ) - await self._redis_exec( - self.redis_message.expire, f"msg:{temp_uuid}", 3600 - ) # 1小时过期 + await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=message_data) + await self._redis_exec(self.redis_message.expire, f"msg:{temp_uuid}", 3600) # 1小时过期 # 加入频道消息列表 - await self._redis_exec( - self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid - ) - await self._redis_exec( - self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99 - ) # 保持最新100条 - await self._redis_exec( - self.redis_message.expire, f"channel:{channel_id}:messages", 86400 - ) # 24小时过期 + await self._redis_exec(self.redis_message.lpush, f"channel:{channel_id}:messages", temp_uuid) + await self._redis_exec(self.redis_message.ltrim, f"channel:{channel_id}:messages", 0, 99) # 保持最新100条 + await self._redis_exec(self.redis_message.expire, f"channel:{channel_id}:messages", 86400) # 24小时过期 # 加入异步处理队列 - await self._redis_exec( - self.redis_message.lpush, "message_write_queue", temp_uuid - ) + await self._redis_exec(self.redis_message.lpush, "message_write_queue", temp_uuid) logger.info(f"Message cached to Redis: {temp_uuid}") except Exception as e: logger.error(f"Failed to cache message to Redis: {e}") - async def get_cached_messages( - self, channel_id: int, limit: int = 50, since: int = 0 - ) -> list[dict]: + async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: """从 Redis 获取缓存的消息""" try: message_uuids = await self._redis_exec( @@ -78,15 +64,11 @@ class MessageQueueProcessor: if isinstance(temp_uuid, bytes): temp_uuid = temp_uuid.decode("utf-8") - raw_data = await self._redis_exec( - self.redis_message.hgetall, f"msg:{temp_uuid}" - ) + raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}") if raw_data: # 解码 Redis 返回的字节数据 message_data = { - k.decode("utf-8") if isinstance(k, bytes) else k: v.decode( - "utf-8" - ) + k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") if isinstance(v, bytes) else v for k, v in raw_data.items() @@ -103,9 +85,7 @@ class MessageQueueProcessor: logger.error(f"Failed to get cached messages: {e}") return [] - async def update_message_status( - self, temp_uuid: str, status: str, message_id: int | None = None - ): + async def update_message_status(self, temp_uuid: str, status: str, message_id: int | None = None): """更新消息状态""" try: update_data = {"status": status} @@ -113,26 +93,20 @@ class MessageQueueProcessor: update_data["message_id"] = str(message_id) update_data["db_timestamp"] = datetime.now().isoformat() - await self._redis_exec( - self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data - ) + await self._redis_exec(self.redis_message.hset, f"msg:{temp_uuid}", mapping=update_data) except Exception as e: logger.error(f"Failed to update message status: {e}") async def get_message_status(self, temp_uuid: str) -> dict | None: """获取消息状态""" try: - raw_data = await self._redis_exec( - self.redis_message.hgetall, f"msg:{temp_uuid}" - ) + raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}") if not raw_data: return None # 解码 Redis 返回的字节数据 return { - k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") - if isinstance(v, bytes) - else v + k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") if isinstance(v, bytes) else v for k, v in raw_data.items() } except Exception as e: @@ -148,9 +122,7 @@ class MessageQueueProcessor: # 批量获取消息 message_uuids = [] for _ in range(20): # 批量处理20条消息 - result = await self._redis_exec( - self.redis_message.brpop, ["message_write_queue"], timeout=1 - ) + result = await self._redis_exec(self.redis_message.brpop, ["message_write_queue"], timeout=1) if result: # result是 (queue_name, value) 的元组,需要解码 uuid_value = result[1] @@ -179,17 +151,13 @@ class MessageQueueProcessor: for temp_uuid in message_uuids: try: # 获取消息数据并解码 - raw_data = await self._redis_exec( - self.redis_message.hgetall, f"msg:{temp_uuid}" - ) + raw_data = await self._redis_exec(self.redis_message.hgetall, f"msg:{temp_uuid}") if not raw_data: continue # 解码 Redis 返回的字节数据 message_data = { - k.decode("utf-8") if isinstance(k, bytes) else k: v.decode( - "utf-8" - ) + k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") if isinstance(v, bytes) else v for k, v in raw_data.items() @@ -215,10 +183,7 @@ class MessageQueueProcessor: await session.refresh(msg) # 更新成功状态,包含临时消息ID映射 - assert msg.message_id is not None - await self.update_message_status( - temp_uuid, "completed", msg.message_id - ) + await self.update_message_status(temp_uuid, "completed", msg.message_id) # 如果有临时消息ID,存储映射关系并通知客户端更新 if message_data.get("temp_message_id"): @@ -232,12 +197,11 @@ class MessageQueueProcessor: # 发送消息ID更新通知到频道 channel_id = int(message_data["channel_id"]) - await self._notify_message_update( - channel_id, temp_msg_id, msg.message_id, message_data - ) + await self._notify_message_update(channel_id, temp_msg_id, msg.message_id, message_data) logger.info( - f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, temp_id: {message_data.get('temp_message_id')}" + f"Message {temp_uuid} persisted to DB with ID {msg.message_id}, " + f"temp_id: {message_data.get('temp_message_id')}" ) except Exception as e: @@ -272,9 +236,7 @@ class MessageQueueProcessor: json.dumps(update_event), ) - logger.info( - f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}" - ) + logger.info(f"Published message update: temp_id={temp_message_id}, real_id={real_message_id}") except Exception as e: logger.error(f"Failed to notify message update: {e}") @@ -320,9 +282,7 @@ async def cache_message_to_redis(channel_id: int, message_data: dict, temp_uuid: await message_queue_processor.cache_message(channel_id, message_data, temp_uuid) -async def get_cached_messages( - channel_id: int, limit: int = 50, since: int = 0 -) -> list[dict]: +async def get_cached_messages(channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: """从 Redis 获取缓存的消息 - 便捷接口""" return await message_queue_processor.get_cached_messages(channel_id, limit, since) diff --git a/app/service/online_status_maintenance.py b/app/service/online_status_maintenance.py index e7a1b05..f5a3e12 100644 --- a/app/service/online_status_maintenance.py +++ b/app/service/online_status_maintenance.py @@ -4,10 +4,10 @@ 此模块提供在游玩状态下维护用户在线状态的功能, 解决游玩时显示离线的问题。 """ + from __future__ import annotations import asyncio -from datetime import datetime, timedelta from app.dependencies.database import get_redis from app.log import logger @@ -17,32 +17,32 @@ from app.router.v2.stats import REDIS_PLAYING_USERS_KEY, _redis_exec, get_redis_ async def maintain_playing_users_online_status(): """ 维护正在游玩用户的在线状态 - + 定期刷新正在游玩用户的metadata在线标记, 确保他们在游玩过程中显示为在线状态。 """ redis_sync = get_redis_message() redis_async = get_redis() - + try: # 获取所有正在游玩的用户 playing_users = await _redis_exec(redis_sync.smembers, REDIS_PLAYING_USERS_KEY) - + if not playing_users: return - + logger.debug(f"Maintaining online status for {len(playing_users)} playing users") - + # 为每个游玩用户刷新metadata在线标记 for user_id in playing_users: user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id) metadata_key = f"metadata:online:{user_id_str}" - + # 设置或刷新metadata在线标记,过期时间为1小时 await redis_async.set(metadata_key, "playing", ex=3600) - + logger.debug(f"Updated metadata online status for {len(playing_users)} playing users") - + except Exception as e: logger.error(f"Error maintaining playing users online status: {e}") @@ -50,11 +50,11 @@ async def maintain_playing_users_online_status(): async def start_online_status_maintenance_task(): """ 启动在线状态维护任务 - + 每5分钟运行一次维护任务,确保游玩用户保持在线状态 """ logger.info("Starting online status maintenance task") - + while True: try: await maintain_playing_users_online_status() diff --git a/app/service/online_status_manager.py b/app/service/online_status_manager.py index b4ef91e..c5a4236 100644 --- a/app/service/online_status_manager.py +++ b/app/service/online_status_manager.py @@ -3,9 +3,9 @@ 此模块负责统一管理用户的在线状态,确保用户在连接WebSocket后立即显示为在线。 """ + from __future__ import annotations -import asyncio from datetime import datetime from app.dependencies.database import get_redis @@ -15,92 +15,93 @@ from app.router.v2.stats import add_online_user class OnlineStatusManager: """在线状态管理器""" - + @staticmethod async def set_user_online(user_id: int, hub_type: str = "general") -> None: """ 设置用户为在线状态 - + Args: user_id: 用户ID hub_type: Hub类型 (metadata, spectator, multiplayer等) """ try: redis = get_redis() - + # 1. 添加到在线用户集合 await add_online_user(user_id) - + # 2. 设置metadata在线标记,这是is_online检查的关键 metadata_key = f"metadata:online:{user_id}" await redis.set(metadata_key, hub_type, ex=7200) # 2小时过期 - + # 3. 设置最后活跃时间戳 last_seen_key = f"user:last_seen:{user_id}" await redis.set(last_seen_key, int(datetime.utcnow().timestamp()), ex=7200) - + logger.debug(f"[OnlineStatusManager] User {user_id} set online via {hub_type}") - + except Exception as e: logger.error(f"[OnlineStatusManager] Error setting user {user_id} online: {e}") - + @staticmethod async def refresh_user_online_status(user_id: int, hub_type: str = "active") -> None: """ 刷新用户的在线状态 - + Args: user_id: 用户ID hub_type: 当前活动类型 """ try: redis = get_redis() - + # 刷新metadata在线标记 metadata_key = f"metadata:online:{user_id}" await redis.set(metadata_key, hub_type, ex=7200) - + # 刷新最后活跃时间 last_seen_key = f"user:last_seen:{user_id}" await redis.set(last_seen_key, int(datetime.utcnow().timestamp()), ex=7200) - + logger.debug(f"[OnlineStatusManager] Refreshed online status for user {user_id}") - + except Exception as e: logger.error(f"[OnlineStatusManager] Error refreshing user {user_id} status: {e}") - + @staticmethod async def set_user_offline(user_id: int) -> None: """ 设置用户为离线状态 - + Args: user_id: 用户ID """ try: redis = get_redis() - + # 删除metadata在线标记 metadata_key = f"metadata:online:{user_id}" await redis.delete(metadata_key) - + # 从在线用户集合中移除 from app.router.v2.stats import remove_online_user + await remove_online_user(user_id) - + logger.debug(f"[OnlineStatusManager] User {user_id} set offline") - + except Exception as e: logger.error(f"[OnlineStatusManager] Error setting user {user_id} offline: {e}") - + @staticmethod async def is_user_online(user_id: int) -> bool: """ 检查用户是否在线 - + Args: user_id: 用户ID - + Returns: bool: 用户是否在线 """ @@ -112,19 +113,19 @@ class OnlineStatusManager: except Exception as e: logger.error(f"[OnlineStatusManager] Error checking user {user_id} online status: {e}") return False - + @staticmethod async def get_online_users_count() -> int: """ 获取在线用户数量 - + Returns: int: 在线用户数量 """ try: - from app.router.v2.stats import _get_online_users_count from app.dependencies.database import get_redis - + from app.router.v2.stats import _get_online_users_count + redis = get_redis() return await _get_online_users_count(redis) except Exception as e: diff --git a/app/service/optimized_message.py b/app/service/optimized_message.py index 76eb6b4..06a5d99 100644 --- a/app/service/optimized_message.py +++ b/app/service/optimized_message.py @@ -50,7 +50,6 @@ class OptimizedMessageService: Returns: 消息响应对象 """ - assert sender.id is not None # 准备消息数据 message_data = { @@ -97,9 +96,7 @@ class OptimizedMessageService: logger.info(f"Message sent to channel {channel_id} with temp_uuid {temp_uuid}") return temp_response - async def get_cached_messages( - self, channel_id: int, limit: int = 50, since: int = 0 - ) -> list[dict]: + async def get_cached_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict]: """ 获取缓存的消息 @@ -125,9 +122,7 @@ class OptimizedMessageService: """ return await self.message_queue.get_message_status(temp_uuid) - async def wait_for_message_persisted( - self, temp_uuid: str, timeout: int = 30 - ) -> dict | None: + async def wait_for_message_persisted(self, temp_uuid: str, timeout: int = 30) -> dict | None: # noqa: ASYNC109 """ 等待消息持久化到数据库 diff --git a/app/service/password_reset_service.py b/app/service/password_reset_service.py index 193b897..fc0c02e 100644 --- a/app/service/password_reset_service.py +++ b/app/service/password_reset_service.py @@ -4,74 +4,67 @@ from __future__ import annotations +from datetime import UTC, datetime +import json import secrets import string -from datetime import datetime, UTC, timedelta -from typing import Optional, Tuple -import json -from app.config import settings +from app.auth import get_password_hash, invalidate_user_tokens from app.database import User from app.dependencies.database import with_db -from app.service.email_service import EmailService -from app.service.email_queue import email_queue # 导入邮件队列 from app.log import logger -from app.auth import get_password_hash, invalidate_user_tokens +from app.service.email_queue import email_queue # 导入邮件队列 +from app.service.email_service import EmailService -from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession from redis.asyncio import Redis +from sqlmodel import select class PasswordResetService: """密码重置服务 - 使用Redis管理验证码""" - + # Redis键前缀 RESET_CODE_PREFIX = "password_reset:code:" # 存储验证码 RESET_RATE_LIMIT_PREFIX = "password_reset:rate_limit:" # 限制请求频率 - + def __init__(self): self.email_service = EmailService() - + def generate_reset_code(self) -> str: """生成8位重置验证码""" - return ''.join(secrets.choice(string.digits) for _ in range(8)) - + return "".join(secrets.choice(string.digits) for _ in range(8)) + def _get_reset_code_key(self, email: str) -> str: """获取验证码Redis键""" return f"{self.RESET_CODE_PREFIX}{email.lower()}" - + def _get_rate_limit_key(self, email: str) -> str: """获取频率限制Redis键""" return f"{self.RESET_RATE_LIMIT_PREFIX}{email.lower()}" - + async def request_password_reset( - self, - email: str, - ip_address: str, - user_agent: str, - redis: Redis - ) -> Tuple[bool, str]: + self, email: str, ip_address: str, user_agent: str, redis: Redis + ) -> tuple[bool, str]: """ 请求密码重置 - + Args: email: 邮箱地址 ip_address: 请求IP user_agent: 用户代理 redis: Redis连接 - + Returns: Tuple[success, message] """ email = email.lower().strip() - + async with with_db() as session: # 查找用户 user_query = select(User).where(User.email == email) user_result = await session.exec(user_query) user = user_result.first() - + if not user: # 为了安全考虑,不告诉用户邮箱不存在,但仍然要检查频率限制 rate_limit_key = self._get_rate_limit_key(email) @@ -80,15 +73,15 @@ class PasswordResetService: # 设置一个假的频率限制,防止恶意用户探测邮箱 await redis.setex(rate_limit_key, 60, "1") return True, "如果该邮箱地址存在,您将收到密码重置邮件" - + # 检查频率限制 rate_limit_key = self._get_rate_limit_key(email) if await redis.get(rate_limit_key): return False, "请求过于频繁,请稍后再试" - + # 生成重置验证码 reset_code = self.generate_reset_code() - + # 存储验证码信息到Redis reset_code_key = self._get_reset_code_key(email) reset_data = { @@ -98,22 +91,18 @@ class PasswordResetService: "created_at": datetime.now(UTC).isoformat(), "ip_address": ip_address, "user_agent": user_agent, - "used": False + "used": False, } - + try: # 先设置频率限制 await redis.setex(rate_limit_key, 60, "1") # 存储验证码,10分钟过期 await redis.setex(reset_code_key, 600, json.dumps(reset_data)) - + # 发送重置邮件 - email_sent = await self.send_password_reset_email( - email=email, - code=reset_code, - username=user.username - ) - + email_sent = await self.send_password_reset_email(email=email, code=reset_code, username=user.username) + if email_sent: logger.info(f"[Password Reset] Sent reset code to user {user.id} ({email})") return True, "密码重置邮件已发送,请查收邮箱" @@ -123,17 +112,17 @@ class PasswordResetService: await redis.delete(rate_limit_key) logger.warning(f"[Password Reset] Email sending failed, cleaned up Redis data for {email}") return False, "邮件发送失败,请稍后重试" - - except Exception as e: + + except Exception: # Redis操作失败,清理可能的部分数据 try: await redis.delete(reset_code_key) await redis.delete(rate_limit_key) - except: + except Exception: pass - logger.error(f"[Password Reset] Redis operation failed: {e}") + logger.exception("[Password Reset] Redis operation failed") return False, "服务暂时不可用,请稍后重试" - + async def send_password_reset_email(self, email: str, code: str, username: str) -> bool: """发送密码重置邮件(使用邮件队列)""" try: @@ -206,15 +195,15 @@ class PasswordResetService:

osu! 密码重置

Password Reset Request

- +

你好 {username}!

我们收到了您的密码重置请求。如果这是您本人操作,请使用以下验证码重置密码:

- +
{code}
- +

这个验证码将在 10 分钟后过期

- +
⚠️ 安全提醒:
    @@ -224,19 +213,19 @@ class PasswordResetService:
  • 建议设置一个强密码以保护您的账户安全
- +

如果您有任何问题,请联系我们的支持团队。

- +
- +

Hello {username}!

We received a request to reset your password. If this was you, please use the following verification code to reset your password:

- +

This verification code will expire in 10 minutes.

- +

Security Notice: Do not share this verification code with anyone. If you did not request a password reset, please ignore this email.

- + - """ - + """ # noqa: E501 + # 纯文本内容(作为备用) plain_content = f""" 你好 {username}! @@ -270,120 +259,123 @@ class PasswordResetService: # 添加邮件到队列 subject = "密码重置 - Password Reset" metadata = {"type": "password_reset", "email": email, "code": code} - + await email_queue.enqueue_email( to_email=email, subject=subject, content=plain_content, html_content=html_content, - metadata=metadata + metadata=metadata, ) - + logger.info(f"[Password Reset] Enqueued reset code email to {email}") return True - + except Exception as e: logger.error(f"[Password Reset] Failed to enqueue email: {e}") return False - + async def reset_password( self, email: str, reset_code: str, new_password: str, ip_address: str, - redis: Redis - ) -> Tuple[bool, str]: + redis: Redis, + ) -> tuple[bool, str]: """ 重置密码 - + Args: email: 邮箱地址 reset_code: 重置验证码 new_password: 新密码 ip_address: 请求IP redis: Redis连接 - + Returns: Tuple[success, message] """ email = email.lower().strip() reset_code = reset_code.strip() - + async with with_db() as session: # 从Redis获取验证码数据 reset_code_key = self._get_reset_code_key(email) reset_data_str = await redis.get(reset_code_key) - + if not reset_data_str: return False, "验证码无效或已过期" - + try: reset_data = json.loads(reset_data_str) except json.JSONDecodeError: return False, "验证码数据格式错误" - + # 验证验证码 if reset_data.get("reset_code") != reset_code: return False, "验证码错误" - + # 检查是否已使用 if reset_data.get("used", False): return False, "验证码已使用" - + # 验证邮箱匹配 if reset_data.get("email") != email: return False, "邮箱地址不匹配" - + # 查找用户 user_query = select(User).where(User.email == email) user_result = await session.exec(user_query) user = user_result.first() - + if not user: return False, "用户不存在" - + if user.id is None: return False, "用户ID无效" - + # 验证用户ID匹配 if reset_data.get("user_id") != user.id: return False, "用户信息不匹配" - + # 密码强度检查 if len(new_password) < 6: return False, "密码长度至少为6位" - + try: # 先标记验证码为已使用(在数据库操作之前) reset_data["used"] = True reset_data["used_at"] = datetime.now(UTC).isoformat() - + # 保存用户ID用于日志记录 user_id = user.id - + # 更新用户密码 password_hash = get_password_hash(new_password) user.pw_bcrypt = password_hash # 使用正确的字段名称 pw_bcrypt 而不是 password_hash - + # 提交数据库更改 await session.commit() - + # 使该用户的所有现有令牌失效(使其他客户端登录失效) tokens_deleted = await invalidate_user_tokens(session, user_id) - + # 数据库操作成功后,更新Redis状态 await redis.setex(reset_code_key, 300, json.dumps(reset_data)) # 保留5分钟用于日志记录 - - logger.info(f"[Password Reset] User {user_id} ({email}) successfully reset password from IP {ip_address}, invalidated {tokens_deleted} tokens") + + logger.info( + f"[Password Reset] User {user_id} ({email}) successfully reset password from IP {ip_address}," + f" invalidated {tokens_deleted} tokens" + ) return True, "密码重置成功,所有设备已被登出" - + except Exception as e: # 不要在异常处理中访问user.id,可能触发数据库操作 user_id = reset_data.get("user_id", "未知") logger.error(f"[Password Reset] Failed to reset password for user {user_id}: {e}") await session.rollback() - + # 数据库回滚时,需要恢复Redis中的验证码状态 try: # 恢复验证码为未使用状态 @@ -394,35 +386,39 @@ class PasswordResetService: "created_at": reset_data.get("created_at"), "ip_address": reset_data.get("ip_address"), "user_agent": reset_data.get("user_agent"), - "used": False # 恢复为未使用状态 + "used": False, # 恢复为未使用状态 } - + # 计算剩余的TTL时间 created_at = datetime.fromisoformat(reset_data.get("created_at", "")) elapsed = (datetime.now(UTC) - created_at).total_seconds() remaining_ttl = max(0, 600 - int(elapsed)) # 600秒总过期时间 - + if remaining_ttl > 0: - await redis.setex(reset_code_key, remaining_ttl, json.dumps(original_reset_data)) + await redis.setex( + reset_code_key, + remaining_ttl, + json.dumps(original_reset_data), + ) logger.info(f"[Password Reset] Restored Redis state after database rollback for {email}") else: # 如果已经过期,直接删除 await redis.delete(reset_code_key) logger.info(f"[Password Reset] Removed expired reset code after database rollback for {email}") - + except Exception as redis_error: logger.error(f"[Password Reset] Failed to restore Redis state after rollback: {redis_error}") - + return False, "密码重置失败,请稍后重试" - + async def get_reset_attempts_count(self, email: str, redis: Redis) -> int: """ 获取邮箱的重置尝试次数(通过检查频率限制键) - + Args: email: 邮箱地址 redis: Redis连接 - + Returns: 尝试次数 """ diff --git a/app/service/ranking_cache_service.py b/app/service/ranking_cache_service.py index d8f51a6..ea37306 100644 --- a/app/service/ranking_cache_service.py +++ b/app/service/ranking_cache_service.py @@ -34,9 +34,7 @@ class DateTimeEncoder(json.JSONEncoder): def safe_json_dumps(data) -> str: """安全的 JSON 序列化,支持 datetime 对象""" - return json.dumps( - data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":") - ) + return json.dumps(data, cls=DateTimeEncoder, ensure_ascii=False, separators=(",", ":")) class RankingCacheService: @@ -225,9 +223,7 @@ class RankingCacheService: ) -> None: """刷新排行榜缓存""" if self._refreshing: - logger.debug( - f"Ranking cache refresh already in progress for {ruleset}:{type}" - ) + logger.debug(f"Ranking cache refresh already in progress for {ruleset}:{type}") return # 使用配置文件的设置 @@ -253,9 +249,7 @@ class RankingCacheService: order_by = col(UserStatistics.ranked_score).desc() if country: - wheres.append( - col(UserStatistics.user).has(country_code=country.upper()) - ) + wheres.append(col(UserStatistics.user).has(country_code=country.upper())) # 获取总用户数用于统计 total_users_query = select(UserStatistics).where(*wheres) @@ -277,11 +271,7 @@ class RankingCacheService: for page in range(1, max_pages + 1): try: statistics_list = await session.exec( - select(UserStatistics) - .where(*wheres) - .order_by(order_by) - .limit(50) - .offset(50 * (page - 1)) + select(UserStatistics).where(*wheres).order_by(order_by).limit(50).offset(50 * (page - 1)) ) statistics_data = statistics_list.all() @@ -291,9 +281,7 @@ class RankingCacheService: # 转换为响应格式并确保正确序列化 ranking_data = [] for statistics in statistics_data: - user_stats_resp = await UserStatisticsResp.from_db( - statistics, session, None, include - ) + user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include) # 将 UserStatisticsResp 转换为字典,处理所有序列化问题 user_dict = json.loads(user_stats_resp.model_dump_json()) ranking_data.append(user_dict) @@ -323,9 +311,7 @@ class RankingCacheService: ) -> None: """刷新地区排行榜缓存""" if self._refreshing: - logger.debug( - f"Country ranking cache refresh already in progress for {ruleset}" - ) + logger.debug(f"Country ranking cache refresh already in progress for {ruleset}") return if max_pages is None: @@ -449,9 +435,7 @@ class RankingCacheService: for country in top_countries: for mode in game_modes: for ranking_type in ranking_types: - task = self.refresh_ranking_cache( - session, mode, ranking_type, country - ) + task = self.refresh_ranking_cache(session, mode, ranking_type, country) refresh_tasks.append(task) # 地区排行榜 @@ -493,9 +477,7 @@ class RankingCacheService: if keys: await self.redis.delete(*keys) deleted_keys += len(keys) - logger.info( - f"Invalidated {len(keys)} cache keys for {ruleset}:{type}" - ) + logger.info(f"Invalidated {len(keys)} cache keys for {ruleset}:{type}") elif ruleset: # 删除特定游戏模式的所有缓存 patterns = [ @@ -563,9 +545,7 @@ class RankingCacheService: "cached_user_rankings": len(ranking_keys), "cached_country_rankings": len(country_keys), "total_cached_rankings": len(total_keys), - "estimated_total_size_mb": ( - round(total_size / 1024 / 1024, 2) if total_size > 0 else 0 - ), + "estimated_total_size_mb": (round(total_size / 1024 / 1024, 2) if total_size > 0 else 0), "refreshing": self._refreshing, } except Exception as e: diff --git a/app/service/recalculate.py b/app/service/recalculate.py index fd12ab4..0112bd2 100644 --- a/app/service/recalculate.py +++ b/app/service/recalculate.py @@ -35,12 +35,8 @@ async def recalculate(): fetcher = await get_fetcher() redis = get_redis() for mode in GameMode: - await session.execute( - delete(PPBestScore).where(col(PPBestScore.gamemode) == mode) - ) - await session.execute( - delete(BestScore).where(col(BestScore.gamemode) == mode) - ) + await session.execute(delete(PPBestScore).where(col(PPBestScore.gamemode) == mode)) + await session.execute(delete(BestScore).where(col(BestScore.gamemode) == mode)) await session.commit() logger.info(f"Recalculating for mode: {mode}") statistics_list = ( @@ -53,32 +49,21 @@ async def recalculate(): ).all() await asyncio.gather( *[ - _recalculate_pp( - statistics.user_id, statistics.mode, session, fetcher, redis - ) + _recalculate_pp(statistics.user_id, statistics.mode, session, fetcher, redis) for statistics in statistics_list ] ) await asyncio.gather( *[ - _recalculate_best_score( - statistics.user_id, statistics.mode, session - ) + _recalculate_best_score(statistics.user_id, statistics.mode, session) for statistics in statistics_list ] ) await session.commit() - await asyncio.gather( - *[ - _recalculate_statistics(statistics, session) - for statistics in statistics_list - ] - ) + await asyncio.gather(*[_recalculate_statistics(statistics, session) for statistics in statistics_list]) await session.commit() - logger.success( - f"Recalculated for mode: {mode}, total users: {len(statistics_list)}" - ) + logger.success(f"Recalculated for mode: {mode}, total users: {len(statistics_list)}") async def _recalculate_pp( @@ -104,9 +89,7 @@ async def _recalculate_pp( beatmap_id = score.beatmap_id while time > 0: try: - db_beatmap = await Beatmap.get_or_fetch( - session, fetcher, bid=beatmap_id - ) + db_beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=beatmap_id) except HTTPError: time -= 1 await asyncio.sleep(2) @@ -116,9 +99,7 @@ async def _recalculate_pp( score.pp = 0 return try: - pp = await pre_fetch_and_calculate_pp( - score, beatmap_id, session, redis, fetcher - ) + pp = await pre_fetch_and_calculate_pp(score, beatmap_id, session, redis, fetcher) score.pp = pp if pp == 0: return @@ -138,15 +119,10 @@ async def _recalculate_pp( await asyncio.sleep(2) continue except Exception: - logger.exception( - f"Error calculating pp for score {score.id} on beatmap {beatmap_id}" - ) + logger.exception(f"Error calculating pp for score {score.id} on beatmap {beatmap_id}") return if time <= 0: - logger.warning( - f"Failed to fetch beatmap {beatmap_id} after 10 attempts, " - "retrying later..." - ) + logger.warning(f"Failed to fetch beatmap {beatmap_id} after 10 attempts, retrying later...") return score while len(scores) > 0: @@ -271,9 +247,7 @@ async def _recalculate_statistics(statistics: UserStatistics, session: AsyncSess statistics.count_100 += score.n100 + score.nkatu statistics.count_50 += score.n50 statistics.count_miss += score.nmiss - statistics.total_hits += ( - score.n300 + score.ngeki + score.n100 + score.nkatu + score.n50 - ) + statistics.total_hits += score.n300 + score.ngeki + score.n100 + score.nkatu + score.n50 if ranked and score.passed: statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo) diff --git a/app/service/redis_message_system.py b/app/service/redis_message_system.py index 51c7df6..8d3a033 100644 --- a/app/service/redis_message_system.py +++ b/app/service/redis_message_system.py @@ -18,6 +18,7 @@ from app.database.chat import ChatMessage, ChatMessageResp, MessageType from app.database.lazer_user import RANKING_INCLUDES, User, UserResp from app.dependencies.database import get_redis_message, with_db from app.log import logger +from app.utils import bg_tasks class RedisMessageSystem: @@ -67,12 +68,11 @@ class RedisMessageSystem: # 获取频道类型以判断是否需要存储到数据库 async with with_db() as session: - from app.database.chat import ChatChannel, ChannelType + from app.database.chat import ChannelType, ChatChannel + from sqlmodel import select - - channel_result = await session.exec( - select(ChatChannel.type).where(ChatChannel.channel_id == channel_id) - ) + + channel_result = await session.exec(select(ChatChannel.type).where(ChatChannel.channel_id == channel_id)) channel_type = channel_result.first() is_multiplayer = channel_type == ChannelType.MULTIPLAYER @@ -132,17 +132,14 @@ class RedisMessageSystem: if is_multiplayer: logger.info( - f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id}, will not be persisted to database" + f"Multiplayer message {message_id} sent to Redis cache for channel {channel_id}," + " will not be persisted to database" ) else: - logger.info( - f"Message {message_id} sent to Redis cache for channel {channel_id}" - ) + logger.info(f"Message {message_id} sent to Redis cache for channel {channel_id}") return response - async def get_messages( - self, channel_id: int, limit: int = 50, since: int = 0 - ) -> list[ChatMessageResp]: + async def get_messages(self, channel_id: int, limit: int = 50, since: int = 0) -> list[ChatMessageResp]: """ 获取频道消息 - 优先从 Redis 获取最新消息 @@ -166,9 +163,7 @@ class RedisMessageSystem: # 获取发送者信息 sender = await session.get(User, msg_data["sender_id"]) if sender: - user_resp = await UserResp.from_db( - sender, session, RANKING_INCLUDES - ) + user_resp = await UserResp.from_db(sender, session, RANKING_INCLUDES) if user_resp.statistics is None: from app.database.statistics import UserStatisticsResp @@ -223,39 +218,28 @@ class RedisMessageSystem: async def _generate_message_id(self, channel_id: int) -> int: """生成唯一的消息ID - 确保全局唯一且严格递增""" # 使用全局计数器确保所有频道的消息ID都是严格递增的 - message_id = await self._redis_exec( - self.redis.incr, "global_message_id_counter" - ) + message_id = await self._redis_exec(self.redis.incr, "global_message_id_counter") # 同时更新频道的最后消息ID,用于客户端状态同步 - await self._redis_exec( - self.redis.set, f"channel:{channel_id}:last_msg_id", message_id - ) + await self._redis_exec(self.redis.set, f"channel:{channel_id}:last_msg_id", message_id) return message_id - async def _store_to_redis( - self, message_id: int, channel_id: int, message_data: dict[str, Any] - ): + async def _store_to_redis(self, message_id: int, channel_id: int, message_data: dict[str, Any]): """存储消息到 Redis""" try: # 检查是否是多人房间消息 is_multiplayer = message_data.get("is_multiplayer", False) - + # 存储消息数据 await self._redis_exec( self.redis.hset, f"msg:{channel_id}:{message_id}", - mapping={ - k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) - for k, v in message_data.items() - }, + mapping={k: json.dumps(v) if isinstance(v, dict | list) else str(v) for k, v in message_data.items()}, ) # 设置消息过期时间(7天) - await self._redis_exec( - self.redis.expire, f"msg:{channel_id}:{message_id}", 604800 - ) + await self._redis_exec(self.redis.expire, f"msg:{channel_id}:{message_id}", 604800) # 清理可能存在的错误类型键,然后添加到频道消息列表(按时间排序) channel_messages_key = f"channel:{channel_id}:messages" @@ -264,14 +248,10 @@ class RedisMessageSystem: try: key_type = await self._redis_exec(self.redis.type, channel_messages_key) if key_type and key_type != "zset": - logger.warning( - f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}" - ) + logger.warning(f"Deleting Redis key {channel_messages_key} with wrong type: {key_type}") await self._redis_exec(self.redis.delete, channel_messages_key) except Exception as type_check_error: - logger.warning( - f"Failed to check key type for {channel_messages_key}: {type_check_error}" - ) + logger.warning(f"Failed to check key type for {channel_messages_key}: {type_check_error}") # 如果检查失败,直接删除键以确保清理 await self._redis_exec(self.redis.delete, channel_messages_key) @@ -283,15 +263,11 @@ class RedisMessageSystem: ) # 保持频道消息列表大小(最多1000条) - await self._redis_exec( - self.redis.zremrangebyrank, channel_messages_key, 0, -1001 - ) + await self._redis_exec(self.redis.zremrangebyrank, channel_messages_key, 0, -1001) # 只有非多人房间消息才添加到待持久化队列 if not is_multiplayer: - await self._redis_exec( - self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}" - ) + await self._redis_exec(self.redis.lpush, "pending_messages", f"{channel_id}:{message_id}") logger.debug(f"Message {message_id} added to persistence queue") else: logger.debug(f"Message {message_id} in multiplayer room, skipped persistence queue") @@ -300,9 +276,7 @@ class RedisMessageSystem: logger.error(f"Failed to store message to Redis: {e}") raise - async def _get_from_redis( - self, channel_id: int, limit: int = 50, since: int = 0 - ) -> list[dict[str, Any]]: + async def _get_from_redis(self, channel_id: int, limit: int = 50, since: int = 0) -> list[dict[str, Any]]: """从 Redis 获取消息""" try: # 获取消息键列表,按消息ID排序 @@ -340,9 +314,7 @@ class RedisMessageSystem: # 尝试解析 JSON try: - if k in ["grade_counts", "level"] or v.startswith( - ("{", "[") - ): + if k in ["grade_counts", "level"] or v.startswith(("{", "[")): message_data[k] = json.loads(v) elif k in ["message_id", "channel_id", "sender_id"]: message_data[k] = int(v) @@ -368,9 +340,7 @@ class RedisMessageSystem: logger.error(f"Failed to get messages from Redis: {e}") return [] - async def _backfill_from_database( - self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int - ): + async def _backfill_from_database(self, channel_id: int, existing_messages: list[ChatMessageResp], limit: int): """从数据库补充历史消息""" try: # 找到最小的消息ID @@ -404,9 +374,7 @@ class RedisMessageSystem: except Exception as e: logger.error(f"Failed to backfill from database: {e}") - async def _get_from_database_only( - self, channel_id: int, limit: int, since: int - ) -> list[ChatMessageResp]: + async def _get_from_database_only(self, channel_id: int, limit: int, since: int) -> list[ChatMessageResp]: """仅从数据库获取消息(回退方案)""" try: async with with_db() as session: @@ -417,20 +385,14 @@ class RedisMessageSystem: if since > 0: # 获取指定ID之后的消息,按ID正序 query = query.where(col(ChatMessage.message_id) > since) - query = query.order_by(col(ChatMessage.message_id).asc()).limit( - limit - ) + query = query.order_by(col(ChatMessage.message_id).asc()).limit(limit) else: # 获取最新消息,按ID倒序(最新的在前面) - query = query.order_by(col(ChatMessage.message_id).desc()).limit( - limit - ) + query = query.order_by(col(ChatMessage.message_id).desc()).limit(limit) messages = (await session.exec(query)).all() - results = [ - await ChatMessageResp.from_db(msg, session) for msg in messages - ] + results = [await ChatMessageResp.from_db(msg, session) for msg in messages] # 如果是 since > 0,保持正序;否则反转为时间正序 if since == 0: @@ -451,9 +413,7 @@ class RedisMessageSystem: # 获取待处理的消息 message_keys = [] for _ in range(self.max_batch_size): - key = await self._redis_exec( - self.redis.brpop, ["pending_messages"], timeout=1 - ) + key = await self._redis_exec(self.redis.brpop, ["pending_messages"], timeout=1) if key: # key 是 (queue_name, value) 的元组 value = key[1] @@ -483,9 +443,7 @@ class RedisMessageSystem: channel_id, message_id = map(int, key.split(":")) # 从 Redis 获取消息数据 - raw_data = await self._redis_exec( - self.redis.hgetall, f"msg:{channel_id}:{message_id}" - ) + raw_data = await self._redis_exec(self.redis.hgetall, f"msg:{channel_id}:{message_id}") if not raw_data: continue @@ -546,9 +504,7 @@ class RedisMessageSystem: # 提交批次 try: await session.commit() - logger.info( - f"Batch of {len(message_keys)} messages committed to database" - ) + logger.info(f"Batch of {len(message_keys)} messages committed to database") except Exception as e: logger.error(f"Failed to commit message batch: {e}") await session.rollback() @@ -559,7 +515,7 @@ class RedisMessageSystem: self._running = True self._batch_timer = asyncio.create_task(self._batch_persist_to_database()) # 启动时初始化消息ID计数器 - asyncio.create_task(self._initialize_message_counter()) + bg_tasks.add_task(self._initialize_message_counter) logger.info("Redis message system started") async def _initialize_message_counter(self): @@ -576,27 +532,19 @@ class RedisMessageSystem: max_id = result.one() or 0 # 检查 Redis 中的计数器值 - current_counter = await self._redis_exec( - self.redis.get, "global_message_id_counter" - ) + current_counter = await self._redis_exec(self.redis.get, "global_message_id_counter") current_counter = int(current_counter) if current_counter else 0 # 设置计数器为两者中的最大值 initial_counter = max(max_id, current_counter) - await self._redis_exec( - self.redis.set, "global_message_id_counter", initial_counter - ) + await self._redis_exec(self.redis.set, "global_message_id_counter", initial_counter) - logger.info( - f"Initialized global message ID counter to {initial_counter}" - ) + logger.info(f"Initialized global message ID counter to {initial_counter}") except Exception as e: logger.error(f"Failed to initialize message counter: {e}") # 如果初始化失败,设置一个安全的起始值 - await self._redis_exec( - self.redis.setnx, "global_message_id_counter", 1000000 - ) + await self._redis_exec(self.redis.setnx, "global_message_id_counter", 1000000) async def _cleanup_redis_keys(self): """清理可能存在问题的 Redis 键""" @@ -612,9 +560,7 @@ class RedisMessageSystem: try: key_type = await self._redis_exec(self.redis.type, key) if key_type and key_type != "zset": - logger.warning( - f"Cleaning up Redis key {key} with wrong type: {key_type}" - ) + logger.warning(f"Cleaning up Redis key {key} with wrong type: {key_type}") await self._redis_exec(self.redis.delete, key) except Exception as cleanup_error: logger.warning(f"Failed to cleanup key {key}: {cleanup_error}") diff --git a/app/service/room.py b/app/service/room.py index d6fdcc6..99ef917 100644 --- a/app/service/room.py +++ b/app/service/room.py @@ -14,15 +14,11 @@ from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession -async def create_playlist_room_from_api( - session: AsyncSession, room: APIUploadedRoom, host_id: int -) -> Room: +async def create_playlist_room_from_api(session: AsyncSession, room: APIUploadedRoom, host_id: int) -> Room: db_room = room.to_room() db_room.host_id = host_id db_room.starts_at = datetime.now(UTC) - db_room.ends_at = db_room.starts_at + timedelta( - minutes=db_room.duration if db_room.duration is not None else 0 - ) + db_room.ends_at = db_room.starts_at + timedelta(minutes=db_room.duration if db_room.duration is not None else 0) session.add(db_room) await session.commit() await session.refresh(db_room) @@ -87,13 +83,9 @@ async def create_playlist_room( return db_room -async def add_playlists_to_room( - session: AsyncSession, room_id: int, playlist: list[Playlist], owner_id: int -): +async def add_playlists_to_room(session: AsyncSession, room_id: int, playlist: list[Playlist], owner_id: int): for item in playlist: - if not ( - await session.exec(select(exists().where(col(Beatmap.id) == item.beatmap))) - ).first(): + if not (await session.exec(select(exists().where(col(Beatmap.id) == item.beatmap)))).first(): fetcher = await get_fetcher() await Beatmap.get_or_fetch(session, fetcher, item.beatmap_id) item.id = await Playlist.get_next_id_for_room(room_id, session) diff --git a/app/service/session_manager.py b/app/service/session_manager.py index 11d78c4..29e73f3 100644 --- a/app/service/session_manager.py +++ b/app/service/session_manager.py @@ -4,15 +4,15 @@ API 状态管理 - 模拟 osu! 的 APIState 和会话管理 from __future__ import annotations -from enum import Enum -from typing import Optional from datetime import datetime +from enum import Enum from pydantic import BaseModel class APIState(str, Enum): """API 连接状态,对应 osu! 的 APIState""" + OFFLINE = "offline" CONNECTING = "connecting" REQUIRES_SECOND_FACTOR_AUTH = "requires_second_factor_auth" # 需要二次验证 @@ -22,6 +22,7 @@ class APIState(str, Enum): class UserSession(BaseModel): """用户会话信息""" + user_id: int username: str email: str @@ -38,10 +39,10 @@ class UserSession(BaseModel): class SessionManager: """会话管理器""" - + def __init__(self): self._sessions: dict[str, UserSession] = {} - + def create_session( self, user_id: int, @@ -49,19 +50,19 @@ class SessionManager: email: str, ip_address: str, country_code: str | None = None, - is_new_location: bool = False + is_new_location: bool = False, ) -> UserSession: """创建新的用户会话""" import secrets - + session_token = secrets.token_urlsafe(32) - + # 根据是否为新位置决定初始状态 if is_new_location: state = APIState.REQUIRES_SECOND_FACTOR_AUTH else: state = APIState.ONLINE - + session = UserSession( user_id=user_id, username=username, @@ -71,33 +72,33 @@ class SessionManager: requires_verification=is_new_location, ip_address=ip_address, country_code=country_code, - is_new_location=is_new_location + is_new_location=is_new_location, ) - + self._sessions[session_token] = session return session - + def get_session(self, session_token: str) -> UserSession | None: """获取会话""" return self._sessions.get(session_token) - + def update_session_state(self, session_token: str, state: APIState): """更新会话状态""" if session_token in self._sessions: self._sessions[session_token].state = state - + def mark_verification_sent(self, session_token: str): """标记验证邮件已发送""" if session_token in self._sessions: session = self._sessions[session_token] session.verification_sent = True session.last_verification_attempt = datetime.now() - + def increment_failed_attempts(self, session_token: str): """增加失败尝试次数""" if session_token in self._sessions: self._sessions[session_token].failed_attempts += 1 - + def verify_session(self, session_token: str) -> bool: """验证会话成功""" if session_token in self._sessions: @@ -106,11 +107,11 @@ class SessionManager: session.requires_verification = False return True return False - + def remove_session(self, session_token: str): """移除会话""" self._sessions.pop(session_token, None) - + def cleanup_expired_sessions(self): """清理过期会话""" # 这里可以实现清理逻辑 diff --git a/app/service/stats_cleanup.py b/app/service/stats_cleanup.py index a7c57d0..e83b662 100644 --- a/app/service/stats_cleanup.py +++ b/app/service/stats_cleanup.py @@ -26,14 +26,12 @@ async def cleanup_stale_online_users() -> tuple[int, int]: # 检查在线用户的最后活动时间 current_time = datetime.utcnow() - stale_threshold = current_time - timedelta(hours=2) # 2小时无活动视为过期 + stale_threshold = current_time - timedelta(hours=2) # 2小时无活动视为过期 # noqa: F841 # 对于在线用户,我们检查metadata在线标记 stale_online_users = [] for user_id in online_users: - user_id_str = ( - user_id.decode() if isinstance(user_id, bytes) else str(user_id) - ) + user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id) metadata_key = f"metadata:online:{user_id_str}" # 如果metadata标记不存在,说明用户已经离线 @@ -42,9 +40,7 @@ async def cleanup_stale_online_users() -> tuple[int, int]: # 清理过期的在线用户 if stale_online_users: - await _redis_exec( - redis_sync.srem, REDIS_ONLINE_USERS_KEY, *stale_online_users - ) + await _redis_exec(redis_sync.srem, REDIS_ONLINE_USERS_KEY, *stale_online_users) online_cleaned = len(stale_online_users) logger.info(f"Cleaned {online_cleaned} stale online users") @@ -52,22 +48,19 @@ async def cleanup_stale_online_users() -> tuple[int, int]: # 只有当用户明确不在任何hub连接中时才移除 stale_playing_users = [] for user_id in playing_users: - user_id_str = ( - user_id.decode() if isinstance(user_id, bytes) else str(user_id) - ) + user_id_str = user_id.decode() if isinstance(user_id, bytes) else str(user_id) metadata_key = f"metadata:online:{user_id_str}" - + # 只有当metadata在线标记完全不存在且用户也不在在线列表中时, # 才认为用户真正离线 - if (not await redis_async.exists(metadata_key) and - user_id_str not in [u.decode() if isinstance(u, bytes) else str(u) for u in online_users]): + if not await redis_async.exists(metadata_key) and user_id_str not in [ + u.decode() if isinstance(u, bytes) else str(u) for u in online_users + ]: stale_playing_users.append(user_id_str) # 清理过期的游玩用户 if stale_playing_users: - await _redis_exec( - redis_sync.srem, REDIS_PLAYING_USERS_KEY, *stale_playing_users - ) + await _redis_exec(redis_sync.srem, REDIS_PLAYING_USERS_KEY, *stale_playing_users) playing_cleaned = len(stale_playing_users) logger.info(f"Cleaned {playing_cleaned} stale playing users") diff --git a/app/service/stats_scheduler.py b/app/service/stats_scheduler.py index cef88d1..7c72a12 100644 --- a/app/service/stats_scheduler.py +++ b/app/service/stats_scheduler.py @@ -61,26 +61,29 @@ class StatsScheduler: try: # 计算下次区间结束时间 now = datetime.utcnow() - + # 计算当前区间的结束时间 current_minute = (now.minute // 30) * 30 - current_interval_end = now.replace(minute=current_minute, second=0, microsecond=0) + timedelta(minutes=30) - + current_interval_end = now.replace(minute=current_minute, second=0, microsecond=0) + timedelta( + minutes=30 + ) + # 如果当前时间已经超过了当前区间结束时间,说明需要等待下一个区间结束 if now >= current_interval_end: current_interval_end += timedelta(minutes=30) - + # 计算需要等待的时间 sleep_seconds = (current_interval_end - now).total_seconds() - + # 添加小的缓冲时间,确保区间真正结束后再处理 sleep_seconds += 10 # 额外等待10秒 - + # 限制等待时间范围 sleep_seconds = max(min(sleep_seconds, 32 * 60), 10) - + logger.debug( - f"Next interval finalization in {sleep_seconds / 60:.1f} minutes at {current_interval_end.strftime('%H:%M:%S')}" + f"Next interval finalization in {sleep_seconds / 60:.1f} " + f"minutes at {current_interval_end.strftime('%H:%M:%S')}" ) await asyncio.sleep(sleep_seconds) @@ -137,7 +140,8 @@ class StatsScheduler: online_cleaned, playing_cleaned = await cleanup_stale_online_users() if online_cleaned > 0 or playing_cleaned > 0: logger.info( - f"Initial cleanup: removed {online_cleaned} stale online users, {playing_cleaned} stale playing users" + f"Initial cleanup: removed {online_cleaned} stale online users," + f" {playing_cleaned} stale playing users" ) await refresh_redis_key_expiry() diff --git a/app/service/subscribers/base.py b/app/service/subscribers/base.py index 39af4cb..89a9e9a 100644 --- a/app/service/subscribers/base.py +++ b/app/service/subscribers/base.py @@ -31,9 +31,7 @@ class RedisSubscriber: async def listen(self): while True: - message = await self.pubsub.get_message( - ignore_subscribe_messages=True, timeout=None - ) + message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=None) if message is not None and message["type"] == "message": matched_handlers: list[Callable[[str, str], Awaitable[Any]]] = [] @@ -53,10 +51,7 @@ class RedisSubscriber: if matched_handlers: await asyncio.gather( - *[ - handler(message["channel"], message["data"]) - for handler in matched_handlers - ] + *[handler(message["channel"], message["data"]) for handler in matched_handlers] ) def start(self): diff --git a/app/service/subscribers/score_processed.py b/app/service/subscribers/score_processed.py index 2b69740..9613987 100644 --- a/app/service/subscribers/score_processed.py +++ b/app/service/subscribers/score_processed.py @@ -46,12 +46,7 @@ class ScoreSubscriber(RedisSubscriber): return async with with_db() as session: score = await session.get(Score, score_id) - if ( - not score - or not score.passed - or score.room_id is None - or score.playlist_item_id is None - ): + if not score or not score.passed or score.room_id is None or score.playlist_item_id is None: return if not self.room_subscriber.get(score.room_id, []): return diff --git a/app/service/user_cache_service.py b/app/service/user_cache_service.py index f999fe8..7b01be4 100644 --- a/app/service/user_cache_service.py +++ b/app/service/user_cache_service.py @@ -47,17 +47,13 @@ class UserCacheService: self._refreshing = False self._background_tasks: set = set() - def _get_v1_user_cache_key( - self, user_id: int, ruleset: GameMode | None = None - ) -> str: + def _get_v1_user_cache_key(self, user_id: int, ruleset: GameMode | None = None) -> str: """生成 V1 用户缓存键""" if ruleset: return f"v1_user:{user_id}:ruleset:{ruleset}" return f"v1_user:{user_id}" - async def get_v1_user_from_cache( - self, user_id: int, ruleset: GameMode | None = None - ) -> dict | None: + async def get_v1_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> dict | None: """从缓存获取 V1 用户信息""" try: cache_key = self._get_v1_user_cache_key(user_id, ruleset) @@ -96,9 +92,7 @@ class UserCacheService: keys = await self.redis.keys(pattern) if keys: await self.redis.delete(*keys) - logger.info( - f"Invalidated {len(keys)} V1 cache entries for user {user_id}" - ) + logger.info(f"Invalidated {len(keys)} V1 cache entries for user {user_id}") except Exception as e: logger.error(f"Error invalidating V1 user cache: {e}") @@ -126,9 +120,7 @@ class UserCacheService: """生成用户谱面集缓存键""" return f"user:{user_id}:beatmapsets:{beatmapset_type}:limit:{limit}:offset:{offset}" - async def get_user_from_cache( - self, user_id: int, ruleset: GameMode | None = None - ) -> UserResp | None: + async def get_user_from_cache(self, user_id: int, ruleset: GameMode | None = None) -> UserResp | None: """从缓存获取用户信息""" try: cache_key = self._get_user_cache_key(user_id, ruleset) @@ -172,14 +164,10 @@ class UserCacheService: ) -> list[ScoreResp] | None: """从缓存获取用户成绩""" try: - cache_key = self._get_user_scores_cache_key( - user_id, score_type, mode, limit, offset - ) + cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset) cached_data = await self.redis.get(cache_key) if cached_data: - logger.debug( - f"User scores cache hit for user {user_id}, type {score_type}" - ) + logger.debug(f"User scores cache hit for user {user_id}, type {score_type}") data = json.loads(cached_data) return [ScoreResp(**score_data) for score_data in data] return None @@ -201,16 +189,12 @@ class UserCacheService: try: if expire_seconds is None: expire_seconds = settings.user_scores_cache_expire_seconds - cache_key = self._get_user_scores_cache_key( - user_id, score_type, mode, limit, offset - ) + cache_key = self._get_user_scores_cache_key(user_id, score_type, mode, limit, offset) # 使用 model_dump_json() 而不是 model_dump() + json.dumps() scores_json_list = [score.model_dump_json() for score in scores] cached_data = f"[{','.join(scores_json_list)}]" await self.redis.setex(cache_key, expire_seconds, cached_data) - logger.debug( - f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s" - ) + logger.debug(f"Cached user {user_id} scores ({score_type}) for {expire_seconds}s") except Exception as e: logger.error(f"Error caching user scores: {e}") @@ -219,14 +203,10 @@ class UserCacheService: ) -> list[Any] | None: """从缓存获取用户谱面集""" try: - cache_key = self._get_user_beatmapsets_cache_key( - user_id, beatmapset_type, limit, offset - ) + cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset) cached_data = await self.redis.get(cache_key) if cached_data: - logger.debug( - f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}" - ) + logger.debug(f"User beatmapsets cache hit for user {user_id}, type {beatmapset_type}") return json.loads(cached_data) return None except Exception as e: @@ -246,9 +226,7 @@ class UserCacheService: try: if expire_seconds is None: expire_seconds = settings.user_beatmapsets_cache_expire_seconds - cache_key = self._get_user_beatmapsets_cache_key( - user_id, beatmapset_type, limit, offset - ) + cache_key = self._get_user_beatmapsets_cache_key(user_id, beatmapset_type, limit, offset) # 使用 model_dump_json() 处理有 model_dump_json 方法的对象,否则使用 safe_json_dumps serialized_beatmapsets = [] for bms in beatmapsets: @@ -258,9 +236,7 @@ class UserCacheService: serialized_beatmapsets.append(safe_json_dumps(bms)) cached_data = f"[{','.join(serialized_beatmapsets)}]" await self.redis.setex(cache_key, expire_seconds, cached_data) - logger.debug( - f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s" - ) + logger.debug(f"Cached user {user_id} beatmapsets ({beatmapset_type}) for {expire_seconds}s") except Exception as e: logger.error(f"Error caching user beatmapsets: {e}") @@ -276,9 +252,7 @@ class UserCacheService: except Exception as e: logger.error(f"Error invalidating user cache: {e}") - async def invalidate_user_scores_cache( - self, user_id: int, mode: GameMode | None = None - ): + async def invalidate_user_scores_cache(self, user_id: int, mode: GameMode | None = None): """使用户成绩缓存失效""" try: # 删除用户成绩相关缓存 @@ -287,9 +261,7 @@ class UserCacheService: keys = await self.redis.keys(pattern) if keys: await self.redis.delete(*keys) - logger.info( - f"Invalidated {len(keys)} score cache entries for user {user_id}" - ) + logger.info(f"Invalidated {len(keys)} score cache entries for user {user_id}") except Exception as e: logger.error(f"Error invalidating user scores cache: {e}") @@ -303,9 +275,7 @@ class UserCacheService: logger.info(f"Preloading cache for {len(user_ids)} users") # 批量获取用户 - users = ( - await session.exec(select(User).where(col(User.id).in_(user_ids))) - ).all() + users = (await session.exec(select(User).where(col(User.id).in_(user_ids)))).all() # 串行缓存用户信息,避免并发数据库访问问题 cached_count = 0 @@ -332,9 +302,7 @@ class UserCacheService: except Exception as e: logger.error(f"Error caching single user {user.id}: {e}") - async def refresh_user_cache_on_score_submit( - self, session: AsyncSession, user_id: int, mode: GameMode - ): + async def refresh_user_cache_on_score_submit(self, session: AsyncSession, user_id: int, mode: GameMode): """成绩提交后刷新用户缓存""" try: # 使相关缓存失效(包括 v1 和 v2) @@ -367,24 +335,12 @@ class UserCacheService: continue return { - "cached_users": len( - [ - k - for k in user_keys - if ":scores:" not in k and ":beatmapsets:" not in k - ] - ), - "cached_v1_users": len( - [k for k in v1_user_keys if ":scores:" not in k] - ), + "cached_users": len([k for k in user_keys if ":scores:" not in k and ":beatmapsets:" not in k]), + "cached_v1_users": len([k for k in v1_user_keys if ":scores:" not in k]), "cached_user_scores": len([k for k in user_keys if ":scores:" in k]), - "cached_user_beatmapsets": len( - [k for k in user_keys if ":beatmapsets:" in k] - ), + "cached_user_beatmapsets": len([k for k in user_keys if ":beatmapsets:" in k]), "total_cached_entries": len(all_keys), - "estimated_total_size_mb": ( - round(total_size / 1024 / 1024, 2) if total_size > 0 else 0 - ), + "estimated_total_size_mb": (round(total_size / 1024 / 1024, 2) if total_size > 0 else 0), "refreshing": self._refreshing, } except Exception as e: diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index cfa4739..88bf7af 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -145,14 +145,9 @@ class Hub[TState: UserState]: connection: WebSocket, ) -> Client: if connection_token in self.clients: - raise ValueError( - f"Client with connection token {connection_token} already exists." - ) + raise ValueError(f"Client with connection token {connection_token} already exists.") if connection_token in self.waited_clients: - if ( - self.waited_clients[connection_token] - < time.time() - settings.signalr_negotiate_timeout - ): + if self.waited_clients[connection_token] < time.time() - settings.signalr_negotiate_timeout: raise TimeoutError(f"Connection {connection_id} has waited too long.") del self.waited_clients[connection_token] client = Client(connection_id, connection_token, connection, protocol) @@ -196,9 +191,7 @@ class Hub[TState: UserState]: try: await client.send_packet(packet) except WebSocketDisconnect as e: - logger.info( - f"Client {client.connection_id} disconnected: {e.code}, {e.reason}" - ) + logger.info(f"Client {client.connection_id} disconnected: {e.code}, {e.reason}") await self.remove_client(client) except RuntimeError as e: if "disconnect message" in str(e): @@ -216,9 +209,7 @@ class Hub[TState: UserState]: tasks.append(self.call_noblock(client, method, *args)) await asyncio.gather(*tasks) - async def broadcast_group_call( - self, group_id: str, method: str, *args: Any - ) -> None: + async def broadcast_group_call(self, group_id: str, method: str, *args: Any) -> None: tasks = [] for client in self.groups.get(group_id, []): tasks.append(self.call_noblock(client, method, *args)) @@ -241,9 +232,7 @@ class Hub[TState: UserState]: self.tasks.add(task) task.add_done_callback(self.tasks.discard) except WebSocketDisconnect as e: - logger.info( - f"Client {client.connection_id} disconnected: {e.code}, {e.reason}" - ) + logger.info(f"Client {client.connection_id} disconnected: {e.code}, {e.reason}") except RuntimeError as e: if "disconnect message" in str(e): logger.info(f"Client {client.connection_id} closed the connection.") @@ -251,12 +240,8 @@ class Hub[TState: UserState]: logger.exception(f"RuntimeError in client {client.connection_id}: {e}") except CloseConnection as e: if not e.from_client: - await client.send_packet( - ClosePacket(error=e.message, allow_reconnect=e.allow_reconnect) - ) - logger.info( - f"Client {client.connection_id} closed the connection: {e.message}" - ) + await client.send_packet(ClosePacket(error=e.message, allow_reconnect=e.allow_reconnect)) + logger.info(f"Client {client.connection_id} closed the connection: {e.message}") except Exception: logger.exception(f"Error in client {client.connection_id}") @@ -273,15 +258,9 @@ class Hub[TState: UserState]: result = await self.invoke_method(client, packet.target, args) except InvokeException as e: error = e.message - logger.debug( - f"Client {client.connection_token} call {packet.target}" - f" failed: {error}" - ) + logger.debug(f"Client {client.connection_token} call {packet.target} failed: {error}") except Exception: - logger.exception( - f"Error invoking method {packet.target} for " - f"client {client.connection_id}" - ) + logger.exception(f"Error invoking method {packet.target} for client {client.connection_id}") error = "Unknown error occured in server" if packet.invocation_id is not None: await client.send_packet( @@ -303,9 +282,7 @@ class Hub[TState: UserState]: for name, param in signature.parameters.items(): if name == "self" or param.annotation is Client: continue - call_params.append( - client.protocol.validate_object(args.pop(0), param.annotation) - ) + call_params.append(client.protocol.validate_object(args.pop(0), param.annotation)) return await method_(client, *call_params) async def call(self, client: Client, method: str, *args: Any) -> Any: diff --git a/app/signalr/hub/metadata.py b/app/signalr/hub/metadata.py index 757e9c8..34a283d 100644 --- a/app/signalr/hub/metadata.py +++ b/app/signalr/hub/metadata.py @@ -13,7 +13,7 @@ from app.database.playlist_best_score import PlaylistBestScore from app.database.playlists import Playlist from app.database.room import Room from app.database.score import Score -from app.dependencies.database import get_redis, with_db +from app.dependencies.database import with_db from app.log import logger from app.models.metadata_hub import ( TOTAL_SCORE_DISTRIBUTION_BINS, @@ -44,13 +44,8 @@ class MetadataHub(Hub[MetadataClientState]): self._today = datetime.now(UTC).date() self._lock = asyncio.Lock() - def get_daily_challenge_stats( - self, daily_challenge_room: int - ) -> MultiplayerRoomStats: - if ( - self._daily_challenge_stats is None - or self._today != datetime.now(UTC).date() - ): + def get_daily_challenge_stats(self, daily_challenge_room: int) -> MultiplayerRoomStats: + if self._daily_challenge_stats is None or self._today != datetime.now(UTC).date(): self._daily_challenge_stats = MultiplayerRoomStats( room_id=daily_challenge_room, playlist_item_stats={}, @@ -65,9 +60,7 @@ class MetadataHub(Hub[MetadataClientState]): def room_watcher_group(room_id: int) -> str: return f"metadata:multiplayer-room-watchers:{room_id}" - def broadcast_tasks( - self, user_id: int, store: MetadataClientState | None - ) -> set[Coroutine]: + def broadcast_tasks(self, user_id: int, store: MetadataClientState | None) -> set[Coroutine]: if store is not None and not store.pushable: return set() data = store.for_push if store else None @@ -96,18 +89,15 @@ class MetadataHub(Hub[MetadataClientState]): # Use centralized offline status management from app.service.online_status_manager import online_status_manager + await online_status_manager.set_user_offline(user_id) if state.pushable: await asyncio.gather(*self.broadcast_tasks(user_id, None)) - + async with with_db() as session: async with session.begin(): - user = ( - await session.exec( - select(User).where(User.id == int(state.connection_id)) - ) - ).one() + user = (await session.exec(select(User).where(User.id == int(state.connection_id)))).one() user.last_visit = datetime.now(UTC) await session.commit() @@ -124,6 +114,7 @@ class MetadataHub(Hub[MetadataClientState]): # Use centralized online status management from app.service.online_status_manager import online_status_manager + await online_status_manager.set_user_online(user_id, "metadata") # CRITICAL FIX: Set online status IMMEDIATELY upon connection @@ -143,20 +134,14 @@ class MetadataHub(Hub[MetadataClientState]): ).all() tasks = [] for friend_id in friends: - self.groups.setdefault( - self.friend_presence_watchers_group(friend_id), set() - ).add(client) - if ( - friend_state := self.state.get(friend_id) - ) and friend_state.pushable: + self.groups.setdefault(self.friend_presence_watchers_group(friend_id), set()).add(client) + if (friend_state := self.state.get(friend_id)) and friend_state.pushable: tasks.append( self.broadcast_group_call( self.friend_presence_watchers_group(friend_id), "FriendPresenceUpdated", friend_id, - friend_state.for_push - if friend_state.pushable - else None, + friend_state.for_push if friend_state.pushable else None, ) ) await asyncio.gather(*tasks) @@ -177,7 +162,7 @@ class MetadataHub(Hub[MetadataClientState]): room_id=daily_challenge_room.id, ), ) - + # CRITICAL FIX: Immediately broadcast the user's online status to all watchers # This ensures the user appears as "currently online" right after connection # Similar to the C# implementation's immediate broadcast logic @@ -185,7 +170,7 @@ class MetadataHub(Hub[MetadataClientState]): if online_presence_tasks: await asyncio.gather(*online_presence_tasks) logger.info(f"[MetadataHub] Broadcasted online status for user {user_id} to watchers") - + # Also send the user's own presence update to confirm online status await self.call_noblock( client, @@ -213,9 +198,7 @@ class MetadataHub(Hub[MetadataClientState]): ) await asyncio.gather(*tasks) - async def UpdateActivity( - self, client: Client, activity: UserActivity | None - ) -> None: + async def UpdateActivity(self, client: Client, activity: UserActivity | None) -> None: user_id = int(client.connection_id) store = self.get_or_create_state(client) store.activity = activity @@ -246,15 +229,16 @@ class MetadataHub(Hub[MetadataClientState]): ] ) self.add_to_group(client, self.online_presence_watchers_group()) - logger.info(f"[MetadataHub] Client {client.connection_id} now watching user presence, sent {len([s for s in self.state.values() if s.pushable])} online users") + logger.info( + f"[MetadataHub] Client {client.connection_id} now watching user presence, " + f"sent {len([s for s in self.state.values() if s.pushable])} online users" + ) async def EndWatchingUserPresence(self, client: Client) -> None: self.remove_from_group(client, self.online_presence_watchers_group()) async def notify_room_score_processed(self, event: MultiplayerRoomScoreSetEvent): - await self.broadcast_group_call( - self.room_watcher_group(event.room_id), "MultiplayerRoomScoreSet", event - ) + await self.broadcast_group_call(self.room_watcher_group(event.room_id), "MultiplayerRoomScoreSet", event) async def BeginWatchingMultiplayerRoom(self, client: Client, room_id: int): self.add_to_group(client, self.room_watcher_group(room_id)) @@ -289,9 +273,7 @@ class MetadataHub(Hub[MetadataClientState]): PlaylistBestScore.room_id == stats.room_id, PlaylistBestScore.playlist_id == playlist_id, PlaylistBestScore.score_id > last_processed_score_id, - col(PlaylistBestScore.score).has( - col(Score.passed).is_(True) - ), + col(PlaylistBestScore.score).has(col(Score.passed).is_(True)), ) ) ).all() @@ -311,17 +293,13 @@ class MetadataHub(Hub[MetadataClientState]): ) totals[bin_index] += 1 - item.cumulative_score += sum( - score.total_score for score in scores - ) + item.cumulative_score += sum(score.total_score for score in scores) for j in range(TOTAL_SCORE_DISTRIBUTION_BINS): item.total_score_distribution[j] += totals.get(j, 0) if scores: - item.last_processed_score_id = max( - score.score_id for score in scores - ) + item.last_processed_score_id = max(score.score_id for score in scores) async def EndWatchingMultiplayerRoom(self, client: Client, room_id: int): self.remove_from_group(client, self.room_watcher_group(room_id)) diff --git a/app/signalr/hub/multiplayer.py b/app/signalr/hub/multiplayer.py index 2ea759a..ca850c8 100644 --- a/app/signalr/hub/multiplayer.py +++ b/app/signalr/hub/multiplayer.py @@ -114,9 +114,7 @@ class MultiplayerEventLogger: ) await self.log_event(event) - async def game_started( - self, room_id: int, playlist_item_id: int, details: MatchStartedEventDetail - ): + async def game_started(self, room_id: int, playlist_item_id: int, details: MatchStartedEventDetail): event = MultiplayerEvent( room_id=room_id, playlist_item_id=playlist_item_id, @@ -166,6 +164,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): # Use centralized offline status management from app.service.online_status_manager import online_status_manager + await online_status_manager.set_user_offline(user_id) if state.room_id != 0 and state.room_id in self.rooms: @@ -173,9 +172,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room = server_room.room user = next((u for u in room.users if u.user_id == user_id), None) if user is not None: - await self.make_user_leave( - self.get_client_by_id(str(user_id)), server_room, user - ) + await self.make_user_leave(self.get_client_by_id(str(user_id)), server_room, user) async def on_client_connect(self, client: Client) -> None: """Track online users when connecting to multiplayer hub""" @@ -183,6 +180,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): # Use centralized online status management from app.service.online_status_manager import online_status_manager + await online_status_manager.set_user_online(client.user_id, "multiplayer") def _ensure_in_room(self, client: Client) -> ServerMultiplayerRoom: @@ -212,9 +210,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): type=room.settings.match_type, queue_mode=room.settings.queue_mode, auto_skip=room.settings.auto_skip, - auto_start_duration=int( - room.settings.auto_start_duration.total_seconds() - ), + auto_start_duration=int(room.settings.auto_start_duration.total_seconds()), host_id=client.user_id, status=RoomStatus.IDLE, ) @@ -231,26 +227,20 @@ class MultiplayerHub(Hub[MultiplayerClientState]): await session.commit() await session.refresh(channel) await session.refresh(db_room) - room.channel_id = channel.channel_id # pyright: ignore[reportAttributeAccessIssue] + room.channel_id = channel.channel_id db_room.channel_id = channel.channel_id item = room.playlist[0] item.owner_id = client.user_id room.room_id = db_room.id starts_at = db_room.starts_at or datetime.now(UTC) - beatmap_exists = await session.exec( - select(exists().where(col(Beatmap.id) == item.beatmap_id)) - ) + beatmap_exists = await session.exec(select(exists().where(col(Beatmap.id) == item.beatmap_id))) if not beatmap_exists.one(): fetcher = await get_fetcher() try: - await Beatmap.get_or_fetch( - session, fetcher, bid=item.beatmap_id - ) + await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id) except HTTPError: - raise InvokeException( - "Failed to fetch beatmap, please retry later" - ) + raise InvokeException("Failed to fetch beatmap, please retry later") await Playlist.add_to_db(item, room.room_id, session) server_room = ServerMultiplayerRoom( @@ -262,9 +252,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): self.rooms[room.room_id] = server_room await server_room.set_handler() await self.event_logger.room_created(room.room_id, client.user_id) - return await self.JoinRoomWithPassword( - client, room.room_id, room.settings.password - ) + return await self.JoinRoomWithPassword(client, room.room_id, room.settings.password) async def JoinRoom(self, client: Client, room_id: int): return self.JoinRoomWithPassword(client, room_id, "") @@ -350,9 +338,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): beatmap_availability, ) - async def ChangeBeatmapAvailability( - self, client: Client, beatmap_availability: BeatmapAvailability - ): + async def ChangeBeatmapAvailability(self, client: Client, beatmap_availability: BeatmapAvailability): server_room = self._ensure_in_room(client) room = server_room.room user = next((u for u in room.users if u.user_id == client.user_id), None) @@ -371,10 +357,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): user = next((u for u in room.users if u.user_id == client.user_id), None) if user is None: raise InvokeException("You are not in this room") - logger.info( - f"[MultiplayerHub] {client.user_id} adding " - f"beatmap {item.beatmap_id} to room {room.room_id}" - ) + logger.info(f"[MultiplayerHub] {client.user_id} adding beatmap {item.beatmap_id} to room {room.room_id}") await server_room.queue.add_item( item, user, @@ -388,10 +371,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if user is None: raise InvokeException("You are not in this room") - logger.info( - f"[MultiplayerHub] {client.user_id} editing " - f"item {item.id} in room {room.room_id}" - ) + logger.info(f"[MultiplayerHub] {client.user_id} editing item {item.id} in room {room.room_id}") await server_room.queue.edit_item( item, user, @@ -405,10 +385,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if user is None: raise InvokeException("You are not in this room") - logger.info( - f"[MultiplayerHub] {client.user_id} removing " - f"item {item_id} from room {room.room_id}" - ) + logger.info(f"[MultiplayerHub] {client.user_id} removing item {item_id} from room {room.room_id}") await server_room.queue.remove_item( item_id, user, @@ -424,9 +401,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): type=room.room.settings.match_type, queue_mode=room.room.settings.queue_mode, auto_skip=room.room.settings.auto_skip, - auto_start_duration=int( - room.room.settings.auto_start_duration.total_seconds() - ), + auto_start_duration=int(room.room.settings.auto_start_duration.total_seconds()), host_id=room.room.host.user_id if room.room.host else None, ) ) @@ -456,9 +431,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): item_id, ) - async def playlist_changed( - self, room: ServerMultiplayerRoom, item: PlaylistItem, beatmap_changed: bool - ): + async def playlist_changed(self, room: ServerMultiplayerRoom, item: PlaylistItem, beatmap_changed: bool): if item.id == room.room.settings.playlist_item_id: await self.validate_styles(room) await self.unready_all_users(room, beatmap_changed) @@ -468,9 +441,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): item, ) - async def ChangeUserStyle( - self, client: Client, beatmap_id: int | None, ruleset_id: int | None - ): + async def ChangeUserStyle(self, client: Client, beatmap_id: int | None, ruleset_id: int | None): server_room = self._ensure_in_room(client) room = server_room.room user = next((u for u in room.users if u.user_id == client.user_id), None) @@ -496,9 +467,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) async with with_db() as session: try: - beatmap = await Beatmap.get_or_fetch( - session, fetcher, bid=room.queue.current_item.beatmap_id - ) + beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=room.queue.current_item.beatmap_id) except HTTPError: raise InvokeException("Current item beatmap not found") beatmap_ids = ( @@ -518,11 +487,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if beatmap_id is not None and user_beatmap is None: beatmap_id = None beatmap_ruleset = user_beatmap[1] if user_beatmap else beatmap.mode - if ( - ruleset_id is not None - and beatmap_ruleset != GameMode.OSU - and ruleset_id != beatmap_ruleset - ): + if ruleset_id is not None and beatmap_ruleset != GameMode.OSU and ruleset_id != beatmap_ruleset: ruleset_id = None await self.change_user_style( beatmap_id, @@ -532,9 +497,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) for user in room.room.users: - is_valid, valid_mods = room.queue.current_item.validate_user_mods( - user, user.mods - ) + is_valid, valid_mods = room.queue.current_item.validate_user_mods(user, user.mods) if not is_valid: await self.change_user_mods(valid_mods, room, user) @@ -553,34 +516,24 @@ class MultiplayerHub(Hub[MultiplayerClientState]): raise InvokeException("Current item does not allow free user styles.") async with with_db() as session: - item_beatmap = await session.get( - Beatmap, room.queue.current_item.beatmap_id - ) + item_beatmap = await session.get(Beatmap, room.queue.current_item.beatmap_id) if item_beatmap is None: raise InvokeException("Item beatmap not found") - user_beatmap = ( - item_beatmap - if beatmap_id is None - else await session.get(Beatmap, beatmap_id) - ) + user_beatmap = item_beatmap if beatmap_id is None else await session.get(Beatmap, beatmap_id) if user_beatmap is None: raise InvokeException("Invalid beatmap selected.") if user_beatmap.beatmapset_id != item_beatmap.beatmapset_id: - raise InvokeException( - "Selected beatmap is not from the same beatmap set." - ) + raise InvokeException("Selected beatmap is not from the same beatmap set.") if ( ruleset_id is not None and user_beatmap.mode != GameMode.OSU and ruleset_id != int(user_beatmap.mode) ): - raise InvokeException( - "Selected ruleset is not supported for the given beatmap." - ) + raise InvokeException("Selected ruleset is not supported for the given beatmap.") user.beatmap_id = beatmap_id user.ruleset_id = ruleset_id @@ -608,16 +561,10 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room: ServerMultiplayerRoom, user: MultiplayerRoomUser, ): - is_valid, valid_mods = room.queue.current_item.validate_user_mods( - user, new_mods - ) + is_valid, valid_mods = room.queue.current_item.validate_user_mods(user, new_mods) if not is_valid: - incompatible_mods = [ - mod["acronym"] for mod in new_mods if mod not in valid_mods - ] - raise InvokeException( - f"Incompatible mods were selected: {','.join(incompatible_mods)}" - ) + incompatible_mods = [mod["acronym"] for mod in new_mods if mod not in valid_mods] + raise InvokeException(f"Incompatible mods were selected: {','.join(incompatible_mods)}") if user.mods == valid_mods: return @@ -640,16 +587,12 @@ class MultiplayerHub(Hub[MultiplayerClientState]): match new: case MultiplayerUserState.IDLE: if old.is_playing: - raise InvokeException( - "Cannot return to idle without aborting gameplay." - ) + raise InvokeException("Cannot return to idle without aborting gameplay.") case MultiplayerUserState.READY: if old != MultiplayerUserState.IDLE: raise InvokeException(f"Cannot change state from {old} to {new}") if room.queue.current_item.expired: - raise InvokeException( - "Cannot ready up while all items have been played." - ) + raise InvokeException("Cannot ready up while all items have been played.") case MultiplayerUserState.WAITING_FOR_LOAD: raise InvokeException(f"Cannot change state from {old} to {new}") case MultiplayerUserState.LOADED: @@ -688,9 +631,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): MultiplayerRoomState.PLAYING, ) ): - raise InvokeException( - f"Cannot change state from {old} to {new}" - ) + raise InvokeException(f"Cannot change state from {old} to {new}") case _: raise InvokeException(f"Invalid state transition from {old} to {new}") @@ -713,9 +654,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if not user.state.is_playing: return - logger.info( - f"[MultiplayerHub] User {user.user_id} changing state from {user.state} to {state}" - ) + logger.info(f"[MultiplayerHub] User {user.user_id} changing state from {user.state} to {state}") await self.validate_user_stare( server_room, @@ -737,10 +676,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): user: MultiplayerRoomUser, state: MultiplayerUserState, ): - logger.info( - f"[MultiplayerHub] {user.user_id}'s state " - f"changed from {user.state} to {state}" - ) + logger.info(f"[MultiplayerHub] {user.user_id}'s state changed from {user.state} to {state}") user.state = state await self.broadcast_group_call( self.group_id(room.room.room_id), @@ -760,23 +696,17 @@ class MultiplayerHub(Hub[MultiplayerClientState]): # If switching to spectating during gameplay, immediately request load if room_state == MultiplayerRoomState.WAITING_FOR_LOAD: - logger.info( - f"[MultiplayerHub] Spectator {user.user_id} joining during load phase" - ) + logger.info(f"[MultiplayerHub] Spectator {user.user_id} joining during load phase") await self.call_noblock(client, "LoadRequested") elif room_state == MultiplayerRoomState.PLAYING: - logger.info( - f"[MultiplayerHub] Spectator {user.user_id} joining during active gameplay" - ) + logger.info(f"[MultiplayerHub] Spectator {user.user_id} joining during active gameplay") await self.call_noblock(client, "LoadRequested") - + # Also sync the spectator with current game state await self._send_current_gameplay_state_to_spectator(client, room) - async def _send_current_gameplay_state_to_spectator( - self, client: Client, room: ServerMultiplayerRoom - ): + async def _send_current_gameplay_state_to_spectator(self, client: Client, room: ServerMultiplayerRoom): """ Send current gameplay state information to a newly joined spectator. This helps spectators sync with ongoing gameplay. @@ -794,27 +724,20 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room_user.user_id, room_user.state, ) - + # If the room is in OPEN state but we have users in RESULTS state, # this means the game just finished and we should send ResultsReady - if (room.room.state == MultiplayerRoomState.OPEN and - any(u.state == MultiplayerUserState.RESULTS for u in room.room.users)): - logger.debug( - f"[MultiplayerHub] Sending ResultsReady to new spectator {client.user_id}" - ) + if room.room.state == MultiplayerRoomState.OPEN and any( + u.state == MultiplayerUserState.RESULTS for u in room.room.users + ): + logger.debug(f"[MultiplayerHub] Sending ResultsReady to new spectator {client.user_id}") await self.call_noblock(client, "ResultsReady") - logger.debug( - f"[MultiplayerHub] Sent current gameplay state to spectator {client.user_id}" - ) + logger.debug(f"[MultiplayerHub] Sent current gameplay state to spectator {client.user_id}") except Exception as e: - logger.error( - f"[MultiplayerHub] Failed to send gameplay state to spectator {client.user_id}: {e}" - ) + logger.error(f"[MultiplayerHub] Failed to send gameplay state to spectator {client.user_id}: {e}") - async def _send_room_state_to_new_user( - self, client: Client, room: ServerMultiplayerRoom - ): + async def _send_room_state_to_new_user(self, client: Client, room: ServerMultiplayerRoom): """ Send complete room state to a newly joined user. Critical for spectators joining ongoing games. @@ -847,28 +770,21 @@ class MultiplayerHub(Hub[MultiplayerClientState]): # Critical fix: If room is OPEN but has users in RESULTS state, # send ResultsReady to new joiners (including spectators) - if (room.room.state == MultiplayerRoomState.OPEN and - any(u.state == MultiplayerUserState.RESULTS for u in room.room.users)): - logger.info( - f"[MultiplayerHub] Sending ResultsReady to newly joined user {client.user_id}" - ) + if room.room.state == MultiplayerRoomState.OPEN and any( + u.state == MultiplayerUserState.RESULTS for u in room.room.users + ): + logger.info(f"[MultiplayerHub] Sending ResultsReady to newly joined user {client.user_id}") await self.call_noblock(client, "ResultsReady") # Critical addition: Send current playing users to SpectatorHub for cross-hub sync # This ensures spectators can watch multiplayer players properly await self._sync_with_spectator_hub(client, room) - logger.debug( - f"[MultiplayerHub] Sent complete room state to new user {client.user_id}" - ) + logger.debug(f"[MultiplayerHub] Sent complete room state to new user {client.user_id}") except Exception as e: - logger.error( - f"[MultiplayerHub] Failed to send room state to user {client.user_id}: {e}" - ) + logger.error(f"[MultiplayerHub] Failed to send room state to user {client.user_id}: {e}") - async def _sync_with_spectator_hub( - self, client: Client, room: ServerMultiplayerRoom - ): + async def _sync_with_spectator_hub(self, client: Client, room: ServerMultiplayerRoom): """ Sync with SpectatorHub to ensure cross-hub spectating works properly. This is crucial for users watching multiplayer players from other pages. @@ -893,13 +809,16 @@ class MultiplayerHub(Hub[MultiplayerClientState]): f"[MultiplayerHub] Synced spectator state for user {room_user.user_id} " f"to new client {client.user_id}" ) - + # Critical addition: Notify SpectatorHub about users in RESULTS state elif room_user.state == MultiplayerUserState.RESULTS: # Create a synthetic finished state for cross-hub spectating try: - from app.models.spectator_hub import SpectatedUserState, SpectatorState - + from app.models.spectator_hub import ( + SpectatedUserState, + SpectatorState, + ) + finished_state = SpectatorState( beatmap_id=room.queue.current_item.beatmap_id, ruleset_id=room_user.ruleset_id or 0, @@ -919,9 +838,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): f"to client {client.user_id}" ) except Exception as e: - logger.debug( - f"[MultiplayerHub] Failed to create synthetic finished state: {e}" - ) + logger.debug(f"[MultiplayerHub] Failed to create synthetic finished state: {e}") except Exception as e: logger.debug(f"[MultiplayerHub] Failed to sync with SpectatorHub: {e}") @@ -933,75 +850,55 @@ class MultiplayerHub(Hub[MultiplayerClientState]): if room.room.settings.auto_start_enabled: if ( not room.queue.current_item.expired - and any( - u.state == MultiplayerUserState.READY - for u in room.room.users - ) + and any(u.state == MultiplayerUserState.READY for u in room.room.users) and not any( - isinstance(countdown, MatchStartCountdown) - for countdown in room.room.active_countdowns + isinstance(countdown, MatchStartCountdown) for countdown in room.room.active_countdowns ) ): await room.start_countdown( - MatchStartCountdown( - time_remaining=room.room.settings.auto_start_duration - ), + MatchStartCountdown(time_remaining=room.room.settings.auto_start_duration), self.start_match, ) case MultiplayerRoomState.WAITING_FOR_LOAD: - played_count = len( - [True for user in room.room.users if user.state.is_playing] - ) + played_count = len([True for user in room.room.users if user.state.is_playing]) ready_count = len( - [ - True - for user in room.room.users - if user.state == MultiplayerUserState.READY_FOR_GAMEPLAY - ] + [True for user in room.room.users if user.state == MultiplayerUserState.READY_FOR_GAMEPLAY] ) if played_count == ready_count: await self.start_gameplay(room) case MultiplayerRoomState.PLAYING: - if all( - u.state != MultiplayerUserState.PLAYING for u in room.room.users - ): + if all(u.state != MultiplayerUserState.PLAYING for u in room.room.users): any_user_finished_playing = False - + # Handle finished players first for u in filter( lambda u: u.state == MultiplayerUserState.FINISHED_PLAY, room.room.users, ): any_user_finished_playing = True - await self.change_user_state( - room, u, MultiplayerUserState.RESULTS - ) - + await self.change_user_state(room, u, MultiplayerUserState.RESULTS) + # Critical fix: Handle spectators who should also see results # Move spectators to RESULTS state so they can see the results screen for u in filter( lambda u: u.state == MultiplayerUserState.SPECTATING, room.room.users, ): - logger.debug( - f"[MultiplayerHub] Moving spectator {u.user_id} to RESULTS state" - ) - await self.change_user_state( - room, u, MultiplayerUserState.RESULTS - ) - + logger.debug(f"[MultiplayerHub] Moving spectator {u.user_id} to RESULTS state") + await self.change_user_state(room, u, MultiplayerUserState.RESULTS) + await self.change_room_state(room, MultiplayerRoomState.OPEN) - + # Send ResultsReady to all room members await self.broadcast_group_call( self.group_id(room.room.room_id), "ResultsReady", ) - + # Critical addition: Notify SpectatorHub about finished games # This ensures cross-hub spectating works properly await self._notify_spectator_hub_game_ended(room) - + if any_user_finished_playing: await self.event_logger.game_completed( room.room.room_id, @@ -1014,13 +911,8 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) await room.queue.finish_current_item() - async def change_room_state( - self, room: ServerMultiplayerRoom, state: MultiplayerRoomState - ): - logger.debug( - f"[MultiplayerHub] Room {room.room.room_id} state " - f"changed from {room.room.state} to {state}" - ) + async def change_room_state(self, room: ServerMultiplayerRoom, state: MultiplayerRoomState): + logger.debug(f"[MultiplayerHub] Room {room.room.room_id} state changed from {room.room.state} to {state}") room.room.state = state await self.broadcast_group_call( self.group_id(room.room.room_id), @@ -1064,10 +956,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): u for u in room.room.users if u.availability.state == DownloadState.LOCALLY_AVAILABLE - and ( - u.state == MultiplayerUserState.READY - or u.state == MultiplayerUserState.IDLE - ) + and (u.state == MultiplayerUserState.READY or u.state == MultiplayerUserState.IDLE) ] for u in ready_users: await self.change_user_state(room, u, MultiplayerUserState.WAITING_FOR_LOAD) @@ -1080,9 +969,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): "LoadRequested", ) await room.start_countdown( - ForceGameplayStartCountdown( - time_remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT) - ), + ForceGameplayStartCountdown(time_remaining=timedelta(seconds=GAMEPLAY_LOAD_TIMEOUT)), self.start_gameplay, ) await self.event_logger.game_started( @@ -1133,9 +1020,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): else: await room.queue.finish_current_item() - async def send_match_event( - self, room: ServerMultiplayerRoom, event: MatchServerEvent - ): + async def send_match_event(self, room: ServerMultiplayerRoom, event: MatchServerEvent): await self.broadcast_group_call( self.group_id(room.room.room_id), "MatchEvent", @@ -1183,24 +1068,16 @@ class MultiplayerHub(Hub[MultiplayerClientState]): await self.end_room(room) return await self.update_room_state(room) - if ( - len(room.room.users) != 0 - and room.room.host - and room.room.host.user_id == user.user_id - ): + if len(room.room.users) != 0 and room.room.host and room.room.host.user_id == user.user_id: next_host = room.room.users[0] await self.set_host(room, next_host) if kicked: if client: await self.call_noblock(client, "UserKicked", user) - await self.broadcast_group_call( - self.group_id(room.room.room_id), "UserKicked", user - ) + await self.broadcast_group_call(self.group_id(room.room.room_id), "UserKicked", user) else: - await self.broadcast_group_call( - self.group_id(room.room.room_id), "UserLeft", user - ) + await self.broadcast_group_call(self.group_id(room.room.room_id), "UserLeft", user) async def end_room(self, room: ServerMultiplayerRoom): assert room.room.host @@ -1214,9 +1091,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): type=room.room.settings.match_type, queue_mode=room.room.settings.queue_mode, auto_skip=room.room.settings.auto_skip, - auto_start_duration=int( - room.room.settings.auto_start_duration.total_seconds() - ), + auto_start_duration=int(room.room.settings.auto_start_duration.total_seconds()), host_id=room.room.host.user_id, ) ) @@ -1262,10 +1137,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) target_client = self.get_client_by_id(str(user.user_id)) await self.make_user_leave(target_client, server_room, user, kicked=True) - logger.info( - f"[MultiplayerHub] {user.user_id} was kicked from room {room.room_id}" - f"by {client.user_id}" - ) + logger.info(f"[MultiplayerHub] {user.user_id} was kicked from room {room.room_id}by {client.user_id}") async def set_host(self, room: ServerMultiplayerRoom, user: MultiplayerRoomUser): room.room.host = user @@ -1289,10 +1161,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): new_host.user_id, ) await self.set_host(server_room, new_host) - logger.info( - f"[MultiplayerHub] {client.user_id} transferred host to {new_host.user_id}" - f" in room {room.room_id}" - ) + logger.info(f"[MultiplayerHub] {client.user_id} transferred host to {new_host.user_id} in room {room.room_id}") async def AbortGameplay(self, client: Client): server_room = self._ensure_in_room(client) @@ -1316,10 +1185,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room = server_room.room self._ensure_host(client, server_room) - if ( - room.state != MultiplayerRoomState.PLAYING - and room.state != MultiplayerRoomState.WAITING_FOR_LOAD - ): + if room.state != MultiplayerRoomState.PLAYING and room.state != MultiplayerRoomState.WAITING_FOR_LOAD: raise InvokeException("Cannot abort a match that hasn't started.") await asyncio.gather( @@ -1335,13 +1201,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): GameplayAbortReason.HOST_ABORTED, ) await self.update_room_state(server_room) - logger.info( - f"[MultiplayerHub] {client.user_id} aborted match in room {room.room_id}" - ) + logger.info(f"[MultiplayerHub] {client.user_id} aborted match in room {room.room_id}") - async def change_user_match_state( - self, room: ServerMultiplayerRoom, user: MultiplayerRoomUser - ): + async def change_user_match_state(self, room: ServerMultiplayerRoom, user: MultiplayerRoomUser): await self.broadcast_group_call( self.group_id(room.room.room_id), "MatchUserStateChanged", @@ -1402,10 +1264,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): ) if countdown is None: return - if ( - isinstance(countdown, MatchStartCountdown) - and room.settings.auto_start_enabled - ) or isinstance( + if (isinstance(countdown, MatchStartCountdown) and room.settings.auto_start_enabled) or isinstance( countdown, (ForceGameplayStartCountdown | ServerShuttingDownCountdown) ): raise InvokeException("Cannot stop the requested countdown") @@ -1447,25 +1306,16 @@ class MultiplayerHub(Hub[MultiplayerClientState]): raise InvokeException("User already invited") if db_user.is_restricted: raise InvokeException("User is restricted") - if ( - inviter_relationship - and inviter_relationship.type == RelationshipType.BLOCK - ): + if inviter_relationship and inviter_relationship.type == RelationshipType.BLOCK: raise InvokeException("Cannot perform action due to user being blocked") - if ( - target_relationship - and target_relationship.type == RelationshipType.BLOCK - ): + if target_relationship and target_relationship.type == RelationshipType.BLOCK: raise InvokeException("Cannot perform action due to user being blocked") if ( db_user.pm_friends_only and target_relationship is not None and target_relationship.type != RelationshipType.FOLLOW ): - raise InvokeException( - "Cannot perform action " - "because user has disabled non-friend communications" - ) + raise InvokeException("Cannot perform action because user has disabled non-friend communications") target_client = self.get_client_by_id(str(user_id)) if target_client is None: @@ -1478,9 +1328,7 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room.settings.password, ) - async def unready_all_users( - self, room: ServerMultiplayerRoom, reset_beatmap_availability: bool - ): + async def unready_all_users(self, room: ServerMultiplayerRoom, reset_beatmap_availability: bool): await asyncio.gather( *[ self.change_user_state( @@ -1512,8 +1360,8 @@ class MultiplayerHub(Hub[MultiplayerClientState]): """ try: # Import here to avoid circular imports - from app.signalr.hub import SpectatorHubs from app.models.spectator_hub import SpectatedUserState, SpectatorState + from app.signalr.hub import SpectatorHubs # For each user who finished the game, notify SpectatorHub for room_user in room.room.users: @@ -1534,13 +1382,9 @@ class MultiplayerHub(Hub[MultiplayerClientState]): room_user.user_id, finished_state, ) - - logger.debug( - f"[MultiplayerHub] Notified SpectatorHub that user {room_user.user_id} finished game" - ) + + logger.debug(f"[MultiplayerHub] Notified SpectatorHub that user {room_user.user_id} finished game") except Exception as e: - logger.debug( - f"[MultiplayerHub] Failed to notify SpectatorHub about game end: {e}" - ) + logger.debug(f"[MultiplayerHub] Failed to notify SpectatorHub about game end: {e}") # This is not critical, so we don't raise the exception diff --git a/app/signalr/hub/spectator.py b/app/signalr/hub/spectator.py index b83585e..fce16de 100644 --- a/app/signalr/hub/spectator.py +++ b/app/signalr/hub/spectator.py @@ -31,7 +31,7 @@ from app.models.spectator_hub import ( StoreClientState, StoreScore, ) -from app.utils import unix_timestamp_to_windows +from app.utils import bg_tasks, unix_timestamp_to_windows from .hub import Client, Hub @@ -111,20 +111,14 @@ async def save_replay( last_time = 0 for frame in frames: time = round(frame.time) - frame_strs.append( - f"{time - last_time}|{frame.mouse_x or 0.0}" - f"|{frame.mouse_y or 0.0}|{frame.button_state}" - ) + frame_strs.append(f"{time - last_time}|{frame.mouse_x or 0.0}|{frame.mouse_y or 0.0}|{frame.button_state}") last_time = time frame_strs.append("-12345|0|0|0") - compressed = lzma.compress( - ",".join(frame_strs).encode("ascii"), format=lzma.FORMAT_ALONE - ) + compressed = lzma.compress(",".join(frame_strs).encode("ascii"), format=lzma.FORMAT_ALONE) data.extend(struct.pack(" None: """ @@ -198,6 +185,7 @@ class SpectatorHub(Hub[StoreClientState]): # Use centralized online status management from app.service.online_status_manager import online_status_manager + await online_status_manager.set_user_online(client.user_id, "spectator") # Send all current player states to the new client @@ -208,17 +196,13 @@ class SpectatorHub(Hub[StoreClientState]): active_states.append((user_id, store.state)) if active_states: - logger.debug( - f"[SpectatorHub] Sending {len(active_states)} active player states to {client.user_id}" - ) + logger.debug(f"[SpectatorHub] Sending {len(active_states)} active player states to {client.user_id}") # Send states sequentially to avoid overwhelming the client for user_id, state in active_states: try: await self.call_noblock(client, "UserBeganPlaying", user_id, state) except Exception as e: - logger.debug( - f"[SpectatorHub] Failed to send state for user {user_id}: {e}" - ) + logger.debug(f"[SpectatorHub] Failed to send state for user {user_id}: {e}") # Also sync with MultiplayerHub for cross-hub spectating await self._sync_with_multiplayer_hub(client) @@ -236,10 +220,7 @@ class SpectatorHub(Hub[StoreClientState]): for room_id, server_room in MultiplayerHubs.rooms.items(): for room_user in server_room.room.users: # Send state for users who are playing or in results - if ( - room_user.state.is_playing - and room_user.user_id not in self.state - ): + if room_user.state.is_playing and room_user.user_id not in self.state: # Create a synthetic SpectatorState for multiplayer players # This helps with cross-hub spectating try: @@ -261,13 +242,12 @@ class SpectatorHub(Hub[StoreClientState]): f"[SpectatorHub] Sent synthetic multiplayer state for user {room_user.user_id}" ) except Exception as e: - logger.debug( - f"[SpectatorHub] Failed to create synthetic state: {e}" - ) - + logger.debug(f"[SpectatorHub] Failed to create synthetic state: {e}") + # Critical addition: Notify about finished players in multiplayer games elif ( - hasattr(room_user.state, 'name') and room_user.state.name == 'RESULTS' + hasattr(room_user.state, "name") + and room_user.state.name == "RESULTS" and room_user.user_id not in self.state ): try: @@ -286,21 +266,15 @@ class SpectatorHub(Hub[StoreClientState]): room_user.user_id, finished_state, ) - logger.debug( - f"[SpectatorHub] Sent synthetic finished state for user {room_user.user_id}" - ) + logger.debug(f"[SpectatorHub] Sent synthetic finished state for user {room_user.user_id}") except Exception as e: - logger.debug( - f"[SpectatorHub] Failed to create synthetic finished state: {e}" - ) + logger.debug(f"[SpectatorHub] Failed to create synthetic finished state: {e}") except Exception as e: logger.debug(f"[SpectatorHub] Failed to sync with MultiplayerHub: {e}") # This is not critical, so we don't raise the exception - async def BeginPlaySession( - self, client: Client, score_token: int, state: SpectatorState - ) -> None: + async def BeginPlaySession(self, client: Client, score_token: int, state: SpectatorState) -> None: user_id = int(client.connection_id) store = self.get_or_create_state(client) if store.state is not None: @@ -312,14 +286,10 @@ class SpectatorHub(Hub[StoreClientState]): async with with_db() as session: async with session.begin(): try: - beatmap = await Beatmap.get_or_fetch( - session, fetcher, bid=state.beatmap_id - ) + beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=state.beatmap_id) except HTTPError: raise InvokeException(f"Beatmap {state.beatmap_id} not found.") - user = ( - await session.exec(select(User).where(User.id == user_id)) - ).first() + user = (await session.exec(select(User).where(User.id == user_id))).first() if not user: return name = user.username @@ -342,8 +312,8 @@ class SpectatorHub(Hub[StoreClientState]): from app.router.v2.stats import add_playing_user from app.service.online_status_manager import online_status_manager - asyncio.create_task(add_playing_user(user_id)) - + bg_tasks.add_task(add_playing_user, user_id) + # Critical fix: Maintain metadata online presence during gameplay # This ensures the user appears online while playing await online_status_manager.refresh_user_online_status(user_id, "playing") @@ -367,6 +337,7 @@ class SpectatorHub(Hub[StoreClientState]): # Critical fix: Refresh online status during active gameplay # This prevents users from appearing offline while playing from app.service.online_status_manager import online_status_manager + await online_status_manager.refresh_user_online_status(user_id, "playing_active") header = frame_data.header @@ -377,15 +348,13 @@ class SpectatorHub(Hub[StoreClientState]): score_info.statistics = header.statistics store.score.replay_frames.extend(frame_data.frames) - await self.broadcast_group_call( - self.group_id(user_id), "UserSentFrames", user_id, frame_data - ) + await self.broadcast_group_call(self.group_id(user_id), "UserSentFrames", user_id, frame_data) async def EndPlaySession(self, client: Client, state: SpectatorState) -> None: user_id = int(client.connection_id) store = self.get_or_create_state(client) score = store.score - + # Early return if no active session if ( score is None @@ -398,19 +367,19 @@ class SpectatorHub(Hub[StoreClientState]): try: # Process score if conditions are met - if ( - settings.enable_all_beatmap_leaderboard - and store.beatmap_status.has_leaderboard() - ) and any(k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()): + if (settings.enable_all_beatmap_leaderboard and store.beatmap_status.has_leaderboard()) and any( + k.is_hit() and v > 0 for k, v in score.score_info.statistics.items() + ): await self._process_score(store, client) - + # End the play session and notify watchers await self._end_session(user_id, state, store) # Remove from playing user tracking from app.router.v2.stats import remove_playing_user - asyncio.create_task(remove_playing_user(user_id)) - + + bg_tasks.add_task(remove_playing_user, user_id) + finally: # CRITICAL FIX: Always clear state in finally block to ensure cleanup # This matches the official C# implementation pattern @@ -439,9 +408,9 @@ class SpectatorHub(Hub[StoreClientState]): ) result = await session.exec( select(Score) - .options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType] + .options(joinedload(Score.beatmap)) .where( - Score.id == sub_query, + Score.id == sub_query.scalar_subquery(), Score.user_id == user_id, ) ) @@ -472,18 +441,12 @@ class SpectatorHub(Hub[StoreClientState]): frames=store.score.replay_frames, ) - async def _end_session( - self, user_id: int, state: SpectatorState, store: StoreClientState - ) -> None: + async def _end_session(self, user_id: int, state: SpectatorState, store: StoreClientState) -> None: async def _add_failtime(): async with with_db() as session: failtime = await session.get(FailTime, state.beatmap_id) total_length = ( - await session.exec( - select(Beatmap.total_length).where( - Beatmap.id == state.beatmap_id - ) - ) + await session.exec(select(Beatmap.total_length).where(Beatmap.id == state.beatmap_id)) ).one() index = clamp(round((exit_time / total_length) * 100), 0, 99) if failtime is not None: @@ -495,7 +458,8 @@ class SpectatorHub(Hub[StoreClientState]): elif state.state == SpectatedUserState.Quit: resp.exit[index] += 1 - new_failtime = FailTime.from_resp(state.beatmap_id, resp) # pyright: ignore[reportArgumentType] + assert state.beatmap_id + new_failtime = FailTime.from_resp(state.beatmap_id, resp) if failtime is not None: await session.merge(new_failtime) else: @@ -527,9 +491,7 @@ class SpectatorHub(Hub[StoreClientState]): if state.state == SpectatedUserState.Playing: state.state = SpectatedUserState.Quit - logger.debug( - f"[SpectatorHub] Changed state from Playing to Quit for user {user_id}" - ) + logger.debug(f"[SpectatorHub] Changed state from Playing to Quit for user {user_id}") # Calculate exit time safely exit_time = 0 @@ -558,10 +520,7 @@ class SpectatorHub(Hub[StoreClientState]): self.tasks.add(task) task.add_done_callback(self.tasks.discard) - logger.info( - f"[SpectatorHub] {user_id} finished playing {state.beatmap_id} " - f"with {state.state}" - ) + logger.info(f"[SpectatorHub] {user_id} finished playing {state.beatmap_id} with {state.state}") await self.broadcast_group_call( self.group_id(user_id), "UserFinishedPlaying", @@ -585,9 +544,7 @@ class SpectatorHub(Hub[StoreClientState]): # CRITICAL FIX: Only send state if user is actually playing # Don't send state for finished/quit games if target_store.state.state == SpectatedUserState.Playing: - logger.debug( - f"[SpectatorHub] {target_id} is currently playing, sending state" - ) + logger.debug(f"[SpectatorHub] {target_id} is currently playing, sending state") # Send current state to the watcher immediately await self.call_noblock( client, @@ -613,25 +570,17 @@ class SpectatorHub(Hub[StoreClientState]): # Get watcher's username and notify the target user try: async with with_db() as session: - username = ( - await session.exec(select(User.username).where(User.id == user_id)) - ).first() + username = (await session.exec(select(User.username).where(User.id == user_id))).first() if not username: - logger.warning( - f"[SpectatorHub] Could not find username for user {user_id}" - ) + logger.warning(f"[SpectatorHub] Could not find username for user {user_id}") return # Notify target user that someone started watching if (target_client := self.get_client_by_id(str(target_id))) is not None: # Create watcher info array (matches official format) watcher_info = [[user_id, username]] - await self.call_noblock( - target_client, "UserStartedWatching", watcher_info - ) - logger.debug( - f"[SpectatorHub] Notified {target_id} that {username} started watching" - ) + await self.call_noblock(target_client, "UserStartedWatching", watcher_info) + logger.debug(f"[SpectatorHub] Notified {target_id} that {username} started watching") except Exception as e: logger.error(f"[SpectatorHub] Error notifying target user {target_id}: {e}") @@ -654,10 +603,6 @@ class SpectatorHub(Hub[StoreClientState]): # Notify target user that watcher stopped watching if (target_client := self.get_client_by_id(str(target_id))) is not None: await self.call_noblock(target_client, "UserEndedWatching", user_id) - logger.debug( - f"[SpectatorHub] Notified {target_id} that {user_id} stopped watching" - ) + logger.debug(f"[SpectatorHub] Notified {target_id} that {user_id} stopped watching") else: - logger.debug( - f"[SpectatorHub] Target user {target_id} not found for end watching notification" - ) + logger.debug(f"[SpectatorHub] Target user {target_id} not found for end watching notification") diff --git a/app/signalr/packet.py b/app/signalr/packet.py index 8949f4b..4c0acdd 100644 --- a/app/signalr/packet.py +++ b/app/signalr/packet.py @@ -100,10 +100,7 @@ class MsgpackProtocol: elif issubclass(typ, datetime.timedelta): return int(v.total_seconds() * 10_000_000) elif isinstance(v, dict): - return { - cls.serialize_msgpack(k): cls.serialize_msgpack(value) - for k, value in v.items() - } + return {cls.serialize_msgpack(k): cls.serialize_msgpack(value) for k, value in v.items()} elif issubclass(typ, Enum): list_ = list(typ) return list_.index(v) if v in list_ else v.value @@ -113,9 +110,7 @@ class MsgpackProtocol: def serialize_to_list(cls, value: BaseModel) -> list[Any]: values = [] for field, info in value.__class__.model_fields.items(): - metadata = next( - (m for m in info.metadata if isinstance(m, SignalRMeta)), None - ) + metadata = next((m for m in info.metadata if isinstance(m, SignalRMeta)), None) if metadata and metadata.member_ignore: continue values.append(cls.serialize_msgpack(v=getattr(value, field))) @@ -130,9 +125,7 @@ class MsgpackProtocol: d = {} i = 0 for field, info in typ.model_fields.items(): - metadata = next( - (m for m in info.metadata if isinstance(m, SignalRMeta)), None - ) + metadata = next((m for m in info.metadata if isinstance(m, SignalRMeta)), None) if metadata and metadata.member_ignore: continue anno = info.annotation @@ -224,10 +217,7 @@ class MsgpackProtocol: return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v) elif get_origin(typ) is dict: return { - cls.validate_object(k, get_args(typ)[0]): cls.validate_object( - v, get_args(typ)[1] - ) - for k, v in v.items() + cls.validate_object(k, get_args(typ)[0]): cls.validate_object(v, get_args(typ)[1]) for k, v in v.items() } elif (origin := get_origin(typ)) is Union or origin is UnionType: args = get_args(typ) @@ -242,13 +232,8 @@ class MsgpackProtocol: # except `X (Other Type) | None` if NoneType in args and v is None: return None - if not all( - issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args - ): - raise ValueError( - f"Cannot validate {v} to {typ}, " - "only SignalRUnionMessage subclasses are supported" - ) + if not all(issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args): + raise ValueError(f"Cannot validate {v} to {typ}, only SignalRUnionMessage subclasses are supported") union_type = v[0] for arg in args: assert issubclass(arg, SignalRUnionMessage) @@ -267,9 +252,7 @@ class MsgpackProtocol: ] ) if packet.arguments is not None: - payload.append( - [MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments] - ) + payload.append([MsgpackProtocol.serialize_msgpack(arg) for arg in packet.arguments]) if packet.stream_ids is not None: payload.append(packet.stream_ids) elif isinstance(packet, CompletionPacket): @@ -282,9 +265,7 @@ class MsgpackProtocol: [ packet.invocation_id, result_kind, - packet.error - or MsgpackProtocol.serialize_msgpack(packet.result) - or None, + packet.error or MsgpackProtocol.serialize_msgpack(packet.result) or None, ] ) elif isinstance(packet, ClosePacket): @@ -307,10 +288,7 @@ class JSONProtocol: if issubclass(typ, BaseModel): return cls.serialize_model(v, in_union) elif isinstance(v, dict): - return { - cls.serialize_to_json(k, True): cls.serialize_to_json(value) - for k, value in v.items() - } + return {cls.serialize_to_json(k, True): cls.serialize_to_json(value) for k, value in v.items()} elif isinstance(v, list): return [cls.serialize_to_json(item) for item in v] elif isinstance(v, datetime.datetime): @@ -333,9 +311,7 @@ class JSONProtocol: d = {} is_union = issubclass(v.__class__, SignalRUnionMessage) for field, info in v.__class__.model_fields.items(): - metadata = next( - (m for m in info.metadata if isinstance(m, SignalRMeta)), None - ) + metadata = next((m for m in info.metadata if isinstance(m, SignalRMeta)), None) if metadata and metadata.json_ignore: continue name = ( @@ -358,14 +334,10 @@ class JSONProtocol: return d @staticmethod - def process_object( - v: Any, typ: type[BaseModel], from_union: bool = False - ) -> dict[str, Any]: + def process_object(v: Any, typ: type[BaseModel], from_union: bool = False) -> dict[str, Any]: d = {} for field, info in typ.model_fields.items(): - metadata = next( - (m for m in info.metadata if isinstance(m, SignalRMeta)), None - ) + metadata = next((m for m in info.metadata if isinstance(m, SignalRMeta)), None) if metadata and metadata.json_ignore: continue name = ( @@ -435,9 +407,7 @@ class JSONProtocol: # d.hh:mm:ss parts = v.split(":") if len(parts) == 3: - return datetime.timedelta( - hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2]) - ) + return datetime.timedelta(hours=int(parts[0]), minutes=int(parts[1]), seconds=int(parts[2])) elif len(parts) == 2: return datetime.timedelta(minutes=int(parts[0]), seconds=int(parts[1])) elif len(parts) == 1: @@ -449,10 +419,7 @@ class JSONProtocol: return list_[v] if isinstance(v, int) and 0 <= v < len(list_) else typ(v) elif get_origin(typ) is dict: return { - cls.validate_object(k, get_args(typ)[0]): cls.validate_object( - v, get_args(typ)[1] - ) - for k, v in v.items() + cls.validate_object(k, get_args(typ)[0]): cls.validate_object(v, get_args(typ)[1]) for k, v in v.items() } elif (origin := get_origin(typ)) is Union or origin is UnionType: args = get_args(typ) @@ -467,13 +434,8 @@ class JSONProtocol: # except `X (Other Type) | None` if NoneType in args and v is None: return None - if not all( - issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args - ): - raise ValueError( - f"Cannot validate {v} to {typ}, " - "only SignalRUnionMessage subclasses are supported" - ) + if not all(issubclass(arg, SignalRUnionMessage) or arg is NoneType for arg in args): + raise ValueError(f"Cannot validate {v} to {typ}, only SignalRUnionMessage subclasses are supported") # https://github.com/ppy/osu/blob/98acd9/osu.Game/Online/SignalRDerivedTypeWorkaroundJsonConverter.cs union_type = v["$dtype"] for arg in args: @@ -498,9 +460,7 @@ class JSONProtocol: if packet.invocation_id is not None: payload["invocationId"] = packet.invocation_id if packet.arguments is not None: - payload["arguments"] = [ - JSONProtocol.serialize_to_json(arg) for arg in packet.arguments - ] + payload["arguments"] = [JSONProtocol.serialize_to_json(arg) for arg in packet.arguments] if packet.stream_ids is not None: payload["streamIds"] = packet.stream_ids elif isinstance(packet, CompletionPacket): diff --git a/app/signalr/router.py b/app/signalr/router.py index 8601279..55b08fc 100644 --- a/app/signalr/router.py +++ b/app/signalr/router.py @@ -56,11 +56,9 @@ async def connect( return try: async for session in factory(): - if ( - user := await get_current_user( - session, SecurityScopes(scopes=["*"]), token_pw=token - ) - ) is None or str(user.id) != user_id: + if (user := await get_current_user(session, SecurityScopes(scopes=["*"]), token_pw=token)) is None or str( + user.id + ) != user_id: await websocket.close(code=1008) return except HTTPException: diff --git a/app/signalr/store.py b/app/signalr/store.py index 008da03..3d5591a 100644 --- a/app/signalr/store.py +++ b/app/signalr/store.py @@ -19,9 +19,7 @@ class ResultStore: self._seq = (self._seq + 1) % sys.maxsize return str(s) - def add_result( - self, invocation_id: str, result: Any, error: str | None = None - ) -> None: + def add_result(self, invocation_id: str, result: Any, error: str | None = None) -> None: if isinstance(invocation_id, str) and invocation_id.isdecimal(): if future := self._futures.get(invocation_id): future.set_result((result, error)) diff --git a/app/signalr/utils.py b/app/signalr/utils.py index 1bf84be..d7d23cf 100644 --- a/app/signalr/utils.py +++ b/app/signalr/utils.py @@ -13,9 +13,7 @@ if sys.version_info < (3, 12, 4): else: def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: - return cast(Any, type_)._evaluate( - globalns, localns, type_params=(), recursive_guard=set() - ) + return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set()) def get_annotation(param: inspect.Parameter, globalns: dict[str, Any]) -> Any: diff --git a/app/utils.py b/app/utils.py index 031b0cc..ca64f23 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,7 +1,12 @@ from __future__ import annotations +import asyncio +from collections.abc import Awaitable, Callable, Sequence from datetime import datetime +import functools +import inspect from io import BytesIO +from typing import Any, ParamSpec, TypeVar from fastapi import HTTPException from PIL import Image @@ -151,75 +156,117 @@ def check_image(content: bytes, size: int, width: int, height: int) -> None: def simplify_user_agent(user_agent: str | None, max_length: int = 200) -> str | None: """ 简化 User-Agent 字符串,只保留 osu! 和关键设备系统信息浏览器 - + Args: user_agent: 原始 User-Agent 字符串 max_length: 最大长度限制 - + Returns: 简化后的 User-Agent 字符串,或 None """ import re - + if not user_agent: return None - + # 如果长度在限制内,直接返回 if len(user_agent) <= max_length: return user_agent - + # 提取操作系统信息 os_info = "" os_patterns = [ - r'(Windows[^;)]*)', - r'(Mac OS[^;)]*)', - r'(Linux[^;)]*)', - r'(Android[^;)]*)', - r'(iOS[^;)]*)', - r'(iPhone[^;)]*)', - r'(iPad[^;)]*)' + r"(Windows[^;)]*)", + r"(Mac OS[^;)]*)", + r"(Linux[^;)]*)", + r"(Android[^;)]*)", + r"(iOS[^;)]*)", + r"(iPhone[^;)]*)", + r"(iPad[^;)]*)", ] - + for pattern in os_patterns: match = re.search(pattern, user_agent, re.IGNORECASE) if match: os_info = match.group(1).strip() break - + # 提取浏览器信息 browser_info = "" browser_patterns = [ - r'(osu![^)]*)', # osu! 客户端 - r'(Chrome/[\d.]+)', - r'(Firefox/[\d.]+)', - r'(Safari/[\d.]+)', - r'(Edge/[\d.]+)', - r'(Opera/[\d.]+)' + r"(osu![^)]*)", # osu! 客户端 + r"(Chrome/[\d.]+)", + r"(Firefox/[\d.]+)", + r"(Safari/[\d.]+)", + r"(Edge/[\d.]+)", + r"(Opera/[\d.]+)", ] - + for pattern in browser_patterns: match = re.search(pattern, user_agent, re.IGNORECASE) if match: browser_info = match.group(1).strip() # 如果找到了 osu! 客户端,优先使用 - if 'osu!' in browser_info.lower(): + if "osu!" in browser_info.lower(): break - + # 构建简化的 User-Agent parts = [] if os_info: parts.append(os_info) if browser_info: parts.append(browser_info) - + if parts: - simplified = '; '.join(parts) + simplified = "; ".join(parts) else: # 如果没有识别到关键信息,截断原始字符串 - simplified = user_agent[:max_length-3] + "..." - + simplified = user_agent[: max_length - 3] + "..." + # 确保不超过最大长度 if len(simplified) > max_length: - simplified = simplified[:max_length-3] + "..." - + simplified = simplified[: max_length - 3] + "..." + return simplified + + +# https://github.com/encode/starlette/blob/master/starlette/_utils.py +T = TypeVar("T") +AwaitableCallable = Callable[..., Awaitable[T]] + + +def is_async_callable(obj: Any) -> bool: + while isinstance(obj, functools.partial): + obj = obj.func + + return inspect.iscoroutinefunction(obj) + + +P = ParamSpec("P") + + +async def run_in_threadpool(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func = functools.partial(func, *args, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, func) + + +class BackgroundTasks: + def __init__(self, tasks: Sequence[asyncio.Task] | None = None): + self.tasks = set(tasks) if tasks else set() + + def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: + if is_async_callable(func): + coro = func(*args, **kwargs) + else: + coro = run_in_threadpool(func, *args, **kwargs) + task = asyncio.create_task(coro) + self.tasks.add(task) + task.add_done_callback(self.tasks.discard) + + def stop(self) -> None: + for task in self.tasks: + task.cancel() + self.tasks.clear() + + +bg_tasks = BackgroundTasks() diff --git a/main.py b/main.py index 2be5c37..3031213 100644 --- a/main.py +++ b/main.py @@ -23,7 +23,10 @@ from app.router import ( ) from app.router.redirect import redirect_router from app.scheduler.cache_scheduler import start_cache_scheduler, stop_cache_scheduler -from app.scheduler.database_cleanup_scheduler import start_database_cleanup_scheduler, stop_database_cleanup_scheduler +from app.scheduler.database_cleanup_scheduler import ( + start_database_cleanup_scheduler, + stop_database_cleanup_scheduler, +) from app.service.beatmap_download_service import download_service from app.service.calculate_all_user_rank import calculate_user_rank from app.service.create_banchobot import create_banchobot @@ -32,11 +35,12 @@ from app.service.email_queue import start_email_processor, stop_email_processor from app.service.geoip_scheduler import schedule_geoip_updates from app.service.init_geoip import init_geoip from app.service.load_achievements import load_achievements +from app.service.online_status_maintenance import schedule_online_status_maintenance from app.service.osu_rx_statistics import create_rx_statistics from app.service.recalculate import recalculate from app.service.redis_message_system import redis_message_system from app.service.stats_scheduler import start_stats_scheduler, stop_stats_scheduler -from app.service.online_status_maintenance import schedule_online_status_maintenance +from app.utils import bg_tasks # 检查 New Relic 配置文件是否存在,如果存在则初始化 New Relic newrelic_config_path = os.path.join(os.path.dirname(__file__), "newrelic.ini") @@ -52,9 +56,7 @@ if os.path.exists(newrelic_config_path): newrelic.agent.initialize(newrelic_config_path, environment) logger.info(f"[NewRelic] Enabled, environment: {environment}") except ImportError: - logger.warning( - "[NewRelic] Config file found but 'newrelic' package is not installed" - ) + logger.warning("[NewRelic] Config file found but 'newrelic' package is not installed") except Exception as e: logger.error(f"[NewRelic] Initialization failed: {e}") else: @@ -90,6 +92,7 @@ async def lifespan(app: FastAPI): load_achievements() # on shutdown yield + bg_tasks.stop() stop_scheduler() redis_message_system.stop() # 停止 Redis 消息系统 stop_stats_scheduler() # 停止统计调度器 @@ -179,14 +182,10 @@ async def http_exception_handler(requst: Request, exc: HTTPException): if settings.secret_key == "your_jwt_secret_here": - logger.warning( - "jwt_secret_key is unset. Your server is unsafe. " - "Use this command to generate: openssl rand -hex 32" - ) + logger.warning("jwt_secret_key is unset. Your server is unsafe. Use this command to generate: openssl rand -hex 32") if settings.osu_web_client_secret == "your_osu_web_client_secret_here": logger.warning( - "osu_web_client_secret is unset. Your server is unsafe. " - "Use this command to generate: openssl rand -hex 40" + "osu_web_client_secret is unset. Your server is unsafe. Use this command to generate: openssl rand -hex 40" ) if __name__ == "__main__": diff --git a/migrations/env.py b/migrations/env.py index 825cde6..7fbcb1b 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -60,9 +60,7 @@ def run_migrations_offline() -> None: def do_run_migrations(connection: Connection) -> None: - context.configure( - connection=connection, target_metadata=target_metadata, compare_type=True - ) + context.configure(connection=connection, target_metadata=target_metadata, compare_type=True) with context.begin_transaction(): context.run_migrations() diff --git a/migrations/versions/0f96348cdfd2_add_email_verification_tables.py b/migrations/versions/0f96348cdfd2_add_email_verification_tables.py index 0fd86c6..c71f03a 100644 --- a/migrations/versions/0f96348cdfd2_add_email_verification_tables.py +++ b/migrations/versions/0f96348cdfd2_add_email_verification_tables.py @@ -5,6 +5,7 @@ Revises: e96a649e18ca Create Date: 2025-08-22 07:26:59.129564 """ + from __future__ import annotations from collections.abc import Sequence @@ -38,7 +39,7 @@ def upgrade() -> None: sa.Index("ix_email_verifications_user_id", "user_id"), sa.Index("ix_email_verifications_email", "email"), ) - + # 创建登录会话表 op.create_table( "login_sessions", diff --git a/migrations/versions/198227d190b8_user_add_events.py b/migrations/versions/198227d190b8_user_add_events.py index 790e1cd..6f3d619 100644 --- a/migrations/versions/198227d190b8_user_add_events.py +++ b/migrations/versions/198227d190b8_user_add_events.py @@ -53,9 +53,7 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("id"), ) - op.create_index( - op.f("ix_user_events_user_id"), "user_events", ["user_id"], unique=False - ) + op.create_index(op.f("ix_user_events_user_id"), "user_events", ["user_id"], unique=False) # ### end Alembic commands ### diff --git a/migrations/versions/19cdc9ce4dcb_gamemode_add_osurx_osupp.py b/migrations/versions/19cdc9ce4dcb_gamemode_add_osurx_osupp.py index 3cd82c7..e06d44b 100644 --- a/migrations/versions/19cdc9ce4dcb_gamemode_add_osurx_osupp.py +++ b/migrations/versions/19cdc9ce4dcb_gamemode_add_osurx_osupp.py @@ -26,51 +26,37 @@ def upgrade() -> None: op.alter_column( "lazer_users", "playmode", - type_=sa.Enum( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode" - ), + type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"), ) op.alter_column( "beatmaps", "mode", - type_=sa.Enum( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode" - ), + type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"), ) op.alter_column( "lazer_user_statistics", "mode", - type_=sa.Enum( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode" - ), + type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"), ) op.alter_column( "score_tokens", "ruleset_id", - type_=sa.Enum( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode" - ), + type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"), ) op.alter_column( "scores", "gamemode", - type_=sa.Enum( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode" - ), + type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"), ) op.alter_column( "best_scores", "gamemode", - type_=sa.Enum( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode" - ), + type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"), ) op.alter_column( "total_score_best_scores", "gamemode", - type_=sa.Enum( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode" - ), + type_=sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"), ) # ### end Alembic commands ### diff --git a/migrations/versions/2dcd04d3f4dc_fix_user_login_log_table_name.py b/migrations/versions/2dcd04d3f4dc_fix_user_login_log_table_name.py index 1a3e46e..52cc20e 100644 --- a/migrations/versions/2dcd04d3f4dc_fix_user_login_log_table_name.py +++ b/migrations/versions/2dcd04d3f4dc_fix_user_login_log_table_name.py @@ -29,39 +29,19 @@ def upgrade() -> None: "user_login_log", sa.Column("id", sa.Integer(), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column( - "ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False - ), - sa.Column( - "user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True - ), + sa.Column("ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False), + sa.Column("user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), sa.Column("login_time", sa.DateTime(), nullable=False), - sa.Column( - "country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True - ), - sa.Column( - "country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True - ), - sa.Column( - "city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True - ), - sa.Column( - "latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True - ), - sa.Column( - "longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True - ), - sa.Column( - "time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True - ), + sa.Column("country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True), + sa.Column("country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), + sa.Column("city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), + sa.Column("latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True), + sa.Column("longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True), + sa.Column("time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True), sa.Column("asn", sa.Integer(), nullable=True), - sa.Column( - "organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True - ), + sa.Column("organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True), sa.Column("login_success", sa.Boolean(), nullable=False), - sa.Column( - "login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False - ), + sa.Column("login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), sa.PrimaryKeyConstraint("id"), ) @@ -71,9 +51,7 @@ def upgrade() -> None: ["ip_address"], unique=False, ) - op.create_index( - op.f("ix_user_login_log_user_id"), "user_login_log", ["user_id"], unique=False - ) + op.create_index(op.f("ix_user_login_log_user_id"), "user_login_log", ["user_id"], unique=False) op.drop_index(op.f("ix_userloginlog_ip_address"), table_name="userloginlog") op.drop_index(op.f("ix_userloginlog_user_id"), table_name="userloginlog") op.drop_table("userloginlog") @@ -111,12 +89,8 @@ def downgrade() -> None: mysql_default_charset="utf8mb4", mysql_engine="InnoDB", ) - op.create_index( - op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False - ) - op.create_index( - op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False - ) + op.create_index(op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False) + op.create_index(op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False) op.drop_index(op.f("ix_user_login_log_user_id"), table_name="user_login_log") op.drop_index(op.f("ix_user_login_log_ip_address"), table_name="user_login_log") op.drop_table("user_login_log") diff --git a/migrations/versions/3eef4794ded1_add_user_login_log_table.py b/migrations/versions/3eef4794ded1_add_user_login_log_table.py index 020d276..2228548 100644 --- a/migrations/versions/3eef4794ded1_add_user_login_log_table.py +++ b/migrations/versions/3eef4794ded1_add_user_login_log_table.py @@ -28,48 +28,24 @@ def upgrade() -> None: "userloginlog", sa.Column("id", sa.Integer(), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column( - "ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False - ), - sa.Column( - "user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True - ), + sa.Column("ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False), + sa.Column("user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), sa.Column("login_time", sa.DateTime(), nullable=False), - sa.Column( - "country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True - ), - sa.Column( - "country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True - ), - sa.Column( - "city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True - ), - sa.Column( - "latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True - ), - sa.Column( - "longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True - ), - sa.Column( - "time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True - ), + sa.Column("country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True), + sa.Column("country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), + sa.Column("city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), + sa.Column("latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True), + sa.Column("longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True), + sa.Column("time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True), sa.Column("asn", sa.Integer(), nullable=True), - sa.Column( - "organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True - ), + sa.Column("organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True), sa.Column("login_success", sa.Boolean(), nullable=False), - sa.Column( - "login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False - ), + sa.Column("login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), sa.PrimaryKeyConstraint("id"), ) - op.create_index( - op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False - ) - op.create_index( - op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False - ) + op.create_index(op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False) + op.create_index(op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False) # ### end Alembic commands ### diff --git a/migrations/versions/4f46c43d8601_notification_add_notification.py b/migrations/versions/4f46c43d8601_notification_add_notification.py index 15fbc57..f718c47 100644 --- a/migrations/versions/4f46c43d8601_notification_add_notification.py +++ b/migrations/versions/4f46c43d8601_notification_add_notification.py @@ -58,9 +58,7 @@ def upgrade() -> None: ), nullable=False, ), - sa.Column( - "category", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False - ), + sa.Column("category", sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=True), sa.Column("object_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("object_id", sa.BigInteger(), nullable=True), @@ -68,16 +66,10 @@ def upgrade() -> None: sa.Column("details", sa.JSON(), nullable=True), sa.PrimaryKeyConstraint("id"), ) - op.create_index( - op.f("ix_notifications_category"), "notifications", ["category"], unique=False - ) + op.create_index(op.f("ix_notifications_category"), "notifications", ["category"], unique=False) op.create_index(op.f("ix_notifications_id"), "notifications", ["id"], unique=False) - op.create_index( - op.f("ix_notifications_name"), "notifications", ["name"], unique=False - ) - op.create_index( - op.f("ix_notifications_object_id"), "notifications", ["object_id"], unique=False - ) + op.create_index(op.f("ix_notifications_name"), "notifications", ["name"], unique=False) + op.create_index(op.f("ix_notifications_object_id"), "notifications", ["object_id"], unique=False) op.create_index( op.f("ix_notifications_object_type"), "notifications", @@ -106,9 +98,7 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("id"), ) - op.create_index( - op.f("ix_user_notifications_id"), "user_notifications", ["id"], unique=False - ) + op.create_index(op.f("ix_user_notifications_id"), "user_notifications", ["id"], unique=False) op.create_index( op.f("ix_user_notifications_is_read"), "user_notifications", diff --git a/migrations/versions/59c9a0827de0_beatmap_add_indexes.py b/migrations/versions/59c9a0827de0_beatmap_add_indexes.py index b1a2fbb..9fab9c4 100644 --- a/migrations/versions/59c9a0827de0_beatmap_add_indexes.py +++ b/migrations/versions/59c9a0827de0_beatmap_add_indexes.py @@ -22,18 +22,14 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.create_index( - op.f("ix_beatmaps_beatmap_status"), "beatmaps", ["beatmap_status"], unique=False - ) + op.create_index(op.f("ix_beatmaps_beatmap_status"), "beatmaps", ["beatmap_status"], unique=False) op.create_index( op.f("ix_beatmaps_difficulty_rating"), "beatmaps", ["difficulty_rating"], unique=False, ) - op.create_index( - op.f("ix_beatmaps_last_updated"), "beatmaps", ["last_updated"], unique=False - ) + op.create_index(op.f("ix_beatmaps_last_updated"), "beatmaps", ["last_updated"], unique=False) op.create_index(op.f("ix_beatmaps_user_id"), "beatmaps", ["user_id"], unique=False) op.create_index(op.f("ix_beatmaps_version"), "beatmaps", ["version"], unique=False) op.create_index( @@ -54,48 +50,32 @@ def upgrade() -> None: ["beatmap_status"], unique=False, ) - op.create_index( - op.f("ix_beatmapsets_creator"), "beatmapsets", ["creator"], unique=False - ) + op.create_index(op.f("ix_beatmapsets_creator"), "beatmapsets", ["creator"], unique=False) op.create_index( op.f("ix_beatmapsets_last_updated"), "beatmapsets", ["last_updated"], unique=False, ) - op.create_index( - op.f("ix_beatmapsets_play_count"), "beatmapsets", ["play_count"], unique=False - ) - op.create_index( - op.f("ix_beatmapsets_ranked_date"), "beatmapsets", ["ranked_date"], unique=False - ) - op.create_index( - op.f("ix_beatmapsets_storyboard"), "beatmapsets", ["storyboard"], unique=False - ) + op.create_index(op.f("ix_beatmapsets_play_count"), "beatmapsets", ["play_count"], unique=False) + op.create_index(op.f("ix_beatmapsets_ranked_date"), "beatmapsets", ["ranked_date"], unique=False) + op.create_index(op.f("ix_beatmapsets_storyboard"), "beatmapsets", ["storyboard"], unique=False) op.create_index( op.f("ix_beatmapsets_submitted_date"), "beatmapsets", ["submitted_date"], unique=False, ) - op.create_index( - op.f("ix_beatmapsets_title"), "beatmapsets", ["title"], unique=False - ) + op.create_index(op.f("ix_beatmapsets_title"), "beatmapsets", ["title"], unique=False) op.create_index( op.f("ix_beatmapsets_title_unicode"), "beatmapsets", ["title_unicode"], unique=False, ) - op.create_index( - op.f("ix_beatmapsets_track_id"), "beatmapsets", ["track_id"], unique=False - ) - op.create_index( - op.f("ix_beatmapsets_user_id"), "beatmapsets", ["user_id"], unique=False - ) - op.create_index( - op.f("ix_beatmapsets_video"), "beatmapsets", ["video"], unique=False - ) + op.create_index(op.f("ix_beatmapsets_track_id"), "beatmapsets", ["track_id"], unique=False) + op.create_index(op.f("ix_beatmapsets_user_id"), "beatmapsets", ["user_id"], unique=False) + op.create_index(op.f("ix_beatmapsets_video"), "beatmapsets", ["video"], unique=False) # ### end Alembic commands ### diff --git a/migrations/versions/5b76689f6e4b_increase_the_length_limit_of_the_user_.py b/migrations/versions/5b76689f6e4b_increase_the_length_limit_of_the_user_.py index 8a26293..fdcec41 100644 --- a/migrations/versions/5b76689f6e4b_increase_the_length_limit_of_the_user_.py +++ b/migrations/versions/5b76689f6e4b_increase_the_length_limit_of_the_user_.py @@ -5,6 +5,7 @@ Revises: 65e7dc8d5905 Create Date: 2025-08-22 15:14:59.242274 """ + from __future__ import annotations from collections.abc import Sequence @@ -23,18 +24,24 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.alter_column("login_sessions", "user_agent", - existing_type=mysql.VARCHAR(length=255), - type_=sa.String(length=250), - existing_nullable=True) + op.alter_column( + "login_sessions", + "user_agent", + existing_type=mysql.VARCHAR(length=255), + type_=sa.String(length=250), + existing_nullable=True, + ) # ### end Alembic commands ### def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.alter_column("login_sessions", "user_agent", - existing_type=sa.String(length=250), - type_=mysql.VARCHAR(length=255), - existing_nullable=True) + op.alter_column( + "login_sessions", + "user_agent", + existing_type=sa.String(length=250), + type_=mysql.VARCHAR(length=255), + existing_nullable=True, + ) # ### end Alembic commands ### diff --git a/migrations/versions/65e7dc8d5905_team_add_team_request_table.py b/migrations/versions/65e7dc8d5905_team_add_team_request_table.py index 997b9e5..fda8536 100644 --- a/migrations/versions/65e7dc8d5905_team_add_team_request_table.py +++ b/migrations/versions/65e7dc8d5905_team_add_team_request_table.py @@ -40,9 +40,7 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("user_id", "team_id"), ) - op.alter_column( - "team_members", "user_id", existing_type=mysql.BIGINT(), nullable=False - ) + op.alter_column("team_members", "user_id", existing_type=mysql.BIGINT(), nullable=False) op.drop_index(op.f("ix_team_members_id"), table_name="team_members") op.drop_column("team_members", "id") op.add_column( @@ -79,8 +77,6 @@ def downgrade() -> None: sa.Column("id", mysql.INTEGER(), autoincrement=True, nullable=False), ) op.create_index(op.f("ix_team_members_id"), "team_members", ["id"], unique=False) - op.alter_column( - "team_members", "user_id", existing_type=mysql.BIGINT(), nullable=True - ) + op.alter_column("team_members", "user_id", existing_type=mysql.BIGINT(), nullable=True) op.drop_table("team_requests") # ### end Alembic commands ### diff --git a/migrations/versions/7e9d5e012d37_auth_add_v1_keys_table.py b/migrations/versions/7e9d5e012d37_auth_add_v1_keys_table.py index b410ee7..e5bdc88 100644 --- a/migrations/versions/7e9d5e012d37_auth_add_v1_keys_table.py +++ b/migrations/versions/7e9d5e012d37_auth_add_v1_keys_table.py @@ -38,9 +38,7 @@ def upgrade() -> None: ) op.create_index(op.f("ix_v1_api_keys_key"), "v1_api_keys", ["key"], unique=False) op.create_index(op.f("ix_v1_api_keys_name"), "v1_api_keys", ["name"], unique=False) - op.create_index( - op.f("ix_v1_api_keys_owner_id"), "v1_api_keys", ["owner_id"], unique=False - ) + op.create_index(op.f("ix_v1_api_keys_owner_id"), "v1_api_keys", ["owner_id"], unique=False) # ### end Alembic commands ### diff --git a/migrations/versions/8bab62d764a5_statistics_remove_level_progress.py b/migrations/versions/8bab62d764a5_statistics_remove_level_progress.py index 5ae4215..d799722 100644 --- a/migrations/versions/8bab62d764a5_statistics_remove_level_progress.py +++ b/migrations/versions/8bab62d764a5_statistics_remove_level_progress.py @@ -40,9 +40,7 @@ def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.add_column( "lazer_user_statistics", - sa.Column( - "level_progress", mysql.INTEGER(), autoincrement=False, nullable=False - ), + sa.Column("level_progress", mysql.INTEGER(), autoincrement=False, nullable=False), ) op.alter_column( "lazer_user_statistics", diff --git a/migrations/versions/951a2188e691_score_add_rx_for_taiko_catch.py b/migrations/versions/951a2188e691_score_add_rx_for_taiko_catch.py index 3d8c252..e165b0f 100644 --- a/migrations/versions/951a2188e691_score_add_rx_for_taiko_catch.py +++ b/migrations/versions/951a2188e691_score_add_rx_for_taiko_catch.py @@ -26,73 +26,55 @@ def upgrade() -> None: op.alter_column( "beatmaps", "mode", - existing_type=mysql.ENUM( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX" - ), + existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX"), nullable=False, ) op.alter_column( "best_scores", "gamemode", - existing_type=mysql.ENUM( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX" - ), + existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX"), nullable=False, ) op.alter_column( "lazer_user_statistics", "mode", - existing_type=mysql.ENUM( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX" - ), + existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX"), nullable=False, ) op.alter_column( "lazer_users", "playmode", - existing_type=mysql.ENUM( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX" - ), + existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX"), nullable=False, ) op.alter_column( "score_tokens", "ruleset_id", - existing_type=mysql.ENUM( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX" - ), + existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX"), nullable=False, ) op.alter_column( "scores", "gamemode", - existing_type=mysql.ENUM( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX" - ), + existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX"), nullable=False, ) op.alter_column( "total_score_best_scores", "gamemode", - existing_type=mysql.ENUM( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX" - ), + existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX"), nullable=False, ) op.alter_column( "rank_history", "mode", - existing_type=mysql.ENUM( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX" - ), + existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX"), nullable=False, ) op.alter_column( "rank_top", "mode", - existing_type=mysql.ENUM( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX" - ), + existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX"), nullable=False, ) # ### end Alembic commands ### @@ -146,17 +128,13 @@ def downgrade() -> None: op.alter_column( "rank_top", "mode", - existing_type=mysql.ENUM( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX" - ), + existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX"), nullable=False, ) op.alter_column( "rank_top", "mode", - existing_type=mysql.ENUM( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX" - ), + existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", "TAIKORX", "FRUITSRX"), nullable=False, ) # ### end Alembic commands ### diff --git a/migrations/versions/9aa4f7c06824_playlist_best_scores_remove_foreign_key_.py b/migrations/versions/9aa4f7c06824_playlist_best_scores_remove_foreign_key_.py index a7381e3..37cb1e0 100644 --- a/migrations/versions/9aa4f7c06824_playlist_best_scores_remove_foreign_key_.py +++ b/migrations/versions/9aa4f7c06824_playlist_best_scores_remove_foreign_key_.py @@ -22,9 +22,7 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint( - op.f("playlist_best_scores_ibfk_1"), "playlist_best_scores", type_="foreignkey" - ) + op.drop_constraint(op.f("playlist_best_scores_ibfk_1"), "playlist_best_scores", type_="foreignkey") # ### end Alembic commands ### diff --git a/migrations/versions/9f6b27e8ea51_add_table_banned_beatmaps.py b/migrations/versions/9f6b27e8ea51_add_table_banned_beatmaps.py index 202d517..9568433 100644 --- a/migrations/versions/9f6b27e8ea51_add_table_banned_beatmaps.py +++ b/migrations/versions/9f6b27e8ea51_add_table_banned_beatmaps.py @@ -35,9 +35,7 @@ def upgrade() -> None: ["beatmap_id"], unique=False, ) - op.create_index( - op.f("ix_banned_beatmaps_id"), "banned_beatmaps", ["id"], unique=False - ) + op.create_index(op.f("ix_banned_beatmaps_id"), "banned_beatmaps", ["id"], unique=False) # ### end Alembic commands ### diff --git a/migrations/versions/a8669ba11e96_auth_support_custom_client.py b/migrations/versions/a8669ba11e96_auth_support_custom_client.py index 0a765a2..91b0367 100644 --- a/migrations/versions/a8669ba11e96_auth_support_custom_client.py +++ b/migrations/versions/a8669ba11e96_auth_support_custom_client.py @@ -36,22 +36,16 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("client_id"), ) - op.create_index( - op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=False - ) + op.create_index(op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=False) op.create_index( op.f("ix_oauth_clients_client_secret"), "oauth_clients", ["client_secret"], unique=False, ) - op.create_index( - op.f("ix_oauth_clients_owner_id"), "oauth_clients", ["owner_id"], unique=False - ) + op.create_index(op.f("ix_oauth_clients_owner_id"), "oauth_clients", ["owner_id"], unique=False) op.add_column("oauth_tokens", sa.Column("client_id", sa.Integer(), nullable=False)) - op.create_index( - op.f("ix_oauth_tokens_client_id"), "oauth_tokens", ["client_id"], unique=False - ) + op.create_index(op.f("ix_oauth_tokens_client_id"), "oauth_tokens", ["client_id"], unique=False) # ### end Alembic commands ### diff --git a/migrations/versions/aa582c13f905_count_add_replays_watched_counts.py b/migrations/versions/aa582c13f905_count_add_replays_watched_counts.py index e470345..33c727c 100644 --- a/migrations/versions/aa582c13f905_count_add_replays_watched_counts.py +++ b/migrations/versions/aa582c13f905_count_add_replays_watched_counts.py @@ -77,14 +77,8 @@ def downgrade() -> None: "replays_watched_counts", type_="foreignkey", ) - op.drop_index( - op.f("ix_replays_watched_counts_year"), table_name="replays_watched_counts" - ) - op.drop_index( - op.f("ix_replays_watched_counts_user_id"), table_name="replays_watched_counts" - ) - op.drop_index( - op.f("ix_replays_watched_counts_month"), table_name="replays_watched_counts" - ) + op.drop_index(op.f("ix_replays_watched_counts_year"), table_name="replays_watched_counts") + op.drop_index(op.f("ix_replays_watched_counts_user_id"), table_name="replays_watched_counts") + op.drop_index(op.f("ix_replays_watched_counts_month"), table_name="replays_watched_counts") op.drop_table("replays_watched_counts") # ### end Alembic commands ### diff --git a/migrations/versions/b6a304d96a2d_user_support_rank.py b/migrations/versions/b6a304d96a2d_user_support_rank.py index 9a93489..347e238 100644 --- a/migrations/versions/b6a304d96a2d_user_support_rank.py +++ b/migrations/versions/b6a304d96a2d_user_support_rank.py @@ -30,9 +30,7 @@ def upgrade() -> None: sa.Column("user_id", sa.BigInteger(), nullable=True), sa.Column( "mode", - sa.Enum( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode" - ), + sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"), nullable=False, ), sa.Column("rank", sa.Integer(), nullable=False), @@ -43,21 +41,15 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("id"), ) - op.create_index( - op.f("ix_rank_history_date"), "rank_history", ["date"], unique=False - ) - op.create_index( - op.f("ix_rank_history_user_id"), "rank_history", ["user_id"], unique=False - ) + op.create_index(op.f("ix_rank_history_date"), "rank_history", ["date"], unique=False) + op.create_index(op.f("ix_rank_history_user_id"), "rank_history", ["user_id"], unique=False) op.create_table( "rank_top", sa.Column("id", sa.BigInteger(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=True), sa.Column( "mode", - sa.Enum( - "OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode" - ), + sa.Enum("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP", name="gamemode"), nullable=False, ), sa.Column("rank", sa.Integer(), nullable=False), @@ -84,9 +76,7 @@ def upgrade() -> None: ) op.drop_column("lazer_user_statistics", "country_rank") op.drop_column("lazer_user_statistics", "global_rank") - op.create_index( - op.f("ix_oauth_clients_name"), "oauth_clients", ["name"], unique=False - ) + op.create_index(op.f("ix_oauth_clients_name"), "oauth_clients", ["name"], unique=False) # ### end Alembic commands ### @@ -102,12 +92,8 @@ def downgrade() -> None: "lazer_user_statistics", sa.Column("country_rank", mysql.INTEGER(), autoincrement=False, nullable=True), ) - op.drop_index( - op.f("ix_lazer_user_statistics_pp"), table_name="lazer_user_statistics" - ) - op.drop_index( - op.f("ix_lazer_user_statistics_mode"), table_name="lazer_user_statistics" - ) + op.drop_index(op.f("ix_lazer_user_statistics_pp"), table_name="lazer_user_statistics") + op.drop_index(op.f("ix_lazer_user_statistics_mode"), table_name="lazer_user_statistics") op.drop_table("rank_top") op.drop_table("rank_history") # ### end Alembic commands ### diff --git a/migrations/versions/ce29ef0a5674_beatmap_make_max_combo_nullable.py b/migrations/versions/ce29ef0a5674_beatmap_make_max_combo_nullable.py index b651a28..87b6613 100644 --- a/migrations/versions/ce29ef0a5674_beatmap_make_max_combo_nullable.py +++ b/migrations/versions/ce29ef0a5674_beatmap_make_max_combo_nullable.py @@ -23,16 +23,12 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.alter_column( - "beatmaps", "max_combo", existing_type=mysql.INTEGER(), nullable=True - ) + op.alter_column("beatmaps", "max_combo", existing_type=mysql.INTEGER(), nullable=True) # ### end Alembic commands ### def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.alter_column( - "beatmaps", "max_combo", existing_type=mysql.INTEGER(), nullable=False - ) + op.alter_column("beatmaps", "max_combo", existing_type=mysql.INTEGER(), nullable=False) # ### end Alembic commands ### diff --git a/migrations/versions/d103d442dc24_add_password_reset_table.py b/migrations/versions/d103d442dc24_add_password_reset_table.py index b348afa..9e15a12 100644 --- a/migrations/versions/d103d442dc24_add_password_reset_table.py +++ b/migrations/versions/d103d442dc24_add_password_reset_table.py @@ -5,6 +5,7 @@ Revises: 0f96348cdfd2 Create Date: 2025-08-22 08:27:58.468119 """ + from __future__ import annotations from collections.abc import Sequence @@ -12,7 +13,6 @@ from collections.abc import Sequence from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import mysql -import sqlmodel # revision identifiers, used by Alembic. revision: str = "d103d442dc24" @@ -24,179 +24,271 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - + # 安全创建 password_resets 表(如果不存在) try: - op.create_table("password_resets", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.BigInteger(), nullable=False), - sa.Column("email", sa.String(255), nullable=False), - sa.Column("reset_code", sa.String(8), nullable=False), - sa.Column("created_at", sa.DateTime(), nullable=False), - sa.Column("expires_at", sa.DateTime(), nullable=False), - sa.Column("is_used", sa.Boolean(), nullable=False), - sa.Column("used_at", sa.DateTime(), nullable=True), - sa.Column("ip_address", sa.String(255), nullable=True), - sa.Column("user_agent", sa.String(255), nullable=True), - sa.ForeignKeyConstraint(["user_id"], ["lazer_users.id"], ), - sa.PrimaryKeyConstraint("id") + op.create_table( + "password_resets", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("email", sa.String(255), nullable=False), + sa.Column("reset_code", sa.String(8), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("expires_at", sa.DateTime(), nullable=False), + sa.Column("is_used", sa.Boolean(), nullable=False), + sa.Column("used_at", sa.DateTime(), nullable=True), + sa.Column("ip_address", sa.String(255), nullable=True), + sa.Column("user_agent", sa.String(255), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["lazer_users.id"], + ), + sa.PrimaryKeyConstraint("id"), ) except Exception: # 如果表已存在,继续执行 pass - + # 安全创建索引 try: op.create_index(op.f("ix_password_resets_email"), "password_resets", ["email"], unique=False) except Exception: pass - + try: - op.create_index(op.f("ix_password_resets_user_id"), "password_resets", ["user_id"], unique=False) + op.create_index( + op.f("ix_password_resets_user_id"), + "password_resets", + ["user_id"], + unique=False, + ) except Exception: pass - + # 安全删除 two_factor_auth 表 - 先删除表(这会自动删除外键约束和索引) try: op.drop_table("two_factor_auth") except Exception: # 如果表不存在,继续执行 pass - + # 安全删除 user_ip_history 表 try: op.drop_table("user_ip_history") except Exception: # 如果表不存在,继续执行 pass - + # 安全删除 session_verification 表 try: op.drop_table("session_verification") except Exception: # 如果表不存在,继续执行 pass - op.alter_column("beatmapsets", "nsfw", - existing_type=mysql.TINYINT(display_width=1), - nullable=True) - op.alter_column("beatmapsets", "spotlight", - existing_type=mysql.TINYINT(display_width=1), - nullable=True) - op.alter_column("beatmapsets", "video", - existing_type=mysql.TINYINT(display_width=1), - nullable=True) - op.alter_column("beatmapsets", "can_be_hyped", - existing_type=mysql.TINYINT(display_width=1), - nullable=True) - op.alter_column("beatmapsets", "discussion_locked", - existing_type=mysql.TINYINT(display_width=1), - nullable=True) - op.alter_column("beatmapsets", "storyboard", - existing_type=mysql.TINYINT(display_width=1), - nullable=True) - op.alter_column("beatmapsets", "download_disabled", - existing_type=mysql.TINYINT(display_width=1), - nullable=True) - + op.alter_column( + "beatmapsets", + "nsfw", + existing_type=mysql.TINYINT(display_width=1), + nullable=True, + ) + op.alter_column( + "beatmapsets", + "spotlight", + existing_type=mysql.TINYINT(display_width=1), + nullable=True, + ) + op.alter_column( + "beatmapsets", + "video", + existing_type=mysql.TINYINT(display_width=1), + nullable=True, + ) + op.alter_column( + "beatmapsets", + "can_be_hyped", + existing_type=mysql.TINYINT(display_width=1), + nullable=True, + ) + op.alter_column( + "beatmapsets", + "discussion_locked", + existing_type=mysql.TINYINT(display_width=1), + nullable=True, + ) + op.alter_column( + "beatmapsets", + "storyboard", + existing_type=mysql.TINYINT(display_width=1), + nullable=True, + ) + op.alter_column( + "beatmapsets", + "download_disabled", + existing_type=mysql.TINYINT(display_width=1), + nullable=True, + ) + # 安全删除索引 try: op.drop_index(op.f("uq_user_achievement"), table_name="lazer_user_achievements") except Exception: # 如果索引不存在或有外键约束,继续执行 pass - - op.alter_column("scores", "has_replay", - existing_type=mysql.TINYINT(display_width=1), - nullable=True) - op.alter_column("scores", "passed", - existing_type=mysql.TINYINT(display_width=1), - nullable=True) - op.alter_column("scores", "preserve", - existing_type=mysql.TINYINT(display_width=1), - nullable=True) + + op.alter_column( + "scores", + "has_replay", + existing_type=mysql.TINYINT(display_width=1), + nullable=True, + ) + op.alter_column("scores", "passed", existing_type=mysql.TINYINT(display_width=1), nullable=True) + op.alter_column( + "scores", + "preserve", + existing_type=mysql.TINYINT(display_width=1), + nullable=True, + ) # ### end Alembic commands ### def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.alter_column("scores", "preserve", - existing_type=mysql.TINYINT(display_width=1), - nullable=False) - op.alter_column("scores", "passed", - existing_type=mysql.TINYINT(display_width=1), - nullable=False) - op.alter_column("scores", "has_replay", - existing_type=mysql.TINYINT(display_width=1), - nullable=False) - op.create_index(op.f("uq_user_achievement"), "lazer_user_achievements", ["user_id", "achievement_id"], unique=True) - op.alter_column("beatmapsets", "download_disabled", - existing_type=mysql.TINYINT(display_width=1), - nullable=False) - op.alter_column("beatmapsets", "storyboard", - existing_type=mysql.TINYINT(display_width=1), - nullable=False) - op.alter_column("beatmapsets", "discussion_locked", - existing_type=mysql.TINYINT(display_width=1), - nullable=False) - op.alter_column("beatmapsets", "can_be_hyped", - existing_type=mysql.TINYINT(display_width=1), - nullable=False) - op.alter_column("beatmapsets", "video", - existing_type=mysql.TINYINT(display_width=1), - nullable=False) - op.alter_column("beatmapsets", "spotlight", - existing_type=mysql.TINYINT(display_width=1), - nullable=False) - op.alter_column("beatmapsets", "nsfw", - existing_type=mysql.TINYINT(display_width=1), - nullable=False) - op.create_table("session_verification", - sa.Column("id", mysql.BIGINT(), autoincrement=True, nullable=False), - sa.Column("user_id", mysql.BIGINT(), autoincrement=False, nullable=False), - sa.Column("session_id", mysql.VARCHAR(length=255), nullable=False), - sa.Column("ip_address", mysql.VARCHAR(length=45), nullable=False), - sa.Column("is_verified", mysql.TINYINT(display_width=1), autoincrement=False, nullable=False), - sa.Column("verified_at", mysql.DATETIME(), nullable=True), - sa.Column("expires_at", mysql.DATETIME(), nullable=False), - sa.ForeignKeyConstraint(["user_id"], ["lazer_users.id"], name=op.f("session_verification_ibfk_1")), - sa.PrimaryKeyConstraint("id"), - mysql_collate="utf8mb4_0900_ai_ci", - mysql_default_charset="utf8mb4", - mysql_engine="InnoDB" + op.alter_column( + "scores", + "preserve", + existing_type=mysql.TINYINT(display_width=1), + nullable=False, ) - op.create_index(op.f("ix_session_verification_user_id"), "session_verification", ["user_id"], unique=False) - op.create_table("user_ip_history", - sa.Column("id", mysql.BIGINT(), autoincrement=True, nullable=False), - sa.Column("user_id", mysql.BIGINT(), autoincrement=False, nullable=False), - sa.Column("ip_address", mysql.VARCHAR(length=45), nullable=False), - sa.Column("first_seen", mysql.DATETIME(), nullable=False), - sa.Column("last_seen", mysql.DATETIME(), nullable=False), - sa.Column("usage_count", mysql.INTEGER(), autoincrement=False, nullable=False), - sa.Column("is_trusted", mysql.TINYINT(display_width=1), autoincrement=False, nullable=False), - sa.ForeignKeyConstraint(["user_id"], ["lazer_users.id"], name=op.f("user_ip_history_ibfk_1")), - sa.PrimaryKeyConstraint("id"), - mysql_collate="utf8mb4_0900_ai_ci", - mysql_default_charset="utf8mb4", - mysql_engine="InnoDB" + op.alter_column("scores", "passed", existing_type=mysql.TINYINT(display_width=1), nullable=False) + op.alter_column( + "scores", + "has_replay", + existing_type=mysql.TINYINT(display_width=1), + nullable=False, + ) + op.create_index( + op.f("uq_user_achievement"), + "lazer_user_achievements", + ["user_id", "achievement_id"], + unique=True, + ) + op.alter_column( + "beatmapsets", + "download_disabled", + existing_type=mysql.TINYINT(display_width=1), + nullable=False, + ) + op.alter_column( + "beatmapsets", + "storyboard", + existing_type=mysql.TINYINT(display_width=1), + nullable=False, + ) + op.alter_column( + "beatmapsets", + "discussion_locked", + existing_type=mysql.TINYINT(display_width=1), + nullable=False, + ) + op.alter_column( + "beatmapsets", + "can_be_hyped", + existing_type=mysql.TINYINT(display_width=1), + nullable=False, + ) + op.alter_column( + "beatmapsets", + "video", + existing_type=mysql.TINYINT(display_width=1), + nullable=False, + ) + op.alter_column( + "beatmapsets", + "spotlight", + existing_type=mysql.TINYINT(display_width=1), + nullable=False, + ) + op.alter_column( + "beatmapsets", + "nsfw", + existing_type=mysql.TINYINT(display_width=1), + nullable=False, + ) + op.create_table( + "session_verification", + sa.Column("id", mysql.BIGINT(), autoincrement=True, nullable=False), + sa.Column("user_id", mysql.BIGINT(), autoincrement=False, nullable=False), + sa.Column("session_id", mysql.VARCHAR(length=255), nullable=False), + sa.Column("ip_address", mysql.VARCHAR(length=45), nullable=False), + sa.Column( + "is_verified", + mysql.TINYINT(display_width=1), + autoincrement=False, + nullable=False, + ), + sa.Column("verified_at", mysql.DATETIME(), nullable=True), + sa.Column("expires_at", mysql.DATETIME(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["lazer_users.id"], name=op.f("session_verification_ibfk_1")), + sa.PrimaryKeyConstraint("id"), + mysql_collate="utf8mb4_0900_ai_ci", + mysql_default_charset="utf8mb4", + mysql_engine="InnoDB", + ) + op.create_index( + op.f("ix_session_verification_user_id"), + "session_verification", + ["user_id"], + unique=False, + ) + op.create_table( + "user_ip_history", + sa.Column("id", mysql.BIGINT(), autoincrement=True, nullable=False), + sa.Column("user_id", mysql.BIGINT(), autoincrement=False, nullable=False), + sa.Column("ip_address", mysql.VARCHAR(length=45), nullable=False), + sa.Column("first_seen", mysql.DATETIME(), nullable=False), + sa.Column("last_seen", mysql.DATETIME(), nullable=False), + sa.Column("usage_count", mysql.INTEGER(), autoincrement=False, nullable=False), + sa.Column( + "is_trusted", + mysql.TINYINT(display_width=1), + autoincrement=False, + nullable=False, + ), + sa.ForeignKeyConstraint(["user_id"], ["lazer_users.id"], name=op.f("user_ip_history_ibfk_1")), + sa.PrimaryKeyConstraint("id"), + mysql_collate="utf8mb4_0900_ai_ci", + mysql_default_charset="utf8mb4", + mysql_engine="InnoDB", ) op.create_index(op.f("ix_user_ip_history_user_id"), "user_ip_history", ["user_id"], unique=False) - op.create_index(op.f("ix_user_ip_history_ip_address"), "user_ip_history", ["ip_address"], unique=False) + op.create_index( + op.f("ix_user_ip_history_ip_address"), + "user_ip_history", + ["ip_address"], + unique=False, + ) op.create_index(op.f("ix_user_ip_history_id"), "user_ip_history", ["id"], unique=False) - op.create_table("two_factor_auth", - sa.Column("id", mysql.BIGINT(), autoincrement=True, nullable=False), - sa.Column("user_id", mysql.BIGINT(), autoincrement=False, nullable=False), - sa.Column("verification_code", mysql.VARCHAR(length=8), nullable=False), - sa.Column("expires_at", mysql.DATETIME(), nullable=False), - sa.Column("is_used", mysql.TINYINT(display_width=1), autoincrement=False, nullable=False), - sa.Column("ip_address", mysql.VARCHAR(length=45), nullable=False), - sa.Column("trigger_reason", mysql.VARCHAR(length=50), nullable=False), - sa.Column("verified_at", mysql.DATETIME(), nullable=True), - sa.ForeignKeyConstraint(["user_id"], ["lazer_users.id"], name=op.f("two_factor_auth_ibfk_1")), - sa.PrimaryKeyConstraint("id"), - mysql_collate="utf8mb4_0900_ai_ci", - mysql_default_charset="utf8mb4", - mysql_engine="InnoDB" + op.create_table( + "two_factor_auth", + sa.Column("id", mysql.BIGINT(), autoincrement=True, nullable=False), + sa.Column("user_id", mysql.BIGINT(), autoincrement=False, nullable=False), + sa.Column("verification_code", mysql.VARCHAR(length=8), nullable=False), + sa.Column("expires_at", mysql.DATETIME(), nullable=False), + sa.Column( + "is_used", + mysql.TINYINT(display_width=1), + autoincrement=False, + nullable=False, + ), + sa.Column("ip_address", mysql.VARCHAR(length=45), nullable=False), + sa.Column("trigger_reason", mysql.VARCHAR(length=50), nullable=False), + sa.Column("verified_at", mysql.DATETIME(), nullable=True), + sa.ForeignKeyConstraint(["user_id"], ["lazer_users.id"], name=op.f("two_factor_auth_ibfk_1")), + sa.PrimaryKeyConstraint("id"), + mysql_collate="utf8mb4_0900_ai_ci", + mysql_default_charset="utf8mb4", + mysql_engine="InnoDB", ) op.create_index(op.f("ix_two_factor_auth_user_id"), "two_factor_auth", ["user_id"], unique=False) op.create_index(op.f("ix_two_factor_auth_id"), "two_factor_auth", ["id"], unique=False) diff --git a/migrations/versions/dd33d89aa2c2_chat_add_chat.py b/migrations/versions/dd33d89aa2c2_chat_add_chat.py index 5c95336..973c380 100644 --- a/migrations/versions/dd33d89aa2c2_chat_add_chat.py +++ b/migrations/versions/dd33d89aa2c2_chat_add_chat.py @@ -61,12 +61,8 @@ def upgrade() -> None: ["description"], unique=False, ) - op.create_index( - op.f("ix_chat_channels_name"), "chat_channels", ["name"], unique=False - ) - op.create_index( - op.f("ix_chat_channels_type"), "chat_channels", ["type"], unique=False - ) + op.create_index(op.f("ix_chat_channels_name"), "chat_channels", ["name"], unique=False) + op.create_index(op.f("ix_chat_channels_type"), "chat_channels", ["type"], unique=False) op.create_table( "chat_messages", sa.Column("channel_id", sa.Integer(), nullable=False), @@ -102,15 +98,9 @@ def upgrade() -> None: ["message_id"], unique=False, ) - op.create_index( - op.f("ix_chat_messages_sender_id"), "chat_messages", ["sender_id"], unique=False - ) - op.create_index( - op.f("ix_chat_messages_timestamp"), "chat_messages", ["timestamp"], unique=False - ) - op.create_index( - op.f("ix_chat_messages_type"), "chat_messages", ["type"], unique=False - ) + op.create_index(op.f("ix_chat_messages_sender_id"), "chat_messages", ["sender_id"], unique=False) + op.create_index(op.f("ix_chat_messages_timestamp"), "chat_messages", ["timestamp"], unique=False) + op.create_index(op.f("ix_chat_messages_type"), "chat_messages", ["type"], unique=False) op.create_table( "chat_silence_users", sa.Column("id", sa.Integer(), nullable=False), diff --git a/migrations/versions/df9f725a077c_room_add_channel_id.py b/migrations/versions/df9f725a077c_room_add_channel_id.py index 85e7d93..1125cba 100644 --- a/migrations/versions/df9f725a077c_room_add_channel_id.py +++ b/migrations/versions/df9f725a077c_room_add_channel_id.py @@ -24,15 +24,9 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.alter_column( - "chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=True - ) - op.alter_column( - "chat_silence_users", "banned_at", existing_type=mysql.DATETIME(), nullable=True - ) - op.create_index( - op.f("ix_chat_silence_users_id"), "chat_silence_users", ["id"], unique=False - ) + op.alter_column("chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=True) + op.alter_column("chat_silence_users", "banned_at", existing_type=mysql.DATETIME(), nullable=True) + op.create_index(op.f("ix_chat_silence_users_id"), "chat_silence_users", ["id"], unique=False) op.add_column("rooms", sa.Column("channel_id", sa.Integer(), nullable=True)) # ### end Alembic commands ### @@ -48,7 +42,5 @@ def downgrade() -> None: existing_type=mysql.DATETIME(), nullable=False, ) - op.alter_column( - "chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=False - ) + op.alter_column("chat_silence_users", "user_id", existing_type=mysql.BIGINT(), nullable=False) # ### end Alembic commands ### diff --git a/migrations/versions/e96a649e18ca_achievement_remove_primary_key_.py b/migrations/versions/e96a649e18ca_achievement_remove_primary_key_.py index b55a663..bd02c7e 100644 --- a/migrations/versions/e96a649e18ca_achievement_remove_primary_key_.py +++ b/migrations/versions/e96a649e18ca_achievement_remove_primary_key_.py @@ -23,20 +23,14 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint( - constraint_name="PRIMARY", table_name="lazer_user_achievements", type_="primary" - ) - op.create_primary_key( - "pk_lazer_user_achievements", "lazer_user_achievements", ["id"] - ) + op.drop_constraint(constraint_name="PRIMARY", table_name="lazer_user_achievements", type_="primary") + op.create_primary_key("pk_lazer_user_achievements", "lazer_user_achievements", ["id"]) op.create_index( "ix_lazer_user_achievements_achievement_id", "lazer_user_achievements", ["achievement_id"], ) - op.create_unique_constraint( - "uq_user_achievement", "lazer_user_achievements", ["user_id", "achievement_id"] - ) + op.create_unique_constraint("uq_user_achievement", "lazer_user_achievements", ["user_id", "achievement_id"]) op.alter_column( "lazer_user_achievements", "id", @@ -63,14 +57,8 @@ def downgrade() -> None: nullable=False, existing_server_default=None, ) - op.drop_constraint( - constraint_name="PRIMARY", table_name="lazer_user_achievements", type_="primary" - ) + op.drop_constraint(constraint_name="PRIMARY", table_name="lazer_user_achievements", type_="primary") op.drop_constraint("uq_user_achievement", "lazer_user_achievements", type_="unique") - op.drop_index( - "ix_lazer_user_achievements_achievement_id", "lazer_user_achievements" - ) - op.create_primary_key( - "PRIMARY", "lazer_user_achievements", ["id", "achievement_id"] - ) + op.drop_index("ix_lazer_user_achievements_achievement_id", "lazer_user_achievements") + op.create_primary_key("PRIMARY", "lazer_user_achievements", ["id", "achievement_id"]) # ### end Alembic commands ### diff --git a/migrations/versions/f785165a5c0b_convert_event_event_payload_from_str_to_.py b/migrations/versions/f785165a5c0b_convert_event_event_payload_from_str_to_.py index 48827be..a70efbb 100644 --- a/migrations/versions/f785165a5c0b_convert_event_event_payload_from_str_to_.py +++ b/migrations/versions/f785165a5c0b_convert_event_event_payload_from_str_to_.py @@ -48,9 +48,7 @@ def upgrade() -> None: existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP"), nullable=False, ) - op.alter_column( - "monthly_playcounts", "count", existing_type=mysql.INTEGER(), nullable=False - ) + op.alter_column("monthly_playcounts", "count", existing_type=mysql.INTEGER(), nullable=False) op.alter_column( "score_tokens", "ruleset_id", @@ -107,9 +105,7 @@ def downgrade() -> None: existing_type=mysql.ENUM("OSU", "TAIKO", "FRUITS", "MANIA", "OSURX", "OSUAP"), nullable=True, ) - op.alter_column( - "monthly_playcounts", "count", existing_type=mysql.INTEGER(), nullable=True - ) + op.alter_column("monthly_playcounts", "count", existing_type=mysql.INTEGER(), nullable=True) op.alter_column( "lazer_users", "playmode", diff --git a/migrations/versions/fdb3822a30ba_init.py b/migrations/versions/fdb3822a30ba_init.py index 55ace06..e15b293 100644 --- a/migrations/versions/fdb3822a30ba_init.py +++ b/migrations/versions/fdb3822a30ba_init.py @@ -114,15 +114,11 @@ def upgrade() -> None: sa.Column("nominations_current", sa.Integer(), nullable=False), sa.Column("hype_current", sa.Integer(), nullable=False), sa.Column("hype_required", sa.Integer(), nullable=False), - sa.Column( - "availability_info", sqlmodel.sql.sqltypes.AutoString(), nullable=True - ), + sa.Column("availability_info", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column("download_disabled", sa.Boolean(), nullable=False), sa.PrimaryKeyConstraint("id"), ) - op.create_index( - op.f("ix_beatmapsets_artist"), "beatmapsets", ["artist"], unique=False - ) + op.create_index(op.f("ix_beatmapsets_artist"), "beatmapsets", ["artist"], unique=False) op.create_index( op.f("ix_beatmapsets_artist_unicode"), "beatmapsets", @@ -133,18 +129,14 @@ def upgrade() -> None: op.create_table( "lazer_users", sa.Column("avatar_url", sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column( - "country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=False - ), + sa.Column("country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=False), sa.Column("is_active", sa.Boolean(), nullable=False), sa.Column("is_bot", sa.Boolean(), nullable=False), sa.Column("is_supporter", sa.Boolean(), nullable=False), sa.Column("last_visit", sa.DateTime(timezone=True), nullable=True), sa.Column("pm_friends_only", sa.Boolean(), nullable=False), sa.Column("profile_colour", sqlmodel.sql.sqltypes.AutoString(), nullable=True), - sa.Column( - "username", sqlmodel.sql.sqltypes.AutoString(length=32), nullable=False - ), + sa.Column("username", sqlmodel.sql.sqltypes.AutoString(length=32), nullable=False), sa.Column("page", sa.JSON(), nullable=True), sa.Column("previous_usernames", sa.JSON(), nullable=True), sa.Column("support_level", sa.Integer(), nullable=False), @@ -179,13 +171,9 @@ def upgrade() -> None: sa.Column("is_qat", sa.Boolean(), nullable=False), sa.Column("is_bng", sa.Boolean(), nullable=False), sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), - sa.Column( - "email", sqlmodel.sql.sqltypes.AutoString(length=254), nullable=False - ), + sa.Column("email", sqlmodel.sql.sqltypes.AutoString(length=254), nullable=False), sa.Column("priv", sa.Integer(), nullable=False), - sa.Column( - "pw_bcrypt", sqlmodel.sql.sqltypes.AutoString(length=60), nullable=False - ), + sa.Column("pw_bcrypt", sqlmodel.sql.sqltypes.AutoString(length=60), nullable=False), sa.Column("silence_end_at", sa.DateTime(timezone=True), nullable=True), sa.Column("donor_end_at", sa.DateTime(timezone=True), nullable=True), sa.PrimaryKeyConstraint("id"), @@ -198,19 +186,13 @@ def upgrade() -> None: ) op.create_index(op.f("ix_lazer_users_email"), "lazer_users", ["email"], unique=True) op.create_index(op.f("ix_lazer_users_id"), "lazer_users", ["id"], unique=False) - op.create_index( - op.f("ix_lazer_users_username"), "lazer_users", ["username"], unique=True - ) + op.create_index(op.f("ix_lazer_users_username"), "lazer_users", ["username"], unique=True) op.create_table( "teams", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False), - sa.Column( - "short_name", sqlmodel.sql.sqltypes.AutoString(length=10), nullable=False - ), - sa.Column( - "flag_url", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True - ), + sa.Column("short_name", sqlmodel.sql.sqltypes.AutoString(length=10), nullable=False), + sa.Column("flag_url", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), sa.Column("created_at", sa.DateTime(), nullable=True), sa.PrimaryKeyConstraint("id"), ) @@ -263,12 +245,8 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("id"), ) - op.create_index( - op.f("ix_beatmaps_beatmapset_id"), "beatmaps", ["beatmapset_id"], unique=False - ) - op.create_index( - op.f("ix_beatmaps_checksum"), "beatmaps", ["checksum"], unique=False - ) + op.create_index(op.f("ix_beatmaps_beatmapset_id"), "beatmaps", ["beatmapset_id"], unique=False) + op.create_index(op.f("ix_beatmaps_checksum"), "beatmaps", ["checksum"], unique=False) op.create_index(op.f("ix_beatmaps_id"), "beatmaps", ["id"], unique=False) op.create_table( "daily_challenge_stats", @@ -409,27 +387,19 @@ def upgrade() -> None: ["user_id"], unique=False, ) - op.create_index( - op.f("ix_monthly_playcounts_year"), "monthly_playcounts", ["year"], unique=False - ) + op.create_index(op.f("ix_monthly_playcounts_year"), "monthly_playcounts", ["year"], unique=False) op.create_table( "oauth_tokens", sa.Column("id", sa.Integer(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=True), - sa.Column( - "access_token", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=False - ), + sa.Column("access_token", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=False), sa.Column( "refresh_token", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=False, ), - sa.Column( - "token_type", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=False - ), - sa.Column( - "scope", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False - ), + sa.Column("token_type", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=False), + sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False), sa.Column("expires_at", sa.DateTime(), nullable=True), sa.Column("created_at", sa.DateTime(), nullable=True), sa.ForeignKeyConstraint( @@ -441,17 +411,13 @@ def upgrade() -> None: sa.UniqueConstraint("refresh_token"), ) op.create_index(op.f("ix_oauth_tokens_id"), "oauth_tokens", ["id"], unique=False) - op.create_index( - op.f("ix_oauth_tokens_user_id"), "oauth_tokens", ["user_id"], unique=False - ) + op.create_index(op.f("ix_oauth_tokens_user_id"), "oauth_tokens", ["user_id"], unique=False) op.create_table( "relationship", sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=True), sa.Column("target_id", sa.BigInteger(), nullable=True), - sa.Column( - "type", sa.Enum("FOLLOW", "BLOCK", name="relationshiptype"), nullable=False - ), + sa.Column("type", sa.Enum("FOLLOW", "BLOCK", name="relationshiptype"), nullable=False), sa.ForeignKeyConstraint( ["target_id"], ["lazer_users.id"], @@ -462,12 +428,8 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("id"), ) - op.create_index( - op.f("ix_relationship_target_id"), "relationship", ["target_id"], unique=False - ) - op.create_index( - op.f("ix_relationship_user_id"), "relationship", ["user_id"], unique=False - ) + op.create_index(op.f("ix_relationship_target_id"), "relationship", ["target_id"], unique=False) + op.create_index(op.f("ix_relationship_user_id"), "relationship", ["user_id"], unique=False) op.create_table( "rooms", sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), @@ -495,16 +457,12 @@ def upgrade() -> None: ), sa.Column( "queue_mode", - sa.Enum( - "HOST_ONLY", "ALL_PLAYERS", "ALL_PLAYERS_ROUND_ROBIN", name="queuemode" - ), + sa.Enum("HOST_ONLY", "ALL_PLAYERS", "ALL_PLAYERS_ROUND_ROBIN", name="queuemode"), nullable=False, ), sa.Column("auto_skip", sa.Boolean(), nullable=False), sa.Column("auto_start_duration", sa.Integer(), nullable=False), - sa.Column( - "status", sa.Enum("IDLE", "PLAYING", name="roomstatus"), nullable=False - ), + sa.Column("status", sa.Enum("IDLE", "PLAYING", name="roomstatus"), nullable=False), sa.Column("id", sa.Integer(), nullable=False), sa.Column("host_id", sa.BigInteger(), nullable=True), sa.ForeignKeyConstraint( @@ -559,9 +517,7 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("id"), ) - op.create_index( - op.f("ix_user_account_history_id"), "user_account_history", ["id"], unique=False - ) + op.create_index(op.f("ix_user_account_history_id"), "user_account_history", ["id"], unique=False) op.create_index( op.f("ix_user_account_history_user_id"), "user_account_history", @@ -654,9 +610,7 @@ def upgrade() -> None: ["event_type"], unique=False, ) - op.create_index( - op.f("ix_multiplayer_events_id"), "multiplayer_events", ["id"], unique=False - ) + op.create_index(op.f("ix_multiplayer_events_id"), "multiplayer_events", ["id"], unique=False) op.create_index( op.f("ix_multiplayer_events_room_id"), "multiplayer_events", @@ -714,12 +668,8 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("db_id"), ) - op.create_index( - op.f("ix_room_playlists_db_id"), "room_playlists", ["db_id"], unique=False - ) - op.create_index( - op.f("ix_room_playlists_id"), "room_playlists", ["id"], unique=False - ) + op.create_index(op.f("ix_room_playlists_db_id"), "room_playlists", ["db_id"], unique=False) + op.create_index(op.f("ix_room_playlists_id"), "room_playlists", ["id"], unique=False) op.create_table( "score_tokens", sa.Column("score_id", sa.BigInteger(), nullable=True), @@ -754,9 +704,7 @@ def upgrade() -> None: op.create_table( "scores", sa.Column("accuracy", sa.Float(), nullable=False), - sa.Column( - "map_md5", sqlmodel.sql.sqltypes.AutoString(length=32), nullable=False - ), + sa.Column("map_md5", sqlmodel.sql.sqltypes.AutoString(length=32), nullable=False), sa.Column("build_id", sa.Integer(), nullable=True), sa.Column("classic_total_score", sa.BigInteger(), nullable=True), sa.Column("ended_at", sa.DateTime(), nullable=True), @@ -805,9 +753,7 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("id"), ) - op.create_index( - op.f("ix_scores_beatmap_id"), "scores", ["beatmap_id"], unique=False - ) + op.create_index(op.f("ix_scores_beatmap_id"), "scores", ["beatmap_id"], unique=False) op.create_index(op.f("ix_scores_gamemode"), "scores", ["gamemode"], unique=False) op.create_index(op.f("ix_scores_map_md5"), "scores", ["map_md5"], unique=False) op.create_index(op.f("ix_scores_user_id"), "scores", ["user_id"], unique=False) @@ -837,15 +783,9 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("score_id"), ) - op.create_index( - op.f("ix_best_scores_beatmap_id"), "best_scores", ["beatmap_id"], unique=False - ) - op.create_index( - op.f("ix_best_scores_gamemode"), "best_scores", ["gamemode"], unique=False - ) - op.create_index( - op.f("ix_best_scores_user_id"), "best_scores", ["user_id"], unique=False - ) + op.create_index(op.f("ix_best_scores_beatmap_id"), "best_scores", ["beatmap_id"], unique=False) + op.create_index(op.f("ix_best_scores_gamemode"), "best_scores", ["gamemode"], unique=False) + op.create_index(op.f("ix_best_scores_user_id"), "best_scores", ["user_id"], unique=False) op.create_table( "playlist_best_scores", sa.Column("user_id", sa.BigInteger(), nullable=True), @@ -945,9 +885,7 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.drop_index( - op.f("ix_total_score_best_scores_user_id"), table_name="total_score_best_scores" - ) + op.drop_index(op.f("ix_total_score_best_scores_user_id"), table_name="total_score_best_scores") op.drop_index( op.f("ix_total_score_best_scores_gamemode"), table_name="total_score_best_scores", @@ -957,15 +895,9 @@ def downgrade() -> None: table_name="total_score_best_scores", ) op.drop_table("total_score_best_scores") - op.drop_index( - op.f("ix_playlist_best_scores_user_id"), table_name="playlist_best_scores" - ) - op.drop_index( - op.f("ix_playlist_best_scores_room_id"), table_name="playlist_best_scores" - ) - op.drop_index( - op.f("ix_playlist_best_scores_playlist_id"), table_name="playlist_best_scores" - ) + op.drop_index(op.f("ix_playlist_best_scores_user_id"), table_name="playlist_best_scores") + op.drop_index(op.f("ix_playlist_best_scores_room_id"), table_name="playlist_best_scores") + op.drop_index(op.f("ix_playlist_best_scores_playlist_id"), table_name="playlist_best_scores") op.drop_table("playlist_best_scores") op.drop_index(op.f("ix_best_scores_user_id"), table_name="best_scores") op.drop_index(op.f("ix_best_scores_gamemode"), table_name="best_scores") @@ -983,34 +915,18 @@ def downgrade() -> None: op.drop_index(op.f("ix_room_playlists_db_id"), table_name="room_playlists") op.drop_table("room_playlists") op.drop_table("room_participated_users") - op.drop_index( - op.f("ix_multiplayer_events_user_id"), table_name="multiplayer_events" - ) - op.drop_index( - op.f("ix_multiplayer_events_room_id"), table_name="multiplayer_events" - ) + op.drop_index(op.f("ix_multiplayer_events_user_id"), table_name="multiplayer_events") + op.drop_index(op.f("ix_multiplayer_events_room_id"), table_name="multiplayer_events") op.drop_index(op.f("ix_multiplayer_events_id"), table_name="multiplayer_events") - op.drop_index( - op.f("ix_multiplayer_events_event_type"), table_name="multiplayer_events" - ) + op.drop_index(op.f("ix_multiplayer_events_event_type"), table_name="multiplayer_events") op.drop_table("multiplayer_events") - op.drop_index( - op.f("ix_item_attempts_count_user_id"), table_name="item_attempts_count" - ) - op.drop_index( - op.f("ix_item_attempts_count_room_id"), table_name="item_attempts_count" - ) + op.drop_index(op.f("ix_item_attempts_count_user_id"), table_name="item_attempts_count") + op.drop_index(op.f("ix_item_attempts_count_room_id"), table_name="item_attempts_count") op.drop_table("item_attempts_count") - op.drop_index( - op.f("ix_beatmap_playcounts_user_id"), table_name="beatmap_playcounts" - ) - op.drop_index( - op.f("ix_beatmap_playcounts_beatmap_id"), table_name="beatmap_playcounts" - ) + op.drop_index(op.f("ix_beatmap_playcounts_user_id"), table_name="beatmap_playcounts") + op.drop_index(op.f("ix_beatmap_playcounts_beatmap_id"), table_name="beatmap_playcounts") op.drop_table("beatmap_playcounts") - op.drop_index( - op.f("ix_user_account_history_user_id"), table_name="user_account_history" - ) + op.drop_index(op.f("ix_user_account_history_user_id"), table_name="user_account_history") op.drop_index(op.f("ix_user_account_history_id"), table_name="user_account_history") op.drop_table("user_account_history") op.drop_index(op.f("ix_team_members_id"), table_name="team_members") @@ -1027,29 +943,17 @@ def downgrade() -> None: op.drop_index(op.f("ix_oauth_tokens_id"), table_name="oauth_tokens") op.drop_table("oauth_tokens") op.drop_index(op.f("ix_monthly_playcounts_year"), table_name="monthly_playcounts") - op.drop_index( - op.f("ix_monthly_playcounts_user_id"), table_name="monthly_playcounts" - ) + op.drop_index(op.f("ix_monthly_playcounts_user_id"), table_name="monthly_playcounts") op.drop_index(op.f("ix_monthly_playcounts_month"), table_name="monthly_playcounts") op.drop_table("monthly_playcounts") - op.drop_index( - op.f("ix_lazer_user_statistics_user_id"), table_name="lazer_user_statistics" - ) + op.drop_index(op.f("ix_lazer_user_statistics_user_id"), table_name="lazer_user_statistics") op.drop_table("lazer_user_statistics") - op.drop_index( - op.f("ix_lazer_user_achievements_id"), table_name="lazer_user_achievements" - ) + op.drop_index(op.f("ix_lazer_user_achievements_id"), table_name="lazer_user_achievements") op.drop_table("lazer_user_achievements") - op.drop_index( - op.f("ix_favourite_beatmapset_user_id"), table_name="favourite_beatmapset" - ) - op.drop_index( - op.f("ix_favourite_beatmapset_beatmapset_id"), table_name="favourite_beatmapset" - ) + op.drop_index(op.f("ix_favourite_beatmapset_user_id"), table_name="favourite_beatmapset") + op.drop_index(op.f("ix_favourite_beatmapset_beatmapset_id"), table_name="favourite_beatmapset") op.drop_table("favourite_beatmapset") - op.drop_index( - op.f("ix_daily_challenge_stats_user_id"), table_name="daily_challenge_stats" - ) + op.drop_index(op.f("ix_daily_challenge_stats_user_id"), table_name="daily_challenge_stats") op.drop_table("daily_challenge_stats") op.drop_index(op.f("ix_beatmaps_id"), table_name="beatmaps") op.drop_index(op.f("ix_beatmaps_checksum"), table_name="beatmaps") diff --git a/pyproject.toml b/pyproject.toml index 7ef93e5..1f45d3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ authors = [{ name = "GooGuTeam" }] [tool.ruff] -line-length = 88 +line-length = 120 target-version = "py312" [tool.ruff.format] @@ -92,10 +92,10 @@ pythonVersion = "3.12" pythonPlatform = "All" typeCheckingMode = "standard" -reportShadowedImports = false disableBytesTypePromotions = true reportIncompatibleMethodOverride = false reportIncompatibleVariableOverride = false +exclude = ["migrations/", ".venv/", "venv/"] [tool.uv.workspace] members = [ @@ -114,6 +114,7 @@ cache-keys = [{file = "pyproject.toml"}, {file = "packages/msgpack_lazer_api/Car dev = [ "maturin>=1.9.2", "pre-commit>=4.2.0", + "pyright>=1.1.404", "ruff>=0.12.4", "types-aioboto3[aioboto3,essential]>=15.0.0", ] diff --git a/test_spectator_buffer.py b/test_spectator_buffer.py deleted file mode 100644 index 3f47917..0000000 --- a/test_spectator_buffer.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -观战缓冲区测试脚本 -用于验证观战同步和缓冲区功能是否正常工作 -""" - -import asyncio -import json -from datetime import UTC, datetime - -from app.signalr.hub.spectator_buffer import SpectatorStateManager, spectator_state_manager -from app.models.spectator_hub import SpectatorState, SpectatedUserState - -async def test_spectator_buffer(): - """测试观战缓冲区功能""" - print("=== 观战缓冲区测试开始 ===") - - # 模拟用户1开始游戏 - user1_id = 100 - user1_state = SpectatorState( - beatmap_id=123456, - ruleset_id=0, - mods=[], - state=SpectatedUserState.Playing, - maximum_statistics={} - ) - - await spectator_state_manager.handle_user_began_playing(user1_id, user1_state, { - 'beatmap_checksum': 'test_checksum', - 'score_token': 12345, - 'username': 'TestUser1', - 'started_at': datetime.now(UTC).timestamp() - }) - print(f"✓ 用户 {user1_id} 开始游戏 (谱面: {user1_state.beatmap_id})") - - # 模拟多人游戏同步 - multiplayer_data = { - 'room_id': 10, - 'beatmap_id': 789012, # 不同的谱面ID - 'ruleset_id': 1, # 不同的模式 - 'mods': [], - 'state': 'PLAYING', - 'is_multiplayer': True - } - - user2_id = 200 - await spectator_state_manager.sync_with_multiplayer(user2_id, multiplayer_data) - print(f"✓ 用户 {user2_id} 多人游戏同步 (谱面: {multiplayer_data['beatmap_id']}, 模式: {multiplayer_data['ruleset_id']})") - - # 模拟观战者开始观看 - spectator_id = 300 - catchup_bundle = await spectator_state_manager.handle_spectator_start_watching(spectator_id, user1_id) - print(f"✓ 观战者 {spectator_id} 开始观看用户 {user1_id}") - - if catchup_bundle: - print(f" - 追赶数据包包含: {list(catchup_bundle.keys())}") - if 'state' in catchup_bundle: - state = catchup_bundle['state'] - print(f" - 谱面ID: {state.beatmap_id}, 模式: {state.ruleset_id}") - - # 检查缓冲区统计 - stats = spectator_state_manager.get_buffer_stats() - print(f"✓ 缓冲区统计: {stats}") - - # 验证状态同步 - user1_buffered = spectator_state_manager.buffer.get_user_state(user1_id) - user2_buffered = spectator_state_manager.buffer.get_user_state(user2_id) - - if user1_buffered: - print(f"✓ 用户1缓冲状态: 谱面={user1_buffered.beatmap_id}, 模式={user1_buffered.ruleset_id}") - - if user2_buffered: - print(f"✓ 用户2缓冲状态: 谱面={user2_buffered.beatmap_id}, 模式={user2_buffered.ruleset_id}") - - # 验证不同谱面的处理 - if user1_buffered and user2_buffered: - if user1_buffered.beatmap_id != user2_buffered.beatmap_id: - print("✓ 不同用户的不同谱面已正确处理") - else: - print("⚠️ 用户谱面同步可能存在问题") - - print("=== 观战缓冲区测试完成 ===") - -if __name__ == "__main__": - asyncio.run(test_spectator_buffer()) diff --git a/tools/add_daily_challenge.py b/tools/add_daily_challenge.py index 6ae1761..1fa5fb4 100644 --- a/tools/add_daily_challenge.py +++ b/tools/add_daily_challenge.py @@ -63,15 +63,11 @@ async def main(): if not ruleset_inp: ruleset_inp = str(int(beatmap.mode)) elif not ruleset_inp.isdigit(): - ruleset_inp = input( - f"Invalid input. Enter ruleset ID ({int(beatmap.mode)}) >>> " - ) + ruleset_inp = input(f"Invalid input. Enter ruleset ID ({int(beatmap.mode)}) >>> ") continue ruleset_id = int(ruleset_inp) if beatmap.mode != GameMode.OSU and ruleset_id != int(beatmap.mode): - ruleset_inp = input( - f"Invalid input. Enter ruleset ID ({int(beatmap.mode)}) >>> " - ) + ruleset_inp = input(f"Invalid input. Enter ruleset ID ({int(beatmap.mode)}) >>> ") continue break diff --git a/uv.lock b/uv.lock index ac18584..1cfb608 100644 --- a/uv.lock +++ b/uv.lock @@ -566,6 +566,7 @@ dependencies = [ dev = [ { name = "maturin" }, { name = "pre-commit" }, + { name = "pyright" }, { name = "ruff" }, { name = "types-aioboto3", extra = ["aioboto3", "essential"] }, ] @@ -605,6 +606,7 @@ requires-dist = [ dev = [ { name = "maturin", specifier = ">=1.9.2" }, { name = "pre-commit", specifier = ">=4.2.0" }, + { name = "pyright", specifier = ">=1.1.404" }, { name = "ruff", specifier = ">=0.12.4" }, { name = "types-aioboto3", extras = ["aioboto3", "essential"], specifier = ">=15.0.0" }, ] @@ -1228,6 +1230,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/94/e4181a1f6286f545507528c78016e00065ea913276888db2262507693ce5/PyMySQL-1.1.1-py3-none-any.whl", hash = "sha256:4de15da4c61dc132f4fb9ab763063e693d521a80fd0e87943b9a453dd4c19d6c", size = 44972, upload-time = "2024-05-21T11:03:41.216Z" }, ] +[[package]] +name = "pyright" +version = "1.1.404" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodeenv" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e2/6e/026be64c43af681d5632722acd100b06d3d39f383ec382ff50a71a6d5bce/pyright-1.1.404.tar.gz", hash = "sha256:455e881a558ca6be9ecca0b30ce08aa78343ecc031d37a198ffa9a7a1abeb63e", size = 4065679, upload-time = "2025-08-20T18:46:14.029Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/30/89aa7f7d7a875bbb9a577d4b1dc5a3e404e3d2ae2657354808e905e358e0/pyright-1.1.404-py3-none-any.whl", hash = "sha256:c7b7ff1fdb7219c643079e4c3e7d4125f0dafcc19d253b47e898d130ea426419", size = 5902951, upload-time = "2025-08-20T18:46:12.096Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"