refactor(project): make pyright & ruff happy
This commit is contained in:
@@ -65,9 +65,7 @@ async def to_the_core(
|
||||
# using either of the mods specified: DT, NC
|
||||
if not score.passed:
|
||||
return False
|
||||
if (
|
||||
"Nightcore" not in beatmap.beatmapset.title
|
||||
) and "Nightcore" not in beatmap.beatmapset.artist:
|
||||
if ("Nightcore" not in beatmap.beatmapset.title) and "Nightcore" not in beatmap.beatmapset.artist:
|
||||
return False
|
||||
mods_ = mod_to_save(score.mods)
|
||||
if "DT" not in mods_ or "NC" not in mods_:
|
||||
@@ -118,9 +116,7 @@ async def reckless_adandon(
|
||||
fetcher = await get_fetcher()
|
||||
redis = get_redis()
|
||||
mods_ = score.mods.copy()
|
||||
attribute = await calculate_beatmap_attributes(
|
||||
beatmap.id, score.gamemode, mods_, redis, fetcher
|
||||
)
|
||||
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
||||
if attribute.star_rating < 3:
|
||||
return False
|
||||
return True
|
||||
@@ -186,9 +182,7 @@ async def slow_and_steady(
|
||||
fetcher = await get_fetcher()
|
||||
redis = get_redis()
|
||||
mods_ = score.mods.copy()
|
||||
attribute = await calculate_beatmap_attributes(
|
||||
beatmap.id, score.gamemode, mods_, redis, fetcher
|
||||
)
|
||||
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
||||
return attribute.star_rating >= 3
|
||||
|
||||
|
||||
@@ -218,9 +212,7 @@ async def sognare(
|
||||
mods_ = mod_to_save(score.mods)
|
||||
if "HT" not in mods_:
|
||||
return False
|
||||
return (
|
||||
beatmap.beatmapset.artist == "LeaF" and beatmap.beatmapset.title == "Evanescent"
|
||||
)
|
||||
return beatmap.beatmapset.artist == "LeaF" and beatmap.beatmapset.title == "Evanescent"
|
||||
|
||||
|
||||
async def realtor_extraordinaire(
|
||||
@@ -234,10 +226,7 @@ async def realtor_extraordinaire(
|
||||
mods_ = mod_to_save(score.mods)
|
||||
if not ("DT" in mods_ or "NC" in mods_) or "HR" not in mods_:
|
||||
return False
|
||||
return (
|
||||
beatmap.beatmapset.artist == "cYsmix"
|
||||
and beatmap.beatmapset.title == "House With Legs"
|
||||
)
|
||||
return beatmap.beatmapset.artist == "cYsmix" and beatmap.beatmapset.title == "House With Legs"
|
||||
|
||||
|
||||
async def impeccable(
|
||||
@@ -255,9 +244,7 @@ async def impeccable(
|
||||
fetcher = await get_fetcher()
|
||||
redis = get_redis()
|
||||
mods_ = score.mods.copy()
|
||||
attribute = await calculate_beatmap_attributes(
|
||||
beatmap.id, score.gamemode, mods_, redis, fetcher
|
||||
)
|
||||
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
||||
return attribute.star_rating >= 4
|
||||
|
||||
|
||||
@@ -274,18 +261,14 @@ async def aeon(
|
||||
mods_ = mod_to_save(score.mods)
|
||||
if "FL" not in mods_ or "HD" not in mods_ or "HT" not in mods_:
|
||||
return False
|
||||
if not beatmap.beatmapset.ranked_date or beatmap.beatmapset.ranked_date > datetime(
|
||||
2012, 1, 1
|
||||
):
|
||||
if not beatmap.beatmapset.ranked_date or beatmap.beatmapset.ranked_date > datetime(2012, 1, 1):
|
||||
return False
|
||||
if beatmap.total_length < 180:
|
||||
return False
|
||||
fetcher = await get_fetcher()
|
||||
redis = get_redis()
|
||||
mods_ = score.mods.copy()
|
||||
attribute = await calculate_beatmap_attributes(
|
||||
beatmap.id, score.gamemode, mods_, redis, fetcher
|
||||
)
|
||||
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
||||
return attribute.star_rating >= 4
|
||||
|
||||
|
||||
@@ -297,10 +280,7 @@ async def quick_maths(
|
||||
# Get exactly 34 misses on any difficulty of Function Phantom - Variable.
|
||||
if score.nmiss != 34:
|
||||
return False
|
||||
return (
|
||||
beatmap.beatmapset.artist == "Function Phantom"
|
||||
and beatmap.beatmapset.title == "Variable"
|
||||
)
|
||||
return beatmap.beatmapset.artist == "Function Phantom" and beatmap.beatmapset.title == "Variable"
|
||||
|
||||
|
||||
async def kaleidoscope(
|
||||
@@ -328,8 +308,7 @@ async def valediction(
|
||||
return (
|
||||
score.passed
|
||||
and beatmap.beatmapset.artist == "a_hisa"
|
||||
and beatmap.beatmapset.title
|
||||
== "Alexithymia | Lupinus | Tokei no Heya to Seishin Sekai"
|
||||
and beatmap.beatmapset.title == "Alexithymia | Lupinus | Tokei no Heya to Seishin Sekai"
|
||||
and score.accuracy >= 0.9
|
||||
)
|
||||
|
||||
@@ -342,9 +321,7 @@ async def right_on_time(
|
||||
# Submit a score on Kola Kid - timer on the first minute of any hour
|
||||
if not score.passed:
|
||||
return False
|
||||
if not (
|
||||
beatmap.beatmapset.artist == "Kola Kid" and beatmap.beatmapset.title == "timer"
|
||||
):
|
||||
if not (beatmap.beatmapset.artist == "Kola Kid" and beatmap.beatmapset.title == "timer"):
|
||||
return False
|
||||
return score.ended_at.minute == 0
|
||||
|
||||
@@ -361,9 +338,7 @@ async def not_again(
|
||||
return False
|
||||
if score.accuracy < 0.99:
|
||||
return False
|
||||
return (
|
||||
beatmap.beatmapset.artist == "ARForest" and beatmap.beatmapset.title == "Regret"
|
||||
)
|
||||
return beatmap.beatmapset.artist == "ARForest" and beatmap.beatmapset.title == "Regret"
|
||||
|
||||
|
||||
async def deliberation(
|
||||
@@ -377,18 +352,13 @@ async def deliberation(
|
||||
mods_ = mod_to_save(score.mods)
|
||||
if "HT" not in mods_:
|
||||
return False
|
||||
if (
|
||||
not beatmap.beatmap_status.has_pp()
|
||||
and beatmap.beatmap_status != BeatmapRankStatus.LOVED
|
||||
):
|
||||
if not beatmap.beatmap_status.has_pp() and beatmap.beatmap_status != BeatmapRankStatus.LOVED:
|
||||
return False
|
||||
|
||||
fetcher = await get_fetcher()
|
||||
redis = get_redis()
|
||||
mods_copy = score.mods.copy()
|
||||
attribute = await calculate_beatmap_attributes(
|
||||
beatmap.id, score.gamemode, mods_copy, redis, fetcher
|
||||
)
|
||||
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_copy, redis, fetcher)
|
||||
return attribute.star_rating >= 6
|
||||
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ MEDALS: Medals = {
|
||||
Achievement(
|
||||
id=93,
|
||||
name="Sweet Rave Party",
|
||||
desc="Founded in the fine tradition of changing things that were just fine as they were.", # noqa: E501
|
||||
desc="Founded in the fine tradition of changing things that were just fine as they were.",
|
||||
assets_id="all-intro-nightcore",
|
||||
): partial(process_mod, "NC"),
|
||||
Achievement(
|
||||
|
||||
@@ -16,11 +16,7 @@ async def process_combo(
|
||||
score: Score,
|
||||
beatmap: Beatmap,
|
||||
) -> bool:
|
||||
if (
|
||||
not score.passed
|
||||
or not beatmap.beatmap_status.has_pp()
|
||||
or score.gamemode != GameMode.OSU
|
||||
):
|
||||
if not score.passed or not beatmap.beatmap_status.has_pp() or score.gamemode != GameMode.OSU:
|
||||
return False
|
||||
if combo < 1:
|
||||
return False
|
||||
|
||||
@@ -44,9 +44,7 @@ async def process_skill(
|
||||
redis = get_redis()
|
||||
mods_ = score.mods.copy()
|
||||
mods_.sort(key=lambda x: x["acronym"])
|
||||
attribute = await calculate_beatmap_attributes(
|
||||
beatmap.id, score.gamemode, mods_, redis, fetcher
|
||||
)
|
||||
attribute = await calculate_beatmap_attributes(beatmap.id, score.gamemode, mods_, redis, fetcher)
|
||||
if attribute.star_rating < star or attribute.star_rating >= star + 1:
|
||||
return False
|
||||
if type == "fc" and not score.is_perfect_combo:
|
||||
|
||||
46
app/auth.py
46
app/auth.py
@@ -43,9 +43,7 @@ def validate_username(username: str) -> list[str]:
|
||||
|
||||
# 检查用户名格式(只允许字母、数字、下划线、连字符)
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
|
||||
errors.append(
|
||||
"Username can only contain letters, numbers, underscores, and hyphens"
|
||||
)
|
||||
errors.append("Username can only contain letters, numbers, underscores, and hyphens")
|
||||
|
||||
# 检查是否以数字开头
|
||||
if username[0].isdigit():
|
||||
@@ -104,9 +102,7 @@ def get_password_hash(password: str) -> str:
|
||||
return pw_bcrypt.decode()
|
||||
|
||||
|
||||
async def authenticate_user_legacy(
|
||||
db: AsyncSession, name: str, password: str
|
||||
) -> User | None:
|
||||
async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) -> User | None:
|
||||
"""
|
||||
验证用户身份 - 使用类似 from_login 的逻辑
|
||||
"""
|
||||
@@ -145,9 +141,7 @@ async def authenticate_user_legacy(
|
||||
return None
|
||||
|
||||
|
||||
async def authenticate_user(
|
||||
db: AsyncSession, username: str, password: str
|
||||
) -> User | None:
|
||||
async def authenticate_user(db: AsyncSession, username: str, password: str) -> User | None:
|
||||
"""验证用户身份"""
|
||||
return await authenticate_user_legacy(db, username, password)
|
||||
|
||||
@@ -158,14 +152,10 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s
|
||||
if expires_delta:
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(UTC) + timedelta(
|
||||
minutes=settings.access_token_expire_minutes
|
||||
)
|
||||
expire = datetime.now(UTC) + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
|
||||
to_encode.update({"exp": expire, "random": secrets.token_hex(16)})
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.secret_key, algorithm=settings.algorithm
|
||||
)
|
||||
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
@@ -178,20 +168,20 @@ def generate_refresh_token() -> str:
|
||||
|
||||
async def invalidate_user_tokens(db: AsyncSession, user_id: int) -> int:
|
||||
"""使指定用户的所有令牌失效
|
||||
|
||||
|
||||
返回删除的令牌数量
|
||||
"""
|
||||
# 使用 select 先获取所有令牌
|
||||
stmt = select(OAuthToken).where(OAuthToken.user_id == user_id)
|
||||
result = await db.exec(stmt)
|
||||
tokens = result.all()
|
||||
|
||||
|
||||
# 逐个删除令牌
|
||||
count = 0
|
||||
for token in tokens:
|
||||
await db.delete(token)
|
||||
count += 1
|
||||
|
||||
|
||||
# 提交更改
|
||||
await db.commit()
|
||||
return count
|
||||
@@ -200,9 +190,7 @@ async def invalidate_user_tokens(db: AsyncSession, user_id: int) -> int:
|
||||
def verify_token(token: str) -> dict | None:
|
||||
"""验证访问令牌"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.secret_key, algorithms=[settings.algorithm]
|
||||
)
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
@@ -221,17 +209,13 @@ async def store_token(
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
# 删除用户的旧令牌
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.user_id == user_id, OAuthToken.client_id == client_id
|
||||
)
|
||||
statement = select(OAuthToken).where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id)
|
||||
old_tokens = (await db.exec(statement)).all()
|
||||
for token in old_tokens:
|
||||
await db.delete(token)
|
||||
|
||||
# 检查是否有重复的 access_token
|
||||
duplicate_token = (
|
||||
await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))
|
||||
).first()
|
||||
duplicate_token = (await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))).first()
|
||||
if duplicate_token:
|
||||
await db.delete(duplicate_token)
|
||||
|
||||
@@ -250,9 +234,7 @@ async def store_token(
|
||||
return token_record
|
||||
|
||||
|
||||
async def get_token_by_access_token(
|
||||
db: AsyncSession, access_token: str
|
||||
) -> OAuthToken | None:
|
||||
async def get_token_by_access_token(db: AsyncSession, access_token: str) -> OAuthToken | None:
|
||||
"""根据访问令牌获取令牌记录"""
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.access_token == access_token,
|
||||
@@ -261,9 +243,7 @@ async def get_token_by_access_token(
|
||||
return (await db.exec(statement)).first()
|
||||
|
||||
|
||||
async def get_token_by_refresh_token(
|
||||
db: AsyncSession, refresh_token: str
|
||||
) -> OAuthToken | None:
|
||||
async def get_token_by_refresh_token(db: AsyncSession, refresh_token: str) -> OAuthToken | None:
|
||||
"""根据刷新令牌获取令牌记录"""
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.refresh_token == refresh_token,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
import math
|
||||
@@ -67,11 +68,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
||||
|
||||
if settings.suspicious_score_check:
|
||||
beatmap_banned = (
|
||||
await session.exec(
|
||||
select(exists()).where(
|
||||
col(BannedBeatmaps.beatmap_id) == score.beatmap_id
|
||||
)
|
||||
)
|
||||
await session.exec(select(exists()).where(col(BannedBeatmaps.beatmap_id) == score.beatmap_id))
|
||||
).first()
|
||||
if beatmap_banned:
|
||||
return 0
|
||||
@@ -82,12 +79,9 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
||||
logger.warning(f"Beatmap {score.beatmap_id} is suspicious, banned")
|
||||
return 0
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Error checking if beatmap {score.beatmap_id} is suspicious"
|
||||
)
|
||||
logger.exception(f"Error checking if beatmap {score.beatmap_id} is suspicious")
|
||||
|
||||
# 使用线程池执行计算密集型操作以避免阻塞事件循环
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -118,9 +112,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
||||
pp = attrs.pp
|
||||
|
||||
# mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp
|
||||
if settings.suspicious_score_check and (
|
||||
(attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 2300
|
||||
):
|
||||
if settings.suspicious_score_check and ((attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 2300):
|
||||
logger.warning(
|
||||
f"User {score.user_id} played {score.beatmap_id} "
|
||||
f"(star={attrs.difficulty.stars}) with {pp=} "
|
||||
@@ -131,9 +123,7 @@ async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> f
|
||||
return pp
|
||||
|
||||
|
||||
async def pre_fetch_and_calculate_pp(
|
||||
score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher
|
||||
) -> float:
|
||||
async def pre_fetch_and_calculate_pp(score: "Score", beatmap_id: int, session: AsyncSession, redis, fetcher) -> float:
|
||||
"""
|
||||
优化版PP计算:预先获取beatmap文件并使用缓存
|
||||
"""
|
||||
@@ -144,9 +134,7 @@ async def pre_fetch_and_calculate_pp(
|
||||
# 快速检查是否被封禁
|
||||
if settings.suspicious_score_check:
|
||||
beatmap_banned = (
|
||||
await session.exec(
|
||||
select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id)
|
||||
)
|
||||
await session.exec(select(exists()).where(col(BannedBeatmaps.beatmap_id) == beatmap_id))
|
||||
).first()
|
||||
if beatmap_banned:
|
||||
return 0
|
||||
@@ -202,9 +190,7 @@ async def batch_calculate_pp(
|
||||
banned_beatmaps = set()
|
||||
if settings.suspicious_score_check:
|
||||
banned_results = await session.exec(
|
||||
select(BannedBeatmaps.beatmap_id).where(
|
||||
col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids)
|
||||
)
|
||||
select(BannedBeatmaps.beatmap_id).where(col(BannedBeatmaps.beatmap_id).in_(unique_beatmap_ids))
|
||||
)
|
||||
banned_beatmaps = set(banned_results.all())
|
||||
|
||||
@@ -380,9 +366,7 @@ def calculate_score_to_level(total_score: int) -> float:
|
||||
level = 0.0
|
||||
|
||||
while remaining_score > 0:
|
||||
next_level_requirement = to_next_level[
|
||||
min(len(to_next_level) - 1, round(level))
|
||||
]
|
||||
next_level_requirement = to_next_level[min(len(to_next_level) - 1, round(level))]
|
||||
level += min(1, remaining_score / next_level_requirement)
|
||||
remaining_score -= next_level_requirement
|
||||
|
||||
@@ -417,9 +401,7 @@ class Threshold(int, Enum):
|
||||
NOTE_POSX_THRESHOLD = 512 # x: [-512,512]
|
||||
NOTE_POSY_THRESHOLD = 384 # y: [-384,384]
|
||||
|
||||
POS_ERROR_THRESHOLD = (
|
||||
1280 * 50
|
||||
) # 超过这么多个物件(包括滑条控制点)的位置有问题就毙掉
|
||||
POS_ERROR_THRESHOLD = 1280 * 50 # 超过这么多个物件(包括滑条控制点)的位置有问题就毙掉
|
||||
|
||||
SLIDER_REPEAT_THRESHOLD = 5000
|
||||
|
||||
@@ -469,10 +451,7 @@ def is_2b(hit_objects: list[HitObject]) -> bool:
|
||||
def is_suspicious_beatmap(content: str) -> bool:
|
||||
osufile = OsuFile(content=content.encode("utf-8")).parse_file()
|
||||
|
||||
if (
|
||||
osufile.hit_objects[-1].start_time - osufile.hit_objects[0].start_time
|
||||
> 24 * 60 * 60 * 1000
|
||||
):
|
||||
if osufile.hit_objects[-1].start_time - osufile.hit_objects[0].start_time > 24 * 60 * 60 * 1000:
|
||||
return True
|
||||
if osufile.mode == int(GameMode.TAIKO):
|
||||
if len(osufile.hit_objects) > Threshold.TAIKO_THRESHOLD:
|
||||
|
||||
@@ -124,14 +124,10 @@ class Settings(BaseSettings):
|
||||
smtp_password: str = ""
|
||||
from_email: str = "noreply@example.com"
|
||||
from_name: str = "osu! server"
|
||||
|
||||
|
||||
# 邮件验证功能开关
|
||||
enable_email_verification: bool = Field(
|
||||
default=True, description="是否启用邮件验证功能"
|
||||
)
|
||||
enable_email_sending: bool = Field(
|
||||
default=False, description="是否真实发送邮件(False时仅模拟发送)"
|
||||
)
|
||||
enable_email_verification: bool = Field(default=True, description="是否启用邮件验证功能")
|
||||
enable_email_sending: bool = Field(default=False, description="是否真实发送邮件(False时仅模拟发送)")
|
||||
|
||||
# Sentry 配置
|
||||
sentry_dsn: HttpUrl | None = None
|
||||
@@ -143,12 +139,8 @@ class Settings(BaseSettings):
|
||||
geoip_update_hour: int = 2 # 每周更新的小时数(0-23)
|
||||
|
||||
# 游戏设置
|
||||
enable_rx: bool = Field(
|
||||
default=False, validation_alias=AliasChoices("enable_rx", "enable_osu_rx")
|
||||
)
|
||||
enable_ap: bool = Field(
|
||||
default=False, validation_alias=AliasChoices("enable_ap", "enable_osu_ap")
|
||||
)
|
||||
enable_rx: bool = Field(default=False, validation_alias=AliasChoices("enable_rx", "enable_osu_rx"))
|
||||
enable_ap: bool = Field(default=False, validation_alias=AliasChoices("enable_ap", "enable_osu_ap"))
|
||||
enable_all_mods_pp: bool = False
|
||||
enable_supporter_for_all_users: bool = False
|
||||
enable_all_beatmap_leaderboard: bool = False
|
||||
@@ -189,9 +181,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# 存储设置
|
||||
storage_service: StorageServiceType = StorageServiceType.LOCAL
|
||||
storage_settings: (
|
||||
LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings
|
||||
) = LocalStorageSettings()
|
||||
storage_settings: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings = LocalStorageSettings()
|
||||
|
||||
@field_validator("fetcher_scopes", mode="before")
|
||||
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
|
||||
@@ -207,22 +197,13 @@ class Settings(BaseSettings):
|
||||
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
|
||||
if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2:
|
||||
if not isinstance(v, CloudflareR2Settings):
|
||||
raise ValueError(
|
||||
"When storage_service is 'r2', "
|
||||
"storage_settings must be CloudflareR2Settings"
|
||||
)
|
||||
raise ValueError("When storage_service is 'r2', storage_settings must be CloudflareR2Settings")
|
||||
elif info.data.get("storage_service") == StorageServiceType.LOCAL:
|
||||
if not isinstance(v, LocalStorageSettings):
|
||||
raise ValueError(
|
||||
"When storage_service is 'local', "
|
||||
"storage_settings must be LocalStorageSettings"
|
||||
)
|
||||
raise ValueError("When storage_service is 'local', storage_settings must be LocalStorageSettings")
|
||||
elif info.data.get("storage_service") == StorageServiceType.AWS_S3:
|
||||
if not isinstance(v, AWSS3StorageSettings):
|
||||
raise ValueError(
|
||||
"When storage_service is 's3', "
|
||||
"storage_settings must be AWSS3StorageSettings"
|
||||
)
|
||||
raise ValueError("When storage_service is 's3', storage_settings must be AWSS3StorageSettings")
|
||||
return v
|
||||
|
||||
|
||||
|
||||
@@ -28,18 +28,14 @@ if TYPE_CHECKING:
|
||||
|
||||
class UserAchievementBase(SQLModel, UTCBaseModel):
|
||||
achievement_id: int
|
||||
achieved_at: datetime = Field(
|
||||
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
achieved_at: datetime = Field(default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)))
|
||||
|
||||
|
||||
class UserAchievement(UserAchievementBase, table=True):
|
||||
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "lazer_user_achievements"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")), exclude=True)
|
||||
user: "User" = Relationship(back_populates="achievement")
|
||||
|
||||
|
||||
@@ -56,11 +52,7 @@ async def process_achievements(session: AsyncSession, redis: Redis, score_id: in
|
||||
if not score:
|
||||
return
|
||||
achieved = (
|
||||
await session.exec(
|
||||
select(UserAchievement.achievement_id).where(
|
||||
UserAchievement.user_id == score.user_id
|
||||
)
|
||||
)
|
||||
await session.exec(select(UserAchievement.achievement_id).where(UserAchievement.user_id == score.user_id))
|
||||
).all()
|
||||
not_achieved = {k: v for k, v in MEDALS.items() if k.id not in achieved}
|
||||
result: list[Achievement] = []
|
||||
@@ -78,9 +70,7 @@ async def process_achievements(session: AsyncSession, redis: Redis, score_id: in
|
||||
)
|
||||
await redis.publish(
|
||||
"chat:notification",
|
||||
UserAchievementUnlock.init(
|
||||
r, score.user_id, score.gamemode
|
||||
).model_dump_json(),
|
||||
UserAchievementUnlock.init(r, score.user_id, score.gamemode).model_dump_json(),
|
||||
)
|
||||
event = Event(
|
||||
created_at=now,
|
||||
|
||||
@@ -20,42 +20,34 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class OAuthToken(UTCBaseModel, SQLModel, table=True):
|
||||
__tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "oauth_tokens"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
client_id: int = Field(index=True)
|
||||
access_token: str = Field(max_length=500, unique=True)
|
||||
refresh_token: str = Field(max_length=500, unique=True)
|
||||
token_type: str = Field(default="Bearer", max_length=20)
|
||||
scope: str = Field(default="*", max_length=100)
|
||||
expires_at: datetime = Field(sa_column=Column(DateTime))
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
|
||||
|
||||
user: "User" = Relationship()
|
||||
|
||||
|
||||
class OAuthClient(SQLModel, table=True):
|
||||
__tablename__ = "oauth_clients" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "oauth_clients"
|
||||
name: str = Field(max_length=100, index=True)
|
||||
description: str = Field(sa_column=Column(Text), default="")
|
||||
client_id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
client_secret: str = Field(default_factory=secrets.token_hex, index=True)
|
||||
redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
owner_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
|
||||
|
||||
class V1APIKeys(SQLModel, table=True):
|
||||
__tablename__ = "v1_api_keys" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "v1_api_keys"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
name: str = Field(max_length=100, index=True)
|
||||
key: str = Field(default_factory=secrets.token_hex, index=True)
|
||||
owner_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
|
||||
@@ -60,17 +60,13 @@ class BeatmapBase(SQLModel):
|
||||
|
||||
|
||||
class Beatmap(BeatmapBase, table=True):
|
||||
__tablename__ = "beatmaps" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "beatmaps"
|
||||
id: int = Field(primary_key=True, index=True)
|
||||
beatmapset_id: int = Field(foreign_key="beatmapsets.id", index=True)
|
||||
beatmap_status: BeatmapRankStatus = Field(index=True)
|
||||
# optional
|
||||
beatmapset: Beatmapset = Relationship(
|
||||
back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
failtimes: FailTime | None = Relationship(
|
||||
back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
beatmapset: Beatmapset = Relationship(back_populates="beatmaps", sa_relationship_kwargs={"lazy": "joined"})
|
||||
failtimes: FailTime | None = Relationship(back_populates="beatmap", sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@classmethod
|
||||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||
@@ -84,21 +80,15 @@ class Beatmap(BeatmapBase, table=True):
|
||||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
||||
}
|
||||
)
|
||||
if not (
|
||||
await session.exec(select(exists()).where(Beatmap.id == resp.id))
|
||||
).first():
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||||
session.add(beatmap)
|
||||
await session.commit()
|
||||
beatmap = (
|
||||
await session.exec(select(Beatmap).where(Beatmap.id == resp.id))
|
||||
).first()
|
||||
beatmap = (await session.exec(select(Beatmap).where(Beatmap.id == resp.id))).first()
|
||||
assert beatmap is not None, "Beatmap should not be None after commit"
|
||||
return beatmap
|
||||
|
||||
@classmethod
|
||||
async def from_resp_batch(
|
||||
cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0
|
||||
) -> list["Beatmap"]:
|
||||
async def from_resp_batch(cls, session: AsyncSession, inp: list["BeatmapResp"], from_: int = 0) -> list["Beatmap"]:
|
||||
beatmaps = []
|
||||
for resp in inp:
|
||||
if resp.id == from_:
|
||||
@@ -113,9 +103,7 @@ class Beatmap(BeatmapBase, table=True):
|
||||
"beatmap_status": BeatmapRankStatus(resp.ranked),
|
||||
}
|
||||
)
|
||||
if not (
|
||||
await session.exec(select(exists()).where(Beatmap.id == resp.id))
|
||||
).first():
|
||||
if not (await session.exec(select(exists()).where(Beatmap.id == resp.id))).first():
|
||||
session.add(beatmap)
|
||||
beatmaps.append(beatmap)
|
||||
await session.commit()
|
||||
@@ -130,17 +118,11 @@ class Beatmap(BeatmapBase, table=True):
|
||||
md5: str | None = None,
|
||||
) -> "Beatmap":
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap).where(
|
||||
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
|
||||
)
|
||||
)
|
||||
await session.exec(select(Beatmap).where(Beatmap.id == bid if bid is not None else Beatmap.checksum == md5))
|
||||
).first()
|
||||
if not beatmap:
|
||||
resp = await fetcher.get_beatmap(bid, md5)
|
||||
r = await session.exec(
|
||||
select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)
|
||||
)
|
||||
r = await session.exec(select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id))
|
||||
if not r.first():
|
||||
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
|
||||
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
|
||||
@@ -178,10 +160,7 @@ class BeatmapResp(BeatmapBase):
|
||||
if query_mode is not None and beatmap.mode != query_mode:
|
||||
beatmap_["convert"] = True
|
||||
beatmap_["is_scoreable"] = beatmap_status.has_leaderboard()
|
||||
if (
|
||||
settings.enable_all_beatmap_leaderboard
|
||||
and not beatmap_status.has_leaderboard()
|
||||
):
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||
beatmap_["ranked"] = BeatmapRankStatus.APPROVED.value
|
||||
beatmap_["status"] = BeatmapRankStatus.APPROVED.name.lower()
|
||||
else:
|
||||
@@ -189,9 +168,7 @@ class BeatmapResp(BeatmapBase):
|
||||
beatmap_["ranked"] = beatmap_status.value
|
||||
beatmap_["mode_int"] = int(beatmap.mode)
|
||||
if not from_set:
|
||||
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(
|
||||
beatmap.beatmapset, session=session, user=user
|
||||
)
|
||||
beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset, session=session, user=user)
|
||||
if beatmap.failtimes is not None:
|
||||
beatmap_["failtimes"] = FailTimeResp.from_db(beatmap.failtimes)
|
||||
else:
|
||||
@@ -218,7 +195,7 @@ class BeatmapResp(BeatmapBase):
|
||||
|
||||
|
||||
class BannedBeatmaps(SQLModel, table=True):
|
||||
__tablename__ = "banned_beatmaps" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "banned_beatmaps"
|
||||
id: int | None = Field(primary_key=True, index=True, default=None)
|
||||
beatmap_id: int = Field(index=True)
|
||||
|
||||
@@ -230,15 +207,10 @@ async def calculate_beatmap_attributes(
|
||||
redis: Redis,
|
||||
fetcher: "Fetcher",
|
||||
):
|
||||
key = (
|
||||
f"beatmap:{beatmap_id}:{ruleset}:"
|
||||
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
)
|
||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
if await redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key))
|
||||
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||
attr = await asyncio.get_event_loop().run_in_executor(
|
||||
None, calculate_beatmap_attribute, resp, ruleset, mods_
|
||||
)
|
||||
attr = await asyncio.get_event_loop().run_in_executor(None, calculate_beatmap_attribute, resp, ruleset, mods_)
|
||||
await redis.set(key, attr.model_dump_json())
|
||||
return attr
|
||||
|
||||
@@ -23,15 +23,13 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class BeatmapPlaycounts(AsyncAttrs, SQLModel, table=True):
|
||||
__tablename__ = "beatmap_playcounts" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "beatmap_playcounts"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
|
||||
)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
||||
playcount: int = Field(default=0)
|
||||
|
||||
@@ -59,9 +57,7 @@ class BeatmapPlaycountsResp(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
async def process_beatmap_playcount(
|
||||
session: AsyncSession, user_id: int, beatmap_id: int
|
||||
) -> None:
|
||||
async def process_beatmap_playcount(session: AsyncSession, user_id: int, beatmap_id: int) -> None:
|
||||
existing_playcount = (
|
||||
await session.exec(
|
||||
select(BeatmapPlaycounts).where(
|
||||
@@ -89,7 +85,5 @@ async def process_beatmap_playcount(
|
||||
}
|
||||
session.add(playcount_event)
|
||||
else:
|
||||
new_playcount = BeatmapPlaycounts(
|
||||
user_id=user_id, beatmap_id=beatmap_id, playcount=1
|
||||
)
|
||||
new_playcount = BeatmapPlaycounts(user_id=user_id, beatmap_id=beatmap_id, playcount=1)
|
||||
session.add(new_playcount)
|
||||
|
||||
@@ -86,9 +86,7 @@ class BeatmapsetBase(SQLModel):
|
||||
|
||||
# optional
|
||||
# converts: list[Beatmap] = Relationship(back_populates="beatmapset")
|
||||
current_nominations: list[BeatmapNomination] | None = Field(
|
||||
None, sa_column=Column(JSON)
|
||||
)
|
||||
current_nominations: list[BeatmapNomination] | None = Field(None, sa_column=Column(JSON))
|
||||
description: BeatmapDescription | None = Field(default=None, sa_column=Column(JSON))
|
||||
# TODO: discussions: list[BeatmapsetDiscussion] = None
|
||||
# TODO: current_user_attributes: Optional[CurrentUserAttributes] = None
|
||||
@@ -105,22 +103,18 @@ class BeatmapsetBase(SQLModel):
|
||||
can_be_hyped: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
discussion_locked: bool = Field(default=False, sa_column=Column(Boolean))
|
||||
last_updated: datetime = Field(sa_column=Column(DateTime, index=True))
|
||||
ranked_date: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime, index=True)
|
||||
)
|
||||
ranked_date: datetime | None = Field(default=None, sa_column=Column(DateTime, index=True))
|
||||
storyboard: bool = Field(default=False, sa_column=Column(Boolean, index=True))
|
||||
submitted_date: datetime = Field(sa_column=Column(DateTime, index=True))
|
||||
tags: str = Field(default="", sa_column=Column(Text))
|
||||
|
||||
|
||||
class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
__tablename__ = "beatmapsets" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "beatmapsets"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
# Beatmapset
|
||||
beatmap_status: BeatmapRankStatus = Field(
|
||||
default=BeatmapRankStatus.GRAVEYARD, index=True
|
||||
)
|
||||
beatmap_status: BeatmapRankStatus = Field(default=BeatmapRankStatus.GRAVEYARD, index=True)
|
||||
|
||||
# optional
|
||||
beatmaps: list["Beatmap"] = Relationship(back_populates="beatmapset")
|
||||
@@ -137,9 +131,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset")
|
||||
|
||||
@classmethod
|
||||
async def from_resp(
|
||||
cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0
|
||||
) -> "Beatmapset":
|
||||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapsetResp", from_: int = 0) -> "Beatmapset":
|
||||
from .beatmap import Beatmap
|
||||
|
||||
d = resp.model_dump()
|
||||
@@ -167,18 +159,14 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
||||
"download_disabled": resp.availability.download_disabled or False,
|
||||
}
|
||||
)
|
||||
if not (
|
||||
await session.exec(select(exists()).where(Beatmapset.id == resp.id))
|
||||
).first():
|
||||
if not (await session.exec(select(exists()).where(Beatmapset.id == resp.id))).first():
|
||||
session.add(beatmapset)
|
||||
await session.commit()
|
||||
await Beatmap.from_resp_batch(session, resp.beatmaps, from_=from_)
|
||||
return beatmapset
|
||||
|
||||
@classmethod
|
||||
async def get_or_fetch(
|
||||
cls, session: AsyncSession, fetcher: "Fetcher", sid: int
|
||||
) -> "Beatmapset":
|
||||
async def get_or_fetch(cls, session: AsyncSession, fetcher: "Fetcher", sid: int) -> "Beatmapset":
|
||||
beatmapset = await session.get(Beatmapset, sid)
|
||||
if not beatmapset:
|
||||
resp = await fetcher.get_beatmapset(sid)
|
||||
@@ -227,13 +215,9 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
@model_validator(mode="after")
|
||||
def fix_genre_language(self) -> Self:
|
||||
if self.genre is None:
|
||||
self.genre = BeatmapTranslationText(
|
||||
name=Genre(self.genre_id).name, id=self.genre_id
|
||||
)
|
||||
self.genre = BeatmapTranslationText(name=Genre(self.genre_id).name, id=self.genre_id)
|
||||
if self.language is None:
|
||||
self.language = BeatmapTranslationText(
|
||||
name=Language(self.language_id).name, id=self.language_id
|
||||
)
|
||||
self.language = BeatmapTranslationText(name=Language(self.language_id).name, id=self.language_id)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@@ -252,9 +236,7 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
await BeatmapResp.from_db(beatmap, from_set=True, session=session)
|
||||
for beatmap in await beatmapset.awaitable_attrs.beatmaps
|
||||
],
|
||||
"hype": BeatmapHype(
|
||||
current=beatmapset.hype_current, required=beatmapset.hype_required
|
||||
),
|
||||
"hype": BeatmapHype(current=beatmapset.hype_current, required=beatmapset.hype_required),
|
||||
"availability": BeatmapAvailability(
|
||||
more_information=beatmapset.availability_info,
|
||||
download_disabled=beatmapset.download_disabled,
|
||||
@@ -282,10 +264,7 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
update["ratings"] = []
|
||||
|
||||
beatmap_status = beatmapset.beatmap_status
|
||||
if (
|
||||
settings.enable_all_beatmap_leaderboard
|
||||
and not beatmap_status.has_leaderboard()
|
||||
):
|
||||
if settings.enable_all_beatmap_leaderboard and not beatmap_status.has_leaderboard():
|
||||
update["status"] = BeatmapRankStatus.APPROVED.name.lower()
|
||||
update["ranked"] = BeatmapRankStatus.APPROVED.value
|
||||
else:
|
||||
@@ -295,9 +274,7 @@ class BeatmapsetResp(BeatmapsetBase):
|
||||
if session and user:
|
||||
existing_favourite = (
|
||||
await session.exec(
|
||||
select(FavouriteBeatmapset).where(
|
||||
FavouriteBeatmapset.beatmapset_id == beatmapset.id
|
||||
)
|
||||
select(FavouriteBeatmapset).where(FavouriteBeatmapset.beatmapset_id == beatmapset.id)
|
||||
)
|
||||
).first()
|
||||
update["has_favourited"] = existing_favourite is not None
|
||||
|
||||
@@ -20,13 +20,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class BestScore(SQLModel, table=True):
|
||||
__tablename__ = "total_score_best_scores" # pyright: ignore[reportAssignmentType]
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
score_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
|
||||
)
|
||||
__tablename__: str = "total_score_best_scores"
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
||||
gamemode: GameMode = Field(index=True)
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
|
||||
@@ -51,30 +51,22 @@ class ChatChannelBase(SQLModel):
|
||||
|
||||
|
||||
class ChatChannel(ChatChannelBase, table=True):
|
||||
__tablename__ = "chat_channels" # pyright: ignore[reportAssignmentType]
|
||||
channel_id: int | None = Field(primary_key=True, index=True, default=None)
|
||||
__tablename__: str = "chat_channels"
|
||||
channel_id: int = Field(primary_key=True, index=True, default=None)
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls, channel: str | int, session: AsyncSession
|
||||
) -> "ChatChannel | None":
|
||||
async def get(cls, channel: str | int, session: AsyncSession) -> "ChatChannel | None":
|
||||
if isinstance(channel, int) or channel.isdigit():
|
||||
# 使用查询而不是 get() 来确保对象完全加载
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))
|
||||
channel_ = result.first()
|
||||
if channel_ is not None:
|
||||
return channel_
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel)
|
||||
)
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
return result.first()
|
||||
|
||||
@classmethod
|
||||
async def get_pm_channel(
|
||||
cls, user1: int, user2: int, session: AsyncSession
|
||||
) -> "ChatChannel | None":
|
||||
async def get_pm_channel(cls, user1: int, user2: int, session: AsyncSession) -> "ChatChannel | None":
|
||||
channel = await cls.get(f"pm_{user1}_{user2}", session)
|
||||
if channel is None:
|
||||
channel = await cls.get(f"pm_{user2}_{user1}", session)
|
||||
@@ -153,18 +145,13 @@ class ChatChannelResp(ChatChannelBase):
|
||||
.limit(10)
|
||||
)
|
||||
).all()
|
||||
c.recent_messages = [
|
||||
await ChatMessageResp.from_db(msg, session, user) for msg in messages
|
||||
]
|
||||
c.recent_messages = [await ChatMessageResp.from_db(msg, session, user) for msg in messages]
|
||||
c.recent_messages.reverse()
|
||||
|
||||
if c.type == ChannelType.PM and users and len(users) == 2:
|
||||
target_user_id = next(u for u in users if u != user.id)
|
||||
target_name = await session.exec(
|
||||
select(User.username).where(User.id == target_user_id)
|
||||
)
|
||||
target_name = await session.exec(select(User.username).where(User.id == target_user_id))
|
||||
c.name = target_name.one()
|
||||
assert user.id
|
||||
c.users = [target_user_id, user.id]
|
||||
return c
|
||||
|
||||
@@ -181,19 +168,15 @@ class MessageType(str, Enum):
|
||||
class ChatMessageBase(UTCBaseModel, SQLModel):
|
||||
channel_id: int = Field(index=True, foreign_key="chat_channels.channel_id")
|
||||
content: str = Field(sa_column=Column(VARCHAR(1000)))
|
||||
message_id: int | None = Field(index=True, primary_key=True, default=None)
|
||||
sender_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||
)
|
||||
message_id: int = Field(index=True, primary_key=True, default=None)
|
||||
sender_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
timestamp: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC))
|
||||
type: MessageType = Field(default=MessageType.PLAIN, index=True, exclude=True)
|
||||
uuid: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ChatMessage(ChatMessageBase, table=True):
|
||||
__tablename__ = "chat_messages" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "chat_messages"
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
channel: ChatChannel = Relationship()
|
||||
|
||||
@@ -211,9 +194,7 @@ class ChatMessageResp(ChatMessageBase):
|
||||
if user:
|
||||
m.sender = await UserResp.from_db(user, session, RANKING_INCLUDES)
|
||||
else:
|
||||
m.sender = await UserResp.from_db(
|
||||
db_message.user, session, RANKING_INCLUDES
|
||||
)
|
||||
m.sender = await UserResp.from_db(db_message.user, session, RANKING_INCLUDES)
|
||||
return m
|
||||
|
||||
|
||||
@@ -221,17 +202,13 @@ class ChatMessageResp(ChatMessageBase):
|
||||
|
||||
|
||||
class SilenceUser(UTCBaseModel, SQLModel, table=True):
|
||||
__tablename__ = "chat_silence_users" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(primary_key=True, default=None, index=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
__tablename__: str = "chat_silence_users"
|
||||
id: int = Field(primary_key=True, default=None, index=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
channel_id: int = Field(foreign_key="chat_channels.channel_id", index=True)
|
||||
until: datetime | None = Field(sa_column=Column(DateTime, index=True), default=None)
|
||||
reason: str | None = Field(default=None, sa_column=Column(VARCHAR(255), index=True))
|
||||
banned_at: datetime = Field(
|
||||
sa_column=Column(DateTime, index=True), default=datetime.now(UTC)
|
||||
)
|
||||
banned_at: datetime = Field(sa_column=Column(DateTime, index=True), default=datetime.now(UTC))
|
||||
|
||||
|
||||
class UserSilenceResp(SQLModel):
|
||||
@@ -240,7 +217,6 @@ class UserSilenceResp(SQLModel):
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_silence: SilenceUser) -> "UserSilenceResp":
|
||||
assert db_silence.id is not None
|
||||
return cls(
|
||||
id=db_silence.id,
|
||||
user_id=db_silence.user_id,
|
||||
|
||||
@@ -21,28 +21,24 @@ class CountBase(SQLModel):
|
||||
|
||||
|
||||
class MonthlyPlaycounts(CountBase, table=True):
|
||||
__tablename__ = "monthly_playcounts" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "monthly_playcounts"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
|
||||
)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
user: "User" = Relationship(back_populates="monthly_playcounts")
|
||||
|
||||
|
||||
class ReplayWatchedCount(CountBase, table=True):
|
||||
__tablename__ = "replays_watched_counts" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "replays_watched_counts"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True),
|
||||
)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
user: "User" = Relationship(back_populates="replays_watched_counts")
|
||||
|
||||
|
||||
|
||||
@@ -24,9 +24,7 @@ class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
|
||||
daily_streak_best: int = Field(default=0)
|
||||
daily_streak_current: int = Field(default=0)
|
||||
last_update: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||||
last_weekly_streak: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime)
|
||||
)
|
||||
last_weekly_streak: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||||
playcount: int = Field(default=0)
|
||||
top_10p_placements: int = Field(default=0)
|
||||
top_50p_placements: int = Field(default=0)
|
||||
@@ -35,7 +33,7 @@ class DailyChallengeStatsBase(SQLModel, UTCBaseModel):
|
||||
|
||||
|
||||
class DailyChallengeStats(DailyChallengeStatsBase, table=True):
|
||||
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "daily_challenge_stats"
|
||||
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
@@ -61,9 +59,7 @@ class DailyChallengeStatsResp(DailyChallengeStatsBase):
|
||||
return cls.model_validate(obj)
|
||||
|
||||
|
||||
async def process_daily_challenge_score(
|
||||
session: AsyncSession, user_id: int, room_id: int
|
||||
):
|
||||
async def process_daily_challenge_score(session: AsyncSession, user_id: int, room_id: int):
|
||||
from .playlist_best_score import PlaylistBestScore
|
||||
|
||||
score = (
|
||||
|
||||
@@ -4,16 +4,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, UTC
|
||||
from sqlmodel import SQLModel, Field
|
||||
from sqlalchemy import Column, BigInteger, ForeignKey
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import BigInteger, Column, ForeignKey
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class EmailVerification(SQLModel, table=True):
|
||||
"""邮件验证记录"""
|
||||
|
||||
|
||||
__tablename__: str = "email_verifications"
|
||||
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
|
||||
email: str = Field(index=True)
|
||||
@@ -28,9 +29,9 @@ class EmailVerification(SQLModel, table=True):
|
||||
|
||||
class LoginSession(SQLModel, table=True):
|
||||
"""登录会话记录"""
|
||||
|
||||
|
||||
__tablename__: str = "login_sessions"
|
||||
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
|
||||
session_token: str = Field(unique=True, index=True) # 会话令牌
|
||||
|
||||
@@ -36,17 +36,13 @@ class EventType(str, Enum):
|
||||
|
||||
class EventBase(SQLModel):
|
||||
id: int = Field(default=None, primary_key=True)
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), default=datetime.now(UTC))
|
||||
)
|
||||
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), default=datetime.now(UTC)))
|
||||
type: EventType
|
||||
event_payload: dict = Field(
|
||||
exclude=True, default_factory=dict, sa_column=Column(JSON)
|
||||
)
|
||||
event_payload: dict = Field(exclude=True, default_factory=dict, sa_column=Column(JSON))
|
||||
|
||||
|
||||
class Event(EventBase, table=True):
|
||||
__tablename__ = "user_events" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "user_events"
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True),
|
||||
|
||||
@@ -16,8 +16,8 @@ FAILTIME_STRUCT = Struct("<100i")
|
||||
|
||||
|
||||
class FailTime(SQLModel, table=True):
|
||||
__tablename__ = "failtime" # pyright: ignore[reportAssignmentType]
|
||||
beatmap_id: int = Field(primary_key=True, index=True, foreign_key="beatmaps.id")
|
||||
__tablename__: str = "failtime"
|
||||
beatmap_id: int = Field(primary_key=True, foreign_key="beatmaps.id")
|
||||
exit: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False))
|
||||
fail: bytes = Field(sa_column=Column(VARBINARY(400), nullable=False))
|
||||
|
||||
@@ -41,12 +41,8 @@ class FailTime(SQLModel, table=True):
|
||||
|
||||
|
||||
class FailTimeResp(BaseModel):
|
||||
exit: list[int] = Field(
|
||||
default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400))
|
||||
)
|
||||
fail: list[int] = Field(
|
||||
default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400))
|
||||
)
|
||||
exit: list[int] = Field(default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)))
|
||||
fail: list[int] = Field(default_factory=lambda: list(FAILTIME_STRUCT.unpack(b"\x00" * 400)))
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, failtime: FailTime) -> "FailTimeResp":
|
||||
|
||||
@@ -16,7 +16,7 @@ from sqlmodel import (
|
||||
|
||||
|
||||
class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True):
|
||||
__tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "favourite_beatmapset"
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
|
||||
|
||||
@@ -75,9 +75,7 @@ class UserBase(UTCBaseModel, SQLModel):
|
||||
is_active: bool = True
|
||||
is_bot: bool = False
|
||||
is_supporter: bool = False
|
||||
last_visit: datetime | None = Field(
|
||||
default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
last_visit: datetime | None = Field(default=datetime.now(UTC), sa_column=Column(DateTime(timezone=True)))
|
||||
pm_friends_only: bool = False
|
||||
profile_colour: str | None = None
|
||||
username: str = Field(max_length=32, unique=True, index=True)
|
||||
@@ -90,9 +88,7 @@ class UserBase(UTCBaseModel, SQLModel):
|
||||
is_restricted: bool = False
|
||||
# blocks
|
||||
cover: UserProfileCover = Field(
|
||||
default=UserProfileCover(
|
||||
url="https://assets.ppy.sh/user-profile-covers/default.jpeg"
|
||||
),
|
||||
default=UserProfileCover(url="https://assets.ppy.sh/user-profile-covers/default.jpeg"),
|
||||
sa_column=Column(JSON),
|
||||
)
|
||||
beatmap_playcounts_count: int = 0
|
||||
@@ -150,9 +146,9 @@ class UserBase(UTCBaseModel, SQLModel):
|
||||
|
||||
|
||||
class User(AsyncAttrs, UserBase, table=True):
|
||||
__tablename__ = "lazer_users" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "lazer_users"
|
||||
|
||||
id: int | None = Field(
|
||||
id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
|
||||
)
|
||||
@@ -160,16 +156,10 @@ class User(AsyncAttrs, UserBase, table=True):
|
||||
statistics: list[UserStatistics] = Relationship()
|
||||
achievement: list[UserAchievement] = Relationship(back_populates="user")
|
||||
team_membership: TeamMember | None = Relationship(back_populates="user")
|
||||
daily_challenge_stats: DailyChallengeStats | None = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
daily_challenge_stats: DailyChallengeStats | None = Relationship(back_populates="user")
|
||||
monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user")
|
||||
replays_watched_counts: list[ReplayWatchedCount] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
replays_watched_counts: list[ReplayWatchedCount] = Relationship(back_populates="user")
|
||||
favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship(back_populates="user")
|
||||
rank_history: list[RankHistory] = Relationship(
|
||||
back_populates="user",
|
||||
)
|
||||
@@ -178,16 +168,10 @@ class User(AsyncAttrs, UserBase, table=True):
|
||||
email: str = Field(max_length=254, unique=True, index=True, exclude=True)
|
||||
priv: int = Field(default=1, exclude=True)
|
||||
pw_bcrypt: str = Field(max_length=60, exclude=True)
|
||||
silence_end_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
|
||||
)
|
||||
donor_end_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True)), exclude=True
|
||||
)
|
||||
silence_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True)
|
||||
donor_end_at: datetime | None = Field(default=None, sa_column=Column(DateTime(timezone=True)), exclude=True)
|
||||
|
||||
async def is_user_can_pm(
|
||||
self, from_user: "User", session: AsyncSession
|
||||
) -> tuple[bool, str]:
|
||||
async def is_user_can_pm(self, from_user: "User", session: AsyncSession) -> tuple[bool, str]:
|
||||
from .relationship import Relationship, RelationshipType
|
||||
|
||||
from_relationship = (
|
||||
@@ -200,13 +184,10 @@ class User(AsyncAttrs, UserBase, table=True):
|
||||
).first()
|
||||
if from_relationship and from_relationship.type == RelationshipType.BLOCK:
|
||||
return False, "You have blocked the target user."
|
||||
if from_user.pm_friends_only and (
|
||||
not from_relationship or from_relationship.type != RelationshipType.FOLLOW
|
||||
):
|
||||
if from_user.pm_friends_only and (not from_relationship or from_relationship.type != RelationshipType.FOLLOW):
|
||||
return (
|
||||
False,
|
||||
"You have disabled non-friend communications "
|
||||
"and target user is not your friend.",
|
||||
"You have disabled non-friend communications and target user is not your friend.",
|
||||
)
|
||||
|
||||
relationship = (
|
||||
@@ -219,9 +200,7 @@ class User(AsyncAttrs, UserBase, table=True):
|
||||
).first()
|
||||
if relationship and relationship.type == RelationshipType.BLOCK:
|
||||
return False, "Target user has blocked you."
|
||||
if self.pm_friends_only and (
|
||||
not relationship or relationship.type != RelationshipType.FOLLOW
|
||||
):
|
||||
if self.pm_friends_only and (not relationship or relationship.type != RelationshipType.FOLLOW):
|
||||
return False, "Target user has disabled non-friend communications"
|
||||
return True, ""
|
||||
|
||||
@@ -288,9 +267,7 @@ class UserResp(UserBase):
|
||||
u = cls.model_validate(obj.model_dump())
|
||||
u.id = obj.id
|
||||
u.default_group = "bot" if u.is_bot else "default"
|
||||
u.country = Country(
|
||||
code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown")
|
||||
)
|
||||
u.country = Country(code=obj.country_code, name=COUNTRIES.get(obj.country_code, "Unknown"))
|
||||
u.follower_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
@@ -314,9 +291,7 @@ class UserResp(UserBase):
|
||||
redis = get_redis()
|
||||
u.is_online = await redis.exists(f"metadata:online:{obj.id}")
|
||||
u.cover_url = (
|
||||
obj.cover.get(
|
||||
"url", "https://assets.ppy.sh/user-profile-covers/default.jpeg"
|
||||
)
|
||||
obj.cover.get("url", "https://assets.ppy.sh/user-profile-covers/default.jpeg")
|
||||
if obj.cover
|
||||
else "https://assets.ppy.sh/user-profile-covers/default.jpeg"
|
||||
)
|
||||
@@ -335,22 +310,15 @@ class UserResp(UserBase):
|
||||
]
|
||||
|
||||
if "team" in include:
|
||||
if await obj.awaitable_attrs.team_membership:
|
||||
assert obj.team_membership
|
||||
u.team = obj.team_membership.team
|
||||
if team_membership := await obj.awaitable_attrs.team_membership:
|
||||
u.team = team_membership.team
|
||||
|
||||
if "account_history" in include:
|
||||
u.account_history = [
|
||||
UserAccountHistoryResp.from_db(ah)
|
||||
for ah in await obj.awaitable_attrs.account_history
|
||||
]
|
||||
u.account_history = [UserAccountHistoryResp.from_db(ah) for ah in await obj.awaitable_attrs.account_history]
|
||||
|
||||
if "daily_challenge_user_stats":
|
||||
if await obj.awaitable_attrs.daily_challenge_stats:
|
||||
assert obj.daily_challenge_stats
|
||||
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(
|
||||
obj.daily_challenge_stats
|
||||
)
|
||||
if daily_challenge_stats := await obj.awaitable_attrs.daily_challenge_stats:
|
||||
u.daily_challenge_user_stats = DailyChallengeStatsResp.from_db(daily_challenge_stats)
|
||||
|
||||
if "statistics" in include:
|
||||
current_stattistics = None
|
||||
@@ -359,59 +327,40 @@ class UserResp(UserBase):
|
||||
current_stattistics = i
|
||||
break
|
||||
u.statistics = (
|
||||
await UserStatisticsResp.from_db(
|
||||
current_stattistics, session, obj.country_code
|
||||
)
|
||||
await UserStatisticsResp.from_db(current_stattistics, session, obj.country_code)
|
||||
if current_stattistics
|
||||
else None
|
||||
)
|
||||
|
||||
if "statistics_rulesets" in include:
|
||||
u.statistics_rulesets = {
|
||||
i.mode.value: await UserStatisticsResp.from_db(
|
||||
i, session, obj.country_code
|
||||
)
|
||||
i.mode.value: await UserStatisticsResp.from_db(i, session, obj.country_code)
|
||||
for i in await obj.awaitable_attrs.statistics
|
||||
}
|
||||
|
||||
if "monthly_playcounts" in include:
|
||||
u.monthly_playcounts = [
|
||||
CountResp.from_db(pc)
|
||||
for pc in await obj.awaitable_attrs.monthly_playcounts
|
||||
]
|
||||
u.monthly_playcounts = [CountResp.from_db(pc) for pc in await obj.awaitable_attrs.monthly_playcounts]
|
||||
if len(u.monthly_playcounts) == 1:
|
||||
d = u.monthly_playcounts[0].start_date
|
||||
u.monthly_playcounts.insert(
|
||||
0, CountResp(start_date=d - timedelta(days=20), count=0)
|
||||
)
|
||||
u.monthly_playcounts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
|
||||
|
||||
if "replays_watched_counts" in include:
|
||||
u.replay_watched_counts = [
|
||||
CountResp.from_db(rwc)
|
||||
for rwc in await obj.awaitable_attrs.replays_watched_counts
|
||||
CountResp.from_db(rwc) for rwc in await obj.awaitable_attrs.replays_watched_counts
|
||||
]
|
||||
if len(u.replay_watched_counts) == 1:
|
||||
d = u.replay_watched_counts[0].start_date
|
||||
u.replay_watched_counts.insert(
|
||||
0, CountResp(start_date=d - timedelta(days=20), count=0)
|
||||
)
|
||||
u.replay_watched_counts.insert(0, CountResp(start_date=d - timedelta(days=20), count=0))
|
||||
|
||||
if "achievements" in include:
|
||||
u.user_achievements = [
|
||||
UserAchievementResp.from_db(ua)
|
||||
for ua in await obj.awaitable_attrs.achievement
|
||||
]
|
||||
u.user_achievements = [UserAchievementResp.from_db(ua) for ua in await obj.awaitable_attrs.achievement]
|
||||
if "rank_history" in include:
|
||||
rank_history = await RankHistoryResp.from_db(session, obj.id, ruleset)
|
||||
if len(rank_history.data) != 0:
|
||||
u.rank_history = rank_history
|
||||
|
||||
rank_top = (
|
||||
await session.exec(
|
||||
select(RankTop).where(
|
||||
RankTop.user_id == obj.id, RankTop.mode == ruleset
|
||||
)
|
||||
)
|
||||
await session.exec(select(RankTop).where(RankTop.user_id == obj.id, RankTop.mode == ruleset))
|
||||
).first()
|
||||
if rank_top:
|
||||
u.rank_highest = (
|
||||
@@ -425,9 +374,7 @@ class UserResp(UserBase):
|
||||
|
||||
u.favourite_beatmapset_count = (
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(FavouriteBeatmapset)
|
||||
.where(FavouriteBeatmapset.user_id == obj.id)
|
||||
select(func.count()).select_from(FavouriteBeatmapset).where(FavouriteBeatmapset.user_id == obj.id)
|
||||
)
|
||||
).one()
|
||||
u.scores_pinned_count = (
|
||||
@@ -478,17 +425,19 @@ class UserResp(UserBase):
|
||||
# 检查会话验证状态
|
||||
# 如果邮件验证功能被禁用,则始终设置 session_verified 为 true
|
||||
from app.config import settings
|
||||
|
||||
if not settings.enable_email_verification:
|
||||
u.session_verified = True
|
||||
else:
|
||||
# 如果用户有未验证的登录会话,则设置 session_verified 为 false
|
||||
from .email_verification import LoginSession
|
||||
|
||||
unverified_session = (
|
||||
await session.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == obj.id,
|
||||
LoginSession.is_verified == False,
|
||||
LoginSession.expires_at > datetime.now(UTC)
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
LoginSession.expires_at > datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
@@ -30,8 +30,8 @@ class MultiplayerEventBase(SQLModel, UTCBaseModel):
|
||||
|
||||
|
||||
class MultiplayerEvent(MultiplayerEventBase, table=True):
|
||||
__tablename__ = "multiplayer_events" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(
|
||||
__tablename__: str = "multiplayer_events"
|
||||
id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class Notification(SQLModel, table=True):
|
||||
__tablename__ = "notifications" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "notifications"
|
||||
|
||||
id: int = Field(primary_key=True, index=True, default=None)
|
||||
name: NotificationName = Field(index=True)
|
||||
@@ -30,7 +30,7 @@ class Notification(SQLModel, table=True):
|
||||
|
||||
|
||||
class UserNotification(SQLModel, table=True):
|
||||
__tablename__ = "user_notifications" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "user_notifications"
|
||||
id: int = Field(
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
@@ -40,9 +40,7 @@ class UserNotification(SQLModel, table=True):
|
||||
default=None,
|
||||
)
|
||||
notification_id: int = Field(index=True, foreign_key="notifications.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
is_read: bool = Field(index=True)
|
||||
|
||||
notification: Notification = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@@ -4,16 +4,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, UTC
|
||||
from sqlmodel import SQLModel, Field
|
||||
from sqlalchemy import Column, BigInteger, ForeignKey
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import BigInteger, Column, ForeignKey
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class PasswordReset(SQLModel, table=True):
|
||||
"""密码重置记录"""
|
||||
|
||||
|
||||
__tablename__: str = "password_resets"
|
||||
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False, index=True))
|
||||
email: str = Field(index=True)
|
||||
|
||||
@@ -21,16 +21,14 @@ class ItemAttemptsCountBase(SQLModel):
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
attempts: int = Field(default=0)
|
||||
completed: int = Field(default=0)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
accuracy: float = 0.0
|
||||
pp: float = 0
|
||||
total_score: int = 0
|
||||
|
||||
|
||||
class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
|
||||
__tablename__ = "item_attempts_count" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "item_attempts_count"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
|
||||
user: User = Relationship()
|
||||
@@ -63,9 +61,7 @@ class ItemAttemptsCount(AsyncAttrs, ItemAttemptsCountBase, table=True):
|
||||
self.pp = sum(score.score.pp for score in playlist_scores)
|
||||
self.completed = len([score for score in playlist_scores if score.score.passed])
|
||||
self.accuracy = (
|
||||
sum(score.score.accuracy for score in playlist_scores) / self.completed
|
||||
if self.completed > 0
|
||||
else 0.0
|
||||
sum(score.score.accuracy for score in playlist_scores) / self.completed if self.completed > 0 else 0.0
|
||||
)
|
||||
await session.commit()
|
||||
await session.refresh(self)
|
||||
|
||||
@@ -21,14 +21,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class PlaylistBestScore(SQLModel, table=True):
|
||||
__tablename__ = "playlist_best_scores" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "playlist_best_scores"
|
||||
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
score_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
|
||||
room_id: int = Field(foreign_key="rooms.id", index=True)
|
||||
playlist_id: int = Field(index=True)
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
|
||||
@@ -50,7 +50,7 @@ class PlaylistBase(SQLModel, UTCBaseModel):
|
||||
|
||||
|
||||
class Playlist(PlaylistBase, table=True):
|
||||
__tablename__ = "room_playlists" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "room_playlists"
|
||||
db_id: int = Field(default=None, primary_key=True, index=True, exclude=True)
|
||||
room_id: int = Field(foreign_key="rooms.id", exclude=True)
|
||||
|
||||
@@ -63,16 +63,12 @@ class Playlist(PlaylistBase, table=True):
|
||||
|
||||
@classmethod
|
||||
async def get_next_id_for_room(cls, room_id: int, session: AsyncSession) -> int:
|
||||
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(
|
||||
cls.room_id == room_id
|
||||
)
|
||||
stmt = select(func.coalesce(func.max(cls.id), -1) + 1).where(cls.room_id == room_id)
|
||||
result = await session.exec(stmt)
|
||||
return result.one()
|
||||
|
||||
@classmethod
|
||||
async def from_hub(
|
||||
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
|
||||
) -> "Playlist":
|
||||
async def from_hub(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession) -> "Playlist":
|
||||
next_id = await cls.get_next_id_for_room(room_id, session=session)
|
||||
return cls(
|
||||
id=next_id,
|
||||
@@ -90,9 +86,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
|
||||
@classmethod
|
||||
async def update(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
db_playlist = await session.exec(
|
||||
select(cls).where(cls.id == playlist.id, cls.room_id == room_id)
|
||||
)
|
||||
db_playlist = await session.exec(select(cls).where(cls.id == playlist.id, cls.room_id == room_id))
|
||||
db_playlist = db_playlist.first()
|
||||
if db_playlist is None:
|
||||
raise ValueError("Playlist item not found")
|
||||
@@ -108,9 +102,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def add_to_db(
|
||||
cls, playlist: PlaylistItem, room_id: int, session: AsyncSession
|
||||
):
|
||||
async def add_to_db(cls, playlist: PlaylistItem, room_id: int, session: AsyncSession):
|
||||
db_playlist = await cls.from_hub(playlist, room_id, session)
|
||||
session.add(db_playlist)
|
||||
await session.commit()
|
||||
@@ -119,9 +111,7 @@ class Playlist(PlaylistBase, table=True):
|
||||
|
||||
@classmethod
|
||||
async def delete_item(cls, item_id: int, room_id: int, session: AsyncSession):
|
||||
db_playlist = await session.exec(
|
||||
select(cls).where(cls.id == item_id, cls.room_id == room_id)
|
||||
)
|
||||
db_playlist = await session.exec(select(cls).where(cls.id == item_id, cls.room_id == room_id))
|
||||
db_playlist = db_playlist.first()
|
||||
if db_playlist is None:
|
||||
raise ValueError("Playlist item not found")
|
||||
@@ -133,9 +123,7 @@ class PlaylistResp(PlaylistBase):
|
||||
beatmap: BeatmapResp | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, playlist: Playlist, include: list[str] = []
|
||||
) -> "PlaylistResp":
|
||||
async def from_db(cls, playlist: Playlist, include: list[str] = []) -> "PlaylistResp":
|
||||
data = playlist.model_dump()
|
||||
if "beatmap" in include:
|
||||
data["beatmap"] = await BeatmapResp.from_db(playlist.beatmap)
|
||||
|
||||
@@ -20,13 +20,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class PPBestScore(SQLModel, table=True):
|
||||
__tablename__ = "best_scores" # pyright: ignore[reportAssignmentType]
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
score_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True)
|
||||
)
|
||||
__tablename__: str = "best_scores"
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
score_id: int = Field(sa_column=Column(BigInteger, ForeignKey("scores.id"), primary_key=True))
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id", index=True)
|
||||
gamemode: GameMode = Field(index=True)
|
||||
pp: float = Field(
|
||||
|
||||
@@ -26,12 +26,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class RankHistory(SQLModel, table=True):
|
||||
__tablename__ = "rank_history" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "rank_history"
|
||||
|
||||
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
mode: GameMode
|
||||
rank: int
|
||||
date: dt = Field(
|
||||
@@ -43,12 +41,10 @@ class RankHistory(SQLModel, table=True):
|
||||
|
||||
|
||||
class RankTop(SQLModel, table=True):
|
||||
__tablename__ = "rank_top" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "rank_top"
|
||||
|
||||
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
mode: GameMode
|
||||
rank: int
|
||||
date: dt = Field(
|
||||
@@ -62,9 +58,7 @@ class RankHistoryResp(BaseModel):
|
||||
data: list[int]
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, session: AsyncSession, user_id: int, mode: GameMode
|
||||
) -> "RankHistoryResp":
|
||||
async def from_db(cls, session: AsyncSession, user_id: int, mode: GameMode) -> "RankHistoryResp":
|
||||
results = (
|
||||
await session.exec(
|
||||
select(RankHistory)
|
||||
|
||||
@@ -21,7 +21,7 @@ class RelationshipType(str, Enum):
|
||||
|
||||
|
||||
class Relationship(SQLModel, table=True):
|
||||
__tablename__ = "relationship" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "relationship"
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(BigInteger, autoincrement=True, primary_key=True),
|
||||
@@ -59,9 +59,7 @@ class RelationshipResp(BaseModel):
|
||||
type: RelationshipType
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, session: AsyncSession, relationship: Relationship
|
||||
) -> "RelationshipResp":
|
||||
async def from_db(cls, session: AsyncSession, relationship: Relationship) -> "RelationshipResp":
|
||||
target_relationship = (
|
||||
await session.exec(
|
||||
select(Relationship).where(
|
||||
|
||||
@@ -58,11 +58,9 @@ class RoomBase(SQLModel, UTCBaseModel):
|
||||
|
||||
|
||||
class Room(AsyncAttrs, RoomBase, table=True):
|
||||
__tablename__ = "rooms" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "rooms"
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
host_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
host_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
|
||||
host: User = Relationship()
|
||||
playlist: list[Playlist] = Relationship(
|
||||
@@ -109,12 +107,8 @@ class RoomResp(RoomBase):
|
||||
if not playlist.expired:
|
||||
stats.count_active += 1
|
||||
rulesets.add(playlist.ruleset_id)
|
||||
difficulty_range.min = min(
|
||||
difficulty_range.min, playlist.beatmap.difficulty_rating
|
||||
)
|
||||
difficulty_range.max = max(
|
||||
difficulty_range.max, playlist.beatmap.difficulty_rating
|
||||
)
|
||||
difficulty_range.min = min(difficulty_range.min, playlist.beatmap.difficulty_rating)
|
||||
difficulty_range.max = max(difficulty_range.max, playlist.beatmap.difficulty_rating)
|
||||
resp.playlist.append(await PlaylistResp.from_db(playlist, ["beatmap"]))
|
||||
stats.ruleset_ids = list(rulesets)
|
||||
resp.playlist_item_stats = stats
|
||||
@@ -137,13 +131,9 @@ class RoomResp(RoomBase):
|
||||
include=["statistics"],
|
||||
)
|
||||
)
|
||||
resp.host = await UserResp.from_db(
|
||||
await room.awaitable_attrs.host, session, include=["statistics"]
|
||||
)
|
||||
resp.host = await UserResp.from_db(await room.awaitable_attrs.host, session, include=["statistics"])
|
||||
if "current_user_score" in include and user:
|
||||
resp.current_user_score = await PlaylistAggregateScore.from_db(
|
||||
room.id, user.id, session
|
||||
)
|
||||
resp.current_user_score = await PlaylistAggregateScore.from_db(room.id, user.id, session)
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -18,22 +18,16 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class RoomParticipatedUser(AsyncAttrs, SQLModel, table=True):
|
||||
__tablename__ = "room_participated_users" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "room_participated_users"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
)
|
||||
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True, autoincrement=True))
|
||||
room_id: int = Field(sa_column=Column(ForeignKey("rooms.id"), nullable=False))
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), nullable=False))
|
||||
joined_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=False),
|
||||
default=datetime.now(UTC),
|
||||
)
|
||||
left_at: datetime | None = Field(
|
||||
sa_column=Column(DateTime(timezone=True), nullable=True), default=None
|
||||
)
|
||||
left_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True), nullable=True), default=None)
|
||||
|
||||
room: "Room" = Relationship()
|
||||
user: "User" = Relationship()
|
||||
|
||||
@@ -47,9 +47,9 @@ from .score_token import ScoreToken
|
||||
|
||||
from pydantic import field_serializer, field_validator
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime
|
||||
from sqlalchemy import Boolean, Column, ColumnExpressionArgument, DateTime, TextClause
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import Mapped, aliased
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
@@ -76,9 +76,7 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
accuracy: float
|
||||
map_md5: str = Field(max_length=32, index=True)
|
||||
build_id: int | None = Field(default=None)
|
||||
classic_total_score: int | None = Field(
|
||||
default=0, sa_column=Column(BigInteger)
|
||||
) # solo_score
|
||||
classic_total_score: int | None = Field(default=0, sa_column=Column(BigInteger)) # solo_score
|
||||
ended_at: datetime = Field(sa_column=Column(DateTime))
|
||||
has_replay: bool = Field(sa_column=Column(Boolean))
|
||||
max_combo: int
|
||||
@@ -91,14 +89,10 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
room_id: int | None = Field(default=None) # multiplayer
|
||||
started_at: datetime = Field(sa_column=Column(DateTime))
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
total_score_without_mods: int = Field(
|
||||
default=0, sa_column=Column(BigInteger), exclude=True
|
||||
)
|
||||
total_score_without_mods: int = Field(default=0, sa_column=Column(BigInteger), exclude=True)
|
||||
type: str
|
||||
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
||||
maximum_statistics: ScoreStatistics = Field(
|
||||
sa_column=Column(JSON), default_factory=dict
|
||||
)
|
||||
maximum_statistics: ScoreStatistics = Field(sa_column=Column(JSON), default_factory=dict)
|
||||
|
||||
@field_validator("maximum_statistics", mode="before")
|
||||
@classmethod
|
||||
@@ -147,10 +141,8 @@ class ScoreBase(AsyncAttrs, SQLModel, UTCBaseModel):
|
||||
|
||||
|
||||
class Score(ScoreBase, table=True):
|
||||
__tablename__ = "scores" # pyright: ignore[reportAssignmentType]
|
||||
id: int | None = Field(
|
||||
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
|
||||
)
|
||||
__tablename__: str = "scores"
|
||||
id: int = Field(default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True))
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
@@ -193,8 +185,8 @@ class Score(ScoreBase, table=True):
|
||||
return str(v)
|
||||
|
||||
# optional
|
||||
beatmap: Beatmap = Relationship()
|
||||
user: User = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
beatmap: Mapped[Beatmap] = Relationship()
|
||||
user: Mapped[User] = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@property
|
||||
def is_perfect_combo(self) -> bool:
|
||||
@@ -205,11 +197,7 @@ class Score(ScoreBase, table=True):
|
||||
*where_clauses: ColumnExpressionArgument[bool] | bool,
|
||||
) -> SelectOfScalar["Score"]:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()
|
||||
)
|
||||
.label("rn")
|
||||
func.row_number().over(partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()).label("rn")
|
||||
)
|
||||
subq = select(Score, rownum).where(*where_clauses).subquery()
|
||||
best = aliased(Score, subq, adapt_on_names=True)
|
||||
@@ -296,12 +284,9 @@ class ScoreResp(ScoreBase):
|
||||
await session.refresh(score)
|
||||
|
||||
s = cls.model_validate(score.model_dump())
|
||||
assert score.id
|
||||
await score.awaitable_attrs.beatmap
|
||||
s.beatmap = await BeatmapResp.from_db(score.beatmap)
|
||||
s.beatmapset = await BeatmapsetResp.from_db(
|
||||
score.beatmap.beatmapset, session=session, user=score.user
|
||||
)
|
||||
s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset, session=session, user=score.user)
|
||||
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
|
||||
s.legacy_perfect = s.max_combo == s.beatmap.max_combo
|
||||
s.ruleset_id = int(score.gamemode)
|
||||
@@ -371,11 +356,7 @@ class ScoreAround(SQLModel):
|
||||
|
||||
async def get_best_id(session: AsyncSession, score_id: int) -> None:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()
|
||||
)
|
||||
.label("rn")
|
||||
func.row_number().over(partition_by=col(PPBestScore.user_id), order_by=col(PPBestScore.pp).desc()).label("rn")
|
||||
)
|
||||
subq = select(PPBestScore, rownum).subquery()
|
||||
stmt = select(subq.c.rn).where(subq.c.score_id == score_id)
|
||||
@@ -389,8 +370,8 @@ async def _score_where(
|
||||
mode: GameMode,
|
||||
mods: list[str] | None = None,
|
||||
user: User | None = None,
|
||||
) -> list[ColumnElement[bool]] | None:
|
||||
wheres = [
|
||||
) -> list[ColumnElement[bool] | TextClause] | None:
|
||||
wheres: list[ColumnElement[bool] | TextClause] = [
|
||||
col(BestScore.beatmap_id) == beatmap,
|
||||
col(BestScore.gamemode) == mode,
|
||||
]
|
||||
@@ -410,9 +391,7 @@ async def _score_where(
|
||||
return None
|
||||
elif type == LeaderboardType.COUNTRY:
|
||||
if user and user.is_supporter:
|
||||
wheres.append(
|
||||
col(BestScore.user).has(col(User.country_code) == user.country_code)
|
||||
)
|
||||
wheres.append(col(BestScore.user).has(col(User.country_code) == user.country_code))
|
||||
else:
|
||||
return None
|
||||
elif type == LeaderboardType.TEAM:
|
||||
@@ -420,18 +399,14 @@ async def _score_where(
|
||||
team_membership = await user.awaitable_attrs.team_membership
|
||||
if team_membership:
|
||||
team_id = team_membership.team_id
|
||||
wheres.append(
|
||||
col(BestScore.user).has(
|
||||
col(User.team_membership).has(TeamMember.team_id == team_id)
|
||||
)
|
||||
)
|
||||
wheres.append(col(BestScore.user).has(col(User.team_membership).has(TeamMember.team_id == team_id)))
|
||||
if mods:
|
||||
if user and user.is_supporter:
|
||||
wheres.append(
|
||||
text(
|
||||
"JSON_CONTAINS(total_score_best_scores.mods, :w)"
|
||||
" AND JSON_CONTAINS(:w, total_score_best_scores.mods)"
|
||||
).params(w=json.dumps(mods)) # pyright: ignore[reportArgumentType]
|
||||
).params(w=json.dumps(mods))
|
||||
)
|
||||
else:
|
||||
return None
|
||||
@@ -654,18 +629,14 @@ def calculate_playtime(score: Score, beatmap_length: int) -> tuple[int, bool]:
|
||||
+ (score.nsmall_tick_hit or 0)
|
||||
)
|
||||
total_obj = 0
|
||||
for statistics, count in (
|
||||
score.maximum_statistics.items() if score.maximum_statistics else {}
|
||||
):
|
||||
for statistics, count in score.maximum_statistics.items() if score.maximum_statistics else {}:
|
||||
if not isinstance(statistics, HitResult):
|
||||
statistics = HitResult(statistics)
|
||||
if statistics.is_scorable():
|
||||
total_obj += count
|
||||
|
||||
return total_length, score.passed or (
|
||||
total_length > 8
|
||||
and score.total_score >= 5000
|
||||
and total_obj_hited >= min(0.1 * total_obj, 20)
|
||||
total_length > 8 and score.total_score >= 5000 and total_obj_hited >= min(0.1 * total_obj, 20)
|
||||
)
|
||||
|
||||
|
||||
@@ -678,12 +649,8 @@ async def process_user(
|
||||
ranked: bool = False,
|
||||
has_leaderboard: bool = False,
|
||||
):
|
||||
assert user.id
|
||||
assert score.id
|
||||
mod_for_save = mod_to_save(score.mods)
|
||||
previous_score_best = await get_user_best_score_in_beatmap(
|
||||
session, score.beatmap_id, user.id, score.gamemode
|
||||
)
|
||||
previous_score_best = await get_user_best_score_in_beatmap(session, score.beatmap_id, user.id, score.gamemode)
|
||||
previous_score_best_mod = await get_user_best_score_with_mod_in_beatmap(
|
||||
session, score.beatmap_id, user.id, mod_for_save, score.gamemode
|
||||
)
|
||||
@@ -698,9 +665,7 @@ async def process_user(
|
||||
)
|
||||
).first()
|
||||
if mouthly_playcount is None:
|
||||
mouthly_playcount = MonthlyPlaycounts(
|
||||
user_id=user.id, year=date.today().year, month=date.today().month
|
||||
)
|
||||
mouthly_playcount = MonthlyPlaycounts(user_id=user.id, year=date.today().year, month=date.today().month)
|
||||
add_to_db = True
|
||||
statistics = None
|
||||
for i in await user.awaitable_attrs.statistics:
|
||||
@@ -708,17 +673,11 @@ async def process_user(
|
||||
statistics = i
|
||||
break
|
||||
if statistics is None:
|
||||
raise ValueError(
|
||||
f"User {user.id} does not have statistics for mode {score.gamemode.value}"
|
||||
)
|
||||
raise ValueError(f"User {user.id} does not have statistics for mode {score.gamemode.value}")
|
||||
|
||||
# pc, pt, tth, tts
|
||||
statistics.total_score += score.total_score
|
||||
difference = (
|
||||
score.total_score - previous_score_best.total_score
|
||||
if previous_score_best
|
||||
else score.total_score
|
||||
)
|
||||
difference = score.total_score - previous_score_best.total_score if previous_score_best else score.total_score
|
||||
if difference > 0 and score.passed and ranked:
|
||||
match score.rank:
|
||||
case Rank.X:
|
||||
@@ -746,11 +705,8 @@ async def process_user(
|
||||
statistics.ranked_score += difference
|
||||
statistics.level_current = calculate_score_to_level(statistics.total_score)
|
||||
statistics.maximum_combo = max(statistics.maximum_combo, score.max_combo)
|
||||
new_score_position = await get_score_position_by_user(
|
||||
session, score.beatmap_id, user, score.gamemode
|
||||
)
|
||||
new_score_position = await get_score_position_by_user(session, score.beatmap_id, user, score.gamemode)
|
||||
total_users = await session.exec(select(func.count()).select_from(User))
|
||||
assert total_users is not None
|
||||
score_range = min(50, math.ceil(float(total_users.one()) * 0.01))
|
||||
if new_score_position <= score_range and new_score_position > 0:
|
||||
# Get the scores that might be displaced
|
||||
@@ -774,11 +730,7 @@ async def process_user(
|
||||
)
|
||||
|
||||
# If this score was previously in top positions but now pushed out
|
||||
if (
|
||||
i < score_range
|
||||
and displaced_position > score_range
|
||||
and displaced_position is not None
|
||||
):
|
||||
if i < score_range and displaced_position > score_range and displaced_position is not None:
|
||||
# Create rank lost event for the displaced user
|
||||
rank_lost_event = Event(
|
||||
created_at=datetime.now(UTC),
|
||||
@@ -814,10 +766,7 @@ async def process_user(
|
||||
)
|
||||
|
||||
# 情况3: 有最佳分数记录和该mod组合的记录,且是同一个记录,更新得分更高的情况
|
||||
elif (
|
||||
previous_score_best.score_id == previous_score_best_mod.score_id
|
||||
and difference > 0
|
||||
):
|
||||
elif previous_score_best.score_id == previous_score_best_mod.score_id and difference > 0:
|
||||
previous_score_best.total_score = score.total_score
|
||||
previous_score_best.rank = score.rank
|
||||
previous_score_best.score_id = score.id
|
||||
@@ -847,9 +796,7 @@ async def process_user(
|
||||
statistics.count_300 += score.n300 + score.ngeki
|
||||
statistics.count_50 += score.n50
|
||||
statistics.count_miss += score.nmiss
|
||||
statistics.total_hits += (
|
||||
score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
|
||||
)
|
||||
statistics.total_hits += score.n300 + score.n100 + score.n50 + score.ngeki + score.nkatu
|
||||
|
||||
if score.passed and ranked:
|
||||
with session.no_autoflush:
|
||||
@@ -885,7 +832,6 @@ async def process_score(
|
||||
item_id: int | None = None,
|
||||
room_id: int | None = None,
|
||||
) -> Score:
|
||||
assert user.id
|
||||
can_get_pp = info.passed and ranked and mods_can_get_pp(info.ruleset_id, info.mods)
|
||||
gamemode = GameMode.from_int(info.ruleset_id).to_special_mode(info.mods)
|
||||
score = Score(
|
||||
@@ -922,20 +868,15 @@ async def process_score(
|
||||
if can_get_pp:
|
||||
from app.calculator import pre_fetch_and_calculate_pp
|
||||
|
||||
pp = await pre_fetch_and_calculate_pp(
|
||||
score, beatmap_id, session, redis, fetcher
|
||||
)
|
||||
pp = await pre_fetch_and_calculate_pp(score, beatmap_id, session, redis, fetcher)
|
||||
score.pp = pp
|
||||
session.add(score)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
await session.refresh(score)
|
||||
if can_get_pp and score.pp != 0:
|
||||
previous_pp_best = await get_user_best_pp_in_beatmap(
|
||||
session, beatmap_id, user_id, score.gamemode
|
||||
)
|
||||
previous_pp_best = await get_user_best_pp_in_beatmap(session, beatmap_id, user_id, score.gamemode)
|
||||
if previous_pp_best is None or score.pp > previous_pp_best.pp:
|
||||
assert score.id
|
||||
best_score = PPBestScore(
|
||||
user_id=user_id,
|
||||
score_id=score.id,
|
||||
|
||||
@@ -7,6 +7,7 @@ from .beatmap import Beatmap
|
||||
from .lazer_user import User
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
|
||||
@@ -14,16 +15,12 @@ class ScoreTokenBase(SQLModel, UTCBaseModel):
|
||||
score_id: int | None = Field(sa_column=Column(BigInteger), default=None)
|
||||
ruleset_id: GameMode
|
||||
playlist_item_id: int | None = Field(default=None) # playlist
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
|
||||
|
||||
|
||||
class ScoreToken(ScoreTokenBase, table=True):
|
||||
__tablename__ = "score_tokens" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "score_tokens"
|
||||
__table_args__ = (Index("idx_user_playlist", "user_id", "playlist_item_id"),)
|
||||
|
||||
id: int | None = Field(
|
||||
@@ -37,8 +34,8 @@ class ScoreToken(ScoreTokenBase, table=True):
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id")))
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id")
|
||||
user: User = Relationship()
|
||||
beatmap: Beatmap = Relationship()
|
||||
user: Mapped[User] = Relationship()
|
||||
beatmap: Mapped[Beatmap] = Relationship()
|
||||
|
||||
|
||||
class ScoreTokenResp(ScoreTokenBase):
|
||||
|
||||
@@ -58,7 +58,7 @@ class UserStatisticsBase(SQLModel):
|
||||
|
||||
|
||||
class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
|
||||
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "lazer_user_statistics"
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
@@ -123,9 +123,7 @@ class UserStatisticsResp(UserStatisticsBase):
|
||||
if "user" in include:
|
||||
from .lazer_user import RANKING_INCLUDES, UserResp
|
||||
|
||||
user = await UserResp.from_db(
|
||||
await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES
|
||||
)
|
||||
user = await UserResp.from_db(await obj.awaitable_attrs.user, session, include=RANKING_INCLUDES)
|
||||
s.user = user
|
||||
user_country = user.country_code
|
||||
|
||||
@@ -149,9 +147,7 @@ class UserStatisticsResp(UserStatisticsBase):
|
||||
return s
|
||||
|
||||
|
||||
async def get_rank(
|
||||
session: AsyncSession, statistics: UserStatistics, country: str | None = None
|
||||
) -> int | None:
|
||||
async def get_rank(session: AsyncSession, statistics: UserStatistics, country: str | None = None) -> int | None:
|
||||
from .lazer_user import User
|
||||
|
||||
query = select(
|
||||
@@ -168,9 +164,7 @@ async def get_rank(
|
||||
|
||||
subq = query.subquery()
|
||||
|
||||
result = await session.exec(
|
||||
select(subq.c.rank).where(subq.c.user_id == statistics.user_id)
|
||||
)
|
||||
result = await session.exec(select(subq.c.rank).where(subq.c.user_id == statistics.user_id))
|
||||
|
||||
rank = result.first()
|
||||
if rank is None:
|
||||
|
||||
@@ -11,9 +11,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Team(SQLModel, UTCBaseModel, table=True):
|
||||
__tablename__ = "teams" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "teams"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
id: int = Field(default=None, primary_key=True, index=True)
|
||||
name: str = Field(max_length=100)
|
||||
short_name: str = Field(max_length=10)
|
||||
flag_url: str | None = Field(default=None)
|
||||
@@ -26,34 +26,22 @@ class Team(SQLModel, UTCBaseModel, table=True):
|
||||
|
||||
|
||||
class TeamMember(SQLModel, UTCBaseModel, table=True):
|
||||
__tablename__ = "team_members" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "team_members"
|
||||
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
|
||||
team_id: int = Field(foreign_key="teams.id")
|
||||
joined_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
joined_at: datetime = Field(default_factory=datetime.utcnow, sa_column=Column(DateTime))
|
||||
|
||||
user: "User" = Relationship(
|
||||
back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
team: "Team" = Relationship(
|
||||
back_populates="members", sa_relationship_kwargs={"lazy": "joined"}
|
||||
)
|
||||
user: "User" = Relationship(back_populates="team_membership", sa_relationship_kwargs={"lazy": "joined"})
|
||||
team: "Team" = Relationship(back_populates="members", sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
|
||||
class TeamRequest(SQLModel, UTCBaseModel, table=True):
|
||||
__tablename__ = "team_requests" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "team_requests"
|
||||
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), primary_key=True))
|
||||
team_id: int = Field(foreign_key="teams.id", primary_key=True)
|
||||
requested_at: datetime = Field(
|
||||
default=datetime.now(UTC), sa_column=Column(DateTime)
|
||||
)
|
||||
requested_at: datetime = Field(default=datetime.now(UTC), sa_column=Column(DateTime))
|
||||
|
||||
user: "User" = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
team: "Team" = Relationship(sa_relationship_kwargs={"lazy": "joined"})
|
||||
|
||||
@@ -22,7 +22,7 @@ class UserAccountHistoryBase(SQLModel, UTCBaseModel):
|
||||
|
||||
|
||||
class UserAccountHistory(UserAccountHistoryBase, table=True):
|
||||
__tablename__ = "user_account_history" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "user_account_history"
|
||||
|
||||
id: int | None = Field(
|
||||
sa_column=Column(
|
||||
@@ -32,9 +32,7 @@ class UserAccountHistory(UserAccountHistoryBase, table=True):
|
||||
primary_key=True,
|
||||
)
|
||||
)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
|
||||
|
||||
class UserAccountHistoryResp(UserAccountHistoryBase):
|
||||
|
||||
@@ -10,27 +10,17 @@ from sqlmodel import Field, SQLModel
|
||||
class UserLoginLog(SQLModel, table=True):
|
||||
"""User login log table"""
|
||||
|
||||
__tablename__ = "user_login_log" # pyright: ignore[reportAssignmentType]
|
||||
__tablename__: str = "user_login_log"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, description="Record ID")
|
||||
user_id: int = Field(index=True, description="User ID")
|
||||
ip_address: str = Field(
|
||||
max_length=45, index=True, description="IP address (supports IPv4 and IPv6)"
|
||||
)
|
||||
user_agent: str | None = Field(
|
||||
default=None, max_length=500, description="User agent information"
|
||||
)
|
||||
login_time: datetime = Field(
|
||||
default_factory=datetime.utcnow, description="Login time"
|
||||
)
|
||||
ip_address: str = Field(max_length=45, index=True, description="IP address (supports IPv4 and IPv6)")
|
||||
user_agent: str | None = Field(default=None, max_length=500, description="User agent information")
|
||||
login_time: datetime = Field(default_factory=datetime.utcnow, description="Login time")
|
||||
|
||||
# GeoIP information
|
||||
country_code: str | None = Field(
|
||||
default=None, max_length=2, description="Country code"
|
||||
)
|
||||
country_name: str | None = Field(
|
||||
default=None, max_length=100, description="Country name"
|
||||
)
|
||||
country_code: str | None = Field(default=None, max_length=2, description="Country code")
|
||||
country_name: str | None = Field(default=None, max_length=100, description="Country name")
|
||||
city_name: str | None = Field(default=None, max_length=100, description="City name")
|
||||
latitude: str | None = Field(default=None, max_length=20, description="Latitude")
|
||||
longitude: str | None = Field(default=None, max_length=20, description="Longitude")
|
||||
@@ -38,22 +28,14 @@ class UserLoginLog(SQLModel, table=True):
|
||||
|
||||
# ASN information
|
||||
asn: int | None = Field(default=None, description="Autonomous System Number")
|
||||
organization: str | None = Field(
|
||||
default=None, max_length=200, description="Organization name"
|
||||
)
|
||||
organization: str | None = Field(default=None, max_length=200, description="Organization name")
|
||||
|
||||
# Login status
|
||||
login_success: bool = Field(
|
||||
default=True, description="Whether the login was successful"
|
||||
)
|
||||
login_method: str = Field(
|
||||
max_length=50, description="Login method (password/oauth/etc.)"
|
||||
)
|
||||
login_success: bool = Field(default=True, description="Whether the login was successful")
|
||||
login_method: str = Field(max_length=50, description="Login method (password/oauth/etc.)")
|
||||
|
||||
# Additional information
|
||||
notes: str | None = Field(
|
||||
default=None, max_length=500, description="Additional notes"
|
||||
)
|
||||
notes: str | None = Field(default=None, max_length=500, description="Additional notes")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@@ -40,15 +40,11 @@ engine = create_async_engine(
|
||||
redis_client = redis.from_url(settings.redis_url, decode_responses=True)
|
||||
|
||||
# Redis 消息缓存连接 (db1) - 使用同步客户端在线程池中执行
|
||||
redis_message_client = sync_redis.from_url(
|
||||
settings.redis_url, decode_responses=True, db=1
|
||||
)
|
||||
redis_message_client = sync_redis.from_url(settings.redis_url, decode_responses=True, db=1)
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
db_session_context: ContextVar[AsyncSession | None] = ContextVar(
|
||||
"db_session_context", default=None
|
||||
)
|
||||
db_session_context: ContextVar[AsyncSession | None] = ContextVar("db_session_context", default=None)
|
||||
|
||||
|
||||
async def get_db():
|
||||
|
||||
@@ -25,7 +25,5 @@ async def get_fetcher() -> Fetcher:
|
||||
if refresh_token:
|
||||
fetcher.refresh_token = str(refresh_token)
|
||||
if not fetcher.access_token or not fetcher.refresh_token:
|
||||
logger.opt(colors=True).info(
|
||||
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
|
||||
)
|
||||
logger.opt(colors=True).info(f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>")
|
||||
return fetcher
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC
|
||||
from typing import cast
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
@@ -16,7 +17,7 @@ def get_scheduler() -> AsyncIOScheduler:
|
||||
global scheduler
|
||||
if scheduler is None:
|
||||
init_scheduler()
|
||||
return scheduler # pyright: ignore[reportReturnType]
|
||||
return cast(AsyncIOScheduler, scheduler)
|
||||
|
||||
|
||||
def start_scheduler():
|
||||
|
||||
@@ -70,9 +70,7 @@ async def v1_authorize(
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
api_key_record = (
|
||||
await db.exec(select(V1APIKeys).where(V1APIKeys.key == api_key))
|
||||
).first()
|
||||
api_key_record = (await db.exec(select(V1APIKeys).where(V1APIKeys.key == api_key))).first()
|
||||
if not api_key_record:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
@@ -98,9 +96,7 @@ async def get_current_user(
|
||||
security_scopes: SecurityScopes,
|
||||
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
|
||||
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
|
||||
token_client_credentials: Annotated[
|
||||
str | None, Depends(oauth2_client_credentials)
|
||||
] = None,
|
||||
token_client_credentials: Annotated[str | None, Depends(oauth2_client_credentials)] = None,
|
||||
) -> User:
|
||||
"""获取当前认证用户"""
|
||||
token = token_pw or token_code or token_client_credentials
|
||||
@@ -119,9 +115,7 @@ async def get_current_user(
|
||||
if not is_client:
|
||||
for scope in security_scopes.scopes:
|
||||
if scope not in token_record.scope.split(","):
|
||||
raise HTTPException(
|
||||
status_code=403, detail=f"Insufficient scope: {scope}"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail=f"Insufficient scope: {scope}")
|
||||
|
||||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||||
if not user:
|
||||
|
||||
@@ -121,14 +121,10 @@ class BaseFetcher:
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"Request failed (attempt {attempt + 1}/{max_retries + 1}): {e}, retrying..."
|
||||
)
|
||||
logger.warning(f"Request failed (attempt {attempt + 1}/{max_retries + 1}): {e}, retrying...")
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"Request failed after {max_retries + 1} attempts: {e}"
|
||||
)
|
||||
logger.error(f"Request failed after {max_retries + 1} attempts: {e}")
|
||||
break
|
||||
|
||||
# 如果所有重试都失败了
|
||||
@@ -196,13 +192,9 @@ class BaseFetcher:
|
||||
f"fetcher:refresh_token:{self.client_id}",
|
||||
self.refresh_token,
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully refreshed access token for client {self.client_id}"
|
||||
)
|
||||
logger.info(f"Successfully refreshed access token for client {self.client_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to refresh access token for client {self.client_id}: {e}"
|
||||
)
|
||||
logger.error(f"Failed to refresh access token for client {self.client_id}: {e}")
|
||||
# 清除无效的 token,要求重新授权
|
||||
self.access_token = ""
|
||||
self.refresh_token = ""
|
||||
@@ -210,9 +202,7 @@ class BaseFetcher:
|
||||
redis = get_redis()
|
||||
await redis.delete(f"fetcher:access_token:{self.client_id}")
|
||||
await redis.delete(f"fetcher:refresh_token:{self.client_id}")
|
||||
logger.warning(
|
||||
f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}"
|
||||
)
|
||||
logger.warning(f"Cleared invalid tokens. Please re-authorize: {self.authorize_url}")
|
||||
raise
|
||||
|
||||
async def _trigger_reauthorization(self) -> None:
|
||||
@@ -237,8 +227,7 @@ class BaseFetcher:
|
||||
await redis.delete(f"fetcher:refresh_token:{self.client_id}")
|
||||
|
||||
logger.warning(
|
||||
f"All tokens cleared for client {self.client_id}. "
|
||||
f"Please re-authorize using: {self.authorize_url}"
|
||||
f"All tokens cleared for client {self.client_id}. Please re-authorize using: {self.authorize_url}"
|
||||
)
|
||||
|
||||
def reset_auth_retry_count(self) -> None:
|
||||
|
||||
@@ -7,18 +7,14 @@ from ._base import BaseFetcher
|
||||
|
||||
|
||||
class BeatmapFetcher(BaseFetcher):
|
||||
async def get_beatmap(
|
||||
self, beatmap_id: int | None = None, beatmap_checksum: str | None = None
|
||||
) -> BeatmapResp:
|
||||
async def get_beatmap(self, beatmap_id: int | None = None, beatmap_checksum: str | None = None) -> BeatmapResp:
|
||||
if beatmap_id:
|
||||
params = {"id": beatmap_id}
|
||||
elif beatmap_checksum:
|
||||
params = {"checksum": beatmap_checksum}
|
||||
else:
|
||||
raise ValueError("Either beatmap_id or beatmap_checksum must be provided.")
|
||||
logger.opt(colors=True).debug(
|
||||
f"<blue>[BeatmapFetcher]</blue> get_beatmap: <y>{params}</y>"
|
||||
)
|
||||
logger.opt(colors=True).debug(f"<blue>[BeatmapFetcher]</blue> get_beatmap: <y>{params}</y>")
|
||||
|
||||
return BeatmapResp.model_validate(
|
||||
await self.request_api(
|
||||
|
||||
@@ -18,9 +18,7 @@ class BeatmapRawFetcher(BaseFetcher):
|
||||
async def get_beatmap_raw(self, beatmap_id: int) -> str:
|
||||
for url in urls:
|
||||
req_url = url.format(beatmap_id=beatmap_id)
|
||||
logger.opt(colors=True).debug(
|
||||
f"<blue>[BeatmapRawFetcher]</blue> get_beatmap_raw: <y>{req_url}</y>"
|
||||
)
|
||||
logger.opt(colors=True).debug(f"<blue>[BeatmapRawFetcher]</blue> get_beatmap_raw: <y>{req_url}</y>")
|
||||
resp = await self._request(req_url)
|
||||
if resp.status_code >= 400:
|
||||
continue
|
||||
@@ -34,9 +32,7 @@ class BeatmapRawFetcher(BaseFetcher):
|
||||
)
|
||||
return response
|
||||
|
||||
async def get_or_fetch_beatmap_raw(
|
||||
self, redis: redis.Redis, beatmap_id: int
|
||||
) -> str:
|
||||
async def get_or_fetch_beatmap_raw(self, redis: redis.Redis, beatmap_id: int) -> str:
|
||||
from app.config import settings
|
||||
|
||||
cache_key = f"beatmap:{beatmap_id}:raw"
|
||||
@@ -48,7 +44,7 @@ class BeatmapRawFetcher(BaseFetcher):
|
||||
if content:
|
||||
# 延长缓存时间
|
||||
await redis.expire(cache_key, cache_expire)
|
||||
return content # pyright: ignore[reportReturnType]
|
||||
return content
|
||||
|
||||
# 获取并缓存
|
||||
raw = await self.get_beatmap_raw(beatmap_id)
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.helpers.rate_limiter import osu_api_rate_limiter
|
||||
from app.log import logger
|
||||
from app.models.beatmap import SearchQueryModel
|
||||
from app.models.model import Cursor
|
||||
from app.utils import bg_tasks
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
@@ -81,9 +82,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
cache_json = json.dumps(cache_data, sort_keys=True, separators=(",", ":"))
|
||||
cache_hash = hashlib.md5(cache_json.encode()).hexdigest()
|
||||
|
||||
logger.opt(colors=True).debug(
|
||||
f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}"
|
||||
)
|
||||
logger.opt(colors=True).debug(f"<blue>[CacheKey]</blue> Query: {cache_data}, Hash: {cache_hash}")
|
||||
|
||||
return f"beatmapset:search:{cache_hash}"
|
||||
|
||||
@@ -103,22 +102,16 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
return {}
|
||||
|
||||
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
|
||||
logger.opt(colors=True).debug(
|
||||
f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>"
|
||||
)
|
||||
logger.opt(colors=True).debug(f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>")
|
||||
|
||||
return BeatmapsetResp.model_validate(
|
||||
await self.request_api(
|
||||
f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}"
|
||||
)
|
||||
await self.request_api(f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}")
|
||||
)
|
||||
|
||||
async def search_beatmapset(
|
||||
self, query: SearchQueryModel, cursor: Cursor, redis_client: redis.Redis
|
||||
) -> SearchBeatmapsetsResp:
|
||||
logger.opt(colors=True).debug(
|
||||
f"<blue>[BeatmapsetFetcher]</blue> search_beatmapset: <y>{query}</y>"
|
||||
)
|
||||
logger.opt(colors=True).debug(f"<blue>[BeatmapsetFetcher]</blue> search_beatmapset: <y>{query}</y>")
|
||||
|
||||
# 生成缓存键
|
||||
cache_key = self._generate_cache_key(query, cursor)
|
||||
@@ -126,9 +119,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
# 尝试从缓存获取结果
|
||||
cached_result = await redis_client.get(cache_key)
|
||||
if cached_result:
|
||||
logger.opt(colors=True).debug(
|
||||
f"<green>[BeatmapsetFetcher]</green> Cache hit for key: <y>{cache_key}</y>"
|
||||
)
|
||||
logger.opt(colors=True).debug(f"<green>[BeatmapsetFetcher]</green> Cache hit for key: <y>{cache_key}</y>")
|
||||
try:
|
||||
cached_data = json.loads(cached_result)
|
||||
return SearchBeatmapsetsResp.model_validate(cached_data)
|
||||
@@ -138,13 +129,9 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
)
|
||||
|
||||
# 缓存未命中,从 API 获取数据
|
||||
logger.opt(colors=True).debug(
|
||||
"<blue>[BeatmapsetFetcher]</blue> Cache miss, fetching from API"
|
||||
)
|
||||
logger.opt(colors=True).debug("<blue>[BeatmapsetFetcher]</blue> Cache miss, fetching from API")
|
||||
|
||||
params = query.model_dump(
|
||||
exclude_none=True, exclude_unset=True, exclude_defaults=True
|
||||
)
|
||||
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
|
||||
|
||||
if query.cursor_string:
|
||||
params["cursor_string"] = query.cursor_string
|
||||
@@ -164,39 +151,26 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
|
||||
# 将结果缓存 15 分钟
|
||||
cache_ttl = 15 * 60 # 15 分钟
|
||||
await redis_client.set(
|
||||
cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl
|
||||
)
|
||||
await redis_client.set(cache_key, json.dumps(api_response, separators=(",", ":")), ex=cache_ttl)
|
||||
|
||||
logger.opt(colors=True).debug(
|
||||
f"<green>[BeatmapsetFetcher]</green> Cached result for key: "
|
||||
f"<y>{cache_key}</y> (TTL: {cache_ttl}s)"
|
||||
f"<green>[BeatmapsetFetcher]</green> Cached result for key: <y>{cache_key}</y> (TTL: {cache_ttl}s)"
|
||||
)
|
||||
|
||||
resp = SearchBeatmapsetsResp.model_validate(api_response)
|
||||
|
||||
# 智能预取:只在用户明确搜索时才预取,避免过多API请求
|
||||
# 且只在有搜索词或特定条件时预取,避免首页浏览时的过度预取
|
||||
if api_response.get("cursor") and (
|
||||
query.q or query.s != "leaderboard" or cursor
|
||||
):
|
||||
if api_response.get("cursor") and (query.q or query.s != "leaderboard" or cursor):
|
||||
# 在后台预取下1页(减少预取量)
|
||||
import asyncio
|
||||
|
||||
# 不立即创建任务,而是延迟一段时间再预取
|
||||
async def delayed_prefetch():
|
||||
await asyncio.sleep(3.0) # 延迟3秒
|
||||
await self.prefetch_next_pages(
|
||||
query, api_response["cursor"], redis_client, pages=1
|
||||
)
|
||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=1)
|
||||
|
||||
# 创建延迟预取任务
|
||||
task = asyncio.create_task(delayed_prefetch())
|
||||
# 添加到后台任务集合避免被垃圾回收
|
||||
if not hasattr(self, "_background_tasks"):
|
||||
self._background_tasks = set()
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
bg_tasks.add_task(delayed_prefetch)
|
||||
|
||||
return resp
|
||||
|
||||
@@ -218,18 +192,14 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
# 使用当前 cursor 请求下一页
|
||||
next_query = query.model_copy()
|
||||
|
||||
logger.opt(colors=True).debug(
|
||||
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetching page {page + 1}"
|
||||
)
|
||||
logger.opt(colors=True).debug(f"<cyan>[BeatmapsetFetcher]</cyan> Prefetching page {page + 1}")
|
||||
|
||||
# 生成下一页的缓存键
|
||||
next_cache_key = self._generate_cache_key(next_query, cursor)
|
||||
|
||||
# 检查是否已经缓存
|
||||
if await redis_client.exists(next_cache_key):
|
||||
logger.opt(colors=True).debug(
|
||||
f"<cyan>[BeatmapsetFetcher]</cyan> Page {page + 1} already cached"
|
||||
)
|
||||
logger.opt(colors=True).debug(f"<cyan>[BeatmapsetFetcher]</cyan> Page {page + 1} already cached")
|
||||
# 尝试从缓存获取cursor继续预取
|
||||
cached_data = await redis_client.get(next_cache_key)
|
||||
if cached_data:
|
||||
@@ -247,9 +217,7 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
await asyncio.sleep(1.5) # 1.5秒延迟
|
||||
|
||||
# 请求下一页数据
|
||||
params = next_query.model_dump(
|
||||
exclude_none=True, exclude_unset=True, exclude_defaults=True
|
||||
)
|
||||
params = next_query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
|
||||
|
||||
for k, v in cursor.items():
|
||||
params[f"cursor[{k}]"] = v
|
||||
@@ -277,22 +245,18 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
)
|
||||
|
||||
logger.opt(colors=True).debug(
|
||||
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetched page {page + 1} "
|
||||
f"(TTL: {prefetch_ttl}s)"
|
||||
f"<cyan>[BeatmapsetFetcher]</cyan> Prefetched page {page + 1} (TTL: {prefetch_ttl}s)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.opt(colors=True).warning(
|
||||
f"<yellow>[BeatmapsetFetcher]</yellow> Prefetch failed: {e}"
|
||||
)
|
||||
logger.opt(colors=True).warning(f"<yellow>[BeatmapsetFetcher]</yellow> Prefetch failed: {e}")
|
||||
|
||||
async def warmup_homepage_cache(self, redis_client: redis.Redis) -> None:
|
||||
"""预热主页缓存"""
|
||||
homepage_queries = self._get_homepage_queries()
|
||||
|
||||
logger.opt(colors=True).info(
|
||||
f"<magenta>[BeatmapsetFetcher]</magenta> Starting homepage cache warmup "
|
||||
f"({len(homepage_queries)} queries)"
|
||||
f"<magenta>[BeatmapsetFetcher]</magenta> Starting homepage cache warmup ({len(homepage_queries)} queries)"
|
||||
)
|
||||
|
||||
for i, (query, cursor) in enumerate(homepage_queries):
|
||||
@@ -306,15 +270,12 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
# 检查是否已经缓存
|
||||
if await redis_client.exists(cache_key):
|
||||
logger.opt(colors=True).debug(
|
||||
f"<magenta>[BeatmapsetFetcher]</magenta> "
|
||||
f"Query {query.sort} already cached"
|
||||
f"<magenta>[BeatmapsetFetcher]</magenta> Query {query.sort} already cached"
|
||||
)
|
||||
continue
|
||||
|
||||
# 请求并缓存
|
||||
params = query.model_dump(
|
||||
exclude_none=True, exclude_unset=True, exclude_defaults=True
|
||||
)
|
||||
params = query.model_dump(exclude_none=True, exclude_unset=True, exclude_defaults=True)
|
||||
|
||||
api_response = await self.request_api(
|
||||
"https://osu.ppy.sh/api/v2/beatmapsets/search",
|
||||
@@ -334,17 +295,13 @@ class BeatmapsetFetcher(BaseFetcher):
|
||||
)
|
||||
|
||||
logger.opt(colors=True).info(
|
||||
f"<magenta>[BeatmapsetFetcher]</magenta> "
|
||||
f"Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
|
||||
f"<magenta>[BeatmapsetFetcher]</magenta> Warmed up cache for {query.sort} (TTL: {cache_ttl}s)"
|
||||
)
|
||||
|
||||
if api_response.get("cursor"):
|
||||
await self.prefetch_next_pages(
|
||||
query, api_response["cursor"], redis_client, pages=2
|
||||
)
|
||||
await self.prefetch_next_pages(query, api_response["cursor"], redis_client, pages=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.opt(colors=True).error(
|
||||
f"<red>[BeatmapsetFetcher]</red> "
|
||||
f"Failed to warmup cache for {query.sort}: {e}"
|
||||
f"<red>[BeatmapsetFetcher]</red> Failed to warmup cache for {query.sort}: {e}"
|
||||
)
|
||||
|
||||
@@ -55,14 +55,9 @@ class GeoIPHelper:
|
||||
- 临时目录退出后自动清理
|
||||
"""
|
||||
if not self.license_key:
|
||||
raise ValueError(
|
||||
"缺少 MaxMind License Key,请传入或设置环境变量 MAXMIND_LICENSE_KEY"
|
||||
)
|
||||
raise ValueError("缺少 MaxMind License Key,请传入或设置环境变量 MAXMIND_LICENSE_KEY")
|
||||
|
||||
url = (
|
||||
f"{BASE_URL}?edition_id={edition_id}&"
|
||||
f"license_key={self.license_key}&suffix=tar.gz"
|
||||
)
|
||||
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
|
||||
|
||||
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
|
||||
with client.stream("GET", url) as resp:
|
||||
|
||||
@@ -48,8 +48,7 @@ class RateLimiter:
|
||||
|
||||
if wait_time > 0:
|
||||
logger.opt(colors=True).info(
|
||||
f"<yellow>[RateLimiter]</yellow> Rate limit reached, "
|
||||
f"waiting {wait_time:.2f}s"
|
||||
f"<yellow>[RateLimiter]</yellow> Rate limit reached, waiting {wait_time:.2f}s"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
current_time = time.time()
|
||||
@@ -107,11 +106,7 @@ class RateLimiter:
|
||||
"max_requests_per_minute": self.max_requests_per_minute,
|
||||
"burst_requests": len(self.burst_times),
|
||||
"burst_limit": self.burst_limit,
|
||||
"next_reset_in_seconds": (
|
||||
60.0 - (current_time - self.request_times[0])
|
||||
if self.request_times
|
||||
else 0.0
|
||||
),
|
||||
"next_reset_in_seconds": (60.0 - (current_time - self.request_times[0]) if self.request_times else 0.0),
|
||||
}
|
||||
|
||||
|
||||
|
||||
22
app/log.py
22
app/log.py
@@ -46,14 +46,10 @@ class InterceptHandler(logging.Handler):
|
||||
color = True
|
||||
else:
|
||||
color = False
|
||||
logger.opt(depth=depth, exception=record.exc_info, colors=color).log(
|
||||
level, message
|
||||
)
|
||||
logger.opt(depth=depth, exception=record.exc_info, colors=color).log(level, message)
|
||||
|
||||
def _format_uvicorn_error_log(self, message: str) -> str:
|
||||
websocket_pattern = (
|
||||
r'(\d+\.\d+\.\d+\.\d+:\d+)\s*-\s*"WebSocket\s+([^"]+)"\s+([\w\[\]]+)'
|
||||
)
|
||||
websocket_pattern = r'(\d+\.\d+\.\d+\.\d+:\d+)\s*-\s*"WebSocket\s+([^"]+)"\s+([\w\[\]]+)'
|
||||
websocket_match = re.search(websocket_pattern, message)
|
||||
|
||||
if websocket_match:
|
||||
@@ -64,14 +60,8 @@ class InterceptHandler(logging.Handler):
|
||||
"[accepted]": "<green>[accepted]</green>",
|
||||
"403": "<red>403 [rejected]</red>",
|
||||
}
|
||||
colored_status = status_colors.get(
|
||||
status.lower(), f"<white>{status}</white>"
|
||||
)
|
||||
return (
|
||||
f'{colored_ip} - "<bold><magenta>WebSocket</magenta> '
|
||||
f'{path}</bold>" '
|
||||
f"{colored_status}"
|
||||
)
|
||||
colored_status = status_colors.get(status.lower(), f"<white>{status}</white>")
|
||||
return f'{colored_ip} - "<bold><magenta>WebSocket</magenta> {path}</bold>" {colored_status}'
|
||||
else:
|
||||
return message
|
||||
|
||||
@@ -121,9 +111,7 @@ logger.remove()
|
||||
logger.add(
|
||||
stdout,
|
||||
colorize=True,
|
||||
format=(
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"
|
||||
),
|
||||
format=("<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"),
|
||||
level=settings.log_level,
|
||||
diagnose=settings.debug,
|
||||
)
|
||||
|
||||
@@ -19,17 +19,11 @@ class Achievement(NamedTuple):
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return (
|
||||
self.medal_url
|
||||
or f"https://assets.ppy.sh/medals/client/{self.assets_id}.png"
|
||||
)
|
||||
return self.medal_url or f"https://assets.ppy.sh/medals/client/{self.assets_id}.png"
|
||||
|
||||
@property
|
||||
def url2x(self) -> str:
|
||||
return (
|
||||
self.medal_url2x
|
||||
or f"https://assets.ppy.sh/medals/client/{self.assets_id}@2x.png"
|
||||
)
|
||||
return self.medal_url2x or f"https://assets.ppy.sh/medals/client/{self.assets_id}@2x.png"
|
||||
|
||||
|
||||
MedalProcessor = Callable[[AsyncSession, "Score", "Beatmap"], Awaitable[bool]]
|
||||
|
||||
@@ -11,7 +11,8 @@ class APIMe(UserResp):
|
||||
"""
|
||||
/me 端点的响应模型
|
||||
对应 osu! 的 APIMe 类型,继承 APIUser(UserResp) 并包含 session_verified 字段
|
||||
|
||||
|
||||
session_verified 字段已经在 UserResp 中定义,这里不需要重复定义
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -95,11 +95,7 @@ class SearchQueryModel(BaseModel):
|
||||
|
||||
q: str = Field("", description="搜索关键词")
|
||||
c: Annotated[
|
||||
list[
|
||||
Literal[
|
||||
"recommended", "converts", "follows", "spotlights", "featured_artists"
|
||||
]
|
||||
],
|
||||
list[Literal["recommended", "converts", "follows", "spotlights", "featured_artists"]],
|
||||
BeforeValidator(_parse_list),
|
||||
PlainSerializer(lambda x: ".".join(x)),
|
||||
] = Field(
|
||||
@@ -188,12 +184,10 @@ class SearchQueryModel(BaseModel):
|
||||
list[Literal["video", "storyboard"]],
|
||||
BeforeValidator(_parse_list),
|
||||
PlainSerializer(lambda x: ".".join(x)),
|
||||
] = Field(
|
||||
default_factory=list, description=("其他:video 有视频 / storyboard 有故事板")
|
||||
] = Field(default_factory=list, description=("其他:video 有视频 / storyboard 有故事板"))
|
||||
r: Annotated[list[Rank], BeforeValidator(_parse_list), PlainSerializer(lambda x: ".".join(x))] = Field(
|
||||
default_factory=list, description="成绩"
|
||||
)
|
||||
r: Annotated[
|
||||
list[Rank], BeforeValidator(_parse_list), PlainSerializer(lambda x: ".".join(x))
|
||||
] = Field(default_factory=list, description="成绩")
|
||||
played: bool = Field(
|
||||
default=False,
|
||||
description="玩过",
|
||||
|
||||
@@ -9,12 +9,13 @@ from pydantic import BaseModel
|
||||
|
||||
class ExtendedTokenResponse(BaseModel):
|
||||
"""扩展的令牌响应,支持二次验证状态"""
|
||||
|
||||
access_token: str | None = None
|
||||
token_type: str = "Bearer"
|
||||
expires_in: int | None = None
|
||||
refresh_token: str | None = None
|
||||
scope: str | None = None
|
||||
|
||||
|
||||
# 二次验证相关字段
|
||||
requires_second_factor: bool = False
|
||||
verification_message: str | None = None
|
||||
@@ -23,6 +24,7 @@ class ExtendedTokenResponse(BaseModel):
|
||||
|
||||
class SessionState(BaseModel):
|
||||
"""会话状态"""
|
||||
|
||||
user_id: int
|
||||
username: str
|
||||
email: str
|
||||
|
||||
@@ -145,9 +145,7 @@ class MultiplayerPlaylistItemStats(BaseModel):
|
||||
|
||||
class MultiplayerRoomStats(BaseModel):
|
||||
room_id: int
|
||||
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
playlist_item_stats: dict[int, MultiplayerPlaylistItemStats] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class MultiplayerRoomScoreSetEvent(BaseModel):
|
||||
|
||||
@@ -174,11 +174,7 @@ def mods_can_get_pp(ruleset_id: int, mods: list[APIMod]) -> bool:
|
||||
return True
|
||||
ranked_mods = RANKED_MODS[ruleset_id]
|
||||
for mod in mods:
|
||||
if (
|
||||
app_settings.enable_rx
|
||||
and mod["acronym"] == "RX"
|
||||
and ruleset_id in {0, 1, 2}
|
||||
):
|
||||
if app_settings.enable_rx and mod["acronym"] == "RX" and ruleset_id in {0, 1, 2}:
|
||||
continue
|
||||
if app_settings.enable_ap and mod["acronym"] == "AP" and ruleset_id == 0:
|
||||
continue
|
||||
@@ -251,10 +247,7 @@ def get_available_mods(ruleset_id: int, required_mods: list[APIMod]) -> list[API
|
||||
if mod_acronym in incompatible_mods:
|
||||
continue
|
||||
|
||||
if any(
|
||||
required_acronym in mod_data["IncompatibleMods"]
|
||||
for required_acronym in required_mod_acronyms
|
||||
):
|
||||
if any(required_acronym in mod_data["IncompatibleMods"] for required_acronym in required_mod_acronyms):
|
||||
continue
|
||||
|
||||
if mod_data.get("UserPlayable", False):
|
||||
|
||||
@@ -121,32 +121,21 @@ class PlaylistItem(BaseModel):
|
||||
star_rating: float
|
||||
freestyle: bool
|
||||
|
||||
def _validate_mod_for_ruleset(
|
||||
self, mod: APIMod, ruleset_key: int, context: str = "mod"
|
||||
) -> None:
|
||||
def _validate_mod_for_ruleset(self, mod: APIMod, ruleset_key: int, context: str = "mod") -> None:
|
||||
typed_ruleset_key = cast(Literal[0, 1, 2, 3], ruleset_key)
|
||||
|
||||
# Check if mod is valid for ruleset
|
||||
if (
|
||||
typed_ruleset_key not in API_MODS
|
||||
or mod["acronym"] not in API_MODS[typed_ruleset_key]
|
||||
):
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is invalid for this ruleset"
|
||||
)
|
||||
if typed_ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[typed_ruleset_key]:
|
||||
raise InvokeException(f"{context} {mod['acronym']} is invalid for this ruleset")
|
||||
|
||||
mod_settings = API_MODS[typed_ruleset_key][mod["acronym"]]
|
||||
|
||||
# Check if mod is unplayable in multiplayer
|
||||
if mod_settings.get("UserPlayable", True) is False:
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is not playable by users"
|
||||
)
|
||||
raise InvokeException(f"{context} {mod['acronym']} is not playable by users")
|
||||
|
||||
if mod_settings.get("ValidForMultiplayer", True) is False:
|
||||
raise InvokeException(
|
||||
f"{context} {mod['acronym']} is not valid for multiplayer"
|
||||
)
|
||||
raise InvokeException(f"{context} {mod['acronym']} is not valid for multiplayer")
|
||||
|
||||
def _check_mod_compatibility(self, mods: list[APIMod], ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
@@ -159,10 +148,7 @@ class PlaylistItem(BaseModel):
|
||||
incompatible = set(mod1_settings.get("IncompatibleMods", []))
|
||||
for mod2 in mods[i + 1 :]:
|
||||
if mod2["acronym"] in incompatible:
|
||||
raise InvokeException(
|
||||
f"Mods {mod1['acronym']} and "
|
||||
f"{mod2['acronym']} are incompatible"
|
||||
)
|
||||
raise InvokeException(f"Mods {mod1['acronym']} and {mod2['acronym']} are incompatible")
|
||||
|
||||
def _check_required_allowed_compatibility(self, ruleset_key: int) -> None:
|
||||
from typing import Literal, cast
|
||||
@@ -178,10 +164,7 @@ class PlaylistItem(BaseModel):
|
||||
conflicting_allowed = allowed_acronyms & incompatible
|
||||
if conflicting_allowed:
|
||||
conflict_list = ", ".join(conflicting_allowed)
|
||||
raise InvokeException(
|
||||
f"Required mod {req_acronym} conflicts with "
|
||||
f"allowed mods: {conflict_list}"
|
||||
)
|
||||
raise InvokeException(f"Required mod {req_acronym} conflicts with allowed mods: {conflict_list}")
|
||||
|
||||
def validate_playlist_item_mods(self) -> None:
|
||||
ruleset_key = cast(Literal[0, 1, 2, 3], self.ruleset_id)
|
||||
@@ -219,10 +202,7 @@ class PlaylistItem(BaseModel):
|
||||
|
||||
# Check if mods are valid for the ruleset
|
||||
for mod in proposed_mods:
|
||||
if (
|
||||
ruleset_key not in API_MODS
|
||||
or mod["acronym"] not in API_MODS[ruleset_key]
|
||||
):
|
||||
if ruleset_key not in API_MODS or mod["acronym"] not in API_MODS[ruleset_key]:
|
||||
all_proposed_valid = False
|
||||
continue
|
||||
valid_mods.append(mod)
|
||||
@@ -252,9 +232,7 @@ class PlaylistItem(BaseModel):
|
||||
|
||||
# Check compatibility with required mods
|
||||
required_mod_acronyms = {mod["acronym"] for mod in self.required_mods}
|
||||
all_mod_acronyms = {
|
||||
mod["acronym"] for mod in final_valid_mods
|
||||
} | required_mod_acronyms
|
||||
all_mod_acronyms = {mod["acronym"] for mod in final_valid_mods} | required_mod_acronyms
|
||||
|
||||
# Check for incompatibility between required and user mods
|
||||
filtered_valid_mods = []
|
||||
@@ -288,9 +266,7 @@ class PlaylistItem(BaseModel):
|
||||
class _MultiplayerCountdown(SignalRUnionMessage):
|
||||
id: int = 0
|
||||
time_remaining: timedelta
|
||||
is_exclusive: Annotated[
|
||||
bool, Field(default=True), SignalRMeta(member_ignore=True)
|
||||
] = True
|
||||
is_exclusive: Annotated[bool, Field(default=True), SignalRMeta(member_ignore=True)] = True
|
||||
|
||||
|
||||
class MatchStartCountdown(_MultiplayerCountdown):
|
||||
@@ -305,17 +281,13 @@ class ServerShuttingDownCountdown(_MultiplayerCountdown):
|
||||
union_type: ClassVar[Literal[2]] = 2
|
||||
|
||||
|
||||
MultiplayerCountdown = (
|
||||
MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
|
||||
)
|
||||
MultiplayerCountdown = MatchStartCountdown | ForceGameplayStartCountdown | ServerShuttingDownCountdown
|
||||
|
||||
|
||||
class MultiplayerRoomUser(BaseModel):
|
||||
user_id: int
|
||||
state: MultiplayerUserState = MultiplayerUserState.IDLE
|
||||
availability: BeatmapAvailability = BeatmapAvailability(
|
||||
state=DownloadState.UNKNOWN, download_progress=None
|
||||
)
|
||||
availability: BeatmapAvailability = BeatmapAvailability(state=DownloadState.UNKNOWN, download_progress=None)
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
match_state: MatchUserState | None = None
|
||||
ruleset_id: int | None = None # freestyle
|
||||
@@ -358,9 +330,7 @@ class MultiplayerRoom(BaseModel):
|
||||
expired=item.expired,
|
||||
playlist_order=item.playlist_order,
|
||||
played_at=item.played_at,
|
||||
star_rating=item.beatmap.difficulty_rating
|
||||
if item.beatmap is not None
|
||||
else 0.0,
|
||||
star_rating=item.beatmap.difficulty_rating if item.beatmap is not None else 0.0,
|
||||
freestyle=item.freestyle,
|
||||
)
|
||||
)
|
||||
@@ -425,9 +395,7 @@ class MultiplayerQueue:
|
||||
user_item_groups[item.owner_id] = []
|
||||
user_item_groups[item.owner_id].append(item)
|
||||
|
||||
max_items = max(
|
||||
(len(items) for items in user_item_groups.values()), default=0
|
||||
)
|
||||
max_items = max((len(items) for items in user_item_groups.values()), default=0)
|
||||
|
||||
for i in range(max_items):
|
||||
current_set = []
|
||||
@@ -436,20 +404,13 @@ class MultiplayerQueue:
|
||||
current_set.append(items[i])
|
||||
|
||||
if is_first_set:
|
||||
current_set.sort(
|
||||
key=lambda item: (item.playlist_order, item.id)
|
||||
)
|
||||
current_set.sort(key=lambda item: (item.playlist_order, item.id))
|
||||
ordered_active_items.extend(current_set)
|
||||
first_set_order_by_user_id = {
|
||||
item.owner_id: idx
|
||||
for idx, item in enumerate(ordered_active_items)
|
||||
item.owner_id: idx for idx, item in enumerate(ordered_active_items)
|
||||
}
|
||||
else:
|
||||
current_set.sort(
|
||||
key=lambda item: first_set_order_by_user_id.get(
|
||||
item.owner_id, 0
|
||||
)
|
||||
)
|
||||
current_set.sort(key=lambda item: first_set_order_by_user_id.get(item.owner_id, 0))
|
||||
ordered_active_items.extend(current_set)
|
||||
|
||||
is_first_set = False
|
||||
@@ -464,9 +425,7 @@ class MultiplayerQueue:
|
||||
continue
|
||||
item.playlist_order = idx
|
||||
await Playlist.update(item, self.room.room_id, session)
|
||||
await self.hub.playlist_changed(
|
||||
self.server_room, item, beatmap_changed=False
|
||||
)
|
||||
await self.hub.playlist_changed(self.server_room, item, beatmap_changed=False)
|
||||
|
||||
async def update_current_item(self):
|
||||
upcoming_items = self.upcoming_items
|
||||
@@ -494,16 +453,7 @@ class MultiplayerQueue:
|
||||
raise InvokeException("You are not the host")
|
||||
|
||||
limit = HOST_LIMIT if is_host else PER_USER_LIMIT
|
||||
if (
|
||||
len(
|
||||
[
|
||||
True
|
||||
for u in self.room.playlist
|
||||
if u.owner_id == user.user_id and not u.expired
|
||||
]
|
||||
)
|
||||
>= limit
|
||||
):
|
||||
if len([True for u in self.room.playlist if u.owner_id == user.user_id and not u.expired]) >= limit:
|
||||
raise InvokeException(f"You can only have {limit} items in the queue")
|
||||
|
||||
if item.freestyle and len(item.allowed_mods) > 0:
|
||||
@@ -512,9 +462,7 @@ class MultiplayerQueue:
|
||||
async with with_db() as session:
|
||||
fetcher = await get_fetcher()
|
||||
async with session:
|
||||
beatmap = await Beatmap.get_or_fetch(
|
||||
session, fetcher, bid=item.beatmap_id
|
||||
)
|
||||
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
|
||||
if beatmap is None:
|
||||
raise InvokeException("Beatmap not found")
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
@@ -538,29 +486,19 @@ class MultiplayerQueue:
|
||||
async with with_db() as session:
|
||||
fetcher = await get_fetcher()
|
||||
async with session:
|
||||
beatmap = await Beatmap.get_or_fetch(
|
||||
session, fetcher, bid=item.beatmap_id
|
||||
)
|
||||
beatmap = await Beatmap.get_or_fetch(session, fetcher, bid=item.beatmap_id)
|
||||
if item.beatmap_checksum != beatmap.checksum:
|
||||
raise InvokeException("Checksum mismatch")
|
||||
|
||||
existing_item = next(
|
||||
(i for i in self.room.playlist if i.id == item.id), None
|
||||
)
|
||||
existing_item = next((i for i in self.room.playlist if i.id == item.id), None)
|
||||
if existing_item is None:
|
||||
raise InvokeException(
|
||||
"Attempted to change an item that doesn't exist"
|
||||
)
|
||||
raise InvokeException("Attempted to change an item that doesn't exist")
|
||||
|
||||
if existing_item.owner_id != user.user_id and self.room.host != user:
|
||||
raise InvokeException(
|
||||
"Attempted to change an item which is not owned by the user"
|
||||
)
|
||||
raise InvokeException("Attempted to change an item which is not owned by the user")
|
||||
|
||||
if existing_item.expired:
|
||||
raise InvokeException(
|
||||
"Attempted to change an item which has already been played"
|
||||
)
|
||||
raise InvokeException("Attempted to change an item which has already been played")
|
||||
|
||||
item.validate_playlist_item_mods()
|
||||
item.owner_id = user.user_id
|
||||
@@ -578,8 +516,7 @@ class MultiplayerQueue:
|
||||
await self.hub.playlist_changed(
|
||||
self.server_room,
|
||||
item,
|
||||
beatmap_changed=item.beatmap_checksum
|
||||
!= existing_item.beatmap_checksum,
|
||||
beatmap_changed=item.beatmap_checksum != existing_item.beatmap_checksum,
|
||||
)
|
||||
|
||||
async def remove_item(self, playlist_item_id: int, user: MultiplayerRoomUser):
|
||||
@@ -600,14 +537,10 @@ class MultiplayerQueue:
|
||||
raise InvokeException("The only item in the room cannot be removed")
|
||||
|
||||
if item.owner_id != user.user_id and self.room.host != user:
|
||||
raise InvokeException(
|
||||
"Attempted to remove an item which is not owned by the user"
|
||||
)
|
||||
raise InvokeException("Attempted to remove an item which is not owned by the user")
|
||||
|
||||
if item.expired:
|
||||
raise InvokeException(
|
||||
"Attempted to remove an item which has already been played"
|
||||
)
|
||||
raise InvokeException("Attempted to remove an item which has already been played")
|
||||
|
||||
async with with_db() as session:
|
||||
await Playlist.delete_item(item.id, self.room.room_id, session)
|
||||
@@ -668,9 +601,7 @@ class CountdownInfo:
|
||||
def __init__(self, countdown: MultiplayerCountdown):
|
||||
self.countdown = countdown
|
||||
self.duration = (
|
||||
countdown.time_remaining
|
||||
if countdown.time_remaining > timedelta(seconds=0)
|
||||
else timedelta(seconds=0)
|
||||
countdown.time_remaining if countdown.time_remaining > timedelta(seconds=0) else timedelta(seconds=0)
|
||||
)
|
||||
|
||||
|
||||
@@ -704,9 +635,7 @@ class MatchTypeHandler(ABC):
|
||||
async def handle_join(self, user: MultiplayerRoomUser): ...
|
||||
|
||||
@abstractmethod
|
||||
async def handle_request(
|
||||
self, user: MultiplayerRoomUser, request: MatchRequest
|
||||
): ...
|
||||
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
|
||||
|
||||
@abstractmethod
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
@@ -723,9 +652,7 @@ class HeadToHeadHandler(MatchTypeHandler):
|
||||
await self.hub.change_user_match_state(self.room, user)
|
||||
|
||||
@override
|
||||
async def handle_request(
|
||||
self, user: MultiplayerRoomUser, request: MatchRequest
|
||||
): ...
|
||||
async def handle_request(self, user: MultiplayerRoomUser, request: MatchRequest): ...
|
||||
|
||||
@override
|
||||
async def handle_leave(self, user: MultiplayerRoomUser): ...
|
||||
@@ -762,9 +689,7 @@ class TeamVersusHandler(MatchTypeHandler):
|
||||
|
||||
team_counts = defaultdict(int)
|
||||
for user in self.room.room.users:
|
||||
if user.match_state is not None and isinstance(
|
||||
user.match_state, TeamVersusUserState
|
||||
):
|
||||
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
|
||||
team_counts[user.match_state.team_id] += 1
|
||||
|
||||
if team_counts:
|
||||
@@ -798,9 +723,7 @@ class TeamVersusHandler(MatchTypeHandler):
|
||||
def get_details(self) -> MatchStartedEventDetail:
|
||||
teams: dict[int, Literal["blue", "red"]] = {}
|
||||
for user in self.room.room.users:
|
||||
if user.match_state is not None and isinstance(
|
||||
user.match_state, TeamVersusUserState
|
||||
):
|
||||
if user.match_state is not None and isinstance(user.match_state, TeamVersusUserState):
|
||||
teams[user.user_id] = "blue" if user.match_state.team_id == 1 else "red"
|
||||
detail = MatchStartedEventDetail(room_type="team_versus", team=teams)
|
||||
return detail
|
||||
@@ -843,9 +766,7 @@ class ServerMultiplayerRoom:
|
||||
self._tracked_countdown = {}
|
||||
|
||||
async def set_handler(self):
|
||||
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](
|
||||
self
|
||||
)
|
||||
self.match_type_handler = MATCH_TYPE_HANDLERS[self.room.settings.match_type](self)
|
||||
for i in self.room.users:
|
||||
await self.match_type_handler.handle_join(i)
|
||||
|
||||
@@ -871,9 +792,7 @@ class ServerMultiplayerRoom:
|
||||
info = CountdownInfo(countdown)
|
||||
self.room.active_countdowns.append(info.countdown)
|
||||
self._tracked_countdown[countdown.id] = info
|
||||
await self.hub.send_match_event(
|
||||
self, CountdownStartedEvent(countdown=info.countdown)
|
||||
)
|
||||
await self.hub.send_match_event(self, CountdownStartedEvent(countdown=info.countdown))
|
||||
info.task = asyncio.create_task(_countdown_task(self))
|
||||
|
||||
async def stop_countdown(self, countdown: MultiplayerCountdown):
|
||||
|
||||
@@ -53,7 +53,7 @@ class NotificationName(str, Enum):
|
||||
NotificationName.BEATMAP_OWNER_CHANGE: "beatmap_owner_change",
|
||||
NotificationName.BEATMAPSET_DISCUSSION_LOCK: "beatmapset_discussion",
|
||||
NotificationName.BEATMAPSET_DISCUSSION_POST_NEW: "beatmapset_discussion",
|
||||
NotificationName.BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM: "beatmapset_problem", # noqa: E501
|
||||
NotificationName.BEATMAPSET_DISCUSSION_QUALIFIED_PROBLEM: "beatmapset_problem",
|
||||
NotificationName.BEATMAPSET_DISCUSSION_REVIEW_NEW: "beatmapset_discussion",
|
||||
NotificationName.BEATMAPSET_DISCUSSION_UNLOCK: "beatmapset_discussion",
|
||||
NotificationName.BEATMAPSET_DISQUALIFY: "beatmapset_state",
|
||||
@@ -164,17 +164,11 @@ class ChannelMessageTeam(ChannelMessageBase):
|
||||
from app.database import TeamMember
|
||||
|
||||
user_team_id = (
|
||||
await session.exec(
|
||||
select(TeamMember.team_id).where(TeamMember.user_id == self._user.id)
|
||||
)
|
||||
await session.exec(select(TeamMember.team_id).where(TeamMember.user_id == self._user.id))
|
||||
).first()
|
||||
if not user_team_id:
|
||||
return []
|
||||
user_ids = (
|
||||
await session.exec(
|
||||
select(TeamMember.user_id).where(TeamMember.team_id == user_team_id)
|
||||
)
|
||||
).all()
|
||||
user_ids = (await session.exec(select(TeamMember.user_id).where(TeamMember.team_id == user_team_id))).all()
|
||||
return list(user_ids)
|
||||
|
||||
|
||||
|
||||
@@ -197,9 +197,7 @@ class SoloScoreSubmissionInfo(BaseModel):
|
||||
# check incompatible mods
|
||||
for mod in mods:
|
||||
if mod["acronym"] in incompatible_mods:
|
||||
raise ValueError(
|
||||
f"Mod {mod['acronym']} is incompatible with other mods"
|
||||
)
|
||||
raise ValueError(f"Mod {mod['acronym']} is incompatible with other mods")
|
||||
setting_mods = API_MODS[info.data["ruleset_id"]].get(mod["acronym"])
|
||||
if not setting_mods:
|
||||
raise ValueError(f"Invalid mod: {mod['acronym']}")
|
||||
|
||||
@@ -22,9 +22,7 @@ class SignalRUnionMessage(BaseModel):
|
||||
|
||||
class Transport(BaseModel):
|
||||
transport: str
|
||||
transfer_formats: list[str] = Field(
|
||||
default_factory=lambda: ["Binary", "Text"], alias="transferFormats"
|
||||
)
|
||||
transfer_formats: list[str] = Field(default_factory=lambda: ["Binary", "Text"], alias="transferFormats")
|
||||
|
||||
|
||||
class NegotiateResponse(BaseModel):
|
||||
|
||||
@@ -89,9 +89,7 @@ class LegacyReplayFrame(BaseModel):
|
||||
mouse_y: float | None = None
|
||||
button_state: int
|
||||
|
||||
header: Annotated[
|
||||
FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)
|
||||
]
|
||||
header: Annotated[FrameHeader | None, Field(default=None), SignalRMeta(member_ignore=True)]
|
||||
|
||||
|
||||
class FrameDataBundle(BaseModel):
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import re
|
||||
from typing import Literal, Union
|
||||
from typing import Literal
|
||||
|
||||
from app.auth import (
|
||||
authenticate_user,
|
||||
@@ -22,19 +22,19 @@ from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
from app.log import logger
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
from app.models.oauth import (
|
||||
OAuthErrorResponse,
|
||||
RegistrationRequestErrors,
|
||||
TokenResponse,
|
||||
UserRegistrationErrors,
|
||||
)
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
from app.models.score import GameMode
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.service.email_verification_service import (
|
||||
EmailVerificationService,
|
||||
LoginSessionService
|
||||
LoginSessionService,
|
||||
)
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.service.password_reset_service import password_reset_service
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
@@ -44,13 +44,9 @@ from sqlalchemy import text
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
def create_oauth_error_response(
|
||||
error: str, description: str, hint: str, status_code: int = 400
|
||||
):
|
||||
def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400):
|
||||
"""创建标准的 OAuth 错误响应"""
|
||||
error_data = OAuthErrorResponse(
|
||||
error=error, error_description=description, hint=hint, message=description
|
||||
)
|
||||
error_data = OAuthErrorResponse(error=error, error_description=description, hint=hint, message=description)
|
||||
return JSONResponse(status_code=status_code, content=error_data.model_dump())
|
||||
|
||||
|
||||
@@ -123,9 +119,7 @@ async def register_user(
|
||||
)
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
return JSONResponse(status_code=422, content={"form_error": errors.model_dump()})
|
||||
|
||||
try:
|
||||
# 获取客户端 IP 并查询地理位置
|
||||
@@ -137,10 +131,7 @@ async def register_user(
|
||||
geo_info = geoip.lookup(client_ip)
|
||||
if geo_info and geo_info.get("country_iso"):
|
||||
country_code = geo_info["country_iso"]
|
||||
logger.info(
|
||||
f"User {user_username} registering from "
|
||||
f"{client_ip}, country: {country_code}"
|
||||
)
|
||||
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
|
||||
else:
|
||||
logger.warning(f"Could not determine country for IP {client_ip}")
|
||||
except Exception as e:
|
||||
@@ -148,7 +139,7 @@ async def register_user(
|
||||
|
||||
# 创建新用户
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
result = await db.execute( # pyright: ignore[reportDeprecated]
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
||||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
|
||||
@@ -173,7 +164,6 @@ async def register_user(
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
assert new_user.id is not None, "New user ID should not be None"
|
||||
for i in [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]:
|
||||
statistics = UserStatistics(mode=i, user_id=new_user.id)
|
||||
db.add(statistics)
|
||||
@@ -193,36 +183,30 @@ async def register_user(
|
||||
logger.exception(f"Registration error for user {user_username}")
|
||||
|
||||
# 返回通用错误
|
||||
errors = RegistrationRequestErrors(
|
||||
message="An error occurred while creating your account. Please try again."
|
||||
)
|
||||
errors = RegistrationRequestErrors(message="An error occurred while creating your account. Please try again.")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"form_error": errors.model_dump()})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/oauth/token",
|
||||
response_model=Union[TokenResponse, ExtendedTokenResponse],
|
||||
response_model=TokenResponse | ExtendedTokenResponse,
|
||||
name="获取访问令牌",
|
||||
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
|
||||
)
|
||||
async def oauth_token(
|
||||
db: Database,
|
||||
request: Request,
|
||||
grant_type: Literal[
|
||||
"authorization_code", "refresh_token", "password", "client_credentials"
|
||||
] = Form(..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"),
|
||||
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
|
||||
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
|
||||
),
|
||||
client_id: int = Form(..., description="客户端 ID"),
|
||||
client_secret: str = Form(..., description="客户端密钥"),
|
||||
code: str | None = Form(None, description="授权码(仅授权码模式需要)"),
|
||||
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*')"),
|
||||
username: str | None = Form(None, description="用户名(仅密码模式需要)"),
|
||||
password: str | None = Form(None, description="密码(仅密码模式需要)"),
|
||||
refresh_token: str | None = Form(
|
||||
None, description="刷新令牌(仅刷新令牌模式需要)"
|
||||
),
|
||||
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
):
|
||||
@@ -303,37 +287,33 @@ async def oauth_token(
|
||||
await db.refresh(user)
|
||||
|
||||
# 获取用户信息和客户端信息
|
||||
user_id = getattr(user, "id")
|
||||
assert user_id is not None, "User ID should not be None after authentication"
|
||||
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
user_id = user.id
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
|
||||
|
||||
# 获取国家代码
|
||||
geo_info = geoip.lookup(ip_address)
|
||||
country_code = geo_info.get("country_iso", "XX")
|
||||
|
||||
|
||||
# 检查是否为新位置登录
|
||||
is_new_location = await LoginSessionService.check_new_location(
|
||||
db, user_id, ip_address, country_code
|
||||
)
|
||||
|
||||
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
|
||||
|
||||
# 创建登录会话记录
|
||||
login_session = await LoginSessionService.create_session(
|
||||
login_session = await LoginSessionService.create_session( # noqa: F841
|
||||
db, redis, user_id, ip_address, user_agent, country_code, is_new_location
|
||||
)
|
||||
|
||||
|
||||
# 如果是新位置登录,需要邮件验证
|
||||
if is_new_location and settings.enable_email_verification:
|
||||
# 刷新用户对象以确保属性已加载
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
# 发送邮件验证码
|
||||
verification_sent = await EmailVerificationService.send_verification_email(
|
||||
db, redis, user_id, user.username, user.email, ip_address, user_agent
|
||||
)
|
||||
|
||||
|
||||
# 记录需要二次验证的登录尝试
|
||||
await LoginLogService.record_login(
|
||||
db=db,
|
||||
@@ -343,14 +323,16 @@ async def oauth_token(
|
||||
login_method="password_pending_verification",
|
||||
notes=f"新位置登录,需要邮件验证 - IP: {ip_address}, 国家: {country_code}",
|
||||
)
|
||||
|
||||
|
||||
if not verification_sent:
|
||||
# 邮件发送失败,记录错误
|
||||
logger.error(f"[Auth] Failed to send email verification code for user {user_id}")
|
||||
elif is_new_location and not settings.enable_email_verification:
|
||||
# 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
logger.debug(f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}")
|
||||
logger.debug(
|
||||
f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}"
|
||||
)
|
||||
else:
|
||||
# 不是新位置登录,正常登录
|
||||
await LoginLogService.record_login(
|
||||
@@ -361,20 +343,17 @@ async def oauth_token(
|
||||
login_method="password",
|
||||
notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}",
|
||||
)
|
||||
|
||||
|
||||
# 无论是否新位置登录,都返回正常的token
|
||||
# session_verified状态通过/me接口的session_verified字段来体现
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
# 获取用户ID,避免触发延迟加载
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user_id)}, expires_delta=access_token_expires
|
||||
)
|
||||
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
assert user_id
|
||||
await store_token(
|
||||
db,
|
||||
user_id,
|
||||
@@ -423,9 +402,7 @@ async def oauth_token(
|
||||
|
||||
# 生成新的访问令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires
|
||||
)
|
||||
access_token = create_access_token(data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires)
|
||||
new_refresh_token = generate_refresh_token()
|
||||
|
||||
# 更新令牌
|
||||
@@ -489,17 +466,11 @@ async def oauth_token(
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
# 重新查询只获取ID,避免触发延迟加载
|
||||
id_result = await db.exec(select(User.id).where(User.username == username))
|
||||
user_id = id_result.first()
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user_id)}, expires_delta=access_token_expires
|
||||
)
|
||||
user_id = user.id
|
||||
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
assert user_id
|
||||
await store_token(
|
||||
db,
|
||||
user_id,
|
||||
@@ -539,9 +510,7 @@ async def oauth_token(
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": "3"}, expires_delta=access_token_expires
|
||||
)
|
||||
access_token = create_access_token(data={"sub": "3"}, expires_delta=access_token_expires)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
@@ -567,7 +536,7 @@ async def oauth_token(
|
||||
@router.post(
|
||||
"/password-reset/request",
|
||||
name="请求密码重置",
|
||||
description="通过邮箱请求密码重置验证码"
|
||||
description="通过邮箱请求密码重置验证码",
|
||||
)
|
||||
async def request_password_reset(
|
||||
request: Request,
|
||||
@@ -578,42 +547,26 @@ async def request_password_reset(
|
||||
请求密码重置
|
||||
"""
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
|
||||
|
||||
# 请求密码重置
|
||||
success, message = await password_reset_service.request_password_reset(
|
||||
email=email.lower().strip(),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
redis=redis
|
||||
redis=redis,
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"success": True,
|
||||
"message": message
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"success": True, "message": message})
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"success": False,
|
||||
"error": message
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=400, content={"success": False, "error": message})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/password-reset/reset",
|
||||
name="重置密码",
|
||||
description="使用验证码重置密码"
|
||||
)
|
||||
@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
|
||||
async def reset_password(
|
||||
request: Request,
|
||||
email: str = Form(..., description="邮箱地址"),
|
||||
@@ -625,32 +578,20 @@ async def reset_password(
|
||||
重置密码
|
||||
"""
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = get_client_ip(request)
|
||||
|
||||
|
||||
# 重置密码
|
||||
success, message = await password_reset_service.reset_password(
|
||||
email=email.lower().strip(),
|
||||
reset_code=reset_code.strip(),
|
||||
new_password=new_password,
|
||||
ip_address=ip_address,
|
||||
redis=redis
|
||||
redis=redis,
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"success": True,
|
||||
"message": message
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"success": True, "message": message})
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"success": False,
|
||||
"error": message
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=400, content={"success": False, "error": message})
|
||||
|
||||
@@ -43,9 +43,9 @@ async def get_notifications(
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
if settings.server_url is not None:
|
||||
notification_endpoint = f"{settings.server_url}notification-server".replace(
|
||||
"http://", "ws://"
|
||||
).replace("https://", "wss://")
|
||||
notification_endpoint = f"{settings.server_url}notification-server".replace("http://", "ws://").replace(
|
||||
"https://", "wss://"
|
||||
)
|
||||
else:
|
||||
notification_endpoint = "/notification-server"
|
||||
query = select(UserNotification).where(
|
||||
@@ -96,21 +96,15 @@ async def _get_notifications(
|
||||
query = base_query.where(UserNotification.notification_id == identity.id)
|
||||
if identity.object_id is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.object_id) == identity.object_id
|
||||
)
|
||||
col(UserNotification.notification).has(col(Notification.object_id) == identity.object_id)
|
||||
)
|
||||
if identity.object_type is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.object_type) == identity.object_type
|
||||
)
|
||||
col(UserNotification.notification).has(col(Notification.object_type) == identity.object_type)
|
||||
)
|
||||
if identity.category is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.category) == identity.category
|
||||
)
|
||||
col(UserNotification.notification).has(col(Notification.category) == identity.category)
|
||||
)
|
||||
result.update({n.notification_id: n for n in await session.exec(query)})
|
||||
return list(result.values())
|
||||
@@ -134,7 +128,6 @@ async def mark_notifications_as_read(
|
||||
for user_notification in user_notifications:
|
||||
user_notification.is_read = True
|
||||
|
||||
assert current_user.id
|
||||
await server.send_event(
|
||||
current_user.id,
|
||||
ChatEvent(
|
||||
|
||||
@@ -91,9 +91,7 @@ class Bot:
|
||||
if reply:
|
||||
await self._send_reply(user, channel, reply, session)
|
||||
|
||||
async def _send_message(
|
||||
self, channel: ChatChannel, content: str, session: AsyncSession
|
||||
) -> None:
|
||||
async def _send_message(self, channel: ChatChannel, content: str, session: AsyncSession) -> None:
|
||||
bot = await session.get(User, self.bot_user_id)
|
||||
if bot is None:
|
||||
return
|
||||
@@ -101,7 +99,6 @@ class Bot:
|
||||
if channel_id is None:
|
||||
return
|
||||
|
||||
assert bot.id is not None
|
||||
msg = ChatMessage(
|
||||
channel_id=channel_id,
|
||||
content=content,
|
||||
@@ -115,9 +112,7 @@ class Bot:
|
||||
resp = await ChatMessageResp.from_db(msg, session, bot)
|
||||
await server.send_message_to_channel(resp)
|
||||
|
||||
async def _ensure_pm_channel(
|
||||
self, user: User, session: AsyncSession
|
||||
) -> ChatChannel | None:
|
||||
async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
|
||||
user_id = user.id
|
||||
if user_id is None:
|
||||
return None
|
||||
@@ -160,9 +155,7 @@ bot = Bot()
|
||||
|
||||
|
||||
@bot.command("help")
|
||||
async def _help(
|
||||
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
async def _help(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
|
||||
cmds = sorted(bot._handlers.keys())
|
||||
if args:
|
||||
target = args[0].lower()
|
||||
@@ -175,9 +168,7 @@ async def _help(
|
||||
|
||||
|
||||
@bot.command("roll")
|
||||
def _roll(
|
||||
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
|
||||
if len(args) > 0 and args[0].isdigit():
|
||||
r = random.randint(1, int(args[0]))
|
||||
else:
|
||||
@@ -186,13 +177,9 @@ def _roll(
|
||||
|
||||
|
||||
@bot.command("stats")
|
||||
async def _stats(
|
||||
user: User, args: list[str], session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
async def _stats(user: User, args: list[str], session: AsyncSession, channel: ChatChannel) -> str:
|
||||
if len(args) >= 1:
|
||||
target_user = (
|
||||
await session.exec(select(User).where(User.username == args[0]))
|
||||
).first()
|
||||
target_user = (await session.exec(select(User).where(User.username == args[0]))).first()
|
||||
if not target_user:
|
||||
return f"User '{args[0]}' not found."
|
||||
else:
|
||||
@@ -202,14 +189,8 @@ async def _stats(
|
||||
if len(args) >= 2:
|
||||
gamemode = GameMode.parse(args[1].upper())
|
||||
if gamemode is None:
|
||||
subquery = (
|
||||
select(func.max(Score.id))
|
||||
.where(Score.user_id == target_user.id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
last_score = (
|
||||
await session.exec(select(Score).where(Score.id == subquery))
|
||||
).first()
|
||||
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
|
||||
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
|
||||
if last_score is not None:
|
||||
gamemode = last_score.gamemode
|
||||
else:
|
||||
@@ -295,9 +276,7 @@ async def _mp_host(
|
||||
return "Usage: !mp host <username>"
|
||||
|
||||
username = args[0]
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
|
||||
@@ -362,24 +341,18 @@ async def _mp_team(
|
||||
if team is None:
|
||||
return "Invalid team colour. Use 'red' or 'blue'."
|
||||
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
user_client = MultiplayerHubs.get_client_by_id(str(user_id))
|
||||
if not user_client:
|
||||
return f"User '{username}' is not in the room."
|
||||
if (
|
||||
user_client.user_id != signalr_client.user_id
|
||||
and room.room.host.user_id != signalr_client.user_id
|
||||
):
|
||||
assert room.room.host
|
||||
if user_client.user_id != signalr_client.user_id and room.room.host.user_id != signalr_client.user_id:
|
||||
return "You are not allowed to change other users' teams."
|
||||
|
||||
try:
|
||||
await MultiplayerHubs.SendMatchRequest(
|
||||
user_client, ChangeTeamRequest(team_id=team)
|
||||
)
|
||||
await MultiplayerHubs.SendMatchRequest(user_client, ChangeTeamRequest(team_id=team))
|
||||
return ""
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
@@ -414,9 +387,7 @@ async def _mp_kick(
|
||||
return "Usage: !mp kick <username>"
|
||||
|
||||
username = args[0]
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
|
||||
@@ -456,10 +427,7 @@ async def _mp_map(
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id)
|
||||
if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode:
|
||||
return (
|
||||
f"Cannot convert to {playmode.value}. "
|
||||
f"Original mode is {beatmap.mode.value}."
|
||||
)
|
||||
return f"Cannot convert to {playmode.value}. Original mode is {beatmap.mode.value}."
|
||||
except HTTPError:
|
||||
return "Beatmap not found"
|
||||
|
||||
@@ -530,9 +498,7 @@ async def _mp_mods(
|
||||
if freestyle:
|
||||
item.allowed_mods = []
|
||||
elif freemod:
|
||||
item.allowed_mods = get_available_mods(
|
||||
current_item.ruleset_id, required_mods
|
||||
)
|
||||
item.allowed_mods = get_available_mods(current_item.ruleset_id, required_mods)
|
||||
else:
|
||||
item.allowed_mods = allowed_mods
|
||||
item.required_mods = required_mods
|
||||
@@ -601,14 +567,9 @@ async def _score(
|
||||
include_fail: bool = False,
|
||||
gamemode: GameMode | None = None,
|
||||
) -> str:
|
||||
q = (
|
||||
select(Score)
|
||||
.where(Score.user_id == user_id)
|
||||
.order_by(col(Score.id).desc())
|
||||
.options(joinedload(Score.beatmap))
|
||||
)
|
||||
q = select(Score).where(Score.user_id == user_id).order_by(col(Score.id).desc()).options(joinedload(Score.beatmap))
|
||||
if not include_fail:
|
||||
q = q.where(Score.passed.is_(True))
|
||||
q = q.where(col(Score.passed).is_(True))
|
||||
if gamemode is not None:
|
||||
q = q.where(Score.gamemode == gamemode)
|
||||
|
||||
@@ -619,17 +580,13 @@ async def _score(
|
||||
result = f"""{score.beatmap.beatmapset.title} [{score.beatmap.version}] ({score.gamemode.name.lower()})
|
||||
Played at {score.started_at}
|
||||
{score.pp:.2f}pp {score.accuracy:.2%} {",".join(mod_to_save(score.mods))} {score.rank.name.upper()}
|
||||
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}""" # noqa: E501
|
||||
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}"""
|
||||
if score.gamemode == GameMode.MANIA:
|
||||
keys = next(
|
||||
(mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None
|
||||
)
|
||||
keys = next((mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None)
|
||||
if keys is None:
|
||||
keys = f"{int(score.beatmap.cs)}K"
|
||||
p_d_g = f"{score.ngeki / score.n300:.2f}:1" if score.n300 > 0 else "inf:1"
|
||||
result += (
|
||||
f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
|
||||
)
|
||||
result += f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -38,27 +38,18 @@ class UpdateResponse(BaseModel):
|
||||
)
|
||||
async def get_update(
|
||||
session: Database,
|
||||
history_since: int | None = Query(
|
||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||
),
|
||||
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
includes: list[str] = Query(
|
||||
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
|
||||
),
|
||||
includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
resp = UpdateResponse()
|
||||
if "presence" in includes:
|
||||
assert current_user.id
|
||||
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||
for channel_id in channel_ids:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||||
if db_channel:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_type = db_channel.type
|
||||
@@ -69,34 +60,20 @@ async def get_update(
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
if "silences" in includes:
|
||||
if history_since:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(col(SilenceUser.id) > history_since)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
elif since:
|
||||
msg = await session.get(ChatMessage, since)
|
||||
if msg:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
col(SilenceUser.banned_at) > msg.timestamp
|
||||
)
|
||||
)
|
||||
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
return resp
|
||||
|
||||
|
||||
@@ -115,15 +92,9 @@ async def join_channel(
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -145,15 +116,9 @@ async def leave_channel(
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -173,27 +138,20 @@ async def get_channel_list(
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
channels = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
|
||||
)
|
||||
).all()
|
||||
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
|
||||
results = []
|
||||
for channel in channels:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
channel_type = channel.type
|
||||
|
||||
assert channel_id is not None
|
||||
results.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
return results
|
||||
@@ -219,15 +177,9 @@ async def get_channel(
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -237,8 +189,6 @@ async def get_channel(
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
assert channel_id is not None
|
||||
|
||||
users = []
|
||||
if channel_type == ChannelType.PM:
|
||||
user_ids = channel_name.split("_")[1:]
|
||||
@@ -259,9 +209,7 @@ async def get_channel(
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -284,9 +232,7 @@ class CreateChannelReq(BaseModel):
|
||||
raise ValueError("target_id must be set for PM channels")
|
||||
else:
|
||||
if self.target_ids is None or self.channel is None or self.message is None:
|
||||
raise ValueError(
|
||||
"target_ids, channel, and message must be set for ANNOUNCE channels"
|
||||
)
|
||||
raise ValueError("target_ids, channel, and message must be set for ANNOUNCE channels")
|
||||
return self
|
||||
|
||||
|
||||
@@ -312,24 +258,20 @@ async def create_channel(
|
||||
raise HTTPException(status_code=403, detail=block)
|
||||
|
||||
channel = await ChatChannel.get_pm_channel(
|
||||
current_user.id, # pyright: ignore[reportArgumentType]
|
||||
current_user.id,
|
||||
req.target_id, # pyright: ignore[reportArgumentType]
|
||||
session,
|
||||
)
|
||||
channel_name = f"pm_{current_user.id}_{req.target_id}"
|
||||
else:
|
||||
channel_name = req.channel.name if req.channel else "Unnamed Channel"
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel_name)
|
||||
)
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
|
||||
channel = result.first()
|
||||
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
name=channel_name,
|
||||
description=req.channel.description
|
||||
if req.channel
|
||||
else "Private message channel",
|
||||
description=req.channel.description if req.channel else "Private message channel",
|
||||
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
|
||||
)
|
||||
session.add(channel)
|
||||
@@ -340,16 +282,13 @@ async def create_channel(
|
||||
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
else:
|
||||
target_users = await session.exec(
|
||||
select(User).where(col(User.id).in_(req.target_ids or []))
|
||||
)
|
||||
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
|
||||
await server.batch_join_channel([*target_users, current_user], channel, session)
|
||||
|
||||
await server.join_channel(current_user, channel, session)
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id
|
||||
|
||||
return await ChatChannelResp.from_db(
|
||||
channel,
|
||||
|
||||
@@ -41,33 +41,19 @@ class KeepAliveResp(BaseModel):
|
||||
)
|
||||
async def keep_alive(
|
||||
session: Database,
|
||||
history_since: int | None = Query(
|
||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||
),
|
||||
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
):
|
||||
resp = KeepAliveResp()
|
||||
if history_since:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(col(SilenceUser.id) > history_since)
|
||||
)
|
||||
).all()
|
||||
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
elif since:
|
||||
msg = await session.get(ChatMessage, since)
|
||||
if msg:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
col(SilenceUser.banned_at) > msg.timestamp
|
||||
)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
|
||||
return resp
|
||||
|
||||
@@ -93,15 +79,9 @@ async def send_message(
|
||||
):
|
||||
# 使用明确的查询来获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -111,9 +91,6 @@ async def send_message(
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
assert channel_id is not None
|
||||
assert current_user.id
|
||||
|
||||
# 使用 Redis 消息系统发送消息 - 立即返回
|
||||
resp = await redis_message_system.send_message(
|
||||
channel_id=channel_id,
|
||||
@@ -125,9 +102,7 @@ async def send_message(
|
||||
|
||||
# 立即广播消息给所有客户端
|
||||
is_bot_command = req.message.startswith("!")
|
||||
await server.send_message_to_channel(
|
||||
resp, is_bot_command and channel_type == ChannelType.PUBLIC
|
||||
)
|
||||
await server.send_message_to_channel(resp, is_bot_command and channel_type == ChannelType.PUBLIC)
|
||||
|
||||
# 处理机器人命令
|
||||
if is_bot_command:
|
||||
@@ -147,14 +122,10 @@ async def send_message(
|
||||
if channel_type == ChannelType.PM:
|
||||
user_ids = channel_name.split("_")[1:]
|
||||
await server.new_private_notification(
|
||||
ChannelMessage.init(
|
||||
temp_msg, current_user, [int(u) for u in user_ids], channel_type
|
||||
)
|
||||
ChannelMessage.init(temp_msg, current_user, [int(u) for u in user_ids], channel_type)
|
||||
)
|
||||
elif channel_type == ChannelType.TEAM:
|
||||
await server.new_private_notification(
|
||||
ChannelMessageTeam.init(temp_msg, current_user)
|
||||
)
|
||||
await server.new_private_notification(ChannelMessageTeam.init(temp_msg, current_user))
|
||||
|
||||
return resp
|
||||
|
||||
@@ -176,22 +147,15 @@ async def get_message(
|
||||
):
|
||||
# 使用明确的查询获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = db_channel.channel_id
|
||||
assert channel_id is not None
|
||||
|
||||
# 使用 Redis 消息系统获取消息
|
||||
try:
|
||||
@@ -230,23 +194,15 @@ async def mark_as_read(
|
||||
):
|
||||
# 使用明确的查询获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
# 立即提取需要的属性
|
||||
channel_id = db_channel.channel_id
|
||||
assert channel_id
|
||||
assert current_user.id
|
||||
await server.mark_as_read(channel_id, current_user.id, message)
|
||||
|
||||
|
||||
@@ -283,7 +239,6 @@ async def create_new_pm(
|
||||
if not is_can_pm:
|
||||
raise HTTPException(status_code=403, detail=block)
|
||||
|
||||
assert user_id
|
||||
channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session)
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
@@ -297,7 +252,6 @@ async def create_new_pm(
|
||||
await session.refresh(target)
|
||||
await session.refresh(current_user)
|
||||
|
||||
assert channel.channel_id
|
||||
await server.batch_join_channel([target, current_user], channel, session)
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel, session, current_user, redis, server.channels[channel.channel_id]
|
||||
|
||||
@@ -17,6 +17,7 @@ from app.log import logger
|
||||
from app.models.chat import ChatEvent
|
||||
from app.models.notification import NotificationDetail
|
||||
from app.service.subscribers.chat import ChatSubscriber
|
||||
from app.utils import bg_tasks
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
@@ -37,20 +38,11 @@ class ChatServer:
|
||||
self.ChatSubscriber.chat_server = self
|
||||
self._subscribed = False
|
||||
|
||||
def _add_task(self, task):
|
||||
task = asyncio.create_task(task)
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
def connect(self, user_id: int, client: WebSocket):
|
||||
self.connect_client[user_id] = client
|
||||
|
||||
def get_user_joined_channel(self, user_id: int) -> list[int]:
|
||||
return [
|
||||
channel_id
|
||||
for channel_id, users in self.channels.items()
|
||||
if user_id in users
|
||||
]
|
||||
return [channel_id for channel_id, users in self.channels.items() if user_id in users]
|
||||
|
||||
async def disconnect(self, user: User, session: AsyncSession):
|
||||
user_id = user.id
|
||||
@@ -61,9 +53,7 @@ class ChatServer:
|
||||
channel.remove(user_id)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
|
||||
).first()
|
||||
if db_channel:
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
@@ -93,11 +83,10 @@ class ChatServer:
|
||||
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
||||
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
||||
|
||||
async def send_message_to_channel(
|
||||
self, message: ChatMessageResp, is_bot_command: bool = False
|
||||
):
|
||||
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
|
||||
logger.info(
|
||||
f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}"
|
||||
f"Sending message to channel {message.channel_id}, message_id: "
|
||||
f"{message.message_id}, is_bot_command: {is_bot_command}"
|
||||
)
|
||||
|
||||
event = ChatEvent(
|
||||
@@ -106,62 +95,44 @@ class ChatServer:
|
||||
)
|
||||
if is_bot_command:
|
||||
logger.info(f"Sending bot command to user {message.sender_id}")
|
||||
self._add_task(self.send_event(message.sender_id, event))
|
||||
bg_tasks.add_task(self.send_event, message.sender_id, event)
|
||||
else:
|
||||
# 总是广播消息,无论是临时ID还是真实ID
|
||||
logger.info(
|
||||
f"Broadcasting message to all users in channel {message.channel_id}"
|
||||
)
|
||||
self._add_task(
|
||||
self.broadcast(
|
||||
message.channel_id,
|
||||
event,
|
||||
)
|
||||
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
|
||||
bg_tasks.add_task(
|
||||
self.broadcast,
|
||||
message.channel_id,
|
||||
event,
|
||||
)
|
||||
|
||||
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
||||
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
||||
if message.message_id and message.message_id > 0:
|
||||
await self.mark_as_read(
|
||||
message.channel_id, message.sender_id, message.message_id
|
||||
)
|
||||
await self.redis.set(
|
||||
f"chat:{message.channel_id}:last_msg", message.message_id
|
||||
)
|
||||
logger.info(
|
||||
f"Updated last message ID for channel {message.channel_id} to {message.message_id}"
|
||||
)
|
||||
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
|
||||
else:
|
||||
logger.debug(
|
||||
f"Skipping last message update for message ID: {message.message_id}"
|
||||
)
|
||||
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
|
||||
|
||||
async def batch_join_channel(
|
||||
self, users: list[User], channel: ChatChannel, session: AsyncSession
|
||||
):
|
||||
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
|
||||
not_joined = []
|
||||
|
||||
if channel_id not in self.channels:
|
||||
self.channels[channel_id] = []
|
||||
for user in users:
|
||||
assert user.id is not None
|
||||
if user.id not in self.channels[channel_id]:
|
||||
self.channels[channel_id].append(user.id)
|
||||
not_joined.append(user)
|
||||
|
||||
for user in not_joined:
|
||||
assert user.id is not None
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id]
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
await self.send_event(
|
||||
user.id,
|
||||
@@ -171,13 +142,9 @@ class ChatServer:
|
||||
),
|
||||
)
|
||||
|
||||
async def join_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
) -> ChatChannelResp:
|
||||
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
assert user_id is not None
|
||||
|
||||
if channel_id not in self.channels:
|
||||
self.channels[channel_id] = []
|
||||
@@ -202,13 +169,9 @@ class ChatServer:
|
||||
|
||||
return channel_resp
|
||||
|
||||
async def leave_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
) -> None:
|
||||
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
assert user_id is not None
|
||||
|
||||
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
||||
self.channels[channel_id].remove(user_id)
|
||||
@@ -221,9 +184,7 @@ class ChatServer:
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels.get(channel_id)
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
await self.send_event(
|
||||
user_id,
|
||||
@@ -236,11 +197,7 @@ class ChatServer:
|
||||
async def join_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||||
if db_channel is None:
|
||||
return
|
||||
|
||||
@@ -253,11 +210,7 @@ class ChatServer:
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||||
if db_channel is None:
|
||||
return
|
||||
|
||||
@@ -270,13 +223,7 @@ class ChatServer:
|
||||
async def new_private_notification(self, detail: NotificationDetail):
|
||||
async with with_db() as session:
|
||||
id = await insert_notification(session, detail)
|
||||
users = (
|
||||
await session.exec(
|
||||
select(UserNotification).where(
|
||||
UserNotification.notification_id == id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
users = (await session.exec(select(UserNotification).where(UserNotification.notification_id == id))).all()
|
||||
for user_notification in users:
|
||||
data = user_notification.notification.model_dump()
|
||||
data["is_read"] = user_notification.is_read
|
||||
@@ -308,9 +255,7 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
|
||||
await ws.close(code=1000)
|
||||
break
|
||||
except WebSocketDisconnect as e:
|
||||
logger.info(
|
||||
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
|
||||
)
|
||||
logger.info(f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}")
|
||||
except RuntimeError as e:
|
||||
if "disconnect message" in str(e):
|
||||
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
|
||||
@@ -332,11 +277,7 @@ async def chat_websocket(
|
||||
|
||||
async for session in factory():
|
||||
token = authorization[7:]
|
||||
if (
|
||||
user := await get_current_user(
|
||||
session, SecurityScopes(scopes=["chat.read"]), token_pw=token
|
||||
)
|
||||
) is None:
|
||||
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
@@ -346,12 +287,9 @@ async def chat_websocket(
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
user_id = user.id
|
||||
assert user_id
|
||||
server.connect(user_id, websocket)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
|
||||
if db_channel is not None:
|
||||
await server.join_channel(user, db_channel, session)
|
||||
|
||||
|
||||
@@ -2,22 +2,20 @@
|
||||
密码重置管理接口
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
from __future__ import annotations
|
||||
|
||||
from app.dependencies.database import get_redis
|
||||
from app.service.password_reset_service import password_reset_service
|
||||
from app.log import logger
|
||||
from app.service.password_reset_service import password_reset_service
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
|
||||
router = APIRouter(prefix="/admin/password-reset", tags=["密码重置管理"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/status/{email}",
|
||||
name="查询重置状态",
|
||||
description="查询指定邮箱的密码重置状态"
|
||||
)
|
||||
@router.get("/status/{email}", name="查询重置状态", description="查询指定邮箱的密码重置状态")
|
||||
async def get_password_reset_status(
|
||||
email: str,
|
||||
redis: Redis = Depends(get_redis),
|
||||
@@ -25,28 +23,16 @@ async def get_password_reset_status(
|
||||
"""查询密码重置状态"""
|
||||
try:
|
||||
info = await password_reset_service.get_reset_code_info(email, redis)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"success": True,
|
||||
"data": info
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"success": True, "data": info})
|
||||
except Exception as e:
|
||||
logger.error(f"[Admin] Failed to get password reset status for {email}: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "获取状态失败"
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"success": False, "error": "获取状态失败"})
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/cleanup/{email}",
|
||||
name="清理重置数据",
|
||||
description="强制清理指定邮箱的密码重置数据"
|
||||
description="强制清理指定邮箱的密码重置数据",
|
||||
)
|
||||
async def force_cleanup_reset(
|
||||
email: str,
|
||||
@@ -55,38 +41,23 @@ async def force_cleanup_reset(
|
||||
"""强制清理密码重置数据"""
|
||||
try:
|
||||
success = await password_reset_service.force_cleanup_user_reset(email, redis)
|
||||
|
||||
|
||||
if success:
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"success": True,
|
||||
"message": f"已清理邮箱 {email} 的重置数据"
|
||||
}
|
||||
content={"success": True, "message": f"已清理邮箱 {email} 的重置数据"},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "清理失败"
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"success": False, "error": "清理失败"})
|
||||
except Exception as e:
|
||||
logger.error(f"[Admin] Failed to cleanup password reset for {email}: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "清理操作失败"
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"success": False, "error": "清理操作失败"})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cleanup/expired",
|
||||
name="清理过期验证码",
|
||||
description="清理所有过期的密码重置验证码"
|
||||
description="清理所有过期的密码重置验证码",
|
||||
)
|
||||
async def cleanup_expired_codes(
|
||||
redis: Redis = Depends(get_redis),
|
||||
@@ -99,25 +70,15 @@ async def cleanup_expired_codes(
|
||||
content={
|
||||
"success": True,
|
||||
"message": f"已清理 {count} 个过期的验证码",
|
||||
"cleaned_count": count
|
||||
}
|
||||
"cleaned_count": count,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Admin] Failed to cleanup expired codes: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "清理操作失败"
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"success": False, "error": "清理操作失败"})
|
||||
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
name="重置统计",
|
||||
description="获取密码重置的统计信息"
|
||||
)
|
||||
@router.get("/stats", name="重置统计", description="获取密码重置的统计信息")
|
||||
async def get_reset_statistics(
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
@@ -126,53 +87,42 @@ async def get_reset_statistics(
|
||||
# 获取所有重置相关的键
|
||||
reset_keys = await redis.keys("password_reset:code:*")
|
||||
rate_limit_keys = await redis.keys("password_reset:rate_limit:*")
|
||||
|
||||
|
||||
active_resets = 0
|
||||
used_resets = 0
|
||||
active_rate_limits = 0
|
||||
|
||||
|
||||
# 统计活跃重置
|
||||
for key in reset_keys:
|
||||
data_str = await redis.get(key)
|
||||
if data_str:
|
||||
try:
|
||||
import json
|
||||
|
||||
data = json.loads(data_str)
|
||||
if data.get("used", False):
|
||||
used_resets += 1
|
||||
else:
|
||||
active_resets += 1
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# 统计频率限制
|
||||
for key in rate_limit_keys:
|
||||
ttl = await redis.ttl(key)
|
||||
if ttl > 0:
|
||||
active_rate_limits += 1
|
||||
|
||||
|
||||
stats = {
|
||||
"total_reset_codes": len(reset_keys),
|
||||
"active_resets": active_resets,
|
||||
"used_resets": used_resets,
|
||||
"active_rate_limits": active_rate_limits,
|
||||
"total_rate_limit_keys": len(rate_limit_keys)
|
||||
"total_rate_limit_keys": len(rate_limit_keys),
|
||||
}
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"success": True,
|
||||
"data": stats
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return JSONResponse(status_code=200, content={"success": True, "data": stats})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Admin] Failed to get reset statistics: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "获取统计信息失败"
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"success": False, "error": "获取统计信息失败"})
|
||||
|
||||
@@ -26,7 +26,7 @@ async def create_oauth_app(
|
||||
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
result = await session.execute( # pyright: ignore[reportDeprecated]
|
||||
result = await session.execute(
|
||||
text(
|
||||
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
||||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'oauth_clients'"
|
||||
@@ -84,9 +84,7 @@ async def get_user_oauth_apps(
|
||||
session: Database,
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
oauth_apps = await session.exec(
|
||||
select(OAuthClient).where(OAuthClient.owner_id == current_user.id)
|
||||
)
|
||||
oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id))
|
||||
return [
|
||||
{
|
||||
"name": app.name,
|
||||
@@ -113,13 +111,9 @@ async def delete_oauth_app(
|
||||
if not oauth_client:
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
if oauth_client.owner_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Forbidden: Not the owner of this app"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
|
||||
|
||||
tokens = await session.exec(
|
||||
select(OAuthToken).where(OAuthToken.client_id == client_id)
|
||||
)
|
||||
tokens = await session.exec(select(OAuthToken).where(OAuthToken.client_id == client_id))
|
||||
for token in tokens:
|
||||
await session.delete(token)
|
||||
|
||||
@@ -144,9 +138,7 @@ async def update_oauth_app(
|
||||
if not oauth_client:
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
if oauth_client.owner_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Forbidden: Not the owner of this app"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
|
||||
|
||||
oauth_client.name = name
|
||||
oauth_client.description = description
|
||||
@@ -176,14 +168,10 @@ async def refresh_secret(
|
||||
if not oauth_client:
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
if oauth_client.owner_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Forbidden: Not the owner of this app"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
|
||||
|
||||
oauth_client.client_secret = secrets.token_hex()
|
||||
tokens = await session.exec(
|
||||
select(OAuthToken).where(OAuthToken.client_id == client_id)
|
||||
)
|
||||
tokens = await session.exec(select(OAuthToken).where(OAuthToken.client_id == client_id))
|
||||
for token in tokens:
|
||||
await session.delete(token)
|
||||
|
||||
@@ -215,9 +203,7 @@ async def generate_oauth_code(
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
|
||||
if redirect_uri not in client.redirect_uris:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Redirect URI not allowed for this client"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="Redirect URI not allowed for this client")
|
||||
|
||||
code = secrets.token_urlsafe(80)
|
||||
await redis.hset( # pyright: ignore[reportGeneralTypeIssues]
|
||||
|
||||
@@ -50,12 +50,8 @@ async def check_user_relationship(
|
||||
)
|
||||
).first()
|
||||
|
||||
is_followed = bool(
|
||||
target_relationship and target_relationship.type == RelationshipType.FOLLOW
|
||||
)
|
||||
is_following = bool(
|
||||
my_relationship and my_relationship.type == RelationshipType.FOLLOW
|
||||
)
|
||||
is_followed = bool(target_relationship and target_relationship.type == RelationshipType.FOLLOW)
|
||||
is_following = bool(my_relationship and my_relationship.type == RelationshipType.FOLLOW)
|
||||
|
||||
return CheckResponse(
|
||||
is_followed=is_followed,
|
||||
|
||||
@@ -40,16 +40,13 @@ async def create_team(
|
||||
支持的图片格式: PNG、JPEG、GIF
|
||||
"""
|
||||
user_id = current_user.id
|
||||
assert user_id
|
||||
if (await current_user.awaitable_attrs.team_membership) is not None:
|
||||
raise HTTPException(status_code=403, detail="You are already in a team")
|
||||
|
||||
is_existed = (await session.exec(select(exists()).where(Team.name == name))).first()
|
||||
if is_existed:
|
||||
raise HTTPException(status_code=409, detail="Name already exists")
|
||||
is_existed = (
|
||||
await session.exec(select(exists()).where(Team.short_name == short_name))
|
||||
).first()
|
||||
is_existed = (await session.exec(select(exists()).where(Team.short_name == short_name))).first()
|
||||
if is_existed:
|
||||
raise HTTPException(status_code=409, detail="Short name already exists")
|
||||
|
||||
@@ -101,7 +98,6 @@ async def update_team(
|
||||
"""
|
||||
team = await session.get(Team, team_id)
|
||||
user_id = current_user.id
|
||||
assert user_id
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail="Team not found")
|
||||
if team.leader_id != user_id:
|
||||
@@ -110,9 +106,7 @@ async def update_team(
|
||||
is_existed = (await session.exec(select(exists()).where(Team.name == name))).first()
|
||||
if is_existed:
|
||||
raise HTTPException(status_code=409, detail="Name already exists")
|
||||
is_existed = (
|
||||
await session.exec(select(exists()).where(Team.short_name == short_name))
|
||||
).first()
|
||||
is_existed = (await session.exec(select(exists()).where(Team.short_name == short_name))).first()
|
||||
if is_existed:
|
||||
raise HTTPException(status_code=409, detail="Short name already exists")
|
||||
|
||||
@@ -132,20 +126,12 @@ async def update_team(
|
||||
team.cover_url = await storage.get_file_url(storage_path)
|
||||
|
||||
if leader_id is not None:
|
||||
if not (
|
||||
await session.exec(select(exists()).where(User.id == leader_id))
|
||||
).first():
|
||||
if not (await session.exec(select(exists()).where(User.id == leader_id))).first():
|
||||
raise HTTPException(status_code=404, detail="Leader not found")
|
||||
if not (
|
||||
await session.exec(
|
||||
select(TeamMember).where(
|
||||
TeamMember.user_id == leader_id, TeamMember.team_id == team.id
|
||||
)
|
||||
)
|
||||
await session.exec(select(TeamMember).where(TeamMember.user_id == leader_id, TeamMember.team_id == team.id))
|
||||
).first():
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Leader is not a member of the team"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="Leader is not a member of the team")
|
||||
team.leader_id = leader_id
|
||||
|
||||
await session.commit()
|
||||
@@ -166,9 +152,7 @@ async def delete_team(
|
||||
if team.leader_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="You are not the team leader")
|
||||
|
||||
team_members = await session.exec(
|
||||
select(TeamMember).where(TeamMember.team_id == team_id)
|
||||
)
|
||||
team_members = await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))
|
||||
for member in team_members:
|
||||
await session.delete(member)
|
||||
|
||||
@@ -186,15 +170,10 @@ async def get_team(
|
||||
session: Database,
|
||||
team_id: int = Path(..., description="战队 ID"),
|
||||
):
|
||||
members = (
|
||||
await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))
|
||||
).all()
|
||||
members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all()
|
||||
return TeamQueryResp(
|
||||
team=members[0].team,
|
||||
members=[
|
||||
await UserResp.from_db(m.user, session, include=BASE_INCLUDES)
|
||||
for m in members
|
||||
],
|
||||
members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members],
|
||||
)
|
||||
|
||||
|
||||
@@ -213,15 +192,11 @@ async def request_join_team(
|
||||
|
||||
if (
|
||||
await session.exec(
|
||||
select(exists()).where(
|
||||
TeamRequest.team_id == team_id, TeamRequest.user_id == current_user.id
|
||||
)
|
||||
select(exists()).where(TeamRequest.team_id == team_id, TeamRequest.user_id == current_user.id)
|
||||
)
|
||||
).first():
|
||||
raise HTTPException(status_code=409, detail="Join request already exists")
|
||||
team_request = TeamRequest(
|
||||
user_id=current_user.id, team_id=team_id, requested_at=datetime.now(UTC)
|
||||
)
|
||||
team_request = TeamRequest(user_id=current_user.id, team_id=team_id, requested_at=datetime.now(UTC))
|
||||
session.add(team_request)
|
||||
await session.commit()
|
||||
await session.refresh(team_request)
|
||||
@@ -229,9 +204,7 @@ async def request_join_team(
|
||||
|
||||
|
||||
@router.post("/team/{team_id}/{user_id}/request", name="接受加入请求", status_code=204)
|
||||
@router.delete(
|
||||
"/team/{team_id}/{user_id}/request", name="拒绝加入请求", status_code=204
|
||||
)
|
||||
@router.delete("/team/{team_id}/{user_id}/request", name="拒绝加入请求", status_code=204)
|
||||
async def handle_request(
|
||||
req: Request,
|
||||
session: Database,
|
||||
@@ -247,11 +220,7 @@ async def handle_request(
|
||||
raise HTTPException(status_code=403, detail="You are not the team leader")
|
||||
|
||||
team_request = (
|
||||
await session.exec(
|
||||
select(TeamRequest).where(
|
||||
TeamRequest.team_id == team_id, TeamRequest.user_id == user_id
|
||||
)
|
||||
)
|
||||
await session.exec(select(TeamRequest).where(TeamRequest.team_id == team_id, TeamRequest.user_id == user_id))
|
||||
).first()
|
||||
if not team_request:
|
||||
raise HTTPException(status_code=404, detail="Join request not found")
|
||||
@@ -261,16 +230,10 @@ async def handle_request(
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if req.method == "POST":
|
||||
if (
|
||||
await session.exec(select(exists()).where(TeamMember.user_id == user_id))
|
||||
).first():
|
||||
raise HTTPException(
|
||||
status_code=409, detail="User is already a member of the team"
|
||||
)
|
||||
if (await session.exec(select(exists()).where(TeamMember.user_id == user_id))).first():
|
||||
raise HTTPException(status_code=409, detail="User is already a member of the team")
|
||||
|
||||
session.add(
|
||||
TeamMember(user_id=user_id, team_id=team_id, joined_at=datetime.now(UTC))
|
||||
)
|
||||
session.add(TeamMember(user_id=user_id, team_id=team_id, joined_at=datetime.now(UTC)))
|
||||
|
||||
await server.new_private_notification(TeamApplicationAccept.init(team_request))
|
||||
else:
|
||||
@@ -294,19 +257,13 @@ async def kick_member(
|
||||
raise HTTPException(status_code=403, detail="You are not the team leader")
|
||||
|
||||
team_member = (
|
||||
await session.exec(
|
||||
select(TeamMember).where(
|
||||
TeamMember.team_id == team_id, TeamMember.user_id == user_id
|
||||
)
|
||||
)
|
||||
await session.exec(select(TeamMember).where(TeamMember.team_id == team_id, TeamMember.user_id == user_id))
|
||||
).first()
|
||||
if not team_member:
|
||||
raise HTTPException(status_code=404, detail="User is not a member of the team")
|
||||
|
||||
if team.leader_id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You cannot leave because you are the team leader"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="You cannot leave because you are the team leader")
|
||||
|
||||
await session.delete(team_member)
|
||||
await session.commit()
|
||||
|
||||
@@ -35,10 +35,7 @@ async def user_rename(
|
||||
返回:
|
||||
- 成功: None
|
||||
"""
|
||||
assert current_user is not None
|
||||
samename_user = (
|
||||
await session.exec(select(User).where(User.username == new_name))
|
||||
).first()
|
||||
samename_user = (await session.exec(select(User).where(User.username == new_name))).first()
|
||||
if samename_user:
|
||||
raise HTTPException(409, "Username Exisits")
|
||||
errors = validate_username(new_name)
|
||||
|
||||
@@ -106,9 +106,7 @@ class V1Beatmap(AllStrModel):
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(FavouriteBeatmapset)
|
||||
.where(
|
||||
FavouriteBeatmapset.beatmapset_id == db_beatmap.beatmapset.id
|
||||
)
|
||||
.where(FavouriteBeatmapset.beatmapset_id == db_beatmap.beatmapset.id)
|
||||
)
|
||||
).one(),
|
||||
rating=0, # TODO
|
||||
@@ -154,12 +152,8 @@ async def get_beatmaps(
|
||||
beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"),
|
||||
beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"),
|
||||
user: str | None = Query(None, alias="u", description="谱师"),
|
||||
type: Literal["string", "id"] | None = Query(
|
||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||
),
|
||||
ruleset_id: int | None = Query(
|
||||
None, alias="m", description="Ruleset ID", ge=0, le=3
|
||||
), # TODO
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0, le=3), # TODO
|
||||
convert: bool = Query(False, alias="a", description="转谱"), # TODO
|
||||
checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"),
|
||||
limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"),
|
||||
@@ -181,11 +175,7 @@ async def get_beatmaps(
|
||||
else:
|
||||
beatmaps = beatmapset.beatmaps
|
||||
elif user is not None:
|
||||
where = (
|
||||
Beatmapset.user_id == user
|
||||
if type == "id" or user.isdigit()
|
||||
else Beatmapset.creator == user
|
||||
)
|
||||
where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
|
||||
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()
|
||||
for beatmapset in beatmapsets:
|
||||
if len(beatmaps) >= limit:
|
||||
@@ -193,11 +183,7 @@ async def get_beatmaps(
|
||||
beatmaps.extend(beatmapset.beatmaps)
|
||||
elif since is not None:
|
||||
beatmapsets = (
|
||||
await session.exec(
|
||||
select(Beatmapset)
|
||||
.where(col(Beatmapset.ranked_date) > since)
|
||||
.limit(limit)
|
||||
)
|
||||
await session.exec(select(Beatmapset).where(col(Beatmapset.ranked_date) > since).limit(limit))
|
||||
).all()
|
||||
for beatmapset in beatmapsets:
|
||||
if len(beatmaps) >= limit:
|
||||
@@ -214,11 +200,7 @@ async def get_beatmaps(
|
||||
redis,
|
||||
fetcher,
|
||||
)
|
||||
results.append(
|
||||
await V1Beatmap.from_db(
|
||||
session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty
|
||||
)
|
||||
)
|
||||
results.append(await V1Beatmap.from_db(session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty))
|
||||
continue
|
||||
except Exception:
|
||||
...
|
||||
|
||||
@@ -41,9 +41,7 @@ async def download_replay(
|
||||
ge=0,
|
||||
),
|
||||
score_id: int | None = Query(None, alias="s", description="成绩 ID"),
|
||||
type: Literal["string", "id"] | None = Query(
|
||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||
),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
mods: int = Query(0, description="成绩的 MOD"),
|
||||
storage_service: StorageService = Depends(get_storage_service),
|
||||
):
|
||||
@@ -58,13 +56,9 @@ async def download_replay(
|
||||
await session.exec(
|
||||
select(Score).where(
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == user
|
||||
if type == "id" or user.isdigit()
|
||||
else col(Score.user).has(username=user),
|
||||
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
|
||||
Score.mods == mods_,
|
||||
Score.gamemode == GameMode.from_int_extra(ruleset_id)
|
||||
if ruleset_id is not None
|
||||
else True,
|
||||
Score.gamemode == GameMode.from_int_extra(ruleset_id) if ruleset_id is not None else True,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
@@ -73,10 +67,7 @@ async def download_replay(
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
|
||||
filepath = (
|
||||
f"replays/{score_record.id}_{score_record.beatmap_id}"
|
||||
f"_{score_record.user_id}_lazer_replay.osr"
|
||||
)
|
||||
filepath = f"replays/{score_record.id}_{score_record.beatmap_id}_{score_record.user_id}_lazer_replay.osr"
|
||||
if not await storage_service.is_exists(filepath):
|
||||
raise HTTPException(status_code=404, detail="Replay file not found")
|
||||
|
||||
@@ -100,6 +91,4 @@ async def download_replay(
|
||||
await session.commit()
|
||||
|
||||
data = await storage_service.read_file(filepath)
|
||||
return ReplayModel(
|
||||
content=base64.b64encode(data).decode("utf-8"), encoding="base64"
|
||||
)
|
||||
return ReplayModel(content=base64.b64encode(data).decode("utf-8"), encoding="base64")
|
||||
|
||||
@@ -8,9 +8,7 @@ from app.dependencies.user import v1_authorize
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1", dependencies=[Depends(v1_authorize)], tags=["V1 API"]
|
||||
)
|
||||
router = APIRouter(prefix="/api/v1", dependencies=[Depends(v1_authorize)], tags=["V1 API"])
|
||||
|
||||
|
||||
class AllStrModel(BaseModel):
|
||||
|
||||
@@ -70,9 +70,7 @@ async def get_user_best(
|
||||
session: Database,
|
||||
user: str = Query(..., alias="u", description="用户"),
|
||||
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
||||
type: Literal["string", "id"] | None = Query(
|
||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||
),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
||||
):
|
||||
try:
|
||||
@@ -80,9 +78,7 @@ async def get_user_best(
|
||||
await session.exec(
|
||||
select(Score)
|
||||
.where(
|
||||
Score.user_id == user
|
||||
if type == "id" or user.isdigit()
|
||||
else col(Score.user).has(username=user),
|
||||
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
|
||||
Score.gamemode == GameMode.from_int_extra(ruleset_id),
|
||||
exists().where(col(PPBestScore.score_id) == Score.id),
|
||||
)
|
||||
@@ -106,9 +102,7 @@ async def get_user_recent(
|
||||
session: Database,
|
||||
user: str = Query(..., alias="u", description="用户"),
|
||||
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
||||
type: Literal["string", "id"] | None = Query(
|
||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||
),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
||||
):
|
||||
try:
|
||||
@@ -116,9 +110,7 @@ async def get_user_recent(
|
||||
await session.exec(
|
||||
select(Score)
|
||||
.where(
|
||||
Score.user_id == user
|
||||
if type == "id" or user.isdigit()
|
||||
else col(Score.user).has(username=user),
|
||||
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
|
||||
Score.gamemode == GameMode.from_int_extra(ruleset_id),
|
||||
Score.ended_at > datetime.now(UTC) - timedelta(hours=24),
|
||||
)
|
||||
@@ -143,9 +135,7 @@ async def get_scores(
|
||||
user: str | None = Query(None, alias="u", description="用户"),
|
||||
beatmap_id: int = Query(alias="b", description="谱面 ID"),
|
||||
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
||||
type: Literal["string", "id"] | None = Query(
|
||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||
),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
||||
mods: int = Query(0, description="成绩的 MOD"),
|
||||
):
|
||||
@@ -157,9 +147,7 @@ async def get_scores(
|
||||
.where(
|
||||
Score.gamemode == GameMode.from_int_extra(ruleset_id),
|
||||
Score.beatmap_id == beatmap_id,
|
||||
Score.user_id == user
|
||||
if type == "id" or user.isdigit()
|
||||
else col(Score.user).has(username=user),
|
||||
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
|
||||
)
|
||||
.options(joinedload(Score.beatmap))
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
@@ -13,7 +12,7 @@ from app.service.user_cache_service import get_user_cache_service
|
||||
|
||||
from .router import AllStrModel, router
|
||||
|
||||
from fastapi import HTTPException, Query
|
||||
from fastapi import BackgroundTasks, HTTPException, Query
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
@@ -49,9 +48,7 @@ class V1User(AllStrModel):
|
||||
return f"v1_user:{user_id}"
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, session: Database, db_user: User, ruleset: GameMode | None = None
|
||||
) -> "V1User":
|
||||
async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User":
|
||||
# 确保 user_id 不为 None
|
||||
if db_user.id is None:
|
||||
raise ValueError("User ID cannot be None")
|
||||
@@ -63,9 +60,7 @@ class V1User(AllStrModel):
|
||||
current_statistics = i
|
||||
break
|
||||
if current_statistics:
|
||||
statistics = await UserStatisticsResp.from_db(
|
||||
current_statistics, session, db_user.country_code
|
||||
)
|
||||
statistics = await UserStatisticsResp.from_db(current_statistics, session, db_user.country_code)
|
||||
else:
|
||||
statistics = None
|
||||
return cls(
|
||||
@@ -78,9 +73,7 @@ class V1User(AllStrModel):
|
||||
playcount=statistics.play_count if statistics else 0,
|
||||
ranked_score=statistics.ranked_score if statistics else 0,
|
||||
total_score=statistics.total_score if statistics else 0,
|
||||
pp_rank=statistics.global_rank
|
||||
if statistics and statistics.global_rank
|
||||
else 0,
|
||||
pp_rank=statistics.global_rank if statistics and statistics.global_rank else 0,
|
||||
level=current_statistics.level_current if current_statistics else 0,
|
||||
pp_raw=statistics.pp if statistics else 0.0,
|
||||
accuracy=statistics.hit_accuracy if statistics else 0,
|
||||
@@ -91,9 +84,7 @@ class V1User(AllStrModel):
|
||||
count_rank_a=current_statistics.grade_a if current_statistics else 0,
|
||||
country=db_user.country_code,
|
||||
total_seconds_played=statistics.play_time if statistics else 0,
|
||||
pp_country_rank=statistics.country_rank
|
||||
if statistics and statistics.country_rank
|
||||
else 0,
|
||||
pp_country_rank=statistics.country_rank if statistics and statistics.country_rank else 0,
|
||||
events=[], # TODO
|
||||
)
|
||||
|
||||
@@ -106,14 +97,11 @@ class V1User(AllStrModel):
|
||||
)
|
||||
async def get_user(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: str = Query(..., alias="u", description="用户"),
|
||||
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0),
|
||||
type: Literal["string", "id"] | None = Query(
|
||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||
),
|
||||
event_days: int = Query(
|
||||
default=1, ge=1, le=31, description="从现在起所有事件的最大天数"
|
||||
),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
event_days: int = Query(default=1, ge=1, le=31, description="从现在起所有事件的最大天数"),
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
@@ -131,9 +119,7 @@ async def get_user(
|
||||
if is_id_query:
|
||||
try:
|
||||
user_id_for_cache = int(user)
|
||||
cached_v1_user = await cache_service.get_v1_user_from_cache(
|
||||
user_id_for_cache, ruleset
|
||||
)
|
||||
cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset)
|
||||
if cached_v1_user:
|
||||
return [V1User(**cached_v1_user)]
|
||||
except (ValueError, TypeError):
|
||||
@@ -158,9 +144,7 @@ async def get_user(
|
||||
# 异步缓存结果(如果有用户ID)
|
||||
if db_user.id is not None:
|
||||
user_data = v1_user.model_dump()
|
||||
asyncio.create_task(
|
||||
cache_service.cache_v1_user(user_data, db_user.id, ruleset)
|
||||
)
|
||||
background_tasks.add_task(cache_service.cache_v1_user, user_data, db_user.id, ruleset)
|
||||
|
||||
return [v1_user]
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
|
||||
from . import ( # noqa: F401
|
||||
beatmap,
|
||||
beatmapset,
|
||||
me,
|
||||
|
||||
@@ -40,18 +40,13 @@ class BatchGetResp(BaseModel):
|
||||
tags=["谱面"],
|
||||
name="查询单个谱面",
|
||||
response_model=BeatmapResp,
|
||||
description=(
|
||||
"根据谱面 ID / MD5 / 文件名 查询单个谱面。"
|
||||
"至少提供 id / checksum / filename 之一。"
|
||||
),
|
||||
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
|
||||
)
|
||||
async def lookup_beatmap(
|
||||
db: Database,
|
||||
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
|
||||
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
|
||||
filename: str | None = Query(
|
||||
default=None, alias="filename", description="谱面文件名"
|
||||
),
|
||||
filename: str | None = Query(default=None, alias="filename", description="谱面文件名"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
@@ -96,43 +91,23 @@ async def get_beatmap(
|
||||
tags=["谱面"],
|
||||
name="批量获取谱面",
|
||||
response_model=BatchGetResp,
|
||||
description=(
|
||||
"批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。"
|
||||
"为空时按最近更新时间返回。"
|
||||
),
|
||||
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
|
||||
)
|
||||
async def batch_get_beatmaps(
|
||||
db: Database,
|
||||
beatmap_ids: list[int] = Query(
|
||||
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
|
||||
),
|
||||
beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
if not beatmap_ids:
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
||||
)
|
||||
).all()
|
||||
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
|
||||
else:
|
||||
beatmaps = list(
|
||||
(
|
||||
await db.exec(
|
||||
select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50)
|
||||
)
|
||||
).all()
|
||||
)
|
||||
not_found_beatmaps = [
|
||||
bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]
|
||||
]
|
||||
beatmaps = list((await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50))).all())
|
||||
not_found_beatmaps = [bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]]
|
||||
beatmaps.extend(
|
||||
beatmap
|
||||
for beatmap in await asyncio.gather(
|
||||
*[
|
||||
Beatmap.get_or_fetch(db, fetcher, bid=bid)
|
||||
for bid in not_found_beatmaps
|
||||
],
|
||||
*[Beatmap.get_or_fetch(db, fetcher, bid=bid) for bid in not_found_beatmaps],
|
||||
return_exceptions=True,
|
||||
)
|
||||
if isinstance(beatmap, Beatmap)
|
||||
@@ -140,12 +115,7 @@ async def batch_get_beatmaps(
|
||||
for beatmap in beatmaps:
|
||||
await db.refresh(beatmap)
|
||||
|
||||
return BatchGetResp(
|
||||
beatmaps=[
|
||||
await BeatmapResp.from_db(bm, session=db, user=current_user)
|
||||
for bm in beatmaps
|
||||
]
|
||||
)
|
||||
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -163,12 +133,8 @@ async def get_beatmap_attributes(
|
||||
default_factory=list,
|
||||
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
|
||||
),
|
||||
ruleset: GameMode | None = Query(
|
||||
default=None, description="指定 ruleset;为空则使用谱面自身模式"
|
||||
),
|
||||
ruleset_id: int | None = Query(
|
||||
default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3
|
||||
),
|
||||
ruleset: GameMode | None = Query(default=None, description="指定 ruleset;为空则使用谱面自身模式"),
|
||||
ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
@@ -187,16 +153,11 @@ async def get_beatmap_attributes(
|
||||
if ruleset is None:
|
||||
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||||
ruleset = beatmap_db.mode
|
||||
key = (
|
||||
f"beatmap:{beatmap_id}:{ruleset}:"
|
||||
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
)
|
||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
if await redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
try:
|
||||
return await calculate_beatmap_attributes(
|
||||
beatmap_id, ruleset, mods_, redis, fetcher
|
||||
)
|
||||
return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher)
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
@@ -35,9 +35,7 @@ from sqlmodel import exists, select
|
||||
async def _save_to_db(sets: SearchBeatmapsetsResp):
|
||||
async with with_db() as session:
|
||||
for s in sets.beatmapsets:
|
||||
if not (
|
||||
await session.exec(select(exists()).where(Beatmapset.id == s.id))
|
||||
).first():
|
||||
if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first():
|
||||
await Beatmapset.from_resp(session, s)
|
||||
|
||||
|
||||
@@ -117,9 +115,7 @@ async def lookup_beatmapset(
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||||
resp = await BeatmapsetResp.from_db(
|
||||
beatmap.beatmapset, session=db, user=current_user
|
||||
)
|
||||
resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -138,9 +134,7 @@ async def get_beatmapset(
|
||||
):
|
||||
try:
|
||||
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
|
||||
return await BeatmapsetResp.from_db(
|
||||
beatmapset, session=db, include=["recent_favourites"], user=current_user
|
||||
)
|
||||
return await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
|
||||
@@ -165,9 +159,7 @@ async def download_beatmapset(
|
||||
country_code = geo_info.get("country_iso", "")
|
||||
|
||||
# 优先使用IP地理位置判断,如果获取失败则回退到用户账户的国家代码
|
||||
is_china = country_code == "CN" or (
|
||||
not country_code and current_user.country_code == "CN"
|
||||
)
|
||||
is_china = country_code == "CN" or (not country_code and current_user.country_code == "CN")
|
||||
|
||||
try:
|
||||
# 使用负载均衡服务获取下载URL
|
||||
@@ -179,13 +171,10 @@ async def download_beatmapset(
|
||||
# 如果负载均衡服务失败,回退到原有逻辑
|
||||
if is_china:
|
||||
return RedirectResponse(
|
||||
f"https://dl.sayobot.cn/beatmaps/download/"
|
||||
f"{'novideo' if no_video else 'full'}/{beatmapset_id}"
|
||||
f"https://dl.sayobot.cn/beatmaps/download/{'novideo' if no_video else 'full'}/{beatmapset_id}"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}"
|
||||
)
|
||||
return RedirectResponse(f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}")
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -197,12 +186,9 @@ async def download_beatmapset(
|
||||
async def favourite_beatmapset(
|
||||
db: Database,
|
||||
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
||||
action: Literal["favourite", "unfavourite"] = Form(
|
||||
description="操作类型:favourite 收藏 / unfavourite 取消收藏"
|
||||
),
|
||||
action: Literal["favourite", "unfavourite"] = Form(description="操作类型:favourite 收藏 / unfavourite 取消收藏"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
existing_favourite = (
|
||||
await db.exec(
|
||||
select(FavouriteBeatmapset).where(
|
||||
@@ -212,15 +198,11 @@ async def favourite_beatmapset(
|
||||
)
|
||||
).first()
|
||||
|
||||
if (action == "favourite" and existing_favourite) or (
|
||||
action == "unfavourite" and not existing_favourite
|
||||
):
|
||||
if (action == "favourite" and existing_favourite) or (action == "unfavourite" and not existing_favourite):
|
||||
return
|
||||
|
||||
if action == "favourite":
|
||||
favourite = FavouriteBeatmapset(
|
||||
user_id=current_user.id, beatmapset_id=beatmapset_id
|
||||
)
|
||||
favourite = FavouriteBeatmapset(user_id=current_user.id, beatmapset_id=beatmapset_id)
|
||||
db.add(favourite)
|
||||
else:
|
||||
await db.delete(existing_favourite)
|
||||
|
||||
@@ -4,8 +4,8 @@ from app.database import User
|
||||
from app.database.lazer_user import ALL_INCLUDED
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import Database
|
||||
from app.models.score import GameMode
|
||||
from app.models.api_me import APIMe
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .router import router
|
||||
|
||||
|
||||
@@ -33,6 +33,4 @@ class BackgroundsResp(BaseModel):
|
||||
description="获取当前季节背景图列表。",
|
||||
)
|
||||
async def get_seasonal_backgrounds():
|
||||
return BackgroundsResp(
|
||||
backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds]
|
||||
)
|
||||
return BackgroundsResp(backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds])
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.service.ranking_cache_service import get_ranking_cache_service
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Path, Query, Security
|
||||
from fastapi import BackgroundTasks, Path, Query, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -38,6 +38,7 @@ class CountryResponse(BaseModel):
|
||||
)
|
||||
async def get_country_ranking(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
page: int = Query(1, ge=1, description="页码"), # TODO
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -51,9 +52,7 @@ async def get_country_ranking(
|
||||
|
||||
if cached_data:
|
||||
# 从缓存返回数据
|
||||
return CountryResponse(
|
||||
ranking=[CountryStatistics.model_validate(item) for item in cached_data]
|
||||
)
|
||||
return CountryResponse(ranking=[CountryStatistics.model_validate(item) for item in cached_data])
|
||||
|
||||
# 缓存未命中,从数据库查询
|
||||
response = CountryResponse(ranking=[])
|
||||
@@ -105,14 +104,15 @@ async def get_country_ranking(
|
||||
|
||||
# 异步缓存数据(不等待完成)
|
||||
cache_data = [item.model_dump() for item in current_page_data]
|
||||
cache_task = cache_service.cache_country_ranking(
|
||||
ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60
|
||||
)
|
||||
|
||||
# 创建后台任务来缓存数据
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(cache_task)
|
||||
background_tasks.add_task(
|
||||
cache_service.cache_country_ranking,
|
||||
ruleset,
|
||||
cache_data,
|
||||
page,
|
||||
ttl=settings.ranking_cache_expire_minutes * 60,
|
||||
)
|
||||
|
||||
# 返回当前页的结果
|
||||
response.ranking = current_page_data
|
||||
@@ -132,10 +132,9 @@ class TopUsersResponse(BaseModel):
|
||||
)
|
||||
async def get_user_ranking(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
type: Literal["performance", "score"] = Path(
|
||||
..., description="排名类型:performance 表现分 / score 计分成绩总分"
|
||||
),
|
||||
type: Literal["performance", "score"] = Path(..., description="排名类型:performance 表现分 / score 计分成绩总分"),
|
||||
country: str | None = Query(None, description="国家代码"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -149,9 +148,7 @@ async def get_user_ranking(
|
||||
|
||||
if cached_data:
|
||||
# 从缓存返回数据
|
||||
return TopUsersResponse(
|
||||
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]
|
||||
)
|
||||
return TopUsersResponse(ranking=[UserStatisticsResp.model_validate(item) for item in cached_data])
|
||||
|
||||
# 缓存未命中,从数据库查询
|
||||
wheres = [
|
||||
@@ -169,25 +166,22 @@ async def get_user_ranking(
|
||||
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
|
||||
|
||||
statistics_list = await session.exec(
|
||||
select(UserStatistics)
|
||||
.where(*wheres)
|
||||
.order_by(order_by)
|
||||
.limit(50)
|
||||
.offset(50 * (page - 1))
|
||||
select(UserStatistics).where(*wheres).order_by(order_by).limit(50).offset(50 * (page - 1))
|
||||
)
|
||||
|
||||
# 转换为响应格式
|
||||
ranking_data = []
|
||||
for statistics in statistics_list:
|
||||
user_stats_resp = await UserStatisticsResp.from_db(
|
||||
statistics, session, None, include
|
||||
)
|
||||
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
|
||||
ranking_data.append(user_stats_resp)
|
||||
|
||||
# 异步缓存数据(不等待完成)
|
||||
# 使用配置文件中的TTL设置
|
||||
cache_data = [item.model_dump() for item in ranking_data]
|
||||
cache_task = cache_service.cache_ranking(
|
||||
# 创建后台任务来缓存数据
|
||||
|
||||
background_tasks.add_task(
|
||||
cache_service.cache_ranking,
|
||||
ruleset,
|
||||
type,
|
||||
cache_data,
|
||||
@@ -196,139 +190,134 @@ async def get_user_ranking(
|
||||
ttl=settings.ranking_cache_expire_minutes * 60,
|
||||
)
|
||||
|
||||
# 创建后台任务来缓存数据
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(cache_task)
|
||||
|
||||
resp = TopUsersResponse(ranking=ranking_data)
|
||||
return resp
|
||||
|
||||
|
||||
""" @router.post(
|
||||
"/rankings/cache/refresh",
|
||||
name="刷新排行榜缓存",
|
||||
description="手动刷新排行榜缓存(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def refresh_ranking_cache(
|
||||
session: Database,
|
||||
ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
|
||||
type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
|
||||
country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
|
||||
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
if ruleset and type:
|
||||
# 刷新特定的用户排行榜
|
||||
await cache_service.refresh_ranking_cache(session, ruleset, type, country)
|
||||
message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
|
||||
|
||||
# 如果请求刷新地区排行榜
|
||||
if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
|
||||
await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
message += f" and country ranking for {ruleset}"
|
||||
|
||||
return {"message": message}
|
||||
elif ruleset:
|
||||
# 刷新特定游戏模式的所有排行榜
|
||||
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
|
||||
for ranking_type in ranking_types:
|
||||
await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
|
||||
|
||||
if include_country_ranking:
|
||||
await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
|
||||
return {"message": f"Refreshed all ranking caches for {ruleset}"}
|
||||
else:
|
||||
# 刷新所有排行榜
|
||||
await cache_service.refresh_all_rankings(session)
|
||||
return {"message": "Refreshed all ranking caches"}
|
||||
# @router.post(
|
||||
# "/rankings/cache/refresh",
|
||||
# name="刷新排行榜缓存",
|
||||
# description="手动刷新排行榜缓存(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def refresh_ranking_cache(
|
||||
# session: Database,
|
||||
# ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
|
||||
# type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
|
||||
# country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
|
||||
# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# if ruleset and type:
|
||||
# # 刷新特定的用户排行榜
|
||||
# await cache_service.refresh_ranking_cache(session, ruleset, type, country)
|
||||
# message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
|
||||
|
||||
# # 如果请求刷新地区排行榜
|
||||
# if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
|
||||
# await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
# message += f" and country ranking for {ruleset}"
|
||||
|
||||
# return {"message": message}
|
||||
# elif ruleset:
|
||||
# # 刷新特定游戏模式的所有排行榜
|
||||
# ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
|
||||
# for ranking_type in ranking_types:
|
||||
# await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
|
||||
|
||||
# if include_country_ranking:
|
||||
# await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
|
||||
# return {"message": f"Refreshed all ranking caches for {ruleset}"}
|
||||
# else:
|
||||
# # 刷新所有排行榜
|
||||
# await cache_service.refresh_all_rankings(session)
|
||||
# return {"message": "Refreshed all ranking caches"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rankings/{ruleset}/country/cache/refresh",
|
||||
name="刷新地区排行榜缓存",
|
||||
description="手动刷新地区排行榜缓存(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def refresh_country_ranking_cache(
|
||||
session: Database,
|
||||
ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
return {"message": f"Refreshed country ranking cache for {ruleset}"}
|
||||
# @router.post(
|
||||
# "/rankings/{ruleset}/country/cache/refresh",
|
||||
# name="刷新地区排行榜缓存",
|
||||
# description="手动刷新地区排行榜缓存(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def refresh_country_ranking_cache(
|
||||
# session: Database,
|
||||
# ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
# return {"message": f"Refreshed country ranking cache for {ruleset}"}
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/rankings/cache",
|
||||
name="清除排行榜缓存",
|
||||
description="清除排行榜缓存(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def clear_ranking_cache(
|
||||
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
|
||||
type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
|
||||
country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
|
||||
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
|
||||
|
||||
if ruleset and type:
|
||||
message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
|
||||
if include_country_ranking:
|
||||
message += " and country ranking"
|
||||
return {"message": message}
|
||||
else:
|
||||
message = "Cleared all ranking caches"
|
||||
if include_country_ranking:
|
||||
message += " including country rankings"
|
||||
return {"message": message}
|
||||
# @router.delete(
|
||||
# "/rankings/cache",
|
||||
# name="清除排行榜缓存",
|
||||
# description="清除排行榜缓存(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def clear_ranking_cache(
|
||||
# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
|
||||
# type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
|
||||
# country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
|
||||
# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
|
||||
|
||||
# if ruleset and type:
|
||||
# message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
|
||||
# if include_country_ranking:
|
||||
# message += " and country ranking"
|
||||
# return {"message": message}
|
||||
# else:
|
||||
# message = "Cleared all ranking caches"
|
||||
# if include_country_ranking:
|
||||
# message += " including country rankings"
|
||||
# return {"message": message}
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/rankings/{ruleset}/country/cache",
|
||||
name="清除地区排行榜缓存",
|
||||
description="清除地区排行榜缓存(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def clear_country_ranking_cache(
|
||||
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
await cache_service.invalidate_country_cache(ruleset)
|
||||
|
||||
if ruleset:
|
||||
return {"message": f"Cleared country ranking cache for {ruleset}"}
|
||||
else:
|
||||
return {"message": "Cleared all country ranking caches"}
|
||||
# @router.delete(
|
||||
# "/rankings/{ruleset}/country/cache",
|
||||
# name="清除地区排行榜缓存",
|
||||
# description="清除地区排行榜缓存(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def clear_country_ranking_cache(
|
||||
# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# await cache_service.invalidate_country_cache(ruleset)
|
||||
|
||||
# if ruleset:
|
||||
# return {"message": f"Cleared country ranking cache for {ruleset}"}
|
||||
# else:
|
||||
# return {"message": "Cleared all country ranking caches"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rankings/cache/stats",
|
||||
name="获取排行榜缓存统计",
|
||||
description="获取排行榜缓存统计信息(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def get_ranking_cache_stats(
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
stats = await cache_service.get_cache_stats()
|
||||
return stats """
|
||||
# @router.get(
|
||||
# "/rankings/cache/stats",
|
||||
# name="获取排行榜缓存统计",
|
||||
# description="获取排行榜缓存统计信息(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def get_ranking_cache_stats(
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# stats = await cache_service.get_cache_stats()
|
||||
# return stats
|
||||
|
||||
@@ -30,11 +30,7 @@ async def get_relationship(
|
||||
request: Request,
|
||||
current_user: User = Security(get_current_user, scopes=["friends.read"]),
|
||||
):
|
||||
relationship_type = (
|
||||
RelationshipType.FOLLOW
|
||||
if request.url.path.endswith("/friends")
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
|
||||
relationships = await db.exec(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == current_user.id,
|
||||
@@ -71,12 +67,7 @@ async def add_relationship(
|
||||
target: int = Query(description="目标用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
relationship_type = (
|
||||
RelationshipType.FOLLOW
|
||||
if request.url.path.endswith("/friends")
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
|
||||
if target == current_user.id:
|
||||
raise HTTPException(422, "Cannot add relationship to yourself")
|
||||
relationship = (
|
||||
@@ -120,11 +111,8 @@ async def add_relationship(
|
||||
Relationship.target_id == target,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
assert relationship, "Relationship should exist after commit"
|
||||
return AddFriendResp(
|
||||
user_relation=await RelationshipResp.from_db(db, relationship)
|
||||
)
|
||||
).one()
|
||||
return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship))
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -145,11 +133,7 @@ async def delete_relationship(
|
||||
target: int = Path(..., description="目标用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
relationship_type = (
|
||||
RelationshipType.BLOCK
|
||||
if "/blocks/" in request.url.path
|
||||
else RelationshipType.FOLLOW
|
||||
)
|
||||
relationship_type = RelationshipType.BLOCK if "/blocks/" in request.url.path else RelationshipType.FOLLOW
|
||||
relationship = (
|
||||
await db.exec(
|
||||
select(Relationship).where(
|
||||
|
||||
@@ -39,17 +39,11 @@ async def get_all_rooms(
|
||||
db: Database,
|
||||
mode: Literal["open", "ended", "participated", "owned", None] = Query(
|
||||
default="open",
|
||||
description=(
|
||||
"房间模式:open 当前开放 / ended 已经结束 / "
|
||||
"participated 参与过 / owned 自己创建的房间"
|
||||
),
|
||||
description=("房间模式:open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
|
||||
),
|
||||
category: RoomCategory = Query(
|
||||
RoomCategory.NORMAL,
|
||||
description=(
|
||||
"房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
|
||||
" / DAILY_CHALLENGE 每日挑战"
|
||||
),
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
|
||||
),
|
||||
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -60,10 +54,7 @@ async def get_all_rooms(
|
||||
if status is not None:
|
||||
where_clauses.append(col(Room.status) == status)
|
||||
if mode == "open":
|
||||
where_clauses.append(
|
||||
(col(Room.ends_at).is_(None))
|
||||
| (col(Room.ends_at) > now.replace(tzinfo=UTC))
|
||||
)
|
||||
where_clauses.append((col(Room.ends_at).is_(None)) | (col(Room.ends_at) > now.replace(tzinfo=UTC)))
|
||||
if category == RoomCategory.REALTIME:
|
||||
where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
|
||||
if mode == "participated":
|
||||
@@ -76,10 +67,7 @@ async def get_all_rooms(
|
||||
if mode == "owned":
|
||||
where_clauses.append(col(Room.host_id) == current_user.id)
|
||||
if mode == "ended":
|
||||
where_clauses.append(
|
||||
(col(Room.ends_at).is_not(None))
|
||||
& (col(Room.ends_at) < now.replace(tzinfo=UTC))
|
||||
)
|
||||
where_clauses.append((col(Room.ends_at).is_not(None)) & (col(Room.ends_at) < now.replace(tzinfo=UTC)))
|
||||
|
||||
db_rooms = (
|
||||
(
|
||||
@@ -97,11 +85,7 @@ async def get_all_rooms(
|
||||
resp = await RoomResp.from_db(room, db)
|
||||
if category == RoomCategory.REALTIME:
|
||||
mp_room = MultiplayerHubs.rooms.get(room.id)
|
||||
resp.has_password = (
|
||||
bool(mp_room.room.settings.password.strip())
|
||||
if mp_room is not None
|
||||
else False
|
||||
)
|
||||
resp.has_password = bool(mp_room.room.settings.password.strip()) if mp_room is not None else False
|
||||
resp.category = RoomCategory.NORMAL
|
||||
resp_list.append(resp)
|
||||
|
||||
@@ -115,9 +99,7 @@ class APICreatedRoom(RoomResp):
|
||||
error: str = ""
|
||||
|
||||
|
||||
async def _participate_room(
|
||||
room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
|
||||
):
|
||||
async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
|
||||
participated_user = (
|
||||
await session.exec(
|
||||
select(RoomParticipatedUser).where(
|
||||
@@ -154,7 +136,6 @@ async def create_room(
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
user_id = current_user.id
|
||||
db_room = await create_playlist_room_from_api(db, room, user_id)
|
||||
await _participate_room(db_room.id, user_id, db_room, db, redis)
|
||||
@@ -177,10 +158,7 @@ async def get_room(
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
category: str = Query(
|
||||
default="",
|
||||
description=(
|
||||
"房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
|
||||
" / DAILY_CHALLENGE 每日挑战 (可选)"
|
||||
),
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
|
||||
),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
@@ -188,9 +166,7 @@ async def get_room(
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
resp = await RoomResp.from_db(
|
||||
db_room, include=["current_user_score"], session=db, user=current_user
|
||||
)
|
||||
resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -400,7 +376,6 @@ async def get_room_events(
|
||||
for score in scores:
|
||||
user_ids.add(score.user_id)
|
||||
beatmap_ids.add(score.beatmap_id)
|
||||
assert event.id is not None
|
||||
first_event_id = min(first_event_id, event.id)
|
||||
last_event_id = max(last_event_id, event.id)
|
||||
|
||||
@@ -416,16 +391,12 @@ async def get_room_events(
|
||||
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
|
||||
user_resps = [await UserResp.from_db(user, db) for user in users]
|
||||
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
|
||||
beatmap_resps = [
|
||||
await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
|
||||
]
|
||||
beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps]
|
||||
beatmapset_resps = {}
|
||||
for beatmap_resp in beatmap_resps:
|
||||
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
|
||||
|
||||
playlist_items_resps = [
|
||||
await PlaylistResp.from_db(item) for item in playlist_items.values()
|
||||
]
|
||||
playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()]
|
||||
|
||||
return RoomEvents(
|
||||
beatmaps=beatmap_resps,
|
||||
|
||||
@@ -104,11 +104,7 @@ async def submit_score(
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
score_token = (
|
||||
await db.exec(
|
||||
select(ScoreToken)
|
||||
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(ScoreToken.id == token)
|
||||
)
|
||||
await db.exec(select(ScoreToken).options(joinedload(ScoreToken.beatmap)).where(ScoreToken.id == token))
|
||||
).first()
|
||||
if not score_token or score_token.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Score token not found")
|
||||
@@ -138,10 +134,7 @@ async def submit_score(
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
has_pp = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp
|
||||
has_leaderboard = (
|
||||
db_beatmap.beatmap_status.has_leaderboard()
|
||||
| settings.enable_all_beatmap_leaderboard
|
||||
)
|
||||
has_leaderboard = db_beatmap.beatmap_status.has_leaderboard() | settings.enable_all_beatmap_leaderboard
|
||||
beatmap_length = db_beatmap.total_length
|
||||
score = await process_score(
|
||||
current_user,
|
||||
@@ -167,21 +160,11 @@ async def submit_score(
|
||||
has_pp,
|
||||
has_leaderboard,
|
||||
)
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
|
||||
.where(Score.id == score_id)
|
||||
)
|
||||
).first()
|
||||
assert score is not None
|
||||
score = (await db.exec(select(Score).options(joinedload(Score.user)).where(Score.id == score_id))).one()
|
||||
|
||||
resp = await ScoreResp.from_db(db, score)
|
||||
total_users = (await db.exec(select(func.count()).select_from(User))).first()
|
||||
assert total_users is not None
|
||||
if resp.rank_global is not None and resp.rank_global <= min(
|
||||
math.ceil(float(total_users) * 0.01), 50
|
||||
):
|
||||
total_users = (await db.exec(select(func.count()).select_from(User))).one()
|
||||
if resp.rank_global is not None and resp.rank_global <= min(math.ceil(float(total_users) * 0.01), 50):
|
||||
rank_event = Event(
|
||||
created_at=datetime.now(UTC),
|
||||
type=EventType.RANK,
|
||||
@@ -207,9 +190,7 @@ async def submit_score(
|
||||
score_gamemode = score.gamemode
|
||||
|
||||
if user_id is not None:
|
||||
background_task.add_task(
|
||||
_refresh_user_cache_background, redis, user_id, score_gamemode
|
||||
)
|
||||
background_task.add_task(_refresh_user_cache_background, redis, user_id, score_gamemode)
|
||||
background_task.add_task(process_user_achievement, resp.id)
|
||||
return resp
|
||||
|
||||
@@ -225,9 +206,7 @@ async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameM
|
||||
# 创建独立的数据库会话
|
||||
session = AsyncSession(engine)
|
||||
try:
|
||||
await user_cache_service.refresh_user_cache_on_score_submit(
|
||||
session, user_id, mode
|
||||
)
|
||||
await user_cache_service.refresh_user_cache_on_score_submit(session, user_id, mode)
|
||||
finally:
|
||||
await session.close()
|
||||
except Exception as e:
|
||||
@@ -280,22 +259,16 @@ async def get_beatmap_scores(
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
mode: GameMode = Query(description="指定 auleset"),
|
||||
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
||||
mods: list[str] = Query(
|
||||
default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"
|
||||
),
|
||||
mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"),
|
||||
type: LeaderboardType = Query(
|
||||
LeaderboardType.GLOBAL,
|
||||
description=(
|
||||
"排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"
|
||||
),
|
||||
description=("排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
|
||||
),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="this server only contains lazer scores"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="this server only contains lazer scores")
|
||||
|
||||
all_scores, user_score, count = await get_leaderboard(
|
||||
db,
|
||||
@@ -310,9 +283,7 @@ async def get_beatmap_scores(
|
||||
user_score_resp = await ScoreResp.from_db(db, user_score) if user_score else None
|
||||
resp = BeatmapScores(
|
||||
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
|
||||
user_score=BeatmapUserScore(
|
||||
score=user_score_resp, position=user_score_resp.rank_global or 0
|
||||
)
|
||||
user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0)
|
||||
if user_score_resp
|
||||
else None,
|
||||
score_count=count,
|
||||
@@ -342,9 +313,7 @@ async def get_user_beatmap_score(
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
@@ -386,9 +355,7 @@ async def get_user_all_beatmap_scores(
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
@@ -420,7 +387,6 @@ async def create_solo_score(
|
||||
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
@@ -454,10 +420,7 @@ async def submit_solo_score(
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher=Depends(get_fetcher),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
return await submit_score(
|
||||
background_task, info, beatmap_id, token, current_user, db, redis, fetcher
|
||||
)
|
||||
return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -478,7 +441,6 @@ async def create_playlist_score(
|
||||
version_hash: str = Form("", description="谱面版本哈希"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
@@ -488,26 +450,16 @@ async def create_playlist_score(
|
||||
db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None
|
||||
if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC):
|
||||
raise HTTPException(status_code=400, detail="Room has ended")
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Playlist not found")
|
||||
|
||||
# validate
|
||||
if not item.freestyle:
|
||||
if item.ruleset_id != ruleset_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Ruleset mismatch in playlist item"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Ruleset mismatch in playlist item")
|
||||
if item.beatmap_id != beatmap_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Beatmap ID mismatch in playlist item"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Beatmap ID mismatch in playlist item")
|
||||
agg = await session.exec(
|
||||
select(ItemAttemptsCount).where(
|
||||
ItemAttemptsCount.room_id == room_id,
|
||||
@@ -523,9 +475,7 @@ async def create_playlist_score(
|
||||
if item.expired:
|
||||
raise HTTPException(status_code=400, detail="Playlist item has expired")
|
||||
if item.played_at:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Playlist item has already been played"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Playlist item has already been played")
|
||||
# 这里应该不用验证mod了吧。。。
|
||||
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
|
||||
score_token = ScoreToken(
|
||||
@@ -557,18 +507,10 @@ async def submit_playlist_score(
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Playlist item not found")
|
||||
room = await session.get(Room, room_id)
|
||||
@@ -621,9 +563,7 @@ async def index_playlist_scores(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
|
||||
cursor: int = Query(
|
||||
2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"
|
||||
),
|
||||
cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
@@ -693,9 +633,6 @@ async def show_playlist_score(
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
@@ -715,9 +652,7 @@ async def show_playlist_score(
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if completed_players := await redis.get(
|
||||
f"multiplayer:{room_id}:gameplay:players"
|
||||
):
|
||||
if completed_players := await redis.get(f"multiplayer:{room_id}:gameplay:players"):
|
||||
completed = completed_players == "0"
|
||||
if score_record and completed:
|
||||
break
|
||||
@@ -784,9 +719,7 @@ async def get_user_playlist_score(
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
resp = await ScoreResp.from_db(session, score_record.score)
|
||||
resp.position = await get_position(
|
||||
room_id, playlist_id, score_record.score_id, session
|
||||
)
|
||||
resp.position = await get_position(room_id, playlist_id, score_record.score_id, session)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -850,11 +783,7 @@ async def unpin_score(
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
score_record = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.id == score_id, Score.user_id == user_id)
|
||||
)
|
||||
).first()
|
||||
score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
@@ -878,10 +807,7 @@ async def unpin_score(
|
||||
"/score-pins/{score_id}/reorder",
|
||||
status_code=204,
|
||||
name="调整置顶成绩顺序",
|
||||
description=(
|
||||
"**客户端专属**\n调整已置顶成绩的展示顺序。"
|
||||
"仅提供 after_score_id 或 before_score_id 之一。"
|
||||
),
|
||||
description=("**客户端专属**\n调整已置顶成绩的展示顺序。仅提供 after_score_id 或 before_score_id 之一。"),
|
||||
tags=["成绩"],
|
||||
)
|
||||
async def reorder_score_pin(
|
||||
@@ -894,11 +820,7 @@ async def reorder_score_pin(
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
score_record = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.id == score_id, Score.user_id == user_id)
|
||||
)
|
||||
).first()
|
||||
score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
@@ -908,8 +830,7 @@ async def reorder_score_pin(
|
||||
if (after_score_id is None) == (before_score_id is None):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either after_score_id or before_score_id "
|
||||
"must be provided (but not both)",
|
||||
detail="Either after_score_id or before_score_id must be provided (but not both)",
|
||||
)
|
||||
|
||||
all_pinned_scores = (
|
||||
@@ -927,9 +848,7 @@ async def reorder_score_pin(
|
||||
target_order = None
|
||||
reference_score_id = after_score_id or before_score_id
|
||||
|
||||
reference_score = next(
|
||||
(s for s in all_pinned_scores if s.id == reference_score_id), None
|
||||
)
|
||||
reference_score = next((s for s in all_pinned_scores if s.id == reference_score_id), None)
|
||||
if not reference_score:
|
||||
detail = "After score not found" if after_score_id else "Before score not found"
|
||||
raise HTTPException(status_code=404, detail=detail)
|
||||
@@ -951,9 +870,7 @@ async def reorder_score_pin(
|
||||
if current_order < s.pinned_order <= target_order and s.id != score_id:
|
||||
updates.append((s.id, s.pinned_order - 1))
|
||||
if after_score_id:
|
||||
final_target = (
|
||||
target_order - 1 if target_order > current_order else target_order
|
||||
)
|
||||
final_target = target_order - 1 if target_order > current_order else target_order
|
||||
else:
|
||||
final_target = target_order
|
||||
else:
|
||||
@@ -964,9 +881,7 @@ async def reorder_score_pin(
|
||||
|
||||
for score_id, new_order in updates:
|
||||
await db.exec(select(Score).where(Score.id == score_id))
|
||||
score_to_update = (
|
||||
await db.exec(select(Score).where(Score.id == score_id))
|
||||
).first()
|
||||
score_to_update = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||
if score_to_update:
|
||||
score_to_update.pinned_order = new_order
|
||||
|
||||
|
||||
@@ -4,34 +4,29 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, UTC
|
||||
from typing import Annotated
|
||||
|
||||
from app.auth import authenticate_user
|
||||
from app.config import settings
|
||||
from app.database import User
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import GeoIPHelper, get_geoip_helper
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.service.email_verification_service import (
|
||||
EmailVerificationService,
|
||||
LoginSessionService
|
||||
EmailVerificationService,
|
||||
LoginSessionService,
|
||||
)
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
|
||||
from fastapi import Form, Depends, Request, HTTPException, status, Security
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, Form, HTTPException, Request, Security, status
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class SessionReissueResponse(BaseModel):
|
||||
"""重新发送验证码响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
@@ -40,39 +35,35 @@ class SessionReissueResponse(BaseModel):
|
||||
"/session/verify",
|
||||
name="验证会话",
|
||||
description="验证邮件验证码并完成会话认证",
|
||||
status_code=204
|
||||
status_code=204,
|
||||
)
|
||||
async def verify_session(
|
||||
request: Request,
|
||||
db: Database,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
verification_key: str = Form(..., description="8位邮件验证码"),
|
||||
current_user: User = Security(get_current_user)
|
||||
current_user: User = Security(get_current_user),
|
||||
) -> Response:
|
||||
"""
|
||||
验证邮件验证码并完成会话认证
|
||||
|
||||
|
||||
对应 osu! 的 session/verify 接口
|
||||
成功时返回 204 No Content,失败时返回 401 Unauthorized
|
||||
"""
|
||||
try:
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||
|
||||
|
||||
ip_address = get_client_ip(request) # noqa: F841
|
||||
user_agent = request.headers.get("User-Agent", "Unknown") # noqa: F841
|
||||
|
||||
# 从当前认证用户获取信息
|
||||
user_id = current_user.id
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户未认证"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未认证")
|
||||
|
||||
# 验证邮件验证码
|
||||
success, message = await EmailVerificationService.verify_code(
|
||||
db, redis, user_id, verification_key
|
||||
)
|
||||
|
||||
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_key)
|
||||
|
||||
if success:
|
||||
# 记录成功的邮件验证
|
||||
await LoginLogService.record_login(
|
||||
@@ -81,9 +72,9 @@ async def verify_session(
|
||||
request=request,
|
||||
login_method="email_verification",
|
||||
login_success=True,
|
||||
notes=f"邮件验证成功"
|
||||
notes="邮件验证成功",
|
||||
)
|
||||
|
||||
|
||||
# 返回 204 No Content 表示验证成功
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
else:
|
||||
@@ -93,83 +84,69 @@ async def verify_session(
|
||||
request=request,
|
||||
attempted_username=current_user.username,
|
||||
login_method="email_verification",
|
||||
notes=f"邮件验证失败: {message}"
|
||||
notes=f"邮件验证失败: {message}",
|
||||
)
|
||||
|
||||
|
||||
# 返回 401 Unauthorized 表示验证失败
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=message
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=message)
|
||||
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的用户会话"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="验证过程中发生错误"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的用户会话")
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="验证过程中发生错误")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/session/verify/reissue",
|
||||
name="重新发送验证码",
|
||||
description="重新发送邮件验证码",
|
||||
response_model=SessionReissueResponse
|
||||
response_model=SessionReissueResponse,
|
||||
)
|
||||
async def reissue_verification_code(
|
||||
request: Request,
|
||||
db: Database,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
current_user: User = Security(get_current_user)
|
||||
current_user: User = Security(get_current_user),
|
||||
) -> SessionReissueResponse:
|
||||
"""
|
||||
重新发送邮件验证码
|
||||
|
||||
|
||||
对应 osu! 的 session/verify/reissue 接口
|
||||
"""
|
||||
try:
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||
|
||||
|
||||
# 从当前认证用户获取信息
|
||||
user_id = current_user.id
|
||||
if not user_id:
|
||||
return SessionReissueResponse(
|
||||
success=False,
|
||||
message="用户未认证"
|
||||
)
|
||||
|
||||
return SessionReissueResponse(success=False, message="用户未认证")
|
||||
|
||||
# 重新发送验证码
|
||||
success, message = await EmailVerificationService.resend_verification_code(
|
||||
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
||||
db,
|
||||
redis,
|
||||
user_id,
|
||||
current_user.username,
|
||||
current_user.email,
|
||||
ip_address,
|
||||
user_agent,
|
||||
)
|
||||
|
||||
return SessionReissueResponse(
|
||||
success=success,
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
return SessionReissueResponse(success=success, message=message)
|
||||
|
||||
except ValueError:
|
||||
return SessionReissueResponse(
|
||||
success=False,
|
||||
message="无效的用户会话"
|
||||
)
|
||||
except Exception as e:
|
||||
return SessionReissueResponse(
|
||||
success=False,
|
||||
message="重新发送过程中发生错误"
|
||||
)
|
||||
return SessionReissueResponse(success=False, message="无效的用户会话")
|
||||
except Exception:
|
||||
return SessionReissueResponse(success=False, message="重新发送过程中发生错误")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/session/check-new-location",
|
||||
name="检查新位置登录",
|
||||
description="检查登录是否来自新位置(内部接口)"
|
||||
description="检查登录是否来自新位置(内部接口)",
|
||||
)
|
||||
async def check_new_location(
|
||||
request: Request,
|
||||
@@ -183,22 +160,21 @@ async def check_new_location(
|
||||
"""
|
||||
try:
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
geo_info = geoip.lookup(ip_address)
|
||||
country_code = geo_info.get("country_iso", "XX")
|
||||
|
||||
is_new_location = await LoginSessionService.check_new_location(
|
||||
db, user_id, ip_address, country_code
|
||||
)
|
||||
|
||||
|
||||
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
|
||||
|
||||
return {
|
||||
"is_new_location": is_new_location,
|
||||
"ip_address": ip_address,
|
||||
"country_code": country_code
|
||||
"country_code": country_code,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"is_new_location": True, # 出错时默认为新位置
|
||||
"error": str(e)
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
@@ -1,73 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from typing import Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from app.dependencies.database import get_redis, get_redis_message
|
||||
from app.log import logger
|
||||
from app.utils import bg_tasks
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Redis key constants
|
||||
REDIS_ONLINE_USERS_KEY = "server:online_users"
|
||||
REDIS_PLAYING_USERS_KEY = "server:playing_users"
|
||||
REDIS_PLAYING_USERS_KEY = "server:playing_users"
|
||||
REDIS_REGISTERED_USERS_KEY = "server:registered_users"
|
||||
REDIS_ONLINE_HISTORY_KEY = "server:online_history"
|
||||
|
||||
# 线程池用于同步Redis操作
|
||||
_executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
|
||||
async def _redis_exec(func, *args, **kwargs):
|
||||
"""在线程池中执行同步Redis操作"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(_executor, func, *args, **kwargs)
|
||||
|
||||
|
||||
class ServerStats(BaseModel):
|
||||
"""服务器统计信息响应模型"""
|
||||
|
||||
registered_users: int
|
||||
online_users: int
|
||||
playing_users: int
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class OnlineHistoryPoint(BaseModel):
|
||||
"""在线历史数据点"""
|
||||
|
||||
timestamp: datetime
|
||||
online_count: int
|
||||
playing_count: int
|
||||
|
||||
|
||||
class OnlineHistoryResponse(BaseModel):
|
||||
"""24小时在线历史响应模型"""
|
||||
|
||||
history: list[OnlineHistoryPoint]
|
||||
current_stats: ServerStats
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ServerStats, tags=["统计"])
|
||||
async def get_server_stats() -> ServerStats:
|
||||
"""
|
||||
获取服务器实时统计信息
|
||||
|
||||
|
||||
返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息
|
||||
"""
|
||||
redis = get_redis()
|
||||
|
||||
|
||||
try:
|
||||
# 并行获取所有统计数据
|
||||
registered_count, online_count, playing_count = await asyncio.gather(
|
||||
_get_registered_users_count(redis),
|
||||
_get_online_users_count(redis),
|
||||
_get_playing_users_count(redis)
|
||||
_get_playing_users_count(redis),
|
||||
)
|
||||
|
||||
|
||||
return ServerStats(
|
||||
registered_users=registered_count,
|
||||
online_users=online_count,
|
||||
playing_users=playing_count,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting server stats: {e}")
|
||||
@@ -76,14 +83,15 @@ async def get_server_stats() -> ServerStats:
|
||||
registered_users=0,
|
||||
online_users=0,
|
||||
playing_users=0,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"])
|
||||
async def get_online_history() -> OnlineHistoryResponse:
|
||||
"""
|
||||
获取最近24小时在线统计历史
|
||||
|
||||
|
||||
返回过去24小时内每小时的在线用户数和游玩用户数统计,
|
||||
包含当前实时数据作为最新数据点
|
||||
"""
|
||||
@@ -92,80 +100,80 @@ async def get_online_history() -> OnlineHistoryResponse:
|
||||
redis_sync = get_redis_message()
|
||||
history_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
|
||||
history_points = []
|
||||
|
||||
|
||||
# 处理历史数据
|
||||
for data in history_data:
|
||||
try:
|
||||
point_data = json.loads(data)
|
||||
# 只保留基本字段
|
||||
history_points.append(OnlineHistoryPoint(
|
||||
timestamp=datetime.fromisoformat(point_data["timestamp"]),
|
||||
online_count=point_data["online_count"],
|
||||
playing_count=point_data["playing_count"]
|
||||
))
|
||||
history_points.append(
|
||||
OnlineHistoryPoint(
|
||||
timestamp=datetime.fromisoformat(point_data["timestamp"]),
|
||||
online_count=point_data["online_count"],
|
||||
playing_count=point_data["playing_count"],
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning(f"Invalid history data point: {data}, error: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 获取当前实时统计信息
|
||||
current_stats = await get_server_stats()
|
||||
|
||||
|
||||
# 如果历史数据为空或者最新数据超过15分钟,添加当前数据点
|
||||
if not history_points or (
|
||||
history_points and
|
||||
(current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds() > 15 * 60
|
||||
history_points
|
||||
and (current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds()
|
||||
> 15 * 60
|
||||
):
|
||||
history_points.append(OnlineHistoryPoint(
|
||||
timestamp=current_stats.timestamp,
|
||||
online_count=current_stats.online_users,
|
||||
playing_count=current_stats.playing_users
|
||||
))
|
||||
|
||||
history_points.append(
|
||||
OnlineHistoryPoint(
|
||||
timestamp=current_stats.timestamp,
|
||||
online_count=current_stats.online_users,
|
||||
playing_count=current_stats.playing_users,
|
||||
)
|
||||
)
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
history_points.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
|
||||
|
||||
# 限制到最多48个数据点(24小时)
|
||||
history_points = history_points[:48]
|
||||
|
||||
return OnlineHistoryResponse(
|
||||
history=history_points,
|
||||
current_stats=current_stats
|
||||
)
|
||||
|
||||
return OnlineHistoryResponse(history=history_points, current_stats=current_stats)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting online history: {e}")
|
||||
# 返回空历史和当前状态
|
||||
current_stats = await get_server_stats()
|
||||
return OnlineHistoryResponse(
|
||||
history=[],
|
||||
current_stats=current_stats
|
||||
)
|
||||
return OnlineHistoryResponse(history=[], current_stats=current_stats)
|
||||
|
||||
|
||||
@router.get("/stats/debug", tags=["统计"])
|
||||
async def get_stats_debug_info():
|
||||
"""
|
||||
获取统计系统调试信息
|
||||
|
||||
|
||||
用于调试时间对齐和区间统计问题
|
||||
"""
|
||||
try:
|
||||
from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager
|
||||
|
||||
|
||||
current_time = datetime.utcnow()
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
interval_stats = await EnhancedIntervalStatsManager.get_current_interval_stats()
|
||||
|
||||
|
||||
# 获取Redis中的实际数据
|
||||
redis_sync = get_redis_message()
|
||||
|
||||
|
||||
online_key = f"server:interval_online_users:{current_interval.interval_key}"
|
||||
playing_key = f"server:interval_playing_users:{current_interval.interval_key}"
|
||||
|
||||
|
||||
online_users_raw = await _redis_exec(redis_sync.smembers, online_key)
|
||||
playing_users_raw = await _redis_exec(redis_sync.smembers, playing_key)
|
||||
|
||||
|
||||
online_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in online_users_raw]
|
||||
playing_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in playing_users_raw]
|
||||
|
||||
|
||||
return {
|
||||
"current_time": current_time.isoformat(),
|
||||
"current_interval": {
|
||||
@@ -175,28 +183,29 @@ async def get_stats_debug_info():
|
||||
"is_current": current_interval.is_current(),
|
||||
"minutes_remaining": int((current_interval.end_time - current_time).total_seconds() / 60),
|
||||
"seconds_remaining": int((current_interval.end_time - current_time).total_seconds()),
|
||||
"progress_percentage": round((1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100, 1)
|
||||
"progress_percentage": round(
|
||||
(1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100,
|
||||
1,
|
||||
),
|
||||
},
|
||||
"interval_statistics": interval_stats.to_dict() if interval_stats else None,
|
||||
"redis_data": {
|
||||
"online_users": online_users,
|
||||
"playing_users": playing_users,
|
||||
"online_count": len(online_users),
|
||||
"playing_count": len(playing_users)
|
||||
"playing_count": len(playing_users),
|
||||
},
|
||||
"system_status": {
|
||||
"stats_system": "enhanced_interval_stats",
|
||||
"data_alignment": "30_minute_boundaries",
|
||||
"real_time_updates": True,
|
||||
"auto_24h_fill": True
|
||||
}
|
||||
"auto_24h_fill": True,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting debug info: {e}")
|
||||
return {
|
||||
"error": "Failed to retrieve debug information",
|
||||
"message": str(e)
|
||||
}
|
||||
return {"error": "Failed to retrieve debug information", "message": str(e)}
|
||||
|
||||
|
||||
async def _get_registered_users_count(redis) -> int:
|
||||
"""获取注册用户总数(从缓存)"""
|
||||
@@ -207,6 +216,7 @@ async def _get_registered_users_count(redis) -> int:
|
||||
logger.error(f"Error getting registered users count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _get_online_users_count(redis) -> int:
|
||||
"""获取当前在线用户数"""
|
||||
try:
|
||||
@@ -216,6 +226,7 @@ async def _get_online_users_count(redis) -> int:
|
||||
logger.error(f"Error getting online users count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _get_playing_users_count(redis) -> int:
|
||||
"""获取当前游玩用户数"""
|
||||
try:
|
||||
@@ -225,27 +236,28 @@ async def _get_playing_users_count(redis) -> int:
|
||||
logger.error(f"Error getting playing users count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 统计更新功能
|
||||
async def update_registered_users_count() -> None:
|
||||
"""更新注册用户数缓存"""
|
||||
from app.dependencies.database import with_db
|
||||
from app.database import User
|
||||
from app.const import BANCHOBOT_ID
|
||||
from sqlmodel import select, func
|
||||
|
||||
from app.database import User
|
||||
from app.dependencies.database import with_db
|
||||
|
||||
from sqlmodel import func, select
|
||||
|
||||
redis = get_redis()
|
||||
try:
|
||||
async with with_db() as db:
|
||||
# 排除机器人用户(BANCHOBOT_ID)
|
||||
result = await db.exec(
|
||||
select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID)
|
||||
)
|
||||
result = await db.exec(select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID))
|
||||
count = result.first()
|
||||
await redis.set(REDIS_REGISTERED_USERS_KEY, count or 0, ex=300) # 5分钟过期
|
||||
logger.debug(f"Updated registered users count: {count}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating registered users count: {e}")
|
||||
|
||||
|
||||
async def add_online_user(user_id: int) -> None:
|
||||
"""添加在线用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -257,14 +269,20 @@ async def add_online_user(user_id: int) -> None:
|
||||
if ttl <= 0: # -1表示永不过期,-2表示不存在,0表示已过期
|
||||
await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期
|
||||
logger.debug(f"Added online user {user_id}")
|
||||
|
||||
|
||||
# 立即更新当前区间统计
|
||||
from app.service.enhanced_interval_stats import update_user_activity_in_interval
|
||||
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=False))
|
||||
|
||||
|
||||
bg_tasks.add_task(
|
||||
update_user_activity_in_interval,
|
||||
user_id,
|
||||
is_playing=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding online user {user_id}: {e}")
|
||||
|
||||
|
||||
async def remove_online_user(user_id: int) -> None:
|
||||
"""移除在线用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -274,6 +292,7 @@ async def remove_online_user(user_id: int) -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing online user {user_id}: {e}")
|
||||
|
||||
|
||||
async def add_playing_user(user_id: int) -> None:
|
||||
"""添加游玩用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -285,14 +304,16 @@ async def add_playing_user(user_id: int) -> None:
|
||||
if ttl <= 0: # -1表示永不过期,-2表示不存在,0表示已过期
|
||||
await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期
|
||||
logger.debug(f"Added playing user {user_id}")
|
||||
|
||||
|
||||
# 立即更新当前区间统计
|
||||
from app.service.enhanced_interval_stats import update_user_activity_in_interval
|
||||
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=True))
|
||||
|
||||
|
||||
bg_tasks.add_task(update_user_activity_in_interval, user_id, is_playing=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding playing user {user_id}: {e}")
|
||||
|
||||
|
||||
async def remove_playing_user(user_id: int) -> None:
|
||||
"""移除游玩用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -301,6 +322,7 @@ async def remove_playing_user(user_id: int) -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing playing user {user_id}: {e}")
|
||||
|
||||
|
||||
async def record_hourly_stats() -> None:
|
||||
"""记录统计数据 - 简化版本,主要作为fallback使用"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -308,24 +330,27 @@ async def record_hourly_stats() -> None:
|
||||
try:
|
||||
# 先确保Redis连接正常
|
||||
await redis_async.ping()
|
||||
|
||||
|
||||
online_count = await _get_online_users_count(redis_async)
|
||||
playing_count = await _get_playing_users_count(redis_async)
|
||||
|
||||
|
||||
current_time = datetime.utcnow()
|
||||
history_point = {
|
||||
"timestamp": current_time.isoformat(),
|
||||
"online_count": online_count,
|
||||
"playing_count": playing_count
|
||||
"playing_count": playing_count,
|
||||
}
|
||||
|
||||
|
||||
# 添加到历史记录
|
||||
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
|
||||
# 只保留48个数据点(24小时,每30分钟一个点)
|
||||
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
|
||||
# 设置过期时间为26小时,确保有足够缓冲
|
||||
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
|
||||
|
||||
logger.info(f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}")
|
||||
|
||||
logger.info(
|
||||
f"Recorded fallback stats: online={online_count}, playing={playing_count} "
|
||||
f"at {current_time.strftime('%H:%M:%S')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error recording fallback stats: {e}")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
@@ -26,7 +25,7 @@ from app.service.user_cache_service import get_user_cache_service
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import HTTPException, Path, Query, Security
|
||||
from fastapi import BackgroundTasks, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import exists, false, select
|
||||
from sqlmodel.sql.expression import col
|
||||
@@ -47,13 +46,10 @@ class BatchUserResponse(BaseModel):
|
||||
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
|
||||
async def get_users(
|
||||
session: Database,
|
||||
user_ids: list[int] = Query(
|
||||
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"
|
||||
),
|
||||
background_task: BackgroundTasks,
|
||||
user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
include_variant_statistics: bool = Query(
|
||||
default=False, description="是否包含各模式的统计信息"
|
||||
), # TODO: future use
|
||||
include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
@@ -72,11 +68,7 @@ async def get_users(
|
||||
|
||||
# 查询未缓存的用户
|
||||
if uncached_user_ids:
|
||||
searched_users = (
|
||||
await session.exec(
|
||||
select(User).where(col(User.id).in_(uncached_user_ids))
|
||||
)
|
||||
).all()
|
||||
searched_users = (await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))).all()
|
||||
|
||||
# 将查询到的用户添加到缓存并返回
|
||||
for searched_user in searched_users:
|
||||
@@ -88,7 +80,7 @@ async def get_users(
|
||||
)
|
||||
cached_users.append(user_resp)
|
||||
# 异步缓存,不阻塞响应
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return BatchUserResponse(users=cached_users)
|
||||
else:
|
||||
@@ -103,7 +95,7 @@ async def get_users(
|
||||
)
|
||||
users.append(user_resp)
|
||||
# 异步缓存
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return BatchUserResponse(users=users)
|
||||
|
||||
@@ -117,6 +109,7 @@ async def get_users(
|
||||
)
|
||||
async def get_user_info_ruleset(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
ruleset: GameMode | None = Path(description="指定 ruleset"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -134,9 +127,7 @@ async def get_user_info_ruleset(
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
select(User).where(
|
||||
User.id == int(user_id)
|
||||
if user_id.isdigit()
|
||||
else User.username == user_id.removeprefix("@")
|
||||
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
|
||||
)
|
||||
)
|
||||
).first()
|
||||
@@ -151,7 +142,7 @@ async def get_user_info_ruleset(
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(cache_service.cache_user(user_resp, ruleset))
|
||||
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
|
||||
|
||||
return user_resp
|
||||
|
||||
@@ -165,6 +156,7 @@ async def get_user_info_ruleset(
|
||||
tags=["用户"],
|
||||
)
|
||||
async def get_user_info(
|
||||
background_task: BackgroundTasks,
|
||||
session: Database,
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -182,9 +174,7 @@ async def get_user_info(
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
select(User).where(
|
||||
User.id == int(user_id)
|
||||
if user_id.isdigit()
|
||||
else User.username == user_id.removeprefix("@")
|
||||
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
|
||||
)
|
||||
)
|
||||
).first()
|
||||
@@ -198,7 +188,7 @@ async def get_user_info(
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return user_resp
|
||||
|
||||
@@ -212,6 +202,7 @@ async def get_user_info(
|
||||
)
|
||||
async def get_user_beatmapsets(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
type: BeatmapsetType = Path(description="谱面集类型"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -222,9 +213,7 @@ async def get_user_beatmapsets(
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
# 先尝试从缓存获取
|
||||
cached_result = await cache_service.get_user_beatmapsets_from_cache(
|
||||
user_id, type.value, limit, offset
|
||||
)
|
||||
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
|
||||
if cached_result is not None:
|
||||
# 根据类型恢复对象
|
||||
if type == BeatmapsetType.MOST_PLAYED:
|
||||
@@ -253,10 +242,7 @@ async def get_user_beatmapsets(
|
||||
raise HTTPException(404, detail="User not found")
|
||||
favourites = await user.awaitable_attrs.favourite_beatmapsets
|
||||
resp = [
|
||||
await BeatmapsetResp.from_db(
|
||||
favourite.beatmapset, session=session, user=user
|
||||
)
|
||||
for favourite in favourites
|
||||
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
|
||||
]
|
||||
|
||||
elif type == BeatmapsetType.MOST_PLAYED:
|
||||
@@ -267,25 +253,18 @@ async def get_user_beatmapsets(
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
resp = [
|
||||
await BeatmapPlaycountsResp.from_db(most_played_beatmap)
|
||||
for most_played_beatmap in most_played
|
||||
]
|
||||
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
|
||||
else:
|
||||
raise HTTPException(400, detail="Invalid beatmapset type")
|
||||
|
||||
# 异步缓存结果
|
||||
async def cache_beatmapsets():
|
||||
try:
|
||||
await cache_service.cache_user_beatmapsets(
|
||||
user_id, type.value, resp, limit, offset
|
||||
)
|
||||
await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
|
||||
)
|
||||
logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}")
|
||||
|
||||
asyncio.create_task(cache_beatmapsets())
|
||||
background_task.add_task(cache_beatmapsets)
|
||||
|
||||
return resp
|
||||
|
||||
@@ -299,18 +278,14 @@ async def get_user_beatmapsets(
|
||||
)
|
||||
async def get_user_scores(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
type: Literal["best", "recent", "firsts", "pinned"] = Path(
|
||||
description=(
|
||||
"成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩"
|
||||
" / firsts 第一名成绩 / pinned 置顶成绩"
|
||||
)
|
||||
description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")
|
||||
),
|
||||
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"),
|
||||
include_fails: bool = Query(False, description="是否包含失败的成绩"),
|
||||
mode: GameMode | None = Query(
|
||||
None, description="指定 ruleset (可选,默认为用户主模式)"
|
||||
),
|
||||
mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -320,9 +295,7 @@ async def get_user_scores(
|
||||
|
||||
# 先尝试从缓存获取(对于recent类型使用较短的缓存时间)
|
||||
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
|
||||
cached_scores = await cache_service.get_user_scores_from_cache(
|
||||
user_id, type, mode, limit, offset
|
||||
)
|
||||
cached_scores = await cache_service.get_user_scores_from_cache(user_id, type, mode, limit, offset)
|
||||
if cached_scores is not None:
|
||||
return cached_scores
|
||||
|
||||
@@ -332,9 +305,7 @@ async def get_user_scores(
|
||||
|
||||
gamemode = mode or db_user.playmode
|
||||
order_by = None
|
||||
where_clause = (col(Score.user_id) == db_user.id) & (
|
||||
col(Score.gamemode) == gamemode
|
||||
)
|
||||
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
|
||||
if not include_fails:
|
||||
where_clause &= col(Score.passed).is_(True)
|
||||
if type == "pinned":
|
||||
@@ -351,13 +322,7 @@ async def get_user_scores(
|
||||
where_clause &= false()
|
||||
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(Score)
|
||||
.where(where_clause)
|
||||
.order_by(order_by)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
await session.exec(select(Score).where(where_clause).order_by(order_by).limit(limit).offset(offset))
|
||||
).all()
|
||||
if not scores:
|
||||
return []
|
||||
@@ -371,18 +336,14 @@ async def get_user_scores(
|
||||
]
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(
|
||||
cache_service.cache_user_scores(
|
||||
user_id, type, score_responses, mode, limit, offset, cache_expire
|
||||
)
|
||||
background_task.add_task(
|
||||
cache_service.cache_user_scores, user_id, type, score_responses, mode, limit, offset, cache_expire
|
||||
)
|
||||
|
||||
return score_responses
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
|
||||
)
|
||||
@router.get("/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp])
|
||||
async def get_user_events(
|
||||
session: Database,
|
||||
user: int,
|
||||
|
||||
@@ -59,9 +59,7 @@ class CacheScheduler:
|
||||
# 从配置文件获取间隔设置
|
||||
check_interval = 5 * 60 # 5分钟检查间隔
|
||||
beatmap_cache_interval = 30 * 60 # 30分钟beatmap缓存间隔
|
||||
ranking_cache_interval = (
|
||||
settings.ranking_cache_refresh_interval_minutes * 60
|
||||
) # 从配置读取
|
||||
ranking_cache_interval = settings.ranking_cache_refresh_interval_minutes * 60 # 从配置读取
|
||||
user_cache_interval = 15 * 60 # 15分钟用户缓存预加载间隔
|
||||
user_cleanup_interval = 60 * 60 # 60分钟用户缓存清理间隔
|
||||
|
||||
|
||||
@@ -5,9 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies.database import engine
|
||||
from app.log import logger
|
||||
from app.service.database_cleanup_service import DatabaseCleanupService
|
||||
@@ -51,16 +49,16 @@ class DatabaseCleanupScheduler:
|
||||
try:
|
||||
# 每小时运行一次清理
|
||||
await asyncio.sleep(3600) # 3600秒 = 1小时
|
||||
|
||||
|
||||
if not self.running:
|
||||
break
|
||||
|
||||
|
||||
await self._run_cleanup()
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Database cleanup scheduler error: {str(e)}")
|
||||
logger.error(f"Database cleanup scheduler error: {e!s}")
|
||||
# 发生错误后等待5分钟再继续
|
||||
await asyncio.sleep(300)
|
||||
|
||||
@@ -69,20 +67,20 @@ class DatabaseCleanupScheduler:
|
||||
try:
|
||||
async with AsyncSession(engine) as db:
|
||||
logger.debug("Starting scheduled database cleanup...")
|
||||
|
||||
|
||||
# 清理过期的验证码
|
||||
expired_codes = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
|
||||
|
||||
|
||||
# 清理过期的登录会话
|
||||
expired_sessions = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
||||
|
||||
|
||||
# 只在有清理记录时输出总结
|
||||
total_cleaned = expired_codes + expired_sessions
|
||||
if total_cleaned > 0:
|
||||
logger.debug(f"Scheduled cleanup completed - codes: {expired_codes}, sessions: {expired_sessions}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during scheduled database cleanup: {str(e)}")
|
||||
logger.error(f"Error during scheduled database cleanup: {e!s}")
|
||||
|
||||
async def run_manual_cleanup(self):
|
||||
"""手动运行完整清理"""
|
||||
@@ -95,7 +93,7 @@ class DatabaseCleanupScheduler:
|
||||
logger.debug(f"Manual cleanup completed, total records cleaned: {total}")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Error during manual database cleanup: {str(e)}")
|
||||
logger.error(f"Error during manual database cleanup: {e!s}")
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@@ -63,10 +63,7 @@ class BeatmapCacheService:
|
||||
if preload_tasks:
|
||||
results = await asyncio.gather(*preload_tasks, return_exceptions=True)
|
||||
success_count = sum(1 for r in results if r is True)
|
||||
logger.info(
|
||||
f"Preloaded {success_count}/{len(preload_tasks)} "
|
||||
f"beatmaps successfully"
|
||||
)
|
||||
logger.info(f"Preloaded {success_count}/{len(preload_tasks)} beatmaps successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during beatmap preloading: {e}")
|
||||
@@ -119,9 +116,7 @@ class BeatmapCacheService:
|
||||
|
||||
return {
|
||||
"cached_beatmaps": len(keys),
|
||||
"estimated_total_size_mb": (
|
||||
round(total_size / 1024 / 1024, 2) if total_size > 0 else 0
|
||||
),
|
||||
"estimated_total_size_mb": (round(total_size / 1024 / 1024, 2) if total_size > 0 else 0),
|
||||
"preloading": self._preloading,
|
||||
}
|
||||
except Exception as e:
|
||||
@@ -155,9 +150,7 @@ def get_beatmap_cache_service(redis: Redis, fetcher: "Fetcher") -> BeatmapCacheS
|
||||
return _cache_service
|
||||
|
||||
|
||||
async def schedule_preload_task(
|
||||
session: AsyncSession, redis: Redis, fetcher: "Fetcher"
|
||||
):
|
||||
async def schedule_preload_task(session: AsyncSession, redis: Redis, fetcher: "Fetcher"):
|
||||
"""
|
||||
定时预加载任务
|
||||
"""
|
||||
|
||||
@@ -192,22 +192,16 @@ class BeatmapDownloadService:
|
||||
healthy_endpoints.sort(key=lambda x: x.priority)
|
||||
return healthy_endpoints
|
||||
|
||||
def get_download_url(
|
||||
self, beatmapset_id: int, no_video: bool, is_china: bool
|
||||
) -> str:
|
||||
def get_download_url(self, beatmapset_id: int, no_video: bool, is_china: bool) -> str:
|
||||
"""获取下载URL,带负载均衡和故障转移"""
|
||||
healthy_endpoints = self.get_healthy_endpoints(is_china)
|
||||
|
||||
if not healthy_endpoints:
|
||||
# 如果没有健康的端点,记录错误并回退到所有端点中优先级最高的
|
||||
logger.error(f"No healthy endpoints available for is_china={is_china}")
|
||||
endpoints = (
|
||||
self.china_endpoints if is_china else self.international_endpoints
|
||||
)
|
||||
endpoints = self.china_endpoints if is_china else self.international_endpoints
|
||||
if not endpoints:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="No download endpoints available"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="No download endpoints available")
|
||||
endpoint = min(endpoints, key=lambda x: x.priority)
|
||||
else:
|
||||
# 使用第一个健康的端点(已按优先级排序)
|
||||
@@ -218,9 +212,7 @@ class BeatmapDownloadService:
|
||||
video_type = "novideo" if no_video else "full"
|
||||
return endpoint.url_template.format(type=video_type, sid=beatmapset_id)
|
||||
elif endpoint.name == "Nerinyan":
|
||||
return endpoint.url_template.format(
|
||||
sid=beatmapset_id, no_video="true" if no_video else "false"
|
||||
)
|
||||
return endpoint.url_template.format(sid=beatmapset_id, no_video="true" if no_video else "false")
|
||||
elif endpoint.name == "OsuDirect":
|
||||
# osu.direct 似乎没有no_video参数,直接使用基础URL
|
||||
return endpoint.url_template.format(sid=beatmapset_id)
|
||||
@@ -239,9 +231,7 @@ class BeatmapDownloadService:
|
||||
for name, status in self.endpoint_status.items():
|
||||
status_info["endpoints"][name] = {
|
||||
"healthy": status.is_healthy,
|
||||
"last_check": status.last_check.isoformat()
|
||||
if status.last_check
|
||||
else None,
|
||||
"last_check": status.last_check.isoformat() if status.last_check else None,
|
||||
"consecutive_failures": status.consecutive_failures,
|
||||
"last_error": status.last_error,
|
||||
"priority": status.endpoint.priority,
|
||||
|
||||
@@ -11,9 +11,7 @@ from app.models.score import GameMode
|
||||
from sqlmodel import col, exists, select, update
|
||||
|
||||
|
||||
@get_scheduler().scheduled_job(
|
||||
"cron", hour=0, minute=0, second=0, id="calculate_user_rank"
|
||||
)
|
||||
@get_scheduler().scheduled_job("cron", hour=0, minute=0, second=0, id="calculate_user_rank")
|
||||
async def calculate_user_rank(is_today: bool = False):
|
||||
today = datetime.now(UTC).date()
|
||||
target_date = today if is_today else today - timedelta(days=1)
|
||||
|
||||
@@ -11,9 +11,7 @@ from sqlmodel import exists, select
|
||||
|
||||
async def create_banchobot():
|
||||
async with with_db() as session:
|
||||
is_exist = (
|
||||
await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))
|
||||
).first()
|
||||
is_exist = (await session.exec(select(exists()).where(User.id == BANCHOBOT_ID))).first()
|
||||
if not is_exist:
|
||||
banchobot = User(
|
||||
username="BanchoBot",
|
||||
|
||||
@@ -82,8 +82,7 @@ async def daily_challenge_job():
|
||||
|
||||
if beatmap is None or ruleset_id is None:
|
||||
logger.warning(
|
||||
f"[DailyChallenge] Missing required data for daily challenge {now}."
|
||||
" Will try again in 5 minutes."
|
||||
f"[DailyChallenge] Missing required data for daily challenge {now}. Will try again in 5 minutes."
|
||||
)
|
||||
get_scheduler().add_job(
|
||||
daily_challenge_job,
|
||||
@@ -104,9 +103,7 @@ async def daily_challenge_job():
|
||||
else:
|
||||
allowed_mods_list = get_available_mods(ruleset_id_int, required_mods_list)
|
||||
|
||||
next_day = (now + timedelta(days=1)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
next_day = (now + timedelta(days=1)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
room = await create_daily_challenge_room(
|
||||
beatmap=beatmap_int,
|
||||
ruleset_id=ruleset_id_int,
|
||||
@@ -114,24 +111,13 @@ async def daily_challenge_job():
|
||||
allowed_mods=allowed_mods_list,
|
||||
duration=int((next_day - now - timedelta(minutes=2)).total_seconds() / 60),
|
||||
)
|
||||
await MetadataHubs.broadcast_call(
|
||||
"DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id)
|
||||
)
|
||||
logger.success(
|
||||
"[DailyChallenge] Added today's daily challenge: "
|
||||
f"{beatmap=}, {ruleset_id=}, {required_mods=}"
|
||||
)
|
||||
await MetadataHubs.broadcast_call("DailyChallengeUpdated", DailyChallengeInfo(room_id=room.id))
|
||||
logger.success(f"[DailyChallenge] Added today's daily challenge: {beatmap=}, {ruleset_id=}, {required_mods=}")
|
||||
return
|
||||
except (ValueError, json.JSONDecodeError) as e:
|
||||
logger.warning(
|
||||
f"[DailyChallenge] Error processing daily challenge data: {e}"
|
||||
" Will try again in 5 minutes."
|
||||
)
|
||||
logger.warning(f"[DailyChallenge] Error processing daily challenge data: {e} Will try again in 5 minutes.")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"[DailyChallenge] Unexpected error in daily challenge job: {e}"
|
||||
" Will try again in 5 minutes."
|
||||
)
|
||||
logger.exception(f"[DailyChallenge] Unexpected error in daily challenge job: {e} Will try again in 5 minutes.")
|
||||
get_scheduler().add_job(
|
||||
daily_challenge_job,
|
||||
"date",
|
||||
@@ -139,9 +125,7 @@ async def daily_challenge_job():
|
||||
)
|
||||
|
||||
|
||||
@get_scheduler().scheduled_job(
|
||||
"cron", hour=0, minute=1, second=0, id="daily_challenge_last_top"
|
||||
)
|
||||
@get_scheduler().scheduled_job("cron", hour=0, minute=1, second=0, id="daily_challenge_last_top")
|
||||
async def process_daily_challenge_top():
|
||||
async with with_db() as session:
|
||||
now = datetime.now(UTC)
|
||||
@@ -182,11 +166,7 @@ async def process_daily_challenge_top():
|
||||
await session.commit()
|
||||
del s
|
||||
|
||||
user_ids = (
|
||||
await session.exec(
|
||||
select(User.id).where(col(User.id).not_in(participated_users))
|
||||
)
|
||||
).all()
|
||||
user_ids = (await session.exec(select(User.id).where(col(User.id).not_in(participated_users)))).all()
|
||||
for id in user_ids:
|
||||
stats = await session.get(DailyChallengeStats, id)
|
||||
if stats is None: # not execute
|
||||
|
||||
@@ -4,14 +4,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.log import logger
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy import and_
|
||||
|
||||
|
||||
class DatabaseCleanupService:
|
||||
@@ -21,211 +20,207 @@ class DatabaseCleanupService:
|
||||
async def cleanup_expired_verification_codes(db: AsyncSession) -> int:
|
||||
"""
|
||||
清理过期的邮件验证码
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找过期的验证码记录
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
stmt = select(EmailVerification).where(
|
||||
EmailVerification.expires_at < current_time
|
||||
)
|
||||
|
||||
stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
|
||||
result = await db.exec(stmt)
|
||||
expired_codes = result.all()
|
||||
|
||||
|
||||
# 删除过期的记录
|
||||
deleted_count = 0
|
||||
for code in expired_codes:
|
||||
await db.delete(code)
|
||||
deleted_count += 1
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired email verification codes")
|
||||
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {str(e)}")
|
||||
logger.error(f"[Cleanup Service] Error cleaning expired verification codes: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_expired_login_sessions(db: AsyncSession) -> int:
|
||||
"""
|
||||
清理过期的登录会话
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找过期的登录会话记录
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
stmt = select(LoginSession).where(
|
||||
LoginSession.expires_at < current_time
|
||||
)
|
||||
|
||||
stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
|
||||
result = await db.exec(stmt)
|
||||
expired_sessions = result.all()
|
||||
|
||||
|
||||
# 删除过期的记录
|
||||
deleted_count = 0
|
||||
for session in expired_sessions:
|
||||
await db.delete(session)
|
||||
deleted_count += 1
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} expired login sessions")
|
||||
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {str(e)}")
|
||||
logger.error(f"[Cleanup Service] Error cleaning expired login sessions: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_used_verification_codes(db: AsyncSession, days_old: int = 7) -> int:
|
||||
"""
|
||||
清理旧的已使用验证码记录
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
days_old: 清理多少天前的已使用记录,默认7天
|
||||
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找指定天数前的已使用验证码记录
|
||||
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
|
||||
|
||||
stmt = select(EmailVerification).where(
|
||||
EmailVerification.is_used == True
|
||||
)
|
||||
|
||||
stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
|
||||
result = await db.exec(stmt)
|
||||
all_used_codes = result.all()
|
||||
|
||||
|
||||
# 筛选出过期的记录
|
||||
old_used_codes = [
|
||||
code for code in all_used_codes
|
||||
if code.used_at and code.used_at < cutoff_time
|
||||
]
|
||||
|
||||
old_used_codes = [code for code in all_used_codes if code.used_at and code.used_at < cutoff_time]
|
||||
|
||||
# 删除旧的已使用记录
|
||||
deleted_count = 0
|
||||
for code in old_used_codes:
|
||||
await db.delete(code)
|
||||
deleted_count += 1
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days")
|
||||
|
||||
logger.debug(
|
||||
f"[Cleanup Service] Cleaned up {deleted_count} used verification codes older than {days_old} days"
|
||||
)
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {str(e)}")
|
||||
logger.error(f"[Cleanup Service] Error cleaning old used verification codes: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_verified_sessions(db: AsyncSession, days_old: int = 30) -> int:
|
||||
"""
|
||||
清理旧的已验证会话记录
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
days_old: 清理多少天前的已验证记录,默认30天
|
||||
|
||||
|
||||
Returns:
|
||||
int: 清理的记录数
|
||||
"""
|
||||
try:
|
||||
# 查找指定天数前的已验证会话记录
|
||||
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
|
||||
|
||||
stmt = select(LoginSession).where(
|
||||
LoginSession.is_verified == True
|
||||
)
|
||||
|
||||
stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
|
||||
result = await db.exec(stmt)
|
||||
all_verified_sessions = result.all()
|
||||
|
||||
|
||||
# 筛选出过期的记录
|
||||
old_verified_sessions = [
|
||||
session for session in all_verified_sessions
|
||||
session
|
||||
for session in all_verified_sessions
|
||||
if session.verified_at and session.verified_at < cutoff_time
|
||||
]
|
||||
|
||||
|
||||
# 删除旧的已验证记录
|
||||
deleted_count = 0
|
||||
for session in old_verified_sessions:
|
||||
await db.delete(session)
|
||||
deleted_count += 1
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days")
|
||||
|
||||
logger.debug(
|
||||
f"[Cleanup Service] Cleaned up {deleted_count} verified sessions older than {days_old} days"
|
||||
)
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {str(e)}")
|
||||
logger.error(f"[Cleanup Service] Error cleaning old verified sessions: {e!s}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def run_full_cleanup(db: AsyncSession) -> dict[str, int]:
|
||||
"""
|
||||
运行完整的清理流程
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 各项清理的结果统计
|
||||
"""
|
||||
results = {}
|
||||
|
||||
|
||||
# 清理过期的验证码
|
||||
results["expired_verification_codes"] = await DatabaseCleanupService.cleanup_expired_verification_codes(db)
|
||||
|
||||
|
||||
# 清理过期的登录会话
|
||||
results["expired_login_sessions"] = await DatabaseCleanupService.cleanup_expired_login_sessions(db)
|
||||
|
||||
|
||||
# 清理7天前的已使用验证码
|
||||
results["old_used_verification_codes"] = await DatabaseCleanupService.cleanup_old_used_verification_codes(db, 7)
|
||||
|
||||
|
||||
# 清理30天前的已验证会话
|
||||
results["old_verified_sessions"] = await DatabaseCleanupService.cleanup_old_verified_sessions(db, 30)
|
||||
|
||||
|
||||
total_cleaned = sum(results.values())
|
||||
if total_cleaned > 0:
|
||||
logger.debug(f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}")
|
||||
|
||||
logger.debug(
|
||||
f"[Cleanup Service] Full cleanup completed, total cleaned: {total_cleaned} records - {results}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
async def get_cleanup_statistics(db: AsyncSession) -> dict[str, int]:
|
||||
"""
|
||||
获取清理统计信息
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 统计信息
|
||||
"""
|
||||
@@ -233,57 +228,54 @@ class DatabaseCleanupService:
|
||||
current_time = datetime.now(UTC)
|
||||
cutoff_7_days = current_time - timedelta(days=7)
|
||||
cutoff_30_days = current_time - timedelta(days=30)
|
||||
|
||||
|
||||
# 统计过期的验证码数量
|
||||
expired_codes_stmt = select(EmailVerification).where(
|
||||
EmailVerification.expires_at < current_time
|
||||
)
|
||||
expired_codes_stmt = select(EmailVerification).where(EmailVerification.expires_at < current_time)
|
||||
expired_codes_result = await db.exec(expired_codes_stmt)
|
||||
expired_codes_count = len(expired_codes_result.all())
|
||||
|
||||
|
||||
# 统计过期的登录会话数量
|
||||
expired_sessions_stmt = select(LoginSession).where(
|
||||
LoginSession.expires_at < current_time
|
||||
)
|
||||
expired_sessions_stmt = select(LoginSession).where(LoginSession.expires_at < current_time)
|
||||
expired_sessions_result = await db.exec(expired_sessions_stmt)
|
||||
expired_sessions_count = len(expired_sessions_result.all())
|
||||
|
||||
|
||||
# 统计7天前的已使用验证码数量
|
||||
old_used_codes_stmt = select(EmailVerification).where(
|
||||
EmailVerification.is_used == True
|
||||
)
|
||||
old_used_codes_stmt = select(EmailVerification).where(col(EmailVerification.is_used).is_(True))
|
||||
old_used_codes_result = await db.exec(old_used_codes_stmt)
|
||||
all_used_codes = old_used_codes_result.all()
|
||||
old_used_codes_count = len([
|
||||
code for code in all_used_codes
|
||||
if code.used_at and code.used_at < cutoff_7_days
|
||||
])
|
||||
|
||||
# 统计30天前的已验证会话数量
|
||||
old_verified_sessions_stmt = select(LoginSession).where(
|
||||
LoginSession.is_verified == True
|
||||
old_used_codes_count = len(
|
||||
[code for code in all_used_codes if code.used_at and code.used_at < cutoff_7_days]
|
||||
)
|
||||
|
||||
# 统计30天前的已验证会话数量
|
||||
old_verified_sessions_stmt = select(LoginSession).where(col(LoginSession.is_verified).is_(True))
|
||||
old_verified_sessions_result = await db.exec(old_verified_sessions_stmt)
|
||||
all_verified_sessions = old_verified_sessions_result.all()
|
||||
old_verified_sessions_count = len([
|
||||
session for session in all_verified_sessions
|
||||
if session.verified_at and session.verified_at < cutoff_30_days
|
||||
])
|
||||
|
||||
old_verified_sessions_count = len(
|
||||
[
|
||||
session
|
||||
for session in all_verified_sessions
|
||||
if session.verified_at and session.verified_at < cutoff_30_days
|
||||
]
|
||||
)
|
||||
|
||||
return {
|
||||
"expired_verification_codes": expired_codes_count,
|
||||
"expired_login_sessions": expired_sessions_count,
|
||||
"old_used_verification_codes": old_used_codes_count,
|
||||
"old_verified_sessions": old_verified_sessions_count,
|
||||
"total_cleanable": expired_codes_count + expired_sessions_count + old_used_codes_count + old_verified_sessions_count
|
||||
"total_cleanable": expired_codes_count
|
||||
+ expired_sessions_count
|
||||
+ old_used_codes_count
|
||||
+ old_verified_sessions_count,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Cleanup Service] Error getting cleanup statistics: {str(e)}")
|
||||
logger.error(f"[Cleanup Service] Error getting cleanup statistics: {e!s}")
|
||||
return {
|
||||
"expired_verification_codes": 0,
|
||||
"expired_login_sessions": 0,
|
||||
"old_used_verification_codes": 0,
|
||||
"old_verified_sessions": 0,
|
||||
"total_cleanable": 0
|
||||
"total_cleanable": 0,
|
||||
}
|
||||
|
||||
@@ -8,17 +8,18 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from datetime import datetime
|
||||
import json
|
||||
import uuid
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from typing import Dict, Any, Optional
|
||||
import redis as sync_redis # 添加同步Redis导入
|
||||
from email.mime.text import MIMEText
|
||||
import json
|
||||
import smtplib
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies.database import redis_message_client # 使用同步Redis客户端
|
||||
from app.log import logger
|
||||
from app.utils import bg_tasks # 添加同步Redis导入
|
||||
|
||||
import redis as sync_redis
|
||||
|
||||
|
||||
class EmailQueue:
|
||||
@@ -30,14 +31,14 @@ class EmailQueue:
|
||||
self._processing = False
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
||||
self._retry_limit = 3 # 重试次数限制
|
||||
|
||||
|
||||
# 邮件配置
|
||||
self.smtp_server = getattr(settings, 'smtp_server', 'localhost')
|
||||
self.smtp_port = getattr(settings, 'smtp_port', 587)
|
||||
self.smtp_username = getattr(settings, 'smtp_username', '')
|
||||
self.smtp_password = getattr(settings, 'smtp_password', '')
|
||||
self.from_email = getattr(settings, 'from_email', 'noreply@example.com')
|
||||
self.from_name = getattr(settings, 'from_name', 'osu! server')
|
||||
self.smtp_server = getattr(settings, "smtp_server", "localhost")
|
||||
self.smtp_port = getattr(settings, "smtp_port", 587)
|
||||
self.smtp_username = getattr(settings, "smtp_username", "")
|
||||
self.smtp_password = getattr(settings, "smtp_password", "")
|
||||
self.from_email = getattr(settings, "from_email", "noreply@example.com")
|
||||
self.from_name = getattr(settings, "from_name", "osu! server")
|
||||
|
||||
async def _run_in_executor(self, func, *args):
|
||||
"""在线程池中运行同步操作"""
|
||||
@@ -48,7 +49,7 @@ class EmailQueue:
|
||||
"""启动邮件处理任务"""
|
||||
if not self._processing:
|
||||
self._processing = True
|
||||
asyncio.create_task(self._process_email_queue())
|
||||
bg_tasks.add_task(self._process_email_queue)
|
||||
logger.info("Email queue processing started")
|
||||
|
||||
async def stop_processing(self):
|
||||
@@ -56,27 +57,29 @@ class EmailQueue:
|
||||
self._processing = False
|
||||
logger.info("Email queue processing stopped")
|
||||
|
||||
async def enqueue_email(self,
|
||||
to_email: str,
|
||||
subject: str,
|
||||
content: str,
|
||||
html_content: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> str:
|
||||
async def enqueue_email(
|
||||
self,
|
||||
to_email: str,
|
||||
subject: str,
|
||||
content: str,
|
||||
html_content: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
将邮件加入队列等待发送
|
||||
|
||||
|
||||
Args:
|
||||
to_email: 收件人邮箱地址
|
||||
subject: 邮件主题
|
||||
content: 邮件纯文本内容
|
||||
html_content: 邮件HTML内容(如果有)
|
||||
metadata: 额外元数据(如密码重置ID等)
|
||||
|
||||
|
||||
Returns:
|
||||
邮件任务ID
|
||||
"""
|
||||
email_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
email_data = {
|
||||
"id": email_id,
|
||||
"to_email": to_email,
|
||||
@@ -86,125 +89,117 @@ class EmailQueue:
|
||||
"metadata": json.dumps(metadata) if metadata else "{}",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"status": "pending", # pending, sending, sent, failed
|
||||
"retry_count": "0"
|
||||
"retry_count": "0",
|
||||
}
|
||||
|
||||
|
||||
# 将邮件数据存入Redis
|
||||
await self._run_in_executor(
|
||||
lambda: self.redis.hset(f"email:{email_id}", mapping=email_data)
|
||||
)
|
||||
|
||||
await self._run_in_executor(lambda: self.redis.hset(f"email:{email_id}", mapping=email_data))
|
||||
|
||||
# 设置24小时过期(防止数据堆积)
|
||||
await self._run_in_executor(
|
||||
self.redis.expire, f"email:{email_id}", 86400
|
||||
)
|
||||
|
||||
await self._run_in_executor(self.redis.expire, f"email:{email_id}", 86400)
|
||||
|
||||
# 加入发送队列
|
||||
await self._run_in_executor(
|
||||
self.redis.lpush, "email_queue", email_id
|
||||
)
|
||||
|
||||
await self._run_in_executor(self.redis.lpush, "email_queue", email_id)
|
||||
|
||||
logger.info(f"Email enqueued with id: {email_id} to {to_email}")
|
||||
return email_id
|
||||
|
||||
async def get_email_status(self, email_id: str) -> Dict[str, Any]:
|
||||
async def get_email_status(self, email_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
获取邮件发送状态
|
||||
|
||||
|
||||
Args:
|
||||
email_id: 邮件任务ID
|
||||
|
||||
|
||||
Returns:
|
||||
邮件任务状态信息
|
||||
"""
|
||||
email_data = await self._run_in_executor(
|
||||
self.redis.hgetall, f"email:{email_id}"
|
||||
)
|
||||
|
||||
email_data = await self._run_in_executor(self.redis.hgetall, f"email:{email_id}")
|
||||
|
||||
# 解码Redis返回的字节数据
|
||||
if email_data:
|
||||
return {
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k:
|
||||
v.decode("utf-8") if isinstance(v, bytes) else v
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k: v.decode("utf-8") if isinstance(v, bytes) else v
|
||||
for k, v in email_data.items()
|
||||
}
|
||||
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
async def _process_email_queue(self):
|
||||
"""处理邮件队列"""
|
||||
logger.info("Starting email queue processor")
|
||||
|
||||
|
||||
while self._processing:
|
||||
try:
|
||||
# 从队列获取邮件ID
|
||||
def brpop_operation():
|
||||
return self.redis.brpop(["email_queue"], timeout=5)
|
||||
|
||||
|
||||
result = await self._run_in_executor(brpop_operation)
|
||||
|
||||
|
||||
if not result:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
|
||||
# 解包返回结果(列表名和值)
|
||||
queue_name, email_id = result
|
||||
if isinstance(email_id, bytes):
|
||||
email_id = email_id.decode("utf-8")
|
||||
|
||||
|
||||
# 获取邮件数据
|
||||
email_data = await self.get_email_status(email_id)
|
||||
if email_data.get("status") == "not_found":
|
||||
logger.warning(f"Email data not found for id: {email_id}")
|
||||
continue
|
||||
|
||||
|
||||
# 更新状态为发送中
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "status", "sending"
|
||||
)
|
||||
|
||||
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "sending")
|
||||
|
||||
# 尝试发送邮件
|
||||
success = await self._send_email(email_data)
|
||||
|
||||
|
||||
if success:
|
||||
# 更新状态为已发送
|
||||
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "sent")
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "status", "sent"
|
||||
)
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "sent_at", datetime.now().isoformat()
|
||||
self.redis.hset,
|
||||
f"email:{email_id}",
|
||||
"sent_at",
|
||||
datetime.now().isoformat(),
|
||||
)
|
||||
logger.info(f"Email {email_id} sent successfully to {email_data.get('to_email')}")
|
||||
else:
|
||||
# 计算重试次数
|
||||
retry_count = int(email_data.get("retry_count", "0")) + 1
|
||||
|
||||
|
||||
if retry_count <= self._retry_limit:
|
||||
# 重新入队,稍后重试
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "retry_count", str(retry_count)
|
||||
self.redis.hset,
|
||||
f"email:{email_id}",
|
||||
"retry_count",
|
||||
str(retry_count),
|
||||
)
|
||||
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "pending")
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "status", "pending"
|
||||
self.redis.hset,
|
||||
f"email:{email_id}",
|
||||
"last_retry",
|
||||
datetime.now().isoformat(),
|
||||
)
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "last_retry", datetime.now().isoformat()
|
||||
)
|
||||
|
||||
|
||||
# 延迟重试(使用指数退避)
|
||||
delay = 60 * (2 ** (retry_count - 1)) # 1分钟,2分钟,4分钟...
|
||||
|
||||
|
||||
# 创建延迟任务
|
||||
asyncio.create_task(self._delayed_retry(email_id, delay))
|
||||
|
||||
bg_tasks.add_task(self._delayed_retry, email_id, delay)
|
||||
|
||||
logger.warning(f"Email {email_id} will be retried in {delay} seconds (attempt {retry_count})")
|
||||
else:
|
||||
# 超过重试次数,标记为失败
|
||||
await self._run_in_executor(
|
||||
self.redis.hset, f"email:{email_id}", "status", "failed"
|
||||
)
|
||||
await self._run_in_executor(self.redis.hset, f"email:{email_id}", "status", "failed")
|
||||
logger.error(f"Email {email_id} failed after {retry_count} attempts")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing email queue: {e}")
|
||||
await asyncio.sleep(5) # 出错后等待5秒
|
||||
@@ -212,53 +207,51 @@ class EmailQueue:
|
||||
async def _delayed_retry(self, email_id: str, delay: int):
|
||||
"""延迟重试发送邮件"""
|
||||
await asyncio.sleep(delay)
|
||||
await self._run_in_executor(
|
||||
self.redis.lpush, "email_queue", email_id
|
||||
)
|
||||
await self._run_in_executor(self.redis.lpush, "email_queue", email_id)
|
||||
logger.info(f"Re-queued email {email_id} for retry after {delay} seconds")
|
||||
|
||||
async def _send_email(self, email_data: Dict[str, Any]) -> bool:
|
||||
async def _send_email(self, email_data: dict[str, Any]) -> bool:
|
||||
"""
|
||||
实际发送邮件
|
||||
|
||||
|
||||
Args:
|
||||
email_data: 邮件数据
|
||||
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
try:
|
||||
# 如果邮件发送功能被禁用,则只记录日志
|
||||
if not getattr(settings, 'enable_email_sending', True):
|
||||
if not getattr(settings, "enable_email_sending", True):
|
||||
logger.info(f"[Mock Email] Would send to {email_data.get('to_email')}: {email_data.get('subject')}")
|
||||
return True
|
||||
|
||||
|
||||
# 创建邮件
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['From'] = f"{self.from_name} <{self.from_email}>"
|
||||
msg['To'] = email_data.get('to_email', '')
|
||||
msg['Subject'] = email_data.get('subject', '')
|
||||
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["From"] = f"{self.from_name} <{self.from_email}>"
|
||||
msg["To"] = email_data.get("to_email", "")
|
||||
msg["Subject"] = email_data.get("subject", "")
|
||||
|
||||
# 添加纯文本内容
|
||||
content = email_data.get('content', '')
|
||||
content = email_data.get("content", "")
|
||||
if content:
|
||||
msg.attach(MIMEText(content, 'plain', 'utf-8'))
|
||||
|
||||
msg.attach(MIMEText(content, "plain", "utf-8"))
|
||||
|
||||
# 添加HTML内容(如果有)
|
||||
html_content = email_data.get('html_content', '')
|
||||
html_content = email_data.get("html_content", "")
|
||||
if html_content:
|
||||
msg.attach(MIMEText(html_content, 'html', 'utf-8'))
|
||||
|
||||
msg.attach(MIMEText(html_content, "html", "utf-8"))
|
||||
|
||||
# 发送邮件
|
||||
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
|
||||
if self.smtp_username and self.smtp_password:
|
||||
server.starttls()
|
||||
server.login(self.smtp_username, self.smtp_password)
|
||||
|
||||
|
||||
server.send_message(msg)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send email: {e}")
|
||||
return False
|
||||
@@ -267,10 +260,12 @@ class EmailQueue:
|
||||
# 全局邮件队列实例
|
||||
email_queue = EmailQueue()
|
||||
|
||||
|
||||
# 在应用启动时调用
|
||||
async def start_email_processor():
|
||||
await email_queue.start_processing()
|
||||
|
||||
|
||||
# 在应用关闭时调用
|
||||
async def stop_email_processor():
|
||||
await email_queue.stop_processing()
|
||||
|
||||
@@ -4,13 +4,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
import secrets
|
||||
import smtplib
|
||||
import string
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from app.config import settings
|
||||
from app.log import logger
|
||||
@@ -18,28 +16,28 @@ from app.log import logger
|
||||
|
||||
class EmailService:
|
||||
"""邮件发送服务"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.smtp_server = getattr(settings, 'smtp_server', 'localhost')
|
||||
self.smtp_port = getattr(settings, 'smtp_port', 587)
|
||||
self.smtp_username = getattr(settings, 'smtp_username', '')
|
||||
self.smtp_password = getattr(settings, 'smtp_password', '')
|
||||
self.from_email = getattr(settings, 'from_email', 'noreply@example.com')
|
||||
self.from_name = getattr(settings, 'from_name', 'osu! server')
|
||||
|
||||
self.smtp_server = getattr(settings, "smtp_server", "localhost")
|
||||
self.smtp_port = getattr(settings, "smtp_port", 587)
|
||||
self.smtp_username = getattr(settings, "smtp_username", "")
|
||||
self.smtp_password = getattr(settings, "smtp_password", "")
|
||||
self.from_email = getattr(settings, "from_email", "noreply@example.com")
|
||||
self.from_name = getattr(settings, "from_name", "osu! server")
|
||||
|
||||
def generate_verification_code(self) -> str:
|
||||
"""生成8位验证码"""
|
||||
# 只使用数字,避免混淆
|
||||
return ''.join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
return "".join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
async def send_verification_email(self, email: str, code: str, username: str) -> bool:
|
||||
"""发送验证邮件"""
|
||||
try:
|
||||
msg = MIMEMultipart()
|
||||
msg['From'] = f"{self.from_name} <{self.from_email}>"
|
||||
msg['To'] = email
|
||||
msg['Subject'] = "邮箱验证 - Email Verification"
|
||||
|
||||
msg["From"] = f"{self.from_name} <{self.from_email}>"
|
||||
msg["To"] = email
|
||||
msg["Subject"] = "邮箱验证 - Email Verification"
|
||||
|
||||
# HTML 邮件内容
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
@@ -101,15 +99,15 @@ class EmailService:
|
||||
<h1>osu! 邮箱验证</h1>
|
||||
<p>Email Verification</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="content">
|
||||
<h2>你好 {username}!</h2>
|
||||
<p>感谢你注册我们的 osu! 服务器。为了完成账户验证,请输入以下验证码:</p>
|
||||
|
||||
|
||||
<div class="code">{code}</div>
|
||||
|
||||
|
||||
<p>这个验证码将在 <strong>10 分钟后过期</strong>。</p>
|
||||
|
||||
|
||||
<div class="warning">
|
||||
<strong>注意:</strong>
|
||||
<ul>
|
||||
@@ -118,19 +116,19 @@ class EmailService:
|
||||
<li>验证码只能使用一次</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
|
||||
<p>如果你有任何问题,请联系我们的支持团队。</p>
|
||||
|
||||
|
||||
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
|
||||
|
||||
|
||||
<h3>Hello {username}!</h3>
|
||||
<p>Thank you for registering on our osu! server. To complete your account verification, please enter the following verification code:</p>
|
||||
|
||||
|
||||
<p>This verification code will expire in <strong>10 minutes</strong>.</p>
|
||||
|
||||
|
||||
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="footer">
|
||||
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
|
||||
<p>This email was sent automatically, please do not reply.</p>
|
||||
@@ -138,26 +136,26 @@ class EmailService:
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
msg.attach(MIMEText(html_content, 'html', 'utf-8'))
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
msg.attach(MIMEText(html_content, "html", "utf-8"))
|
||||
|
||||
# 发送邮件
|
||||
if not settings.enable_email_sending:
|
||||
# 邮件发送功能禁用时只记录日志,不实际发送
|
||||
logger.info(f"[Email Verification] Mock sending verification code to {email}: {code}")
|
||||
return True
|
||||
|
||||
|
||||
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
|
||||
if self.smtp_username and self.smtp_password:
|
||||
server.starttls()
|
||||
server.login(self.smtp_username, self.smtp_password)
|
||||
|
||||
|
||||
server.send_message(msg)
|
||||
|
||||
|
||||
logger.info(f"[Email Verification] Successfully sent verification code to {email}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Failed to send email: {e}")
|
||||
return False
|
||||
|
||||
@@ -4,40 +4,38 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.service.email_service import email_service
|
||||
from app.service.email_queue import email_queue # 导入邮件队列
|
||||
from app.log import logger
|
||||
from app.config import settings
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.log import logger
|
||||
from app.service.email_queue import email_queue # 导入邮件队列
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel import select
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class EmailVerificationService:
|
||||
"""邮件验证服务"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def generate_verification_code() -> str:
|
||||
"""生成8位验证码"""
|
||||
return ''.join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
return "".join(secrets.choice(string.digits) for _ in range(8))
|
||||
|
||||
@staticmethod
|
||||
async def send_verification_email_via_queue(email: str, code: str, username: str, user_id: int) -> bool:
|
||||
"""使用邮件队列发送验证邮件
|
||||
|
||||
|
||||
Args:
|
||||
email: 接收验证码的邮箱地址
|
||||
code: 验证码
|
||||
username: 用户名
|
||||
user_id: 用户ID
|
||||
|
||||
|
||||
Returns:
|
||||
是否成功将邮件加入队列
|
||||
"""
|
||||
@@ -103,15 +101,15 @@ class EmailVerificationService:
|
||||
<h1>osu! 邮箱验证</h1>
|
||||
<p>Email Verification</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="content">
|
||||
<h2>你好 {username}!</h2>
|
||||
<p>请使用以下验证码验证您的账户:</p>
|
||||
|
||||
|
||||
<div class="code">{code}</div>
|
||||
|
||||
|
||||
<p>验证码将在 <strong>10 分钟内有效</strong>。</p>
|
||||
|
||||
|
||||
<div class="warning">
|
||||
<p><strong>重要提示:</strong></p>
|
||||
<ul>
|
||||
@@ -120,17 +118,17 @@ class EmailVerificationService:
|
||||
<li>为了账户安全,请勿在其他网站使用相同的密码</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
|
||||
<hr style="border: none; border-top: 1px solid #ddd; margin: 20px 0;">
|
||||
|
||||
|
||||
<h3>Hello {username}!</h3>
|
||||
<p>Please use the following verification code to verify your account:</p>
|
||||
|
||||
|
||||
<p>This verification code will be valid for <strong>10 minutes</strong>.</p>
|
||||
|
||||
|
||||
<p><strong>Important:</strong> Do not share this verification code with anyone. If you did not request this code, please ignore this email.</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div class="footer">
|
||||
<p>© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。</p>
|
||||
<p>This email was sent automatically, please do not reply.</p>
|
||||
@@ -138,8 +136,8 @@ class EmailVerificationService:
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
# 纯文本备用内容
|
||||
plain_content = f"""
|
||||
你好 {username}!
|
||||
@@ -162,34 +160,30 @@ This verification code will be valid for 10 minutes.
|
||||
© 2025 g0v0! Private Server. 此邮件由系统自动发送,请勿回复。
|
||||
This email was sent automatically, please do not reply.
|
||||
"""
|
||||
|
||||
|
||||
# 将邮件加入队列
|
||||
subject = "邮箱验证 - Email Verification"
|
||||
metadata = {
|
||||
"type": "email_verification",
|
||||
"user_id": user_id,
|
||||
"code": code
|
||||
}
|
||||
|
||||
metadata = {"type": "email_verification", "user_id": user_id, "code": code}
|
||||
|
||||
await email_queue.enqueue_email(
|
||||
to_email=email,
|
||||
subject=subject,
|
||||
content=plain_content,
|
||||
html_content=html_content,
|
||||
metadata=metadata
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Failed to enqueue email: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@staticmethod
|
||||
def generate_session_token() -> str:
|
||||
"""生成会话令牌"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def create_verification_record(
|
||||
db: AsyncSession,
|
||||
@@ -197,27 +191,27 @@ This email was sent automatically, please do not reply.
|
||||
user_id: int,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None
|
||||
user_agent: str | None = None,
|
||||
) -> tuple[EmailVerification, str]:
|
||||
"""创建邮件验证记录"""
|
||||
|
||||
|
||||
# 检查是否有未过期的验证码
|
||||
existing_result = await db.exec(
|
||||
select(EmailVerification).where(
|
||||
EmailVerification.user_id == user_id,
|
||||
EmailVerification.is_used == False,
|
||||
EmailVerification.expires_at > datetime.now(UTC)
|
||||
col(EmailVerification.is_used).is_(False),
|
||||
EmailVerification.expires_at > datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
existing = existing_result.first()
|
||||
|
||||
|
||||
if existing:
|
||||
# 如果有未过期的验证码,直接返回
|
||||
return existing, existing.verification_code
|
||||
|
||||
|
||||
# 生成新的验证码
|
||||
code = EmailVerificationService.generate_verification_code()
|
||||
|
||||
|
||||
# 创建验证记录
|
||||
verification = EmailVerification(
|
||||
user_id=user_id,
|
||||
@@ -225,23 +219,23 @@ This email was sent automatically, please do not reply.
|
||||
verification_code=code,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10), # 10分钟过期
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
|
||||
db.add(verification)
|
||||
await db.commit()
|
||||
await db.refresh(verification)
|
||||
|
||||
|
||||
# 存储到 Redis(用于快速验证)
|
||||
await redis.setex(
|
||||
f"email_verification:{user_id}:{code}",
|
||||
600, # 10分钟过期
|
||||
str(verification.id) if verification.id else "0"
|
||||
str(verification.id) if verification.id else "0",
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[Email Verification] Created verification code for user {user_id}: {code}")
|
||||
return verification, code
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def send_verification_email(
|
||||
db: AsyncSession,
|
||||
@@ -250,7 +244,7 @@ This email was sent automatically, please do not reply.
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None
|
||||
user_agent: str | None = None,
|
||||
) -> bool:
|
||||
"""发送验证邮件"""
|
||||
try:
|
||||
@@ -258,33 +252,38 @@ This email was sent automatically, please do not reply.
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Email Verification] Email verification is disabled, skipping for user {user_id}")
|
||||
return True # 返回成功,但不执行验证流程
|
||||
|
||||
|
||||
# 创建验证记录
|
||||
verification, code = await EmailVerificationService.create_verification_record(
|
||||
(
|
||||
verification,
|
||||
code,
|
||||
) = await EmailVerificationService.create_verification_record(
|
||||
db, redis, user_id, email, ip_address, user_agent
|
||||
)
|
||||
|
||||
|
||||
# 使用邮件队列发送验证邮件
|
||||
success = await EmailVerificationService.send_verification_email_via_queue(email, code, username, user_id)
|
||||
|
||||
|
||||
if success:
|
||||
logger.info(f"[Email Verification] Successfully enqueued verification email to {email} (user: {username})")
|
||||
logger.info(
|
||||
f"[Email Verification] Successfully enqueued verification email to {email} (user: {username})"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[Email Verification] Failed to enqueue verification email: {email} (user: {username})")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Exception during sending verification email: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def verify_code(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
user_id: int,
|
||||
code: str,
|
||||
ip_address: str | None = None
|
||||
ip_address: str | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""验证验证码"""
|
||||
try:
|
||||
@@ -294,46 +293,46 @@ This email was sent automatically, please do not reply.
|
||||
# 仍然标记登录会话为已验证
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
return True, "验证成功(邮件验证功能已禁用)"
|
||||
|
||||
|
||||
# 先从 Redis 检查
|
||||
verification_id = await redis.get(f"email_verification:{user_id}:{code}")
|
||||
if not verification_id:
|
||||
return False, "验证码无效或已过期"
|
||||
|
||||
|
||||
# 从数据库获取验证记录
|
||||
result = await db.exec(
|
||||
select(EmailVerification).where(
|
||||
EmailVerification.id == int(verification_id),
|
||||
EmailVerification.user_id == user_id,
|
||||
EmailVerification.verification_code == code,
|
||||
EmailVerification.is_used == False,
|
||||
EmailVerification.expires_at > datetime.now(UTC)
|
||||
col(EmailVerification.is_used).is_(False),
|
||||
EmailVerification.expires_at > datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
verification = result.first()
|
||||
if not verification:
|
||||
return False, "验证码无效或已过期"
|
||||
|
||||
|
||||
# 标记为已使用
|
||||
verification.is_used = True
|
||||
verification.used_at = datetime.now(UTC)
|
||||
|
||||
|
||||
# 同时更新对应的登录会话状态
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
# 删除 Redis 记录
|
||||
await redis.delete(f"email_verification:{user_id}:{code}")
|
||||
|
||||
|
||||
logger.info(f"[Email Verification] User {user_id} verification code verified successfully")
|
||||
return True, "验证成功"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Exception during verification code validation: {e}")
|
||||
return False, "验证过程中发生错误"
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def resend_verification_code(
|
||||
db: AsyncSession,
|
||||
@@ -342,7 +341,7 @@ This email was sent automatically, please do not reply.
|
||||
username: str,
|
||||
email: str,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None
|
||||
user_agent: str | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""重新发送验证码"""
|
||||
try:
|
||||
@@ -350,25 +349,25 @@ This email was sent automatically, please do not reply.
|
||||
if not settings.enable_email_verification:
|
||||
logger.debug(f"[Email Verification] Email verification is disabled, skipping resend for user {user_id}")
|
||||
return True, "验证码已发送(邮件验证功能已禁用)"
|
||||
|
||||
|
||||
# 检查重发频率限制(60秒内只能发送一次)
|
||||
rate_limit_key = f"email_verification_rate_limit:{user_id}"
|
||||
if await redis.get(rate_limit_key):
|
||||
return False, "请等待60秒后再重新发送"
|
||||
|
||||
|
||||
# 设置频率限制
|
||||
await redis.setex(rate_limit_key, 60, "1")
|
||||
|
||||
|
||||
# 生成新的验证码
|
||||
success = await EmailVerificationService.send_verification_email(
|
||||
db, redis, user_id, username, email, ip_address, user_agent
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
return True, "验证码已重新发送"
|
||||
else:
|
||||
return False, "重新发送失败,请稍后再试"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Email Verification] Exception during resending verification code: {e}")
|
||||
return False, "重新发送过程中发生错误"
|
||||
@@ -376,7 +375,7 @@ This email was sent automatically, please do not reply.
|
||||
|
||||
class LoginSessionService:
|
||||
"""登录会话服务"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def create_session(
|
||||
db: AsyncSession,
|
||||
@@ -385,47 +384,40 @@ class LoginSessionService:
|
||||
ip_address: str,
|
||||
user_agent: str | None = None,
|
||||
country_code: str | None = None,
|
||||
is_new_location: bool = False
|
||||
is_new_location: bool = False,
|
||||
) -> LoginSession:
|
||||
"""创建登录会话"""
|
||||
from app.utils import simplify_user_agent
|
||||
|
||||
|
||||
session_token = EmailVerificationService.generate_session_token()
|
||||
|
||||
# 简化 User-Agent 字符串
|
||||
simplified_user_agent = simplify_user_agent(user_agent, max_length=250)
|
||||
|
||||
|
||||
session = LoginSession(
|
||||
user_id=user_id,
|
||||
session_token=session_token,
|
||||
ip_address=ip_address,
|
||||
user_agent=simplified_user_agent,
|
||||
user_agent=None,
|
||||
country_code=country_code,
|
||||
is_new_location=is_new_location,
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=24), # 24小时过期
|
||||
is_verified=not is_new_location # 新位置需要验证
|
||||
is_verified=not is_new_location, # 新位置需要验证
|
||||
)
|
||||
|
||||
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
|
||||
|
||||
# 存储到 Redis
|
||||
await redis.setex(
|
||||
f"login_session:{session_token}",
|
||||
86400, # 24小时
|
||||
user_id
|
||||
user_id,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[Login Session] Created session for user {user_id} (new location: {is_new_location})")
|
||||
return session
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def verify_session(
|
||||
db: AsyncSession,
|
||||
redis: Redis,
|
||||
session_token: str,
|
||||
verification_code: str
|
||||
db: AsyncSession, redis: Redis, session_token: str, verification_code: str
|
||||
) -> tuple[bool, str]:
|
||||
"""验证会话(通过邮件验证码)"""
|
||||
try:
|
||||
@@ -433,98 +425,89 @@ class LoginSessionService:
|
||||
user_id = await redis.get(f"login_session:{session_token}")
|
||||
if not user_id:
|
||||
return False, "会话无效或已过期"
|
||||
|
||||
|
||||
user_id = int(user_id)
|
||||
|
||||
|
||||
# 验证邮件验证码
|
||||
success, message = await EmailVerificationService.verify_code(
|
||||
db, redis, user_id, verification_code
|
||||
)
|
||||
|
||||
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_code)
|
||||
|
||||
if not success:
|
||||
return False, message
|
||||
|
||||
|
||||
# 更新会话状态
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.session_token == session_token,
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.is_verified == False
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
session = result.first()
|
||||
if session:
|
||||
session.is_verified = True
|
||||
session.verified_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
|
||||
|
||||
logger.info(f"[Login Session] User {user_id} session verification successful")
|
||||
return True, "会话验证成功"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during session verification: {e}")
|
||||
return False, "验证过程中发生错误"
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def check_new_location(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
ip_address: str,
|
||||
country_code: str | None = None
|
||||
db: AsyncSession, user_id: int, ip_address: str, country_code: str | None = None
|
||||
) -> bool:
|
||||
"""检查是否为新位置登录"""
|
||||
try:
|
||||
# 查看过去30天内是否有相同IP或相同国家的登录记录
|
||||
thirty_days_ago = datetime.now(UTC) - timedelta(days=30)
|
||||
|
||||
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.created_at > thirty_days_ago,
|
||||
(LoginSession.ip_address == ip_address) |
|
||||
(LoginSession.country_code == country_code)
|
||||
(LoginSession.ip_address == ip_address) | (LoginSession.country_code == country_code),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
existing_sessions = result.all()
|
||||
|
||||
|
||||
# 如果有历史记录,则不是新位置
|
||||
return len(existing_sessions) == 0
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during new location check: {e}")
|
||||
# 出错时默认为新位置(更安全)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def mark_session_verified(
|
||||
db: AsyncSession,
|
||||
user_id: int
|
||||
) -> bool:
|
||||
async def mark_session_verified(db: AsyncSession, user_id: int) -> bool:
|
||||
"""标记用户的未验证会话为已验证"""
|
||||
try:
|
||||
# 查找用户所有未验证且未过期的会话
|
||||
result = await db.exec(
|
||||
select(LoginSession).where(
|
||||
LoginSession.user_id == user_id,
|
||||
LoginSession.is_verified == False,
|
||||
LoginSession.expires_at > datetime.now(UTC)
|
||||
col(LoginSession.is_verified).is_(False),
|
||||
LoginSession.expires_at > datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
sessions = result.all()
|
||||
|
||||
|
||||
# 标记所有会话为已验证
|
||||
for session in sessions:
|
||||
session.is_verified = True
|
||||
session.verified_at = datetime.now(UTC)
|
||||
|
||||
|
||||
if sessions:
|
||||
logger.info(f"[Login Session] Marked {len(sessions)} session(s) as verified for user {user_id}")
|
||||
|
||||
|
||||
return len(sessions) > 0
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Login Session] Exception during marking sessions as verified: {e}")
|
||||
return False
|
||||
|
||||
@@ -117,14 +117,10 @@ class EnhancedIntervalStatsManager:
|
||||
@staticmethod
|
||||
async def get_current_interval_info() -> IntervalInfo:
|
||||
"""获取当前区间信息"""
|
||||
start_time, end_time = (
|
||||
EnhancedIntervalStatsManager.get_current_interval_boundaries()
|
||||
)
|
||||
start_time, end_time = EnhancedIntervalStatsManager.get_current_interval_boundaries()
|
||||
interval_key = EnhancedIntervalStatsManager.generate_interval_key(start_time)
|
||||
|
||||
return IntervalInfo(
|
||||
start_time=start_time, end_time=end_time, interval_key=interval_key
|
||||
)
|
||||
return IntervalInfo(start_time=start_time, end_time=end_time, interval_key=interval_key)
|
||||
|
||||
@staticmethod
|
||||
async def initialize_current_interval() -> None:
|
||||
@@ -133,9 +129,7 @@ class EnhancedIntervalStatsManager:
|
||||
redis_async = get_redis()
|
||||
|
||||
try:
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
|
||||
# 存储当前区间信息
|
||||
await _redis_exec(
|
||||
@@ -147,9 +141,7 @@ class EnhancedIntervalStatsManager:
|
||||
|
||||
# 初始化区间用户集合(如果不存在)
|
||||
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
|
||||
playing_key = (
|
||||
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
)
|
||||
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
|
||||
# 设置过期时间为35分钟
|
||||
await redis_async.expire(online_key, 35 * 60)
|
||||
@@ -179,7 +171,8 @@ class EnhancedIntervalStatsManager:
|
||||
await EnhancedIntervalStatsManager._ensure_24h_history_exists()
|
||||
|
||||
logger.info(
|
||||
f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')} - {current_interval.end_time.strftime('%H:%M')}"
|
||||
f"Initialized interval stats for {current_interval.start_time.strftime('%H:%M')}"
|
||||
f" - {current_interval.end_time.strftime('%H:%M')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -193,42 +186,32 @@ class EnhancedIntervalStatsManager:
|
||||
|
||||
try:
|
||||
# 检查现有历史数据数量
|
||||
history_length = await _redis_exec(
|
||||
redis_sync.llen, REDIS_ONLINE_HISTORY_KEY
|
||||
)
|
||||
history_length = await _redis_exec(redis_sync.llen, REDIS_ONLINE_HISTORY_KEY)
|
||||
|
||||
if history_length < 48: # 少于48个数据点(24小时*2)
|
||||
logger.info(
|
||||
f"History has only {history_length} points, filling with zeros for 24h"
|
||||
)
|
||||
logger.info(f"History has only {history_length} points, filling with zeros for 24h")
|
||||
|
||||
# 计算需要填充的数据点数量
|
||||
needed_points = 48 - history_length
|
||||
|
||||
# 从当前时间往前推,创建缺失的时间点(都填充为0)
|
||||
current_time = datetime.utcnow()
|
||||
current_interval_start, _ = (
|
||||
EnhancedIntervalStatsManager.get_current_interval_boundaries()
|
||||
)
|
||||
current_time = datetime.utcnow() # noqa: F841
|
||||
current_interval_start, _ = EnhancedIntervalStatsManager.get_current_interval_boundaries()
|
||||
|
||||
# 从当前区间开始往前推,创建历史数据点(确保时间对齐到30分钟边界)
|
||||
fill_points = []
|
||||
for i in range(needed_points):
|
||||
# 每次往前推30分钟,确保时间对齐
|
||||
point_time = current_interval_start - timedelta(
|
||||
minutes=30 * (i + 1)
|
||||
)
|
||||
point_time = current_interval_start - timedelta(minutes=30 * (i + 1))
|
||||
|
||||
# 确保时间对齐到30分钟边界
|
||||
aligned_minute = (point_time.minute // 30) * 30
|
||||
point_time = point_time.replace(
|
||||
minute=aligned_minute, second=0, microsecond=0
|
||||
)
|
||||
point_time = point_time.replace(minute=aligned_minute, second=0, microsecond=0)
|
||||
|
||||
history_point = {
|
||||
"timestamp": point_time.isoformat(),
|
||||
"online_count": 0,
|
||||
"playing_count": 0
|
||||
"playing_count": 0,
|
||||
}
|
||||
fill_points.append(json.dumps(history_point))
|
||||
|
||||
@@ -238,9 +221,7 @@ class EnhancedIntervalStatsManager:
|
||||
temp_key = f"{REDIS_ONLINE_HISTORY_KEY}_temp"
|
||||
if history_length > 0:
|
||||
# 复制现有数据到临时key
|
||||
existing_data = await _redis_exec(
|
||||
redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1
|
||||
)
|
||||
existing_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
|
||||
if existing_data:
|
||||
for data in existing_data:
|
||||
await _redis_exec(redis_sync.rpush, temp_key, data)
|
||||
@@ -250,19 +231,13 @@ class EnhancedIntervalStatsManager:
|
||||
|
||||
# 先添加填充数据(最旧的)
|
||||
for point in reversed(fill_points): # 反向添加,最旧的在最后
|
||||
await _redis_exec(
|
||||
redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point
|
||||
)
|
||||
await _redis_exec(redis_sync.rpush, REDIS_ONLINE_HISTORY_KEY, point)
|
||||
|
||||
# 再添加原有数据(较新的)
|
||||
if history_length > 0:
|
||||
existing_data = await _redis_exec(
|
||||
redis_sync.lrange, temp_key, 0, -1
|
||||
)
|
||||
existing_data = await _redis_exec(redis_sync.lrange, temp_key, 0, -1)
|
||||
for data in existing_data:
|
||||
await _redis_exec(
|
||||
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data
|
||||
)
|
||||
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, data)
|
||||
|
||||
# 清理临时key
|
||||
await redis_async.delete(temp_key)
|
||||
@@ -273,9 +248,7 @@ class EnhancedIntervalStatsManager:
|
||||
# 设置过期时间
|
||||
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
|
||||
|
||||
logger.info(
|
||||
f"Filled {len(fill_points)} historical data points with zeros"
|
||||
)
|
||||
logger.info(f"Filled {len(fill_points)} historical data points with zeros")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring 24h history exists: {e}")
|
||||
@@ -287,9 +260,7 @@ class EnhancedIntervalStatsManager:
|
||||
redis_async = get_redis()
|
||||
|
||||
try:
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
|
||||
# 添加到区间在线用户集合
|
||||
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
|
||||
@@ -298,9 +269,7 @@ class EnhancedIntervalStatsManager:
|
||||
|
||||
# 如果用户在游玩,也添加到游玩用户集合
|
||||
if is_playing:
|
||||
playing_key = (
|
||||
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
)
|
||||
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
await _redis_exec(redis_sync.sadd, playing_key, str(user_id))
|
||||
await redis_async.expire(playing_key, 35 * 60)
|
||||
|
||||
@@ -308,7 +277,8 @@ class EnhancedIntervalStatsManager:
|
||||
await EnhancedIntervalStatsManager._update_interval_stats()
|
||||
|
||||
logger.debug(
|
||||
f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}-{current_interval.end_time.strftime('%H:%M')}"
|
||||
f"Added user {user_id} to current interval {current_interval.start_time.strftime('%H:%M')}"
|
||||
f"-{current_interval.end_time.strftime('%H:%M')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -321,15 +291,11 @@ class EnhancedIntervalStatsManager:
|
||||
redis_async = get_redis()
|
||||
|
||||
try:
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
|
||||
# 获取区间内独特用户数
|
||||
online_key = f"{INTERVAL_ONLINE_USERS_KEY}:{current_interval.interval_key}"
|
||||
playing_key = (
|
||||
f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
)
|
||||
playing_key = f"{INTERVAL_PLAYING_USERS_KEY}:{current_interval.interval_key}"
|
||||
|
||||
unique_online = await _redis_exec(redis_sync.scard, online_key)
|
||||
unique_playing = await _redis_exec(redis_sync.scard, playing_key)
|
||||
@@ -339,16 +305,12 @@ class EnhancedIntervalStatsManager:
|
||||
current_playing = await _get_playing_users_count(redis_async)
|
||||
|
||||
# 获取现有统计数据
|
||||
existing_data = await _redis_exec(
|
||||
redis_sync.get, current_interval.interval_key
|
||||
)
|
||||
existing_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
|
||||
if existing_data:
|
||||
stats = IntervalStats.from_dict(json.loads(existing_data))
|
||||
# 更新峰值
|
||||
stats.peak_online_count = max(stats.peak_online_count, current_online)
|
||||
stats.peak_playing_count = max(
|
||||
stats.peak_playing_count, current_playing
|
||||
)
|
||||
stats.peak_playing_count = max(stats.peak_playing_count, current_playing)
|
||||
stats.total_samples += 1
|
||||
else:
|
||||
# 创建新的统计记录
|
||||
@@ -377,7 +339,8 @@ class EnhancedIntervalStatsManager:
|
||||
await redis_async.expire(current_interval.interval_key, 35 * 60)
|
||||
|
||||
logger.debug(
|
||||
f"Updated interval stats: online={unique_online}, playing={unique_playing}, peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}"
|
||||
f"Updated interval stats: online={unique_online}, playing={unique_playing}, "
|
||||
f"peak_online={stats.peak_online_count}, peak_playing={stats.peak_playing_count}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -395,21 +358,21 @@ class EnhancedIntervalStatsManager:
|
||||
# 上一个区间开始时间是当前区间开始时间减去30分钟
|
||||
previous_start = current_start - timedelta(minutes=30)
|
||||
previous_end = current_start # 上一个区间的结束时间就是当前区间的开始时间
|
||||
|
||||
|
||||
interval_key = EnhancedIntervalStatsManager.generate_interval_key(previous_start)
|
||||
|
||||
|
||||
previous_interval = IntervalInfo(
|
||||
start_time=previous_start,
|
||||
end_time=previous_end,
|
||||
interval_key=interval_key
|
||||
interval_key=interval_key,
|
||||
)
|
||||
|
||||
# 获取最终统计数据
|
||||
stats_data = await _redis_exec(
|
||||
redis_sync.get, previous_interval.interval_key
|
||||
)
|
||||
stats_data = await _redis_exec(redis_sync.get, previous_interval.interval_key)
|
||||
if not stats_data:
|
||||
logger.warning(f"No interval stats found to finalize for {previous_interval.start_time.strftime('%H:%M')}")
|
||||
logger.warning(
|
||||
f"No interval stats found to finalize for {previous_interval.start_time.strftime('%H:%M')}"
|
||||
)
|
||||
return None
|
||||
|
||||
stats = IntervalStats.from_dict(json.loads(stats_data))
|
||||
@@ -418,13 +381,11 @@ class EnhancedIntervalStatsManager:
|
||||
history_point = {
|
||||
"timestamp": previous_interval.start_time.isoformat(),
|
||||
"online_count": stats.unique_online_users,
|
||||
"playing_count": stats.unique_playing_users
|
||||
"playing_count": stats.unique_playing_users,
|
||||
}
|
||||
|
||||
# 添加到历史记录
|
||||
await _redis_exec(
|
||||
redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point)
|
||||
)
|
||||
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
|
||||
# 只保留48个数据点(24小时,每30分钟一个点)
|
||||
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
|
||||
# 设置过期时间为26小时,确保有足够缓冲
|
||||
@@ -452,12 +413,8 @@ class EnhancedIntervalStatsManager:
|
||||
redis_sync = get_redis_message()
|
||||
|
||||
try:
|
||||
current_interval = (
|
||||
await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
)
|
||||
stats_data = await _redis_exec(
|
||||
redis_sync.get, current_interval.interval_key
|
||||
)
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
stats_data = await _redis_exec(redis_sync.get, current_interval.interval_key)
|
||||
|
||||
if stats_data:
|
||||
return IntervalStats.from_dict(json.loads(stats_data))
|
||||
@@ -506,8 +463,6 @@ class EnhancedIntervalStatsManager:
|
||||
|
||||
|
||||
# 便捷函数,用于替换现有的统计更新函数
|
||||
async def update_user_activity_in_interval(
|
||||
user_id: int, is_playing: bool = False
|
||||
) -> None:
|
||||
async def update_user_activity_in_interval(user_id: int, is_playing: bool = False) -> None:
|
||||
"""用户活动时更新区间统计(在登录、开始游玩等时调用)"""
|
||||
await EnhancedIntervalStatsManager.add_user_to_interval(user_id, is_playing)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user