refactor(project): make pyright & ruff happy
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
@@ -26,7 +25,7 @@ from app.service.user_cache_service import get_user_cache_service
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import HTTPException, Path, Query, Security
|
||||
from fastapi import BackgroundTasks, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import exists, false, select
|
||||
from sqlmodel.sql.expression import col
|
||||
@@ -47,13 +46,10 @@ class BatchUserResponse(BaseModel):
|
||||
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
|
||||
async def get_users(
|
||||
session: Database,
|
||||
user_ids: list[int] = Query(
|
||||
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"
|
||||
),
|
||||
background_task: BackgroundTasks,
|
||||
user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
include_variant_statistics: bool = Query(
|
||||
default=False, description="是否包含各模式的统计信息"
|
||||
), # TODO: future use
|
||||
include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
@@ -72,11 +68,7 @@ async def get_users(
|
||||
|
||||
# 查询未缓存的用户
|
||||
if uncached_user_ids:
|
||||
searched_users = (
|
||||
await session.exec(
|
||||
select(User).where(col(User.id).in_(uncached_user_ids))
|
||||
)
|
||||
).all()
|
||||
searched_users = (await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))).all()
|
||||
|
||||
# 将查询到的用户添加到缓存并返回
|
||||
for searched_user in searched_users:
|
||||
@@ -88,7 +80,7 @@ async def get_users(
|
||||
)
|
||||
cached_users.append(user_resp)
|
||||
# 异步缓存,不阻塞响应
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return BatchUserResponse(users=cached_users)
|
||||
else:
|
||||
@@ -103,7 +95,7 @@ async def get_users(
|
||||
)
|
||||
users.append(user_resp)
|
||||
# 异步缓存
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return BatchUserResponse(users=users)
|
||||
|
||||
@@ -117,6 +109,7 @@ async def get_users(
|
||||
)
|
||||
async def get_user_info_ruleset(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
ruleset: GameMode | None = Path(description="指定 ruleset"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -134,9 +127,7 @@ async def get_user_info_ruleset(
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
select(User).where(
|
||||
User.id == int(user_id)
|
||||
if user_id.isdigit()
|
||||
else User.username == user_id.removeprefix("@")
|
||||
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
|
||||
)
|
||||
)
|
||||
).first()
|
||||
@@ -151,7 +142,7 @@ async def get_user_info_ruleset(
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(cache_service.cache_user(user_resp, ruleset))
|
||||
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
|
||||
|
||||
return user_resp
|
||||
|
||||
@@ -165,6 +156,7 @@ async def get_user_info_ruleset(
|
||||
tags=["用户"],
|
||||
)
|
||||
async def get_user_info(
|
||||
background_task: BackgroundTasks,
|
||||
session: Database,
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -182,9 +174,7 @@ async def get_user_info(
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
select(User).where(
|
||||
User.id == int(user_id)
|
||||
if user_id.isdigit()
|
||||
else User.username == user_id.removeprefix("@")
|
||||
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
|
||||
)
|
||||
)
|
||||
).first()
|
||||
@@ -198,7 +188,7 @@ async def get_user_info(
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return user_resp
|
||||
|
||||
@@ -212,6 +202,7 @@ async def get_user_info(
|
||||
)
|
||||
async def get_user_beatmapsets(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
type: BeatmapsetType = Path(description="谱面集类型"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -222,9 +213,7 @@ async def get_user_beatmapsets(
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
# 先尝试从缓存获取
|
||||
cached_result = await cache_service.get_user_beatmapsets_from_cache(
|
||||
user_id, type.value, limit, offset
|
||||
)
|
||||
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
|
||||
if cached_result is not None:
|
||||
# 根据类型恢复对象
|
||||
if type == BeatmapsetType.MOST_PLAYED:
|
||||
@@ -253,10 +242,7 @@ async def get_user_beatmapsets(
|
||||
raise HTTPException(404, detail="User not found")
|
||||
favourites = await user.awaitable_attrs.favourite_beatmapsets
|
||||
resp = [
|
||||
await BeatmapsetResp.from_db(
|
||||
favourite.beatmapset, session=session, user=user
|
||||
)
|
||||
for favourite in favourites
|
||||
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
|
||||
]
|
||||
|
||||
elif type == BeatmapsetType.MOST_PLAYED:
|
||||
@@ -267,25 +253,18 @@ async def get_user_beatmapsets(
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
resp = [
|
||||
await BeatmapPlaycountsResp.from_db(most_played_beatmap)
|
||||
for most_played_beatmap in most_played
|
||||
]
|
||||
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
|
||||
else:
|
||||
raise HTTPException(400, detail="Invalid beatmapset type")
|
||||
|
||||
# 异步缓存结果
|
||||
async def cache_beatmapsets():
|
||||
try:
|
||||
await cache_service.cache_user_beatmapsets(
|
||||
user_id, type.value, resp, limit, offset
|
||||
)
|
||||
await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
|
||||
)
|
||||
logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}")
|
||||
|
||||
asyncio.create_task(cache_beatmapsets())
|
||||
background_task.add_task(cache_beatmapsets)
|
||||
|
||||
return resp
|
||||
|
||||
@@ -299,18 +278,14 @@ async def get_user_beatmapsets(
|
||||
)
|
||||
async def get_user_scores(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
type: Literal["best", "recent", "firsts", "pinned"] = Path(
|
||||
description=(
|
||||
"成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩"
|
||||
" / firsts 第一名成绩 / pinned 置顶成绩"
|
||||
)
|
||||
description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")
|
||||
),
|
||||
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"),
|
||||
include_fails: bool = Query(False, description="是否包含失败的成绩"),
|
||||
mode: GameMode | None = Query(
|
||||
None, description="指定 ruleset (可选,默认为用户主模式)"
|
||||
),
|
||||
mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -320,9 +295,7 @@ async def get_user_scores(
|
||||
|
||||
# 先尝试从缓存获取(对于recent类型使用较短的缓存时间)
|
||||
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
|
||||
cached_scores = await cache_service.get_user_scores_from_cache(
|
||||
user_id, type, mode, limit, offset
|
||||
)
|
||||
cached_scores = await cache_service.get_user_scores_from_cache(user_id, type, mode, limit, offset)
|
||||
if cached_scores is not None:
|
||||
return cached_scores
|
||||
|
||||
@@ -332,9 +305,7 @@ async def get_user_scores(
|
||||
|
||||
gamemode = mode or db_user.playmode
|
||||
order_by = None
|
||||
where_clause = (col(Score.user_id) == db_user.id) & (
|
||||
col(Score.gamemode) == gamemode
|
||||
)
|
||||
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
|
||||
if not include_fails:
|
||||
where_clause &= col(Score.passed).is_(True)
|
||||
if type == "pinned":
|
||||
@@ -351,13 +322,7 @@ async def get_user_scores(
|
||||
where_clause &= false()
|
||||
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(Score)
|
||||
.where(where_clause)
|
||||
.order_by(order_by)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
await session.exec(select(Score).where(where_clause).order_by(order_by).limit(limit).offset(offset))
|
||||
).all()
|
||||
if not scores:
|
||||
return []
|
||||
@@ -371,18 +336,14 @@ async def get_user_scores(
|
||||
]
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(
|
||||
cache_service.cache_user_scores(
|
||||
user_id, type, score_responses, mode, limit, offset, cache_expire
|
||||
)
|
||||
background_task.add_task(
|
||||
cache_service.cache_user_scores, user_id, type, score_responses, mode, limit, offset, cache_expire
|
||||
)
|
||||
|
||||
return score_responses
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
|
||||
)
|
||||
@router.get("/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp])
|
||||
async def get_user_events(
|
||||
session: Database,
|
||||
user: int,
|
||||
|
||||
Reference in New Issue
Block a user