chore(linter): update ruff rules
This commit is contained in:
@@ -32,11 +32,9 @@ async def process_streak(
|
|||||||
).first()
|
).first()
|
||||||
if not stats:
|
if not stats:
|
||||||
return False
|
return False
|
||||||
if streak <= stats.daily_streak_best < next_streak:
|
return bool(
|
||||||
return True
|
streak <= stats.daily_streak_best < next_streak or (next_streak == 0 and stats.daily_streak_best >= streak)
|
||||||
elif next_streak == 0 and stats.daily_streak_best >= streak:
|
)
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
MEDALS = {
|
MEDALS = {
|
||||||
|
|||||||
@@ -68,9 +68,7 @@ async def to_the_core(
|
|||||||
if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist:
|
if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist:
|
||||||
return False
|
return False
|
||||||
mods_ = mod_to_save(score.mods)
|
mods_ = mod_to_save(score.mods)
|
||||||
if "DT" not in mods_ or "NC" not in mods_:
|
return not ("DT" not in mods_ or "NC" not in mods_)
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def wysi(
|
async def wysi(
|
||||||
@@ -83,9 +81,7 @@ async def wysi(
|
|||||||
return False
|
return False
|
||||||
if str(round(score.accuracy, ndigits=4))[3:] != "727":
|
if str(round(score.accuracy, ndigits=4))[3:] != "727":
|
||||||
return False
|
return False
|
||||||
if "xi" not in beatmap.beatmapset.artist:
|
return "xi" in beatmap.beatmapset.artist
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def prepared(
|
async def prepared(
|
||||||
@@ -97,9 +93,7 @@ async def prepared(
|
|||||||
if score.rank != Rank.X and score.rank != Rank.XH:
|
if score.rank != Rank.X and score.rank != Rank.XH:
|
||||||
return False
|
return False
|
||||||
mods_ = mod_to_save(score.mods)
|
mods_ = mod_to_save(score.mods)
|
||||||
if "NF" not in mods_:
|
return "NF" in mods_
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def reckless_adandon(
|
async def reckless_adandon(
|
||||||
@@ -117,9 +111,7 @@ async def reckless_adandon(
|
|||||||
redis = get_redis()
|
redis = get_redis()
|
||||||
mods_ = score.mods.copy()
|
mods_ = score.mods.copy()
|
||||||
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
||||||
if attribute.star_rating < 3:
|
return not attribute.star_rating < 3
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def lights_out(
|
async def lights_out(
|
||||||
@@ -413,11 +405,10 @@ async def by_the_skin_of_the_teeth(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
for mod in score.mods:
|
for mod in score.mods:
|
||||||
if mod.get("acronym") == "AC":
|
if mod.get("acronym") == "AC" and "settings" in mod and "minimum_accuracy" in mod["settings"]:
|
||||||
if "settings" in mod and "minimum_accuracy" in mod["settings"]:
|
target_accuracy = mod["settings"]["minimum_accuracy"]
|
||||||
target_accuracy = mod["settings"]["minimum_accuracy"]
|
if isinstance(target_accuracy, int | float):
|
||||||
if isinstance(target_accuracy, int | float):
|
return abs(score.accuracy - float(target_accuracy)) < 0.0001
|
||||||
return abs(score.accuracy - float(target_accuracy)) < 0.0001
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,9 +19,7 @@ async def process_mod(
|
|||||||
return False
|
return False
|
||||||
if not beatmap.beatmap_status.has_leaderboard():
|
if not beatmap.beatmap_status.has_leaderboard():
|
||||||
return False
|
return False
|
||||||
if len(score.mods) != 1 or score.mods[0]["acronym"] != mod:
|
return not (len(score.mods) != 1 or score.mods[0]["acronym"] != mod)
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def process_category_mod(
|
async def process_category_mod(
|
||||||
|
|||||||
@@ -22,11 +22,7 @@ async def process_combo(
|
|||||||
return False
|
return False
|
||||||
if next_combo != 0 and combo >= next_combo:
|
if next_combo != 0 and combo >= next_combo:
|
||||||
return False
|
return False
|
||||||
if combo <= score.max_combo < next_combo:
|
return bool(combo <= score.max_combo < next_combo or (next_combo == 0 and score.max_combo >= combo))
|
||||||
return True
|
|
||||||
elif next_combo == 0 and score.max_combo >= combo:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
MEDALS: Medals = {
|
MEDALS: Medals = {
|
||||||
|
|||||||
@@ -35,11 +35,7 @@ async def process_playcount(
|
|||||||
).first()
|
).first()
|
||||||
if not stats:
|
if not stats:
|
||||||
return False
|
return False
|
||||||
if pc <= stats.play_count < next_pc:
|
return bool(pc <= stats.play_count < next_pc or (next_pc == 0 and stats.play_count >= pc))
|
||||||
return True
|
|
||||||
elif next_pc == 0 and stats.play_count >= pc:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
MEDALS: Medals = {
|
MEDALS: Medals = {
|
||||||
|
|||||||
@@ -47,9 +47,7 @@ async def process_skill(
|
|||||||
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
||||||
if attribute.star_rating < star or attribute.star_rating >= star + 1:
|
if attribute.star_rating < star or attribute.star_rating >= star + 1:
|
||||||
return False
|
return False
|
||||||
if type == "fc" and not score.is_perfect_combo:
|
return not (type == "fc" and not score.is_perfect_combo)
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
MEDALS: Medals = {
|
MEDALS: Medals = {
|
||||||
|
|||||||
@@ -35,11 +35,7 @@ async def process_tth(
|
|||||||
).first()
|
).first()
|
||||||
if not stats:
|
if not stats:
|
||||||
return False
|
return False
|
||||||
if tth <= stats.total_hits < next_tth:
|
return bool(tth <= stats.total_hits < next_tth or (next_tth == 0 and stats.play_count >= tth))
|
||||||
return True
|
|
||||||
elif next_tth == 0 and stats.play_count >= tth:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
MEDALS: Medals = {
|
MEDALS: Medals = {
|
||||||
|
|||||||
13
app/auth.py
13
app/auth.py
@@ -69,7 +69,7 @@ def verify_password_legacy(plain_password: str, bcrypt_hash: str) -> bool:
|
|||||||
2. MD5哈希 -> bcrypt验证
|
2. MD5哈希 -> bcrypt验证
|
||||||
"""
|
"""
|
||||||
# 1. 明文密码转 MD5
|
# 1. 明文密码转 MD5
|
||||||
pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode()
|
pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode() # noqa: S324
|
||||||
|
|
||||||
# 2. 检查缓存
|
# 2. 检查缓存
|
||||||
if bcrypt_hash in bcrypt_cache:
|
if bcrypt_hash in bcrypt_cache:
|
||||||
@@ -103,7 +103,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|||||||
def get_password_hash(password: str) -> str:
|
def get_password_hash(password: str) -> str:
|
||||||
"""生成密码哈希 - 使用 osu! 的方式"""
|
"""生成密码哈希 - 使用 osu! 的方式"""
|
||||||
# 1. 明文密码 -> MD5
|
# 1. 明文密码 -> MD5
|
||||||
pw_md5 = hashlib.md5(password.encode()).hexdigest().encode()
|
pw_md5 = hashlib.md5(password.encode()).hexdigest().encode() # noqa: S324
|
||||||
# 2. MD5 -> bcrypt
|
# 2. MD5 -> bcrypt
|
||||||
pw_bcrypt = bcrypt.hashpw(pw_md5, bcrypt.gensalt())
|
pw_bcrypt = bcrypt.hashpw(pw_md5, bcrypt.gensalt())
|
||||||
return pw_bcrypt.decode()
|
return pw_bcrypt.decode()
|
||||||
@@ -114,7 +114,7 @@ async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) -
|
|||||||
验证用户身份 - 使用类似 from_login 的逻辑
|
验证用户身份 - 使用类似 from_login 的逻辑
|
||||||
"""
|
"""
|
||||||
# 1. 明文密码转 MD5
|
# 1. 明文密码转 MD5
|
||||||
pw_md5 = hashlib.md5(password.encode()).hexdigest()
|
pw_md5 = hashlib.md5(password.encode()).hexdigest() # noqa: S324
|
||||||
|
|
||||||
# 2. 根据用户名查找用户
|
# 2. 根据用户名查找用户
|
||||||
user = None
|
user = None
|
||||||
@@ -325,12 +325,7 @@ def _generate_totp_account_label(user: User) -> str:
|
|||||||
|
|
||||||
根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性
|
根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性
|
||||||
"""
|
"""
|
||||||
if settings.totp_use_username_in_label:
|
primary_identifier = user.username if settings.totp_use_username_in_label else user.email
|
||||||
# 使用用户名作为主要标识
|
|
||||||
primary_identifier = user.username
|
|
||||||
else:
|
|
||||||
# 使用邮箱作为标识
|
|
||||||
primary_identifier = user.email
|
|
||||||
|
|
||||||
# 如果配置了服务名称,添加到标签中以便在认证器中区分
|
# 如果配置了服务名称,添加到标签中以便在认证器中区分
|
||||||
if settings.totp_service_name:
|
if settings.totp_service_name:
|
||||||
|
|||||||
@@ -419,9 +419,8 @@ def too_dense(hit_objects: list[HitObject], per_1s: int, per_10s: int) -> bool:
|
|||||||
if len(hit_objects) > i + per_1s:
|
if len(hit_objects) > i + per_1s:
|
||||||
if hit_objects[i + per_1s].start_time - hit_objects[i].start_time < 1000:
|
if hit_objects[i + per_1s].start_time - hit_objects[i].start_time < 1000:
|
||||||
return True
|
return True
|
||||||
elif len(hit_objects) > i + per_10s:
|
elif len(hit_objects) > i + per_10s and hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
|
||||||
if hit_objects[i + per_10s].start_time - hit_objects[i].start_time < 10000:
|
return True
|
||||||
return True
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@@ -448,10 +447,7 @@ def slider_is_sus(hit_objects: list[HitObject]) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def is_2b(hit_objects: list[HitObject]) -> bool:
|
def is_2b(hit_objects: list[HitObject]) -> bool:
|
||||||
for i in range(0, len(hit_objects) - 1):
|
return any(hit_objects[i] == hit_objects[i + 1].start_time for i in range(0, len(hit_objects) - 1))
|
||||||
if hit_objects[i] == hit_objects[i + 1].start_time:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_suspicious_beatmap(content: str) -> bool:
|
def is_suspicious_beatmap(content: str) -> bool:
|
||||||
|
|||||||
@@ -217,7 +217,7 @@ STORAGE_SETTINGS='{
|
|||||||
# 服务器设置
|
# 服务器设置
|
||||||
host: Annotated[
|
host: Annotated[
|
||||||
str,
|
str,
|
||||||
Field(default="0.0.0.0", description="服务器监听地址"),
|
Field(default="0.0.0.0", description="服务器监听地址"), # noqa: S104
|
||||||
"服务器设置",
|
"服务器设置",
|
||||||
]
|
]
|
||||||
port: Annotated[
|
port: Annotated[
|
||||||
@@ -609,26 +609,26 @@ STORAGE_SETTINGS='{
|
|||||||
]
|
]
|
||||||
|
|
||||||
@field_validator("fetcher_scopes", mode="before")
|
@field_validator("fetcher_scopes", mode="before")
|
||||||
|
@classmethod
|
||||||
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
|
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
|
||||||
if isinstance(v, str):
|
if isinstance(v, str):
|
||||||
return v.split(",")
|
return v.split(",")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("storage_settings", mode="after")
|
@field_validator("storage_settings", mode="after")
|
||||||
|
@classmethod
|
||||||
def validate_storage_settings(
|
def validate_storage_settings(
|
||||||
cls,
|
cls,
|
||||||
v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings,
|
v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings,
|
||||||
info: ValidationInfo,
|
info: ValidationInfo,
|
||||||
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
|
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
|
||||||
if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2:
|
service = info.data.get("storage_service")
|
||||||
if not isinstance(v, CloudflareR2Settings):
|
if service == StorageServiceType.CLOUDFLARE_R2 and not isinstance(v, CloudflareR2Settings):
|
||||||
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
|
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
|
||||||
elif info.data.get("storage_service") == StorageServiceType.LOCAL:
|
if service == StorageServiceType.LOCAL and not isinstance(v, LocalStorageSettings):
|
||||||
if not isinstance(v, LocalStorageSettings):
|
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
|
||||||
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
|
if service == StorageServiceType.AWS_S3 and not isinstance(v, AWSS3StorageSettings):
|
||||||
elif info.data.get("storage_service") == StorageServiceType.AWS_S3:
|
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
|
||||||
if not isinstance(v, AWSS3StorageSettings):
|
|
||||||
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -71,10 +71,10 @@ class Beatmap(BeatmapBase, table=True):
|
|||||||
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
|
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
async def from_resp_no_save(cls, _session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||||
d = resp.model_dump()
|
d = resp.model_dump()
|
||||||
del d["beatmapset"]
|
del d["beatmapset"]
|
||||||
beatmap = Beatmap.model_validate(
|
beatmap = cls.model_validate(
|
||||||
{
|
{
|
||||||
**d,
|
**d,
|
||||||
"beatmapset_id": resp.beatmapset_id,
|
"beatmapset_id": resp.beatmapset_id,
|
||||||
@@ -90,8 +90,7 @@ class Beatmap(BeatmapBase, table=True):
|
|||||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||||||
session.add(beatmap)
|
session.add(beatmap)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
|
return (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).one()
|
||||||
return beatmap
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
|
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
|
||||||
@@ -250,7 +249,7 @@ async def calculate_beatmap_attributes(
|
|||||||
redis: Redis,
|
redis: Redis,
|
||||||
fetcher: "Fetcher",
|
fetcher: "Fetcher",
|
||||||
):
|
):
|
||||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.sha256(str(mods_).encode()).hexdigest()}:attributes"
|
||||||
if await redis.exists(key):
|
if await redis.exists(key):
|
||||||
return BeatmapAttributes.model_validate_json(await redis.get(key))
|
return BeatmapAttributes.model_validate_json(await redis.get(key))
|
||||||
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
|||||||
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
|
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_resp_no_save(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
|
async def from_resp_no_save(cls, resp: "BeatmapsetResp") -> "Beatmapset":
|
||||||
d = resp.model_dump()
|
d = resp.model_dump()
|
||||||
if resp.nominations:
|
if resp.nominations:
|
||||||
d["nominations_required"] = resp.nominations.required
|
d["nominations_required"] = resp.nominations.required
|
||||||
@@ -158,10 +158,15 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
|||||||
return beatmapset
|
return beatmapset
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
|
async def from_resp(
|
||||||
|
cls,
|
||||||
|
session: AsyncSession,
|
||||||
|
resp: "BeatmapsetResp",
|
||||||
|
from_: int = 0,
|
||||||
|
) -> "Beatmapset":
|
||||||
from .beatmap import Beatmap
|
from .beatmap import Beatmap
|
||||||
|
|
||||||
beatmapset = await cls.from_resp_no_save(session, resp, from_=from_)
|
beatmapset = await cls.from_resp_no_save(resp)
|
||||||
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
|
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
|
||||||
session.add(beatmapset)
|
session.add(beatmapset)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -105,17 +105,11 @@ class ChatChannelResp(ChatChannelBase):
|
|||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
last_msg = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
last_msg_raw = await redis.get(f"chat:{channel.channel_id}:last_msg")
|
||||||
if last_msg and last_msg.isdigit():
|
last_msg = int(last_msg_raw) if last_msg_raw and last_msg_raw.isdigit() else None
|
||||||
last_msg = int(last_msg)
|
|
||||||
else:
|
|
||||||
last_msg = None
|
|
||||||
|
|
||||||
last_read_id = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
last_read_id_raw = await redis.get(f"chat:{channel.channel_id}:last_read:{user.id}")
|
||||||
if last_read_id and last_read_id.isdigit():
|
last_read_id = int(last_read_id_raw) if last_read_id_raw and last_read_id_raw.isdigit() else last_msg
|
||||||
last_read_id = int(last_read_id)
|
|
||||||
else:
|
|
||||||
last_read_id = last_msg
|
|
||||||
|
|
||||||
if silence is not None:
|
if silence is not None:
|
||||||
attribute = ChatUserAttributes(
|
attribute = ChatUserAttributes(
|
||||||
|
|||||||
@@ -520,12 +520,11 @@ async def _score_where(
|
|||||||
wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
|
wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
elif type == LeaderboardType.TEAM:
|
elif type == LeaderboardType.TEAM and user:
|
||||||
if user:
|
team_membership = await user.awaitable_attrs.team_membership
|
||||||
team_membership = await user.awaitable_attrs.team_membership
|
if team_membership:
|
||||||
if team_membership:
|
team_id = team_membership.team_id
|
||||||
team_id = team_membership.team_id
|
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
|
||||||
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
|
|
||||||
if mods:
|
if mods:
|
||||||
if user and user.is_supporter:
|
if user and user.is_supporter:
|
||||||
wheres.append(
|
wheres.append(
|
||||||
|
|||||||
@@ -256,8 +256,6 @@ class UserResp(UserBase):
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
include: list[str] = [],
|
include: list[str] = [],
|
||||||
ruleset: GameMode | None = None,
|
ruleset: GameMode | None = None,
|
||||||
*,
|
|
||||||
token_id: int | None = None,
|
|
||||||
) -> "UserResp":
|
) -> "UserResp":
|
||||||
from app.dependencies.database import get_redis
|
from app.dependencies.database import get_redis
|
||||||
|
|
||||||
@@ -310,16 +308,16 @@ class UserResp(UserBase):
|
|||||||
).all()
|
).all()
|
||||||
]
|
]
|
||||||
|
|
||||||
if "team" in include:
|
if "team" in include and (team_membership := await obj.awaitable_attrs.team_membership):
|
||||||
if team_membership := await obj.awaitable_attrs.team_membership:
|
u.team = team_membership.team
|
||||||
u.team = team_membership.team
|
|
||||||
|
|
||||||
if "account_history" in include:
|
if "account_history" in include:
|
||||||
u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
|
u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
|
||||||
|
|
||||||
if "daily_challenge_user_stats":
|
if "daily_challenge_user_stats" in include and (
|
||||||
if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats:
|
daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats
|
||||||
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
|
):
|
||||||
|
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
|
||||||
|
|
||||||
if "statistics" in include:
|
if "statistics" in include:
|
||||||
current_stattistics = None
|
current_stattistics = None
|
||||||
@@ -443,7 +441,7 @@ class MeResp(UserResp):
|
|||||||
from app.dependencies.database import get_redis
|
from app.dependencies.database import get_redis
|
||||||
from app.service.verification_service import LoginSessionService
|
from app.service.verification_service import LoginSessionService
|
||||||
|
|
||||||
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset, token_id=token_id)
|
u = await super().from_db(obj, session, ALL_INCLUDED, ruleset)
|
||||||
u.session_verified = (
|
u.session_verified = (
|
||||||
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
|
not await LoginSessionService.check_is_need_verification(session, user_id=obj.id, token_id=token_id)
|
||||||
if token_id
|
if token_id
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from fastapi.exceptions import RequestValidationError
|
|||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
|
||||||
def BodyOrForm[T: BaseModel](model: type[T]):
|
def BodyOrForm[T: BaseModel](model: type[T]): # noqa: N802
|
||||||
async def dependency(
|
async def dependency(
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> T:
|
) -> T:
|
||||||
|
|||||||
@@ -119,10 +119,7 @@ async def get_client_user(
|
|||||||
if verify_method is None:
|
if verify_method is None:
|
||||||
# 智能选择验证方式(有TOTP优先TOTP)
|
# 智能选择验证方式(有TOTP优先TOTP)
|
||||||
totp_key = await user.awaitable_attrs.totp_key
|
totp_key = await user.awaitable_attrs.totp_key
|
||||||
if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER:
|
verify_method = "totp" if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER else "mail"
|
||||||
verify_method = "totp"
|
|
||||||
else:
|
|
||||||
verify_method = "mail"
|
|
||||||
|
|
||||||
# 设置选择的验证方法到Redis中,避免重复选择
|
# 设置选择的验证方法到Redis中,避免重复选择
|
||||||
if api_version >= 20250913:
|
if api_version >= 20250913:
|
||||||
|
|||||||
0
app/exceptions/__init__.py
Normal file
0
app/exceptions/__init__.py
Normal file
@@ -116,7 +116,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
|
|
||||||
# 序列化为 JSON 并生成 MD5 哈希
|
# 序列化为 JSON 并生成 MD5 哈希
|
||||||
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
|
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
|
||||||
cache_hash = hashlib.md5(cache_json.encode()).hexdigest()
|
cache_hash = hashlib.md5(cache_json.encode(), usedforsecurity=False).hexdigest()
|
||||||
|
|
||||||
logger.opt(colors=True).debug(f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}")
|
logger.opt(colors=True).debug(f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}")
|
||||||
|
|
||||||
@@ -160,10 +160,10 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
cached_data = json.loads(cached_result)
|
cached_data = json.loads(cached_result)
|
||||||
return SearchBeatmapsetsResp.model_validate(cached_data)
|
return SearchBeatmapsetsResp.model_validate(cached_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.opt(colors=True).warning(f"Cache data invalid, fetching from API: {e}")
|
logger.warning(f"Cache data invalid, fetching from API: {e}")
|
||||||
|
|
||||||
# 缓存未命中,从 API 获取数据
|
# 缓存未命中,从 API 获取数据
|
||||||
logger.opt(colors=True).debug("Cache miss, fetching from API")
|
logger.debug("Cache miss, fetching from API")
|
||||||
|
|
||||||
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
|
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
|
||||||
|
|
||||||
@@ -203,7 +203,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
try:
|
try:
|
||||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
|
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
|
||||||
except RateLimitError:
|
except RateLimitError:
|
||||||
logger.opt(colors=True).info("Prefetch skipped due to rate limit")
|
logger.info("Prefetch skipped due to rate limit")
|
||||||
|
|
||||||
bg_tasks.add_task(delayed_prefetch)
|
bg_tasks.add_task(delayed_prefetch)
|
||||||
|
|
||||||
@@ -227,14 +227,14 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
# 使用当前 cursor 请求下一页
|
# 使用当前 cursor 请求下一页
|
||||||
next_query = query.model_copy()
|
next_query = query.model_copy()
|
||||||
|
|
||||||
logger.opt(colors=True).debug(f"Prefetching page {page + 1}")
|
logger.debug(f"Prefetching page {page + 1}")
|
||||||
|
|
||||||
# 生成下一页的缓存键
|
# 生成下一页的缓存键
|
||||||
next_cache_key = self._generate_cache_key(next_query, cursor)
|
next_cache_key = self._generate_cache_key(next_query, cursor)
|
||||||
|
|
||||||
# 检查是否已经缓存
|
# 检查是否已经缓存
|
||||||
if await redis_client.exists(next_cache_key):
|
if await redis_client.exists(next_cache_key):
|
||||||
logger.opt(colors=True).debug(f"Page {page + 1} already cached")
|
logger.debug(f"Page {page + 1} already cached")
|
||||||
# 尝试从缓存获取cursor继续预取
|
# 尝试从缓存获取cursor继续预取
|
||||||
cached_data = await redis_client.get(next_cache_key)
|
cached_data = await redis_client.get(next_cache_key)
|
||||||
if cached_data:
|
if cached_data:
|
||||||
@@ -244,7 +244,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
cursor = data["cursor"]
|
cursor = data["cursor"]
|
||||||
continue
|
continue
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
logger.warning("Failed to parse cached data for cursor")
|
||||||
break
|
break
|
||||||
|
|
||||||
# 在预取页面之间添加延迟,避免突发请求
|
# 在预取页面之间添加延迟,避免突发请求
|
||||||
@@ -279,18 +279,18 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
ex=prefetch_ttl,
|
ex=prefetch_ttl,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.opt(colors=True).debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
|
logger.debug(f"Prefetched page {page + 1} (TTL: {prefetch_ttl}s)")
|
||||||
|
|
||||||
except RateLimitError:
|
except RateLimitError:
|
||||||
logger.opt(colors=True).info("Prefetch stopped due to rate limit")
|
logger.info("Prefetch stopped due to rate limit")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.opt(colors=True).warning(f"Prefetch failed: {e}")
|
logger.warning(f"Prefetch failed: {e}")
|
||||||
|
|
||||||
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
|
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
|
||||||
"""预热主页缓存"""
|
"""预热主页缓存"""
|
||||||
homepage_queries = self._get_homepage_queries()
|
homepage_queries = self._get_homepage_queries()
|
||||||
|
|
||||||
logger.opt(colors=True).info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)")
|
logger.info(f"Starting homepage cache warmup ({len(homepage_queries)} queries)")
|
||||||
|
|
||||||
for i, (query, cursor) in enumerate(homepage_queries):
|
for i, (query, cursor) in enumerate(homepage_queries):
|
||||||
try:
|
try:
|
||||||
@@ -302,7 +302,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
|
|
||||||
# 检查是否已经缓存
|
# 检查是否已经缓存
|
||||||
if await redis_client.exists(cache_key):
|
if await redis_client.exists(cache_key):
|
||||||
logger.opt(colors=True).debug(f"Query {query.sort} already cached")
|
logger.debug(f"Query {query.sort} already cached")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 请求并缓存
|
# 请求并缓存
|
||||||
@@ -325,15 +325,15 @@ class BeatmapsetFetcher(BaseFetcher):
|
|||||||
ex=cache_ttl,
|
ex=cache_ttl,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.opt(colors=True).info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
|
logger.info(f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)")
|
||||||
|
|
||||||
if api_response.get("cursor"):
|
if api_response.get("cursor"):
|
||||||
try:
|
try:
|
||||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
|
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
|
||||||
except RateLimitError:
|
except RateLimitError:
|
||||||
logger.opt(colors=True).info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
|
logger.info(f"Warmup prefetch skipped for {query.sort} due to rate limit")
|
||||||
|
|
||||||
except RateLimitError:
|
except RateLimitError:
|
||||||
logger.opt(colors=True).warning(f"Warmup skipped for {query.sort} due to rate limit")
|
logger.warning(f"Warmup skipped for {query.sort} due to rate limit")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.opt(colors=True).error(f"Failed to warmup cache for {query.sort}: {e}")
|
logger.error(f"Failed to warmup cache for {query.sort}: {e}")
|
||||||
|
|||||||
0
app/helpers/__init__.py
Normal file
0
app/helpers/__init__.py
Normal file
@@ -1,19 +1,39 @@
|
|||||||
"""
|
"""
|
||||||
GeoLite2 Helper Class
|
GeoLite2 Helper Class (asynchronous)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from contextlib import suppress
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
import tarfile
|
import tarfile
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, Required, TypedDict
|
||||||
|
|
||||||
|
from app.log import logger
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
import httpx
|
import httpx
|
||||||
import maxminddb
|
import maxminddb
|
||||||
|
|
||||||
|
|
||||||
|
class GeoIPLookupResult(TypedDict, total=False):
|
||||||
|
ip: Required[str]
|
||||||
|
country_iso: str
|
||||||
|
country_name: str
|
||||||
|
city_name: str
|
||||||
|
latitude: str
|
||||||
|
longitude: str
|
||||||
|
time_zone: str
|
||||||
|
postal_code: str
|
||||||
|
asn: int | None
|
||||||
|
organization: str
|
||||||
|
|
||||||
|
|
||||||
BASE_URL = "https://download.maxmind.com/app/geoip_download"
|
BASE_URL = "https://download.maxmind.com/app/geoip_download"
|
||||||
EDITIONS = {
|
EDITIONS = {
|
||||||
"City": "GeoLite2-City",
|
"City": "GeoLite2-City",
|
||||||
@@ -25,161 +45,184 @@ EDITIONS = {
|
|||||||
class GeoIPHelper:
|
class GeoIPHelper:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dest_dir="./geoip",
|
dest_dir: str | Path = Path("./geoip"),
|
||||||
license_key=None,
|
license_key: str | None = None,
|
||||||
editions=None,
|
editions: list[str] | None = None,
|
||||||
max_age_days=8,
|
max_age_days: int = 8,
|
||||||
timeout=60.0,
|
timeout: float = 60.0,
|
||||||
):
|
):
|
||||||
self.dest_dir = dest_dir
|
self.dest_dir = Path(dest_dir).expanduser()
|
||||||
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
|
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
|
||||||
self.editions = editions or ["City", "ASN"]
|
self.editions = list(editions or ["City", "ASN"])
|
||||||
self.max_age_days = max_age_days
|
self.max_age_days = max_age_days
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self._readers = {}
|
self._readers: dict[str, maxminddb.Reader] = {}
|
||||||
|
self._update_lock = asyncio.Lock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _safe_extract(tar: tarfile.TarFile, path: str):
|
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
|
||||||
base = Path(path).resolve()
|
base = path.resolve()
|
||||||
for m in tar.getmembers():
|
for member in tar.getmembers():
|
||||||
target = (base / m.name).resolve()
|
target = (base / member.name).resolve()
|
||||||
if not str(target).startswith(str(base)):
|
if not target.is_relative_to(base): # py312
|
||||||
raise RuntimeError("Unsafe path in tar file")
|
raise RuntimeError("Unsafe path in tar file")
|
||||||
tar.extractall(path=path, filter="data")
|
tar.extractall(path=base, filter="data")
|
||||||
|
|
||||||
def _download_and_extract(self, edition_id: str) -> str:
|
@staticmethod
|
||||||
"""
|
def _as_mapping(value: Any) -> dict[str, Any]:
|
||||||
下载并解压 mmdb 文件到 dest_dir,仅保留 .mmdb
|
return value if isinstance(value, dict) else {}
|
||||||
- 跟随 302 重定向
|
|
||||||
- 流式下载到临时文件
|
@staticmethod
|
||||||
- 临时目录退出后自动清理
|
def _as_str(value: Any, default: str = "") -> str:
|
||||||
"""
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _as_int(value: Any) -> int | None:
|
||||||
|
return value if isinstance(value, int) else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_tarball(src: Path, dest: Path) -> None:
|
||||||
|
with tarfile.open(src, "r:gz") as tar:
|
||||||
|
GeoIPHelper._safe_extract(tar, dest)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _find_mmdb(root: Path) -> Path | None:
|
||||||
|
for candidate in root.rglob("*.mmdb"):
|
||||||
|
return candidate
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _latest_file_sync(self, edition_id: str) -> Path | None:
|
||||||
|
directory = self.dest_dir
|
||||||
|
if not directory.is_dir():
|
||||||
|
return None
|
||||||
|
candidates = list(directory.glob(f"{edition_id}*.mmdb"))
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
return max(candidates, key=lambda p: p.stat().st_mtime)
|
||||||
|
|
||||||
|
async def _latest_file(self, edition_id: str) -> Path | None:
|
||||||
|
return await asyncio.to_thread(self._latest_file_sync, edition_id)
|
||||||
|
|
||||||
|
async def _download_and_extract(self, edition_id: str) -> Path:
|
||||||
if not self.license_key:
|
if not self.license_key:
|
||||||
raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.")
|
raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.")
|
||||||
|
|
||||||
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
|
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
|
||||||
|
tmp_dir = Path(await asyncio.to_thread(tempfile.mkdtemp))
|
||||||
|
|
||||||
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
|
try:
|
||||||
with client.stream("GET", url) as resp:
|
tgz_path = tmp_dir / "db.tgz"
|
||||||
|
async with (
|
||||||
|
httpx.AsyncClient(follow_redirects=True, timeout=self.timeout) as client,
|
||||||
|
client.stream("GET", url) as resp,
|
||||||
|
):
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
with tempfile.TemporaryDirectory() as tmpd:
|
async with aiofiles.open(tgz_path, "wb") as download_file:
|
||||||
tgz_path = os.path.join(tmpd, "db.tgz")
|
async for chunk in resp.aiter_bytes():
|
||||||
# 流式写入
|
if chunk:
|
||||||
with open(tgz_path, "wb") as f:
|
await download_file.write(chunk)
|
||||||
for chunk in resp.iter_bytes():
|
|
||||||
if chunk:
|
|
||||||
f.write(chunk)
|
|
||||||
|
|
||||||
# 解压并只移动 .mmdb
|
await asyncio.to_thread(self._extract_tarball, tgz_path, tmp_dir)
|
||||||
with tarfile.open(tgz_path, "r:gz") as tar:
|
mmdb_path = await asyncio.to_thread(self._find_mmdb, tmp_dir)
|
||||||
# 先安全检查与解压
|
if mmdb_path is None:
|
||||||
self._safe_extract(tar, tmpd)
|
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
|
||||||
|
|
||||||
# 递归找 .mmdb
|
await asyncio.to_thread(self.dest_dir.mkdir, parents=True, exist_ok=True)
|
||||||
mmdb_path = None
|
dst = self.dest_dir / mmdb_path.name
|
||||||
for root, _, files in os.walk(tmpd):
|
await asyncio.to_thread(shutil.move, mmdb_path, dst)
|
||||||
for fn in files:
|
return dst
|
||||||
if fn.endswith(".mmdb"):
|
finally:
|
||||||
mmdb_path = os.path.join(root, fn)
|
await asyncio.to_thread(shutil.rmtree, tmp_dir, ignore_errors=True)
|
||||||
break
|
|
||||||
if mmdb_path:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not mmdb_path:
|
async def update(self, force: bool = False) -> None:
|
||||||
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
|
async with self._update_lock:
|
||||||
|
for edition in self.editions:
|
||||||
|
edition_id = EDITIONS[edition]
|
||||||
|
path = await self._latest_file(edition_id)
|
||||||
|
need_download = force or path is None
|
||||||
|
|
||||||
os.makedirs(self.dest_dir, exist_ok=True)
|
if path:
|
||||||
dst = os.path.join(self.dest_dir, os.path.basename(mmdb_path))
|
mtime = await asyncio.to_thread(path.stat)
|
||||||
shutil.move(mmdb_path, dst)
|
age_days = (time.time() - mtime.st_mtime) / 86400
|
||||||
return dst
|
if age_days >= self.max_age_days:
|
||||||
|
need_download = True
|
||||||
def _latest_file(self, edition_id: str):
|
logger.info(
|
||||||
if not os.path.isdir(self.dest_dir):
|
f"{edition_id} database is {age_days:.1f} days old "
|
||||||
return None
|
f"(max: {self.max_age_days}), will download new version"
|
||||||
files = [
|
)
|
||||||
os.path.join(self.dest_dir, f)
|
else:
|
||||||
for f in os.listdir(self.dest_dir)
|
logger.info(
|
||||||
if f.startswith(edition_id) and f.endswith(".mmdb")
|
f"{edition_id} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})"
|
||||||
]
|
)
|
||||||
return max(files, key=os.path.getmtime) if files else None
|
|
||||||
|
|
||||||
def update(self, force=False):
|
|
||||||
from app.log import logger
|
|
||||||
|
|
||||||
for ed in self.editions:
|
|
||||||
eid = EDITIONS[ed]
|
|
||||||
path = self._latest_file(eid)
|
|
||||||
need = force or not path
|
|
||||||
|
|
||||||
if path:
|
|
||||||
age_days = (time.time() - os.path.getmtime(path)) / 86400
|
|
||||||
if age_days >= self.max_age_days:
|
|
||||||
need = True
|
|
||||||
logger.info(
|
|
||||||
f"{eid} database is {age_days:.1f} days old "
|
|
||||||
f"(max: {self.max_age_days}), will download new version"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"{eid} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})")
|
logger.info(f"{edition_id} database not found, will download")
|
||||||
else:
|
|
||||||
logger.info(f"{eid} database not found, will download")
|
|
||||||
|
|
||||||
if need:
|
if need_download:
|
||||||
logger.info(f"Downloading {eid} database...")
|
logger.info(f"Downloading {edition_id} database...")
|
||||||
path = self._download_and_extract(eid)
|
path = await self._download_and_extract(edition_id)
|
||||||
logger.info(f"{eid} database downloaded successfully")
|
logger.info(f"{edition_id} database downloaded successfully")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Using existing {eid} database")
|
logger.info(f"Using existing {edition_id} database")
|
||||||
|
|
||||||
old = self._readers.get(ed)
|
old_reader = self._readers.get(edition)
|
||||||
if old:
|
if old_reader:
|
||||||
try:
|
with suppress(Exception):
|
||||||
old.close()
|
old_reader.close()
|
||||||
except Exception:
|
if path is not None:
|
||||||
pass
|
self._readers[edition] = maxminddb.open_database(str(path))
|
||||||
if path is not None:
|
|
||||||
self._readers[ed] = maxminddb.open_database(path)
|
|
||||||
|
|
||||||
def lookup(self, ip: str):
|
def lookup(self, ip: str) -> GeoIPLookupResult:
|
||||||
res = {"ip": ip}
|
res: GeoIPLookupResult = {"ip": ip}
|
||||||
# City
|
city_reader = self._readers.get("City")
|
||||||
city_r = self._readers.get("City")
|
if city_reader:
|
||||||
if city_r:
|
data = city_reader.get(ip)
|
||||||
data = city_r.get(ip)
|
if isinstance(data, dict):
|
||||||
if data:
|
country = self._as_mapping(data.get("country"))
|
||||||
country = data.get("country") or {}
|
res["country_iso"] = self._as_str(country.get("iso_code"))
|
||||||
res["country_iso"] = country.get("iso_code") or ""
|
country_names = self._as_mapping(country.get("names"))
|
||||||
res["country_name"] = (country.get("names") or {}).get("en", "")
|
res["country_name"] = self._as_str(country_names.get("en"))
|
||||||
city = data.get("city") or {}
|
|
||||||
res["city_name"] = (city.get("names") or {}).get("en", "")
|
city = self._as_mapping(data.get("city"))
|
||||||
loc = data.get("location") or {}
|
city_names = self._as_mapping(city.get("names"))
|
||||||
res["latitude"] = str(loc.get("latitude") or "")
|
res["city_name"] = self._as_str(city_names.get("en"))
|
||||||
res["longitude"] = str(loc.get("longitude") or "")
|
|
||||||
res["time_zone"] = str(loc.get("time_zone") or "")
|
location = self._as_mapping(data.get("location"))
|
||||||
postal = data.get("postal") or {}
|
latitude = location.get("latitude")
|
||||||
if "code" in postal:
|
longitude = location.get("longitude")
|
||||||
res["postal_code"] = postal["code"]
|
res["latitude"] = str(latitude) if latitude is not None else ""
|
||||||
# ASN
|
res["longitude"] = str(longitude) if longitude is not None else ""
|
||||||
asn_r = self._readers.get("ASN")
|
res["time_zone"] = self._as_str(location.get("time_zone"))
|
||||||
if asn_r:
|
|
||||||
data = asn_r.get(ip)
|
postal = self._as_mapping(data.get("postal"))
|
||||||
if data:
|
postal_code = postal.get("code")
|
||||||
res["asn"] = data.get("autonomous_system_number")
|
if postal_code is not None:
|
||||||
res["organization"] = data.get("autonomous_system_organization")
|
res["postal_code"] = self._as_str(postal_code)
|
||||||
|
|
||||||
|
asn_reader = self._readers.get("ASN")
|
||||||
|
if asn_reader:
|
||||||
|
data = asn_reader.get(ip)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
res["asn"] = self._as_int(data.get("autonomous_system_number"))
|
||||||
|
res["organization"] = self._as_str(data.get("autonomous_system_organization"), default="")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
for r in self._readers.values():
|
for reader in self._readers.values():
|
||||||
try:
|
with suppress(Exception):
|
||||||
r.close()
|
reader.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._readers = {}
|
self._readers = {}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 示例用法
|
|
||||||
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
|
async def _demo() -> None:
|
||||||
geo.update()
|
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
|
||||||
print(geo.lookup("8.8.8.8"))
|
await geo.update()
|
||||||
geo.close()
|
print(geo.lookup("8.8.8.8"))
|
||||||
|
geo.close()
|
||||||
|
|
||||||
|
asyncio.run(_demo())
|
||||||
|
|||||||
@@ -97,9 +97,7 @@ class InterceptHandler(logging.Handler):
|
|||||||
status_color = "green"
|
status_color = "green"
|
||||||
elif 300 <= status < 400:
|
elif 300 <= status < 400:
|
||||||
status_color = "yellow"
|
status_color = "yellow"
|
||||||
elif 400 <= status < 500:
|
elif 400 <= status < 500 or 500 <= status < 600:
|
||||||
status_color = "red"
|
|
||||||
elif 500 <= status < 600:
|
|
||||||
status_color = "red"
|
status_color = "red"
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
|||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
# 启动验证流程
|
# 启动验证流程
|
||||||
return await self._initiate_verification(request, session_state)
|
return await self._initiate_verification(session_state)
|
||||||
|
|
||||||
def _should_skip_verification(self, request: Request) -> bool:
|
def _should_skip_verification(self, request: Request) -> bool:
|
||||||
"""检查是否应该跳过验证"""
|
"""检查是否应该跳过验证"""
|
||||||
@@ -93,10 +93,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# 非API请求跳过
|
# 非API请求跳过
|
||||||
if not path.startswith("/api/"):
|
return bool(not path.startswith("/api/"))
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _requires_verification(self, request: Request, user: User) -> bool:
|
def _requires_verification(self, request: Request, user: User) -> bool:
|
||||||
"""检查是否需要验证"""
|
"""检查是否需要验证"""
|
||||||
@@ -177,7 +174,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware):
|
|||||||
logger.error(f"Error getting session state: {e}")
|
logger.error(f"Error getting session state: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _initiate_verification(self, request: Request, state: SessionState) -> Response:
|
async def _initiate_verification(self, state: SessionState) -> Response:
|
||||||
"""启动验证流程"""
|
"""启动验证流程"""
|
||||||
try:
|
try:
|
||||||
method = await state.get_method()
|
method = await state.get_method()
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class ExtendedTokenResponse(BaseModel):
|
|||||||
"""扩展的令牌响应,支持二次验证状态"""
|
"""扩展的令牌响应,支持二次验证状态"""
|
||||||
|
|
||||||
access_token: str | None = None
|
access_token: str | None = None
|
||||||
token_type: str = "Bearer"
|
token_type: str = "Bearer" # noqa: S105
|
||||||
expires_in: int | None = None
|
expires_in: int | None = None
|
||||||
refresh_token: str | None = None
|
refresh_token: str | None = None
|
||||||
scope: str | None = None
|
scope: str | None = None
|
||||||
@@ -20,14 +20,3 @@ class ExtendedTokenResponse(BaseModel):
|
|||||||
requires_second_factor: bool = False
|
requires_second_factor: bool = False
|
||||||
verification_message: str | None = None
|
verification_message: str | None = None
|
||||||
user_id: int | None = None # 用于二次验证的用户ID
|
user_id: int | None = None # 用于二次验证的用户ID
|
||||||
|
|
||||||
|
|
||||||
class SessionState(BaseModel):
|
|
||||||
"""会话状态"""
|
|
||||||
|
|
||||||
user_id: int
|
|
||||||
username: str
|
|
||||||
email: str
|
|
||||||
requires_verification: bool
|
|
||||||
session_token: str | None = None
|
|
||||||
verification_sent: bool = False
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
# ruff: noqa: ARG002
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class TokenRequest(BaseModel):
|
|||||||
|
|
||||||
class TokenResponse(BaseModel):
|
class TokenResponse(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str = "Bearer"
|
token_type: str = "Bearer" # noqa: S105
|
||||||
expires_in: int
|
expires_in: int
|
||||||
refresh_token: str
|
refresh_token: str
|
||||||
scope: str = "*"
|
scope: str = "*"
|
||||||
@@ -67,7 +67,7 @@ class RegistrationRequestErrors(BaseModel):
|
|||||||
class OAuth2ClientCredentialsBearer(OAuth2):
|
class OAuth2ClientCredentialsBearer(OAuth2):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenUrl: Annotated[
|
tokenUrl: Annotated[ # noqa: N803
|
||||||
str,
|
str,
|
||||||
Doc(
|
Doc(
|
||||||
"""
|
"""
|
||||||
@@ -75,7 +75,7 @@ class OAuth2ClientCredentialsBearer(OAuth2):
|
|||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
refreshUrl: Annotated[
|
refreshUrl: Annotated[ # noqa: N803
|
||||||
str | None,
|
str | None,
|
||||||
Doc(
|
Doc(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -46,10 +46,10 @@ class PlayerStatsResponse(BaseModel):
|
|||||||
class PlayerEventItem(BaseModel):
|
class PlayerEventItem(BaseModel):
|
||||||
"""玩家事件项目"""
|
"""玩家事件项目"""
|
||||||
|
|
||||||
userId: int
|
userId: int # noqa: N815
|
||||||
name: str
|
name: str
|
||||||
mapId: int | None = None
|
mapId: int | None = None # noqa: N815
|
||||||
setId: int | None = None
|
setId: int | None = None # noqa: N815
|
||||||
artist: str | None = None
|
artist: str | None = None
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
version: str | None = None
|
version: str | None = None
|
||||||
@@ -88,7 +88,7 @@ class PlayerInfo(BaseModel):
|
|||||||
custom_badge_icon: str
|
custom_badge_icon: str
|
||||||
custom_badge_color: str
|
custom_badge_color: str
|
||||||
userpage_content: str
|
userpage_content: str
|
||||||
recentFailed: int
|
recentFailed: int # noqa: N815
|
||||||
social_discord: str | None = None
|
social_discord: str | None = None
|
||||||
social_youtube: str | None = None
|
social_youtube: str | None = None
|
||||||
social_twitter: str | None = None
|
social_twitter: str | None = None
|
||||||
|
|||||||
@@ -126,21 +126,22 @@ async def register_user(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取客户端 IP 并查询地理位置
|
# 获取客户端 IP 并查询地理位置
|
||||||
country_code = "CN" # 默认国家代码
|
country_code = None # 默认国家代码
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 查询 IP 地理位置
|
# 查询 IP 地理位置
|
||||||
geo_info = geoip.lookup(client_ip)
|
geo_info = geoip.lookup(client_ip)
|
||||||
if geo_info and geo_info.get("country_iso"):
|
if geo_info and (country_code := geo_info.get("country_iso")):
|
||||||
country_code = geo_info["country_iso"]
|
|
||||||
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
|
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Could not determine country for IP {client_ip}")
|
logger.warning(f"Could not determine country for IP {client_ip}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
|
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
|
||||||
|
if country_code is None:
|
||||||
|
country_code = "CN"
|
||||||
|
|
||||||
# 创建新用户
|
# 创建新用户
|
||||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
# 确保 AUTO_INCREMENT 值从3开始(ID=2是BanchoBot)
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
text(
|
text(
|
||||||
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
||||||
@@ -157,7 +158,7 @@ async def register_user(
|
|||||||
email=user_email,
|
email=user_email,
|
||||||
pw_bcrypt=get_password_hash(user_password),
|
pw_bcrypt=get_password_hash(user_password),
|
||||||
priv=1, # 普通用户权限
|
priv=1, # 普通用户权限
|
||||||
country_code=country_code, # 根据 IP 地理位置设置国家
|
country_code=country_code,
|
||||||
join_date=utcnow(),
|
join_date=utcnow(),
|
||||||
last_visit=utcnow(),
|
last_visit=utcnow(),
|
||||||
is_supporter=settings.enable_supporter_for_all_users,
|
is_supporter=settings.enable_supporter_for_all_users,
|
||||||
@@ -386,7 +387,7 @@ async def oauth_token(
|
|||||||
|
|
||||||
return TokenResponse(
|
return TokenResponse(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
token_type="Bearer",
|
token_type="Bearer", # noqa: S106
|
||||||
expires_in=settings.access_token_expire_minutes * 60,
|
expires_in=settings.access_token_expire_minutes * 60,
|
||||||
refresh_token=refresh_token_str,
|
refresh_token=refresh_token_str,
|
||||||
scope=scope,
|
scope=scope,
|
||||||
@@ -439,7 +440,7 @@ async def oauth_token(
|
|||||||
)
|
)
|
||||||
return TokenResponse(
|
return TokenResponse(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
token_type="Bearer",
|
token_type="Bearer", # noqa: S106
|
||||||
expires_in=settings.access_token_expire_minutes * 60,
|
expires_in=settings.access_token_expire_minutes * 60,
|
||||||
refresh_token=new_refresh_token,
|
refresh_token=new_refresh_token,
|
||||||
scope=scope,
|
scope=scope,
|
||||||
@@ -509,7 +510,7 @@ async def oauth_token(
|
|||||||
|
|
||||||
return TokenResponse(
|
return TokenResponse(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
token_type="Bearer",
|
token_type="Bearer", # noqa: S106
|
||||||
expires_in=settings.access_token_expire_minutes * 60,
|
expires_in=settings.access_token_expire_minutes * 60,
|
||||||
refresh_token=refresh_token_str,
|
refresh_token=refresh_token_str,
|
||||||
scope=" ".join(scopes),
|
scope=" ".join(scopes),
|
||||||
@@ -554,7 +555,7 @@ async def oauth_token(
|
|||||||
|
|
||||||
return TokenResponse(
|
return TokenResponse(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
token_type="Bearer",
|
token_type="Bearer", # noqa: S106
|
||||||
expires_in=settings.access_token_expire_minutes * 60,
|
expires_in=settings.access_token_expire_minutes * 60,
|
||||||
refresh_token=refresh_token_str,
|
refresh_token=refresh_token_str,
|
||||||
scope=" ".join(scopes),
|
scope=" ".join(scopes),
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ def _coerce_playlist_item(item_data: dict[str, Any], default_order: int, host_us
|
|||||||
"allowed_mods": item_data.get("allowed_mods", []),
|
"allowed_mods": item_data.get("allowed_mods", []),
|
||||||
"expired": bool(item_data.get("expired", False)),
|
"expired": bool(item_data.get("expired", False)),
|
||||||
"playlist_order": item_data.get("playlist_order", default_order),
|
"playlist_order": item_data.get("playlist_order", default_order),
|
||||||
"played_at": item_data.get("played_at", None),
|
"played_at": item_data.get("played_at"),
|
||||||
"freestyle": bool(item_data.get("freestyle", True)),
|
"freestyle": bool(item_data.get("freestyle", True)),
|
||||||
"beatmap_checksum": item_data.get("beatmap_checksum", ""),
|
"beatmap_checksum": item_data.get("beatmap_checksum", ""),
|
||||||
"star_rating": item_data.get("star_rating", 0.0),
|
"star_rating": item_data.get("star_rating", 0.0),
|
||||||
|
|||||||
@@ -157,10 +157,7 @@ async def _help(user: User, args: list[str], _session: AsyncSession, channel: Ch
|
|||||||
|
|
||||||
@bot.command("roll")
|
@bot.command("roll")
|
||||||
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
|
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
|
||||||
if len(args) > 0 and args[0].isdigit():
|
r = random.randint(1, int(args[0])) if len(args) > 0 and args[0].isdigit() else random.randint(1, 100)
|
||||||
r = random.randint(1, int(args[0]))
|
|
||||||
else:
|
|
||||||
r = random.randint(1, 100)
|
|
||||||
return f"{user.username} rolls {r} point(s)"
|
return f"{user.username} rolls {r} point(s)"
|
||||||
|
|
||||||
|
|
||||||
@@ -179,10 +176,7 @@ async def _stats(user: User, args: list[str], session: AsyncSession, channel: Ch
|
|||||||
if gamemode is None:
|
if gamemode is None:
|
||||||
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
|
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
|
||||||
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
|
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
|
||||||
if last_score is not None:
|
gamemode = last_score.gamemode if last_score is not None else target_user.playmode
|
||||||
gamemode = last_score.gamemode
|
|
||||||
else:
|
|
||||||
gamemode = target_user.playmode
|
|
||||||
|
|
||||||
statistics = (
|
statistics = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
|
|||||||
@@ -313,10 +313,7 @@ async def chat_websocket(
|
|||||||
# 优先使用查询参数中的token,支持token或access_token参数名
|
# 优先使用查询参数中的token,支持token或access_token参数名
|
||||||
auth_token = token or access_token
|
auth_token = token or access_token
|
||||||
if not auth_token and authorization:
|
if not auth_token and authorization:
|
||||||
if authorization.startswith("Bearer "):
|
auth_token = authorization.removeprefix("Bearer ")
|
||||||
auth_token = authorization[7:]
|
|
||||||
else:
|
|
||||||
auth_token = authorization
|
|
||||||
|
|
||||||
if not auth_token:
|
if not auth_token:
|
||||||
await websocket.close(code=1008, reason="Missing authentication token")
|
await websocket.close(code=1008, reason="Missing authentication token")
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from fastapi.responses import RedirectResponse
|
|||||||
redirect_router = APIRouter(include_in_schema=False)
|
redirect_router = APIRouter(include_in_schema=False)
|
||||||
|
|
||||||
|
|
||||||
@redirect_router.get("/users/{path:path}")
|
@redirect_router.get("/users/{path:path}") # noqa: FAST003
|
||||||
@redirect_router.get("/teams/{team_id}")
|
@redirect_router.get("/teams/{team_id}")
|
||||||
@redirect_router.get("/u/{user_id}")
|
@redirect_router.get("/u/{user_id}")
|
||||||
@redirect_router.get("/b/{beatmap_id}")
|
@redirect_router.get("/b/{beatmap_id}")
|
||||||
|
|||||||
@@ -168,10 +168,7 @@ async def get_beatmaps(
|
|||||||
elif beatmapset_id is not None:
|
elif beatmapset_id is not None:
|
||||||
beatmapset = await Beatmapset.get_or_fetch(session, fetcher, beatmapset_id)
|
beatmapset = await Beatmapset.get_or_fetch(session, fetcher, beatmapset_id)
|
||||||
await beatmapset.awaitable_attrs.beatmaps
|
await beatmapset.awaitable_attrs.beatmaps
|
||||||
if len(beatmapset.beatmaps) > limit:
|
beatmaps = beatmapset.beatmaps[:limit] if len(beatmapset.beatmaps) > limit else beatmapset.beatmaps
|
||||||
beatmaps = beatmapset.beatmaps[:limit]
|
|
||||||
else:
|
|
||||||
beatmaps = beatmapset.beatmaps
|
|
||||||
elif user is not None:
|
elif user is not None:
|
||||||
where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
|
where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
|
||||||
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()
|
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()
|
||||||
|
|||||||
@@ -158,7 +158,10 @@ async def get_beatmap_attributes(
|
|||||||
if ruleset is None:
|
if ruleset is None:
|
||||||
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||||||
ruleset = beatmap_db.mode
|
ruleset = beatmap_db.mode
|
||||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
key = (
|
||||||
|
f"beatmap:{beatmap_id}:{ruleset}:"
|
||||||
|
f"{hashlib.md5(str(mods_).encode(), usedforsecurity=False).hexdigest()}:attributes"
|
||||||
|
)
|
||||||
if await redis.exists(key):
|
if await redis.exists(key):
|
||||||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ async def _save_to_db(sets: SearchBeatmapsetsResp):
|
|||||||
response_model=SearchBeatmapsetsResp,
|
response_model=SearchBeatmapsetsResp,
|
||||||
)
|
)
|
||||||
async def search_beatmapset(
|
async def search_beatmapset(
|
||||||
db: Database,
|
|
||||||
query: Annotated[SearchQueryModel, Query(...)],
|
query: Annotated[SearchQueryModel, Query(...)],
|
||||||
request: Request,
|
request: Request,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
@@ -104,7 +103,7 @@ async def search_beatmapset(
|
|||||||
if cached_result:
|
if cached_result:
|
||||||
sets = SearchBeatmapsetsResp(**cached_result)
|
sets = SearchBeatmapsetsResp(**cached_result)
|
||||||
# 处理资源代理
|
# 处理资源代理
|
||||||
processed_sets = await process_response_assets(sets, request)
|
processed_sets = await process_response_assets(sets)
|
||||||
return processed_sets
|
return processed_sets
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -115,7 +114,7 @@ async def search_beatmapset(
|
|||||||
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
|
await cache_service.cache_search_result(query_hash, cursor_hash, sets.model_dump())
|
||||||
|
|
||||||
# 处理资源代理
|
# 处理资源代理
|
||||||
processed_sets = await process_response_assets(sets, request)
|
processed_sets = await process_response_assets(sets)
|
||||||
return processed_sets
|
return processed_sets
|
||||||
except HTTPError as e:
|
except HTTPError as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||||
@@ -140,7 +139,7 @@ async def lookup_beatmapset(
|
|||||||
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
|
cached_resp = await cache_service.get_beatmap_lookup_from_cache(beatmap_id)
|
||||||
if cached_resp:
|
if cached_resp:
|
||||||
# 处理资源代理
|
# 处理资源代理
|
||||||
processed_resp = await process_response_assets(cached_resp, request)
|
processed_resp = await process_response_assets(cached_resp)
|
||||||
return processed_resp
|
return processed_resp
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -151,7 +150,7 @@ async def lookup_beatmapset(
|
|||||||
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
|
await cache_service.cache_beatmap_lookup(beatmap_id, resp)
|
||||||
|
|
||||||
# 处理资源代理
|
# 处理资源代理
|
||||||
processed_resp = await process_response_assets(resp, request)
|
processed_resp = await process_response_assets(resp)
|
||||||
return processed_resp
|
return processed_resp
|
||||||
except HTTPError as exc:
|
except HTTPError as exc:
|
||||||
raise HTTPException(status_code=404, detail="Beatmap not found") from exc
|
raise HTTPException(status_code=404, detail="Beatmap not found") from exc
|
||||||
@@ -176,7 +175,7 @@ async def get_beatmapset(
|
|||||||
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
|
cached_resp = await cache_service.get_beatmapset_from_cache(beatmapset_id)
|
||||||
if cached_resp:
|
if cached_resp:
|
||||||
# 处理资源代理
|
# 处理资源代理
|
||||||
processed_resp = await process_response_assets(cached_resp, request)
|
processed_resp = await process_response_assets(cached_resp)
|
||||||
return processed_resp
|
return processed_resp
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -187,7 +186,7 @@ async def get_beatmapset(
|
|||||||
await cache_service.cache_beatmapset(resp)
|
await cache_service.cache_beatmapset(resp)
|
||||||
|
|
||||||
# 处理资源代理
|
# 处理资源代理
|
||||||
processed_resp = await process_response_assets(resp, request)
|
processed_resp = await process_response_assets(resp)
|
||||||
return processed_resp
|
return processed_resp
|
||||||
except HTTPError as exc:
|
except HTTPError as exc:
|
||||||
raise HTTPException(status_code=404, detail="Beatmapset not found") from exc
|
raise HTTPException(status_code=404, detail="Beatmapset not found") from exc
|
||||||
|
|||||||
@@ -166,7 +166,6 @@ async def get_room(
|
|||||||
db: Database,
|
db: Database,
|
||||||
room_id: Annotated[int, Path(..., description="房间 ID")],
|
room_id: Annotated[int, Path(..., description="房间 ID")],
|
||||||
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
current_user: Annotated[User, Security(get_current_user, scopes=["public"])],
|
||||||
redis: Redis,
|
|
||||||
category: Annotated[
|
category: Annotated[
|
||||||
str,
|
str,
|
||||||
Query(
|
Query(
|
||||||
|
|||||||
@@ -847,10 +847,7 @@ async def reorder_score_pin(
|
|||||||
detail = "After score not found" if after_score_id else "Before score not found"
|
detail = "After score not found" if after_score_id else "Before score not found"
|
||||||
raise HTTPException(status_code=404, detail=detail)
|
raise HTTPException(status_code=404, detail=detail)
|
||||||
|
|
||||||
if after_score_id:
|
target_order = reference_score.pinned_order + 1 if after_score_id else reference_score.pinned_order
|
||||||
target_order = reference_score.pinned_order + 1
|
|
||||||
else:
|
|
||||||
target_order = reference_score.pinned_order
|
|
||||||
|
|
||||||
current_order = score_record.pinned_order
|
current_order = score_record.pinned_order
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class SessionReissueResponse(BaseModel):
|
|||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
class VerifyFailed(Exception):
|
class VerifyFailedError(Exception):
|
||||||
def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False):
|
def __init__(self, message: str, reason: str | None = None, should_reissue: bool = False):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.reason = reason
|
self.reason = reason
|
||||||
@@ -93,10 +93,7 @@ async def verify_session(
|
|||||||
# 智能选择验证方法(参考osu-web实现)
|
# 智能选择验证方法(参考osu-web实现)
|
||||||
# API版本较老或用户未设置TOTP时强制使用邮件验证
|
# API版本较老或用户未设置TOTP时强制使用邮件验证
|
||||||
# print(api_version, totp_key)
|
# print(api_version, totp_key)
|
||||||
if api_version < 20240101 or totp_key is None:
|
verify_method = "mail" if api_version < 20240101 or totp_key is None else "totp"
|
||||||
verify_method = "mail"
|
|
||||||
else:
|
|
||||||
verify_method = "totp"
|
|
||||||
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
|
await LoginSessionService.set_login_method(user_id, token_id, verify_method, redis)
|
||||||
login_method = verify_method
|
login_method = verify_method
|
||||||
|
|
||||||
@@ -109,7 +106,7 @@ async def verify_session(
|
|||||||
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
||||||
)
|
)
|
||||||
verify_method = "mail"
|
verify_method = "mail"
|
||||||
raise VerifyFailed("用户TOTP已被删除,已切换到邮件验证")
|
raise VerifyFailedError("用户TOTP已被删除,已切换到邮件验证")
|
||||||
# 如果未开启邮箱验证,则直接认为认证通过
|
# 如果未开启邮箱验证,则直接认为认证通过
|
||||||
# 正常不会进入到这里
|
# 正常不会进入到这里
|
||||||
|
|
||||||
@@ -120,16 +117,16 @@ async def verify_session(
|
|||||||
else:
|
else:
|
||||||
# 记录详细的验证失败原因(参考osu-web的错误处理)
|
# 记录详细的验证失败原因(参考osu-web的错误处理)
|
||||||
if len(verification_key) != 6:
|
if len(verification_key) != 6:
|
||||||
raise VerifyFailed("TOTP验证码长度错误,应为6位数字", reason="incorrect_length")
|
raise VerifyFailedError("TOTP验证码长度错误,应为6位数字", reason="incorrect_length")
|
||||||
elif not verification_key.isdigit():
|
elif not verification_key.isdigit():
|
||||||
raise VerifyFailed("TOTP验证码格式错误,应为纯数字", reason="incorrect_format")
|
raise VerifyFailedError("TOTP验证码格式错误,应为纯数字", reason="incorrect_format")
|
||||||
else:
|
else:
|
||||||
# 可能是密钥错误或者重放攻击
|
# 可能是密钥错误或者重放攻击
|
||||||
raise VerifyFailed("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
|
raise VerifyFailedError("TOTP 验证失败,请检查验证码是否正确且未过期", reason="incorrect_key")
|
||||||
else:
|
else:
|
||||||
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
|
success, message = await EmailVerificationService.verify_email_code(db, redis, user_id, verification_key)
|
||||||
if not success:
|
if not success:
|
||||||
raise VerifyFailed(f"邮件验证失败: {message}")
|
raise VerifyFailedError(f"邮件验证失败: {message}")
|
||||||
|
|
||||||
await LoginLogService.record_login(
|
await LoginLogService.record_login(
|
||||||
db=db,
|
db=db,
|
||||||
@@ -144,7 +141,7 @@ async def verify_session(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
||||||
except VerifyFailed as e:
|
except VerifyFailedError as e:
|
||||||
await LoginLogService.record_failed_login(
|
await LoginLogService.record_failed_login(
|
||||||
db=db,
|
db=db,
|
||||||
request=request,
|
request=request,
|
||||||
@@ -171,7 +168,9 @@ async def verify_session(
|
|||||||
)
|
)
|
||||||
error_response["reissued"] = True
|
error_response["reissued"] = True
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # 忽略重发邮件失败的错误
|
log("Verification").exception(
|
||||||
|
f"Failed to resend verification email to user {current_user.id} (token: {token_id})"
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response)
|
return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=error_response)
|
||||||
|
|
||||||
|
|||||||
@@ -44,9 +44,7 @@ async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession
|
|||||||
.where(col(Score.beatmap).has(col(Beatmap.mode) == Score.gamemode))
|
.where(col(Score.beatmap).has(col(Beatmap.mode) == Score.gamemode))
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if user_beatmap_score is None:
|
return user_beatmap_score is not None
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
@@ -75,10 +73,9 @@ async def vote_beatmap_tags(
|
|||||||
.where(BeatmapTagVote.user_id == current_user.id)
|
.where(BeatmapTagVote.user_id == current_user.id)
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if previous_votes is None:
|
if previous_votes is None and check_user_can_vote(current_user, beatmap_id, session):
|
||||||
if check_user_can_vote(current_user, beatmap_id, session):
|
new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
|
||||||
new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
|
session.add(new_vote)
|
||||||
session.add(new_vote)
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise HTTPException(400, "Tag is not found")
|
raise HTTPException(400, "Tag is not found")
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ async def get_users(
|
|||||||
|
|
||||||
# 处理资源代理
|
# 处理资源代理
|
||||||
response = BatchUserResponse(users=cached_users)
|
response = BatchUserResponse(users=cached_users)
|
||||||
processed_response = await process_response_assets(response, request)
|
processed_response = await process_response_assets(response)
|
||||||
return processed_response
|
return processed_response
|
||||||
else:
|
else:
|
||||||
searched_users = (await session.exec(select(User).limit(50))).all()
|
searched_users = (await session.exec(select(User).limit(50))).all()
|
||||||
@@ -109,7 +109,7 @@ async def get_users(
|
|||||||
|
|
||||||
# 处理资源代理
|
# 处理资源代理
|
||||||
response = BatchUserResponse(users=users)
|
response = BatchUserResponse(users=users)
|
||||||
processed_response = await process_response_assets(response, request)
|
processed_response = await process_response_assets(response)
|
||||||
return processed_response
|
return processed_response
|
||||||
|
|
||||||
|
|
||||||
@@ -240,7 +240,7 @@ async def get_user_info(
|
|||||||
cached_user = await cache_service.get_user_from_cache(user_id_int)
|
cached_user = await cache_service.get_user_from_cache(user_id_int)
|
||||||
if cached_user:
|
if cached_user:
|
||||||
# 处理资源代理
|
# 处理资源代理
|
||||||
processed_user = await process_response_assets(cached_user, request)
|
processed_user = await process_response_assets(cached_user)
|
||||||
return processed_user
|
return processed_user
|
||||||
|
|
||||||
searched_user = (
|
searched_user = (
|
||||||
@@ -263,7 +263,7 @@ async def get_user_info(
|
|||||||
background_task.add_task(cache_service.cache_user, user_resp)
|
background_task.add_task(cache_service.cache_user, user_resp)
|
||||||
|
|
||||||
# 处理资源代理
|
# 处理资源代理
|
||||||
processed_user = await process_response_assets(user_resp, request)
|
processed_user = await process_response_assets(user_resp)
|
||||||
return processed_user
|
return processed_user
|
||||||
|
|
||||||
|
|
||||||
@@ -381,7 +381,7 @@ async def get_user_scores(
|
|||||||
user_id, type, include_fails, mode, limit, offset, is_legacy_api
|
user_id, type, include_fails, mode, limit, offset, is_legacy_api
|
||||||
)
|
)
|
||||||
if cached_scores is not None:
|
if cached_scores is not None:
|
||||||
processed_scores = await process_response_assets(cached_scores, request)
|
processed_scores = await process_response_assets(cached_scores)
|
||||||
return processed_scores
|
return processed_scores
|
||||||
|
|
||||||
db_user = await session.get(User, user_id)
|
db_user = await session.get(User, user_id)
|
||||||
@@ -438,5 +438,5 @@ async def get_user_scores(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 处理资源代理
|
# 处理资源代理
|
||||||
processed_scores = await process_response_assets(score_responses, request)
|
processed_scores = await process_response_assets(score_responses)
|
||||||
return processed_scores
|
return processed_scores
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from app.service.asset_proxy_service import get_asset_proxy_service
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
|
||||||
async def process_response_assets(data: Any, request: Request) -> Any:
|
async def process_response_assets(data: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
根据配置处理响应数据中的资源URL
|
根据配置处理响应数据中的资源URL
|
||||||
|
|
||||||
@@ -72,7 +72,7 @@ def asset_proxy_response(func):
|
|||||||
|
|
||||||
# 如果有request对象且启用了资源代理,则处理响应
|
# 如果有request对象且启用了资源代理,则处理响应
|
||||||
if request and settings.enable_asset_proxy and should_process_asset_proxy(request.url.path):
|
if request and settings.enable_asset_proxy and should_process_asset_proxy(request.url.path):
|
||||||
result = await process_response_assets(result, request)
|
result = await process_response_assets(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ class BeatmapCacheService:
|
|||||||
if size:
|
if size:
|
||||||
total_size += size
|
total_size += size
|
||||||
except Exception:
|
except Exception:
|
||||||
|
logger.debug(f"Failed to get size for key {key}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -36,11 +36,8 @@ def safe_json_dumps(data) -> str:
|
|||||||
|
|
||||||
def generate_hash(data) -> str:
|
def generate_hash(data) -> str:
|
||||||
"""生成数据的MD5哈希值"""
|
"""生成数据的MD5哈希值"""
|
||||||
if isinstance(data, str):
|
content = data if isinstance(data, str) else safe_json_dumps(data)
|
||||||
content = data
|
return hashlib.md5(content.encode(), usedforsecurity=False).hexdigest()
|
||||||
else:
|
|
||||||
content = safe_json_dumps(data)
|
|
||||||
return hashlib.md5(content.encode()).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
class BeatmapsetCacheService:
|
class BeatmapsetCacheService:
|
||||||
|
|||||||
@@ -110,9 +110,7 @@ class ProcessingBeatmapset:
|
|||||||
changed_beatmaps = []
|
changed_beatmaps = []
|
||||||
for bm in self.beatmapset.beatmaps:
|
for bm in self.beatmapset.beatmaps:
|
||||||
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None)
|
saved = next((s for s in self.record.beatmaps if s["beatmap_id"] == bm.id), None)
|
||||||
if not saved:
|
if not saved or saved["is_deleted"]:
|
||||||
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
|
|
||||||
elif saved["is_deleted"]:
|
|
||||||
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
|
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_ADDED))
|
||||||
elif saved["md5"] != bm.checksum:
|
elif saved["md5"] != bm.checksum:
|
||||||
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED))
|
changed_beatmaps.append(ChangedBeatmap(bm.id, BeatmapChangeType.MAP_UPDATED))
|
||||||
@@ -285,7 +283,7 @@ class BeatmapsetUpdateService:
|
|||||||
async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp):
|
async def _process_changed_beatmapset(self, beatmapset: BeatmapsetResp):
|
||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
db_beatmapset = await session.get(Beatmapset, beatmapset.id)
|
db_beatmapset = await session.get(Beatmapset, beatmapset.id)
|
||||||
new_beatmapset = await Beatmapset.from_resp_no_save(session, beatmapset)
|
new_beatmapset = await Beatmapset.from_resp_no_save(beatmapset)
|
||||||
if db_beatmapset:
|
if db_beatmapset:
|
||||||
await session.merge(new_beatmapset)
|
await session.merge(new_beatmapset)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -356,5 +354,7 @@ def init_beatmapset_update_service(fetcher: "Fetcher") -> BeatmapsetUpdateServic
|
|||||||
|
|
||||||
|
|
||||||
def get_beatmapset_update_service() -> BeatmapsetUpdateService:
|
def get_beatmapset_update_service() -> BeatmapsetUpdateService:
|
||||||
|
if service is None:
|
||||||
|
raise ValueError("BeatmapsetUpdateService is not initialized")
|
||||||
assert service is not None, "BeatmapsetUpdateService is not initialized"
|
assert service is not None, "BeatmapsetUpdateService is not initialized"
|
||||||
return service
|
return service
|
||||||
|
|||||||
@@ -128,7 +128,11 @@ class LoginLogService:
|
|||||||
login_success=False,
|
login_success=False,
|
||||||
login_method=login_method,
|
login_method=login_method,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt",
|
notes=(
|
||||||
|
f"Failed login attempt on user {attempted_username}: {notes}"
|
||||||
|
if attempted_username
|
||||||
|
else "Failed login attempt"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class PasswordResetService:
|
|||||||
await redis.delete(reset_code_key)
|
await redis.delete(reset_code_key)
|
||||||
await redis.delete(rate_limit_key)
|
await redis.delete(rate_limit_key)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
logger.warning("Failed to clean up Redis data after error")
|
||||||
logger.exception("Redis operation failed")
|
logger.exception("Redis operation failed")
|
||||||
return False, "服务暂时不可用,请稍后重试"
|
return False, "服务暂时不可用,请稍后重试"
|
||||||
|
|
||||||
|
|||||||
@@ -593,10 +593,7 @@ class RankingCacheService:
|
|||||||
async def invalidate_country_cache(self, ruleset: GameMode | None = None) -> None:
|
async def invalidate_country_cache(self, ruleset: GameMode | None = None) -> None:
|
||||||
"""使地区排行榜缓存失效"""
|
"""使地区排行榜缓存失效"""
|
||||||
try:
|
try:
|
||||||
if ruleset:
|
pattern = f"country_ranking:{ruleset}:*" if ruleset else "country_ranking:*"
|
||||||
pattern = f"country_ranking:{ruleset}:*"
|
|
||||||
else:
|
|
||||||
pattern = "country_ranking:*"
|
|
||||||
|
|
||||||
keys = await self.redis.keys(pattern)
|
keys = await self.redis.keys(pattern)
|
||||||
if keys:
|
if keys:
|
||||||
@@ -608,10 +605,7 @@ class RankingCacheService:
|
|||||||
async def invalidate_team_cache(self, ruleset: GameMode | None = None) -> None:
|
async def invalidate_team_cache(self, ruleset: GameMode | None = None) -> None:
|
||||||
"""使战队排行榜缓存失效"""
|
"""使战队排行榜缓存失效"""
|
||||||
try:
|
try:
|
||||||
if ruleset:
|
pattern = f"team_ranking:{ruleset}:*" if ruleset else "team_ranking:*"
|
||||||
pattern = f"team_ranking:{ruleset}:*"
|
|
||||||
else:
|
|
||||||
pattern = "team_ranking:*"
|
|
||||||
|
|
||||||
keys = await self.redis.keys(pattern)
|
keys = await self.redis.keys(pattern)
|
||||||
if keys:
|
if keys:
|
||||||
@@ -637,6 +631,7 @@ class RankingCacheService:
|
|||||||
if size:
|
if size:
|
||||||
total_size += size
|
total_size += size
|
||||||
except Exception:
|
except Exception:
|
||||||
|
logger.warning(f"Failed to get memory usage for key {key}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
0
app/service/subscribers/__init__.py
Normal file
0
app/service/subscribers/__init__.py
Normal file
@@ -35,19 +35,19 @@ class ChatSubscriber(RedisSubscriber):
|
|||||||
self.add_handler(ON_NOTIFICATION, self.on_notification)
|
self.add_handler(ON_NOTIFICATION, self.on_notification)
|
||||||
self.start()
|
self.start()
|
||||||
|
|
||||||
async def on_join_room(self, c: str, s: str):
|
async def on_join_room(self, c: str, s: str): # noqa: ARG002
|
||||||
channel_id, user_id = s.split(":")
|
channel_id, user_id = s.split(":")
|
||||||
if self.chat_server is None:
|
if self.chat_server is None:
|
||||||
return
|
return
|
||||||
await self.chat_server.join_room_channel(int(channel_id), int(user_id))
|
await self.chat_server.join_room_channel(int(channel_id), int(user_id))
|
||||||
|
|
||||||
async def on_leave_room(self, c: str, s: str):
|
async def on_leave_room(self, c: str, s: str): # noqa: ARG002
|
||||||
channel_id, user_id = s.split(":")
|
channel_id, user_id = s.split(":")
|
||||||
if self.chat_server is None:
|
if self.chat_server is None:
|
||||||
return
|
return
|
||||||
await self.chat_server.leave_room_channel(int(channel_id), int(user_id))
|
await self.chat_server.leave_room_channel(int(channel_id), int(user_id))
|
||||||
|
|
||||||
async def on_notification(self, c: str, s: str):
|
async def on_notification(self, c: str, s: str): # noqa: ARG002
|
||||||
try:
|
try:
|
||||||
detail = TypeAdapter(NotificationDetails).validate_json(s)
|
detail = TypeAdapter(NotificationDetails).validate_json(s)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|||||||
@@ -357,6 +357,7 @@ class UserCacheService:
|
|||||||
if size:
|
if size:
|
||||||
total_size += size
|
total_size += size
|
||||||
except Exception:
|
except Exception:
|
||||||
|
logger.warning(f"Failed to get memory usage for key {key}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -288,10 +288,6 @@ This email was sent automatically, please do not reply.
|
|||||||
redis: Redis,
|
redis: Redis,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
code: str,
|
code: str,
|
||||||
ip_address: str | None = None,
|
|
||||||
user_agent: str | None = None,
|
|
||||||
client_id: int | None = None,
|
|
||||||
country_code: str | None = None,
|
|
||||||
) -> tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
"""验证邮箱验证码"""
|
"""验证邮箱验证码"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ async def warmup_cache() -> None:
|
|||||||
logger.info("Beatmap cache warmup completed successfully")
|
logger.info("Beatmap cache warmup completed successfully")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Beatmap cache warmup failed: %s", e)
|
logger.error(f"Beatmap cache warmup failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def refresh_ranking_cache() -> None:
|
async def refresh_ranking_cache() -> None:
|
||||||
@@ -59,7 +59,7 @@ async def refresh_ranking_cache() -> None:
|
|||||||
logger.info("Ranking cache refresh completed successfully")
|
logger.info("Ranking cache refresh completed successfully")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Ranking cache refresh failed: %s", e)
|
logger.error(f"Ranking cache refresh failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def schedule_user_cache_preload_task() -> None:
|
async def schedule_user_cache_preload_task() -> None:
|
||||||
@@ -93,14 +93,14 @@ async def schedule_user_cache_preload_task() -> None:
|
|||||||
if active_user_ids:
|
if active_user_ids:
|
||||||
user_ids = [row[0] for row in active_user_ids]
|
user_ids = [row[0] for row in active_user_ids]
|
||||||
await cache_service.preload_user_cache(session, user_ids)
|
await cache_service.preload_user_cache(session, user_ids)
|
||||||
logger.info("Preloaded cache for %s active users", len(user_ids))
|
logger.info(f"Preloaded cache for {len(user_ids)} active users")
|
||||||
else:
|
else:
|
||||||
logger.info("No active users found for cache preload")
|
logger.info("No active users found for cache preload")
|
||||||
|
|
||||||
logger.info("User cache preload task completed successfully")
|
logger.info("User cache preload task completed successfully")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("User cache preload task failed: %s", e)
|
logger.error(f"User cache preload task failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def schedule_user_cache_warmup_task() -> None:
|
async def schedule_user_cache_warmup_task() -> None:
|
||||||
@@ -131,18 +131,18 @@ async def schedule_user_cache_warmup_task() -> None:
|
|||||||
if top_users:
|
if top_users:
|
||||||
user_ids = list(top_users)
|
user_ids = list(top_users)
|
||||||
await cache_service.preload_user_cache(session, user_ids)
|
await cache_service.preload_user_cache(session, user_ids)
|
||||||
logger.info("Warmed cache for top 100 users in %s", mode)
|
logger.info(f"Warmed cache for top 100 users in {mode}")
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to warm cache for %s: %s", mode, e)
|
logger.error(f"Failed to warm cache for {mode}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info("User cache warmup task completed successfully")
|
logger.info("User cache warmup task completed successfully")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("User cache warmup task failed: %s", e)
|
logger.error(f"User cache warmup task failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def schedule_user_cache_cleanup_task() -> None:
|
async def schedule_user_cache_cleanup_task() -> None:
|
||||||
@@ -155,11 +155,11 @@ async def schedule_user_cache_cleanup_task() -> None:
|
|||||||
cache_service = get_user_cache_service(redis)
|
cache_service = get_user_cache_service(redis)
|
||||||
stats = await cache_service.get_cache_stats()
|
stats = await cache_service.get_cache_stats()
|
||||||
|
|
||||||
logger.info("User cache stats: %s", stats)
|
logger.info(f"User cache stats: {stats}")
|
||||||
logger.info("User cache cleanup task completed successfully")
|
logger.info("User cache cleanup task completed successfully")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("User cache cleanup task failed: %s", e)
|
logger.error(f"User cache cleanup task failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def warmup_user_cache() -> None:
|
async def warmup_user_cache() -> None:
|
||||||
@@ -167,7 +167,7 @@ async def warmup_user_cache() -> None:
|
|||||||
try:
|
try:
|
||||||
await schedule_user_cache_warmup_task()
|
await schedule_user_cache_warmup_task()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("User cache warmup failed: %s", e)
|
logger.error(f"User cache warmup failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def preload_user_cache() -> None:
|
async def preload_user_cache() -> None:
|
||||||
@@ -175,7 +175,7 @@ async def preload_user_cache() -> None:
|
|||||||
try:
|
try:
|
||||||
await schedule_user_cache_preload_task()
|
await schedule_user_cache_preload_task()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("User cache preload failed: %s", e)
|
logger.error(f"User cache preload failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_user_cache() -> None:
|
async def cleanup_user_cache() -> None:
|
||||||
@@ -183,7 +183,7 @@ async def cleanup_user_cache() -> None:
|
|||||||
try:
|
try:
|
||||||
await schedule_user_cache_cleanup_task()
|
await schedule_user_cache_cleanup_task()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("User cache cleanup failed: %s", e)
|
logger.error(f"User cache cleanup failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
def register_cache_jobs() -> None:
|
def register_cache_jobs() -> None:
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ Periodically update the MaxMind GeoIP database
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.dependencies.geoip import get_geoip_helper
|
from app.dependencies.geoip import get_geoip_helper
|
||||||
from app.dependencies.scheduler import get_scheduler
|
from app.dependencies.scheduler import get_scheduler
|
||||||
@@ -28,14 +26,10 @@ async def update_geoip_database():
|
|||||||
try:
|
try:
|
||||||
logger.info("Starting scheduled GeoIP database update...")
|
logger.info("Starting scheduled GeoIP database update...")
|
||||||
geoip = get_geoip_helper()
|
geoip = get_geoip_helper()
|
||||||
|
await geoip.update(force=False)
|
||||||
# Run the synchronous update method in a background thread
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
await loop.run_in_executor(None, lambda: geoip.update(force=False))
|
|
||||||
|
|
||||||
logger.info("Scheduled GeoIP database update completed successfully")
|
logger.info("Scheduled GeoIP database update completed successfully")
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.error(f"Scheduled GeoIP database update failed: {e}")
|
logger.error(f"Scheduled GeoIP database update failed: {exc}")
|
||||||
|
|
||||||
|
|
||||||
async def init_geoip():
|
async def init_geoip():
|
||||||
@@ -45,13 +39,8 @@ async def init_geoip():
|
|||||||
try:
|
try:
|
||||||
geoip = get_geoip_helper()
|
geoip = get_geoip_helper()
|
||||||
logger.info("Initializing GeoIP database...")
|
logger.info("Initializing GeoIP database...")
|
||||||
|
await geoip.update(force=False)
|
||||||
# Run the synchronous update method in a background thread
|
|
||||||
# force=False means only download if files don't exist or are expired
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
await loop.run_in_executor(None, lambda: geoip.update(force=False))
|
|
||||||
|
|
||||||
logger.info("GeoIP database initialization completed")
|
logger.info("GeoIP database initialization completed")
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.error(f"GeoIP database initialization failed: {e}")
|
logger.error(f"GeoIP database initialization failed: {exc}")
|
||||||
# Do not raise an exception to avoid blocking application startup
|
# Do not raise an exception to avoid blocking application startup
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ async def create_rx_statistics():
|
|||||||
async with with_db() as session:
|
async with with_db() as session:
|
||||||
users = (await session.exec(select(User.id))).all()
|
users = (await session.exec(select(User.id))).all()
|
||||||
total_users = len(users)
|
total_users = len(users)
|
||||||
logger.info("Ensuring RX/AP statistics exist for %s users", total_users)
|
logger.info(f"Ensuring RX/AP statistics exist for {total_users} users")
|
||||||
rx_created = 0
|
rx_created = 0
|
||||||
ap_created = 0
|
ap_created = 0
|
||||||
for i in users:
|
for i in users:
|
||||||
@@ -57,7 +57,5 @@ async def create_rx_statistics():
|
|||||||
await session.commit()
|
await session.commit()
|
||||||
if rx_created or ap_created:
|
if rx_created or ap_created:
|
||||||
logger.success(
|
logger.success(
|
||||||
"Created %s RX statistics rows and %s AP statistics rows during backfill",
|
f"Created {rx_created} RX statistics rows and {ap_created} AP statistics rows during backfill"
|
||||||
rx_created,
|
|
||||||
ap_created,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -258,10 +258,7 @@ class BackgroundTasks:
|
|||||||
self.tasks = set(tasks) if tasks else set()
|
self.tasks = set(tasks) if tasks else set()
|
||||||
|
|
||||||
def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
|
def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
|
||||||
if is_async_callable(func):
|
coro = func(*args, **kwargs) if is_async_callable(func) else run_in_threadpool(func, *args, **kwargs)
|
||||||
coro = func(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
coro = run_in_threadpool(func, *args, **kwargs)
|
|
||||||
task = asyncio.create_task(coro)
|
task = asyncio.create_task(coro)
|
||||||
self.tasks.add(task)
|
self.tasks.add(task)
|
||||||
task.add_done_callback(self.tasks.discard)
|
task.add_done_callback(self.tasks.discard)
|
||||||
|
|||||||
13
main.py
13
main.py
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
@@ -50,7 +51,7 @@ import sentry_sdk
|
|||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI): # noqa: ARG001
|
||||||
# on startup
|
# on startup
|
||||||
init_mods()
|
init_mods()
|
||||||
init_ranked_mods()
|
init_ranked_mods()
|
||||||
@@ -223,26 +224,26 @@ async def health_check():
|
|||||||
|
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
async def validation_exception_handler(request: Request, exc: RequestValidationError): # noqa: ARG001
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=422,
|
status_code=422,
|
||||||
content={
|
content={
|
||||||
"error": exc.errors(),
|
"error": json.dumps(exc.errors()),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(HTTPException)
|
@app.exception_handler(HTTPException)
|
||||||
async def http_exception_handler(requst: Request, exc: HTTPException):
|
async def http_exception_handler(request: Request, exc: HTTPException): # noqa: ARG001
|
||||||
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
|
return JSONResponse(status_code=exc.status_code, content={"error": exc.detail})
|
||||||
|
|
||||||
|
|
||||||
if settings.secret_key == "your_jwt_secret_here":
|
if settings.secret_key == "your_jwt_secret_here": # noqa: S105
|
||||||
system_logger("Security").opt(colors=True).warning(
|
system_logger("Security").opt(colors=True).warning(
|
||||||
"<y>jwt_secret_key</y> is unset. Your server is unsafe. "
|
"<y>jwt_secret_key</y> is unset. Your server is unsafe. "
|
||||||
"Use this command to generate: <blue>openssl rand -hex 32</blue>."
|
"Use this command to generate: <blue>openssl rand -hex 32</blue>."
|
||||||
)
|
)
|
||||||
if settings.osu_web_client_secret == "your_osu_web_client_secret_here":
|
if settings.osu_web_client_secret == "your_osu_web_client_secret_here": # noqa: S105
|
||||||
system_logger("Security").opt(colors=True).warning(
|
system_logger("Security").opt(colors=True).warning(
|
||||||
"<y>osu_web_client_secret</y> is unset. Your server is unsafe. "
|
"<y>osu_web_client_secret</y> is unset. Your server is unsafe. "
|
||||||
"Use this command to generate: <blue>openssl rand -hex 40</blue>."
|
"Use this command to generate: <blue>openssl rand -hex 40</blue>."
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
# ruff: noqa
|
||||||
"""add_password_reset_table
|
"""add_password_reset_table
|
||||||
|
|
||||||
Revision ID: d103d442dc24
|
Revision ID: d103d442dc24
|
||||||
|
|||||||
@@ -55,12 +55,20 @@ select = [
|
|||||||
"ASYNC", # flake8-async
|
"ASYNC", # flake8-async
|
||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"T10", # flake8-debugger
|
"T10", # flake8-debugger
|
||||||
# "T20", # flake8-print
|
|
||||||
"PYI", # flake8-pyi
|
"PYI", # flake8-pyi
|
||||||
"PT", # flake8-pytest-style
|
"PT", # flake8-pytest-style
|
||||||
"Q", # flake8-quotes
|
"Q", # flake8-quotes
|
||||||
"TID", # flake8-tidy-imports
|
"TID", # flake8-tidy-imports
|
||||||
"RUF", # Ruff-specific rules
|
"RUF", # Ruff-specific rules
|
||||||
|
"FAST", # FastAPI
|
||||||
|
"YTT", # flake8-2020
|
||||||
|
"S", # flake8-bandit
|
||||||
|
"INP", # flake8-no-pep420
|
||||||
|
"SIM", # flake8-simplify
|
||||||
|
"ARG", # flake8-unused-arguments
|
||||||
|
"PTH", # flake8-use-pathlib
|
||||||
|
"N", # pep8-naming
|
||||||
|
"FURB" # refurb
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
"E402", # module-import-not-at-top-of-file
|
"E402", # module-import-not-at-top-of-file
|
||||||
@@ -68,10 +76,17 @@ ignore = [
|
|||||||
"RUF001", # ambiguous-unicode-character-string
|
"RUF001", # ambiguous-unicode-character-string
|
||||||
"RUF002", # ambiguous-unicode-character-docstring
|
"RUF002", # ambiguous-unicode-character-docstring
|
||||||
"RUF003", # ambiguous-unicode-character-comment
|
"RUF003", # ambiguous-unicode-character-comment
|
||||||
|
"S101", # assert
|
||||||
|
"S311", # suspicious-non-cryptographic-random-usage
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint.extend-per-file-ignores]
|
[tool.ruff.lint.extend-per-file-ignores]
|
||||||
"app/database/**/*.py" = ["I002"]
|
"app/database/**/*.py" = ["I002"]
|
||||||
|
"tools/*.py" = ["PTH", "INP001"]
|
||||||
|
"migrations/**/*.py" = ["INP001"]
|
||||||
|
".github/**/*.py" = ["INP001"]
|
||||||
|
"app/achievements/*.py" = ["INP001", "ARG"]
|
||||||
|
"app/router/**/*.py" = ["ARG001"]
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
[tool.ruff.lint.isort]
|
||||||
force-sort-within-sections = true
|
force-sort-within-sections = true
|
||||||
|
|||||||
@@ -163,13 +163,19 @@ async def main():
|
|||||||
|
|
||||||
# Show specific changes
|
# Show specific changes
|
||||||
changes = []
|
changes = []
|
||||||
if "scorerank" in original_payload and "scorerank" in fixed_payload:
|
if (
|
||||||
if original_payload["scorerank"] != fixed_payload["scorerank"]:
|
"scorerank" in original_payload
|
||||||
changes.append(f"scorerank: {original_payload['scorerank']} → {fixed_payload['scorerank']}")
|
and "scorerank" in fixed_payload
|
||||||
|
and original_payload["scorerank"] != fixed_payload["scorerank"]
|
||||||
|
):
|
||||||
|
changes.append(f"scorerank: {original_payload['scorerank']} → {fixed_payload['scorerank']}")
|
||||||
|
|
||||||
if "mode" in original_payload and "mode" in fixed_payload:
|
if (
|
||||||
if original_payload["mode"] != fixed_payload["mode"]:
|
"mode" in original_payload
|
||||||
changes.append(f"mode: {original_payload['mode']} → {fixed_payload['mode']}")
|
and "mode" in fixed_payload
|
||||||
|
and original_payload["mode"] != fixed_payload["mode"]
|
||||||
|
):
|
||||||
|
changes.append(f"mode: {original_payload['mode']} → {fixed_payload['mode']}")
|
||||||
|
|
||||||
if changes:
|
if changes:
|
||||||
print(f" Changes: {', '.join(changes)}")
|
print(f" Changes: {', '.join(changes)}")
|
||||||
|
|||||||
Reference in New Issue
Block a user