refactor(project): make pyright & ruff happy
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
|
||||
from . import ( # noqa: F401
|
||||
beatmap,
|
||||
beatmapset,
|
||||
me,
|
||||
|
||||
@@ -40,18 +40,13 @@ class BatchGetResp(BaseModel):
|
||||
tags=["谱面"],
|
||||
name="查询单个谱面",
|
||||
response_model=BeatmapResp,
|
||||
description=(
|
||||
"根据谱面 ID / MD5 / 文件名 查询单个谱面。"
|
||||
"至少提供 id / checksum / filename 之一。"
|
||||
),
|
||||
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
|
||||
)
|
||||
async def lookup_beatmap(
|
||||
db: Database,
|
||||
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
|
||||
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
|
||||
filename: str | None = Query(
|
||||
default=None, alias="filename", description="谱面文件名"
|
||||
),
|
||||
filename: str | None = Query(default=None, alias="filename", description="谱面文件名"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
@@ -96,43 +91,23 @@ async def get_beatmap(
|
||||
tags=["谱面"],
|
||||
name="批量获取谱面",
|
||||
response_model=BatchGetResp,
|
||||
description=(
|
||||
"批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。"
|
||||
"为空时按最近更新时间返回。"
|
||||
),
|
||||
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
|
||||
)
|
||||
async def batch_get_beatmaps(
|
||||
db: Database,
|
||||
beatmap_ids: list[int] = Query(
|
||||
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
|
||||
),
|
||||
beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
if not beatmap_ids:
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
||||
)
|
||||
).all()
|
||||
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
|
||||
else:
|
||||
beatmaps = list(
|
||||
(
|
||||
await db.exec(
|
||||
select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50)
|
||||
)
|
||||
).all()
|
||||
)
|
||||
not_found_beatmaps = [
|
||||
bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]
|
||||
]
|
||||
beatmaps = list((await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50))).all())
|
||||
not_found_beatmaps = [bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]]
|
||||
beatmaps.extend(
|
||||
beatmap
|
||||
for beatmap in await asyncio.gather(
|
||||
*[
|
||||
Beatmap.get_or_fetch(db, fetcher, bid=bid)
|
||||
for bid in not_found_beatmaps
|
||||
],
|
||||
*[Beatmap.get_or_fetch(db, fetcher, bid=bid) for bid in not_found_beatmaps],
|
||||
return_exceptions=True,
|
||||
)
|
||||
if isinstance(beatmap, Beatmap)
|
||||
@@ -140,12 +115,7 @@ async def batch_get_beatmaps(
|
||||
for beatmap in beatmaps:
|
||||
await db.refresh(beatmap)
|
||||
|
||||
return BatchGetResp(
|
||||
beatmaps=[
|
||||
await BeatmapResp.from_db(bm, session=db, user=current_user)
|
||||
for bm in beatmaps
|
||||
]
|
||||
)
|
||||
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -163,12 +133,8 @@ async def get_beatmap_attributes(
|
||||
default_factory=list,
|
||||
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
|
||||
),
|
||||
ruleset: GameMode | None = Query(
|
||||
default=None, description="指定 ruleset;为空则使用谱面自身模式"
|
||||
),
|
||||
ruleset_id: int | None = Query(
|
||||
default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3
|
||||
),
|
||||
ruleset: GameMode | None = Query(default=None, description="指定 ruleset;为空则使用谱面自身模式"),
|
||||
ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
@@ -187,16 +153,11 @@ async def get_beatmap_attributes(
|
||||
if ruleset is None:
|
||||
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||||
ruleset = beatmap_db.mode
|
||||
key = (
|
||||
f"beatmap:{beatmap_id}:{ruleset}:"
|
||||
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
)
|
||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
if await redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
try:
|
||||
return await calculate_beatmap_attributes(
|
||||
beatmap_id, ruleset, mods_, redis, fetcher
|
||||
)
|
||||
return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher)
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
@@ -35,9 +35,7 @@ from sqlmodel import exists, select
|
||||
async def _save_to_db(sets: SearchBeatmapsetsResp):
|
||||
async with with_db() as session:
|
||||
for s in sets.beatmapsets:
|
||||
if not (
|
||||
await session.exec(select(exists()).where(Beatmapset.id == s.id))
|
||||
).first():
|
||||
if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first():
|
||||
await Beatmapset.from_resp(session, s)
|
||||
|
||||
|
||||
@@ -117,9 +115,7 @@ async def lookup_beatmapset(
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||||
resp = await BeatmapsetResp.from_db(
|
||||
beatmap.beatmapset, session=db, user=current_user
|
||||
)
|
||||
resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -138,9 +134,7 @@ async def get_beatmapset(
|
||||
):
|
||||
try:
|
||||
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
|
||||
return await BeatmapsetResp.from_db(
|
||||
beatmapset, session=db, include=["recent_favourites"], user=current_user
|
||||
)
|
||||
return await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
|
||||
@@ -165,9 +159,7 @@ async def download_beatmapset(
|
||||
country_code = geo_info.get("country_iso", "")
|
||||
|
||||
# 优先使用IP地理位置判断,如果获取失败则回退到用户账户的国家代码
|
||||
is_china = country_code == "CN" or (
|
||||
not country_code and current_user.country_code == "CN"
|
||||
)
|
||||
is_china = country_code == "CN" or (not country_code and current_user.country_code == "CN")
|
||||
|
||||
try:
|
||||
# 使用负载均衡服务获取下载URL
|
||||
@@ -179,13 +171,10 @@ async def download_beatmapset(
|
||||
# 如果负载均衡服务失败,回退到原有逻辑
|
||||
if is_china:
|
||||
return RedirectResponse(
|
||||
f"https://dl.sayobot.cn/beatmaps/download/"
|
||||
f"{'novideo' if no_video else 'full'}/{beatmapset_id}"
|
||||
f"https://dl.sayobot.cn/beatmaps/download/{'novideo' if no_video else 'full'}/{beatmapset_id}"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}"
|
||||
)
|
||||
return RedirectResponse(f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}")
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -197,12 +186,9 @@ async def download_beatmapset(
|
||||
async def favourite_beatmapset(
|
||||
db: Database,
|
||||
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
||||
action: Literal["favourite", "unfavourite"] = Form(
|
||||
description="操作类型:favourite 收藏 / unfavourite 取消收藏"
|
||||
),
|
||||
action: Literal["favourite", "unfavourite"] = Form(description="操作类型:favourite 收藏 / unfavourite 取消收藏"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
existing_favourite = (
|
||||
await db.exec(
|
||||
select(FavouriteBeatmapset).where(
|
||||
@@ -212,15 +198,11 @@ async def favourite_beatmapset(
|
||||
)
|
||||
).first()
|
||||
|
||||
if (action == "favourite" and existing_favourite) or (
|
||||
action == "unfavourite" and not existing_favourite
|
||||
):
|
||||
if (action == "favourite" and existing_favourite) or (action == "unfavourite" and not existing_favourite):
|
||||
return
|
||||
|
||||
if action == "favourite":
|
||||
favourite = FavouriteBeatmapset(
|
||||
user_id=current_user.id, beatmapset_id=beatmapset_id
|
||||
)
|
||||
favourite = FavouriteBeatmapset(user_id=current_user.id, beatmapset_id=beatmapset_id)
|
||||
db.add(favourite)
|
||||
else:
|
||||
await db.delete(existing_favourite)
|
||||
|
||||
@@ -4,8 +4,8 @@ from app.database import User
|
||||
from app.database.lazer_user import ALL_INCLUDED
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import Database
|
||||
from app.models.score import GameMode
|
||||
from app.models.api_me import APIMe
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .router import router
|
||||
|
||||
|
||||
@@ -33,6 +33,4 @@ class BackgroundsResp(BaseModel):
|
||||
description="获取当前季节背景图列表。",
|
||||
)
|
||||
async def get_seasonal_backgrounds():
|
||||
return BackgroundsResp(
|
||||
backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds]
|
||||
)
|
||||
return BackgroundsResp(backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds])
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.service.ranking_cache_service import get_ranking_cache_service
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Path, Query, Security
|
||||
from fastapi import BackgroundTasks, Path, Query, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -38,6 +38,7 @@ class CountryResponse(BaseModel):
|
||||
)
|
||||
async def get_country_ranking(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
page: int = Query(1, ge=1, description="页码"), # TODO
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -51,9 +52,7 @@ async def get_country_ranking(
|
||||
|
||||
if cached_data:
|
||||
# 从缓存返回数据
|
||||
return CountryResponse(
|
||||
ranking=[CountryStatistics.model_validate(item) for item in cached_data]
|
||||
)
|
||||
return CountryResponse(ranking=[CountryStatistics.model_validate(item) for item in cached_data])
|
||||
|
||||
# 缓存未命中,从数据库查询
|
||||
response = CountryResponse(ranking=[])
|
||||
@@ -105,14 +104,15 @@ async def get_country_ranking(
|
||||
|
||||
# 异步缓存数据(不等待完成)
|
||||
cache_data = [item.model_dump() for item in current_page_data]
|
||||
cache_task = cache_service.cache_country_ranking(
|
||||
ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60
|
||||
)
|
||||
|
||||
# 创建后台任务来缓存数据
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(cache_task)
|
||||
background_tasks.add_task(
|
||||
cache_service.cache_country_ranking,
|
||||
ruleset,
|
||||
cache_data,
|
||||
page,
|
||||
ttl=settings.ranking_cache_expire_minutes * 60,
|
||||
)
|
||||
|
||||
# 返回当前页的结果
|
||||
response.ranking = current_page_data
|
||||
@@ -132,10 +132,9 @@ class TopUsersResponse(BaseModel):
|
||||
)
|
||||
async def get_user_ranking(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
type: Literal["performance", "score"] = Path(
|
||||
..., description="排名类型:performance 表现分 / score 计分成绩总分"
|
||||
),
|
||||
type: Literal["performance", "score"] = Path(..., description="排名类型:performance 表现分 / score 计分成绩总分"),
|
||||
country: str | None = Query(None, description="国家代码"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -149,9 +148,7 @@ async def get_user_ranking(
|
||||
|
||||
if cached_data:
|
||||
# 从缓存返回数据
|
||||
return TopUsersResponse(
|
||||
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]
|
||||
)
|
||||
return TopUsersResponse(ranking=[UserStatisticsResp.model_validate(item) for item in cached_data])
|
||||
|
||||
# 缓存未命中,从数据库查询
|
||||
wheres = [
|
||||
@@ -169,25 +166,22 @@ async def get_user_ranking(
|
||||
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
|
||||
|
||||
statistics_list = await session.exec(
|
||||
select(UserStatistics)
|
||||
.where(*wheres)
|
||||
.order_by(order_by)
|
||||
.limit(50)
|
||||
.offset(50 * (page - 1))
|
||||
select(UserStatistics).where(*wheres).order_by(order_by).limit(50).offset(50 * (page - 1))
|
||||
)
|
||||
|
||||
# 转换为响应格式
|
||||
ranking_data = []
|
||||
for statistics in statistics_list:
|
||||
user_stats_resp = await UserStatisticsResp.from_db(
|
||||
statistics, session, None, include
|
||||
)
|
||||
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
|
||||
ranking_data.append(user_stats_resp)
|
||||
|
||||
# 异步缓存数据(不等待完成)
|
||||
# 使用配置文件中的TTL设置
|
||||
cache_data = [item.model_dump() for item in ranking_data]
|
||||
cache_task = cache_service.cache_ranking(
|
||||
# 创建后台任务来缓存数据
|
||||
|
||||
background_tasks.add_task(
|
||||
cache_service.cache_ranking,
|
||||
ruleset,
|
||||
type,
|
||||
cache_data,
|
||||
@@ -196,139 +190,134 @@ async def get_user_ranking(
|
||||
ttl=settings.ranking_cache_expire_minutes * 60,
|
||||
)
|
||||
|
||||
# 创建后台任务来缓存数据
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(cache_task)
|
||||
|
||||
resp = TopUsersResponse(ranking=ranking_data)
|
||||
return resp
|
||||
|
||||
|
||||
""" @router.post(
|
||||
"/rankings/cache/refresh",
|
||||
name="刷新排行榜缓存",
|
||||
description="手动刷新排行榜缓存(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def refresh_ranking_cache(
|
||||
session: Database,
|
||||
ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
|
||||
type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
|
||||
country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
|
||||
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
if ruleset and type:
|
||||
# 刷新特定的用户排行榜
|
||||
await cache_service.refresh_ranking_cache(session, ruleset, type, country)
|
||||
message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
|
||||
|
||||
# 如果请求刷新地区排行榜
|
||||
if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
|
||||
await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
message += f" and country ranking for {ruleset}"
|
||||
|
||||
return {"message": message}
|
||||
elif ruleset:
|
||||
# 刷新特定游戏模式的所有排行榜
|
||||
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
|
||||
for ranking_type in ranking_types:
|
||||
await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
|
||||
|
||||
if include_country_ranking:
|
||||
await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
|
||||
return {"message": f"Refreshed all ranking caches for {ruleset}"}
|
||||
else:
|
||||
# 刷新所有排行榜
|
||||
await cache_service.refresh_all_rankings(session)
|
||||
return {"message": "Refreshed all ranking caches"}
|
||||
# @router.post(
|
||||
# "/rankings/cache/refresh",
|
||||
# name="刷新排行榜缓存",
|
||||
# description="手动刷新排行榜缓存(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def refresh_ranking_cache(
|
||||
# session: Database,
|
||||
# ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
|
||||
# type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
|
||||
# country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
|
||||
# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# if ruleset and type:
|
||||
# # 刷新特定的用户排行榜
|
||||
# await cache_service.refresh_ranking_cache(session, ruleset, type, country)
|
||||
# message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
|
||||
|
||||
# # 如果请求刷新地区排行榜
|
||||
# if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
|
||||
# await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
# message += f" and country ranking for {ruleset}"
|
||||
|
||||
# return {"message": message}
|
||||
# elif ruleset:
|
||||
# # 刷新特定游戏模式的所有排行榜
|
||||
# ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
|
||||
# for ranking_type in ranking_types:
|
||||
# await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
|
||||
|
||||
# if include_country_ranking:
|
||||
# await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
|
||||
# return {"message": f"Refreshed all ranking caches for {ruleset}"}
|
||||
# else:
|
||||
# # 刷新所有排行榜
|
||||
# await cache_service.refresh_all_rankings(session)
|
||||
# return {"message": "Refreshed all ranking caches"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rankings/{ruleset}/country/cache/refresh",
|
||||
name="刷新地区排行榜缓存",
|
||||
description="手动刷新地区排行榜缓存(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def refresh_country_ranking_cache(
|
||||
session: Database,
|
||||
ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
return {"message": f"Refreshed country ranking cache for {ruleset}"}
|
||||
# @router.post(
|
||||
# "/rankings/{ruleset}/country/cache/refresh",
|
||||
# name="刷新地区排行榜缓存",
|
||||
# description="手动刷新地区排行榜缓存(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def refresh_country_ranking_cache(
|
||||
# session: Database,
|
||||
# ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
# return {"message": f"Refreshed country ranking cache for {ruleset}"}
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/rankings/cache",
|
||||
name="清除排行榜缓存",
|
||||
description="清除排行榜缓存(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def clear_ranking_cache(
|
||||
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
|
||||
type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
|
||||
country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
|
||||
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
|
||||
|
||||
if ruleset and type:
|
||||
message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
|
||||
if include_country_ranking:
|
||||
message += " and country ranking"
|
||||
return {"message": message}
|
||||
else:
|
||||
message = "Cleared all ranking caches"
|
||||
if include_country_ranking:
|
||||
message += " including country rankings"
|
||||
return {"message": message}
|
||||
# @router.delete(
|
||||
# "/rankings/cache",
|
||||
# name="清除排行榜缓存",
|
||||
# description="清除排行榜缓存(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def clear_ranking_cache(
|
||||
# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
|
||||
# type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
|
||||
# country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
|
||||
# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
|
||||
|
||||
# if ruleset and type:
|
||||
# message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
|
||||
# if include_country_ranking:
|
||||
# message += " and country ranking"
|
||||
# return {"message": message}
|
||||
# else:
|
||||
# message = "Cleared all ranking caches"
|
||||
# if include_country_ranking:
|
||||
# message += " including country rankings"
|
||||
# return {"message": message}
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/rankings/{ruleset}/country/cache",
|
||||
name="清除地区排行榜缓存",
|
||||
description="清除地区排行榜缓存(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def clear_country_ranking_cache(
|
||||
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
await cache_service.invalidate_country_cache(ruleset)
|
||||
|
||||
if ruleset:
|
||||
return {"message": f"Cleared country ranking cache for {ruleset}"}
|
||||
else:
|
||||
return {"message": "Cleared all country ranking caches"}
|
||||
# @router.delete(
|
||||
# "/rankings/{ruleset}/country/cache",
|
||||
# name="清除地区排行榜缓存",
|
||||
# description="清除地区排行榜缓存(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def clear_country_ranking_cache(
|
||||
# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# await cache_service.invalidate_country_cache(ruleset)
|
||||
|
||||
# if ruleset:
|
||||
# return {"message": f"Cleared country ranking cache for {ruleset}"}
|
||||
# else:
|
||||
# return {"message": "Cleared all country ranking caches"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rankings/cache/stats",
|
||||
name="获取排行榜缓存统计",
|
||||
description="获取排行榜缓存统计信息(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def get_ranking_cache_stats(
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
stats = await cache_service.get_cache_stats()
|
||||
return stats """
|
||||
# @router.get(
|
||||
# "/rankings/cache/stats",
|
||||
# name="获取排行榜缓存统计",
|
||||
# description="获取排行榜缓存统计信息(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def get_ranking_cache_stats(
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# stats = await cache_service.get_cache_stats()
|
||||
# return stats
|
||||
|
||||
@@ -30,11 +30,7 @@ async def get_relationship(
|
||||
request: Request,
|
||||
current_user: User = Security(get_current_user, scopes=["friends.read"]),
|
||||
):
|
||||
relationship_type = (
|
||||
RelationshipType.FOLLOW
|
||||
if request.url.path.endswith("/friends")
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
|
||||
relationships = await db.exec(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == current_user.id,
|
||||
@@ -71,12 +67,7 @@ async def add_relationship(
|
||||
target: int = Query(description="目标用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
relationship_type = (
|
||||
RelationshipType.FOLLOW
|
||||
if request.url.path.endswith("/friends")
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
|
||||
if target == current_user.id:
|
||||
raise HTTPException(422, "Cannot add relationship to yourself")
|
||||
relationship = (
|
||||
@@ -120,11 +111,8 @@ async def add_relationship(
|
||||
Relationship.target_id == target,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
assert relationship, "Relationship should exist after commit"
|
||||
return AddFriendResp(
|
||||
user_relation=await RelationshipResp.from_db(db, relationship)
|
||||
)
|
||||
).one()
|
||||
return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship))
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -145,11 +133,7 @@ async def delete_relationship(
|
||||
target: int = Path(..., description="目标用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
relationship_type = (
|
||||
RelationshipType.BLOCK
|
||||
if "/blocks/" in request.url.path
|
||||
else RelationshipType.FOLLOW
|
||||
)
|
||||
relationship_type = RelationshipType.BLOCK if "/blocks/" in request.url.path else RelationshipType.FOLLOW
|
||||
relationship = (
|
||||
await db.exec(
|
||||
select(Relationship).where(
|
||||
|
||||
@@ -39,17 +39,11 @@ async def get_all_rooms(
|
||||
db: Database,
|
||||
mode: Literal["open", "ended", "participated", "owned", None] = Query(
|
||||
default="open",
|
||||
description=(
|
||||
"房间模式:open 当前开放 / ended 已经结束 / "
|
||||
"participated 参与过 / owned 自己创建的房间"
|
||||
),
|
||||
description=("房间模式:open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
|
||||
),
|
||||
category: RoomCategory = Query(
|
||||
RoomCategory.NORMAL,
|
||||
description=(
|
||||
"房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
|
||||
" / DAILY_CHALLENGE 每日挑战"
|
||||
),
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
|
||||
),
|
||||
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -60,10 +54,7 @@ async def get_all_rooms(
|
||||
if status is not None:
|
||||
where_clauses.append(col(Room.status) == status)
|
||||
if mode == "open":
|
||||
where_clauses.append(
|
||||
(col(Room.ends_at).is_(None))
|
||||
| (col(Room.ends_at) > now.replace(tzinfo=UTC))
|
||||
)
|
||||
where_clauses.append((col(Room.ends_at).is_(None)) | (col(Room.ends_at) > now.replace(tzinfo=UTC)))
|
||||
if category == RoomCategory.REALTIME:
|
||||
where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
|
||||
if mode == "participated":
|
||||
@@ -76,10 +67,7 @@ async def get_all_rooms(
|
||||
if mode == "owned":
|
||||
where_clauses.append(col(Room.host_id) == current_user.id)
|
||||
if mode == "ended":
|
||||
where_clauses.append(
|
||||
(col(Room.ends_at).is_not(None))
|
||||
& (col(Room.ends_at) < now.replace(tzinfo=UTC))
|
||||
)
|
||||
where_clauses.append((col(Room.ends_at).is_not(None)) & (col(Room.ends_at) < now.replace(tzinfo=UTC)))
|
||||
|
||||
db_rooms = (
|
||||
(
|
||||
@@ -97,11 +85,7 @@ async def get_all_rooms(
|
||||
resp = await RoomResp.from_db(room, db)
|
||||
if category == RoomCategory.REALTIME:
|
||||
mp_room = MultiplayerHubs.rooms.get(room.id)
|
||||
resp.has_password = (
|
||||
bool(mp_room.room.settings.password.strip())
|
||||
if mp_room is not None
|
||||
else False
|
||||
)
|
||||
resp.has_password = bool(mp_room.room.settings.password.strip()) if mp_room is not None else False
|
||||
resp.category = RoomCategory.NORMAL
|
||||
resp_list.append(resp)
|
||||
|
||||
@@ -115,9 +99,7 @@ class APICreatedRoom(RoomResp):
|
||||
error: str = ""
|
||||
|
||||
|
||||
async def _participate_room(
|
||||
room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
|
||||
):
|
||||
async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
|
||||
participated_user = (
|
||||
await session.exec(
|
||||
select(RoomParticipatedUser).where(
|
||||
@@ -154,7 +136,6 @@ async def create_room(
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
user_id = current_user.id
|
||||
db_room = await create_playlist_room_from_api(db, room, user_id)
|
||||
await _participate_room(db_room.id, user_id, db_room, db, redis)
|
||||
@@ -177,10 +158,7 @@ async def get_room(
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
category: str = Query(
|
||||
default="",
|
||||
description=(
|
||||
"房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
|
||||
" / DAILY_CHALLENGE 每日挑战 (可选)"
|
||||
),
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
|
||||
),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
@@ -188,9 +166,7 @@ async def get_room(
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
resp = await RoomResp.from_db(
|
||||
db_room, include=["current_user_score"], session=db, user=current_user
|
||||
)
|
||||
resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -400,7 +376,6 @@ async def get_room_events(
|
||||
for score in scores:
|
||||
user_ids.add(score.user_id)
|
||||
beatmap_ids.add(score.beatmap_id)
|
||||
assert event.id is not None
|
||||
first_event_id = min(first_event_id, event.id)
|
||||
last_event_id = max(last_event_id, event.id)
|
||||
|
||||
@@ -416,16 +391,12 @@ async def get_room_events(
|
||||
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
|
||||
user_resps = [await UserResp.from_db(user, db) for user in users]
|
||||
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
|
||||
beatmap_resps = [
|
||||
await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
|
||||
]
|
||||
beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps]
|
||||
beatmapset_resps = {}
|
||||
for beatmap_resp in beatmap_resps:
|
||||
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
|
||||
|
||||
playlist_items_resps = [
|
||||
await PlaylistResp.from_db(item) for item in playlist_items.values()
|
||||
]
|
||||
playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()]
|
||||
|
||||
return RoomEvents(
|
||||
beatmaps=beatmap_resps,
|
||||
|
||||
@@ -104,11 +104,7 @@ async def submit_score(
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
score_token = (
|
||||
await db.exec(
|
||||
select(ScoreToken)
|
||||
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(ScoreToken.id == token)
|
||||
)
|
||||
await db.exec(select(ScoreToken).options(joinedload(ScoreToken.beatmap)).where(ScoreToken.id == token))
|
||||
).first()
|
||||
if not score_token or score_token.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Score token not found")
|
||||
@@ -138,10 +134,7 @@ async def submit_score(
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
has_pp = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp
|
||||
has_leaderboard = (
|
||||
db_beatmap.beatmap_status.has_leaderboard()
|
||||
| settings.enable_all_beatmap_leaderboard
|
||||
)
|
||||
has_leaderboard = db_beatmap.beatmap_status.has_leaderboard() | settings.enable_all_beatmap_leaderboard
|
||||
beatmap_length = db_beatmap.total_length
|
||||
score = await process_score(
|
||||
current_user,
|
||||
@@ -167,21 +160,11 @@ async def submit_score(
|
||||
has_pp,
|
||||
has_leaderboard,
|
||||
)
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
|
||||
.where(Score.id == score_id)
|
||||
)
|
||||
).first()
|
||||
assert score is not None
|
||||
score = (await db.exec(select(Score).options(joinedload(Score.user)).where(Score.id == score_id))).one()
|
||||
|
||||
resp = await ScoreResp.from_db(db, score)
|
||||
total_users = (await db.exec(select(func.count()).select_from(User))).first()
|
||||
assert total_users is not None
|
||||
if resp.rank_global is not None and resp.rank_global <= min(
|
||||
math.ceil(float(total_users) * 0.01), 50
|
||||
):
|
||||
total_users = (await db.exec(select(func.count()).select_from(User))).one()
|
||||
if resp.rank_global is not None and resp.rank_global <= min(math.ceil(float(total_users) * 0.01), 50):
|
||||
rank_event = Event(
|
||||
created_at=datetime.now(UTC),
|
||||
type=EventType.RANK,
|
||||
@@ -207,9 +190,7 @@ async def submit_score(
|
||||
score_gamemode = score.gamemode
|
||||
|
||||
if user_id is not None:
|
||||
background_task.add_task(
|
||||
_refresh_user_cache_background, redis, user_id, score_gamemode
|
||||
)
|
||||
background_task.add_task(_refresh_user_cache_background, redis, user_id, score_gamemode)
|
||||
background_task.add_task(process_user_achievement, resp.id)
|
||||
return resp
|
||||
|
||||
@@ -225,9 +206,7 @@ async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameM
|
||||
# 创建独立的数据库会话
|
||||
session = AsyncSession(engine)
|
||||
try:
|
||||
await user_cache_service.refresh_user_cache_on_score_submit(
|
||||
session, user_id, mode
|
||||
)
|
||||
await user_cache_service.refresh_user_cache_on_score_submit(session, user_id, mode)
|
||||
finally:
|
||||
await session.close()
|
||||
except Exception as e:
|
||||
@@ -280,22 +259,16 @@ async def get_beatmap_scores(
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
mode: GameMode = Query(description="指定 auleset"),
|
||||
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
||||
mods: list[str] = Query(
|
||||
default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"
|
||||
),
|
||||
mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"),
|
||||
type: LeaderboardType = Query(
|
||||
LeaderboardType.GLOBAL,
|
||||
description=(
|
||||
"排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"
|
||||
),
|
||||
description=("排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
|
||||
),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="this server only contains lazer scores"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="this server only contains lazer scores")
|
||||
|
||||
all_scores, user_score, count = await get_leaderboard(
|
||||
db,
|
||||
@@ -310,9 +283,7 @@ async def get_beatmap_scores(
|
||||
user_score_resp = await ScoreResp.from_db(db, user_score) if user_score else None
|
||||
resp = BeatmapScores(
|
||||
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
|
||||
user_score=BeatmapUserScore(
|
||||
score=user_score_resp, position=user_score_resp.rank_global or 0
|
||||
)
|
||||
user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0)
|
||||
if user_score_resp
|
||||
else None,
|
||||
score_count=count,
|
||||
@@ -342,9 +313,7 @@ async def get_user_beatmap_score(
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
@@ -386,9 +355,7 @@ async def get_user_all_beatmap_scores(
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
@@ -420,7 +387,6 @@ async def create_solo_score(
|
||||
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
@@ -454,10 +420,7 @@ async def submit_solo_score(
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher=Depends(get_fetcher),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
return await submit_score(
|
||||
background_task, info, beatmap_id, token, current_user, db, redis, fetcher
|
||||
)
|
||||
return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -478,7 +441,6 @@ async def create_playlist_score(
|
||||
version_hash: str = Form("", description="谱面版本哈希"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
@@ -488,26 +450,16 @@ async def create_playlist_score(
|
||||
db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None
|
||||
if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC):
|
||||
raise HTTPException(status_code=400, detail="Room has ended")
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Playlist not found")
|
||||
|
||||
# validate
|
||||
if not item.freestyle:
|
||||
if item.ruleset_id != ruleset_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Ruleset mismatch in playlist item"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Ruleset mismatch in playlist item")
|
||||
if item.beatmap_id != beatmap_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Beatmap ID mismatch in playlist item"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Beatmap ID mismatch in playlist item")
|
||||
agg = await session.exec(
|
||||
select(ItemAttemptsCount).where(
|
||||
ItemAttemptsCount.room_id == room_id,
|
||||
@@ -523,9 +475,7 @@ async def create_playlist_score(
|
||||
if item.expired:
|
||||
raise HTTPException(status_code=400, detail="Playlist item has expired")
|
||||
if item.played_at:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Playlist item has already been played"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Playlist item has already been played")
|
||||
# 这里应该不用验证mod了吧。。。
|
||||
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
|
||||
score_token = ScoreToken(
|
||||
@@ -557,18 +507,10 @@ async def submit_playlist_score(
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Playlist item not found")
|
||||
room = await session.get(Room, room_id)
|
||||
@@ -621,9 +563,7 @@ async def index_playlist_scores(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
|
||||
cursor: int = Query(
|
||||
2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"
|
||||
),
|
||||
cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
@@ -693,9 +633,6 @@ async def show_playlist_score(
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
@@ -715,9 +652,7 @@ async def show_playlist_score(
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if completed_players := await redis.get(
|
||||
f"multiplayer:{room_id}:gameplay:players"
|
||||
):
|
||||
if completed_players := await redis.get(f"multiplayer:{room_id}:gameplay:players"):
|
||||
completed = completed_players == "0"
|
||||
if score_record and completed:
|
||||
break
|
||||
@@ -784,9 +719,7 @@ async def get_user_playlist_score(
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
resp = await ScoreResp.from_db(session, score_record.score)
|
||||
resp.position = await get_position(
|
||||
room_id, playlist_id, score_record.score_id, session
|
||||
)
|
||||
resp.position = await get_position(room_id, playlist_id, score_record.score_id, session)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -850,11 +783,7 @@ async def unpin_score(
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
score_record = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.id == score_id, Score.user_id == user_id)
|
||||
)
|
||||
).first()
|
||||
score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
@@ -878,10 +807,7 @@ async def unpin_score(
|
||||
"/score-pins/{score_id}/reorder",
|
||||
status_code=204,
|
||||
name="调整置顶成绩顺序",
|
||||
description=(
|
||||
"**客户端专属**\n调整已置顶成绩的展示顺序。"
|
||||
"仅提供 after_score_id 或 before_score_id 之一。"
|
||||
),
|
||||
description=("**客户端专属**\n调整已置顶成绩的展示顺序。仅提供 after_score_id 或 before_score_id 之一。"),
|
||||
tags=["成绩"],
|
||||
)
|
||||
async def reorder_score_pin(
|
||||
@@ -894,11 +820,7 @@ async def reorder_score_pin(
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
score_record = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.id == score_id, Score.user_id == user_id)
|
||||
)
|
||||
).first()
|
||||
score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
@@ -908,8 +830,7 @@ async def reorder_score_pin(
|
||||
if (after_score_id is None) == (before_score_id is None):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either after_score_id or before_score_id "
|
||||
"must be provided (but not both)",
|
||||
detail="Either after_score_id or before_score_id must be provided (but not both)",
|
||||
)
|
||||
|
||||
all_pinned_scores = (
|
||||
@@ -927,9 +848,7 @@ async def reorder_score_pin(
|
||||
target_order = None
|
||||
reference_score_id = after_score_id or before_score_id
|
||||
|
||||
reference_score = next(
|
||||
(s for s in all_pinned_scores if s.id == reference_score_id), None
|
||||
)
|
||||
reference_score = next((s for s in all_pinned_scores if s.id == reference_score_id), None)
|
||||
if not reference_score:
|
||||
detail = "After score not found" if after_score_id else "Before score not found"
|
||||
raise HTTPException(status_code=404, detail=detail)
|
||||
@@ -951,9 +870,7 @@ async def reorder_score_pin(
|
||||
if current_order < s.pinned_order <= target_order and s.id != score_id:
|
||||
updates.append((s.id, s.pinned_order - 1))
|
||||
if after_score_id:
|
||||
final_target = (
|
||||
target_order - 1 if target_order > current_order else target_order
|
||||
)
|
||||
final_target = target_order - 1 if target_order > current_order else target_order
|
||||
else:
|
||||
final_target = target_order
|
||||
else:
|
||||
@@ -964,9 +881,7 @@ async def reorder_score_pin(
|
||||
|
||||
for score_id, new_order in updates:
|
||||
await db.exec(select(Score).where(Score.id == score_id))
|
||||
score_to_update = (
|
||||
await db.exec(select(Score).where(Score.id == score_id))
|
||||
).first()
|
||||
score_to_update = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||
if score_to_update:
|
||||
score_to_update.pinned_order = new_order
|
||||
|
||||
|
||||
@@ -4,34 +4,29 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, UTC
|
||||
from typing import Annotated
|
||||
|
||||
from app.auth import authenticate_user
|
||||
from app.config import settings
|
||||
from app.database import User
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import GeoIPHelper, get_geoip_helper
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.service.email_verification_service import (
|
||||
EmailVerificationService,
|
||||
LoginSessionService
|
||||
EmailVerificationService,
|
||||
LoginSessionService,
|
||||
)
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
|
||||
from fastapi import Form, Depends, Request, HTTPException, status, Security
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, Form, HTTPException, Request, Security, status
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class SessionReissueResponse(BaseModel):
|
||||
"""重新发送验证码响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
@@ -40,39 +35,35 @@ class SessionReissueResponse(BaseModel):
|
||||
"/session/verify",
|
||||
name="验证会话",
|
||||
description="验证邮件验证码并完成会话认证",
|
||||
status_code=204
|
||||
status_code=204,
|
||||
)
|
||||
async def verify_session(
|
||||
request: Request,
|
||||
db: Database,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
verification_key: str = Form(..., description="8位邮件验证码"),
|
||||
current_user: User = Security(get_current_user)
|
||||
current_user: User = Security(get_current_user),
|
||||
) -> Response:
|
||||
"""
|
||||
验证邮件验证码并完成会话认证
|
||||
|
||||
|
||||
对应 osu! 的 session/verify 接口
|
||||
成功时返回 204 No Content,失败时返回 401 Unauthorized
|
||||
"""
|
||||
try:
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||
|
||||
|
||||
ip_address = get_client_ip(request) # noqa: F841
|
||||
user_agent = request.headers.get("User-Agent", "Unknown") # noqa: F841
|
||||
|
||||
# 从当前认证用户获取信息
|
||||
user_id = current_user.id
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户未认证"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未认证")
|
||||
|
||||
# 验证邮件验证码
|
||||
success, message = await EmailVerificationService.verify_code(
|
||||
db, redis, user_id, verification_key
|
||||
)
|
||||
|
||||
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_key)
|
||||
|
||||
if success:
|
||||
# 记录成功的邮件验证
|
||||
await LoginLogService.record_login(
|
||||
@@ -81,9 +72,9 @@ async def verify_session(
|
||||
request=request,
|
||||
login_method="email_verification",
|
||||
login_success=True,
|
||||
notes=f"邮件验证成功"
|
||||
notes="邮件验证成功",
|
||||
)
|
||||
|
||||
|
||||
# 返回 204 No Content 表示验证成功
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
else:
|
||||
@@ -93,83 +84,69 @@ async def verify_session(
|
||||
request=request,
|
||||
attempted_username=current_user.username,
|
||||
login_method="email_verification",
|
||||
notes=f"邮件验证失败: {message}"
|
||||
notes=f"邮件验证失败: {message}",
|
||||
)
|
||||
|
||||
|
||||
# 返回 401 Unauthorized 表示验证失败
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=message
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=message)
|
||||
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的用户会话"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="验证过程中发生错误"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的用户会话")
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="验证过程中发生错误")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/session/verify/reissue",
|
||||
name="重新发送验证码",
|
||||
description="重新发送邮件验证码",
|
||||
response_model=SessionReissueResponse
|
||||
response_model=SessionReissueResponse,
|
||||
)
|
||||
async def reissue_verification_code(
|
||||
request: Request,
|
||||
db: Database,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
current_user: User = Security(get_current_user)
|
||||
current_user: User = Security(get_current_user),
|
||||
) -> SessionReissueResponse:
|
||||
"""
|
||||
重新发送邮件验证码
|
||||
|
||||
|
||||
对应 osu! 的 session/verify/reissue 接口
|
||||
"""
|
||||
try:
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||
|
||||
|
||||
# 从当前认证用户获取信息
|
||||
user_id = current_user.id
|
||||
if not user_id:
|
||||
return SessionReissueResponse(
|
||||
success=False,
|
||||
message="用户未认证"
|
||||
)
|
||||
|
||||
return SessionReissueResponse(success=False, message="用户未认证")
|
||||
|
||||
# 重新发送验证码
|
||||
success, message = await EmailVerificationService.resend_verification_code(
|
||||
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
||||
db,
|
||||
redis,
|
||||
user_id,
|
||||
current_user.username,
|
||||
current_user.email,
|
||||
ip_address,
|
||||
user_agent,
|
||||
)
|
||||
|
||||
return SessionReissueResponse(
|
||||
success=success,
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
return SessionReissueResponse(success=success, message=message)
|
||||
|
||||
except ValueError:
|
||||
return SessionReissueResponse(
|
||||
success=False,
|
||||
message="无效的用户会话"
|
||||
)
|
||||
except Exception as e:
|
||||
return SessionReissueResponse(
|
||||
success=False,
|
||||
message="重新发送过程中发生错误"
|
||||
)
|
||||
return SessionReissueResponse(success=False, message="无效的用户会话")
|
||||
except Exception:
|
||||
return SessionReissueResponse(success=False, message="重新发送过程中发生错误")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/session/check-new-location",
|
||||
name="检查新位置登录",
|
||||
description="检查登录是否来自新位置(内部接口)"
|
||||
description="检查登录是否来自新位置(内部接口)",
|
||||
)
|
||||
async def check_new_location(
|
||||
request: Request,
|
||||
@@ -183,22 +160,21 @@ async def check_new_location(
|
||||
"""
|
||||
try:
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
geo_info = geoip.lookup(ip_address)
|
||||
country_code = geo_info.get("country_iso", "XX")
|
||||
|
||||
is_new_location = await LoginSessionService.check_new_location(
|
||||
db, user_id, ip_address, country_code
|
||||
)
|
||||
|
||||
|
||||
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
|
||||
|
||||
return {
|
||||
"is_new_location": is_new_location,
|
||||
"ip_address": ip_address,
|
||||
"country_code": country_code
|
||||
"country_code": country_code,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"is_new_location": True, # 出错时默认为新位置
|
||||
"error": str(e)
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
@@ -1,73 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from typing import Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from app.dependencies.database import get_redis, get_redis_message
|
||||
from app.log import logger
|
||||
from app.utils import bg_tasks
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Redis key constants
|
||||
REDIS_ONLINE_USERS_KEY = "server:online_users"
|
||||
REDIS_PLAYING_USERS_KEY = "server:playing_users"
|
||||
REDIS_PLAYING_USERS_KEY = "server:playing_users"
|
||||
REDIS_REGISTERED_USERS_KEY = "server:registered_users"
|
||||
REDIS_ONLINE_HISTORY_KEY = "server:online_history"
|
||||
|
||||
# 线程池用于同步Redis操作
|
||||
_executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
|
||||
async def _redis_exec(func, *args, **kwargs):
|
||||
"""在线程池中执行同步Redis操作"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(_executor, func, *args, **kwargs)
|
||||
|
||||
|
||||
class ServerStats(BaseModel):
|
||||
"""服务器统计信息响应模型"""
|
||||
|
||||
registered_users: int
|
||||
online_users: int
|
||||
playing_users: int
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class OnlineHistoryPoint(BaseModel):
|
||||
"""在线历史数据点"""
|
||||
|
||||
timestamp: datetime
|
||||
online_count: int
|
||||
playing_count: int
|
||||
|
||||
|
||||
class OnlineHistoryResponse(BaseModel):
|
||||
"""24小时在线历史响应模型"""
|
||||
|
||||
history: list[OnlineHistoryPoint]
|
||||
current_stats: ServerStats
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ServerStats, tags=["统计"])
|
||||
async def get_server_stats() -> ServerStats:
|
||||
"""
|
||||
获取服务器实时统计信息
|
||||
|
||||
|
||||
返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息
|
||||
"""
|
||||
redis = get_redis()
|
||||
|
||||
|
||||
try:
|
||||
# 并行获取所有统计数据
|
||||
registered_count, online_count, playing_count = await asyncio.gather(
|
||||
_get_registered_users_count(redis),
|
||||
_get_online_users_count(redis),
|
||||
_get_playing_users_count(redis)
|
||||
_get_playing_users_count(redis),
|
||||
)
|
||||
|
||||
|
||||
return ServerStats(
|
||||
registered_users=registered_count,
|
||||
online_users=online_count,
|
||||
playing_users=playing_count,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting server stats: {e}")
|
||||
@@ -76,14 +83,15 @@ async def get_server_stats() -> ServerStats:
|
||||
registered_users=0,
|
||||
online_users=0,
|
||||
playing_users=0,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"])
|
||||
async def get_online_history() -> OnlineHistoryResponse:
|
||||
"""
|
||||
获取最近24小时在线统计历史
|
||||
|
||||
|
||||
返回过去24小时内每小时的在线用户数和游玩用户数统计,
|
||||
包含当前实时数据作为最新数据点
|
||||
"""
|
||||
@@ -92,80 +100,80 @@ async def get_online_history() -> OnlineHistoryResponse:
|
||||
redis_sync = get_redis_message()
|
||||
history_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
|
||||
history_points = []
|
||||
|
||||
|
||||
# 处理历史数据
|
||||
for data in history_data:
|
||||
try:
|
||||
point_data = json.loads(data)
|
||||
# 只保留基本字段
|
||||
history_points.append(OnlineHistoryPoint(
|
||||
timestamp=datetime.fromisoformat(point_data["timestamp"]),
|
||||
online_count=point_data["online_count"],
|
||||
playing_count=point_data["playing_count"]
|
||||
))
|
||||
history_points.append(
|
||||
OnlineHistoryPoint(
|
||||
timestamp=datetime.fromisoformat(point_data["timestamp"]),
|
||||
online_count=point_data["online_count"],
|
||||
playing_count=point_data["playing_count"],
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning(f"Invalid history data point: {data}, error: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 获取当前实时统计信息
|
||||
current_stats = await get_server_stats()
|
||||
|
||||
|
||||
# 如果历史数据为空或者最新数据超过15分钟,添加当前数据点
|
||||
if not history_points or (
|
||||
history_points and
|
||||
(current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds() > 15 * 60
|
||||
history_points
|
||||
and (current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds()
|
||||
> 15 * 60
|
||||
):
|
||||
history_points.append(OnlineHistoryPoint(
|
||||
timestamp=current_stats.timestamp,
|
||||
online_count=current_stats.online_users,
|
||||
playing_count=current_stats.playing_users
|
||||
))
|
||||
|
||||
history_points.append(
|
||||
OnlineHistoryPoint(
|
||||
timestamp=current_stats.timestamp,
|
||||
online_count=current_stats.online_users,
|
||||
playing_count=current_stats.playing_users,
|
||||
)
|
||||
)
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
history_points.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
|
||||
|
||||
# 限制到最多48个数据点(24小时)
|
||||
history_points = history_points[:48]
|
||||
|
||||
return OnlineHistoryResponse(
|
||||
history=history_points,
|
||||
current_stats=current_stats
|
||||
)
|
||||
|
||||
return OnlineHistoryResponse(history=history_points, current_stats=current_stats)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting online history: {e}")
|
||||
# 返回空历史和当前状态
|
||||
current_stats = await get_server_stats()
|
||||
return OnlineHistoryResponse(
|
||||
history=[],
|
||||
current_stats=current_stats
|
||||
)
|
||||
return OnlineHistoryResponse(history=[], current_stats=current_stats)
|
||||
|
||||
|
||||
@router.get("/stats/debug", tags=["统计"])
|
||||
async def get_stats_debug_info():
|
||||
"""
|
||||
获取统计系统调试信息
|
||||
|
||||
|
||||
用于调试时间对齐和区间统计问题
|
||||
"""
|
||||
try:
|
||||
from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager
|
||||
|
||||
|
||||
current_time = datetime.utcnow()
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
interval_stats = await EnhancedIntervalStatsManager.get_current_interval_stats()
|
||||
|
||||
|
||||
# 获取Redis中的实际数据
|
||||
redis_sync = get_redis_message()
|
||||
|
||||
|
||||
online_key = f"server:interval_online_users:{current_interval.interval_key}"
|
||||
playing_key = f"server:interval_playing_users:{current_interval.interval_key}"
|
||||
|
||||
|
||||
online_users_raw = await _redis_exec(redis_sync.smembers, online_key)
|
||||
playing_users_raw = await _redis_exec(redis_sync.smembers, playing_key)
|
||||
|
||||
|
||||
online_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in online_users_raw]
|
||||
playing_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in playing_users_raw]
|
||||
|
||||
|
||||
return {
|
||||
"current_time": current_time.isoformat(),
|
||||
"current_interval": {
|
||||
@@ -175,28 +183,29 @@ async def get_stats_debug_info():
|
||||
"is_current": current_interval.is_current(),
|
||||
"minutes_remaining": int((current_interval.end_time - current_time).total_seconds() / 60),
|
||||
"seconds_remaining": int((current_interval.end_time - current_time).total_seconds()),
|
||||
"progress_percentage": round((1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100, 1)
|
||||
"progress_percentage": round(
|
||||
(1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100,
|
||||
1,
|
||||
),
|
||||
},
|
||||
"interval_statistics": interval_stats.to_dict() if interval_stats else None,
|
||||
"redis_data": {
|
||||
"online_users": online_users,
|
||||
"playing_users": playing_users,
|
||||
"online_count": len(online_users),
|
||||
"playing_count": len(playing_users)
|
||||
"playing_count": len(playing_users),
|
||||
},
|
||||
"system_status": {
|
||||
"stats_system": "enhanced_interval_stats",
|
||||
"data_alignment": "30_minute_boundaries",
|
||||
"real_time_updates": True,
|
||||
"auto_24h_fill": True
|
||||
}
|
||||
"auto_24h_fill": True,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting debug info: {e}")
|
||||
return {
|
||||
"error": "Failed to retrieve debug information",
|
||||
"message": str(e)
|
||||
}
|
||||
return {"error": "Failed to retrieve debug information", "message": str(e)}
|
||||
|
||||
|
||||
async def _get_registered_users_count(redis) -> int:
|
||||
"""获取注册用户总数(从缓存)"""
|
||||
@@ -207,6 +216,7 @@ async def _get_registered_users_count(redis) -> int:
|
||||
logger.error(f"Error getting registered users count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _get_online_users_count(redis) -> int:
|
||||
"""获取当前在线用户数"""
|
||||
try:
|
||||
@@ -216,6 +226,7 @@ async def _get_online_users_count(redis) -> int:
|
||||
logger.error(f"Error getting online users count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _get_playing_users_count(redis) -> int:
|
||||
"""获取当前游玩用户数"""
|
||||
try:
|
||||
@@ -225,27 +236,28 @@ async def _get_playing_users_count(redis) -> int:
|
||||
logger.error(f"Error getting playing users count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 统计更新功能
|
||||
async def update_registered_users_count() -> None:
|
||||
"""更新注册用户数缓存"""
|
||||
from app.dependencies.database import with_db
|
||||
from app.database import User
|
||||
from app.const import BANCHOBOT_ID
|
||||
from sqlmodel import select, func
|
||||
|
||||
from app.database import User
|
||||
from app.dependencies.database import with_db
|
||||
|
||||
from sqlmodel import func, select
|
||||
|
||||
redis = get_redis()
|
||||
try:
|
||||
async with with_db() as db:
|
||||
# 排除机器人用户(BANCHOBOT_ID)
|
||||
result = await db.exec(
|
||||
select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID)
|
||||
)
|
||||
result = await db.exec(select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID))
|
||||
count = result.first()
|
||||
await redis.set(REDIS_REGISTERED_USERS_KEY, count or 0, ex=300) # 5分钟过期
|
||||
logger.debug(f"Updated registered users count: {count}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating registered users count: {e}")
|
||||
|
||||
|
||||
async def add_online_user(user_id: int) -> None:
|
||||
"""添加在线用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -257,14 +269,20 @@ async def add_online_user(user_id: int) -> None:
|
||||
if ttl <= 0: # -1表示永不过期,-2表示不存在,0表示已过期
|
||||
await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期
|
||||
logger.debug(f"Added online user {user_id}")
|
||||
|
||||
|
||||
# 立即更新当前区间统计
|
||||
from app.service.enhanced_interval_stats import update_user_activity_in_interval
|
||||
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=False))
|
||||
|
||||
|
||||
bg_tasks.add_task(
|
||||
update_user_activity_in_interval,
|
||||
user_id,
|
||||
is_playing=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding online user {user_id}: {e}")
|
||||
|
||||
|
||||
async def remove_online_user(user_id: int) -> None:
|
||||
"""移除在线用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -274,6 +292,7 @@ async def remove_online_user(user_id: int) -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing online user {user_id}: {e}")
|
||||
|
||||
|
||||
async def add_playing_user(user_id: int) -> None:
|
||||
"""添加游玩用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -285,14 +304,16 @@ async def add_playing_user(user_id: int) -> None:
|
||||
if ttl <= 0: # -1表示永不过期,-2表示不存在,0表示已过期
|
||||
await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期
|
||||
logger.debug(f"Added playing user {user_id}")
|
||||
|
||||
|
||||
# 立即更新当前区间统计
|
||||
from app.service.enhanced_interval_stats import update_user_activity_in_interval
|
||||
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=True))
|
||||
|
||||
|
||||
bg_tasks.add_task(update_user_activity_in_interval, user_id, is_playing=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding playing user {user_id}: {e}")
|
||||
|
||||
|
||||
async def remove_playing_user(user_id: int) -> None:
|
||||
"""移除游玩用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -301,6 +322,7 @@ async def remove_playing_user(user_id: int) -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing playing user {user_id}: {e}")
|
||||
|
||||
|
||||
async def record_hourly_stats() -> None:
|
||||
"""记录统计数据 - 简化版本,主要作为fallback使用"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -308,24 +330,27 @@ async def record_hourly_stats() -> None:
|
||||
try:
|
||||
# 先确保Redis连接正常
|
||||
await redis_async.ping()
|
||||
|
||||
|
||||
online_count = await _get_online_users_count(redis_async)
|
||||
playing_count = await _get_playing_users_count(redis_async)
|
||||
|
||||
|
||||
current_time = datetime.utcnow()
|
||||
history_point = {
|
||||
"timestamp": current_time.isoformat(),
|
||||
"online_count": online_count,
|
||||
"playing_count": playing_count
|
||||
"playing_count": playing_count,
|
||||
}
|
||||
|
||||
|
||||
# 添加到历史记录
|
||||
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
|
||||
# 只保留48个数据点(24小时,每30分钟一个点)
|
||||
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
|
||||
# 设置过期时间为26小时,确保有足够缓冲
|
||||
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
|
||||
|
||||
logger.info(f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}")
|
||||
|
||||
logger.info(
|
||||
f"Recorded fallback stats: online={online_count}, playing={playing_count} "
|
||||
f"at {current_time.strftime('%H:%M:%S')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error recording fallback stats: {e}")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
@@ -26,7 +25,7 @@ from app.service.user_cache_service import get_user_cache_service
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import HTTPException, Path, Query, Security
|
||||
from fastapi import BackgroundTasks, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import exists, false, select
|
||||
from sqlmodel.sql.expression import col
|
||||
@@ -47,13 +46,10 @@ class BatchUserResponse(BaseModel):
|
||||
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
|
||||
async def get_users(
|
||||
session: Database,
|
||||
user_ids: list[int] = Query(
|
||||
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"
|
||||
),
|
||||
background_task: BackgroundTasks,
|
||||
user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
include_variant_statistics: bool = Query(
|
||||
default=False, description="是否包含各模式的统计信息"
|
||||
), # TODO: future use
|
||||
include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
@@ -72,11 +68,7 @@ async def get_users(
|
||||
|
||||
# 查询未缓存的用户
|
||||
if uncached_user_ids:
|
||||
searched_users = (
|
||||
await session.exec(
|
||||
select(User).where(col(User.id).in_(uncached_user_ids))
|
||||
)
|
||||
).all()
|
||||
searched_users = (await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))).all()
|
||||
|
||||
# 将查询到的用户添加到缓存并返回
|
||||
for searched_user in searched_users:
|
||||
@@ -88,7 +80,7 @@ async def get_users(
|
||||
)
|
||||
cached_users.append(user_resp)
|
||||
# 异步缓存,不阻塞响应
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return BatchUserResponse(users=cached_users)
|
||||
else:
|
||||
@@ -103,7 +95,7 @@ async def get_users(
|
||||
)
|
||||
users.append(user_resp)
|
||||
# 异步缓存
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return BatchUserResponse(users=users)
|
||||
|
||||
@@ -117,6 +109,7 @@ async def get_users(
|
||||
)
|
||||
async def get_user_info_ruleset(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
ruleset: GameMode | None = Path(description="指定 ruleset"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -134,9 +127,7 @@ async def get_user_info_ruleset(
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
select(User).where(
|
||||
User.id == int(user_id)
|
||||
if user_id.isdigit()
|
||||
else User.username == user_id.removeprefix("@")
|
||||
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
|
||||
)
|
||||
)
|
||||
).first()
|
||||
@@ -151,7 +142,7 @@ async def get_user_info_ruleset(
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(cache_service.cache_user(user_resp, ruleset))
|
||||
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
|
||||
|
||||
return user_resp
|
||||
|
||||
@@ -165,6 +156,7 @@ async def get_user_info_ruleset(
|
||||
tags=["用户"],
|
||||
)
|
||||
async def get_user_info(
|
||||
background_task: BackgroundTasks,
|
||||
session: Database,
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -182,9 +174,7 @@ async def get_user_info(
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
select(User).where(
|
||||
User.id == int(user_id)
|
||||
if user_id.isdigit()
|
||||
else User.username == user_id.removeprefix("@")
|
||||
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
|
||||
)
|
||||
)
|
||||
).first()
|
||||
@@ -198,7 +188,7 @@ async def get_user_info(
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return user_resp
|
||||
|
||||
@@ -212,6 +202,7 @@ async def get_user_info(
|
||||
)
|
||||
async def get_user_beatmapsets(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
type: BeatmapsetType = Path(description="谱面集类型"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -222,9 +213,7 @@ async def get_user_beatmapsets(
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
# 先尝试从缓存获取
|
||||
cached_result = await cache_service.get_user_beatmapsets_from_cache(
|
||||
user_id, type.value, limit, offset
|
||||
)
|
||||
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
|
||||
if cached_result is not None:
|
||||
# 根据类型恢复对象
|
||||
if type == BeatmapsetType.MOST_PLAYED:
|
||||
@@ -253,10 +242,7 @@ async def get_user_beatmapsets(
|
||||
raise HTTPException(404, detail="User not found")
|
||||
favourites = await user.awaitable_attrs.favourite_beatmapsets
|
||||
resp = [
|
||||
await BeatmapsetResp.from_db(
|
||||
favourite.beatmapset, session=session, user=user
|
||||
)
|
||||
for favourite in favourites
|
||||
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
|
||||
]
|
||||
|
||||
elif type == BeatmapsetType.MOST_PLAYED:
|
||||
@@ -267,25 +253,18 @@ async def get_user_beatmapsets(
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
resp = [
|
||||
await BeatmapPlaycountsResp.from_db(most_played_beatmap)
|
||||
for most_played_beatmap in most_played
|
||||
]
|
||||
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
|
||||
else:
|
||||
raise HTTPException(400, detail="Invalid beatmapset type")
|
||||
|
||||
# 异步缓存结果
|
||||
async def cache_beatmapsets():
|
||||
try:
|
||||
await cache_service.cache_user_beatmapsets(
|
||||
user_id, type.value, resp, limit, offset
|
||||
)
|
||||
await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
|
||||
)
|
||||
logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}")
|
||||
|
||||
asyncio.create_task(cache_beatmapsets())
|
||||
background_task.add_task(cache_beatmapsets)
|
||||
|
||||
return resp
|
||||
|
||||
@@ -299,18 +278,14 @@ async def get_user_beatmapsets(
|
||||
)
|
||||
async def get_user_scores(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
type: Literal["best", "recent", "firsts", "pinned"] = Path(
|
||||
description=(
|
||||
"成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩"
|
||||
" / firsts 第一名成绩 / pinned 置顶成绩"
|
||||
)
|
||||
description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")
|
||||
),
|
||||
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"),
|
||||
include_fails: bool = Query(False, description="是否包含失败的成绩"),
|
||||
mode: GameMode | None = Query(
|
||||
None, description="指定 ruleset (可选,默认为用户主模式)"
|
||||
),
|
||||
mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -320,9 +295,7 @@ async def get_user_scores(
|
||||
|
||||
# 先尝试从缓存获取(对于recent类型使用较短的缓存时间)
|
||||
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
|
||||
cached_scores = await cache_service.get_user_scores_from_cache(
|
||||
user_id, type, mode, limit, offset
|
||||
)
|
||||
cached_scores = await cache_service.get_user_scores_from_cache(user_id, type, mode, limit, offset)
|
||||
if cached_scores is not None:
|
||||
return cached_scores
|
||||
|
||||
@@ -332,9 +305,7 @@ async def get_user_scores(
|
||||
|
||||
gamemode = mode or db_user.playmode
|
||||
order_by = None
|
||||
where_clause = (col(Score.user_id) == db_user.id) & (
|
||||
col(Score.gamemode) == gamemode
|
||||
)
|
||||
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
|
||||
if not include_fails:
|
||||
where_clause &= col(Score.passed).is_(True)
|
||||
if type == "pinned":
|
||||
@@ -351,13 +322,7 @@ async def get_user_scores(
|
||||
where_clause &= false()
|
||||
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(Score)
|
||||
.where(where_clause)
|
||||
.order_by(order_by)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
await session.exec(select(Score).where(where_clause).order_by(order_by).limit(limit).offset(offset))
|
||||
).all()
|
||||
if not scores:
|
||||
return []
|
||||
@@ -371,18 +336,14 @@ async def get_user_scores(
|
||||
]
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(
|
||||
cache_service.cache_user_scores(
|
||||
user_id, type, score_responses, mode, limit, offset, cache_expire
|
||||
)
|
||||
background_task.add_task(
|
||||
cache_service.cache_user_scores, user_id, type, score_responses, mode, limit, offset, cache_expire
|
||||
)
|
||||
|
||||
return score_responses
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
|
||||
)
|
||||
@router.get("/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp])
|
||||
async def get_user_events(
|
||||
session: Database,
|
||||
user: int,
|
||||
|
||||
Reference in New Issue
Block a user