refactor(project): make pyright & ruff happy

This commit is contained in:
MingxuanGame
2025-08-22 08:21:52 +00:00
parent 3b1d7a2234
commit 598fcc8b38
157 changed files with 2382 additions and 4590 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta
import re
from typing import Literal, Union
from typing import Literal
from app.auth import (
authenticate_user,
@@ -22,19 +22,19 @@ from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger
from app.models.extended_auth import ExtendedTokenResponse
from app.models.oauth import (
OAuthErrorResponse,
RegistrationRequestErrors,
TokenResponse,
UserRegistrationErrors,
)
from app.models.extended_auth import ExtendedTokenResponse
from app.models.score import GameMode
from app.service.login_log_service import LoginLogService
from app.service.email_verification_service import (
EmailVerificationService,
LoginSessionService
LoginSessionService,
)
from app.service.login_log_service import LoginLogService
from app.service.password_reset_service import password_reset_service
from fastapi import APIRouter, Depends, Form, Request
@@ -44,13 +44,9 @@ from sqlalchemy import text
from sqlmodel import select
def create_oauth_error_response(
error: str, description: str, hint: str, status_code: int = 400
):
def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400):
"""创建标准的 OAuth 错误响应"""
error_data = OAuthErrorResponse(
error=error, error_description=description, hint=hint, message=description
)
error_data = OAuthErrorResponse(error=error, error_description=description, hint=hint, message=description)
return JSONResponse(status_code=status_code, content=error_data.model_dump())
@@ -123,9 +119,7 @@ async def register_user(
)
)
return JSONResponse(
status_code=422, content={"form_error": errors.model_dump()}
)
return JSONResponse(status_code=422, content={"form_error": errors.model_dump()})
try:
# 获取客户端 IP 并查询地理位置
@@ -137,10 +131,7 @@ async def register_user(
geo_info = geoip.lookup(client_ip)
if geo_info and geo_info.get("country_iso"):
country_code = geo_info["country_iso"]
logger.info(
f"User {user_username} registering from "
f"{client_ip}, country: {country_code}"
)
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
else:
logger.warning(f"Could not determine country for IP {client_ip}")
except Exception as e:
@@ -148,7 +139,7 @@ async def register_user(
# 创建新用户
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
result = await db.execute( # pyright: ignore[reportDeprecated]
result = await db.execute(
text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
@@ -173,7 +164,6 @@ async def register_user(
db.add(new_user)
await db.commit()
await db.refresh(new_user)
assert new_user.id is not None, "New user ID should not be None"
for i in [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]:
statistics = UserStatistics(mode=i, user_id=new_user.id)
db.add(statistics)
@@ -193,36 +183,30 @@ async def register_user(
logger.exception(f"Registration error for user {user_username}")
# 返回通用错误
errors = RegistrationRequestErrors(
message="An error occurred while creating your account. Please try again."
)
errors = RegistrationRequestErrors(message="An error occurred while creating your account. Please try again.")
return JSONResponse(
status_code=500, content={"form_error": errors.model_dump()}
)
return JSONResponse(status_code=500, content={"form_error": errors.model_dump()})
@router.post(
"/oauth/token",
response_model=Union[TokenResponse, ExtendedTokenResponse],
response_model=TokenResponse | ExtendedTokenResponse,
name="获取访问令牌",
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
)
async def oauth_token(
db: Database,
request: Request,
grant_type: Literal[
"authorization_code", "refresh_token", "password", "client_credentials"
] = Form(..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"),
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
),
client_id: int = Form(..., description="客户端 ID"),
client_secret: str = Form(..., description="客户端密钥"),
code: str | None = Form(None, description="授权码(仅授权码模式需要)"),
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*'"),
username: str | None = Form(None, description="用户名(仅密码模式需要)"),
password: str | None = Form(None, description="密码(仅密码模式需要)"),
refresh_token: str | None = Form(
None, description="刷新令牌(仅刷新令牌模式需要)"
),
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
redis: Redis = Depends(get_redis),
geoip: GeoIPHelper = Depends(get_geoip_helper),
):
@@ -303,37 +287,33 @@ async def oauth_token(
await db.refresh(user)
# 获取用户信息和客户端信息
user_id = getattr(user, "id")
assert user_id is not None, "User ID should not be None after authentication"
from app.dependencies.geoip import get_client_ip
user_id = user.id
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "")
# 获取国家代码
geo_info = geoip.lookup(ip_address)
country_code = geo_info.get("country_iso", "XX")
# 检查是否为新位置登录
is_new_location = await LoginSessionService.check_new_location(
db, user_id, ip_address, country_code
)
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
# 创建登录会话记录
login_session = await LoginSessionService.create_session(
login_session = await LoginSessionService.create_session( # noqa: F841
db, redis, user_id, ip_address, user_agent, country_code, is_new_location
)
# 如果是新位置登录,需要邮件验证
if is_new_location and settings.enable_email_verification:
# 刷新用户对象以确保属性已加载
await db.refresh(user)
# 发送邮件验证码
verification_sent = await EmailVerificationService.send_verification_email(
db, redis, user_id, user.username, user.email, ip_address, user_agent
)
# 记录需要二次验证的登录尝试
await LoginLogService.record_login(
db=db,
@@ -343,14 +323,16 @@ async def oauth_token(
login_method="password_pending_verification",
notes=f"新位置登录,需要邮件验证 - IP: {ip_address}, 国家: {country_code}",
)
if not verification_sent:
# 邮件发送失败,记录错误
logger.error(f"[Auth] Failed to send email verification code for user {user_id}")
elif is_new_location and not settings.enable_email_verification:
# 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
await LoginSessionService.mark_session_verified(db, user_id)
logger.debug(f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}")
logger.debug(
f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}"
)
else:
# 不是新位置登录,正常登录
await LoginLogService.record_login(
@@ -361,20 +343,17 @@ async def oauth_token(
login_method="password",
notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}",
)
# 无论是否新位置登录都返回正常的token
# session_verified状态通过/me接口的session_verified字段来体现
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
# 获取用户ID避免触发延迟加载
access_token = create_access_token(
data={"sub": str(user_id)}, expires_delta=access_token_expires
)
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
refresh_token_str = generate_refresh_token()
# 存储令牌
assert user_id
await store_token(
db,
user_id,
@@ -423,9 +402,7 @@ async def oauth_token(
# 生成新的访问令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires
)
access_token = create_access_token(data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires)
new_refresh_token = generate_refresh_token()
# 更新令牌
@@ -489,17 +466,11 @@ async def oauth_token(
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
# 重新查询只获取ID避免触发延迟加载
id_result = await db.exec(select(User.id).where(User.username == username))
user_id = id_result.first()
access_token = create_access_token(
data={"sub": str(user_id)}, expires_delta=access_token_expires
)
user_id = user.id
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
refresh_token_str = generate_refresh_token()
# 存储令牌
assert user_id
await store_token(
db,
user_id,
@@ -539,9 +510,7 @@ async def oauth_token(
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": "3"}, expires_delta=access_token_expires
)
access_token = create_access_token(data={"sub": "3"}, expires_delta=access_token_expires)
refresh_token_str = generate_refresh_token()
# 存储令牌
@@ -567,7 +536,7 @@ async def oauth_token(
@router.post(
"/password-reset/request",
name="请求密码重置",
description="通过邮箱请求密码重置验证码"
description="通过邮箱请求密码重置验证码",
)
async def request_password_reset(
request: Request,
@@ -578,42 +547,26 @@ async def request_password_reset(
请求密码重置
"""
from app.dependencies.geoip import get_client_ip
# 获取客户端信息
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "")
# 请求密码重置
success, message = await password_reset_service.request_password_reset(
email=email.lower().strip(),
ip_address=ip_address,
user_agent=user_agent,
redis=redis
redis=redis,
)
if success:
return JSONResponse(
status_code=200,
content={
"success": True,
"message": message
}
)
return JSONResponse(status_code=200, content={"success": True, "message": message})
else:
return JSONResponse(
status_code=400,
content={
"success": False,
"error": message
}
)
return JSONResponse(status_code=400, content={"success": False, "error": message})
@router.post(
"/password-reset/reset",
name="重置密码",
description="使用验证码重置密码"
)
@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
async def reset_password(
request: Request,
email: str = Form(..., description="邮箱地址"),
@@ -625,32 +578,20 @@ async def reset_password(
重置密码
"""
from app.dependencies.geoip import get_client_ip
# 获取客户端信息
ip_address = get_client_ip(request)
# 重置密码
success, message = await password_reset_service.reset_password(
email=email.lower().strip(),
reset_code=reset_code.strip(),
new_password=new_password,
ip_address=ip_address,
redis=redis
redis=redis,
)
if success:
return JSONResponse(
status_code=200,
content={
"success": True,
"message": message
}
)
return JSONResponse(status_code=200, content={"success": True, "message": message})
else:
return JSONResponse(
status_code=400,
content={
"success": False,
"error": message
}
)
return JSONResponse(status_code=400, content={"success": False, "error": message})

View File

@@ -43,9 +43,9 @@ async def get_notifications(
current_user: User = Security(get_client_user),
):
if settings.server_url is not None:
notification_endpoint = f"{settings.server_url}notification-server".replace(
"http://", "ws://"
).replace("https://", "wss://")
notification_endpoint = f"{settings.server_url}notification-server".replace("http://", "ws://").replace(
"https://", "wss://"
)
else:
notification_endpoint = "/notification-server"
query = select(UserNotification).where(
@@ -96,21 +96,15 @@ async def _get_notifications(
query = base_query.where(UserNotification.notification_id == identity.id)
if identity.object_id is not None:
query = base_query.where(
col(UserNotification.notification).has(
col(Notification.object_id) == identity.object_id
)
col(UserNotification.notification).has(col(Notification.object_id) == identity.object_id)
)
if identity.object_type is not None:
query = base_query.where(
col(UserNotification.notification).has(
col(Notification.object_type) == identity.object_type
)
col(UserNotification.notification).has(col(Notification.object_type) == identity.object_type)
)
if identity.category is not None:
query = base_query.where(
col(UserNotification.notification).has(
col(Notification.category) == identity.category
)
col(UserNotification.notification).has(col(Notification.category) == identity.category)
)
result.update({n.notification_id: n for n in await session.exec(query)})
return list(result.values())
@@ -134,7 +128,6 @@ async def mark_notifications_as_read(
for user_notification in user_notifications:
user_notification.is_read = True
assert current_user.id
await server.send_event(
current_user.id,
ChatEvent(

View File

@@ -91,9 +91,7 @@ class Bot:
if reply:
await self._send_reply(user, channel, reply, session)
async def _send_message(
self, channel: ChatChannel, content: str, session: AsyncSession
) -> None:
async def _send_message(self, channel: ChatChannel, content: str, session: AsyncSession) -> None:
bot = await session.get(User, self.bot_user_id)
if bot is None:
return
@@ -101,7 +99,6 @@ class Bot:
if channel_id is None:
return
assert bot.id is not None
msg = ChatMessage(
channel_id=channel_id,
content=content,
@@ -115,9 +112,7 @@ class Bot:
resp = await ChatMessageResp.from_db(msg, session, bot)
await server.send_message_to_channel(resp)
async def _ensure_pm_channel(
self, user: User, session: AsyncSession
) -> ChatChannel | None:
async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
user_id = user.id
if user_id is None:
return None
@@ -160,9 +155,7 @@ bot = Bot()
@bot.command("help")
async def _help(
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
) -> str:
async def _help(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
cmds = sorted(bot._handlers.keys())
if args:
target = args[0].lower()
@@ -175,9 +168,7 @@ async def _help(
@bot.command("roll")
def _roll(
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
) -> str:
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
if len(args) > 0 and args[0].isdigit():
r = random.randint(1, int(args[0]))
else:
@@ -186,13 +177,9 @@ def _roll(
@bot.command("stats")
async def _stats(
user: User, args: list[str], session: AsyncSession, channel: ChatChannel
) -> str:
async def _stats(user: User, args: list[str], session: AsyncSession, channel: ChatChannel) -> str:
if len(args) >= 1:
target_user = (
await session.exec(select(User).where(User.username == args[0]))
).first()
target_user = (await session.exec(select(User).where(User.username == args[0]))).first()
if not target_user:
return f"User '{args[0]}' not found."
else:
@@ -202,14 +189,8 @@ async def _stats(
if len(args) >= 2:
gamemode = GameMode.parse(args[1].upper())
if gamemode is None:
subquery = (
select(func.max(Score.id))
.where(Score.user_id == target_user.id)
.scalar_subquery()
)
last_score = (
await session.exec(select(Score).where(Score.id == subquery))
).first()
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
if last_score is not None:
gamemode = last_score.gamemode
else:
@@ -295,9 +276,7 @@ async def _mp_host(
return "Usage: !mp host <username>"
username = args[0]
user_id = (
await session.exec(select(User.id).where(User.username == username))
).first()
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
if not user_id:
return f"User '{username}' not found."
@@ -362,24 +341,18 @@ async def _mp_team(
if team is None:
return "Invalid team colour. Use 'red' or 'blue'."
user_id = (
await session.exec(select(User.id).where(User.username == username))
).first()
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
if not user_id:
return f"User '{username}' not found."
user_client = MultiplayerHubs.get_client_by_id(str(user_id))
if not user_client:
return f"User '{username}' is not in the room."
if (
user_client.user_id != signalr_client.user_id
and room.room.host.user_id != signalr_client.user_id
):
assert room.room.host
if user_client.user_id != signalr_client.user_id and room.room.host.user_id != signalr_client.user_id:
return "You are not allowed to change other users' teams."
try:
await MultiplayerHubs.SendMatchRequest(
user_client, ChangeTeamRequest(team_id=team)
)
await MultiplayerHubs.SendMatchRequest(user_client, ChangeTeamRequest(team_id=team))
return ""
except InvokeException as e:
return e.message
@@ -414,9 +387,7 @@ async def _mp_kick(
return "Usage: !mp kick <username>"
username = args[0]
user_id = (
await session.exec(select(User.id).where(User.username == username))
).first()
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
if not user_id:
return f"User '{username}' not found."
@@ -456,10 +427,7 @@ async def _mp_map(
try:
beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id)
if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode:
return (
f"Cannot convert to {playmode.value}. "
f"Original mode is {beatmap.mode.value}."
)
return f"Cannot convert to {playmode.value}. Original mode is {beatmap.mode.value}."
except HTTPError:
return "Beatmap not found"
@@ -530,9 +498,7 @@ async def _mp_mods(
if freestyle:
item.allowed_mods = []
elif freemod:
item.allowed_mods = get_available_mods(
current_item.ruleset_id, required_mods
)
item.allowed_mods = get_available_mods(current_item.ruleset_id, required_mods)
else:
item.allowed_mods = allowed_mods
item.required_mods = required_mods
@@ -601,14 +567,9 @@ async def _score(
include_fail: bool = False,
gamemode: GameMode | None = None,
) -> str:
q = (
select(Score)
.where(Score.user_id == user_id)
.order_by(col(Score.id).desc())
.options(joinedload(Score.beatmap))
)
q = select(Score).where(Score.user_id == user_id).order_by(col(Score.id).desc()).options(joinedload(Score.beatmap))
if not include_fail:
q = q.where(Score.passed.is_(True))
q = q.where(col(Score.passed).is_(True))
if gamemode is not None:
q = q.where(Score.gamemode == gamemode)
@@ -619,17 +580,13 @@ async def _score(
result = f"""{score.beatmap.beatmapset.title} [{score.beatmap.version}] ({score.gamemode.name.lower()})
Played at {score.started_at}
{score.pp:.2f}pp {score.accuracy:.2%} {",".join(mod_to_save(score.mods))} {score.rank.name.upper()}
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}""" # noqa: E501
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}"""
if score.gamemode == GameMode.MANIA:
keys = next(
(mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None
)
keys = next((mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None)
if keys is None:
keys = f"{int(score.beatmap.cs)}K"
p_d_g = f"{score.ngeki / score.n300:.2f}:1" if score.n300 > 0 else "inf:1"
result += (
f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
)
result += f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
return result

View File

@@ -38,27 +38,18 @@ class UpdateResponse(BaseModel):
)
async def get_update(
session: Database,
history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录"
),
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
includes: list[str] = Query(
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
),
includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
):
resp = UpdateResponse()
if "presence" in includes:
assert current_user.id
channel_ids = server.get_user_joined_channel(current_user.id)
for channel_id in channel_ids:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel:
# 提取必要的属性避免惰性加载
channel_type = db_channel.type
@@ -69,34 +60,20 @@ async def get_update(
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
if "silences" in includes:
if history_since:
silences = (
await session.exec(
select(SilenceUser).where(col(SilenceUser.id) > history_since)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(
select(SilenceUser).where(
col(SilenceUser.banned_at) > msg.timestamp
)
)
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
return resp
@@ -115,15 +92,9 @@ async def join_channel(
):
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -145,15 +116,9 @@ async def leave_channel(
):
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -173,27 +138,20 @@ async def get_channel_list(
current_user: User = Security(get_current_user, scopes=["chat.read"]),
redis: Redis = Depends(get_redis),
):
channels = (
await session.exec(
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
)
).all()
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
results = []
for channel in channels:
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
channel_type = channel.type
assert channel_id is not None
results.append(
await ChatChannelResp.from_db(
channel,
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
return results
@@ -219,15 +177,9 @@ async def get_channel(
):
# 使用明确的查询避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -237,8 +189,6 @@ async def get_channel(
channel_type = db_channel.type
channel_name = db_channel.name
assert channel_id is not None
users = []
if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:]
@@ -259,9 +209,7 @@ async def get_channel(
session,
current_user,
redis,
server.channels.get(channel_id, [])
if channel_type != ChannelType.PUBLIC
else None,
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
)
)
@@ -284,9 +232,7 @@ class CreateChannelReq(BaseModel):
raise ValueError("target_id must be set for PM channels")
else:
if self.target_ids is None or self.channel is None or self.message is None:
raise ValueError(
"target_ids, channel, and message must be set for ANNOUNCE channels"
)
raise ValueError("target_ids, channel, and message must be set for ANNOUNCE channels")
return self
@@ -312,24 +258,20 @@ async def create_channel(
raise HTTPException(status_code=403, detail=block)
channel = await ChatChannel.get_pm_channel(
current_user.id, # pyright: ignore[reportArgumentType]
current_user.id,
req.target_id, # pyright: ignore[reportArgumentType]
session,
)
channel_name = f"pm_{current_user.id}_{req.target_id}"
else:
channel_name = req.channel.name if req.channel else "Unnamed Channel"
result = await session.exec(
select(ChatChannel).where(ChatChannel.name == channel_name)
)
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
channel = result.first()
if channel is None:
channel = ChatChannel(
name=channel_name,
description=req.channel.description
if req.channel
else "Private message channel",
description=req.channel.description if req.channel else "Private message channel",
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
)
session.add(channel)
@@ -340,16 +282,13 @@ async def create_channel(
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
else:
target_users = await session.exec(
select(User).where(col(User.id).in_(req.target_ids or []))
)
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
await server.batch_join_channel([*target_users, current_user], channel, session)
await server.join_channel(current_user, channel, session)
# 提取必要的属性避免惰性加载
channel_id = channel.channel_id
assert channel_id
return await ChatChannelResp.from_db(
channel,

View File

@@ -41,33 +41,19 @@ class KeepAliveResp(BaseModel):
)
async def keep_alive(
session: Database,
history_since: int | None = Query(
None, description="获取自此禁言 ID 之后的禁言记录"
),
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
current_user: User = Security(get_current_user, scopes=["chat.read"]),
):
resp = KeepAliveResp()
if history_since:
silences = (
await session.exec(
select(SilenceUser).where(col(SilenceUser.id) > history_since)
)
).all()
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
elif since:
msg = await session.get(ChatMessage, since)
if msg:
silences = (
await session.exec(
select(SilenceUser).where(
col(SilenceUser.banned_at) > msg.timestamp
)
)
).all()
resp.silences.extend(
[UserSilenceResp.from_db(silence) for silence in silences]
)
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))).all()
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
return resp
@@ -93,15 +79,9 @@ async def send_message(
):
# 使用明确的查询来获取 channel避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
@@ -111,9 +91,6 @@ async def send_message(
channel_type = db_channel.type
channel_name = db_channel.name
assert channel_id is not None
assert current_user.id
# 使用 Redis 消息系统发送消息 - 立即返回
resp = await redis_message_system.send_message(
channel_id=channel_id,
@@ -125,9 +102,7 @@ async def send_message(
# 立即广播消息给所有客户端
is_bot_command = req.message.startswith("!")
await server.send_message_to_channel(
resp, is_bot_command and channel_type == ChannelType.PUBLIC
)
await server.send_message_to_channel(resp, is_bot_command and channel_type == ChannelType.PUBLIC)
# 处理机器人命令
if is_bot_command:
@@ -147,14 +122,10 @@ async def send_message(
if channel_type == ChannelType.PM:
user_ids = channel_name.split("_")[1:]
await server.new_private_notification(
ChannelMessage.init(
temp_msg, current_user, [int(u) for u in user_ids], channel_type
)
ChannelMessage.init(temp_msg, current_user, [int(u) for u in user_ids], channel_type)
)
elif channel_type == ChannelType.TEAM:
await server.new_private_notification(
ChannelMessageTeam.init(temp_msg, current_user)
)
await server.new_private_notification(ChannelMessageTeam.init(temp_msg, current_user))
return resp
@@ -176,22 +147,15 @@ async def get_message(
):
# 使用明确的查询获取 channel避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
# 提取必要的属性避免惰性加载
channel_id = db_channel.channel_id
assert channel_id is not None
# 使用 Redis 消息系统获取消息
try:
@@ -230,23 +194,15 @@ async def mark_as_read(
):
# 使用明确的查询获取 channel避免延迟加载
if channel.isdigit():
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
else:
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
if db_channel is None:
raise HTTPException(status_code=404, detail="Channel not found")
# 立即提取需要的属性
channel_id = db_channel.channel_id
assert channel_id
assert current_user.id
await server.mark_as_read(channel_id, current_user.id, message)
@@ -283,7 +239,6 @@ async def create_new_pm(
if not is_can_pm:
raise HTTPException(status_code=403, detail=block)
assert user_id
channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session)
if channel is None:
channel = ChatChannel(
@@ -297,7 +252,6 @@ async def create_new_pm(
await session.refresh(target)
await session.refresh(current_user)
assert channel.channel_id
await server.batch_join_channel([target, current_user], channel, session)
channel_resp = await ChatChannelResp.from_db(
channel, session, current_user, redis, server.channels[channel.channel_id]

View File

@@ -17,6 +17,7 @@ from app.log import logger
from app.models.chat import ChatEvent
from app.models.notification import NotificationDetail
from app.service.subscribers.chat import ChatSubscriber
from app.utils import bg_tasks
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
from fastapi.security import SecurityScopes
@@ -37,20 +38,11 @@ class ChatServer:
self.ChatSubscriber.chat_server = self
self._subscribed = False
def _add_task(self, task):
task = asyncio.create_task(task)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
def connect(self, user_id: int, client: WebSocket):
self.connect_client[user_id] = client
def get_user_joined_channel(self, user_id: int) -> list[int]:
return [
channel_id
for channel_id, users in self.channels.items()
if user_id in users
]
return [channel_id for channel_id, users in self.channels.items() if user_id in users]
async def disconnect(self, user: User, session: AsyncSession):
user_id = user.id
@@ -61,9 +53,7 @@ class ChatServer:
channel.remove(user_id)
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
).first()
if db_channel:
await self.leave_channel(user, db_channel, session)
@@ -93,11 +83,10 @@ class ChatServer:
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
async def send_message_to_channel(
self, message: ChatMessageResp, is_bot_command: bool = False
):
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
logger.info(
f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}"
f"Sending message to channel {message.channel_id}, message_id: "
f"{message.message_id}, is_bot_command: {is_bot_command}"
)
event = ChatEvent(
@@ -106,62 +95,44 @@ class ChatServer:
)
if is_bot_command:
logger.info(f"Sending bot command to user {message.sender_id}")
self._add_task(self.send_event(message.sender_id, event))
bg_tasks.add_task(self.send_event, message.sender_id, event)
else:
# 总是广播消息无论是临时ID还是真实ID
logger.info(
f"Broadcasting message to all users in channel {message.channel_id}"
)
self._add_task(
self.broadcast(
message.channel_id,
event,
)
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
bg_tasks.add_task(
self.broadcast,
message.channel_id,
event,
)
# 只有真实消息 ID正数且非零才进行标记已读和设置最后消息
# Redis 消息系统生成的ID都是正数所以这里应该都能正常处理
if message.message_id and message.message_id > 0:
await self.mark_as_read(
message.channel_id, message.sender_id, message.message_id
)
await self.redis.set(
f"chat:{message.channel_id}:last_msg", message.message_id
)
logger.info(
f"Updated last message ID for channel {message.channel_id} to {message.message_id}"
)
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
else:
logger.debug(
f"Skipping last message update for message ID: {message.message_id}"
)
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
async def batch_join_channel(
self, users: list[User], channel: ChatChannel, session: AsyncSession
):
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
channel_id = channel.channel_id
assert channel_id is not None
not_joined = []
if channel_id not in self.channels:
self.channels[channel_id] = []
for user in users:
assert user.id is not None
if user.id not in self.channels[channel_id]:
self.channels[channel_id].append(user.id)
not_joined.append(user)
for user in not_joined:
assert user.id is not None
channel_resp = await ChatChannelResp.from_db(
channel,
session,
user,
self.redis,
self.channels[channel_id]
if channel.type != ChannelType.PUBLIC
else None,
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
)
await self.send_event(
user.id,
@@ -171,13 +142,9 @@ class ChatServer:
),
)
async def join_channel(
self, user: User, channel: ChatChannel, session: AsyncSession
) -> ChatChannelResp:
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
user_id = user.id
channel_id = channel.channel_id
assert channel_id is not None
assert user_id is not None
if channel_id not in self.channels:
self.channels[channel_id] = []
@@ -202,13 +169,9 @@ class ChatServer:
return channel_resp
async def leave_channel(
self, user: User, channel: ChatChannel, session: AsyncSession
) -> None:
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
user_id = user.id
channel_id = channel.channel_id
assert channel_id is not None
assert user_id is not None
if channel_id in self.channels and user_id in self.channels[channel_id]:
self.channels[channel_id].remove(user_id)
@@ -221,9 +184,7 @@ class ChatServer:
session,
user,
self.redis,
self.channels.get(channel_id)
if channel.type != ChannelType.PUBLIC
else None,
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
)
await self.send_event(
user_id,
@@ -236,11 +197,7 @@ class ChatServer:
async def join_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel is None:
return
@@ -253,11 +210,7 @@ class ChatServer:
async def leave_room_channel(self, channel_id: int, user_id: int):
async with with_db() as session:
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
)
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
if db_channel is None:
return
@@ -270,13 +223,7 @@ class ChatServer:
async def new_private_notification(self, detail: NotificationDetail):
async with with_db() as session:
id = await insert_notification(session, detail)
users = (
await session.exec(
select(UserNotification).where(
UserNotification.notification_id == id
)
)
).all()
users = (await session.exec(select(UserNotification).where(UserNotification.notification_id == id))).all()
for user_notification in users:
data = user_notification.notification.model_dump()
data["is_read"] = user_notification.is_read
@@ -308,9 +255,7 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
await ws.close(code=1000)
break
except WebSocketDisconnect as e:
logger.info(
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
)
logger.info(f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}")
except RuntimeError as e:
if "disconnect message" in str(e):
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
@@ -332,11 +277,7 @@ async def chat_websocket(
async for session in factory():
token = authorization[7:]
if (
user := await get_current_user(
session, SecurityScopes(scopes=["chat.read"]), token_pw=token
)
) is None:
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None:
await websocket.close(code=1008)
return
@@ -346,12 +287,9 @@ async def chat_websocket(
await websocket.close(code=1008)
return
user_id = user.id
assert user_id
server.connect(user_id, websocket)
# 使用明确的查询避免延迟加载
db_channel = (
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))
).first()
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
if db_channel is not None:
await server.join_channel(user, db_channel, session)

View File

@@ -2,22 +2,20 @@
密码重置管理接口
"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from __future__ import annotations
from app.dependencies.database import get_redis
from app.service.password_reset_service import password_reset_service
from app.log import logger
from app.service.password_reset_service import password_reset_service
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
router = APIRouter(prefix="/admin/password-reset", tags=["密码重置管理"])
@router.get(
"/status/{email}",
name="查询重置状态",
description="查询指定邮箱的密码重置状态"
)
@router.get("/status/{email}", name="查询重置状态", description="查询指定邮箱的密码重置状态")
async def get_password_reset_status(
email: str,
redis: Redis = Depends(get_redis),
@@ -25,28 +23,16 @@ async def get_password_reset_status(
"""查询密码重置状态"""
try:
info = await password_reset_service.get_reset_code_info(email, redis)
return JSONResponse(
status_code=200,
content={
"success": True,
"data": info
}
)
return JSONResponse(status_code=200, content={"success": True, "data": info})
except Exception as e:
logger.error(f"[Admin] Failed to get password reset status for {email}: {e}")
return JSONResponse(
status_code=500,
content={
"success": False,
"error": "获取状态失败"
}
)
return JSONResponse(status_code=500, content={"success": False, "error": "获取状态失败"})
@router.delete(
"/cleanup/{email}",
name="清理重置数据",
description="强制清理指定邮箱的密码重置数据"
description="强制清理指定邮箱的密码重置数据",
)
async def force_cleanup_reset(
email: str,
@@ -55,38 +41,23 @@ async def force_cleanup_reset(
"""强制清理密码重置数据"""
try:
success = await password_reset_service.force_cleanup_user_reset(email, redis)
if success:
return JSONResponse(
status_code=200,
content={
"success": True,
"message": f"已清理邮箱 {email} 的重置数据"
}
content={"success": True, "message": f"已清理邮箱 {email} 的重置数据"},
)
else:
return JSONResponse(
status_code=500,
content={
"success": False,
"error": "清理失败"
}
)
return JSONResponse(status_code=500, content={"success": False, "error": "清理失败"})
except Exception as e:
logger.error(f"[Admin] Failed to cleanup password reset for {email}: {e}")
return JSONResponse(
status_code=500,
content={
"success": False,
"error": "清理操作失败"
}
)
return JSONResponse(status_code=500, content={"success": False, "error": "清理操作失败"})
@router.post(
"/cleanup/expired",
name="清理过期验证码",
description="清理所有过期的密码重置验证码"
description="清理所有过期的密码重置验证码",
)
async def cleanup_expired_codes(
redis: Redis = Depends(get_redis),
@@ -99,25 +70,15 @@ async def cleanup_expired_codes(
content={
"success": True,
"message": f"已清理 {count} 个过期的验证码",
"cleaned_count": count
}
"cleaned_count": count,
},
)
except Exception as e:
logger.error(f"[Admin] Failed to cleanup expired codes: {e}")
return JSONResponse(
status_code=500,
content={
"success": False,
"error": "清理操作失败"
}
)
return JSONResponse(status_code=500, content={"success": False, "error": "清理操作失败"})
@router.get(
"/stats",
name="重置统计",
description="获取密码重置的统计信息"
)
@router.get("/stats", name="重置统计", description="获取密码重置的统计信息")
async def get_reset_statistics(
redis: Redis = Depends(get_redis),
):
@@ -126,53 +87,42 @@ async def get_reset_statistics(
# 获取所有重置相关的键
reset_keys = await redis.keys("password_reset:code:*")
rate_limit_keys = await redis.keys("password_reset:rate_limit:*")
active_resets = 0
used_resets = 0
active_rate_limits = 0
# 统计活跃重置
for key in reset_keys:
data_str = await redis.get(key)
if data_str:
try:
import json
data = json.loads(data_str)
if data.get("used", False):
used_resets += 1
else:
active_resets += 1
except:
except Exception:
pass
# 统计频率限制
for key in rate_limit_keys:
ttl = await redis.ttl(key)
if ttl > 0:
active_rate_limits += 1
stats = {
"total_reset_codes": len(reset_keys),
"active_resets": active_resets,
"used_resets": used_resets,
"active_rate_limits": active_rate_limits,
"total_rate_limit_keys": len(rate_limit_keys)
"total_rate_limit_keys": len(rate_limit_keys),
}
return JSONResponse(
status_code=200,
content={
"success": True,
"data": stats
}
)
return JSONResponse(status_code=200, content={"success": True, "data": stats})
except Exception as e:
logger.error(f"[Admin] Failed to get reset statistics: {e}")
return JSONResponse(
status_code=500,
content={
"success": False,
"error": "获取统计信息失败"
}
)
return JSONResponse(status_code=500, content={"success": False, "error": "获取统计信息失败"})

View File

@@ -26,7 +26,7 @@ async def create_oauth_app(
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
current_user: User = Security(get_client_user),
):
result = await session.execute( # pyright: ignore[reportDeprecated]
result = await session.execute(
text(
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'oauth_clients'"
@@ -84,9 +84,7 @@ async def get_user_oauth_apps(
session: Database,
current_user: User = Security(get_client_user),
):
oauth_apps = await session.exec(
select(OAuthClient).where(OAuthClient.owner_id == current_user.id)
)
oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id))
return [
{
"name": app.name,
@@ -113,13 +111,9 @@ async def delete_oauth_app(
if not oauth_client:
raise HTTPException(status_code=404, detail="OAuth app not found")
if oauth_client.owner_id != current_user.id:
raise HTTPException(
status_code=403, detail="Forbidden: Not the owner of this app"
)
raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
tokens = await session.exec(
select(OAuthToken).where(OAuthToken.client_id == client_id)
)
tokens = await session.exec(select(OAuthToken).where(OAuthToken.client_id == client_id))
for token in tokens:
await session.delete(token)
@@ -144,9 +138,7 @@ async def update_oauth_app(
if not oauth_client:
raise HTTPException(status_code=404, detail="OAuth app not found")
if oauth_client.owner_id != current_user.id:
raise HTTPException(
status_code=403, detail="Forbidden: Not the owner of this app"
)
raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
oauth_client.name = name
oauth_client.description = description
@@ -176,14 +168,10 @@ async def refresh_secret(
if not oauth_client:
raise HTTPException(status_code=404, detail="OAuth app not found")
if oauth_client.owner_id != current_user.id:
raise HTTPException(
status_code=403, detail="Forbidden: Not the owner of this app"
)
raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
oauth_client.client_secret = secrets.token_hex()
tokens = await session.exec(
select(OAuthToken).where(OAuthToken.client_id == client_id)
)
tokens = await session.exec(select(OAuthToken).where(OAuthToken.client_id == client_id))
for token in tokens:
await session.delete(token)
@@ -215,9 +203,7 @@ async def generate_oauth_code(
raise HTTPException(status_code=404, detail="OAuth app not found")
if redirect_uri not in client.redirect_uris:
raise HTTPException(
status_code=403, detail="Redirect URI not allowed for this client"
)
raise HTTPException(status_code=403, detail="Redirect URI not allowed for this client")
code = secrets.token_urlsafe(80)
await redis.hset( # pyright: ignore[reportGeneralTypeIssues]

View File

@@ -50,12 +50,8 @@ async def check_user_relationship(
)
).first()
is_followed = bool(
target_relationship and target_relationship.type == RelationshipType.FOLLOW
)
is_following = bool(
my_relationship and my_relationship.type == RelationshipType.FOLLOW
)
is_followed = bool(target_relationship and target_relationship.type == RelationshipType.FOLLOW)
is_following = bool(my_relationship and my_relationship.type == RelationshipType.FOLLOW)
return CheckResponse(
is_followed=is_followed,

View File

@@ -40,16 +40,13 @@ async def create_team(
支持的图片格式: PNG、JPEG、GIF
"""
user_id = current_user.id
assert user_id
if (await current_user.awaitable_attrs.team_membership) is not None:
raise HTTPException(status_code=403, detail="You are already in a team")
is_existed = (await session.exec(select(exists()).where(Team.name == name))).first()
if is_existed:
raise HTTPException(status_code=409, detail="Name already exists")
is_existed = (
await session.exec(select(exists()).where(Team.short_name == short_name))
).first()
is_existed = (await session.exec(select(exists()).where(Team.short_name == short_name))).first()
if is_existed:
raise HTTPException(status_code=409, detail="Short name already exists")
@@ -101,7 +98,6 @@ async def update_team(
"""
team = await session.get(Team, team_id)
user_id = current_user.id
assert user_id
if not team:
raise HTTPException(status_code=404, detail="Team not found")
if team.leader_id != user_id:
@@ -110,9 +106,7 @@ async def update_team(
is_existed = (await session.exec(select(exists()).where(Team.name == name))).first()
if is_existed:
raise HTTPException(status_code=409, detail="Name already exists")
is_existed = (
await session.exec(select(exists()).where(Team.short_name == short_name))
).first()
is_existed = (await session.exec(select(exists()).where(Team.short_name == short_name))).first()
if is_existed:
raise HTTPException(status_code=409, detail="Short name already exists")
@@ -132,20 +126,12 @@ async def update_team(
team.cover_url = await storage.get_file_url(storage_path)
if leader_id is not None:
if not (
await session.exec(select(exists()).where(User.id == leader_id))
).first():
if not (await session.exec(select(exists()).where(User.id == leader_id))).first():
raise HTTPException(status_code=404, detail="Leader not found")
if not (
await session.exec(
select(TeamMember).where(
TeamMember.user_id == leader_id, TeamMember.team_id == team.id
)
)
await session.exec(select(TeamMember).where(TeamMember.user_id == leader_id, TeamMember.team_id == team.id))
).first():
raise HTTPException(
status_code=404, detail="Leader is not a member of the team"
)
raise HTTPException(status_code=404, detail="Leader is not a member of the team")
team.leader_id = leader_id
await session.commit()
@@ -166,9 +152,7 @@ async def delete_team(
if team.leader_id != current_user.id:
raise HTTPException(status_code=403, detail="You are not the team leader")
team_members = await session.exec(
select(TeamMember).where(TeamMember.team_id == team_id)
)
team_members = await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))
for member in team_members:
await session.delete(member)
@@ -186,15 +170,10 @@ async def get_team(
session: Database,
team_id: int = Path(..., description="战队 ID"),
):
members = (
await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))
).all()
members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all()
return TeamQueryResp(
team=members[0].team,
members=[
await UserResp.from_db(m.user, session, include=BASE_INCLUDES)
for m in members
],
members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members],
)
@@ -213,15 +192,11 @@ async def request_join_team(
if (
await session.exec(
select(exists()).where(
TeamRequest.team_id == team_id, TeamRequest.user_id == current_user.id
)
select(exists()).where(TeamRequest.team_id == team_id, TeamRequest.user_id == current_user.id)
)
).first():
raise HTTPException(status_code=409, detail="Join request already exists")
team_request = TeamRequest(
user_id=current_user.id, team_id=team_id, requested_at=datetime.now(UTC)
)
team_request = TeamRequest(user_id=current_user.id, team_id=team_id, requested_at=datetime.now(UTC))
session.add(team_request)
await session.commit()
await session.refresh(team_request)
@@ -229,9 +204,7 @@ async def request_join_team(
@router.post("/team/{team_id}/{user_id}/request", name="接受加入请求", status_code=204)
@router.delete(
"/team/{team_id}/{user_id}/request", name="拒绝加入请求", status_code=204
)
@router.delete("/team/{team_id}/{user_id}/request", name="拒绝加入请求", status_code=204)
async def handle_request(
req: Request,
session: Database,
@@ -247,11 +220,7 @@ async def handle_request(
raise HTTPException(status_code=403, detail="You are not the team leader")
team_request = (
await session.exec(
select(TeamRequest).where(
TeamRequest.team_id == team_id, TeamRequest.user_id == user_id
)
)
await session.exec(select(TeamRequest).where(TeamRequest.team_id == team_id, TeamRequest.user_id == user_id))
).first()
if not team_request:
raise HTTPException(status_code=404, detail="Join request not found")
@@ -261,16 +230,10 @@ async def handle_request(
raise HTTPException(status_code=404, detail="User not found")
if req.method == "POST":
if (
await session.exec(select(exists()).where(TeamMember.user_id == user_id))
).first():
raise HTTPException(
status_code=409, detail="User is already a member of the team"
)
if (await session.exec(select(exists()).where(TeamMember.user_id == user_id))).first():
raise HTTPException(status_code=409, detail="User is already a member of the team")
session.add(
TeamMember(user_id=user_id, team_id=team_id, joined_at=datetime.now(UTC))
)
session.add(TeamMember(user_id=user_id, team_id=team_id, joined_at=datetime.now(UTC)))
await server.new_private_notification(TeamApplicationAccept.init(team_request))
else:
@@ -294,19 +257,13 @@ async def kick_member(
raise HTTPException(status_code=403, detail="You are not the team leader")
team_member = (
await session.exec(
select(TeamMember).where(
TeamMember.team_id == team_id, TeamMember.user_id == user_id
)
)
await session.exec(select(TeamMember).where(TeamMember.team_id == team_id, TeamMember.user_id == user_id))
).first()
if not team_member:
raise HTTPException(status_code=404, detail="User is not a member of the team")
if team.leader_id == current_user.id:
raise HTTPException(
status_code=403, detail="You cannot leave because you are the team leader"
)
raise HTTPException(status_code=403, detail="You cannot leave because you are the team leader")
await session.delete(team_member)
await session.commit()

View File

@@ -35,10 +35,7 @@ async def user_rename(
返回:
- 成功: None
"""
assert current_user is not None
samename_user = (
await session.exec(select(User).where(User.username == new_name))
).first()
samename_user = (await session.exec(select(User).where(User.username == new_name))).first()
if samename_user:
raise HTTPException(409, "Username Exisits")
errors = validate_username(new_name)

View File

@@ -106,9 +106,7 @@ class V1Beatmap(AllStrModel):
await session.exec(
select(func.count())
.select_from(FavouriteBeatmapset)
.where(
FavouriteBeatmapset.beatmapset_id == db_beatmap.beatmapset.id
)
.where(FavouriteBeatmapset.beatmapset_id == db_beatmap.beatmapset.id)
)
).one(),
rating=0, # TODO
@@ -154,12 +152,8 @@ async def get_beatmaps(
beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"),
beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"),
user: str | None = Query(None, alias="u", description="谱师"),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
ruleset_id: int | None = Query(
None, alias="m", description="Ruleset ID", ge=0, le=3
), # TODO
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0, le=3), # TODO
convert: bool = Query(False, alias="a", description="转谱"), # TODO
checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"),
limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"),
@@ -181,11 +175,7 @@ async def get_beatmaps(
else:
beatmaps = beatmapset.beatmaps
elif user is not None:
where = (
Beatmapset.user_id == user
if type == "id" or user.isdigit()
else Beatmapset.creator == user
)
where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()
for beatmapset in beatmapsets:
if len(beatmaps) >= limit:
@@ -193,11 +183,7 @@ async def get_beatmaps(
beatmaps.extend(beatmapset.beatmaps)
elif since is not None:
beatmapsets = (
await session.exec(
select(Beatmapset)
.where(col(Beatmapset.ranked_date) > since)
.limit(limit)
)
await session.exec(select(Beatmapset).where(col(Beatmapset.ranked_date) > since).limit(limit))
).all()
for beatmapset in beatmapsets:
if len(beatmaps) >= limit:
@@ -214,11 +200,7 @@ async def get_beatmaps(
redis,
fetcher,
)
results.append(
await V1Beatmap.from_db(
session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty
)
)
results.append(await V1Beatmap.from_db(session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty))
continue
except Exception:
...

View File

@@ -41,9 +41,7 @@ async def download_replay(
ge=0,
),
score_id: int | None = Query(None, alias="s", description="成绩 ID"),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
mods: int = Query(0, description="成绩的 MOD"),
storage_service: StorageService = Depends(get_storage_service),
):
@@ -58,13 +56,9 @@ async def download_replay(
await session.exec(
select(Score).where(
Score.beatmap_id == beatmap,
Score.user_id == user
if type == "id" or user.isdigit()
else col(Score.user).has(username=user),
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
Score.mods == mods_,
Score.gamemode == GameMode.from_int_extra(ruleset_id)
if ruleset_id is not None
else True,
Score.gamemode == GameMode.from_int_extra(ruleset_id) if ruleset_id is not None else True,
)
)
).first()
@@ -73,10 +67,7 @@ async def download_replay(
except KeyError:
raise HTTPException(status_code=400, detail="Invalid request")
filepath = (
f"replays/{score_record.id}_{score_record.beatmap_id}"
f"_{score_record.user_id}_lazer_replay.osr"
)
filepath = f"replays/{score_record.id}_{score_record.beatmap_id}_{score_record.user_id}_lazer_replay.osr"
if not await storage_service.is_exists(filepath):
raise HTTPException(status_code=404, detail="Replay file not found")
@@ -100,6 +91,4 @@ async def download_replay(
await session.commit()
data = await storage_service.read_file(filepath)
return ReplayModel(
content=base64.b64encode(data).decode("utf-8"), encoding="base64"
)
return ReplayModel(content=base64.b64encode(data).decode("utf-8"), encoding="base64")

View File

@@ -8,9 +8,7 @@ from app.dependencies.user import v1_authorize
from fastapi import APIRouter, Depends
from pydantic import BaseModel, field_serializer
router = APIRouter(
prefix="/api/v1", dependencies=[Depends(v1_authorize)], tags=["V1 API"]
)
router = APIRouter(prefix="/api/v1", dependencies=[Depends(v1_authorize)], tags=["V1 API"])
class AllStrModel(BaseModel):

View File

@@ -70,9 +70,7 @@ async def get_user_best(
session: Database,
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
):
try:
@@ -80,9 +78,7 @@ async def get_user_best(
await session.exec(
select(Score)
.where(
Score.user_id == user
if type == "id" or user.isdigit()
else col(Score.user).has(username=user),
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
Score.gamemode == GameMode.from_int_extra(ruleset_id),
exists().where(col(PPBestScore.score_id) == Score.id),
)
@@ -106,9 +102,7 @@ async def get_user_recent(
session: Database,
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
):
try:
@@ -116,9 +110,7 @@ async def get_user_recent(
await session.exec(
select(Score)
.where(
Score.user_id == user
if type == "id" or user.isdigit()
else col(Score.user).has(username=user),
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
Score.gamemode == GameMode.from_int_extra(ruleset_id),
Score.ended_at > datetime.now(UTC) - timedelta(hours=24),
)
@@ -143,9 +135,7 @@ async def get_scores(
user: str | None = Query(None, alias="u", description="用户"),
beatmap_id: int = Query(alias="b", description="谱面 ID"),
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
mods: int = Query(0, description="成绩的 MOD"),
):
@@ -157,9 +147,7 @@ async def get_scores(
.where(
Score.gamemode == GameMode.from_int_extra(ruleset_id),
Score.beatmap_id == beatmap_id,
Score.user_id == user
if type == "id" or user.isdigit()
else col(Score.user).has(username=user),
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
)
.options(joinedload(Score.beatmap))
.order_by(col(Score.classic_total_score).desc())

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
from datetime import datetime
from typing import Literal
@@ -13,7 +12,7 @@ from app.service.user_cache_service import get_user_cache_service
from .router import AllStrModel, router
from fastapi import HTTPException, Query
from fastapi import BackgroundTasks, HTTPException, Query
from sqlmodel import select
@@ -49,9 +48,7 @@ class V1User(AllStrModel):
return f"v1_user:{user_id}"
@classmethod
async def from_db(
cls, session: Database, db_user: User, ruleset: GameMode | None = None
) -> "V1User":
async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User":
# 确保 user_id 不为 None
if db_user.id is None:
raise ValueError("User ID cannot be None")
@@ -63,9 +60,7 @@ class V1User(AllStrModel):
current_statistics = i
break
if current_statistics:
statistics = await UserStatisticsResp.from_db(
current_statistics, session, db_user.country_code
)
statistics = await UserStatisticsResp.from_db(current_statistics, session, db_user.country_code)
else:
statistics = None
return cls(
@@ -78,9 +73,7 @@ class V1User(AllStrModel):
playcount=statistics.play_count if statistics else 0,
ranked_score=statistics.ranked_score if statistics else 0,
total_score=statistics.total_score if statistics else 0,
pp_rank=statistics.global_rank
if statistics and statistics.global_rank
else 0,
pp_rank=statistics.global_rank if statistics and statistics.global_rank else 0,
level=current_statistics.level_current if current_statistics else 0,
pp_raw=statistics.pp if statistics else 0.0,
accuracy=statistics.hit_accuracy if statistics else 0,
@@ -91,9 +84,7 @@ class V1User(AllStrModel):
count_rank_a=current_statistics.grade_a if current_statistics else 0,
country=db_user.country_code,
total_seconds_played=statistics.play_time if statistics else 0,
pp_country_rank=statistics.country_rank
if statistics and statistics.country_rank
else 0,
pp_country_rank=statistics.country_rank if statistics and statistics.country_rank else 0,
events=[], # TODO
)
@@ -106,14 +97,11 @@ class V1User(AllStrModel):
)
async def get_user(
session: Database,
background_tasks: BackgroundTasks,
user: str = Query(..., alias="u", description="用户"),
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0),
type: Literal["string", "id"] | None = Query(
None, description="用户类型string 用户名称 / id 用户 ID"
),
event_days: int = Query(
default=1, ge=1, le=31, description="从现在起所有事件的最大天数"
),
type: Literal["string", "id"] | None = Query(None, description="用户类型string 用户名称 / id 用户 ID"),
event_days: int = Query(default=1, ge=1, le=31, description="从现在起所有事件的最大天数"),
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
@@ -131,9 +119,7 @@ async def get_user(
if is_id_query:
try:
user_id_for_cache = int(user)
cached_v1_user = await cache_service.get_v1_user_from_cache(
user_id_for_cache, ruleset
)
cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset)
if cached_v1_user:
return [V1User(**cached_v1_user)]
except (ValueError, TypeError):
@@ -158,9 +144,7 @@ async def get_user(
# 异步缓存结果如果有用户ID
if db_user.id is not None:
user_data = v1_user.model_dump()
asyncio.create_task(
cache_service.cache_v1_user(user_data, db_user.id, ruleset)
)
background_tasks.add_task(cache_service.cache_v1_user, user_data, db_user.id, ruleset)
return [v1_user]

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
from . import ( # noqa: F401
beatmap,
beatmapset,
me,

View File

@@ -40,18 +40,13 @@ class BatchGetResp(BaseModel):
tags=["谱面"],
name="查询单个谱面",
response_model=BeatmapResp,
description=(
"根据谱面 ID / MD5 / 文件名 查询单个谱面。"
"至少提供 id / checksum / filename 之一。"
),
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
)
async def lookup_beatmap(
db: Database,
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
filename: str | None = Query(
default=None, alias="filename", description="谱面文件名"
),
filename: str | None = Query(default=None, alias="filename", description="谱面文件名"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -96,43 +91,23 @@ async def get_beatmap(
tags=["谱面"],
name="批量获取谱面",
response_model=BatchGetResp,
description=(
"批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。"
"为空时按最近更新时间返回。"
),
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
)
async def batch_get_beatmaps(
db: Database,
beatmap_ids: list[int] = Query(
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
),
beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
current_user: User = Security(get_current_user, scopes=["public"]),
fetcher: Fetcher = Depends(get_fetcher),
):
if not beatmap_ids:
beatmaps = (
await db.exec(
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
)
).all()
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
else:
beatmaps = list(
(
await db.exec(
select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50)
)
).all()
)
not_found_beatmaps = [
bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]
]
beatmaps = list((await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50))).all())
not_found_beatmaps = [bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]]
beatmaps.extend(
beatmap
for beatmap in await asyncio.gather(
*[
Beatmap.get_or_fetch(db, fetcher, bid=bid)
for bid in not_found_beatmaps
],
*[Beatmap.get_or_fetch(db, fetcher, bid=bid) for bid in not_found_beatmaps],
return_exceptions=True,
)
if isinstance(beatmap, Beatmap)
@@ -140,12 +115,7 @@ async def batch_get_beatmaps(
for beatmap in beatmaps:
await db.refresh(beatmap)
return BatchGetResp(
beatmaps=[
await BeatmapResp.from_db(bm, session=db, user=current_user)
for bm in beatmaps
]
)
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
@router.post(
@@ -163,12 +133,8 @@ async def get_beatmap_attributes(
default_factory=list,
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
),
ruleset: GameMode | None = Query(
default=None, description="指定 ruleset;为空则使用谱面自身模式"
),
ruleset_id: int | None = Query(
default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3
),
ruleset: GameMode | None = Query(default=None, description="指定 ruleset为空则使用谱面自身模式"),
ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -187,16 +153,11 @@ async def get_beatmap_attributes(
if ruleset is None:
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
ruleset = beatmap_db.mode
key = (
f"beatmap:{beatmap_id}:{ruleset}:"
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
)
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
try:
return await calculate_beatmap_attributes(
beatmap_id, ruleset, mods_, redis, fetcher
)
return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher)
except HTTPStatusError:
raise HTTPException(status_code=404, detail="Beatmap not found")
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]

View File

@@ -35,9 +35,7 @@ from sqlmodel import exists, select
async def _save_to_db(sets: SearchBeatmapsetsResp):
async with with_db() as session:
for s in sets.beatmapsets:
if not (
await session.exec(select(exists()).where(Beatmapset.id == s.id))
).first():
if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first():
await Beatmapset.from_resp(session, s)
@@ -117,9 +115,7 @@ async def lookup_beatmapset(
fetcher: Fetcher = Depends(get_fetcher),
):
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
resp = await BeatmapsetResp.from_db(
beatmap.beatmapset, session=db, user=current_user
)
resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
return resp
@@ -138,9 +134,7 @@ async def get_beatmapset(
):
try:
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
return await BeatmapsetResp.from_db(
beatmapset, session=db, include=["recent_favourites"], user=current_user
)
return await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmapset not found")
@@ -165,9 +159,7 @@ async def download_beatmapset(
country_code = geo_info.get("country_iso", "")
# 优先使用IP地理位置判断如果获取失败则回退到用户账户的国家代码
is_china = country_code == "CN" or (
not country_code and current_user.country_code == "CN"
)
is_china = country_code == "CN" or (not country_code and current_user.country_code == "CN")
try:
# 使用负载均衡服务获取下载URL
@@ -179,13 +171,10 @@ async def download_beatmapset(
# 如果负载均衡服务失败,回退到原有逻辑
if is_china:
return RedirectResponse(
f"https://dl.sayobot.cn/beatmaps/download/"
f"{'novideo' if no_video else 'full'}/{beatmapset_id}"
f"https://dl.sayobot.cn/beatmaps/download/{'novideo' if no_video else 'full'}/{beatmapset_id}"
)
else:
return RedirectResponse(
f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}"
)
return RedirectResponse(f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}")
@router.post(
@@ -197,12 +186,9 @@ async def download_beatmapset(
async def favourite_beatmapset(
db: Database,
beatmapset_id: int = Path(..., description="谱面集 ID"),
action: Literal["favourite", "unfavourite"] = Form(
description="操作类型favourite 收藏 / unfavourite 取消收藏"
),
action: Literal["favourite", "unfavourite"] = Form(description="操作类型favourite 收藏 / unfavourite 取消收藏"),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
existing_favourite = (
await db.exec(
select(FavouriteBeatmapset).where(
@@ -212,15 +198,11 @@ async def favourite_beatmapset(
)
).first()
if (action == "favourite" and existing_favourite) or (
action == "unfavourite" and not existing_favourite
):
if (action == "favourite" and existing_favourite) or (action == "unfavourite" and not existing_favourite):
return
if action == "favourite":
favourite = FavouriteBeatmapset(
user_id=current_user.id, beatmapset_id=beatmapset_id
)
favourite = FavouriteBeatmapset(user_id=current_user.id, beatmapset_id=beatmapset_id)
db.add(favourite)
else:
await db.delete(existing_favourite)

View File

@@ -4,8 +4,8 @@ from app.database import User
from app.database.lazer_user import ALL_INCLUDED
from app.dependencies import get_current_user
from app.dependencies.database import Database
from app.models.score import GameMode
from app.models.api_me import APIMe
from app.models.score import GameMode
from .router import router

View File

@@ -33,6 +33,4 @@ class BackgroundsResp(BaseModel):
description="获取当前季节背景图列表。",
)
async def get_seasonal_backgrounds():
return BackgroundsResp(
backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds]
)
return BackgroundsResp(backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds])

View File

@@ -12,7 +12,7 @@ from app.service.ranking_cache_service import get_ranking_cache_service
from .router import router
from fastapi import Path, Query, Security
from fastapi import BackgroundTasks, Path, Query, Security
from pydantic import BaseModel
from sqlmodel import col, select
@@ -38,6 +38,7 @@ class CountryResponse(BaseModel):
)
async def get_country_ranking(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
page: int = Query(1, ge=1, description="页码"), # TODO
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -51,9 +52,7 @@ async def get_country_ranking(
if cached_data:
# 从缓存返回数据
return CountryResponse(
ranking=[CountryStatistics.model_validate(item) for item in cached_data]
)
return CountryResponse(ranking=[CountryStatistics.model_validate(item) for item in cached_data])
# 缓存未命中,从数据库查询
response = CountryResponse(ranking=[])
@@ -105,14 +104,15 @@ async def get_country_ranking(
# 异步缓存数据(不等待完成)
cache_data = [item.model_dump() for item in current_page_data]
cache_task = cache_service.cache_country_ranking(
ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60
)
# 创建后台任务来缓存数据
import asyncio
asyncio.create_task(cache_task)
background_tasks.add_task(
cache_service.cache_country_ranking,
ruleset,
cache_data,
page,
ttl=settings.ranking_cache_expire_minutes * 60,
)
# 返回当前页的结果
response.ranking = current_page_data
@@ -132,10 +132,9 @@ class TopUsersResponse(BaseModel):
)
async def get_user_ranking(
session: Database,
background_tasks: BackgroundTasks,
ruleset: GameMode = Path(..., description="指定 ruleset"),
type: Literal["performance", "score"] = Path(
..., description="排名类型performance 表现分 / score 计分成绩总分"
),
type: Literal["performance", "score"] = Path(..., description="排名类型performance 表现分 / score 计分成绩总分"),
country: str | None = Query(None, description="国家代码"),
page: int = Query(1, ge=1, description="页码"),
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -149,9 +148,7 @@ async def get_user_ranking(
if cached_data:
# 从缓存返回数据
return TopUsersResponse(
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]
)
return TopUsersResponse(ranking=[UserStatisticsResp.model_validate(item) for item in cached_data])
# 缓存未命中,从数据库查询
wheres = [
@@ -169,25 +166,22 @@ async def get_user_ranking(
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
statistics_list = await session.exec(
select(UserStatistics)
.where(*wheres)
.order_by(order_by)
.limit(50)
.offset(50 * (page - 1))
select(UserStatistics).where(*wheres).order_by(order_by).limit(50).offset(50 * (page - 1))
)
# 转换为响应格式
ranking_data = []
for statistics in statistics_list:
user_stats_resp = await UserStatisticsResp.from_db(
statistics, session, None, include
)
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
ranking_data.append(user_stats_resp)
# 异步缓存数据(不等待完成)
# 使用配置文件中的TTL设置
cache_data = [item.model_dump() for item in ranking_data]
cache_task = cache_service.cache_ranking(
# 创建后台任务来缓存数据
background_tasks.add_task(
cache_service.cache_ranking,
ruleset,
type,
cache_data,
@@ -196,139 +190,134 @@ async def get_user_ranking(
ttl=settings.ranking_cache_expire_minutes * 60,
)
# 创建后台任务来缓存数据
import asyncio
asyncio.create_task(cache_task)
resp = TopUsersResponse(ranking=ranking_data)
return resp
""" @router.post(
"/rankings/cache/refresh",
name="刷新排行榜缓存",
description="手动刷新排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def refresh_ranking_cache(
session: Database,
ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
if ruleset and type:
# 刷新特定的用户排行榜
await cache_service.refresh_ranking_cache(session, ruleset, type, country)
message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
# 如果请求刷新地区排行榜
if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
await cache_service.refresh_country_ranking_cache(session, ruleset)
message += f" and country ranking for {ruleset}"
return {"message": message}
elif ruleset:
# 刷新特定游戏模式的所有排行榜
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
for ranking_type in ranking_types:
await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
if include_country_ranking:
await cache_service.refresh_country_ranking_cache(session, ruleset)
return {"message": f"Refreshed all ranking caches for {ruleset}"}
else:
# 刷新所有排行榜
await cache_service.refresh_all_rankings(session)
return {"message": "Refreshed all ranking caches"}
# @router.post(
# "/rankings/cache/refresh",
# name="刷新排行榜缓存",
# description="手动刷新排行榜缓存(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def refresh_ranking_cache(
# session: Database,
# ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
# type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
# country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# if ruleset and type:
# # 刷新特定的用户排行榜
# await cache_service.refresh_ranking_cache(session, ruleset, type, country)
# message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
# # 如果请求刷新地区排行榜
# if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
# await cache_service.refresh_country_ranking_cache(session, ruleset)
# message += f" and country ranking for {ruleset}"
# return {"message": message}
# elif ruleset:
# # 刷新特定游戏模式的所有排行榜
# ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
# for ranking_type in ranking_types:
# await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
# if include_country_ranking:
# await cache_service.refresh_country_ranking_cache(session, ruleset)
# return {"message": f"Refreshed all ranking caches for {ruleset}"}
# else:
# # 刷新所有排行榜
# await cache_service.refresh_all_rankings(session)
# return {"message": "Refreshed all ranking caches"}
@router.post(
"/rankings/{ruleset}/country/cache/refresh",
name="刷新地区排行榜缓存",
description="手动刷新地区排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def refresh_country_ranking_cache(
session: Database,
ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
await cache_service.refresh_country_ranking_cache(session, ruleset)
return {"message": f"Refreshed country ranking cache for {ruleset}"}
# @router.post(
# "/rankings/{ruleset}/country/cache/refresh",
# name="刷新地区排行榜缓存",
# description="手动刷新地区排行榜缓存(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def refresh_country_ranking_cache(
# session: Database,
# ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# await cache_service.refresh_country_ranking_cache(session, ruleset)
# return {"message": f"Refreshed country ranking cache for {ruleset}"}
@router.delete(
"/rankings/cache",
name="清除排行榜缓存",
description="清除排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def clear_ranking_cache(
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
if ruleset and type:
message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
if include_country_ranking:
message += " and country ranking"
return {"message": message}
else:
message = "Cleared all ranking caches"
if include_country_ranking:
message += " including country rankings"
return {"message": message}
# @router.delete(
# "/rankings/cache",
# name="清除排行榜缓存",
# description="清除排行榜缓存(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def clear_ranking_cache(
# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
# type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
# country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
# if ruleset and type:
# message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
# if include_country_ranking:
# message += " and country ranking"
# return {"message": message}
# else:
# message = "Cleared all ranking caches"
# if include_country_ranking:
# message += " including country rankings"
# return {"message": message}
@router.delete(
"/rankings/{ruleset}/country/cache",
name="清除地区排行榜缓存",
description="清除地区排行榜缓存(管理员功能)",
tags=["排行榜", "管理"],
)
async def clear_country_ranking_cache(
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
await cache_service.invalidate_country_cache(ruleset)
if ruleset:
return {"message": f"Cleared country ranking cache for {ruleset}"}
else:
return {"message": "Cleared all country ranking caches"}
# @router.delete(
# "/rankings/{ruleset}/country/cache",
# name="清除地区排行榜缓存",
# description="清除地区排行榜缓存(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def clear_country_ranking_cache(
# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# await cache_service.invalidate_country_cache(ruleset)
# if ruleset:
# return {"message": f"Cleared country ranking cache for {ruleset}"}
# else:
# return {"message": "Cleared all country ranking caches"}
@router.get(
"/rankings/cache/stats",
name="获取排行榜缓存统计",
description="获取排行榜缓存统计信息(管理员功能)",
tags=["排行榜", "管理"],
)
async def get_ranking_cache_stats(
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
):
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
stats = await cache_service.get_cache_stats()
return stats """
# @router.get(
# "/rankings/cache/stats",
# name="获取排行榜缓存统计",
# description="获取排行榜缓存统计信息(管理员功能)",
# tags=["排行榜", "管理"],
# )
# async def get_ranking_cache_stats(
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
# ):
# redis = get_redis()
# cache_service = get_ranking_cache_service(redis)
# stats = await cache_service.get_cache_stats()
# return stats

View File

@@ -30,11 +30,7 @@ async def get_relationship(
request: Request,
current_user: User = Security(get_current_user, scopes=["friends.read"]),
):
relationship_type = (
RelationshipType.FOLLOW
if request.url.path.endswith("/friends")
else RelationshipType.BLOCK
)
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
relationships = await db.exec(
select(Relationship).where(
Relationship.user_id == current_user.id,
@@ -71,12 +67,7 @@ async def add_relationship(
target: int = Query(description="目标用户 ID"),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
relationship_type = (
RelationshipType.FOLLOW
if request.url.path.endswith("/friends")
else RelationshipType.BLOCK
)
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
if target == current_user.id:
raise HTTPException(422, "Cannot add relationship to yourself")
relationship = (
@@ -120,11 +111,8 @@ async def add_relationship(
Relationship.target_id == target,
)
)
).first()
assert relationship, "Relationship should exist after commit"
return AddFriendResp(
user_relation=await RelationshipResp.from_db(db, relationship)
)
).one()
return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship))
@router.delete(
@@ -145,11 +133,7 @@ async def delete_relationship(
target: int = Path(..., description="目标用户 ID"),
current_user: User = Security(get_client_user),
):
relationship_type = (
RelationshipType.BLOCK
if "/blocks/" in request.url.path
else RelationshipType.FOLLOW
)
relationship_type = RelationshipType.BLOCK if "/blocks/" in request.url.path else RelationshipType.FOLLOW
relationship = (
await db.exec(
select(Relationship).where(

View File

@@ -39,17 +39,11 @@ async def get_all_rooms(
db: Database,
mode: Literal["open", "ended", "participated", "owned", None] = Query(
default="open",
description=(
"房间模式open 当前开放 / ended 已经结束 / "
"participated 参与过 / owned 自己创建的房间"
),
description=("房间模式open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
),
category: RoomCategory = Query(
RoomCategory.NORMAL,
description=(
"房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
" / DAILY_CHALLENGE 每日挑战"
),
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
),
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -60,10 +54,7 @@ async def get_all_rooms(
if status is not None:
where_clauses.append(col(Room.status) == status)
if mode == "open":
where_clauses.append(
(col(Room.ends_at).is_(None))
| (col(Room.ends_at) > now.replace(tzinfo=UTC))
)
where_clauses.append((col(Room.ends_at).is_(None)) | (col(Room.ends_at) > now.replace(tzinfo=UTC)))
if category == RoomCategory.REALTIME:
where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
if mode == "participated":
@@ -76,10 +67,7 @@ async def get_all_rooms(
if mode == "owned":
where_clauses.append(col(Room.host_id) == current_user.id)
if mode == "ended":
where_clauses.append(
(col(Room.ends_at).is_not(None))
& (col(Room.ends_at) < now.replace(tzinfo=UTC))
)
where_clauses.append((col(Room.ends_at).is_not(None)) & (col(Room.ends_at) < now.replace(tzinfo=UTC)))
db_rooms = (
(
@@ -97,11 +85,7 @@ async def get_all_rooms(
resp = await RoomResp.from_db(room, db)
if category == RoomCategory.REALTIME:
mp_room = MultiplayerHubs.rooms.get(room.id)
resp.has_password = (
bool(mp_room.room.settings.password.strip())
if mp_room is not None
else False
)
resp.has_password = bool(mp_room.room.settings.password.strip()) if mp_room is not None else False
resp.category = RoomCategory.NORMAL
resp_list.append(resp)
@@ -115,9 +99,7 @@ class APICreatedRoom(RoomResp):
error: str = ""
async def _participate_room(
room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
):
async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
participated_user = (
await session.exec(
select(RoomParticipatedUser).where(
@@ -154,7 +136,6 @@ async def create_room(
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
assert current_user.id is not None
user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id)
await _participate_room(db_room.id, user_id, db_room, db, redis)
@@ -177,10 +158,7 @@ async def get_room(
room_id: int = Path(..., description="房间 ID"),
category: str = Query(
default="",
description=(
"房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
" / DAILY_CHALLENGE 每日挑战 (可选)"
),
description=("房间分类NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
),
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
@@ -188,9 +166,7 @@ async def get_room(
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None:
raise HTTPException(404, "Room not found")
resp = await RoomResp.from_db(
db_room, include=["current_user_score"], session=db, user=current_user
)
resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user)
return resp
@@ -400,7 +376,6 @@ async def get_room_events(
for score in scores:
user_ids.add(score.user_id)
beatmap_ids.add(score.beatmap_id)
assert event.id is not None
first_event_id = min(first_event_id, event.id)
last_event_id = max(last_event_id, event.id)
@@ -416,16 +391,12 @@ async def get_room_events(
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
user_resps = [await UserResp.from_db(user, db) for user in users]
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
beatmap_resps = [
await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
]
beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps]
beatmapset_resps = {}
for beatmap_resp in beatmap_resps:
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
playlist_items_resps = [
await PlaylistResp.from_db(item) for item in playlist_items.values()
]
playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()]
return RoomEvents(
beatmaps=beatmap_resps,

View File

@@ -104,11 +104,7 @@ async def submit_score(
if not info.passed:
info.rank = Rank.F
score_token = (
await db.exec(
select(ScoreToken)
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
.where(ScoreToken.id == token)
)
await db.exec(select(ScoreToken).options(joinedload(ScoreToken.beatmap)).where(ScoreToken.id == token))
).first()
if not score_token or score_token.user_id != user_id:
raise HTTPException(status_code=404, detail="Score token not found")
@@ -138,10 +134,7 @@ async def submit_score(
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
has_pp = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp
has_leaderboard = (
db_beatmap.beatmap_status.has_leaderboard()
| settings.enable_all_beatmap_leaderboard
)
has_leaderboard = db_beatmap.beatmap_status.has_leaderboard() | settings.enable_all_beatmap_leaderboard
beatmap_length = db_beatmap.total_length
score = await process_score(
current_user,
@@ -167,21 +160,11 @@ async def submit_score(
has_pp,
has_leaderboard,
)
score = (
await db.exec(
select(Score)
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
.where(Score.id == score_id)
)
).first()
assert score is not None
score = (await db.exec(select(Score).options(joinedload(Score.user)).where(Score.id == score_id))).one()
resp = await ScoreResp.from_db(db, score)
total_users = (await db.exec(select(func.count()).select_from(User))).first()
assert total_users is not None
if resp.rank_global is not None and resp.rank_global <= min(
math.ceil(float(total_users) * 0.01), 50
):
total_users = (await db.exec(select(func.count()).select_from(User))).one()
if resp.rank_global is not None and resp.rank_global <= min(math.ceil(float(total_users) * 0.01), 50):
rank_event = Event(
created_at=datetime.now(UTC),
type=EventType.RANK,
@@ -207,9 +190,7 @@ async def submit_score(
score_gamemode = score.gamemode
if user_id is not None:
background_task.add_task(
_refresh_user_cache_background, redis, user_id, score_gamemode
)
background_task.add_task(_refresh_user_cache_background, redis, user_id, score_gamemode)
background_task.add_task(process_user_achievement, resp.id)
return resp
@@ -225,9 +206,7 @@ async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameM
# 创建独立的数据库会话
session = AsyncSession(engine)
try:
await user_cache_service.refresh_user_cache_on_score_submit(
session, user_id, mode
)
await user_cache_service.refresh_user_cache_on_score_submit(session, user_id, mode)
finally:
await session.close()
except Exception as e:
@@ -280,22 +259,16 @@ async def get_beatmap_scores(
beatmap_id: int = Path(description="谱面 ID"),
mode: GameMode = Query(description="指定 auleset"),
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
mods: list[str] = Query(
default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"
),
mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"),
type: LeaderboardType = Query(
LeaderboardType.GLOBAL,
description=(
"排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"
),
description=("排行榜类型GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
),
current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="this server only contains lazer scores"
)
raise HTTPException(status_code=404, detail="this server only contains lazer scores")
all_scores, user_score, count = await get_leaderboard(
db,
@@ -310,9 +283,7 @@ async def get_beatmap_scores(
user_score_resp = await ScoreResp.from_db(db, user_score) if user_score else None
resp = BeatmapScores(
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
user_score=BeatmapUserScore(
score=user_score_resp, position=user_score_resp.rank_global or 0
)
user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0)
if user_score_resp
else None,
score_count=count,
@@ -342,9 +313,7 @@ async def get_user_beatmap_score(
current_user: User = Security(get_current_user, scopes=["public"]),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="This server only contains non-legacy scores"
)
raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
user_score = (
await db.exec(
select(Score)
@@ -386,9 +355,7 @@ async def get_user_all_beatmap_scores(
current_user: User = Security(get_current_user, scopes=["public"]),
):
if legacy_only:
raise HTTPException(
status_code=404, detail="This server only contains non-legacy scores"
)
raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
all_user_scores = (
await db.exec(
select(Score)
@@ -420,7 +387,6 @@ async def create_solo_score(
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -454,10 +420,7 @@ async def submit_solo_score(
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
):
assert current_user.id is not None
return await submit_score(
background_task, info, beatmap_id, token, current_user, db, redis, fetcher
)
return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher)
@router.post(
@@ -478,7 +441,6 @@ async def create_playlist_score(
version_hash: str = Form("", description="谱面版本哈希"),
current_user: User = Security(get_client_user),
):
assert current_user.id is not None
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
@@ -488,26 +450,16 @@ async def create_playlist_score(
db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None
if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC):
raise HTTPException(status_code=400, detail="Room has ended")
item = (
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
)
).first()
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
if not item:
raise HTTPException(status_code=404, detail="Playlist not found")
# validate
if not item.freestyle:
if item.ruleset_id != ruleset_id:
raise HTTPException(
status_code=400, detail="Ruleset mismatch in playlist item"
)
raise HTTPException(status_code=400, detail="Ruleset mismatch in playlist item")
if item.beatmap_id != beatmap_id:
raise HTTPException(
status_code=400, detail="Beatmap ID mismatch in playlist item"
)
raise HTTPException(status_code=400, detail="Beatmap ID mismatch in playlist item")
agg = await session.exec(
select(ItemAttemptsCount).where(
ItemAttemptsCount.room_id == room_id,
@@ -523,9 +475,7 @@ async def create_playlist_score(
if item.expired:
raise HTTPException(status_code=400, detail="Playlist item has expired")
if item.played_at:
raise HTTPException(
status_code=400, detail="Playlist item has already been played"
)
raise HTTPException(status_code=400, detail="Playlist item has already been played")
# 这里应该不用验证mod了吧。。。
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
score_token = ScoreToken(
@@ -557,18 +507,10 @@ async def submit_playlist_score(
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
):
assert current_user.id is not None
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
item = (
await session.exec(
select(Playlist).where(
Playlist.id == playlist_id, Playlist.room_id == room_id
)
)
).first()
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
if not item:
raise HTTPException(status_code=404, detail="Playlist item not found")
room = await session.get(Room, room_id)
@@ -621,9 +563,7 @@ async def index_playlist_scores(
room_id: int,
playlist_id: int,
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
cursor: int = Query(
2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"
),
cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"),
current_user: User = Security(get_current_user, scopes=["public"]),
):
# 立即获取用户ID避免懒加载问题
@@ -693,9 +633,6 @@ async def show_playlist_score(
current_user: User = Security(get_client_user),
redis: Redis = Depends(get_redis),
):
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
room = await session.get(Room, room_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
@@ -715,9 +652,7 @@ async def show_playlist_score(
)
)
).first()
if completed_players := await redis.get(
f"multiplayer:{room_id}:gameplay:players"
):
if completed_players := await redis.get(f"multiplayer:{room_id}:gameplay:players"):
completed = completed_players == "0"
if score_record and completed:
break
@@ -784,9 +719,7 @@ async def get_user_playlist_score(
raise HTTPException(status_code=404, detail="Score not found")
resp = await ScoreResp.from_db(session, score_record.score)
resp.position = await get_position(
room_id, playlist_id, score_record.score_id, session
)
resp.position = await get_position(room_id, playlist_id, score_record.score_id, session)
return resp
@@ -850,11 +783,7 @@ async def unpin_score(
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
score_record = (
await db.exec(
select(Score).where(Score.id == score_id, Score.user_id == user_id)
)
).first()
score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
@@ -878,10 +807,7 @@ async def unpin_score(
"/score-pins/{score_id}/reorder",
status_code=204,
name="调整置顶成绩顺序",
description=(
"**客户端专属**\n调整已置顶成绩的展示顺序。"
"仅提供 after_score_id 或 before_score_id 之一。"
),
description=("**客户端专属**\n调整已置顶成绩的展示顺序。仅提供 after_score_id 或 before_score_id 之一。"),
tags=["成绩"],
)
async def reorder_score_pin(
@@ -894,11 +820,7 @@ async def reorder_score_pin(
# 立即获取用户ID避免懒加载问题
user_id = current_user.id
score_record = (
await db.exec(
select(Score).where(Score.id == score_id, Score.user_id == user_id)
)
).first()
score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
if not score_record:
raise HTTPException(status_code=404, detail="Score not found")
@@ -908,8 +830,7 @@ async def reorder_score_pin(
if (after_score_id is None) == (before_score_id is None):
raise HTTPException(
status_code=400,
detail="Either after_score_id or before_score_id "
"must be provided (but not both)",
detail="Either after_score_id or before_score_id must be provided (but not both)",
)
all_pinned_scores = (
@@ -927,9 +848,7 @@ async def reorder_score_pin(
target_order = None
reference_score_id = after_score_id or before_score_id
reference_score = next(
(s for s in all_pinned_scores if s.id == reference_score_id), None
)
reference_score = next((s for s in all_pinned_scores if s.id == reference_score_id), None)
if not reference_score:
detail = "After score not found" if after_score_id else "Before score not found"
raise HTTPException(status_code=404, detail=detail)
@@ -951,9 +870,7 @@ async def reorder_score_pin(
if current_order < s.pinned_order <= target_order and s.id != score_id:
updates.append((s.id, s.pinned_order - 1))
if after_score_id:
final_target = (
target_order - 1 if target_order > current_order else target_order
)
final_target = target_order - 1 if target_order > current_order else target_order
else:
final_target = target_order
else:
@@ -964,9 +881,7 @@ async def reorder_score_pin(
for score_id, new_order in updates:
await db.exec(select(Score).where(Score.id == score_id))
score_to_update = (
await db.exec(select(Score).where(Score.id == score_id))
).first()
score_to_update = (await db.exec(select(Score).where(Score.id == score_id))).first()
if score_to_update:
score_to_update.pinned_order = new_order

View File

@@ -4,34 +4,29 @@
from __future__ import annotations
from datetime import datetime, UTC
from typing import Annotated
from app.auth import authenticate_user
from app.config import settings
from app.database import User
from app.dependencies import get_current_user
from app.dependencies.database import Database, get_redis
from app.dependencies.geoip import GeoIPHelper, get_geoip_helper
from app.database.email_verification import EmailVerification, LoginSession
from app.service.email_verification_service import (
EmailVerificationService,
LoginSessionService
EmailVerificationService,
LoginSessionService,
)
from app.service.login_log_service import LoginLogService
from app.models.extended_auth import ExtendedTokenResponse
from fastapi import Form, Depends, Request, HTTPException, status, Security
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel
from redis.asyncio import Redis
from sqlmodel import select
from .router import router
from fastapi import Depends, Form, HTTPException, Request, Security, status
from fastapi.responses import Response
from pydantic import BaseModel
from redis.asyncio import Redis
class SessionReissueResponse(BaseModel):
"""重新发送验证码响应"""
success: bool
message: str
@@ -40,39 +35,35 @@ class SessionReissueResponse(BaseModel):
"/session/verify",
name="验证会话",
description="验证邮件验证码并完成会话认证",
status_code=204
status_code=204,
)
async def verify_session(
request: Request,
db: Database,
redis: Annotated[Redis, Depends(get_redis)],
verification_key: str = Form(..., description="8位邮件验证码"),
current_user: User = Security(get_current_user)
current_user: User = Security(get_current_user),
) -> Response:
"""
验证邮件验证码并完成会话认证
对应 osu! 的 session/verify 接口
成功时返回 204 No Content失败时返回 401 Unauthorized
"""
try:
from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")
ip_address = get_client_ip(request) # noqa: F841
user_agent = request.headers.get("User-Agent", "Unknown") # noqa: F841
# 从当前认证用户获取信息
user_id = current_user.id
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户未认证"
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未认证")
# 验证邮件验证码
success, message = await EmailVerificationService.verify_code(
db, redis, user_id, verification_key
)
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_key)
if success:
# 记录成功的邮件验证
await LoginLogService.record_login(
@@ -81,9 +72,9 @@ async def verify_session(
request=request,
login_method="email_verification",
login_success=True,
notes=f"邮件验证成功"
notes="邮件验证成功",
)
# 返回 204 No Content 表示验证成功
return Response(status_code=status.HTTP_204_NO_CONTENT)
else:
@@ -93,83 +84,69 @@ async def verify_session(
request=request,
attempted_username=current_user.username,
login_method="email_verification",
notes=f"邮件验证失败: {message}"
notes=f"邮件验证失败: {message}",
)
# 返回 401 Unauthorized 表示验证失败
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=message
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=message)
except ValueError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的用户会话"
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="验证过程中发生错误"
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的用户会话")
except Exception:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="验证过程中发生错误")
@router.post(
"/session/verify/reissue",
name="重新发送验证码",
description="重新发送邮件验证码",
response_model=SessionReissueResponse
response_model=SessionReissueResponse,
)
async def reissue_verification_code(
request: Request,
db: Database,
redis: Annotated[Redis, Depends(get_redis)],
current_user: User = Security(get_current_user)
current_user: User = Security(get_current_user),
) -> SessionReissueResponse:
"""
重新发送邮件验证码
对应 osu! 的 session/verify/reissue 接口
"""
try:
from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "Unknown")
# 从当前认证用户获取信息
user_id = current_user.id
if not user_id:
return SessionReissueResponse(
success=False,
message="用户未认证"
)
return SessionReissueResponse(success=False, message="用户未认证")
# 重新发送验证码
success, message = await EmailVerificationService.resend_verification_code(
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
db,
redis,
user_id,
current_user.username,
current_user.email,
ip_address,
user_agent,
)
return SessionReissueResponse(
success=success,
message=message
)
return SessionReissueResponse(success=success, message=message)
except ValueError:
return SessionReissueResponse(
success=False,
message="无效的用户会话"
)
except Exception as e:
return SessionReissueResponse(
success=False,
message="重新发送过程中发生错误"
)
return SessionReissueResponse(success=False, message="无效的用户会话")
except Exception:
return SessionReissueResponse(success=False, message="重新发送过程中发生错误")
@router.post(
"/session/check-new-location",
name="检查新位置登录",
description="检查登录是否来自新位置(内部接口)"
description="检查登录是否来自新位置(内部接口)",
)
async def check_new_location(
request: Request,
@@ -183,22 +160,21 @@ async def check_new_location(
"""
try:
from app.dependencies.geoip import get_client_ip
ip_address = get_client_ip(request)
geo_info = geoip.lookup(ip_address)
country_code = geo_info.get("country_iso", "XX")
is_new_location = await LoginSessionService.check_new_location(
db, user_id, ip_address, country_code
)
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
return {
"is_new_location": is_new_location,
"ip_address": ip_address,
"country_code": country_code
"country_code": country_code,
}
except Exception as e:
return {
"is_new_location": True, # 出错时默认为新位置
"error": str(e)
"error": str(e),
}

View File

@@ -1,73 +1,80 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
import json
from typing import Any
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import json
from app.dependencies.database import get_redis, get_redis_message
from app.log import logger
from app.utils import bg_tasks
from .router import router
from fastapi import APIRouter
from pydantic import BaseModel
# Redis key constants
REDIS_ONLINE_USERS_KEY = "server:online_users"
REDIS_PLAYING_USERS_KEY = "server:playing_users"
REDIS_PLAYING_USERS_KEY = "server:playing_users"
REDIS_REGISTERED_USERS_KEY = "server:registered_users"
REDIS_ONLINE_HISTORY_KEY = "server:online_history"
# 线程池用于同步Redis操作
_executor = ThreadPoolExecutor(max_workers=2)
async def _redis_exec(func, *args, **kwargs):
"""在线程池中执行同步Redis操作"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(_executor, func, *args, **kwargs)
class ServerStats(BaseModel):
"""服务器统计信息响应模型"""
registered_users: int
online_users: int
playing_users: int
timestamp: datetime
class OnlineHistoryPoint(BaseModel):
"""在线历史数据点"""
timestamp: datetime
online_count: int
playing_count: int
class OnlineHistoryResponse(BaseModel):
"""24小时在线历史响应模型"""
history: list[OnlineHistoryPoint]
current_stats: ServerStats
@router.get("/stats", response_model=ServerStats, tags=["统计"])
async def get_server_stats() -> ServerStats:
"""
获取服务器实时统计信息
返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息
"""
redis = get_redis()
try:
# 并行获取所有统计数据
registered_count, online_count, playing_count = await asyncio.gather(
_get_registered_users_count(redis),
_get_online_users_count(redis),
_get_playing_users_count(redis)
_get_playing_users_count(redis),
)
return ServerStats(
registered_users=registered_count,
online_users=online_count,
playing_users=playing_count,
timestamp=datetime.utcnow()
timestamp=datetime.utcnow(),
)
except Exception as e:
logger.error(f"Error getting server stats: {e}")
@@ -76,14 +83,15 @@ async def get_server_stats() -> ServerStats:
registered_users=0,
online_users=0,
playing_users=0,
timestamp=datetime.utcnow()
timestamp=datetime.utcnow(),
)
@router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"])
async def get_online_history() -> OnlineHistoryResponse:
"""
获取最近24小时在线统计历史
返回过去24小时内每小时的在线用户数和游玩用户数统计
包含当前实时数据作为最新数据点
"""
@@ -92,80 +100,80 @@ async def get_online_history() -> OnlineHistoryResponse:
redis_sync = get_redis_message()
history_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
history_points = []
# 处理历史数据
for data in history_data:
try:
point_data = json.loads(data)
# 只保留基本字段
history_points.append(OnlineHistoryPoint(
timestamp=datetime.fromisoformat(point_data["timestamp"]),
online_count=point_data["online_count"],
playing_count=point_data["playing_count"]
))
history_points.append(
OnlineHistoryPoint(
timestamp=datetime.fromisoformat(point_data["timestamp"]),
online_count=point_data["online_count"],
playing_count=point_data["playing_count"],
)
)
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f"Invalid history data point: {data}, error: {e}")
continue
# 获取当前实时统计信息
current_stats = await get_server_stats()
# 如果历史数据为空或者最新数据超过15分钟添加当前数据点
if not history_points or (
history_points and
(current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds() > 15 * 60
history_points
and (current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds()
> 15 * 60
):
history_points.append(OnlineHistoryPoint(
timestamp=current_stats.timestamp,
online_count=current_stats.online_users,
playing_count=current_stats.playing_users
))
history_points.append(
OnlineHistoryPoint(
timestamp=current_stats.timestamp,
online_count=current_stats.online_users,
playing_count=current_stats.playing_users,
)
)
# 按时间排序(最新的在前)
history_points.sort(key=lambda x: x.timestamp, reverse=True)
# 限制到最多48个数据点24小时
history_points = history_points[:48]
return OnlineHistoryResponse(
history=history_points,
current_stats=current_stats
)
return OnlineHistoryResponse(history=history_points, current_stats=current_stats)
except Exception as e:
logger.error(f"Error getting online history: {e}")
# 返回空历史和当前状态
current_stats = await get_server_stats()
return OnlineHistoryResponse(
history=[],
current_stats=current_stats
)
return OnlineHistoryResponse(history=[], current_stats=current_stats)
@router.get("/stats/debug", tags=["统计"])
async def get_stats_debug_info():
"""
获取统计系统调试信息
用于调试时间对齐和区间统计问题
"""
try:
from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager
current_time = datetime.utcnow()
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
interval_stats = await EnhancedIntervalStatsManager.get_current_interval_stats()
# 获取Redis中的实际数据
redis_sync = get_redis_message()
online_key = f"server:interval_online_users:{current_interval.interval_key}"
playing_key = f"server:interval_playing_users:{current_interval.interval_key}"
online_users_raw = await _redis_exec(redis_sync.smembers, online_key)
playing_users_raw = await _redis_exec(redis_sync.smembers, playing_key)
online_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in online_users_raw]
playing_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in playing_users_raw]
return {
"current_time": current_time.isoformat(),
"current_interval": {
@@ -175,28 +183,29 @@ async def get_stats_debug_info():
"is_current": current_interval.is_current(),
"minutes_remaining": int((current_interval.end_time - current_time).total_seconds() / 60),
"seconds_remaining": int((current_interval.end_time - current_time).total_seconds()),
"progress_percentage": round((1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100, 1)
"progress_percentage": round(
(1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100,
1,
),
},
"interval_statistics": interval_stats.to_dict() if interval_stats else None,
"redis_data": {
"online_users": online_users,
"playing_users": playing_users,
"online_count": len(online_users),
"playing_count": len(playing_users)
"playing_count": len(playing_users),
},
"system_status": {
"stats_system": "enhanced_interval_stats",
"data_alignment": "30_minute_boundaries",
"real_time_updates": True,
"auto_24h_fill": True
}
"auto_24h_fill": True,
},
}
except Exception as e:
logger.error(f"Error getting debug info: {e}")
return {
"error": "Failed to retrieve debug information",
"message": str(e)
}
return {"error": "Failed to retrieve debug information", "message": str(e)}
async def _get_registered_users_count(redis) -> int:
"""获取注册用户总数(从缓存)"""
@@ -207,6 +216,7 @@ async def _get_registered_users_count(redis) -> int:
logger.error(f"Error getting registered users count: {e}")
return 0
async def _get_online_users_count(redis) -> int:
"""获取当前在线用户数"""
try:
@@ -216,6 +226,7 @@ async def _get_online_users_count(redis) -> int:
logger.error(f"Error getting online users count: {e}")
return 0
async def _get_playing_users_count(redis) -> int:
"""获取当前游玩用户数"""
try:
@@ -225,27 +236,28 @@ async def _get_playing_users_count(redis) -> int:
logger.error(f"Error getting playing users count: {e}")
return 0
# 统计更新功能
async def update_registered_users_count() -> None:
"""更新注册用户数缓存"""
from app.dependencies.database import with_db
from app.database import User
from app.const import BANCHOBOT_ID
from sqlmodel import select, func
from app.database import User
from app.dependencies.database import with_db
from sqlmodel import func, select
redis = get_redis()
try:
async with with_db() as db:
# 排除机器人用户BANCHOBOT_ID
result = await db.exec(
select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID)
)
result = await db.exec(select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID))
count = result.first()
await redis.set(REDIS_REGISTERED_USERS_KEY, count or 0, ex=300) # 5分钟过期
logger.debug(f"Updated registered users count: {count}")
except Exception as e:
logger.error(f"Error updating registered users count: {e}")
async def add_online_user(user_id: int) -> None:
"""添加在线用户"""
redis_sync = get_redis_message()
@@ -257,14 +269,20 @@ async def add_online_user(user_id: int) -> None:
if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期
await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期
logger.debug(f"Added online user {user_id}")
# 立即更新当前区间统计
from app.service.enhanced_interval_stats import update_user_activity_in_interval
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=False))
bg_tasks.add_task(
update_user_activity_in_interval,
user_id,
is_playing=False,
)
except Exception as e:
logger.error(f"Error adding online user {user_id}: {e}")
async def remove_online_user(user_id: int) -> None:
"""移除在线用户"""
redis_sync = get_redis_message()
@@ -274,6 +292,7 @@ async def remove_online_user(user_id: int) -> None:
except Exception as e:
logger.error(f"Error removing online user {user_id}: {e}")
async def add_playing_user(user_id: int) -> None:
"""添加游玩用户"""
redis_sync = get_redis_message()
@@ -285,14 +304,16 @@ async def add_playing_user(user_id: int) -> None:
if ttl <= 0: # -1表示永不过期-2表示不存在0表示已过期
await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期
logger.debug(f"Added playing user {user_id}")
# 立即更新当前区间统计
from app.service.enhanced_interval_stats import update_user_activity_in_interval
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=True))
bg_tasks.add_task(update_user_activity_in_interval, user_id, is_playing=True)
except Exception as e:
logger.error(f"Error adding playing user {user_id}: {e}")
async def remove_playing_user(user_id: int) -> None:
"""移除游玩用户"""
redis_sync = get_redis_message()
@@ -301,6 +322,7 @@ async def remove_playing_user(user_id: int) -> None:
except Exception as e:
logger.error(f"Error removing playing user {user_id}: {e}")
async def record_hourly_stats() -> None:
"""记录统计数据 - 简化版本主要作为fallback使用"""
redis_sync = get_redis_message()
@@ -308,24 +330,27 @@ async def record_hourly_stats() -> None:
try:
# 先确保Redis连接正常
await redis_async.ping()
online_count = await _get_online_users_count(redis_async)
playing_count = await _get_playing_users_count(redis_async)
current_time = datetime.utcnow()
history_point = {
"timestamp": current_time.isoformat(),
"online_count": online_count,
"playing_count": playing_count
"playing_count": playing_count,
}
# 添加到历史记录
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
# 只保留48个数据点24小时每30分钟一个点
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
# 设置过期时间为26小时确保有足够缓冲
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
logger.info(f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}")
logger.info(
f"Recorded fallback stats: online={online_count}, playing={playing_count} "
f"at {current_time.strftime('%H:%M:%S')}"
)
except Exception as e:
logger.error(f"Error recording fallback stats: {e}")

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
from datetime import UTC, datetime, timedelta
from typing import Literal
@@ -26,7 +25,7 @@ from app.service.user_cache_service import get_user_cache_service
from .router import router
from fastapi import HTTPException, Path, Query, Security
from fastapi import BackgroundTasks, HTTPException, Path, Query, Security
from pydantic import BaseModel
from sqlmodel import exists, false, select
from sqlmodel.sql.expression import col
@@ -47,13 +46,10 @@ class BatchUserResponse(BaseModel):
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
async def get_users(
session: Database,
user_ids: list[int] = Query(
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"
),
background_task: BackgroundTasks,
user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"),
# current_user: User = Security(get_current_user, scopes=["public"]),
include_variant_statistics: bool = Query(
default=False, description="是否包含各模式的统计信息"
), # TODO: future use
include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
@@ -72,11 +68,7 @@ async def get_users(
# 查询未缓存的用户
if uncached_user_ids:
searched_users = (
await session.exec(
select(User).where(col(User.id).in_(uncached_user_ids))
)
).all()
searched_users = (await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))).all()
# 将查询到的用户添加到缓存并返回
for searched_user in searched_users:
@@ -88,7 +80,7 @@ async def get_users(
)
cached_users.append(user_resp)
# 异步缓存,不阻塞响应
asyncio.create_task(cache_service.cache_user(user_resp))
background_task.add_task(cache_service.cache_user, user_resp)
return BatchUserResponse(users=cached_users)
else:
@@ -103,7 +95,7 @@ async def get_users(
)
users.append(user_resp)
# 异步缓存
asyncio.create_task(cache_service.cache_user(user_resp))
background_task.add_task(cache_service.cache_user, user_resp)
return BatchUserResponse(users=users)
@@ -117,6 +109,7 @@ async def get_users(
)
async def get_user_info_ruleset(
session: Database,
background_task: BackgroundTasks,
user_id: str = Path(description="用户 ID 或用户名"),
ruleset: GameMode | None = Path(description="指定 ruleset"),
# current_user: User = Security(get_current_user, scopes=["public"]),
@@ -134,9 +127,7 @@ async def get_user_info_ruleset(
searched_user = (
await session.exec(
select(User).where(
User.id == int(user_id)
if user_id.isdigit()
else User.username == user_id.removeprefix("@")
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
)
)
).first()
@@ -151,7 +142,7 @@ async def get_user_info_ruleset(
)
# 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp, ruleset))
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
return user_resp
@@ -165,6 +156,7 @@ async def get_user_info_ruleset(
tags=["用户"],
)
async def get_user_info(
background_task: BackgroundTasks,
session: Database,
user_id: str = Path(description="用户 ID 或用户名"),
# current_user: User = Security(get_current_user, scopes=["public"]),
@@ -182,9 +174,7 @@ async def get_user_info(
searched_user = (
await session.exec(
select(User).where(
User.id == int(user_id)
if user_id.isdigit()
else User.username == user_id.removeprefix("@")
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
)
)
).first()
@@ -198,7 +188,7 @@ async def get_user_info(
)
# 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp))
background_task.add_task(cache_service.cache_user, user_resp)
return user_resp
@@ -212,6 +202,7 @@ async def get_user_info(
)
async def get_user_beatmapsets(
session: Database,
background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"),
type: BeatmapsetType = Path(description="谱面集类型"),
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -222,9 +213,7 @@ async def get_user_beatmapsets(
cache_service = get_user_cache_service(redis)
# 先尝试从缓存获取
cached_result = await cache_service.get_user_beatmapsets_from_cache(
user_id, type.value, limit, offset
)
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
if cached_result is not None:
# 根据类型恢复对象
if type == BeatmapsetType.MOST_PLAYED:
@@ -253,10 +242,7 @@ async def get_user_beatmapsets(
raise HTTPException(404, detail="User not found")
favourites = await user.awaitable_attrs.favourite_beatmapsets
resp = [
await BeatmapsetResp.from_db(
favourite.beatmapset, session=session, user=user
)
for favourite in favourites
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
]
elif type == BeatmapsetType.MOST_PLAYED:
@@ -267,25 +253,18 @@ async def get_user_beatmapsets(
.limit(limit)
.offset(offset)
)
resp = [
await BeatmapPlaycountsResp.from_db(most_played_beatmap)
for most_played_beatmap in most_played
]
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
else:
raise HTTPException(400, detail="Invalid beatmapset type")
# 异步缓存结果
async def cache_beatmapsets():
try:
await cache_service.cache_user_beatmapsets(
user_id, type.value, resp, limit, offset
)
await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
except Exception as e:
logger.error(
f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
)
logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}")
asyncio.create_task(cache_beatmapsets())
background_task.add_task(cache_beatmapsets)
return resp
@@ -299,18 +278,14 @@ async def get_user_beatmapsets(
)
async def get_user_scores(
session: Database,
background_task: BackgroundTasks,
user_id: int = Path(description="用户 ID"),
type: Literal["best", "recent", "firsts", "pinned"] = Path(
description=(
"成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩"
" / firsts 第一名成绩 / pinned 置顶成绩"
)
description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")
),
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"),
include_fails: bool = Query(False, description="是否包含失败的成绩"),
mode: GameMode | None = Query(
None, description="指定 ruleset (可选,默认为用户主模式)"
),
mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"),
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
offset: int = Query(0, ge=0, description="偏移量"),
current_user: User = Security(get_current_user, scopes=["public"]),
@@ -320,9 +295,7 @@ async def get_user_scores(
# 先尝试从缓存获取对于recent类型使用较短的缓存时间
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
cached_scores = await cache_service.get_user_scores_from_cache(
user_id, type, mode, limit, offset
)
cached_scores = await cache_service.get_user_scores_from_cache(user_id, type, mode, limit, offset)
if cached_scores is not None:
return cached_scores
@@ -332,9 +305,7 @@ async def get_user_scores(
gamemode = mode or db_user.playmode
order_by = None
where_clause = (col(Score.user_id) == db_user.id) & (
col(Score.gamemode) == gamemode
)
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
if not include_fails:
where_clause &= col(Score.passed).is_(True)
if type == "pinned":
@@ -351,13 +322,7 @@ async def get_user_scores(
where_clause &= false()
scores = (
await session.exec(
select(Score)
.where(where_clause)
.order_by(order_by)
.limit(limit)
.offset(offset)
)
await session.exec(select(Score).where(where_clause).order_by(order_by).limit(limit).offset(offset))
).all()
if not scores:
return []
@@ -371,18 +336,14 @@ async def get_user_scores(
]
# 异步缓存结果
asyncio.create_task(
cache_service.cache_user_scores(
user_id, type, score_responses, mode, limit, offset, cache_expire
)
background_task.add_task(
cache_service.cache_user_scores, user_id, type, score_responses, mode, limit, offset, cache_expire
)
return score_responses
@router.get(
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
)
@router.get("/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp])
async def get_user_events(
session: Database,
user: int,