refactor(project): make pyright & ruff happy

This commit is contained in:
MingxuanGame
2025-08-22 08:21:52 +00:00
parent 3b1d7a2234
commit 598fcc8b38
157 changed files with 2382 additions and 4590 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
from . import ( # noqa: F401
beatmap,
beatmapset,
me,

View File

@@ -40,18 +40,13 @@ class BatchGetResp(BaseModel):
tags=["谱面"],
name="查询单个谱面",
response_model=BeatmapResp,
description=(
"根据谱面 ID / MD5 / 文件名 查询单个谱面。"
"至少提供 id / checksum / filename 之一。"
),
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
)
async def lookup_beatmap(
db: Database,
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
filename: str | None = Query(
default=None, alias="filename", description="谱面文件名"
),
filename: str | None = Query(default=None, alias="filename", description="谱面文件名"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -96,43 +91,23 @@ async def get_beatmap(
tags=["谱面"],
name="批量获取谱面",
response_model=BatchGetResp,
description=(
"批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。"
"为空时按最近更新时间返回。"
),
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
)
async def batch_get_beatmaps(
db: Database,
beatmap_ids: list[int] = Query(
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
),
beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
):
if not beatmap_ids:
beatmaps = (
await db.exec(
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
)
).all()
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
else:
beatmaps = list(
(
await db.exec(
select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50)
)
).all()
)
not_found_beatmaps = [
bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]
]
beatmaps = list((await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50))).all())
not_found_beatmaps = [bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]]
beatmaps.extend(
beatmap
for beatmap in await asyncio.gather(
*[
Beatmap.get_or_fetch(db, fetcher, bid=bid)
for bid in not_found_beatmaps
],
*[Beatmap.get_or_fetch(db, fetcher, bid=bid) for bid in not_found_beatmaps],
return_exceptions=True,
)
if isinstance(beatmap, Beatmap)
@@ -140,12 +115,7 @@ async def batch_get_beatmaps(
for beatmap in beatmaps:
await db.refresh(beatmap)
return BatchGetResp(
beatmaps=[
await BeatmapResp.from_db(bm, session=db, user=current_user)
for bm in beatmaps
]
)
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
@router.post(
@@ -163,12 +133,8 @@ async def get_beatmap_attributes(
default_factory=list,
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
),
ruleset: GameMode | None = Query(
default=None, description="指定 ruleset;为空则使用谱面自身模式"
),
ruleset_id: int | None = Query(
default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3
),
ruleset: GameMode | None = Query(default=None, description="指定 ruleset为空则使用谱面自身模式"),
ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -187,16 +153,11 @@ async def get_beatmap_attributes(
if ruleset is None:
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
ruleset = beatmap_db.mode
key = (
f"beatmap:{beatmap_id}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
)
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try:
return await calculate_beatmap_attributes(
beatmap_id, ruleset, mods_, redis, fetcher
)
return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher)
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found")
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]

View File

@@ -35,9 +35,7 @@ from sqlmodel import exists, select
async def _save_to_db(sets: SearchBeatmapsetsResp):
async with with_db() as session:
for s in sets.beatmapsets:
if not (
await session.exec(select(exists()).where(Beatmapset.id == s.id))
).first():
if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first():
await Beatmapset.from_resp(session, s)
@@ -117,9 +115,7 @@ async def lookup_beatmapset(
fetcher: Fetcher = Depends(get_fetcher),
):
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
resp = await BeatmapsetResp.from_db(
beatmap.beatmapset, session=db, user=current_user
)
resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
return resp
@@ -138,9 +134,7 @@ async def get_beatmapset(
):
try:
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
return await BeatmapsetResp.from_db(
beatmapset, session=db, include=["recent_favourites"], user=current_user
)
return await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
@@ -165,9 +159,7 @@ async def download_beatmapset(
country_code = geo_info.get("country_iso", "")
# 优先使用IP地理位置判断如果获取失败则回退到用户账户的国家代码
is_china = country_code == "CN" or (
not country_code and current_user.country_code == "CN"
)
is_china = country_code == "CN" or (not country_code and current_user.country_code == "CN")
try:
# 使用负载均衡服务获取下载URL
@@ -179,13 +171,10 @@ async def download_beatmapset(
# 如果负载均衡服务失败,回退到原有逻辑
if is_china:
return RedirectResponse(
f"https://dl.sayobot.cn/beatmaps/download/"
f"{'novideo' if no_video else 'full'}/{beatmapset_id}"
f"https://dl.sayobot.cn/beatmaps/download/{'novideo' if no_video else 'full'}/{beatmapset_id}"
)
else:
return RedirectResponse(
f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}"
)
return RedirectResponse(f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}")
@router.post(
@@ -197,12 +186,9 @@ async def download_beatmapset(
async def favourite_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"),
action: Literal["favourite", "unfavourite"] = Form(
description="操作类型favourite 收藏 / unfavourite 取消收藏"
),
action: Literal["favourite", "unfavourite"] = Form(description="操作类型favourite 收藏 / unfavourite 取消收藏"),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
existing_favourite = (
await db.exec(
select(FavouriteBeatmapset).where(
@@ -212,15 +198,11 @@ async def favourite_beatmapset(
)
).first()
if (action == "favourite" and existing_favourite) or (
action == "unfavourite" and not existing_favourite
):
if (action == "favourite" and existing_favourite) or (action == "unfavourite" and not existing_favourite):
return
if action == "favourite":
favourite = FavouriteBeatmapset(
user_id=current_user.id, beatmapset_id=beatmapset_id
)
favourite = FavouriteBeatmapset(user_id=current_user.id, beatmapset_id=beatmapset_id)
db.add(favourite)
else:
await db.delete(existing_favourite)

View File

@@ -4,8 +4,8 @@ from app.database import User
from app.database.lazer_user import ALL_INCLUDED
from app.dependencies import get_current_user
from app.dependencies.database import Database
from app.models.score import GameMode
from app.models.api_me import APIMe
from app.models.score import GameMode
from .router import router

View File

@@ -33,6 +33,4 @@ class BackgroundsResp(BaseModel):
description="获取当前季节背景图列表。",
)
async def get_seasonal_backgrounds():
return BackgroundsResp(
backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds]
)
return BackgroundsResp(backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds])

View File

@@ -12,7 +12,7 @@ from app.service.ranking_cache_service import get_ranking_cache_service
from .router import router
from fastapi import Path, Query, Security
from fastapi import BackgroundTasks, Path, Query, Security
from pydantic import BaseModel
from sqlmodel import col, select
@@ -38,6 +38,7 @@ class CountryResponse(BaseModel):
)
async def get_country_ranking(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"), # TODO
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -51,9 +52,7 @@ async def get_country_ranking(
if cached_data:
# 从缓存返回数据
return CountryResponse(
ranking=[CountryStatistics.model_validate(item) for item in cached_data]
)
return CountryResponse(ranking=[CountryStatistics.model_validate(item) for item in cached_data])
# 缓存未命中,从数据库查询
response = CountryResponse(ranking=[])
@@ -105,14 +104,15 @@ async def get_country_ranking(
# 异步缓存数据(不等待完成)
cache_data = [item.model_dump() for item in current_page_data]
cache_task = cache_service.cache_country_ranking(
ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60
)
# 创建后台任务来缓存数据
import asyncio
asyncio.create_task(cache_task)
background_tasks.add_task(
cache_service.cache_country_ranking,
ruleset,
cache_data,
page,
ttl=settings.ranking_cache_expire_minutes * 60,
)
# 返回当前页的结果
response.ranking = current_page_data
@@ -132,10 +132,9 @@ class TopUsersResponse(BaseModel):
)
async def get_user_ranking(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
type: Literal["performance", "score"] = Path(
..., description="排名类型performance 表现分 / score 计分成绩总分"
),
type: Literal["performance", "score"] = Path(..., description="排名类型performance 表现分 / score 计分成绩总分"),
country: str | None = Query(None, description="国家代码"),
page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -149,9 +148,7 @@ async def get_user_ranking(
if cached_data:
# 从缓存返回数据
return TopUsersResponse(
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]
)
return TopUsersResponse(ranking=[UserStatisticsResp.model_validate(item) for item in cached_data])
# 缓存未命中,从数据库查询
wheres = [
@@ -169,25 +166,22 @@ async def get_user_ranking(
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
statistics_list = await session.exec(
select(UserStatistics)
.where(*wheres)
.order_by(order_by)
.limit(50)
.offset(50 * (page - 1))
select(UserStatistics).where(*wheres).order_by(order_by).limit(50).offset(50 * (page - 1))
)
# 转换为响应格式
ranking_data = []
for statistics in statistics_list:
user_stats_resp = await UserStatisticsResp.from_db(
statistics, session, None, include
)
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
ranking_data.append(user_stats_resp)
# 异步缓存数据(不等待完成)
# 使用配置文件中的TTL设置
cache_data = [item.model_dump() for item in ranking_data]
cache_task = cache_service.cache_ranking(
# 创建后台任务来缓存数据
background_tasks.add_task(
cache_service.cache_ranking,
ruleset,
type,
cache_data,
@@ -196,139 +190,134 @@ async def get_user_ranking(
ttl=settings.ranking_cache_expire_minutes * 60,
)
# 创建后台任务来缓存数据
import asyncio
asyncio.create_task(cache_task)
resp = TopUsersResponse(ranking=ranking_data)
return resp
""" @router.post(
"/rankings/cache/refresh",
name="刷新排行榜缓存",
description="手动刷新排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def refresh_ranking_cache(
session: Database,
ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
if ruleset and type:
# 刷新特定的用户排行榜
await cache_service.refresh_ranking_cache(session, ruleset, type, country)
message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
# 如果请求刷新地区排行榜
if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
await cache_service.refresh_country_ranking_cache(session, ruleset)
message += f" and country ranking for {ruleset}"
return {"message": message}
elif ruleset:
# 刷新特定游戏模式的所有排行榜
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
for ranking_type in ranking_types:
await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
if include_country_ranking:
await cache_service.refresh_country_ranking_cache(session, ruleset)
return {"message": f"Refreshed all ranking caches for {ruleset}"}
else:
# 刷新所有排行榜
await cache_service.refresh_all_rankings(session)
return {"message": "Refreshed all ranking caches"}
# @router.post(
# "/rankings/cache/refresh",
# name="刷新排行榜缓存",
# description="手动刷新排行榜缓存(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def refresh_ranking_cache(
# session: Database,
# ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
# type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
# country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# if ruleset and type:
# # 刷新特定的用户排行榜
# await cache_service.refresh_ranking_cache(session, ruleset, type, country)
# message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
# # 如果请求刷新地区排行榜
# if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
# await cache_service.refresh_country_ranking_cache(session, ruleset)
# message += f" and country ranking for {ruleset}"
# return {"message": message}
# elif ruleset:
# # 刷新特定游戏模式的所有排行榜
# ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
# for ranking_type in ranking_types:
# await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
# if include_country_ranking:
# await cache_service.refresh_country_ranking_cache(session, ruleset)
# return {"message": f"Refreshed all ranking caches for {ruleset}"}
# else:
# # 刷新所有排行榜
# await cache_service.refresh_all_rankings(session)
# return {"message": "Refreshed all ranking caches"}
@router.post(
"/rankings/{ruleset}/country/cache/refresh",
name="刷新地区排行榜缓存",
description="手动刷新地区排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def refresh_country_ranking_cache(
session: Database,
ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
await cache_service.refresh_country_ranking_cache(session, ruleset)
return {"message": f"Refreshed country ranking cache for {ruleset}"}
# @router.post(
# "/rankings/{ruleset}/country/cache/refresh",
# name="刷新地区排行榜缓存",
# description="手动刷新地区排行榜缓存(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def refresh_country_ranking_cache(
# session: Database,
# ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# await cache_service.refresh_country_ranking_cache(session, ruleset)
# return {"message": f"Refreshed country ranking cache for {ruleset}"}
@router.delete(
"/rankings/cache",
name="清除排行榜缓存",
description="清除排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def clear_ranking_cache(
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
if ruleset and type:
message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
if include_country_ranking:
message += " and country ranking"
return {"message": message}
else:
message = "Cleared all ranking caches"
if include_country_ranking:
message += " including country rankings"
return {"message": message}
# @router.delete(
# "/rankings/cache",
# name="清除排行榜缓存",
# description="清除排行榜缓存(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def clear_ranking_cache(
# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
# type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
# country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
# if ruleset and type:
# message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
# if include_country_ranking:
# message += " and country ranking"
# return {"message": message}
# else:
# message = "Cleared all ranking caches"
# if include_country_ranking:
# message += " including country rankings"
# return {"message": message}
@router.delete(
"/rankings/{ruleset}/country/cache",
name="清除地区排行榜缓存",
description="清除地区排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def clear_country_ranking_cache(
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
await cache_service.invalidate_country_cache(ruleset)
if ruleset:
return {"message": f"Cleared country ranking cache for {ruleset}"}
else:
return {"message": "Cleared all country ranking caches"}
# @router.delete(
# "/rankings/{ruleset}/country/cache",
# name="清除地区排行榜缓存",
# description="清除地区排行榜缓存(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def clear_country_ranking_cache(
# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# await cache_service.invalidate_country_cache(ruleset)
# if ruleset:
# return {"message": f"Cleared country ranking cache for {ruleset}"}
# else:
# return {"message": "Cleared all country ranking caches"}
@router.get(
"/rankings/cache/stats",
name="获取排行榜缓存统计",
description="获取排行榜缓存统计信息(管理员功能)",
tags=["排行榜", "管理"],
)
async def get_ranking_cache_stats(
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
stats = await cache_service.get_cache_stats()
return stats """
# @router.get(
# "/rankings/cache/stats",
# name="获取排行榜缓存统计",
# description="获取排行榜缓存统计信息(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def get_ranking_cache_stats(
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# stats = await cache_service.get_cache_stats()
# return stats

View File

@@ -30,11 +30,7 @@ async def get_relationship(
request: Request,
current_user: User = Security(get_current_user, scopes=["friends.read"]),
):
relationship_type = (
RelationshipType.FOLLOW
if request.url.path.endswith("/friends")
else RelationshipType.BLOCK
)
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
relationships = await db.exec(
select(Relationship).where(
Relationship.user_id == current_user.id,
@@ -71,12 +67,7 @@ async def add_relationship(
target: int = Query(description="目标用户 ID"),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
relationship_type = (
RelationshipType.FOLLOW
if request.url.path.endswith("/friends")
else RelationshipType.BLOCK
)
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
if target == current_user.id:
raise HTTPException(422, "Cannot add relationship to yourself")
relationship = (
@@ -120,11 +111,8 @@ async def add_relationship(
Relationship.target_id == target,
)
)
).first()
assert relationship, "Relationship should exist after commit"
return AddFriendResp(
user_relation=await RelationshipResp.from_db(db, relationship)
)
).one()
return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship))
@router.delete(
@@ -145,11 +133,7 @@ async def delete_relationship(
target: int = Path(..., description="目标用户 ID"),
current_user: User = Security(get_client_user),
):
relationship_type = (
RelationshipType.BLOCK
if "/blocks/" in request.url.path
else RelationshipType.FOLLOW
)
relationship_type = RelationshipType.BLOCK if "/blocks/" in request.url.path else RelationshipType.FOLLOW
relationship = (
await db.exec(
select(Relationship).where(

View File

@@ -39,17 +39,11 @@ async def get_all_rooms(
db: Database,
mode: Literal["open", "ended", "participated", "owned", None] = Query(
default="open",
description=(
"房间模式open 当前开放 / ended 已经结束 / "
"participated 参与过 / owned 自己创建的房间"
),
description=("房间模式open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
),
category: RoomCategory = Query(
RoomCategory.NORMAL,
description=(
"房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
" / DAILY_CHALLENGE 每日挑战"
),
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
),
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -60,10 +54,7 @@ async def get_all_rooms(
if status is not None:
where_clauses.append(col(Room.status) == status)
if mode == "open":
where_clauses.append(
(col(Room.ends_at).is_(None))
| (col(Room.ends_at) > now.replace(tzinfo=UTC))
)
where_clauses.append((col(Room.ends_at).is_(None)) | (col(Room.ends_at) > now.replace(tzinfo=UTC)))
if category == RoomCategory.REALTIME:
where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
if mode == "participated":
@@ -76,10 +67,7 @@ async def get_all_rooms(
if mode == "owned":
where_clauses.append(col(Room.host_id) == current_user.id)
if mode == "ended":
where_clauses.append(
(col(Room.ends_at).is_not(None))
& (col(Room.ends_at) < now.replace(tzinfo=UTC))
)
where_clauses.append((col(Room.ends_at).is_not(None)) & (col(Room.ends_at) < now.replace(tzinfo=UTC)))
db_rooms = (
(
@@ -97,11 +85,7 @@ async def get_all_rooms(
resp = await RoomResp.from_db(room, db)
if category == RoomCategory.REALTIME:
mp_room = MultiplayerHubs.rooms.get(room.id)
resp.has_password = (
bool(mp_room.room.settings.password.strip())
if mp_room is not None
else False
)
resp.has_password = bool(mp_room.room.settings.password.strip()) if mp_room is not None else False
resp.category = RoomCategory.NORMAL
resp_list.append(resp)
@@ -115,9 +99,7 @@ class APICreatedRoom(RoomResp):
error: str = ""
async def _participate_room(
room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
):
async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
participated_user = (
await session.exec(
select(RoomParticipatedUser).where(
@@ -154,7 +136,6 @@ async def create_room(
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
assert current_user.id is not None
user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id)
await _participate_room(db_room.id, user_id, db_room, db, redis)
@@ -177,10 +158,7 @@ async def get_room(
room_id: int = Path(..., description="房间 ID"),
category: str = Query(
default="",
description=(
"房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
" / DAILY_CHALLENGE 每日挑战 (可选)"
),
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
@@ -188,9 +166,7 @@ async def get_room(
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None:
raise HTTPException(404, "Room not found")
resp = await RoomResp.from_db(
db_room, include=["current_user_score"], session=db, user=current_user
)
resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user)
return resp
@@ -400,7 +376,6 @@ async def get_room_events(
for score in scores:
user_ids.add(score.user_id)
beatmap_ids.add(score.beatmap_id)
assert event.id is not None
first_event_id = min(first_event_id, event.id)
last_event_id = max(last_event_id, event.id)
@@ -416,16 +391,12 @@ async def get_room_events(
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
user_resps = [await UserResp.from_db(user, db) for user in users]
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
beatmap_resps = [
await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
]
beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps]
beatmapset_resps = {}
for beatmap_resp in beatmap_resps:
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
playlist_items_resps = [
await PlaylistResp.from_db(item) for item in playlist_items.values()
]
playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()]
return RoomEvents(
beatmaps=beatmap_resps,

View File

@@ -104,11 +104,7 @@ async def submit_score(
if not info.passed:
info.rank = Rank.F
score_token = (
await db.exec(
select(ScoreToken)
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
.where(ScoreToken.id == token)
)
await db.exec(select(ScoreToken).options(joinedload(ScoreToken.beatmap)).where(ScoreToken.id == token))
).first()
if not score_token or score_token.user_id != user_id:
raise HTTPException(status_code=404, detail="Score token not found")
@@ -138,10 +134,7 @@ async def submit_score(
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
has_pp = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp
has_leaderboard = (
db_beatmap.beatmap_status.has_leaderboard()
| settings.enable_all_beatmap_leaderboard
)
has_leaderboard = db_beatmap.beatmap_status.has_leaderboard() | settings.enable_all_beatmap_leaderboard
beatmap_length = db_beatmap.total_length
score = await process_score(
current_user,
@@ -167,21 +160,11 @@ async def submit_score(
has_pp,
has_leaderboard,
)
score = (
await db.exec(
select(Score)
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
.where(Score.id == score_id)
)
).first()
assert score is not None
score = (await db.exec(select(Score).options(joinedload(Score.user)).where(Score.id == score_id))).one()
resp = await ScoreResp.from_db(db, score)
total_users = (await db.exec(select(func.count()).select_from(User))).first()
assert total_users is not None
if resp.rank_global is not None and resp.rank_global <= min(
math.ceil(float(total_users) * 0.01), 50
):
total_users = (await db.exec(select(func.count()).select_from(User))).one()
if resp.rank_global is not None and resp.rank_global <= min(math.ceil(float(total_users) * 0.01), 50):
rank_event = Event(
created_at=datetime.now(UTC),
type=EventType.RANK,
@@ -207,9 +190,7 @@ async def submit_score(
score_gamemode = score.gamemode
if user_id is not None:
background_task.add_task(
_refresh_user_cache_background, redis, user_id, score_gamemode
)
background_task.add_task(_refresh_user_cache_background, redis, user_id, score_gamemode)
background_task.add_task(process_user_achievement, resp.id)
return resp
@@ -225,9 +206,7 @@ async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameM
# 创建独立的数据库会话
session = AsyncSession(engine)
try:
await user_cache_service.refresh_user_cache_on_score_submit(
session, user_id, mode
)
await user_cache_service.refresh_user_cache_on_score_submit(session, user_id, mode)
finally:
await session.close()
except Exception as e:
@@ -280,22 +259,16 @@ async def get_beatmap_scores(
beatmap_id: int = Path(description="谱面 ID"),
mode: GameMode = Query(description="指定 auleset"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
mods: list[str] = Query(
default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"
),
mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"),
type: LeaderboardType = Query(
LeaderboardType.GLOBAL,
description=(
"排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"
),
description=("排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
),
current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="this server only contains lazer scores"
)
raise HTTPException(status_code=404, detail="this server only contains lazer scores")
all_scores, user_score, count = await get_leaderboard(
db,
@@ -310,9 +283,7 @@ async def get_beatmap_scores(
user_score_resp = await ScoreResp.from_db(db, user_score) if user_score else None
resp = BeatmapScores(
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
user_score=BeatmapUserScore(
score=user_score_resp, position=user_score_resp.rank_global or 0
)
user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0)
if user_score_resp
else None,
score_count=count,
@@ -342,9 +313,7 @@ async def get_user_beatmap_score(
current_user: User = Security(get_current_user, scopes=["public"]),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="This server only contains non-legacy scores"
)
raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
user_score = (
await db.exec(
select(Score)
@@ -386,9 +355,7 @@ async def get_user_all_beatmap_scores(
current_user: User = Security(get_current_user, scopes=["public"]),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="This server only contains non-legacy scores"
)
raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
all_user_scores = (
await db.exec(
select(Score)
@@ -420,7 +387,6 @@ async def create_solo_score(
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -454,10 +420,7 @@ async def submit_solo_score(
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
):
assert current_user.id is not None
return await submit_score(
background_task, info, beatmap_id, token, current_user, db, redis, fetcher
)
return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher)
@router.post(
@@ -478,7 +441,6 @@ async def create_playlist_score(
version_hash: str = Form("", description="谱面版本哈希"),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -488,26 +450,16 @@ async def create_playlist_score(
db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None
if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC):
raise HTTPException(status_code=400, detail="Room has ended")
item = (
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
)
).first()
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
if not item:
raise HTTPException(status_code=404, detail="Playlist not found")
# validate
if not item.freestyle:
if item.ruleset_id != ruleset_id:
raise HTTPException(
status_code=400, detail="Ruleset mismatch in playlist item"
)
raise HTTPException(status_code=400, detail="Ruleset mismatch in playlist item")
if item.beatmap_id != beatmap_id:
raise HTTPException(
status_code=400, detail="Beatmap ID mismatch in playlist item"
)
raise HTTPException(status_code=400, detail="Beatmap ID mismatch in playlist item")
agg = await session.exec(
select(ItemAttemptsCount).where(
ItemAttemptsCount.room_id == room_id,
@@ -523,9 +475,7 @@ async def create_playlist_score(
if item.expired:
raise HTTPException(status_code=400, detail="Playlist item has expired")
if item.played_at:
raise HTTPException(
status_code=400, detail="Playlist item has already been played"
)
raise HTTPException(status_code=400, detail="Playlist item has already been played")
# 这里应该不用验证mod了吧。。。
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
score_token = ScoreToken(
@@ -557,18 +507,10 @@ async def submit_playlist_score(
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
):
assert current_user.id is not None
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
item = (
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
)
).first()
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
if not item:
raise HTTPException(status_code=404, detail="Playlist item not found")
room = await session.get(Room, room_id)
@@ -621,9 +563,7 @@ async def index_playlist_scores(
room_id: int,
playlist_id: int,
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
cursor: int = Query(
2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"
),
cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"),
current_user: User = Security(get_current_user, scopes=["public"]),
):
# 立即获取用户ID避免懒加载问题
@@ -693,9 +633,6 @@ async def show_playlist_score(
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
room = await session.get(Room, room_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
@@ -715,9 +652,7 @@ async def show_playlist_score(
)
)
).first()
if completed_players := await redis.get(
f"multiplayer:{room_id}:gameplay:players"
):
if completed_players := await redis.get(f"multiplayer:{room_id}:gameplay:players"):
completed = completed_players == "0"
if score_record and completed:
break
@@ -784,9 +719,7 @@ async def get_user_playlist_score(
raise HTTPException(status_code=404, detail="Score not found")
resp = await ScoreResp.from_db(session, score_record.score)
resp.position = await get_position(
room_id, playlist_id, score_record.score_id, session
)
resp.position = await get_position(room_id, playlist_id, score_record.score_id, session)
return resp
@@ -850,11 +783,7 @@ async def unpin_score(
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
score_record = (
await db.exec(
select(Score).where(Score.id == score_id, Score.user_id == user_id)
)
).first()
score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
@@ -878,10 +807,7 @@ async def unpin_score(
"/score-pins/{score_id}/reorder",
status_code=204,
name="调整置顶成绩顺序",
description=(
"**客户端专属**\n调整已置顶成绩的展示顺序。"
"仅提供 after_score_id 或 before_score_id 之一。"
),
description=("**客户端专属**\n调整已置顶成绩的展示顺序。仅提供 after_score_id 或 before_score_id 之一。"),
tags=["成绩"],
)
async def reorder_score_pin(
@@ -894,11 +820,7 @@ async def reorder_score_pin(
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
score_record = (
await db.exec(
select(Score).where(Score.id == score_id, Score.user_id == user_id)
)
).first()
score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
@@ -908,8 +830,7 @@ async def reorder_score_pin(
if (after_score_id is None) == (before_score_id is None):
raise HTTPException(
status_code=400,
detail="Either after_score_id or before_score_id "
"must be provided (but not both)",
detail="Either after_score_id or before_score_id must be provided (but not both)",
)
all_pinned_scores = (
@@ -927,9 +848,7 @@ async def reorder_score_pin(
target_order = None
reference_score_id = after_score_id or before_score_id
reference_score = next(
(s for s in all_pinned_scores if s.id == reference_score_id), None
)
reference_score = next((s for s in all_pinned_scores if s.id == reference_score_id), None)
if not reference_score:
detail = "After score not found" if after_score_id else "Before score not found"
raise HTTPException(status_code=404, detail=detail)
@@ -951,9 +870,7 @@ async def reorder_score_pin(
if current_order < s.pinned_order <= target_order and s.id != score_id:
updates.append((s.id, s.pinned_order - 1))
if after_score_id:
final_target = (
target_order - 1 if target_order > current_order else target_order
)
final_target = target_order - 1 if target_order > current_order else target_order
else:
final_target = target_order
else:
@@ -964,9 +881,7 @@ async def reorder_score_pin(
for score_id, new_order in updates:
await db.exec(select(Score).where(Score.id == score_id))
score_to_update = (
await db.exec(select(Score).where(Score.id == score_id))
).first()
score_to_update = (await db.exec(select(Score).where(Score.id == score_id))).first()
if score_to_update:
score_to_update.pinned_order = new_order

View File

@@ -4,34 +4,29 @@
from __future__ import annotations
from datetime import datetime, UTC
from typing import Annotated
from app.auth import authenticate_user
from app.config import settings
from app.database import User
from app.dependencies import get_current_user
from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import GeoIPHelper, get_geoip_helper
from app.database.email_verification import EmailVerification, LoginSession
from app.service.email_verification_service import (
EmailVerificationService,
LoginSessionService
EmailVerificationService,
LoginSessionService,
)
from app.service.login_log_service import LoginLogService
from app.models.extended_auth import ExtendedTokenResponse
from fastapi import Form, Depends, Request, HTTPException, status, Security
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlmodel import select
from .router import router
from fastapi import Depends, Form, HTTPException, Request, Security, status
from fastapi.responses import Response
from pydantic import BaseModel
from redis.asyncio import Redis
class SessionReissueResponse(BaseModel):
"""重新发送验证码响应"""
success: bool
message: str
@@ -40,39 +35,35 @@ class SessionReissueResponse(BaseModel):
"/session/verify",
name="验证会话",
description="验证邮件验证码并完成会话认证",
status_code=204
status_code=204,
)
async def verify_session(
request: Request,
db: Database,
redis: Annotated[Redis, Depends(get_redis)],
verification_key: str = Form(..., description="8位邮件验证码"),
current_user: User = Security(get_current_user)
current_user: User = Security(get_current_user),
) -> Response:
"""
验证邮件验证码并完成会话认证
对应 osu! 的 session/verify 接口
成功时返回 204 No Content失败时返回 401 Unauthorized
"""
try:
from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")
ip_address = get_client_ip(request) # noqa: F841
user_agent = request.headers.get("User-Agent", "Unknown") # noqa: F841
# 从当前认证用户获取信息
user_id = current_user.id
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户未认证"
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未认证")
# 验证邮件验证码
success, message = await EmailVerificationService.verify_code(
db, redis, user_id, verification_key
)
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_key)
if success:
# 记录成功的邮件验证
await LoginLogService.record_login(
@@ -81,9 +72,9 @@ async def verify_session(
request=request,
login_method="email_verification",
login_success=True,
notes=f"邮件验证成功"
notes="邮件验证成功",
)
# 返回 204 No Content 表示验证成功
return Response(status_code=status.HTTP_204_NO_CONTENT)
else:
@@ -93,83 +84,69 @@ async def verify_session(
request=request,
attempted_username=current_user.username,
login_method="email_verification",
notes=f"邮件验证失败: {message}"
notes=f"邮件验证失败: {message}",
)
# 返回 401 Unauthorized 表示验证失败
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=message
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=message)
except ValueError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的用户会话"
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="验证过程中发生错误"
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的用户会话")
except Exception:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="验证过程中发生错误")
@router.post(
"/session/verify/reissue",
name="重新发送验证码",
description="重新发送邮件验证码",
response_model=SessionReissueResponse
response_model=SessionReissueResponse,
)
async def reissue_verification_code(
request: Request,
db: Database,
redis: Annotated[Redis, Depends(get_redis)],
current_user: User = Security(get_current_user)
current_user: User = Security(get_current_user),
) -> SessionReissueResponse:
"""
重新发送邮件验证码
对应 osu! 的 session/verify/reissue 接口
"""
try:
from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")
# 从当前认证用户获取信息
user_id = current_user.id
if not user_id:
return SessionReissueResponse(
success=False,
message="用户未认证"
)
return SessionReissueResponse(success=False, message="用户未认证")
# 重新发送验证码
success, message = await EmailVerificationService.resend_verification_code(
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
db,
redis,
user_id,
current_user.username,
current_user.email,
ip_address,
user_agent,
)
return SessionReissueResponse(
success=success,
message=message
)
return SessionReissueResponse(success=success, message=message)
except ValueError:
return SessionReissueResponse(
success=False,
message="无效的用户会话"
)
except Exception as e:
return SessionReissueResponse(
success=False,
message="重新发送过程中发生错误"
)
return SessionReissueResponse(success=False, message="无效的用户会话")
except Exception:
return SessionReissueResponse(success=False, message="重新发送过程中发生错误")
@router.post(
"/session/check-new-location",
name="检查新位置登录",
description="检查登录是否来自新位置(内部接口)"
description="检查登录是否来自新位置(内部接口)",
)
async def check_new_location(
request: Request,
@@ -183,22 +160,21 @@ async def check_new_location(
"""
try:
from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request)
geo_info = geoip.lookup(ip_address)
country_code = geo_info.get("country_iso", "XX")
is_new_location = await LoginSessionService.check_new_location(
db, user_id, ip_address, country_code
)
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
return {
"is_new_location": is_new_location,
"ip_address": ip_address,
"country_code": country_code
"country_code": country_code,
}
except Exception as e:
return {
"is_new_location": True, # 出错时默认为新位置
"error": str(e)
"error": str(e),
}

View File

@@ -1,73 +1,80 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
import json
from typing import Any
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import json
from app.dependencies.database import get_redis, get_redis_message
from app.log import logger
from app.utils import bg_tasks
from .router import router
from fastapi import APIRouter
from pydantic import BaseModel
# Redis key constants
REDIS_ONLINE_USERS_KEY = "server:online_users"
REDIS_PLAYING_USERS_KEY = "server:playing_users"
REDIS_PLAYING_USERS_KEY = "server:playing_users"
REDIS_REGISTERED_USERS_KEY = "server:registered_users"
REDIS_ONLINE_HISTORY_KEY = "server:online_history"
# 线程池用于同步Redis操作
_executor = ThreadPoolExecutor(max_workers=2)
async def _redis_exec(func, *args, **kwargs):
"""在线程池中执行同步Redis操作"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(_executor, func, *args, **kwargs)
class ServerStats(BaseModel):
"""服务器统计信息响应模型"""
registered_users: int
online_users: int
playing_users: int
timestamp: datetime
class OnlineHistoryPoint(BaseModel):
"""在线历史数据点"""
timestamp: datetime
online_count: int
playing_count: int
class OnlineHistoryResponse(BaseModel):
"""24小时在线历史响应模型"""
history: list[OnlineHistoryPoint]
current_stats: ServerStats
@router.get("/stats", response_model=ServerStats, tags=["统计"])
async def get_server_stats() -> ServerStats:
"""
获取服务器实时统计信息
返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息
"""
redis = get_redis()
try:
# 并行获取所有统计数据
registered_count, online_count, playing_count = await asyncio.gather(
_get_registered_users_count(redis),
_get_online_users_count(redis),
_get_playing_users_count(redis)
_get_playing_users_count(redis),
)
return ServerStats(
registered_users=registered_count,
online_users=online_count,
playing_users=playing_count,
timestamp=datetime.utcnow()
timestamp=datetime.utcnow(),
)
except Exception as e:
logger.error(f"Error getting server stats: {e}")
@@ -76,14 +83,15 @@ async def get_server_stats() -> ServerStats:
registered_users=0,
online_users=0,
playing_users=0,
timestamp=datetime.utcnow()
timestamp=datetime.utcnow(),
)
@router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"])
async def get_online_history() -> OnlineHistoryResponse:
"""
获取最近24小时在线统计历史
返回过去24小时内每小时的在线用户数和游玩用户数统计
包含当前实时数据作为最新数据点
"""
@@ -92,80 +100,80 @@ async def get_online_history() -> OnlineHistoryResponse:
redis_sync = get_redis_message()
history_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
history_points = []
# 处理历史数据
for data in history_data:
try:
point_data = json.loads(data)
# 只保留基本字段
history_points.append(OnlineHistoryPoint(
timestamp=datetime.fromisoformat(point_data["timestamp"]),
online_count=point_data["online_count"],
playing_count=point_data["playing_count"]
))
history_points.append(
OnlineHistoryPoint(
timestamp=datetime.fromisoformat(point_data["timestamp"]),
online_count=point_data["online_count"],
playing_count=point_data["playing_count"],
)
)
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f"Invalid history data point: {data}, error: {e}")
continue
# 获取当前实时统计信息
current_stats = await get_server_stats()
# 如果历史数据为空或者最新数据超过15分钟添加当前数据点
if not history_points or (
history_points and
(current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds() > 15 * 60
history_points
and (current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds()
> 15 * 60
):
history_points.append(OnlineHistoryPoint(
timestamp=current_stats.timestamp,
online_count=current_stats.online_users,
playing_count=current_stats.playing_users
))
history_points.append(
OnlineHistoryPoint(
timestamp=current_stats.timestamp,
online_count=current_stats.online_users,
playing_count=current_stats.playing_users,
)
)
# 按时间排序(最新的在前)
history_points.sort(key=lambda x: x.timestamp, reverse=True)
# 限制到最多48个数据点24小时
history_points = history_points[:48]
return OnlineHistoryResponse(
history=history_points,
current_stats=current_stats
)
return OnlineHistoryResponse(history=history_points, current_stats=current_stats)
except Exception as e:
logger.error(f"Error getting online history: {e}")
# 返回空历史和当前状态
current_stats = await get_server_stats()
return OnlineHistoryResponse(
history=[],
current_stats=current_stats
)
return OnlineHistoryResponse(history=[], current_stats=current_stats)
@router.get("/stats/debug", tags=["统计"])
async def get_stats_debug_info():
"""
获取统计系统调试信息
用于调试时间对齐和区间统计问题
"""
try:
from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager
current_time = datetime.utcnow()
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
interval_stats = await EnhancedIntervalStatsManager.get_current_interval_stats()
# 获取Redis中的实际数据
redis_sync = get_redis_message()
online_key = f"server:interval_online_users:{current_interval.interval_key}"
playing_key = f"server:interval_playing_users:{current_interval.interval_key}"
online_users_raw = await _redis_exec(redis_sync.smembers, online_key)
playing_users_raw = await _redis_exec(redis_sync.smembers, playing_key)
online_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in online_users_raw]
playing_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in playing_users_raw]
return {
"current_time": current_time.isoformat(),
"current_interval": {
@@ -175,28 +183,29 @@ async def get_stats_debug_info():
"is_current": current_interval.is_current(),
"minutes_remaining": int((current_interval.end_time - current_time).total_seconds() / 60),
"seconds_remaining": int((current_interval.end_time - current_time).total_seconds()),
"progress_percentage": round((1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100, 1)
"progress_percentage": round(
(1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100,
1,
),
},
"interval_statistics": interval_stats.to_dict() if interval_stats else None,
"redis_data": {
"online_users": online_users,
"playing_users": playing_users,
"online_count": len(online_users),
"playing_count": len(playing_users)
"playing_count": len(playing_users),
},
"system_status": {
"stats_system": "enhanced_interval_stats",
"data_alignment": "30_minute_boundaries",
"real_time_updates": True,
"auto_24h_fill": True
}
"auto_24h_fill": True,
},
}
except Exception as e:
logger.error(f"Error getting debug info: {e}")
return {
"error": "Failed to retrieve debug information",
"message": str(e)
}
return {"error": "Failed to retrieve debug information", "message": str(e)}
async def _get_registered_users_count(redis) -> int:
"""获取注册用户总数(从缓存)"""
@@ -207,6 +216,7 @@ async def _get_registered_users_count(redis) -> int:
logger.error(f"Error getting registered users count: {e}")
return 0
async def _get_online_users_count(redis) -> int:
"""获取当前在线用户数"""
try:
@@ -216,6 +226,7 @@ async def _get_online_users_count(redis) -> int:
logger.error(f"Error getting online users count: {e}")
return 0
async def _get_playing_users_count(redis) -> int:
"""获取当前游玩用户数"""
try:
@@ -225,27 +236,28 @@ async def _get_playing_users_count(redis) -> int:
logger.error(f"Error getting playing users count: {e}")
return 0
# 统计更新功能
async def update_registered_users_count() -> None:
"""更新注册用户数缓存"""
from app.dependencies.database import with_db
from app.database import User
from app.const import BANCHOBOT_ID
from sqlmodel import select, func
from app.database import User
from app.dependencies.database import with_db
from sqlmodel import func, select
redis = get_redis()
try:
async with with_db() as db:
# 排除机器人用户BANCHOBOT_ID
result = await db.exec(
select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID)
)
result = await db.exec(select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID))
count = result.first()
await redis.set(REDIS_REGISTERED_USERS_KEY, count or 0, ex=300) # 5分钟过期
logger.debug(f"Updated registered users count: {count}")
except Exception as e:
logger.error(f"Error updating registered users count: {e}")
async def add_online_user(user_id: int) -> None:
"""添加在线用户"""
redis_sync = get_redis_message()
@@ -257,14 +269,20 @@ async def add_online_user(user_id: int) -> None:
if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期
await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期
logger.debug(f"Added online user {user_id}")
# 立即更新当前区间统计
from app.service.enhanced_interval_stats import update_user_activity_in_interval
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=False))
bg_tasks.add_task(
update_user_activity_in_interval,
user_id,
is_playing=False,
)
except Exception as e:
logger.error(f"Error adding online user {user_id}: {e}")
async def remove_online_user(user_id: int) -> None:
"""移除在线用户"""
redis_sync = get_redis_message()
@@ -274,6 +292,7 @@ async def remove_online_user(user_id: int) -> None:
except Exception as e:
logger.error(f"Error removing online user {user_id}: {e}")
async def add_playing_user(user_id: int) -> None:
"""添加游玩用户"""
redis_sync = get_redis_message()
@@ -285,14 +304,16 @@ async def add_playing_user(user_id: int) -> None:
if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期
await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期
logger.debug(f"Added playing user {user_id}")
# 立即更新当前区间统计
from app.service.enhanced_interval_stats import update_user_activity_in_interval
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=True))
bg_tasks.add_task(update_user_activity_in_interval, user_id, is_playing=True)
except Exception as e:
logger.error(f"Error adding playing user {user_id}: {e}")
async def remove_playing_user(user_id: int) -> None:
"""移除游玩用户"""
redis_sync = get_redis_message()
@@ -301,6 +322,7 @@ async def remove_playing_user(user_id: int) -> None:
except Exception as e:
logger.error(f"Error removing playing user {user_id}: {e}")
async def record_hourly_stats() -> None:
"""记录统计数据 - 简化版本主要作为fallback使用"""
redis_sync = get_redis_message()
@@ -308,24 +330,27 @@ async def record_hourly_stats() -> None:
try:
# 先确保Redis连接正常
await redis_async.ping()
online_count = await _get_online_users_count(redis_async)
playing_count = await _get_playing_users_count(redis_async)
current_time = datetime.utcnow()
history_point = {
"timestamp": current_time.isoformat(),
"online_count": online_count,
"playing_count": playing_count
"playing_count": playing_count,
}
# 添加到历史记录
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
# 只保留48个数据点24小时每30分钟一个点
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
# 设置过期时间为26小时确保有足够缓冲
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
logger.info(f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}")
logger.info(
f"Recorded fallback stats: online={online_count}, playing={playing_count} "
f"at {current_time.strftime('%H:%M:%S')}"
)
except Exception as e:
logger.error(f"Error recording fallback stats: {e}")

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
from datetime import UTC, datetime, timedelta
from typing import Literal
@@ -26,7 +25,7 @@ from app.service.user_cache_service import get_user_cache_service
from .router import router
from fastapi import HTTPException, Path, Query, Security
from fastapi import BackgroundTasks, HTTPException, Path, Query, Security
from pydantic import BaseModel
from sqlmodel import exists, false, select
from sqlmodel.sql.expression import col
@@ -47,13 +46,10 @@ class BatchUserResponse(BaseModel):
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
async def get_users(
session: Database,
user_ids: list[int] = Query(
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"
),
background_task: BackgroundTasks,
user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"),
# current_user: User = Security(get_current_user, scopes=["public"]),
include_variant_statistics: bool = Query(
default=False, description="是否包含各模式的统计信息"
), # TODO: future use
include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
@@ -72,11 +68,7 @@ async def get_users(
# 查询未缓存的用户
if uncached_user_ids:
searched_users = (
await session.exec(
select(User).where(col(User.id).in_(uncached_user_ids))
)
).all()
searched_users = (await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))).all()
# 将查询到的用户添加到缓存并返回
for searched_user in searched_users:
@@ -88,7 +80,7 @@ async def get_users(
)
cached_users.append(user_resp)
# 异步缓存,不阻塞响应
asyncio.create_task(cache_service.cache_user(user_resp))
background_task.add_task(cache_service.cache_user, user_resp)
return BatchUserResponse(users=cached_users)
else:
@@ -103,7 +95,7 @@ async def get_users(
)
users.append(user_resp)
# 异步缓存
asyncio.create_task(cache_service.cache_user(user_resp))
background_task.add_task(cache_service.cache_user, user_resp)
return BatchUserResponse(users=users)
@@ -117,6 +109,7 @@ async def get_users(
)
async def get_user_info_ruleset(
session: Database,
background_task: BackgroundTasks,
user_id: str = Path(description="用户 ID 或用户名"),
ruleset: GameMode | None = Path(description="指定 ruleset"),
# current_user: User = Security(get_current_user, scopes=["public"]),
@@ -134,9 +127,7 @@ async def get_user_info_ruleset(
searched_user = (
await session.exec(
select(User).where(
User.id == int(user_id)
if user_id.isdigit()
else User.username == user_id.removeprefix("@")
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
)
)
).first()
@@ -151,7 +142,7 @@ async def get_user_info_ruleset(
)
# 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp, ruleset))
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
return user_resp
@@ -165,6 +156,7 @@ async def get_user_info_ruleset(
tags=["用户"],
)
async def get_user_info(
background_task: BackgroundTasks,
session: Database,
user_id: str = Path(description="用户 ID 或用户名"),
# current_user: User = Security(get_current_user, scopes=["public"]),
@@ -182,9 +174,7 @@ async def get_user_info(
searched_user = (
await session.exec(
select(User).where(
User.id == int(user_id)
if user_id.isdigit()
else User.username == user_id.removeprefix("@")
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
)
)
).first()
@@ -198,7 +188,7 @@ async def get_user_info(
)
# 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp))
background_task.add_task(cache_service.cache_user, user_resp)
return user_resp
@@ -212,6 +202,7 @@ async def get_user_info(
)
async def get_user_beatmapsets(
session: Database,
background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"),
type: BeatmapsetType = Path(description="谱面集类型"),
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -222,9 +213,7 @@ async def get_user_beatmapsets(
cache_service = get_user_cache_service(redis)
# 先尝试从缓存获取
cached_result = await cache_service.get_user_beatmapsets_from_cache(
user_id, type.value, limit, offset
)
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
if cached_result is not None:
# 根据类型恢复对象
if type == BeatmapsetType.MOST_PLAYED:
@@ -253,10 +242,7 @@ async def get_user_beatmapsets(
raise HTTPException(404, detail="User not found")
favourites = await user.awaitable_attrs.favourite_beatmapsets
resp = [
await BeatmapsetResp.from_db(
favourite.beatmapset, session=session, user=user
)
for favourite in favourites
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
]
elif type == BeatmapsetType.MOST_PLAYED:
@@ -267,25 +253,18 @@ async def get_user_beatmapsets(
.limit(limit)
.offset(offset)
)
resp = [
await BeatmapPlaycountsResp.from_db(most_played_beatmap)
for most_played_beatmap in most_played
]
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
else:
raise HTTPException(400, detail="Invalid beatmapset type")
# 异步缓存结果
async def cache_beatmapsets():
try:
await cache_service.cache_user_beatmapsets(
user_id, type.value, resp, limit, offset
)
await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
except Exception as e:
logger.error(
f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
)
logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}")
asyncio.create_task(cache_beatmapsets())
background_task.add_task(cache_beatmapsets)
return resp
@@ -299,18 +278,14 @@ async def get_user_beatmapsets(
)
async def get_user_scores(
session: Database,
background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"),
type: Literal["best", "recent", "firsts", "pinned"] = Path(
description=(
"成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩"
" / firsts 第一名成绩 / pinned 置顶成绩"
)
description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")
),
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"),
include_fails: bool = Query(False, description="是否包含失败的成绩"),
mode: GameMode | None = Query(
None, description="指定 ruleset (可选,默认为用户主模式)"
),
mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"),
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -320,9 +295,7 @@ async def get_user_scores(
# 先尝试从缓存获取对于recent类型使用较短的缓存时间
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
cached_scores = await cache_service.get_user_scores_from_cache(
user_id, type, mode, limit, offset
)
cached_scores = await cache_service.get_user_scores_from_cache(user_id, type, mode, limit, offset)
if cached_scores is not None:
return cached_scores
@@ -332,9 +305,7 @@ async def get_user_scores(
gamemode = mode or db_user.playmode
order_by = None
where_clause = (col(Score.user_id) == db_user.id) & (
col(Score.gamemode) == gamemode
)
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
if not include_fails:
where_clause &= col(Score.passed).is_(True)
if type == "pinned":
@@ -351,13 +322,7 @@ async def get_user_scores(
where_clause &= false()
scores = (
await session.exec(
select(Score)
.where(where_clause)
.order_by(order_by)
.limit(limit)
.offset(offset)
)
await session.exec(select(Score).where(where_clause).order_by(order_by).limit(limit).offset(offset))
).all()
if not scores:
return []
@@ -371,18 +336,14 @@ async def get_user_scores(
]
# 异步缓存结果
asyncio.create_task(
cache_service.cache_user_scores(
user_id, type, score_responses, mode, limit, offset, cache_expire
)
background_task.add_task(
cache_service.cache_user_scores, user_id, type, score_responses, mode, limit, offset, cache_expire
)
return score_responses
@router.get(
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
)
@router.get("/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp])
async def get_user_events(
session: Database,
user: int,