chore(linter): update ruff rules
This commit is contained in:
@@ -32,11 +32,9 @@ async def process_streak(
|
||||
).first()
|
||||
if not stats:
|
||||
return False
|
||||
if streak <= stats.daily_streak_best < next_streak:
|
||||
return True
|
||||
elif next_streak == 0 and stats.daily_streak_best >= streak:
|
||||
return True
|
||||
return False
|
||||
return bool(
|
||||
streak <= stats.daily_streak_best < next_streak or (next_streak == 0 and stats.daily_streak_best >= streak)
|
||||
)
|
||||
|
||||
|
||||
MEDALS = {
|
||||
|
||||
@@ -68,9 +68,7 @@ async def to_the_core(
|
||||
if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist:
|
||||
return False
|
||||
mods_ = mod_to_save(score.mods)
|
||||
if "DT" not in mods_ or "NC" not in mods_:
|
||||
return False
|
||||
return True
|
||||
return not ("DT" not in mods_ or "NC" not in mods_)
|
||||
|
||||
|
||||
async def wysi(
|
||||
@@ -83,9 +81,7 @@ async def wysi(
|
||||
return False
|
||||
if str(round(score.accuracy, ndigits=4))[3:] != "727":
|
||||
return False
|
||||
if "xi" not in beatmap.beatmapset.artist:
|
||||
return False
|
||||
return True
|
||||
return "xi" in beatmap.beatmapset.artist
|
||||
|
||||
|
||||
async def prepared(
|
||||
@@ -97,9 +93,7 @@ async def prepared(
|
||||
if score.rank != Rank.X and score.rank != Rank.XH:
|
||||
return False
|
||||
mods_ = mod_to_save(score.mods)
|
||||
if "NF" not in mods_:
|
||||
return False
|
||||
return True
|
||||
return "NF" in mods_
|
||||
|
||||
|
||||
async def reckless_adandon(
|
||||
@@ -117,9 +111,7 @@ async def reckless_adandon(
|
||||
redis = get_redis()
|
||||
mods_ = score.mods.copy()
|
||||
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
||||
if attribute.star_rating < 3:
|
||||
return False
|
||||
return True
|
||||
return not attribute.star_rating < 3
|
||||
|
||||
|
||||
async def lights_out(
|
||||
@@ -413,11 +405,10 @@ async def by_the_skin_of_the_teeth(
|
||||
return False
|
||||
|
||||
for mod in score.mods:
|
||||
if mod.get("acronym") == "AC":
|
||||
if "settings" in mod and "minimum_accuracy" in mod["settings"]:
|
||||
target_accuracy = mod["settings"]["minimum_accuracy"]
|
||||
if isinstance(target_accuracy, int | float):
|
||||
return abs(score.accuracy - float(target_accuracy)) < 0.0001
|
||||
if mod.get("acronym") == "AC" and "settings" in mod and "minimum_accuracy" in mod["settings"]:
|
||||
target_accuracy = mod["settings"]["minimum_accuracy"]
|
||||
if isinstance(target_accuracy, int | float):
|
||||
return abs(score.accuracy - float(target_accuracy)) < 0.0001
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -19,9 +19,7 @@ async def process_mod(
|
||||
return False
|
||||
if not beatmap.beatmap_status.has_leaderboard():
|
||||
return False
|
||||
if len(score.mods) != 1 or score.mods[0]["acronym"] != mod:
|
||||
return False
|
||||
return True
|
||||
return not (len(score.mods) != 1 or score.mods[0]["acronym"] != mod)
|
||||
|
||||
|
||||
async def process_category_mod(
|
||||
|
||||
@@ -22,11 +22,7 @@ async def process_combo(
|
||||
return False
|
||||
if next_combo != 0 and combo >= next_combo:
|
||||
return False
|
||||
if combo <= score.max_combo < next_combo:
|
||||
return True
|
||||
elif next_combo == 0 and score.max_combo >= combo:
|
||||
return True
|
||||
return False
|
||||
return bool(combo <= score.max_combo < next_combo or (next_combo == 0 and score.max_combo >= combo))
|
||||
|
||||
|
||||
MEDALS: Medals = {
|
||||
|
||||
@@ -35,11 +35,7 @@ async def process_playcount(
|
||||
).first()
|
||||
if not stats:
|
||||
return False
|
||||
if pc <= stats.play_count < next_pc:
|
||||
return True
|
||||
elif next_pc == 0 and stats.play_count >= pc:
|
||||
return True
|
||||
return False
|
||||
return bool(pc <= stats.play_count < next_pc or (next_pc == 0 and stats.play_count >= pc))
|
||||
|
||||
|
||||
MEDALS: Medals = {
|
||||
|
||||
@@ -47,9 +47,7 @@ async def process_skill(
|
||||
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
||||
if attribute.star_rating < star or attribute.star_rating >= star + 1:
|
||||
return False
|
||||
if type == "fc" and not score.is_perfect_combo:
|
||||
return False
|
||||
return True
|
||||
return not (type == "fc" and not score.is_perfect_combo)
|
||||
|
||||
|
||||
MEDALS: Medals = {
|
||||
|
||||
@@ -35,11 +35,7 @@ async def process_tth(
|
||||
).first()
|
||||
if not stats:
|
||||
return False
|
||||
if tth <= stats.total_hits < next_tth:
|
||||
return True
|
||||
elif next_tth == 0 and stats.play_count >= tth:
|
||||
return True
|
||||
return False
|
||||
return bool(tth <= stats.total_hits < next_tth or (next_tth == 0 and stats.play_count >= tth))
|
||||
|
||||
|
||||
MEDALS: Medals = {
|
||||
|
||||
13
app/auth.py
13
app/auth.py
@@ -69,7 +69,7 @@ def verify_password_legacy(plain_password: str, bcrypt_hash: str) -> bool:
|
||||
2. MD5哈希 -> bcrypt验证
|
||||
"""
|
||||
# 1. 明文密码转 MD5
|
||||
pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode()
|
||||
pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode() # noqa: S324
|
||||
|
||||
# 2. 检查缓存
|
||||
if bcrypt_hash in bcrypt_cache:
|
||||
@@ -103,7 +103,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""生成密码哈希 - 使用 osu! 的方式"""
|
||||
# 1. 明文密码 -> MD5
|
||||
pw_md5 = hashlib.md5(password.encode()).hexdigest().encode()
|
||||
pw_md5 = hashlib.md5(password.encode()).hexdigest().encode() # noqa: S324
|
||||
# 2. MD5 -> bcrypt
|
||||
pw_bcrypt = bcrypt.hashpw(pw_md5, bcrypt.gensalt())
|
||||
return pw_bcrypt.decode()
|
||||
@@ -114,7 +114,7 @@ async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) -
|
||||
验证用户身份 - 使用类似 from_login 的逻辑
|
||||
"""
|
||||
# 1. 明文密码转 MD5
|
||||
pw_md5 = hashlib.md5(password.encode()).hexdigest()
|
||||
pw_md5 = hashlib.md5(password.encode()).hexdigest() # noqa: S324
|
||||
|
||||
# 2. 根据用户名查找用户
|
||||
user = None
|
||||
@@ -325,12 +325,7 @@ def _generate_totp_account_label(user: User) -> str:
|
||||
|
||||
根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性
|
||||
"""
|
||||
if settings.totp_use_username_in_label:
|
||||
# 使用用户名作为主要标识
|
||||
primary_identifier = user.username
|
||||
else:
|
||||
# 使用邮箱作为标识
|
||||
primary_identifier = user.email
|
||||
primary_identifier = user.username if settings.totp_use_username_in_label else user.email
|
||||
|
||||
# 如果配置了服务名称,添加到标签中以便在认证器中区分
|
||||
if settings.totp_service_name:
|
||||
|
||||
@@ -419,9 +419,8 @@ def too_dense(hit_objects: list[HitObject], per_1s: int, per_10s: int) -> bool:
|
||||
if len(hit_objects) > i + per_1s:
|
||||
if hit_objects[i + per_1s].start_time - hit_objects[i].start_time < 1000:
|
||||
return True
|
||||
elif len(hit_objects) > i + per_10s:
|
||||
if hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
|
||||
return True
|
||||
elif len(hit_objects) > i + per_10s and hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -448,10 +447,7 @@ def slider_is_sus(hit_objects: list[HitObject]) -> bool:
|
||||
|
||||
|
||||
def is_2b(hit_objects: list[HitObject]) -> bool:
|
||||
for i in range(0, len(hit_objects) - 1):
|
||||
if hit_objects[i] == hit_objects[i + 1].start_time:
|
||||
return True
|
||||
return False
|
||||
return any(hit_objects[i] == hit_objects[i + 1].start_time for i in range(0, len(hit_objects) - 1))
|
||||
|
||||
|
||||
def is_suspicious_beatmap(content: str) -> bool:
|
||||
|
||||
@@ -217,7 +217,7 @@ STORAGE_SETTINGS='{
|
||||
# 服务器设置
|
||||
host: Annotated[
|
||||
str,
|
||||
Field(default="0.0.0.0", description="服务器监听地址"),
|
||||
Field(default="0.0.0.0", description="服务器监听地址"), # noqa: S104
|
||||
"服务器设置",
|
||||
]
|
||||
port: Annotated[
|
||||
@@ -609,26 +609,26 @@ STORAGE_SETTINGS='{
|
||||
]
|
||||
|
||||
@field_validator("fetcher_scopes", mode="before")
|
||||
@classmethod
|
||||
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
|
||||
if isinstance(v, str):
|
||||
return v.split(",")
|
||||
return v
|
||||
|
||||
@field_validator("storage_settings", mode="after")
|
||||
@classmethod
|
||||
def validate_storage_settings(
|
||||
cls,
|
||||
v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings,
|
||||
info: ValidationInfo,
|
||||
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
|
||||
if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2:
|
||||
if not isinstance(v, CloudflareR2Settings):
|
||||
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
|
||||
elif info.data.get("storage_service") == StorageServiceType.LOCAL:
|
||||
if not isinstance(v, LocalStorageSettings):
|
||||
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
|
||||
elif info.data.get("storage_service") == StorageServiceType.AWS_S3:
|
||||
if not isinstance(v, AWSS3StorageSettings):
|
||||
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
|
||||
service = info.data.get("storage_service")
|
||||
if service == StorageServiceType.CLOUDFLARE_R2 and not isinstance(v, CloudflareR2Settings):
|
||||
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
|
||||
if service == StorageServiceType.LOCAL and not isinstance(v, LocalStorageSettings):
|
||||
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
|
||||
if service == StorageServiceType.AWS_S3 and not isinstance(v, AWSS3StorageSettings):
|
||||
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
|
||||
return v
|
||||
|
||||
|
||||
|
||||
@@ -71,10 +71,10 @@ class Beatmap(BeatmapBase, table=True):
|
||||
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@classmethod
|
||||
async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||
d = resp.model_dump()
|
||||
del d["beatmapset"]
|
||||
beatmap = Beatmap.model_validate(
|
||||
beatmap = cls.model_validate(
|
||||
{
|
||||
**d,
|
||||
"beatmapset_id": resp.beatmapset_id,
|
||||
@@ -90,8 +90,7 @@ class Beatmap(BeatmapBase, table=True):
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||||
session.add(beatmap)
|
||||
await session.commit()
|
||||
beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
|
||||
return beatmap
|
||||
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
|
||||
|
||||
@classmethod
|
||||
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
|
||||
@@ -250,7 +249,7 @@ async def calculate_beatmap_attributes(
|
||||
redis: Redis,
|
||||
fetcher: "Fetcher",
|
||||
):
|
||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes"
|
||||
if await redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key))
|
||||
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
|
||||
@@ -130,7 +130,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
|
||||
|
||||
@classmethod
|
||||
async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
|
||||
async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
|
||||
d = resp.model_dump()
|
||||
if resp.nominations:
|
||||
d["nominations_required"] = resp.nominations.required
|
||||
@@ -158,10 +158,15 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
return beatmapset
|
||||
|
||||
@classmethod
|
||||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
|
||||
async def from_resp(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
resp: "BeatmapsetResp",
|
||||
from_: int = 0,
|
||||
) -> "Beatmapset":
|
||||
from .beatmap import Beatmap
|
||||
|
||||
beatmapset = await cls.from_resp_no_save(session, resp, from_=from_)
|
||||
beatmapset = await cls.from_resp_no_save(resp)
|
||||
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
|
||||
session.add(beatmapset)
|
||||
await session.commit()
|
||||
|
||||
@@ -105,17 +105,11 @@ class ChatChannelResp(ChatChannelBase):
|
||||
)
|
||||
).first()
|
||||
|
||||
last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
if last_msg and last_msg.isdigit():
|
||||
last_msg = int(last_msg)
|
||||
else:
|
||||
last_msg = None
|
||||
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||
|
||||
last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
if last_read_id and last_read_id.isdigit():
|
||||
last_read_id = int(last_read_id)
|
||||
else:
|
||||
last_read_id = last_msg
|
||||
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
|
||||
|
||||
if silence is not None:
|
||||
attribute = ChatUserAttributes(
|
||||
|
||||
@@ -520,12 +520,11 @@ async def _score_where(
|
||||
wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
|
||||
else:
|
||||
return None
|
||||
elif type == LeaderboardType.TEAM:
|
||||
if user:
|
||||
team_membership = await user.awaitable_attrs.team_membership
|
||||
if team_membership:
|
||||
team_id = team_membership.team_id
|
||||
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
|
||||
elif type == LeaderboardType.TEAM and user:
|
||||
team_membership = await user.awaitable_attrs.team_membership
|
||||
if team_membership:
|
||||
team_id = team_membership.team_id
|
||||
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
|
||||
if mods:
|
||||
if user and user.is_supporter:
|
||||
wheres.append(
|
||||
|
||||
@@ -256,8 +256,6 @@ class UserResp(UserBase):
|
||||
session: AsyncSession,
|
||||
include: list[str] = [],
|
||||
ruleset: GameMode | None = None,
|
||||
*,
|
||||
token_id: int | None = None,
|
||||
) -> "UserResp":
|
||||
from app.dependencies.database import get_redis
|
||||
|
||||
@@ -310,16 +308,16 @@ class UserResp(UserBase):
|
||||
).all()
|
||||
]
|
||||
|
||||
if "team" in include:
|
||||
if team_membership := await obj.awaitable_attrs.team_membership:
|
||||
u.team = team_membership.team
|
||||
if "team" in include and (team_membership := await obj.awaitable_attrs.team_membership):
|
||||
u.team = team_membership.team
|
||||
|
||||
if "account_history" in include:
|
||||
u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
|
||||
|
||||
if "daily_challenge_user_stats":
|
||||
if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats:
|
||||
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
|
||||
if "daily_challenge_user_stats" in include and (
|
||||
daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats
|
||||
):
|
||||
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
|
||||
|
||||
if "statistics" in include:
|
||||
current_stattistics = None
|
||||
@@ -443,7 +441,7 @@ class MeResp(UserResp):
|
||||
from app.dependencies.database import get_redis
|
||||
from app.service.verification_service import LoginSessionService
|
||||
|
||||
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset, token_id=token_id)
|
||||
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset)
|
||||
u.session_verified = (
|
||||
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
|
||||
if token_id
|
||||
|
||||
@@ -7,7 +7,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
|
||||
def BodyOrForm[T: BaseModel](model: type[T]):
|
||||
def BodyOrForm[T: BaseModel](model: type[T]): # noqa: N802
|
||||
async def dependency(
|
||||
request: Request,
|
||||
) -> T:
|
||||
|
||||
@@ -119,10 +119,7 @@ async def get_client_user(
|
||||
if verify_method is None:
|
||||
# 智能选择验证方式(有TOTP优先TOTP)
|
||||
totp_key = await user.awaitable_attrs.totp_key
|
||||
if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER:
|
||||
verify_method = "totp"
|
||||
else:
|
||||
verify_method = "mail"
|
||||
verify_method = "totp" if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER else "mail"
|
||||
|
||||
# 设置选择的验证方法到Redis中,避免重复选择
|
||||
if api_version >= 20250913:
|
||||
|
||||
0
app/exceptions/__init__.py
Normal file
0
app/exceptions/__init__.py
Normal file
@@ -116,7 +116,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
|
||||
# 序列化为 JSON 并生成 MD5 哈希
|
||||
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
|
||||
cache_hash = hashlib.md5(cache_json.encode()).hexdigest()
|
||||
cache_hash = hashlib.md5(cache_json.encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
logger.opt(colors=True).debug(f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}")
|
||||
|
||||
@@ -160,10 +160,10 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
cached_data = json.loads(cached_result)
|
||||
return SearchBeatmapsetsResp.model_validate(cached_data)
|
||||
except Exception as e:
|
||||
logger.opt(colors=True).warning(f"Cache data invalid, fetching from API: {e}")
|
||||
logger.warning(f"Cache data invalid, fetching from API: {e}")
|
||||
|
||||
# 缓存未命中,从 API 获取数据
|
||||
logger.opt(colors=True).debug("Cache miss, fetching from API")
|
||||
logger.debug("Cache miss, fetching from API")
|
||||
|
||||
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
|
||||
|
||||
@@ -203,7 +203,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
try:
|
||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
|
||||
except RateLimitError:
|
||||
logger.opt(colors=True).info("Prefetch skipped due to rate limit")
|
||||
logger.info("Prefetch skipped due to rate limit")
|
||||
|
||||
bg_tasks.add_task(delayed_prefetch)
|
||||
|
||||
@@ -227,14 +227,14 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
# 使用当前 cursor 请求下一页
|
||||
next_query = query.model_copy()
|
||||
|
||||
logger.opt(colors=True).debug(f"Prefetching page {page + 1}")
|
||||
logger.debug(f"Prefetching page {page + 1}")
|
||||
|
||||
# 生成下一页的缓存键
|
||||
next_cache_key = self._generate_cache_key(next_query, cursor)
|
||||
|
||||
# 检查是否已经缓存
|
||||
if await redis_client.exists(next_cache_key):
|
||||
logger.opt(colors=True).debug(f"Page {page + 1} already cached")
|
||||
logger.debug(f"Page {page + 1} already cached")
|
||||
# 尝试从缓存获取cursor继续预取
|
||||
cached_data = await redis_client.get(next_cache_key)
|
||||
if cached_data:
|
||||
@@ -244,7 +244,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
cursor = data["cursor"]
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning("Failed to parse cached data for cursor")
|
||||
break
|
||||
|
||||
# 在预取页面之间添加延迟,避免突发请求
|
||||
@@ -279,18 +279,18 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
ex=prefetch_ttl,
|
||||
)
|
||||
|
||||
logger.opt(colors=True).debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
|
||||
logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
|
||||
|
||||
except RateLimitError:
|
||||
logger.opt(colors=True).info("Prefetch stopped due to rate limit")
|
||||
logger.info("Prefetch stopped due to rate limit")
|
||||
except Exception as e:
|
||||
logger.opt(colors=True).warning(f"Prefetch failed: {e}")
|
||||
logger.warning(f"Prefetch failed: {e}")
|
||||
|
||||
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
|
||||
"""预热主页缓存"""
|
||||
homepage_queries = self._get_homepage_queries()
|
||||
|
||||
logger.opt(colors=True).info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)")
|
||||
logger.info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)")
|
||||
|
||||
for i, (query, cursor) in enumerate(homepage_queries):
|
||||
try:
|
||||
@@ -302,7 +302,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
|
||||
# 检查是否已经缓存
|
||||
if await redis_client.exists(cache_key):
|
||||
logger.opt(colors=True).debug(f"Query {query.sort} already cached")
|
||||
logger.debug(f"Query {query.sort} already cached")
|
||||
continue
|
||||
|
||||
# 请求并缓存
|
||||
@@ -325,15 +325,15 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
ex=cache_ttl,
|
||||
)
|
||||
|
||||
logger.opt(colors=True).info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
|
||||
logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
|
||||
|
||||
if api_response.get("cursor"):
|
||||
try:
|
||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
|
||||
except RateLimitError:
|
||||
logger.opt(colors=True).info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
|
||||
logger.info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
|
||||
|
||||
except RateLimitError:
|
||||
logger.opt(colors=True).warning(f"Warmup skipped for {query.sort} due to rate limit")
|
||||
logger.warning(f"Warmup skipped for {query.sort} due to rate limit")
|
||||
except Exception as e:
|
||||
logger.opt(colors=True).error(f"Failed to warmup cache for {query.sort}: {e}")
|
||||
logger.error(f"Failed to warmup cache for {query.sort}: {e}")
|
||||
|
||||
0
app/helpers/__init__.py
Normal file
0
app/helpers/__init__.py
Normal file
@@ -1,19 +1,39 @@
|
||||
"""
|
||||
GeoLite2 Helper Class
|
||||
GeoLite2 Helper Class (asynchronous)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Required, TypedDict
|
||||
|
||||
from app.log import logger
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
import maxminddb
|
||||
|
||||
|
||||
class GeoIPLookupResult(TypedDict, total=False):
|
||||
ip: Required[str]
|
||||
country_iso: str
|
||||
country_name: str
|
||||
city_name: str
|
||||
latitude: str
|
||||
longitude: str
|
||||
time_zone: str
|
||||
postal_code: str
|
||||
asn: int | None
|
||||
organization: str
|
||||
|
||||
|
||||
BASE_URL = "https://download.maxmind.com/app/geoip_download"
|
||||
EDITIONS = {
|
||||
"City": "GeoLite2-City",
|
||||
@@ -25,161 +45,184 @@ EDITIONS = {
|
||||
class GeoIPHelper:
|
||||
def __init__(
|
||||
self,
|
||||
dest_dir="./geoip",
|
||||
license_key=None,
|
||||
editions=None,
|
||||
max_age_days=8,
|
||||
timeout=60.0,
|
||||
dest_dir: str | Path = Path("./geoip"),
|
||||
license_key: str | None = None,
|
||||
editions: list[str] | None = None,
|
||||
max_age_days: int = 8,
|
||||
timeout: float = 60.0,
|
||||
):
|
||||
self.dest_dir = dest_dir
|
||||
self.dest_dir = Path(dest_dir).expanduser()
|
||||
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
|
||||
self.editions = editions or ["City", "ASN"]
|
||||
self.editions = list(editions or ["City", "ASN"])
|
||||
self.max_age_days = max_age_days
|
||||
self.timeout = timeout
|
||||
self._readers = {}
|
||||
self._readers: dict[str, maxminddb.Reader] = {}
|
||||
self._update_lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _safe_extract(tar: tarfile.TarFile, path: str):
|
||||
base = Path(path).resolve()
|
||||
for m in tar.getmembers():
|
||||
target = (base / m.name).resolve()
|
||||
if not str(target).startswith(str(base)):
|
||||
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
|
||||
base = path.resolve()
|
||||
for member in tar.getmembers():
|
||||
target = (base / member.name).resolve()
|
||||
if not target.is_relative_to(base): # py312
|
||||
raise RuntimeError("Unsafe path in tar file")
|
||||
tar.extractall(path=path, filter="data")
|
||||
tar.extractall(path=base, filter="data")
|
||||
|
||||
def _download_and_extract(self, edition_id: str) -> str:
|
||||
"""
|
||||
下载并解压 mmdb 文件到 dest_dir,仅保留 .mmdb
|
||||
- 跟随 302 重定向
|
||||
- 流式下载到临时文件
|
||||
- 临时目录退出后自动清理
|
||||
"""
|
||||
@staticmethod
|
||||
def _as_mapping(value: Any) -> dict[str, Any]:
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
def _as_str(value: Any, default: str = "") -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if value is None:
|
||||
return default
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _as_int(value: Any) -> int | None:
|
||||
return value if isinstance(value, int) else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_tarball(src: Path, dest: Path) -> None:
|
||||
with tarfile.open(src, "r:gz") as tar:
|
||||
GeoIPHelper._safe_extract(tar, dest)
|
||||
|
||||
@staticmethod
|
||||
def _find_mmdb(root: Path) -> Path | None:
|
||||
for candidate in root.rglob("*.mmdb"):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
def _latest_file_sync(self, edition_id: str) -> Path | None:
|
||||
directory = self.dest_dir
|
||||
if not directory.is_dir():
|
||||
return None
|
||||
candidates = list(directory.glob(f"{edition_id}*.mmdb"))
|
||||
if not candidates:
|
||||
return None
|
||||
return max(candidates, key=lambda p: p.stat().st_mtime)
|
||||
|
||||
async def _latest_file(self, edition_id: str) -> Path | None:
|
||||
return await asyncio.to_thread(self._latest_file_sync, edition_id)
|
||||
|
||||
async def _download_and_extract(self, edition_id: str) -> Path:
|
||||
if not self.license_key:
|
||||
raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.")
|
||||
|
||||
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
|
||||
tmp_dir = Path(await asyncio.to_thread(tempfile.mkdtemp))
|
||||
|
||||
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
|
||||
with client.stream("GET", url) as resp:
|
||||
try:
|
||||
tgz_path = tmp_dir / "db.tgz"
|
||||
async with (
|
||||
httpx.AsyncClient(follow_redirects=True, timeout=self.timeout) as client,
|
||||
client.stream("GET", url) as resp,
|
||||
):
|
||||
resp.raise_for_status()
|
||||
with tempfile.TemporaryDirectory() as tmpd:
|
||||
tgz_path = os.path.join(tmpd, "db.tgz")
|
||||
# 流式写入
|
||||
with open(tgz_path, "wb") as f:
|
||||
for chunk in resp.iter_bytes():
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
async with aiofiles.open(tgz_path, "wb") as download_file:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
await download_file.write(chunk)
|
||||
|
||||
# 解压并只移动 .mmdb
|
||||
with tarfile.open(tgz_path, "r:gz") as tar:
|
||||
# 先安全检查与解压
|
||||
self._safe_extract(tar, tmpd)
|
||||
await asyncio.to_thread(self._extract_tarball, tgz_path, tmp_dir)
|
||||
mmdb_path = await asyncio.to_thread(self._find_mmdb, tmp_dir)
|
||||
if mmdb_path is None:
|
||||
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
|
||||
|
||||
# 递归找 .mmdb
|
||||
mmdb_path = None
|
||||
for root, _, files in os.walk(tmpd):
|
||||
for fn in files:
|
||||
if fn.endswith(".mmdb"):
|
||||
mmdb_path = os.path.join(root, fn)
|
||||
break
|
||||
if mmdb_path:
|
||||
break
|
||||
await asyncio.to_thread(self.dest_dir.mkdir, parents=True, exist_ok=True)
|
||||
dst = self.dest_dir / mmdb_path.name
|
||||
await asyncio.to_thread(shutil.move, mmdb_path, dst)
|
||||
return dst
|
||||
finally:
|
||||
await asyncio.to_thread(shutil.rmtree, tmp_dir, ignore_errors=True)
|
||||
|
||||
if not mmdb_path:
|
||||
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
|
||||
async def update(self, force: bool = False) -> None:
|
||||
async with self._update_lock:
|
||||
for edition in self.editions:
|
||||
edition_id = EDITIONS[edition]
|
||||
path = await self._latest_file(edition_id)
|
||||
need_download = force or path is None
|
||||
|
||||
os.makedirs(self.dest_dir, exist_ok=True)
|
||||
dst = os.path.join(self.dest_dir, os.path.basename(mmdb_path))
|
||||
shutil.move(mmdb_path, dst)
|
||||
return dst
|
||||
|
||||
def _latest_file(self, edition_id: str):
|
||||
if not os.path.isdir(self.dest_dir):
|
||||
return None
|
||||
files = [
|
||||
os.path.join(self.dest_dir, f)
|
||||
for f in os.listdir(self.dest_dir)
|
||||
if f.startswith(edition_id) and f.endswith(".mmdb")
|
||||
]
|
||||
return max(files, key=os.path.getmtime) if files else None
|
||||
|
||||
def update(self, force=False):
|
||||
from app.log import logger
|
||||
|
||||
for ed in self.editions:
|
||||
eid = EDITIONS[ed]
|
||||
path = self._latest_file(eid)
|
||||
need = force or not path
|
||||
|
||||
if path:
|
||||
age_days = (time.time() - os.path.getmtime(path)) / 86400
|
||||
if age_days >= self.max_age_days:
|
||||
need = True
|
||||
logger.info(
|
||||
f"{eid} database is {age_days:.1f} days old "
|
||||
f"(max: {self.max_age_days}), will download new version"
|
||||
)
|
||||
if path:
|
||||
mtime = await asyncio.to_thread(path.stat)
|
||||
age_days = (time.time() - mtime.st_mtime) / 86400
|
||||
if age_days >= self.max_age_days:
|
||||
need_download = True
|
||||
logger.info(
|
||||
f"{edition_id} database is {age_days:.1f} days old "
|
||||
f"(max: {self.max_age_days}), will download new version"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{edition_id} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})"
|
||||
)
|
||||
else:
|
||||
logger.info(f"{eid} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})")
|
||||
else:
|
||||
logger.info(f"{eid} database not found, will download")
|
||||
logger.info(f"{edition_id} database not found, will download")
|
||||
|
||||
if need:
|
||||
logger.info(f"Downloading {eid} database...")
|
||||
path = self._download_and_extract(eid)
|
||||
logger.info(f"{eid} database downloaded successfully")
|
||||
else:
|
||||
logger.info(f"Using existing {eid} database")
|
||||
if need_download:
|
||||
logger.info(f"Downloading {edition_id} database...")
|
||||
path = await self._download_and_extract(edition_id)
|
||||
logger.info(f"{edition_id} database downloaded successfully")
|
||||
else:
|
||||
logger.info(f"Using existing {edition_id} database")
|
||||
|
||||
old = self._readers.get(ed)
|
||||
if old:
|
||||
try:
|
||||
old.close()
|
||||
except Exception:
|
||||
pass
|
||||
if path is not None:
|
||||
self._readers[ed] = maxminddb.open_database(path)
|
||||
old_reader = self._readers.get(edition)
|
||||
if old_reader:
|
||||
with suppress(Exception):
|
||||
old_reader.close()
|
||||
if path is not None:
|
||||
self._readers[edition] = maxminddb.open_database(str(path))
|
||||
|
||||
def lookup(self, ip: str):
|
||||
res = {"ip": ip}
|
||||
# City
|
||||
city_r = self._readers.get("City")
|
||||
if city_r:
|
||||
data = city_r.get(ip)
|
||||
if data:
|
||||
country = data.get("country") or {}
|
||||
res["country_iso"] = country.get("iso_code") or ""
|
||||
res["country_name"] = (country.get("names") or {}).get("en", "")
|
||||
city = data.get("city") or {}
|
||||
res["city_name"] = (city.get("names") or {}).get("en", "")
|
||||
loc = data.get("location") or {}
|
||||
res["latitude"] = str(loc.get("latitude") or "")
|
||||
res["longitude"] = str(loc.get("longitude") or "")
|
||||
res["time_zone"] = str(loc.get("time_zone") or "")
|
||||
postal = data.get("postal") or {}
|
||||
if "code" in postal:
|
||||
res["postal_code"] = postal["code"]
|
||||
# ASN
|
||||
asn_r = self._readers.get("ASN")
|
||||
if asn_r:
|
||||
data = asn_r.get(ip)
|
||||
if data:
|
||||
res["asn"] = data.get("autonomous_system_number")
|
||||
res["organization"] = data.get("autonomous_system_organization")
|
||||
def lookup(self, ip: str) -> GeoIPLookupResult:
|
||||
res: GeoIPLookupResult = {"ip": ip}
|
||||
city_reader = self._readers.get("City")
|
||||
if city_reader:
|
||||
data = city_reader.get(ip)
|
||||
if isinstance(data, dict):
|
||||
country = self._as_mapping(data.get("country"))
|
||||
res["country_iso"] = self._as_str(country.get("iso_code"))
|
||||
country_names = self._as_mapping(country.get("names"))
|
||||
res["country_name"] = self._as_str(country_names.get("en"))
|
||||
|
||||
city = self._as_mapping(data.get("city"))
|
||||
city_names = self._as_mapping(city.get("names"))
|
||||
res["city_name"] = self._as_str(city_names.get("en"))
|
||||
|
||||
location = self._as_mapping(data.get("location"))
|
||||
latitude = location.get("latitude")
|
||||
longitude = location.get("longitude")
|
||||
res["latitude"] = str(latitude) if latitude is not None else ""
|
||||
res["longitude"] = str(longitude) if longitude is not None else ""
|
||||
res["time_zone"] = self._as_str(location.get("time_zone"))
|
||||
|
||||
postal = self._as_mapping(data.get("postal"))
|
||||
postal_code = postal.get("code")
|
||||
if postal_code is not None:
|
||||
res["postal_code"] = self._as_str(postal_code)
|
||||
|
||||
asn_reader = self._readers.get("ASN")
|
||||
if asn_reader:
|
||||
data = asn_reader.get(ip)
|
||||
if isinstance(data, dict):
|
||||
res["asn"] = self._as_int(data.get("autonomous_system_number"))
|
||||
res["organization"] = self._as_str(data.get("autonomous_system_organization"), default="")
|
||||
return res
|
||||
|
||||
def close(self):
|
||||
for r in self._readers.values():
|
||||
try:
|
||||
r.close()
|
||||
except Exception:
|
||||
pass
|
||||
def close(self) -> None:
|
||||
for reader in self._readers.values():
|
||||
with suppress(Exception):
|
||||
reader.close()
|
||||
self._readers = {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
|
||||
geo.update()
|
||||
print(geo.lookup("8.8.8.8"))
|
||||
geo.close()
|
||||
|
||||
async def _demo() -> None:
|
||||
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
|
||||
await geo.update()
|
||||
print(geo.lookup("8.8.8.8"))
|
||||
geo.close()
|
||||
|
||||
asyncio.run(_demo())
|
||||
|
||||
@@ -97,9 +97,7 @@ class InterceptHandler(logging.Handler):
|
||||
status_color = "green"
|
||||
elif 300 <= status < 400:
|
||||
status_color = "yellow"
|
||||
elif 400 <= status < 500:
|
||||
status_color = "red"
|
||||
elif 500 <= status < 600:
|
||||
elif 400 <= status < 500 or 500 <= status < 600:
|
||||
status_color = "red"
|
||||
|
||||
return (
|
||||
|
||||
@@ -82,7 +82,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
# 启动验证流程
|
||||
return await self._initiate_verification(request, session_state)
|
||||
return await self._initiate_verification(session_state)
|
||||
|
||||
def _should_skip_verification(self, request: Request) -> bool:
|
||||
"""检查是否应该跳过验证"""
|
||||
@@ -93,10 +93,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
return True
|
||||
|
||||
# 非API请求跳过
|
||||
if not path.startswith("/api/"):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(not path.startswith("/api/"))
|
||||
|
||||
def _requires_verification(self, request: Request, user: User) -> bool:
|
||||
"""检查是否需要验证"""
|
||||
@@ -177,7 +174,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
||||
logger.error(f"Error getting session state: {e}")
|
||||
return None
|
||||
|
||||
async def _initiate_verification(self, request: Request, state: SessionState) -> Response:
|
||||
async def _initiate_verification(self, state: SessionState) -> Response:
|
||||
"""启动验证流程"""
|
||||
try:
|
||||
method = await state.get_method()
|
||||
|
||||
@@ -11,7 +11,7 @@ class ExtendedTokenResponse(BaseModel):
|
||||
"""扩展的令牌响应,支持二次验证状态"""
|
||||
|
||||
access_token: str | None = None
|
||||
token_type: str = "Bearer"
|
||||
token_type: str = "Bearer" # noqa: S105
|
||||
expires_in: int | None = None
|
||||
refresh_token: str | None = None
|
||||
scope: str | None = None
|
||||
@@ -20,14 +20,3 @@ class ExtendedTokenResponse(BaseModel):
|
||||
requires_second_factor: bool = False
|
||||
verification_message: str | None = None
|
||||
user_id: int | None = None # 用于二次验证的用户ID
|
||||
|
||||
|
||||
class SessionState(BaseModel):
|
||||
"""会话状态"""
|
||||
|
||||
user_id: int
|
||||
username: str
|
||||
email: str
|
||||
requires_verification: bool
|
||||
session_token: str | None = None
|
||||
verification_sent: bool = False
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# ruff: noqa: ARG002
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
@@ -22,7 +22,7 @@ class TokenRequest(BaseModel):
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "Bearer"
|
||||
token_type: str = "Bearer" # noqa: S105
|
||||
expires_in: int
|
||||
refresh_token: str
|
||||
scope: str = "*"
|
||||
@@ -67,7 +67,7 @@ class RegistrationRequestErrors(BaseModel):
|
||||
class OAuth2ClientCredentialsBearer(OAuth2):
|
||||
def __init__(
|
||||
self,
|
||||
tokenUrl: Annotated[
|
||||
tokenUrl: Annotated[ # noqa: N803
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
@@ -75,7 +75,7 @@ class OAuth2ClientCredentialsBearer(OAuth2):
|
||||
"""
|
||||
),
|
||||
],
|
||||
refreshUrl: Annotated[
|
||||
refreshUrl: Annotated[ # noqa: N803
|
||||
str | None,
|
||||
Doc(
|
||||
"""
|
||||
|
||||
@@ -46,10 +46,10 @@ class PlayerStatsResponse(BaseModel):
|
||||
class PlayerEventItem(BaseModel):
|
||||
"""玩家事件项目"""
|
||||
|
||||
userId: int
|
||||
userId: int # noqa: N815
|
||||
name: str
|
||||
mapId: int | None = None
|
||||
setId: int | None = None
|
||||
mapId: int | None = None # noqa: N815
|
||||
setId: int | None = None # noqa: N815
|
||||
artist: str | None = None
|
||||
title: str | None = None
|
||||
version: str | None = None
|
||||
@@ -88,7 +88,7 @@ class PlayerInfo(BaseModel):
|
||||
custom_badge_icon: str
|
||||
custom_badge_color: str
|
||||
userpage_content: str
|
||||
recentFailed: int
|
||||
recentFailed: int # noqa: N815
|
||||
social_discord: str | None = None
|
||||
social_youtube: str | None = None
|
||||
social_twitter: str | None = None
|
||||
|
||||
@@ -126,21 +126,22 @@ async def register_user(
|
||||
|
||||
try:
|
||||
# 获取客户端 IP 并查询地理位置
|
||||
country_code = "CN" # 默认国家代码
|
||||
country_code = None # 默认国家代码
|
||||
|
||||
try:
|
||||
# 查询 IP 地理位置
|
||||
geo_info = geoip.lookup(client_ip)
|
||||
if geo_info and geo_info.get("country_iso"):
|
||||
country_code = geo_info["country_iso"]
|
||||
if geo_info and (country_code := geo_info.get("country_iso")):
|
||||
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
|
||||
else:
|
||||
logger.warning(f"Could not determine country for IP {client_ip}")
|
||||
except Exception as e:
|
||||
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
|
||||
if country_code is None:
|
||||
country_code = "CN"
|
||||
|
||||
# 创建新用户
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=2是BanchoBot)
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
||||
@@ -157,7 +158,7 @@ async def register_user(
|
||||
email=user_email,
|
||||
pw_bcrypt=get_password_hash(user_password),
|
||||
priv=1, # 普通用户权限
|
||||
country_code=country_code, # 根据 IP 地理位置设置国家
|
||||
country_code=country_code,
|
||||
join_date=utcnow(),
|
||||
last_visit=utcnow(),
|
||||
is_supporter=settings.enable_supporter_for_all_users,
|
||||
@@ -386,7 +387,7 @@ async def oauth_token(
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
token_type="Bearer", # noqa: S106
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=scope,
|
||||
@@ -439,7 +440,7 @@ async def oauth_token(
|
||||
)
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
token_type="Bearer", # noqa: S106
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=new_refresh_token,
|
||||
scope=scope,
|
||||
@@ -509,7 +510,7 @@ async def oauth_token(
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
token_type="Bearer", # noqa: S106
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=" ".join(scopes),
|
||||
@@ -554,7 +555,7 @@ async def oauth_token(
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
token_type="Bearer", # noqa: S106
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=" ".join(scopes),
|
||||
|
||||
@@ -130,7 +130,7 @@ def _coerce_playlist_item(item_data: dict[str, Any], default_order: int, host_us
|
||||
"allowed_mods": item_data.get("allowed_mods", []),
|
||||
"expired": bool(item_data.get("expired", False)),
|
||||
"playlist_order": item_data.get("playlist_order", default_order),
|
||||
"played_at": item_data.get("played_at", None),
|
||||
"played_at": item_data.get("played_at"),
|
||||
"freestyle": bool(item_data.get("freestyle", True)),
|
||||
"beatmap_checksum": item_data.get("beatmap_checksum", ""),
|
||||
"star_rating": item_data.get("star_rating", 0.0),
|
||||
|
||||
@@ -157,10 +157,7 @@ async def _help(user: User, args: list[str], _session: AsyncSession, channel: Ch
|
||||
|
||||
@bot.command("roll")
|
||||
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
|
||||
if len(args) > 0 and args[0].isdigit():
|
||||
r = random.randint(1, int(args[0]))
|
||||
else:
|
||||
r = random.randint(1, 100)
|
||||
r = random.randint(1, int(args[0])) if len(args) > 0 and args[0].isdigit() else random.randint(1, 100)
|
||||
return f"{user.username} rolls {r} point(s)"
|
||||
|
||||
|
||||
@@ -179,10 +176,7 @@ async def _stats(user: User, args: list[str], session: AsyncSession, channel: Ch
|
||||
if gamemode is None:
|
||||
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
|
||||
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
|
||||
if last_score is not None:
|
||||
gamemode = last_score.gamemode
|
||||
else:
|
||||
gamemode = target_user.playmode
|
||||
gamemode = last_score.gamemode if last_score is not None else target_user.playmode
|
||||
|
||||
statistics = (
|
||||
await session.exec(
|
||||
|
||||
@@ -313,10 +313,7 @@ async def chat_websocket(
|
||||
# 优先使用查询参数中的token,支持token或access_token参数名
|
||||
auth_token = token or access_token
|
||||
if not auth_token and authorization:
|
||||
if authorization.startswith("Bearer "):
|
||||
auth_token = authorization[7:]
|
||||
else:
|
||||
auth_token = authorization
|
||||
auth_token = authorization.removeprefix("Bearer ")
|
||||
|
||||
if not auth_token:
|
||||
await websocket.close(code=1008, reason="Missing authentication token")
|
||||
|
||||
@@ -10,7 +10,7 @@ from fastapi.responses import RedirectResponse
|
||||
redirect_router = APIRouter(include_in_schema=False)
|
||||
|
||||
|
||||
@redirect_router.get("/users/{path:path}")
|
||||
@redirect_router.get("/users/{path:path}") # noqa: FAST003
|
||||
@redirect_router.get("/teams/{team_id}")
|
||||
@redirect_router.get("/u/{user_id}")
|
||||
@redirect_router.get("/b/{beatmap_id}")
|
||||
|
||||
@@ -168,10 +168,7 @@ async def get_beatmaps(
|
||||
elif beatmapset_id is not None:
|
||||
beatmapset = await Beatmapset.get_or_fetch(session, fetcher, beatmapset_id)
|
||||
await beatmapset.awaitable_attrs.beatmaps
|
||||
if len(beatmapset.beatmaps) > limit:
|
||||
beatmaps = beatmapset.beatmaps[:limit]
|
||||
else:
|
||||
beatmaps = beatmapset.beatmaps
|
||||
beatmaps = beatmapset.beatmaps[:limit] if len(beatmapset.beatmaps) > limit else beatmapset.beatmaps
|
||||
elif user is not None:
|
||||
where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
|
||||
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()
|
||||
|
||||
@@ -158,7 +158,10 @@ async def get_beatmap_attributes(
|
||||
if ruleset is None:
|
||||
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||||
ruleset = beatmap_db.mode
|
||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
key = (
|
||||
f"beatmap:{beatmap_id}:{ruleset}:"
|
||||
f"{hashlib.md5(str(mods_).encode(), usedforsecurity=False).hexdigest()}:attributes"
|
||||
)
|
||||
if await redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
try:
|
||||
|
||||
@@ -46,7 +46,6 @@ async def _save_to_db(sets: SearchBeatmapsetsResp):
|
||||
response_model=SearchBeatmapsetsResp,
|
||||
)
|
||||
async def search_beatmapset(
|
||||
db: Database,
|
||||
query: Annotated[SearchQueryModel, Query(...)],
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
@@ -104,7 +103,7 @@ async def search_beatmapset(
|
||||
if cached_result:
|
||||
sets = SearchBeatmapsetsResp(**cached_result)
|
||||
# 处理资源代理
|
||||
processed_sets = await process_response_assets(sets, request)
|
||||
processed_sets = await process_response_assets(sets)
|
||||
return processed_sets
|
||||
|
||||
try:
|
||||
@@ -115,7 +114,7 @@ async def search_beatmapset(
|
||||
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
|
||||
|
||||
# 处理资源代理
|
||||
processed_sets = await process_response_assets(sets, request)
|
||||
processed_sets = await process_response_assets(sets)
|
||||
return processed_sets
|
||||
except HTTPError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
@@ -140,7 +139,7 @@ async def lookup_beatmapset(
|
||||
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
|
||||
if cached_resp:
|
||||
# 处理资源代理
|
||||
processed_resp = await process_response_assets(cached_resp, request)
|
||||
processed_resp = await process_response_assets(cached_resp)
|
||||
return processed_resp
|
||||
|
||||
try:
|
||||
@@ -151,7 +150,7 @@ async def lookup_beatmapset(
|
||||
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
|
||||
|
||||
# 处理资源代理
|
||||
processed_resp = await process_response_assets(resp, request)
|
||||
processed_resp = await process_response_assets(resp)
|
||||
return processed_resp
|
||||
except HTTPError as exc:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found") from exc
|
||||
@@ -176,7 +175,7 @@ async def get_beatmapset(
|
||||
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
|
||||
if cached_resp:
|
||||
# 处理资源代理
|
||||
processed_resp = await process_response_assets(cached_resp, request)
|
||||
processed_resp = await process_response_assets(cached_resp)
|
||||
return processed_resp
|
||||
|
||||
try:
|
||||
@@ -187,7 +186,7 @@ async def get_beatmapset(
|
||||
await cache_service.cache_beatmapset(resp)
|
||||
|
||||
# 处理资源代理
|
||||
processed_resp = await process_response_assets(resp, request)
|
||||
processed_resp = await process_response_assets(resp)
|
||||
return processed_resp
|
||||
except HTTPError as exc:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found") from exc
|
||||
|
||||
@@ -166,7 +166,6 @@ async def get_room(
|
||||
db: Database,
|
||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||
redis: Redis,
|
||||
category: Annotated[
|
||||
str,
|
||||
Query(
|
||||
|
||||
@@ -847,10 +847,7 @@ async def reorder_score_pin(
|
||||
detail = "After score not found" if after_score_id else "Before score not found"
|
||||
raise HTTPException(status_code=404, detail=detail)
|
||||
|
||||
if after_score_id:
|
||||
target_order = reference_score.pinned_order + 1
|
||||
else:
|
||||
target_order = reference_score.pinned_order
|
||||
target_order = reference_score.pinned_order + 1 if after_score_id else reference_score.pinned_order
|
||||
|
||||
current_order = score_record.pinned_order
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class SessionReissueResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class VerifyFailed(Exception):
|
||||
class VerifyFailedError(Exception):
|
||||
def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False):
|
||||
super().__init__(message)
|
||||
self.reason = reason
|
||||
@@ -93,10 +93,7 @@ async def verify_session(
|
||||
# 智能选择验证方法(参考osu-web实现)
|
||||
# API版本较老或用户未设置TOTP时强制使用邮件验证
|
||||
# print(api_version, totp_key)
|
||||
if api_version < 20240101 or totp_key is None:
|
||||
verify_method = "mail"
|
||||
else:
|
||||
verify_method = "totp"
|
||||
verify_method = "mail" if api_version < 20240101 or totp_key is None else "totp"
|
||||
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
|
||||
login_method = verify_method
|
||||
|
||||
@@ -109,7 +106,7 @@ async def verify_session(
|
||||
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
||||
)
|
||||
verify_method = "mail"
|
||||
raise VerifyFailed("用户TOTP已被删除,已切换到邮件验证")
|
||||
raise VerifyFailedError("用户TOTP已被删除,已切换到邮件验证")
|
||||
# 如果未开启邮箱验证,则直接认为认证通过
|
||||
# 正常不会进入到这里
|
||||
|
||||
@@ -120,16 +117,16 @@ async def verify_session(
|
||||
else:
|
||||
# 记录详细的验证失败原因(参考osu-web的错误处理)
|
||||
if len(verification_key) != 6:
|
||||
raise VerifyFailed("TOTP验证码长度错误,应为6位数字", reason="incorrect_length")
|
||||
raise VerifyFailedError("TOTP验证码长度错误,应为6位数字", reason="incorrect_length")
|
||||
elif not verification_key.isdigit():
|
||||
raise VerifyFailed("TOTP验证码格式错误,应为纯数字", reason="incorrect_format")
|
||||
raise VerifyFailedError("TOTP验证码格式错误,应为纯数字", reason="incorrect_format")
|
||||
else:
|
||||
# 可能是密钥错误或者重放攻击
|
||||
raise VerifyFailed("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
|
||||
raise VerifyFailedError("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
|
||||
else:
|
||||
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
|
||||
if not success:
|
||||
raise VerifyFailed(f"邮件验证失败: {message}")
|
||||
raise VerifyFailedError(f"邮件验证失败: {message}")
|
||||
|
||||
await LoginLogService.record_login(
|
||||
db=db,
|
||||
@@ -144,7 +141,7 @@ async def verify_session(
|
||||
await db.commit()
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
except VerifyFailed as e:
|
||||
except VerifyFailedError as e:
|
||||
await LoginLogService.record_failed_login(
|
||||
db=db,
|
||||
request=request,
|
||||
@@ -171,7 +168,9 @@ async def verify_session(
|
||||
)
|
||||
error_response["reissued"] = True
|
||||
except Exception:
|
||||
pass # 忽略重发邮件失败的错误
|
||||
log("Verification").exception(
|
||||
f"Failed to resend verification email to user {current_user.id} (token: {token_id})"
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response)
|
||||
|
||||
|
||||
@@ -44,9 +44,7 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession
|
||||
.where(col(Score.beatmap).has(col(Beatmap.mode) == Score.gamemode))
|
||||
)
|
||||
).first()
|
||||
if user_beatmap_score is None:
|
||||
return False
|
||||
return True
|
||||
return user_beatmap_score is not None
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -75,10 +73,9 @@ async def vote_beatmap_tags(
|
||||
.where(BeatmapTagVote.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
if previous_votes is None:
|
||||
if check_user_can_vote(current_user, beatmap_id, session):
|
||||
new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
|
||||
session.add(new_vote)
|
||||
if previous_votes is None and check_user_can_vote(current_user, beatmap_id, session):
|
||||
new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
|
||||
session.add(new_vote)
|
||||
await session.commit()
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Tag is not found")
|
||||
|
||||
@@ -91,7 +91,7 @@ async def get_users(
|
||||
|
||||
# 处理资源代理
|
||||
response = BatchUserResponse(users=cached_users)
|
||||
processed_response = await process_response_assets(response, request)
|
||||
processed_response = await process_response_assets(response)
|
||||
return processed_response
|
||||
else:
|
||||
searched_users = (await session.exec(select(User).limit(50))).all()
|
||||
@@ -109,7 +109,7 @@ async def get_users(
|
||||
|
||||
# 处理资源代理
|
||||
response = BatchUserResponse(users=users)
|
||||
processed_response = await process_response_assets(response, request)
|
||||
processed_response = await process_response_assets(response)
|
||||
return processed_response
|
||||
|
||||
|
||||
@@ -240,7 +240,7 @@ async def get_user_info(
|
||||
cached_user = await cache_service.get_user_from_cache(user_id_int)
|
||||
if cached_user:
|
||||
# 处理资源代理
|
||||
processed_user = await process_response_assets(cached_user, request)
|
||||
processed_user = await process_response_assets(cached_user)
|
||||
return processed_user
|
||||
|
||||
searched_user = (
|
||||
@@ -263,7 +263,7 @@ async def get_user_info(
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
# 处理资源代理
|
||||
processed_user = await process_response_assets(user_resp, request)
|
||||
processed_user = await process_response_assets(user_resp)
|
||||
return processed_user
|
||||
|
||||
|
||||
@@ -381,7 +381,7 @@ async def get_user_scores(
|
||||
user_id, type, include_fails, mode, limit, offset, is_legacy_api
|
||||
)
|
||||
if cached_scores is not None:
|
||||
processed_scores = await process_response_assets(cached_scores, request)
|
||||
processed_scores = await process_response_assets(cached_scores)
|
||||
return processed_scores
|
||||
|
||||
db_user = await session.get(User, user_id)
|
||||
@@ -438,5 +438,5 @@ async def get_user_scores(
|
||||
)
|
||||
|
||||
# 处理资源代理
|
||||
processed_scores = await process_response_assets(score_responses, request)
|
||||
processed_scores = await process_response_assets(score_responses)
|
||||
return processed_scores
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.service.asset_proxy_service import get_asset_proxy_service
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
async def process_response_assets(data: Any, request: Request) -> Any:
|
||||
async def process_response_assets(data: Any) -> Any:
|
||||
"""
|
||||
根据配置处理响应数据中的资源URL
|
||||
|
||||
@@ -72,7 +72,7 @@ def asset_proxy_response(func):
|
||||
|
||||
# 如果有request对象且启用了资源代理,则处理响应
|
||||
if request and settings.enable_asset_proxy and should_process_asset_proxy(request.url.path):
|
||||
result = await process_response_assets(result, request)
|
||||
result = await process_response_assets(result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -113,6 +113,7 @@ class BeatmapCacheService:
|
||||
if size:
|
||||
total_size += size
|
||||
except Exception:
|
||||
logger.debug(f"Failed to get size for key {key}")
|
||||
continue
|
||||
|
||||
return {
|
||||
|
||||
@@ -36,11 +36,8 @@ def safe_json_dumps(data) -> str:
|
||||
|
||||
def generate_hash(data) -> str:
|
||||
"""生成数据的MD5哈希值"""
|
||||
if isinstance(data, str):
|
||||
content = data
|
||||
else:
|
||||
content = safe_json_dumps(data)
|
||||
return hashlib.md5(content.encode()).hexdigest()
|
||||
content = data if isinstance(data, str) else safe_json_dumps(data)
|
||||
return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
|
||||
class BeatmapsetCacheService:
|
||||
|
||||
@@ -110,9 +110,7 @@ class ProcessingBeatmapset:
|
||||
changed_beatmaps = []
|
||||
for bm in self.beatmapset.beatmaps:
|
||||
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None)
|
||||
if not saved:
|
||||
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
|
||||
elif saved["is_deleted"]:
|
||||
if not saved or saved["is_deleted"]:
|
||||
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
|
||||
elif saved["md5"] != bm.checksum:
|
||||
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED))
|
||||
@@ -285,7 +283,7 @@ class BeatmapsetUpdateService:
|
||||
async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp):
|
||||
async with with_db() as session:
|
||||
db_beatmapset = await session.get(Beatmapset, beatmapset.id)
|
||||
new_beatmapset = await Beatmapset.from_resp_no_save(session, beatmapset)
|
||||
new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset)
|
||||
if db_beatmapset:
|
||||
await session.merge(new_beatmapset)
|
||||
await session.commit()
|
||||
@@ -356,5 +354,7 @@ def init_beatmapset_update_service(fetcher: "Fetcher") -> BeatmapsetUpdateServic
|
||||
|
||||
|
||||
def get_beatmapset_update_service() -> BeatmapsetUpdateService:
|
||||
if service is None:
|
||||
raise ValueError("BeatmapsetUpdateService is not initialized")
|
||||
assert service is not None, "BeatmapsetUpdateService is not initialized"
|
||||
return service
|
||||
|
||||
@@ -128,7 +128,11 @@ class LoginLogService:
|
||||
login_success=False,
|
||||
login_method=login_method,
|
||||
user_agent=user_agent,
|
||||
notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt",
|
||||
notes=(
|
||||
f"Failed login attempt on user {attempted_username}: {notes}"
|
||||
if attempted_username
|
||||
else "Failed login attempt"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ class PasswordResetService:
|
||||
await redis.delete(reset_code_key)
|
||||
await redis.delete(rate_limit_key)
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning("Failed to clean up Redis data after error")
|
||||
logger.exception("Redis operation failed")
|
||||
return False, "服务暂时不可用,请稍后重试"
|
||||
|
||||
|
||||
@@ -593,10 +593,7 @@ class RankingCacheService:
|
||||
async def invalidate_country_cache(self, ruleset: GameMode | None = None) -> None:
|
||||
"""使地区排行榜缓存失效"""
|
||||
try:
|
||||
if ruleset:
|
||||
pattern = f"country_ranking:{ruleset}:*"
|
||||
else:
|
||||
pattern = "country_ranking:*"
|
||||
pattern = f"country_ranking:{ruleset}:*" if ruleset else "country_ranking:*"
|
||||
|
||||
keys = await self.redis.keys(pattern)
|
||||
if keys:
|
||||
@@ -608,10 +605,7 @@ class RankingCacheService:
|
||||
async def invalidate_team_cache(self, ruleset: GameMode | None = None) -> None:
|
||||
"""使战队排行榜缓存失效"""
|
||||
try:
|
||||
if ruleset:
|
||||
pattern = f"team_ranking:{ruleset}:*"
|
||||
else:
|
||||
pattern = "team_ranking:*"
|
||||
pattern = f"team_ranking:{ruleset}:*" if ruleset else "team_ranking:*"
|
||||
|
||||
keys = await self.redis.keys(pattern)
|
||||
if keys:
|
||||
@@ -637,6 +631,7 @@ class RankingCacheService:
|
||||
if size:
|
||||
total_size += size
|
||||
except Exception:
|
||||
logger.warning(f"Failed to get memory usage for key {key}")
|
||||
continue
|
||||
|
||||
return {
|
||||
|
||||
0
app/service/subscribers/__init__.py
Normal file
0
app/service/subscribers/__init__.py
Normal file
@@ -35,19 +35,19 @@ class ChatSubscriber(RedisSubscriber):
|
||||
self.add_handler(ON_NOTIFICATION, self.on_notification)
|
||||
self.start()
|
||||
|
||||
async def on_join_room(self, c: str, s: str):
|
||||
async def on_join_room(self, c: str, s: str): # noqa: ARG002
|
||||
channel_id, user_id = s.split(":")
|
||||
if self.chat_server is None:
|
||||
return
|
||||
await self.chat_server.join_room_channel(int(channel_id), int(user_id))
|
||||
|
||||
async def on_leave_room(self, c: str, s: str):
|
||||
async def on_leave_room(self, c: str, s: str): # noqa: ARG002
|
||||
channel_id, user_id = s.split(":")
|
||||
if self.chat_server is None:
|
||||
return
|
||||
await self.chat_server.leave_room_channel(int(channel_id), int(user_id))
|
||||
|
||||
async def on_notification(self, c: str, s: str):
|
||||
async def on_notification(self, c: str, s: str): # noqa: ARG002
|
||||
try:
|
||||
detail = TypeAdapter(NotificationDetails).validate_json(s)
|
||||
except ValueError:
|
||||
|
||||
@@ -357,6 +357,7 @@ class UserCacheService:
|
||||
if size:
|
||||
total_size += size
|
||||
except Exception:
|
||||
logger.warning(f"Failed to get memory usage for key {key}")
|
||||
continue
|
||||
|
||||
return {
|
||||
|
||||
@@ -288,10 +288,6 @@ This email was sent automatically, please do not reply.
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
code: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
client_id: int | None = None,
|
||||
country_code: str | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""验证邮箱验证码"""
|
||||
try:
|
||||
|
||||
@@ -41,7 +41,7 @@ async def warmup_cache() -> None:
|
||||
logger.info("Beatmap cache warmup completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Beatmap cache warmup failed: %s", e)
|
||||
logger.error(f"Beatmap cache warmup failed: {e}")
|
||||
|
||||
|
||||
async def refresh_ranking_cache() -> None:
|
||||
@@ -59,7 +59,7 @@ async def refresh_ranking_cache() -> None:
|
||||
logger.info("Ranking cache refresh completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Ranking cache refresh failed: %s", e)
|
||||
logger.error(f"Ranking cache refresh failed: {e}")
|
||||
|
||||
|
||||
async def schedule_user_cache_preload_task() -> None:
|
||||
@@ -93,14 +93,14 @@ async def schedule_user_cache_preload_task() -> None:
|
||||
if active_user_ids:
|
||||
user_ids = [row[0] for row in active_user_ids]
|
||||
await cache_service.preload_user_cache(session, user_ids)
|
||||
logger.info("Preloaded cache for %s active users", len(user_ids))
|
||||
logger.info(f"Preloaded cache for {len(user_ids)} active users")
|
||||
else:
|
||||
logger.info("No active users found for cache preload")
|
||||
|
||||
logger.info("User cache preload task completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("User cache preload task failed: %s", e)
|
||||
logger.error(f"User cache preload task failed: {e}")
|
||||
|
||||
|
||||
async def schedule_user_cache_warmup_task() -> None:
|
||||
@@ -131,18 +131,18 @@ async def schedule_user_cache_warmup_task() -> None:
|
||||
if top_users:
|
||||
user_ids = list(top_users)
|
||||
await cache_service.preload_user_cache(session, user_ids)
|
||||
logger.info("Warmed cache for top 100 users in %s", mode)
|
||||
logger.info(f"Warmed cache for top 100 users in {mode}")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to warm cache for %s: %s", mode, e)
|
||||
logger.error(f"Failed to warm cache for {mode}: {e}")
|
||||
continue
|
||||
|
||||
logger.info("User cache warmup task completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("User cache warmup task failed: %s", e)
|
||||
logger.error(f"User cache warmup task failed: {e}")
|
||||
|
||||
|
||||
async def schedule_user_cache_cleanup_task() -> None:
|
||||
@@ -155,11 +155,11 @@ async def schedule_user_cache_cleanup_task() -> None:
|
||||
cache_service = get_user_cache_service(redis)
|
||||
stats = await cache_service.get_cache_stats()
|
||||
|
||||
logger.info("User cache stats: %s", stats)
|
||||
logger.info(f"User cache stats: {stats}")
|
||||
logger.info("User cache cleanup task completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("User cache cleanup task failed: %s", e)
|
||||
logger.error(f"User cache cleanup task failed: {e}")
|
||||
|
||||
|
||||
async def warmup_user_cache() -> None:
|
||||
@@ -167,7 +167,7 @@ async def warmup_user_cache() -> None:
|
||||
try:
|
||||
await schedule_user_cache_warmup_task()
|
||||
except Exception as e:
|
||||
logger.error("User cache warmup failed: %s", e)
|
||||
logger.error(f"User cache warmup failed: {e}")
|
||||
|
||||
|
||||
async def preload_user_cache() -> None:
|
||||
@@ -175,7 +175,7 @@ async def preload_user_cache() -> None:
|
||||
try:
|
||||
await schedule_user_cache_preload_task()
|
||||
except Exception as e:
|
||||
logger.error("User cache preload failed: %s", e)
|
||||
logger.error(f"User cache preload failed: {e}")
|
||||
|
||||
|
||||
async def cleanup_user_cache() -> None:
|
||||
@@ -183,7 +183,7 @@ async def cleanup_user_cache() -> None:
|
||||
try:
|
||||
await schedule_user_cache_cleanup_task()
|
||||
except Exception as e:
|
||||
logger.error("User cache cleanup failed: %s", e)
|
||||
logger.error(f"User cache cleanup failed: {e}")
|
||||
|
||||
|
||||
def register_cache_jobs() -> None:
|
||||
|
||||
@@ -5,8 +5,6 @@ Periodically update the MaxMind GeoIP database
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies.geoip import get_geoip_helper
|
||||
from app.dependencies.scheduler import get_scheduler
|
||||
@@ -28,14 +26,10 @@ async def update_geoip_database():
|
||||
try:
|
||||
logger.info("Starting scheduled GeoIP database update...")
|
||||
geoip = get_geoip_helper()
|
||||
|
||||
# Run the synchronous update method in a background thread
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: geoip.update(force=False))
|
||||
|
||||
await geoip.update(force=False)
|
||||
logger.info("Scheduled GeoIP database update completed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduled GeoIP database update failed: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"Scheduled GeoIP database update failed: {exc}")
|
||||
|
||||
|
||||
async def init_geoip():
|
||||
@@ -45,13 +39,8 @@ async def init_geoip():
|
||||
try:
|
||||
geoip = get_geoip_helper()
|
||||
logger.info("Initializing GeoIP database...")
|
||||
|
||||
# Run the synchronous update method in a background thread
|
||||
# force=False means only download if files don't exist or are expired
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: geoip.update(force=False))
|
||||
|
||||
await geoip.update(force=False)
|
||||
logger.info("GeoIP database initialization completed")
|
||||
except Exception as e:
|
||||
logger.error(f"GeoIP database initialization failed: {e}")
|
||||
except Exception as exc:
|
||||
logger.error(f"GeoIP database initialization failed: {exc}")
|
||||
# Do not raise an exception to avoid blocking application startup
|
||||
|
||||
@@ -16,7 +16,7 @@ async def create_rx_statistics():
|
||||
async with with_db() as session:
|
||||
users = (await session.exec(select(User.id))).all()
|
||||
total_users = len(users)
|
||||
logger.info("Ensuring RX/AP statistics exist for %s users", total_users)
|
||||
logger.info(f"Ensuring RX/AP statistics exist for {total_users} users")
|
||||
rx_created = 0
|
||||
ap_created = 0
|
||||
for i in users:
|
||||
@@ -57,7 +57,5 @@ async def create_rx_statistics():
|
||||
await session.commit()
|
||||
if rx_created or ap_created:
|
||||
logger.success(
|
||||
"Created %s RX statistics rows and %s AP statistics rows during backfill",
|
||||
rx_created,
|
||||
ap_created,
|
||||
f"Created {rx_created} RX statistics rows and {ap_created} AP statistics rows during backfill"
|
||||
)
|
||||
|
||||
@@ -258,10 +258,7 @@ class BackgroundTasks:
|
||||
self.tasks = set(tasks) if tasks else set()
|
||||
|
||||
def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
if is_async_callable(func):
|
||||
coro = func(*args, **kwargs)
|
||||
else:
|
||||
coro = run_in_threadpool(func, *args, **kwargs)
|
||||
coro = func(*args, **kwargs) if is_async_callable(func) else run_in_threadpool(func, *args, **kwargs)
|
||||
task = asyncio.create_task(coro)
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
Reference in New Issue
Block a user