refactor(project): make pyright & ruff happy
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import re
|
||||
from typing import Literal, Union
|
||||
from typing import Literal
|
||||
|
||||
from app.auth import (
|
||||
authenticate_user,
|
||||
@@ -22,19 +22,19 @@ from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||
from app.helpers.geoip_helper import GeoIPHelper
|
||||
from app.log import logger
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
from app.models.oauth import (
|
||||
OAuthErrorResponse,
|
||||
RegistrationRequestErrors,
|
||||
TokenResponse,
|
||||
UserRegistrationErrors,
|
||||
)
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
from app.models.score import GameMode
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.service.email_verification_service import (
|
||||
EmailVerificationService,
|
||||
LoginSessionService
|
||||
LoginSessionService,
|
||||
)
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.service.password_reset_service import password_reset_service
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
@@ -44,13 +44,9 @@ from sqlalchemy import text
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
def create_oauth_error_response(
|
||||
error: str, description: str, hint: str, status_code: int = 400
|
||||
):
|
||||
def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400):
|
||||
"""创建标准的 OAuth 错误响应"""
|
||||
error_data = OAuthErrorResponse(
|
||||
error=error, error_description=description, hint=hint, message=description
|
||||
)
|
||||
error_data = OAuthErrorResponse(error=error, error_description=description, hint=hint, message=description)
|
||||
return JSONResponse(status_code=status_code, content=error_data.model_dump())
|
||||
|
||||
|
||||
@@ -123,9 +119,7 @@ async def register_user(
|
||||
)
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
return JSONResponse(status_code=422, content={"form_error": errors.model_dump()})
|
||||
|
||||
try:
|
||||
# 获取客户端 IP 并查询地理位置
|
||||
@@ -137,10 +131,7 @@ async def register_user(
|
||||
geo_info = geoip.lookup(client_ip)
|
||||
if geo_info and geo_info.get("country_iso"):
|
||||
country_code = geo_info["country_iso"]
|
||||
logger.info(
|
||||
f"User {user_username} registering from "
|
||||
f"{client_ip}, country: {country_code}"
|
||||
)
|
||||
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
|
||||
else:
|
||||
logger.warning(f"Could not determine country for IP {client_ip}")
|
||||
except Exception as e:
|
||||
@@ -148,7 +139,7 @@ async def register_user(
|
||||
|
||||
# 创建新用户
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
result = await db.execute( # pyright: ignore[reportDeprecated]
|
||||
result = await db.execute(
|
||||
text(
|
||||
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
||||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
|
||||
@@ -173,7 +164,6 @@ async def register_user(
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
assert new_user.id is not None, "New user ID should not be None"
|
||||
for i in [GameMode.OSU, GameMode.TAIKO, GameMode.FRUITS, GameMode.MANIA]:
|
||||
statistics = UserStatistics(mode=i, user_id=new_user.id)
|
||||
db.add(statistics)
|
||||
@@ -193,36 +183,30 @@ async def register_user(
|
||||
logger.exception(f"Registration error for user {user_username}")
|
||||
|
||||
# 返回通用错误
|
||||
errors = RegistrationRequestErrors(
|
||||
message="An error occurred while creating your account. Please try again."
|
||||
)
|
||||
errors = RegistrationRequestErrors(message="An error occurred while creating your account. Please try again.")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"form_error": errors.model_dump()})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/oauth/token",
|
||||
response_model=Union[TokenResponse, ExtendedTokenResponse],
|
||||
response_model=TokenResponse | ExtendedTokenResponse,
|
||||
name="获取访问令牌",
|
||||
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
|
||||
)
|
||||
async def oauth_token(
|
||||
db: Database,
|
||||
request: Request,
|
||||
grant_type: Literal[
|
||||
"authorization_code", "refresh_token", "password", "client_credentials"
|
||||
] = Form(..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"),
|
||||
grant_type: Literal["authorization_code", "refresh_token", "password", "client_credentials"] = Form(
|
||||
..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"
|
||||
),
|
||||
client_id: int = Form(..., description="客户端 ID"),
|
||||
client_secret: str = Form(..., description="客户端密钥"),
|
||||
code: str | None = Form(None, description="授权码(仅授权码模式需要)"),
|
||||
scope: str = Form("*", description="权限范围(空格分隔,默认为 '*')"),
|
||||
username: str | None = Form(None, description="用户名(仅密码模式需要)"),
|
||||
password: str | None = Form(None, description="密码(仅密码模式需要)"),
|
||||
refresh_token: str | None = Form(
|
||||
None, description="刷新令牌(仅刷新令牌模式需要)"
|
||||
),
|
||||
refresh_token: str | None = Form(None, description="刷新令牌(仅刷新令牌模式需要)"),
|
||||
redis: Redis = Depends(get_redis),
|
||||
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||
):
|
||||
@@ -303,37 +287,33 @@ async def oauth_token(
|
||||
await db.refresh(user)
|
||||
|
||||
# 获取用户信息和客户端信息
|
||||
user_id = getattr(user, "id")
|
||||
assert user_id is not None, "User ID should not be None after authentication"
|
||||
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
user_id = user.id
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
|
||||
|
||||
# 获取国家代码
|
||||
geo_info = geoip.lookup(ip_address)
|
||||
country_code = geo_info.get("country_iso", "XX")
|
||||
|
||||
|
||||
# 检查是否为新位置登录
|
||||
is_new_location = await LoginSessionService.check_new_location(
|
||||
db, user_id, ip_address, country_code
|
||||
)
|
||||
|
||||
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
|
||||
|
||||
# 创建登录会话记录
|
||||
login_session = await LoginSessionService.create_session(
|
||||
login_session = await LoginSessionService.create_session( # noqa: F841
|
||||
db, redis, user_id, ip_address, user_agent, country_code, is_new_location
|
||||
)
|
||||
|
||||
|
||||
# 如果是新位置登录,需要邮件验证
|
||||
if is_new_location and settings.enable_email_verification:
|
||||
# 刷新用户对象以确保属性已加载
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
# 发送邮件验证码
|
||||
verification_sent = await EmailVerificationService.send_verification_email(
|
||||
db, redis, user_id, user.username, user.email, ip_address, user_agent
|
||||
)
|
||||
|
||||
|
||||
# 记录需要二次验证的登录尝试
|
||||
await LoginLogService.record_login(
|
||||
db=db,
|
||||
@@ -343,14 +323,16 @@ async def oauth_token(
|
||||
login_method="password_pending_verification",
|
||||
notes=f"新位置登录,需要邮件验证 - IP: {ip_address}, 国家: {country_code}",
|
||||
)
|
||||
|
||||
|
||||
if not verification_sent:
|
||||
# 邮件发送失败,记录错误
|
||||
logger.error(f"[Auth] Failed to send email verification code for user {user_id}")
|
||||
elif is_new_location and not settings.enable_email_verification:
|
||||
# 新位置登录但邮件验证功能被禁用,直接标记会话为已验证
|
||||
await LoginSessionService.mark_session_verified(db, user_id)
|
||||
logger.debug(f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}")
|
||||
logger.debug(
|
||||
f"[Auth] New location login detected but email verification disabled, auto-verifying user {user_id}"
|
||||
)
|
||||
else:
|
||||
# 不是新位置登录,正常登录
|
||||
await LoginLogService.record_login(
|
||||
@@ -361,20 +343,17 @@ async def oauth_token(
|
||||
login_method="password",
|
||||
notes=f"正常登录 - IP: {ip_address}, 国家: {country_code}",
|
||||
)
|
||||
|
||||
|
||||
# 无论是否新位置登录,都返回正常的token
|
||||
# session_verified状态通过/me接口的session_verified字段来体现
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
# 获取用户ID,避免触发延迟加载
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user_id)}, expires_delta=access_token_expires
|
||||
)
|
||||
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
assert user_id
|
||||
await store_token(
|
||||
db,
|
||||
user_id,
|
||||
@@ -423,9 +402,7 @@ async def oauth_token(
|
||||
|
||||
# 生成新的访问令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires
|
||||
)
|
||||
access_token = create_access_token(data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires)
|
||||
new_refresh_token = generate_refresh_token()
|
||||
|
||||
# 更新令牌
|
||||
@@ -489,17 +466,11 @@ async def oauth_token(
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
# 重新查询只获取ID,避免触发延迟加载
|
||||
id_result = await db.exec(select(User.id).where(User.username == username))
|
||||
user_id = id_result.first()
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user_id)}, expires_delta=access_token_expires
|
||||
)
|
||||
user_id = user.id
|
||||
access_token = create_access_token(data={"sub": str(user_id)}, expires_delta=access_token_expires)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
assert user_id
|
||||
await store_token(
|
||||
db,
|
||||
user_id,
|
||||
@@ -539,9 +510,7 @@ async def oauth_token(
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": "3"}, expires_delta=access_token_expires
|
||||
)
|
||||
access_token = create_access_token(data={"sub": "3"}, expires_delta=access_token_expires)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
@@ -567,7 +536,7 @@ async def oauth_token(
|
||||
@router.post(
|
||||
"/password-reset/request",
|
||||
name="请求密码重置",
|
||||
description="通过邮箱请求密码重置验证码"
|
||||
description="通过邮箱请求密码重置验证码",
|
||||
)
|
||||
async def request_password_reset(
|
||||
request: Request,
|
||||
@@ -578,42 +547,26 @@ async def request_password_reset(
|
||||
请求密码重置
|
||||
"""
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "")
|
||||
|
||||
|
||||
# 请求密码重置
|
||||
success, message = await password_reset_service.request_password_reset(
|
||||
email=email.lower().strip(),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
redis=redis
|
||||
redis=redis,
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"success": True,
|
||||
"message": message
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"success": True, "message": message})
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"success": False,
|
||||
"error": message
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=400, content={"success": False, "error": message})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/password-reset/reset",
|
||||
name="重置密码",
|
||||
description="使用验证码重置密码"
|
||||
)
|
||||
@router.post("/password-reset/reset", name="重置密码", description="使用验证码重置密码")
|
||||
async def reset_password(
|
||||
request: Request,
|
||||
email: str = Form(..., description="邮箱地址"),
|
||||
@@ -625,32 +578,20 @@ async def reset_password(
|
||||
重置密码
|
||||
"""
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = get_client_ip(request)
|
||||
|
||||
|
||||
# 重置密码
|
||||
success, message = await password_reset_service.reset_password(
|
||||
email=email.lower().strip(),
|
||||
reset_code=reset_code.strip(),
|
||||
new_password=new_password,
|
||||
ip_address=ip_address,
|
||||
redis=redis
|
||||
redis=redis,
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"success": True,
|
||||
"message": message
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"success": True, "message": message})
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"success": False,
|
||||
"error": message
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=400, content={"success": False, "error": message})
|
||||
|
||||
@@ -43,9 +43,9 @@ async def get_notifications(
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
if settings.server_url is not None:
|
||||
notification_endpoint = f"{settings.server_url}notification-server".replace(
|
||||
"http://", "ws://"
|
||||
).replace("https://", "wss://")
|
||||
notification_endpoint = f"{settings.server_url}notification-server".replace("http://", "ws://").replace(
|
||||
"https://", "wss://"
|
||||
)
|
||||
else:
|
||||
notification_endpoint = "/notification-server"
|
||||
query = select(UserNotification).where(
|
||||
@@ -96,21 +96,15 @@ async def _get_notifications(
|
||||
query = base_query.where(UserNotification.notification_id == identity.id)
|
||||
if identity.object_id is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.object_id) == identity.object_id
|
||||
)
|
||||
col(UserNotification.notification).has(col(Notification.object_id) == identity.object_id)
|
||||
)
|
||||
if identity.object_type is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.object_type) == identity.object_type
|
||||
)
|
||||
col(UserNotification.notification).has(col(Notification.object_type) == identity.object_type)
|
||||
)
|
||||
if identity.category is not None:
|
||||
query = base_query.where(
|
||||
col(UserNotification.notification).has(
|
||||
col(Notification.category) == identity.category
|
||||
)
|
||||
col(UserNotification.notification).has(col(Notification.category) == identity.category)
|
||||
)
|
||||
result.update({n.notification_id: n for n in await session.exec(query)})
|
||||
return list(result.values())
|
||||
@@ -134,7 +128,6 @@ async def mark_notifications_as_read(
|
||||
for user_notification in user_notifications:
|
||||
user_notification.is_read = True
|
||||
|
||||
assert current_user.id
|
||||
await server.send_event(
|
||||
current_user.id,
|
||||
ChatEvent(
|
||||
|
||||
@@ -91,9 +91,7 @@ class Bot:
|
||||
if reply:
|
||||
await self._send_reply(user, channel, reply, session)
|
||||
|
||||
async def _send_message(
|
||||
self, channel: ChatChannel, content: str, session: AsyncSession
|
||||
) -> None:
|
||||
async def _send_message(self, channel: ChatChannel, content: str, session: AsyncSession) -> None:
|
||||
bot = await session.get(User, self.bot_user_id)
|
||||
if bot is None:
|
||||
return
|
||||
@@ -101,7 +99,6 @@ class Bot:
|
||||
if channel_id is None:
|
||||
return
|
||||
|
||||
assert bot.id is not None
|
||||
msg = ChatMessage(
|
||||
channel_id=channel_id,
|
||||
content=content,
|
||||
@@ -115,9 +112,7 @@ class Bot:
|
||||
resp = await ChatMessageResp.from_db(msg, session, bot)
|
||||
await server.send_message_to_channel(resp)
|
||||
|
||||
async def _ensure_pm_channel(
|
||||
self, user: User, session: AsyncSession
|
||||
) -> ChatChannel | None:
|
||||
async def _ensure_pm_channel(self, user: User, session: AsyncSession) -> ChatChannel | None:
|
||||
user_id = user.id
|
||||
if user_id is None:
|
||||
return None
|
||||
@@ -160,9 +155,7 @@ bot = Bot()
|
||||
|
||||
|
||||
@bot.command("help")
|
||||
async def _help(
|
||||
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
async def _help(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
|
||||
cmds = sorted(bot._handlers.keys())
|
||||
if args:
|
||||
target = args[0].lower()
|
||||
@@ -175,9 +168,7 @@ async def _help(
|
||||
|
||||
|
||||
@bot.command("roll")
|
||||
def _roll(
|
||||
user: User, args: list[str], _session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
def _roll(user: User, args: list[str], _session: AsyncSession, channel: ChatChannel) -> str:
|
||||
if len(args) > 0 and args[0].isdigit():
|
||||
r = random.randint(1, int(args[0]))
|
||||
else:
|
||||
@@ -186,13 +177,9 @@ def _roll(
|
||||
|
||||
|
||||
@bot.command("stats")
|
||||
async def _stats(
|
||||
user: User, args: list[str], session: AsyncSession, channel: ChatChannel
|
||||
) -> str:
|
||||
async def _stats(user: User, args: list[str], session: AsyncSession, channel: ChatChannel) -> str:
|
||||
if len(args) >= 1:
|
||||
target_user = (
|
||||
await session.exec(select(User).where(User.username == args[0]))
|
||||
).first()
|
||||
target_user = (await session.exec(select(User).where(User.username == args[0]))).first()
|
||||
if not target_user:
|
||||
return f"User '{args[0]}' not found."
|
||||
else:
|
||||
@@ -202,14 +189,8 @@ async def _stats(
|
||||
if len(args) >= 2:
|
||||
gamemode = GameMode.parse(args[1].upper())
|
||||
if gamemode is None:
|
||||
subquery = (
|
||||
select(func.max(Score.id))
|
||||
.where(Score.user_id == target_user.id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
last_score = (
|
||||
await session.exec(select(Score).where(Score.id == subquery))
|
||||
).first()
|
||||
subquery = select(func.max(Score.id)).where(Score.user_id == target_user.id).scalar_subquery()
|
||||
last_score = (await session.exec(select(Score).where(Score.id == subquery))).first()
|
||||
if last_score is not None:
|
||||
gamemode = last_score.gamemode
|
||||
else:
|
||||
@@ -295,9 +276,7 @@ async def _mp_host(
|
||||
return "Usage: !mp host <username>"
|
||||
|
||||
username = args[0]
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
|
||||
@@ -362,24 +341,18 @@ async def _mp_team(
|
||||
if team is None:
|
||||
return "Invalid team colour. Use 'red' or 'blue'."
|
||||
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
user_client = MultiplayerHubs.get_client_by_id(str(user_id))
|
||||
if not user_client:
|
||||
return f"User '{username}' is not in the room."
|
||||
if (
|
||||
user_client.user_id != signalr_client.user_id
|
||||
and room.room.host.user_id != signalr_client.user_id
|
||||
):
|
||||
assert room.room.host
|
||||
if user_client.user_id != signalr_client.user_id and room.room.host.user_id != signalr_client.user_id:
|
||||
return "You are not allowed to change other users' teams."
|
||||
|
||||
try:
|
||||
await MultiplayerHubs.SendMatchRequest(
|
||||
user_client, ChangeTeamRequest(team_id=team)
|
||||
)
|
||||
await MultiplayerHubs.SendMatchRequest(user_client, ChangeTeamRequest(team_id=team))
|
||||
return ""
|
||||
except InvokeException as e:
|
||||
return e.message
|
||||
@@ -414,9 +387,7 @@ async def _mp_kick(
|
||||
return "Usage: !mp kick <username>"
|
||||
|
||||
username = args[0]
|
||||
user_id = (
|
||||
await session.exec(select(User.id).where(User.username == username))
|
||||
).first()
|
||||
user_id = (await session.exec(select(User.id).where(User.username == username))).first()
|
||||
if not user_id:
|
||||
return f"User '{username}' not found."
|
||||
|
||||
@@ -456,10 +427,7 @@ async def _mp_map(
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(session, await get_fetcher(), bid=map_id)
|
||||
if beatmap.mode != GameMode.OSU and playmode and playmode != beatmap.mode:
|
||||
return (
|
||||
f"Cannot convert to {playmode.value}. "
|
||||
f"Original mode is {beatmap.mode.value}."
|
||||
)
|
||||
return f"Cannot convert to {playmode.value}. Original mode is {beatmap.mode.value}."
|
||||
except HTTPError:
|
||||
return "Beatmap not found"
|
||||
|
||||
@@ -530,9 +498,7 @@ async def _mp_mods(
|
||||
if freestyle:
|
||||
item.allowed_mods = []
|
||||
elif freemod:
|
||||
item.allowed_mods = get_available_mods(
|
||||
current_item.ruleset_id, required_mods
|
||||
)
|
||||
item.allowed_mods = get_available_mods(current_item.ruleset_id, required_mods)
|
||||
else:
|
||||
item.allowed_mods = allowed_mods
|
||||
item.required_mods = required_mods
|
||||
@@ -601,14 +567,9 @@ async def _score(
|
||||
include_fail: bool = False,
|
||||
gamemode: GameMode | None = None,
|
||||
) -> str:
|
||||
q = (
|
||||
select(Score)
|
||||
.where(Score.user_id == user_id)
|
||||
.order_by(col(Score.id).desc())
|
||||
.options(joinedload(Score.beatmap))
|
||||
)
|
||||
q = select(Score).where(Score.user_id == user_id).order_by(col(Score.id).desc()).options(joinedload(Score.beatmap))
|
||||
if not include_fail:
|
||||
q = q.where(Score.passed.is_(True))
|
||||
q = q.where(col(Score.passed).is_(True))
|
||||
if gamemode is not None:
|
||||
q = q.where(Score.gamemode == gamemode)
|
||||
|
||||
@@ -619,17 +580,13 @@ async def _score(
|
||||
result = f"""{score.beatmap.beatmapset.title} [{score.beatmap.version}] ({score.gamemode.name.lower()})
|
||||
Played at {score.started_at}
|
||||
{score.pp:.2f}pp {score.accuracy:.2%} {",".join(mod_to_save(score.mods))} {score.rank.name.upper()}
|
||||
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}""" # noqa: E501
|
||||
Great: {score.n300}, Good: {score.n100}, Meh: {score.n50}, Miss: {score.nmiss}"""
|
||||
if score.gamemode == GameMode.MANIA:
|
||||
keys = next(
|
||||
(mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None
|
||||
)
|
||||
keys = next((mod["acronym"] for mod in score.mods if mod["acronym"].endswith("K")), None)
|
||||
if keys is None:
|
||||
keys = f"{int(score.beatmap.cs)}K"
|
||||
p_d_g = f"{score.ngeki / score.n300:.2f}:1" if score.n300 > 0 else "inf:1"
|
||||
result += (
|
||||
f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
|
||||
)
|
||||
result += f"\nKeys: {keys}, Perfect: {score.ngeki}, Ok: {score.nkatu}, P/G: {p_d_g}"
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -38,27 +38,18 @@ class UpdateResponse(BaseModel):
|
||||
)
|
||||
async def get_update(
|
||||
session: Database,
|
||||
history_since: int | None = Query(
|
||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||
),
|
||||
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
includes: list[str] = Query(
|
||||
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
|
||||
),
|
||||
includes: list[str] = Query(["presence", "silences"], alias="includes[]", description="要包含的更新类型"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
resp = UpdateResponse()
|
||||
if "presence" in includes:
|
||||
assert current_user.id
|
||||
channel_ids = server.get_user_joined_channel(current_user.id)
|
||||
for channel_id in channel_ids:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||||
if db_channel:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_type = db_channel.type
|
||||
@@ -69,34 +60,20 @@ async def get_update(
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
if "silences" in includes:
|
||||
if history_since:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(col(SilenceUser.id) > history_since)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
elif since:
|
||||
msg = await session.get(ChatMessage, since)
|
||||
if msg:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
col(SilenceUser.banned_at) > msg.timestamp
|
||||
)
|
||||
)
|
||||
await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
return resp
|
||||
|
||||
|
||||
@@ -115,15 +92,9 @@ async def join_channel(
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -145,15 +116,9 @@ async def leave_channel(
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -173,27 +138,20 @@ async def get_channel_list(
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
channels = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC)
|
||||
)
|
||||
).all()
|
||||
channels = (await session.exec(select(ChatChannel).where(ChatChannel.type == ChannelType.PUBLIC))).all()
|
||||
results = []
|
||||
for channel in channels:
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
channel_type = channel.type
|
||||
|
||||
assert channel_id is not None
|
||||
results.append(
|
||||
await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
return results
|
||||
@@ -219,15 +177,9 @@ async def get_channel(
|
||||
):
|
||||
# 使用明确的查询避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -237,8 +189,6 @@ async def get_channel(
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
assert channel_id is not None
|
||||
|
||||
users = []
|
||||
if channel_type == ChannelType.PM:
|
||||
user_ids = channel_name.split("_")[1:]
|
||||
@@ -259,9 +209,7 @@ async def get_channel(
|
||||
session,
|
||||
current_user,
|
||||
redis,
|
||||
server.channels.get(channel_id, [])
|
||||
if channel_type != ChannelType.PUBLIC
|
||||
else None,
|
||||
server.channels.get(channel_id, []) if channel_type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -284,9 +232,7 @@ class CreateChannelReq(BaseModel):
|
||||
raise ValueError("target_id must be set for PM channels")
|
||||
else:
|
||||
if self.target_ids is None or self.channel is None or self.message is None:
|
||||
raise ValueError(
|
||||
"target_ids, channel, and message must be set for ANNOUNCE channels"
|
||||
)
|
||||
raise ValueError("target_ids, channel, and message must be set for ANNOUNCE channels")
|
||||
return self
|
||||
|
||||
|
||||
@@ -312,24 +258,20 @@ async def create_channel(
|
||||
raise HTTPException(status_code=403, detail=block)
|
||||
|
||||
channel = await ChatChannel.get_pm_channel(
|
||||
current_user.id, # pyright: ignore[reportArgumentType]
|
||||
current_user.id,
|
||||
req.target_id, # pyright: ignore[reportArgumentType]
|
||||
session,
|
||||
)
|
||||
channel_name = f"pm_{current_user.id}_{req.target_id}"
|
||||
else:
|
||||
channel_name = req.channel.name if req.channel else "Unnamed Channel"
|
||||
result = await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.name == channel_name)
|
||||
)
|
||||
result = await session.exec(select(ChatChannel).where(ChatChannel.name == channel_name))
|
||||
channel = result.first()
|
||||
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
name=channel_name,
|
||||
description=req.channel.description
|
||||
if req.channel
|
||||
else "Private message channel",
|
||||
description=req.channel.description if req.channel else "Private message channel",
|
||||
type=ChannelType.PM if req.type == "PM" else ChannelType.ANNOUNCE,
|
||||
)
|
||||
session.add(channel)
|
||||
@@ -340,16 +282,13 @@ async def create_channel(
|
||||
await session.refresh(target) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
await server.batch_join_channel([target, current_user], channel, session) # pyright: ignore[reportPossiblyUnboundVariable]
|
||||
else:
|
||||
target_users = await session.exec(
|
||||
select(User).where(col(User.id).in_(req.target_ids or []))
|
||||
)
|
||||
target_users = await session.exec(select(User).where(col(User.id).in_(req.target_ids or [])))
|
||||
await server.batch_join_channel([*target_users, current_user], channel, session)
|
||||
|
||||
await server.join_channel(current_user, channel, session)
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id
|
||||
|
||||
return await ChatChannelResp.from_db(
|
||||
channel,
|
||||
|
||||
@@ -41,33 +41,19 @@ class KeepAliveResp(BaseModel):
|
||||
)
|
||||
async def keep_alive(
|
||||
session: Database,
|
||||
history_since: int | None = Query(
|
||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||
),
|
||||
history_since: int | None = Query(None, description="获取自此禁言 ID 之后的禁言记录"),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
):
|
||||
resp = KeepAliveResp()
|
||||
if history_since:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(col(SilenceUser.id) > history_since)
|
||||
)
|
||||
).all()
|
||||
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.id) > history_since))).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
elif since:
|
||||
msg = await session.get(ChatMessage, since)
|
||||
if msg:
|
||||
silences = (
|
||||
await session.exec(
|
||||
select(SilenceUser).where(
|
||||
col(SilenceUser.banned_at) > msg.timestamp
|
||||
)
|
||||
)
|
||||
).all()
|
||||
resp.silences.extend(
|
||||
[UserSilenceResp.from_db(silence) for silence in silences]
|
||||
)
|
||||
silences = (await session.exec(select(SilenceUser).where(col(SilenceUser.banned_at) > msg.timestamp))).all()
|
||||
resp.silences.extend([UserSilenceResp.from_db(silence) for silence in silences])
|
||||
|
||||
return resp
|
||||
|
||||
@@ -93,15 +79,9 @@ async def send_message(
|
||||
):
|
||||
# 使用明确的查询来获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
@@ -111,9 +91,6 @@ async def send_message(
|
||||
channel_type = db_channel.type
|
||||
channel_name = db_channel.name
|
||||
|
||||
assert channel_id is not None
|
||||
assert current_user.id
|
||||
|
||||
# 使用 Redis 消息系统发送消息 - 立即返回
|
||||
resp = await redis_message_system.send_message(
|
||||
channel_id=channel_id,
|
||||
@@ -125,9 +102,7 @@ async def send_message(
|
||||
|
||||
# 立即广播消息给所有客户端
|
||||
is_bot_command = req.message.startswith("!")
|
||||
await server.send_message_to_channel(
|
||||
resp, is_bot_command and channel_type == ChannelType.PUBLIC
|
||||
)
|
||||
await server.send_message_to_channel(resp, is_bot_command and channel_type == ChannelType.PUBLIC)
|
||||
|
||||
# 处理机器人命令
|
||||
if is_bot_command:
|
||||
@@ -147,14 +122,10 @@ async def send_message(
|
||||
if channel_type == ChannelType.PM:
|
||||
user_ids = channel_name.split("_")[1:]
|
||||
await server.new_private_notification(
|
||||
ChannelMessage.init(
|
||||
temp_msg, current_user, [int(u) for u in user_ids], channel_type
|
||||
)
|
||||
ChannelMessage.init(temp_msg, current_user, [int(u) for u in user_ids], channel_type)
|
||||
)
|
||||
elif channel_type == ChannelType.TEAM:
|
||||
await server.new_private_notification(
|
||||
ChannelMessageTeam.init(temp_msg, current_user)
|
||||
)
|
||||
await server.new_private_notification(ChannelMessageTeam.init(temp_msg, current_user))
|
||||
|
||||
return resp
|
||||
|
||||
@@ -176,22 +147,15 @@ async def get_message(
|
||||
):
|
||||
# 使用明确的查询获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
# 提取必要的属性避免惰性加载
|
||||
channel_id = db_channel.channel_id
|
||||
assert channel_id is not None
|
||||
|
||||
# 使用 Redis 消息系统获取消息
|
||||
try:
|
||||
@@ -230,23 +194,15 @@ async def mark_as_read(
|
||||
):
|
||||
# 使用明确的查询获取 channel,避免延迟加载
|
||||
if channel.isdigit():
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == int(channel))
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
|
||||
else:
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.name == channel))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.name == channel))).first()
|
||||
|
||||
if db_channel is None:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
|
||||
# 立即提取需要的属性
|
||||
channel_id = db_channel.channel_id
|
||||
assert channel_id
|
||||
assert current_user.id
|
||||
await server.mark_as_read(channel_id, current_user.id, message)
|
||||
|
||||
|
||||
@@ -283,7 +239,6 @@ async def create_new_pm(
|
||||
if not is_can_pm:
|
||||
raise HTTPException(status_code=403, detail=block)
|
||||
|
||||
assert user_id
|
||||
channel = await ChatChannel.get_pm_channel(user_id, req.target_id, session)
|
||||
if channel is None:
|
||||
channel = ChatChannel(
|
||||
@@ -297,7 +252,6 @@ async def create_new_pm(
|
||||
await session.refresh(target)
|
||||
await session.refresh(current_user)
|
||||
|
||||
assert channel.channel_id
|
||||
await server.batch_join_channel([target, current_user], channel, session)
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel, session, current_user, redis, server.channels[channel.channel_id]
|
||||
|
||||
@@ -17,6 +17,7 @@ from app.log import logger
|
||||
from app.models.chat import ChatEvent
|
||||
from app.models.notification import NotificationDetail
|
||||
from app.service.subscribers.chat import ChatSubscriber
|
||||
from app.utils import bg_tasks
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, WebSocket, WebSocketDisconnect
|
||||
from fastapi.security import SecurityScopes
|
||||
@@ -37,20 +38,11 @@ class ChatServer:
|
||||
self.ChatSubscriber.chat_server = self
|
||||
self._subscribed = False
|
||||
|
||||
def _add_task(self, task):
|
||||
task = asyncio.create_task(task)
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
|
||||
def connect(self, user_id: int, client: WebSocket):
|
||||
self.connect_client[user_id] = client
|
||||
|
||||
def get_user_joined_channel(self, user_id: int) -> list[int]:
|
||||
return [
|
||||
channel_id
|
||||
for channel_id, users in self.channels.items()
|
||||
if user_id in users
|
||||
]
|
||||
return [channel_id for channel_id, users in self.channels.items() if user_id in users]
|
||||
|
||||
async def disconnect(self, user: User, session: AsyncSession):
|
||||
user_id = user.id
|
||||
@@ -61,9 +53,7 @@ class ChatServer:
|
||||
channel.remove(user_id)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))
|
||||
).first()
|
||||
if db_channel:
|
||||
await self.leave_channel(user, db_channel, session)
|
||||
@@ -93,11 +83,10 @@ class ChatServer:
|
||||
async def mark_as_read(self, channel_id: int, user_id: int, message_id: int):
|
||||
await self.redis.set(f"chat:{channel_id}:last_read:{user_id}", message_id)
|
||||
|
||||
async def send_message_to_channel(
|
||||
self, message: ChatMessageResp, is_bot_command: bool = False
|
||||
):
|
||||
async def send_message_to_channel(self, message: ChatMessageResp, is_bot_command: bool = False):
|
||||
logger.info(
|
||||
f"Sending message to channel {message.channel_id}, message_id: {message.message_id}, is_bot_command: {is_bot_command}"
|
||||
f"Sending message to channel {message.channel_id}, message_id: "
|
||||
f"{message.message_id}, is_bot_command: {is_bot_command}"
|
||||
)
|
||||
|
||||
event = ChatEvent(
|
||||
@@ -106,62 +95,44 @@ class ChatServer:
|
||||
)
|
||||
if is_bot_command:
|
||||
logger.info(f"Sending bot command to user {message.sender_id}")
|
||||
self._add_task(self.send_event(message.sender_id, event))
|
||||
bg_tasks.add_task(self.send_event, message.sender_id, event)
|
||||
else:
|
||||
# 总是广播消息,无论是临时ID还是真实ID
|
||||
logger.info(
|
||||
f"Broadcasting message to all users in channel {message.channel_id}"
|
||||
)
|
||||
self._add_task(
|
||||
self.broadcast(
|
||||
message.channel_id,
|
||||
event,
|
||||
)
|
||||
logger.info(f"Broadcasting message to all users in channel {message.channel_id}")
|
||||
bg_tasks.add_task(
|
||||
self.broadcast,
|
||||
message.channel_id,
|
||||
event,
|
||||
)
|
||||
|
||||
# 只有真实消息 ID(正数且非零)才进行标记已读和设置最后消息
|
||||
# Redis 消息系统生成的ID都是正数,所以这里应该都能正常处理
|
||||
if message.message_id and message.message_id > 0:
|
||||
await self.mark_as_read(
|
||||
message.channel_id, message.sender_id, message.message_id
|
||||
)
|
||||
await self.redis.set(
|
||||
f"chat:{message.channel_id}:last_msg", message.message_id
|
||||
)
|
||||
logger.info(
|
||||
f"Updated last message ID for channel {message.channel_id} to {message.message_id}"
|
||||
)
|
||||
await self.mark_as_read(message.channel_id, message.sender_id, message.message_id)
|
||||
await self.redis.set(f"chat:{message.channel_id}:last_msg", message.message_id)
|
||||
logger.info(f"Updated last message ID for channel {message.channel_id} to {message.message_id}")
|
||||
else:
|
||||
logger.debug(
|
||||
f"Skipping last message update for message ID: {message.message_id}"
|
||||
)
|
||||
logger.debug(f"Skipping last message update for message ID: {message.message_id}")
|
||||
|
||||
async def batch_join_channel(
|
||||
self, users: list[User], channel: ChatChannel, session: AsyncSession
|
||||
):
|
||||
async def batch_join_channel(self, users: list[User], channel: ChatChannel, session: AsyncSession):
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
|
||||
not_joined = []
|
||||
|
||||
if channel_id not in self.channels:
|
||||
self.channels[channel_id] = []
|
||||
for user in users:
|
||||
assert user.id is not None
|
||||
if user.id not in self.channels[channel_id]:
|
||||
self.channels[channel_id].append(user.id)
|
||||
not_joined.append(user)
|
||||
|
||||
for user in not_joined:
|
||||
assert user.id is not None
|
||||
channel_resp = await ChatChannelResp.from_db(
|
||||
channel,
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels[channel_id]
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
self.channels[channel_id] if channel.type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
await self.send_event(
|
||||
user.id,
|
||||
@@ -171,13 +142,9 @@ class ChatServer:
|
||||
),
|
||||
)
|
||||
|
||||
async def join_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
) -> ChatChannelResp:
|
||||
async def join_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> ChatChannelResp:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
assert user_id is not None
|
||||
|
||||
if channel_id not in self.channels:
|
||||
self.channels[channel_id] = []
|
||||
@@ -202,13 +169,9 @@ class ChatServer:
|
||||
|
||||
return channel_resp
|
||||
|
||||
async def leave_channel(
|
||||
self, user: User, channel: ChatChannel, session: AsyncSession
|
||||
) -> None:
|
||||
async def leave_channel(self, user: User, channel: ChatChannel, session: AsyncSession) -> None:
|
||||
user_id = user.id
|
||||
channel_id = channel.channel_id
|
||||
assert channel_id is not None
|
||||
assert user_id is not None
|
||||
|
||||
if channel_id in self.channels and user_id in self.channels[channel_id]:
|
||||
self.channels[channel_id].remove(user_id)
|
||||
@@ -221,9 +184,7 @@ class ChatServer:
|
||||
session,
|
||||
user,
|
||||
self.redis,
|
||||
self.channels.get(channel_id)
|
||||
if channel.type != ChannelType.PUBLIC
|
||||
else None,
|
||||
self.channels.get(channel_id) if channel.type != ChannelType.PUBLIC else None,
|
||||
)
|
||||
await self.send_event(
|
||||
user_id,
|
||||
@@ -236,11 +197,7 @@ class ChatServer:
|
||||
async def join_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||||
if db_channel is None:
|
||||
return
|
||||
|
||||
@@ -253,11 +210,7 @@ class ChatServer:
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with with_db() as session:
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(
|
||||
select(ChatChannel).where(ChatChannel.channel_id == channel_id)
|
||||
)
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == channel_id))).first()
|
||||
if db_channel is None:
|
||||
return
|
||||
|
||||
@@ -270,13 +223,7 @@ class ChatServer:
|
||||
async def new_private_notification(self, detail: NotificationDetail):
|
||||
async with with_db() as session:
|
||||
id = await insert_notification(session, detail)
|
||||
users = (
|
||||
await session.exec(
|
||||
select(UserNotification).where(
|
||||
UserNotification.notification_id == id
|
||||
)
|
||||
)
|
||||
).all()
|
||||
users = (await session.exec(select(UserNotification).where(UserNotification.notification_id == id))).all()
|
||||
for user_notification in users:
|
||||
data = user_notification.notification.model_dump()
|
||||
data["is_read"] = user_notification.is_read
|
||||
@@ -308,9 +255,7 @@ async def _listen_stop(ws: WebSocket, user_id: int, factory: DBFactory):
|
||||
await ws.close(code=1000)
|
||||
break
|
||||
except WebSocketDisconnect as e:
|
||||
logger.info(
|
||||
f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}"
|
||||
)
|
||||
logger.info(f"[NotificationServer] Client {user_id} disconnected: {e.code}, {e.reason}")
|
||||
except RuntimeError as e:
|
||||
if "disconnect message" in str(e):
|
||||
logger.info(f"[NotificationServer] Client {user_id} closed the connection.")
|
||||
@@ -332,11 +277,7 @@ async def chat_websocket(
|
||||
|
||||
async for session in factory():
|
||||
token = authorization[7:]
|
||||
if (
|
||||
user := await get_current_user(
|
||||
session, SecurityScopes(scopes=["chat.read"]), token_pw=token
|
||||
)
|
||||
) is None:
|
||||
if (user := await get_current_user(session, SecurityScopes(scopes=["chat.read"]), token_pw=token)) is None:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
@@ -346,12 +287,9 @@ async def chat_websocket(
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
user_id = user.id
|
||||
assert user_id
|
||||
server.connect(user_id, websocket)
|
||||
# 使用明确的查询避免延迟加载
|
||||
db_channel = (
|
||||
await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))
|
||||
).first()
|
||||
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == 1))).first()
|
||||
if db_channel is not None:
|
||||
await server.join_channel(user, db_channel, session)
|
||||
|
||||
|
||||
@@ -2,22 +2,20 @@
|
||||
密码重置管理接口
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
from __future__ import annotations
|
||||
|
||||
from app.dependencies.database import get_redis
|
||||
from app.service.password_reset_service import password_reset_service
|
||||
from app.log import logger
|
||||
from app.service.password_reset_service import password_reset_service
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
|
||||
router = APIRouter(prefix="/admin/password-reset", tags=["密码重置管理"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/status/{email}",
|
||||
name="查询重置状态",
|
||||
description="查询指定邮箱的密码重置状态"
|
||||
)
|
||||
@router.get("/status/{email}", name="查询重置状态", description="查询指定邮箱的密码重置状态")
|
||||
async def get_password_reset_status(
|
||||
email: str,
|
||||
redis: Redis = Depends(get_redis),
|
||||
@@ -25,28 +23,16 @@ async def get_password_reset_status(
|
||||
"""查询密码重置状态"""
|
||||
try:
|
||||
info = await password_reset_service.get_reset_code_info(email, redis)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"success": True,
|
||||
"data": info
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"success": True, "data": info})
|
||||
except Exception as e:
|
||||
logger.error(f"[Admin] Failed to get password reset status for {email}: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "获取状态失败"
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"success": False, "error": "获取状态失败"})
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/cleanup/{email}",
|
||||
name="清理重置数据",
|
||||
description="强制清理指定邮箱的密码重置数据"
|
||||
description="强制清理指定邮箱的密码重置数据",
|
||||
)
|
||||
async def force_cleanup_reset(
|
||||
email: str,
|
||||
@@ -55,38 +41,23 @@ async def force_cleanup_reset(
|
||||
"""强制清理密码重置数据"""
|
||||
try:
|
||||
success = await password_reset_service.force_cleanup_user_reset(email, redis)
|
||||
|
||||
|
||||
if success:
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"success": True,
|
||||
"message": f"已清理邮箱 {email} 的重置数据"
|
||||
}
|
||||
content={"success": True, "message": f"已清理邮箱 {email} 的重置数据"},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "清理失败"
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"success": False, "error": "清理失败"})
|
||||
except Exception as e:
|
||||
logger.error(f"[Admin] Failed to cleanup password reset for {email}: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "清理操作失败"
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"success": False, "error": "清理操作失败"})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cleanup/expired",
|
||||
name="清理过期验证码",
|
||||
description="清理所有过期的密码重置验证码"
|
||||
description="清理所有过期的密码重置验证码",
|
||||
)
|
||||
async def cleanup_expired_codes(
|
||||
redis: Redis = Depends(get_redis),
|
||||
@@ -99,25 +70,15 @@ async def cleanup_expired_codes(
|
||||
content={
|
||||
"success": True,
|
||||
"message": f"已清理 {count} 个过期的验证码",
|
||||
"cleaned_count": count
|
||||
}
|
||||
"cleaned_count": count,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Admin] Failed to cleanup expired codes: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "清理操作失败"
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"success": False, "error": "清理操作失败"})
|
||||
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
name="重置统计",
|
||||
description="获取密码重置的统计信息"
|
||||
)
|
||||
@router.get("/stats", name="重置统计", description="获取密码重置的统计信息")
|
||||
async def get_reset_statistics(
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
@@ -126,53 +87,42 @@ async def get_reset_statistics(
|
||||
# 获取所有重置相关的键
|
||||
reset_keys = await redis.keys("password_reset:code:*")
|
||||
rate_limit_keys = await redis.keys("password_reset:rate_limit:*")
|
||||
|
||||
|
||||
active_resets = 0
|
||||
used_resets = 0
|
||||
active_rate_limits = 0
|
||||
|
||||
|
||||
# 统计活跃重置
|
||||
for key in reset_keys:
|
||||
data_str = await redis.get(key)
|
||||
if data_str:
|
||||
try:
|
||||
import json
|
||||
|
||||
data = json.loads(data_str)
|
||||
if data.get("used", False):
|
||||
used_resets += 1
|
||||
else:
|
||||
active_resets += 1
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# 统计频率限制
|
||||
for key in rate_limit_keys:
|
||||
ttl = await redis.ttl(key)
|
||||
if ttl > 0:
|
||||
active_rate_limits += 1
|
||||
|
||||
|
||||
stats = {
|
||||
"total_reset_codes": len(reset_keys),
|
||||
"active_resets": active_resets,
|
||||
"used_resets": used_resets,
|
||||
"active_rate_limits": active_rate_limits,
|
||||
"total_rate_limit_keys": len(rate_limit_keys)
|
||||
"total_rate_limit_keys": len(rate_limit_keys),
|
||||
}
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"success": True,
|
||||
"data": stats
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return JSONResponse(status_code=200, content={"success": True, "data": stats})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Admin] Failed to get reset statistics: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "获取统计信息失败"
|
||||
}
|
||||
)
|
||||
return JSONResponse(status_code=500, content={"success": False, "error": "获取统计信息失败"})
|
||||
|
||||
@@ -26,7 +26,7 @@ async def create_oauth_app(
|
||||
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
result = await session.execute( # pyright: ignore[reportDeprecated]
|
||||
result = await session.execute(
|
||||
text(
|
||||
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
||||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'oauth_clients'"
|
||||
@@ -84,9 +84,7 @@ async def get_user_oauth_apps(
|
||||
session: Database,
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
oauth_apps = await session.exec(
|
||||
select(OAuthClient).where(OAuthClient.owner_id == current_user.id)
|
||||
)
|
||||
oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id))
|
||||
return [
|
||||
{
|
||||
"name": app.name,
|
||||
@@ -113,13 +111,9 @@ async def delete_oauth_app(
|
||||
if not oauth_client:
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
if oauth_client.owner_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Forbidden: Not the owner of this app"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
|
||||
|
||||
tokens = await session.exec(
|
||||
select(OAuthToken).where(OAuthToken.client_id == client_id)
|
||||
)
|
||||
tokens = await session.exec(select(OAuthToken).where(OAuthToken.client_id == client_id))
|
||||
for token in tokens:
|
||||
await session.delete(token)
|
||||
|
||||
@@ -144,9 +138,7 @@ async def update_oauth_app(
|
||||
if not oauth_client:
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
if oauth_client.owner_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Forbidden: Not the owner of this app"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
|
||||
|
||||
oauth_client.name = name
|
||||
oauth_client.description = description
|
||||
@@ -176,14 +168,10 @@ async def refresh_secret(
|
||||
if not oauth_client:
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
if oauth_client.owner_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Forbidden: Not the owner of this app"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="Forbidden: Not the owner of this app")
|
||||
|
||||
oauth_client.client_secret = secrets.token_hex()
|
||||
tokens = await session.exec(
|
||||
select(OAuthToken).where(OAuthToken.client_id == client_id)
|
||||
)
|
||||
tokens = await session.exec(select(OAuthToken).where(OAuthToken.client_id == client_id))
|
||||
for token in tokens:
|
||||
await session.delete(token)
|
||||
|
||||
@@ -215,9 +203,7 @@ async def generate_oauth_code(
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
|
||||
if redirect_uri not in client.redirect_uris:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Redirect URI not allowed for this client"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="Redirect URI not allowed for this client")
|
||||
|
||||
code = secrets.token_urlsafe(80)
|
||||
await redis.hset( # pyright: ignore[reportGeneralTypeIssues]
|
||||
|
||||
@@ -50,12 +50,8 @@ async def check_user_relationship(
|
||||
)
|
||||
).first()
|
||||
|
||||
is_followed = bool(
|
||||
target_relationship and target_relationship.type == RelationshipType.FOLLOW
|
||||
)
|
||||
is_following = bool(
|
||||
my_relationship and my_relationship.type == RelationshipType.FOLLOW
|
||||
)
|
||||
is_followed = bool(target_relationship and target_relationship.type == RelationshipType.FOLLOW)
|
||||
is_following = bool(my_relationship and my_relationship.type == RelationshipType.FOLLOW)
|
||||
|
||||
return CheckResponse(
|
||||
is_followed=is_followed,
|
||||
|
||||
@@ -40,16 +40,13 @@ async def create_team(
|
||||
支持的图片格式: PNG、JPEG、GIF
|
||||
"""
|
||||
user_id = current_user.id
|
||||
assert user_id
|
||||
if (await current_user.awaitable_attrs.team_membership) is not None:
|
||||
raise HTTPException(status_code=403, detail="You are already in a team")
|
||||
|
||||
is_existed = (await session.exec(select(exists()).where(Team.name == name))).first()
|
||||
if is_existed:
|
||||
raise HTTPException(status_code=409, detail="Name already exists")
|
||||
is_existed = (
|
||||
await session.exec(select(exists()).where(Team.short_name == short_name))
|
||||
).first()
|
||||
is_existed = (await session.exec(select(exists()).where(Team.short_name == short_name))).first()
|
||||
if is_existed:
|
||||
raise HTTPException(status_code=409, detail="Short name already exists")
|
||||
|
||||
@@ -101,7 +98,6 @@ async def update_team(
|
||||
"""
|
||||
team = await session.get(Team, team_id)
|
||||
user_id = current_user.id
|
||||
assert user_id
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail="Team not found")
|
||||
if team.leader_id != user_id:
|
||||
@@ -110,9 +106,7 @@ async def update_team(
|
||||
is_existed = (await session.exec(select(exists()).where(Team.name == name))).first()
|
||||
if is_existed:
|
||||
raise HTTPException(status_code=409, detail="Name already exists")
|
||||
is_existed = (
|
||||
await session.exec(select(exists()).where(Team.short_name == short_name))
|
||||
).first()
|
||||
is_existed = (await session.exec(select(exists()).where(Team.short_name == short_name))).first()
|
||||
if is_existed:
|
||||
raise HTTPException(status_code=409, detail="Short name already exists")
|
||||
|
||||
@@ -132,20 +126,12 @@ async def update_team(
|
||||
team.cover_url = await storage.get_file_url(storage_path)
|
||||
|
||||
if leader_id is not None:
|
||||
if not (
|
||||
await session.exec(select(exists()).where(User.id == leader_id))
|
||||
).first():
|
||||
if not (await session.exec(select(exists()).where(User.id == leader_id))).first():
|
||||
raise HTTPException(status_code=404, detail="Leader not found")
|
||||
if not (
|
||||
await session.exec(
|
||||
select(TeamMember).where(
|
||||
TeamMember.user_id == leader_id, TeamMember.team_id == team.id
|
||||
)
|
||||
)
|
||||
await session.exec(select(TeamMember).where(TeamMember.user_id == leader_id, TeamMember.team_id == team.id))
|
||||
).first():
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Leader is not a member of the team"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="Leader is not a member of the team")
|
||||
team.leader_id = leader_id
|
||||
|
||||
await session.commit()
|
||||
@@ -166,9 +152,7 @@ async def delete_team(
|
||||
if team.leader_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="You are not the team leader")
|
||||
|
||||
team_members = await session.exec(
|
||||
select(TeamMember).where(TeamMember.team_id == team_id)
|
||||
)
|
||||
team_members = await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))
|
||||
for member in team_members:
|
||||
await session.delete(member)
|
||||
|
||||
@@ -186,15 +170,10 @@ async def get_team(
|
||||
session: Database,
|
||||
team_id: int = Path(..., description="战队 ID"),
|
||||
):
|
||||
members = (
|
||||
await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))
|
||||
).all()
|
||||
members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all()
|
||||
return TeamQueryResp(
|
||||
team=members[0].team,
|
||||
members=[
|
||||
await UserResp.from_db(m.user, session, include=BASE_INCLUDES)
|
||||
for m in members
|
||||
],
|
||||
members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members],
|
||||
)
|
||||
|
||||
|
||||
@@ -213,15 +192,11 @@ async def request_join_team(
|
||||
|
||||
if (
|
||||
await session.exec(
|
||||
select(exists()).where(
|
||||
TeamRequest.team_id == team_id, TeamRequest.user_id == current_user.id
|
||||
)
|
||||
select(exists()).where(TeamRequest.team_id == team_id, TeamRequest.user_id == current_user.id)
|
||||
)
|
||||
).first():
|
||||
raise HTTPException(status_code=409, detail="Join request already exists")
|
||||
team_request = TeamRequest(
|
||||
user_id=current_user.id, team_id=team_id, requested_at=datetime.now(UTC)
|
||||
)
|
||||
team_request = TeamRequest(user_id=current_user.id, team_id=team_id, requested_at=datetime.now(UTC))
|
||||
session.add(team_request)
|
||||
await session.commit()
|
||||
await session.refresh(team_request)
|
||||
@@ -229,9 +204,7 @@ async def request_join_team(
|
||||
|
||||
|
||||
@router.post("/team/{team_id}/{user_id}/request", name="接受加入请求", status_code=204)
|
||||
@router.delete(
|
||||
"/team/{team_id}/{user_id}/request", name="拒绝加入请求", status_code=204
|
||||
)
|
||||
@router.delete("/team/{team_id}/{user_id}/request", name="拒绝加入请求", status_code=204)
|
||||
async def handle_request(
|
||||
req: Request,
|
||||
session: Database,
|
||||
@@ -247,11 +220,7 @@ async def handle_request(
|
||||
raise HTTPException(status_code=403, detail="You are not the team leader")
|
||||
|
||||
team_request = (
|
||||
await session.exec(
|
||||
select(TeamRequest).where(
|
||||
TeamRequest.team_id == team_id, TeamRequest.user_id == user_id
|
||||
)
|
||||
)
|
||||
await session.exec(select(TeamRequest).where(TeamRequest.team_id == team_id, TeamRequest.user_id == user_id))
|
||||
).first()
|
||||
if not team_request:
|
||||
raise HTTPException(status_code=404, detail="Join request not found")
|
||||
@@ -261,16 +230,10 @@ async def handle_request(
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if req.method == "POST":
|
||||
if (
|
||||
await session.exec(select(exists()).where(TeamMember.user_id == user_id))
|
||||
).first():
|
||||
raise HTTPException(
|
||||
status_code=409, detail="User is already a member of the team"
|
||||
)
|
||||
if (await session.exec(select(exists()).where(TeamMember.user_id == user_id))).first():
|
||||
raise HTTPException(status_code=409, detail="User is already a member of the team")
|
||||
|
||||
session.add(
|
||||
TeamMember(user_id=user_id, team_id=team_id, joined_at=datetime.now(UTC))
|
||||
)
|
||||
session.add(TeamMember(user_id=user_id, team_id=team_id, joined_at=datetime.now(UTC)))
|
||||
|
||||
await server.new_private_notification(TeamApplicationAccept.init(team_request))
|
||||
else:
|
||||
@@ -294,19 +257,13 @@ async def kick_member(
|
||||
raise HTTPException(status_code=403, detail="You are not the team leader")
|
||||
|
||||
team_member = (
|
||||
await session.exec(
|
||||
select(TeamMember).where(
|
||||
TeamMember.team_id == team_id, TeamMember.user_id == user_id
|
||||
)
|
||||
)
|
||||
await session.exec(select(TeamMember).where(TeamMember.team_id == team_id, TeamMember.user_id == user_id))
|
||||
).first()
|
||||
if not team_member:
|
||||
raise HTTPException(status_code=404, detail="User is not a member of the team")
|
||||
|
||||
if team.leader_id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You cannot leave because you are the team leader"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="You cannot leave because you are the team leader")
|
||||
|
||||
await session.delete(team_member)
|
||||
await session.commit()
|
||||
|
||||
@@ -35,10 +35,7 @@ async def user_rename(
|
||||
返回:
|
||||
- 成功: None
|
||||
"""
|
||||
assert current_user is not None
|
||||
samename_user = (
|
||||
await session.exec(select(User).where(User.username == new_name))
|
||||
).first()
|
||||
samename_user = (await session.exec(select(User).where(User.username == new_name))).first()
|
||||
if samename_user:
|
||||
raise HTTPException(409, "Username Exisits")
|
||||
errors = validate_username(new_name)
|
||||
|
||||
@@ -106,9 +106,7 @@ class V1Beatmap(AllStrModel):
|
||||
await session.exec(
|
||||
select(func.count())
|
||||
.select_from(FavouriteBeatmapset)
|
||||
.where(
|
||||
FavouriteBeatmapset.beatmapset_id == db_beatmap.beatmapset.id
|
||||
)
|
||||
.where(FavouriteBeatmapset.beatmapset_id == db_beatmap.beatmapset.id)
|
||||
)
|
||||
).one(),
|
||||
rating=0, # TODO
|
||||
@@ -154,12 +152,8 @@ async def get_beatmaps(
|
||||
beatmapset_id: int | None = Query(None, alias="s", description="谱面集 ID"),
|
||||
beatmap_id: int | None = Query(None, alias="b", description="谱面 ID"),
|
||||
user: str | None = Query(None, alias="u", description="谱师"),
|
||||
type: Literal["string", "id"] | None = Query(
|
||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||
),
|
||||
ruleset_id: int | None = Query(
|
||||
None, alias="m", description="Ruleset ID", ge=0, le=3
|
||||
), # TODO
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0, le=3), # TODO
|
||||
convert: bool = Query(False, alias="a", description="转谱"), # TODO
|
||||
checksum: str | None = Query(None, alias="h", description="谱面文件 MD5"),
|
||||
limit: int = Query(500, ge=1, le=500, description="返回结果数量限制"),
|
||||
@@ -181,11 +175,7 @@ async def get_beatmaps(
|
||||
else:
|
||||
beatmaps = beatmapset.beatmaps
|
||||
elif user is not None:
|
||||
where = (
|
||||
Beatmapset.user_id == user
|
||||
if type == "id" or user.isdigit()
|
||||
else Beatmapset.creator == user
|
||||
)
|
||||
where = Beatmapset.user_id == user if type == "id" or user.isdigit() else Beatmapset.creator == user
|
||||
beatmapsets = (await session.exec(select(Beatmapset).where(where))).all()
|
||||
for beatmapset in beatmapsets:
|
||||
if len(beatmaps) >= limit:
|
||||
@@ -193,11 +183,7 @@ async def get_beatmaps(
|
||||
beatmaps.extend(beatmapset.beatmaps)
|
||||
elif since is not None:
|
||||
beatmapsets = (
|
||||
await session.exec(
|
||||
select(Beatmapset)
|
||||
.where(col(Beatmapset.ranked_date) > since)
|
||||
.limit(limit)
|
||||
)
|
||||
await session.exec(select(Beatmapset).where(col(Beatmapset.ranked_date) > since).limit(limit))
|
||||
).all()
|
||||
for beatmapset in beatmapsets:
|
||||
if len(beatmaps) >= limit:
|
||||
@@ -214,11 +200,7 @@ async def get_beatmaps(
|
||||
redis,
|
||||
fetcher,
|
||||
)
|
||||
results.append(
|
||||
await V1Beatmap.from_db(
|
||||
session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty
|
||||
)
|
||||
)
|
||||
results.append(await V1Beatmap.from_db(session, beatmap, attrs.aim_difficulty, attrs.speed_difficulty))
|
||||
continue
|
||||
except Exception:
|
||||
...
|
||||
|
||||
@@ -41,9 +41,7 @@ async def download_replay(
|
||||
ge=0,
|
||||
),
|
||||
score_id: int | None = Query(None, alias="s", description="成绩 ID"),
|
||||
type: Literal["string", "id"] | None = Query(
|
||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||
),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
mods: int = Query(0, description="成绩的 MOD"),
|
||||
storage_service: StorageService = Depends(get_storage_service),
|
||||
):
|
||||
@@ -58,13 +56,9 @@ async def download_replay(
|
||||
await session.exec(
|
||||
select(Score).where(
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == user
|
||||
if type == "id" or user.isdigit()
|
||||
else col(Score.user).has(username=user),
|
||||
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
|
||||
Score.mods == mods_,
|
||||
Score.gamemode == GameMode.from_int_extra(ruleset_id)
|
||||
if ruleset_id is not None
|
||||
else True,
|
||||
Score.gamemode == GameMode.from_int_extra(ruleset_id) if ruleset_id is not None else True,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
@@ -73,10 +67,7 @@ async def download_replay(
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
|
||||
filepath = (
|
||||
f"replays/{score_record.id}_{score_record.beatmap_id}"
|
||||
f"_{score_record.user_id}_lazer_replay.osr"
|
||||
)
|
||||
filepath = f"replays/{score_record.id}_{score_record.beatmap_id}_{score_record.user_id}_lazer_replay.osr"
|
||||
if not await storage_service.is_exists(filepath):
|
||||
raise HTTPException(status_code=404, detail="Replay file not found")
|
||||
|
||||
@@ -100,6 +91,4 @@ async def download_replay(
|
||||
await session.commit()
|
||||
|
||||
data = await storage_service.read_file(filepath)
|
||||
return ReplayModel(
|
||||
content=base64.b64encode(data).decode("utf-8"), encoding="base64"
|
||||
)
|
||||
return ReplayModel(content=base64.b64encode(data).decode("utf-8"), encoding="base64")
|
||||
|
||||
@@ -8,9 +8,7 @@ from app.dependencies.user import v1_authorize
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1", dependencies=[Depends(v1_authorize)], tags=["V1 API"]
|
||||
)
|
||||
router = APIRouter(prefix="/api/v1", dependencies=[Depends(v1_authorize)], tags=["V1 API"])
|
||||
|
||||
|
||||
class AllStrModel(BaseModel):
|
||||
|
||||
@@ -70,9 +70,7 @@ async def get_user_best(
|
||||
session: Database,
|
||||
user: str = Query(..., alias="u", description="用户"),
|
||||
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
||||
type: Literal["string", "id"] | None = Query(
|
||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||
),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
||||
):
|
||||
try:
|
||||
@@ -80,9 +78,7 @@ async def get_user_best(
|
||||
await session.exec(
|
||||
select(Score)
|
||||
.where(
|
||||
Score.user_id == user
|
||||
if type == "id" or user.isdigit()
|
||||
else col(Score.user).has(username=user),
|
||||
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
|
||||
Score.gamemode == GameMode.from_int_extra(ruleset_id),
|
||||
exists().where(col(PPBestScore.score_id) == Score.id),
|
||||
)
|
||||
@@ -106,9 +102,7 @@ async def get_user_recent(
|
||||
session: Database,
|
||||
user: str = Query(..., alias="u", description="用户"),
|
||||
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
||||
type: Literal["string", "id"] | None = Query(
|
||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||
),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
||||
):
|
||||
try:
|
||||
@@ -116,9 +110,7 @@ async def get_user_recent(
|
||||
await session.exec(
|
||||
select(Score)
|
||||
.where(
|
||||
Score.user_id == user
|
||||
if type == "id" or user.isdigit()
|
||||
else col(Score.user).has(username=user),
|
||||
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
|
||||
Score.gamemode == GameMode.from_int_extra(ruleset_id),
|
||||
Score.ended_at > datetime.now(UTC) - timedelta(hours=24),
|
||||
)
|
||||
@@ -143,9 +135,7 @@ async def get_scores(
|
||||
user: str | None = Query(None, alias="u", description="用户"),
|
||||
beatmap_id: int = Query(alias="b", description="谱面 ID"),
|
||||
ruleset_id: int = Query(0, alias="m", description="Ruleset ID", ge=0),
|
||||
type: Literal["string", "id"] | None = Query(
|
||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||
),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
limit: int = Query(10, ge=1, le=100, description="返回的成绩数量"),
|
||||
mods: int = Query(0, description="成绩的 MOD"),
|
||||
):
|
||||
@@ -157,9 +147,7 @@ async def get_scores(
|
||||
.where(
|
||||
Score.gamemode == GameMode.from_int_extra(ruleset_id),
|
||||
Score.beatmap_id == beatmap_id,
|
||||
Score.user_id == user
|
||||
if type == "id" or user.isdigit()
|
||||
else col(Score.user).has(username=user),
|
||||
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
|
||||
)
|
||||
.options(joinedload(Score.beatmap))
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
@@ -13,7 +12,7 @@ from app.service.user_cache_service import get_user_cache_service
|
||||
|
||||
from .router import AllStrModel, router
|
||||
|
||||
from fastapi import HTTPException, Query
|
||||
from fastapi import BackgroundTasks, HTTPException, Query
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
@@ -49,9 +48,7 @@ class V1User(AllStrModel):
|
||||
return f"v1_user:{user_id}"
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
cls, session: Database, db_user: User, ruleset: GameMode | None = None
|
||||
) -> "V1User":
|
||||
async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User":
|
||||
# 确保 user_id 不为 None
|
||||
if db_user.id is None:
|
||||
raise ValueError("User ID cannot be None")
|
||||
@@ -63,9 +60,7 @@ class V1User(AllStrModel):
|
||||
current_statistics = i
|
||||
break
|
||||
if current_statistics:
|
||||
statistics = await UserStatisticsResp.from_db(
|
||||
current_statistics, session, db_user.country_code
|
||||
)
|
||||
statistics = await UserStatisticsResp.from_db(current_statistics, session, db_user.country_code)
|
||||
else:
|
||||
statistics = None
|
||||
return cls(
|
||||
@@ -78,9 +73,7 @@ class V1User(AllStrModel):
|
||||
playcount=statistics.play_count if statistics else 0,
|
||||
ranked_score=statistics.ranked_score if statistics else 0,
|
||||
total_score=statistics.total_score if statistics else 0,
|
||||
pp_rank=statistics.global_rank
|
||||
if statistics and statistics.global_rank
|
||||
else 0,
|
||||
pp_rank=statistics.global_rank if statistics and statistics.global_rank else 0,
|
||||
level=current_statistics.level_current if current_statistics else 0,
|
||||
pp_raw=statistics.pp if statistics else 0.0,
|
||||
accuracy=statistics.hit_accuracy if statistics else 0,
|
||||
@@ -91,9 +84,7 @@ class V1User(AllStrModel):
|
||||
count_rank_a=current_statistics.grade_a if current_statistics else 0,
|
||||
country=db_user.country_code,
|
||||
total_seconds_played=statistics.play_time if statistics else 0,
|
||||
pp_country_rank=statistics.country_rank
|
||||
if statistics and statistics.country_rank
|
||||
else 0,
|
||||
pp_country_rank=statistics.country_rank if statistics and statistics.country_rank else 0,
|
||||
events=[], # TODO
|
||||
)
|
||||
|
||||
@@ -106,14 +97,11 @@ class V1User(AllStrModel):
|
||||
)
|
||||
async def get_user(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: str = Query(..., alias="u", description="用户"),
|
||||
ruleset_id: int | None = Query(None, alias="m", description="Ruleset ID", ge=0),
|
||||
type: Literal["string", "id"] | None = Query(
|
||||
None, description="用户类型:string 用户名称 / id 用户 ID"
|
||||
),
|
||||
event_days: int = Query(
|
||||
default=1, ge=1, le=31, description="从现在起所有事件的最大天数"
|
||||
),
|
||||
type: Literal["string", "id"] | None = Query(None, description="用户类型:string 用户名称 / id 用户 ID"),
|
||||
event_days: int = Query(default=1, ge=1, le=31, description="从现在起所有事件的最大天数"),
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
@@ -131,9 +119,7 @@ async def get_user(
|
||||
if is_id_query:
|
||||
try:
|
||||
user_id_for_cache = int(user)
|
||||
cached_v1_user = await cache_service.get_v1_user_from_cache(
|
||||
user_id_for_cache, ruleset
|
||||
)
|
||||
cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset)
|
||||
if cached_v1_user:
|
||||
return [V1User(**cached_v1_user)]
|
||||
except (ValueError, TypeError):
|
||||
@@ -158,9 +144,7 @@ async def get_user(
|
||||
# 异步缓存结果(如果有用户ID)
|
||||
if db_user.id is not None:
|
||||
user_data = v1_user.model_dump()
|
||||
asyncio.create_task(
|
||||
cache_service.cache_v1_user(user_data, db_user.id, ruleset)
|
||||
)
|
||||
background_tasks.add_task(cache_service.cache_v1_user, user_data, db_user.id, ruleset)
|
||||
|
||||
return [v1_user]
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
|
||||
from . import ( # noqa: F401
|
||||
beatmap,
|
||||
beatmapset,
|
||||
me,
|
||||
|
||||
@@ -40,18 +40,13 @@ class BatchGetResp(BaseModel):
|
||||
tags=["谱面"],
|
||||
name="查询单个谱面",
|
||||
response_model=BeatmapResp,
|
||||
description=(
|
||||
"根据谱面 ID / MD5 / 文件名 查询单个谱面。"
|
||||
"至少提供 id / checksum / filename 之一。"
|
||||
),
|
||||
description=("根据谱面 ID / MD5 / 文件名 查询单个谱面。至少提供 id / checksum / filename 之一。"),
|
||||
)
|
||||
async def lookup_beatmap(
|
||||
db: Database,
|
||||
id: int | None = Query(default=None, alias="id", description="谱面 ID"),
|
||||
md5: str | None = Query(default=None, alias="checksum", description="谱面文件 MD5"),
|
||||
filename: str | None = Query(
|
||||
default=None, alias="filename", description="谱面文件名"
|
||||
),
|
||||
filename: str | None = Query(default=None, alias="filename", description="谱面文件名"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
@@ -96,43 +91,23 @@ async def get_beatmap(
|
||||
tags=["谱面"],
|
||||
name="批量获取谱面",
|
||||
response_model=BatchGetResp,
|
||||
description=(
|
||||
"批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。"
|
||||
"为空时按最近更新时间返回。"
|
||||
),
|
||||
description=("批量获取谱面。若不提供 ids[],按最近更新时间返回最多 50 条。为空时按最近更新时间返回。"),
|
||||
)
|
||||
async def batch_get_beatmaps(
|
||||
db: Database,
|
||||
beatmap_ids: list[int] = Query(
|
||||
alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"
|
||||
),
|
||||
beatmap_ids: list[int] = Query(alias="ids[]", default_factory=list, description="谱面 ID 列表 (最多 50 个)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
if not beatmap_ids:
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
||||
)
|
||||
).all()
|
||||
beatmaps = (await db.exec(select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50))).all()
|
||||
else:
|
||||
beatmaps = list(
|
||||
(
|
||||
await db.exec(
|
||||
select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50)
|
||||
)
|
||||
).all()
|
||||
)
|
||||
not_found_beatmaps = [
|
||||
bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]
|
||||
]
|
||||
beatmaps = list((await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)).limit(50))).all())
|
||||
not_found_beatmaps = [bid for bid in beatmap_ids if bid not in [bm.id for bm in beatmaps]]
|
||||
beatmaps.extend(
|
||||
beatmap
|
||||
for beatmap in await asyncio.gather(
|
||||
*[
|
||||
Beatmap.get_or_fetch(db, fetcher, bid=bid)
|
||||
for bid in not_found_beatmaps
|
||||
],
|
||||
*[Beatmap.get_or_fetch(db, fetcher, bid=bid) for bid in not_found_beatmaps],
|
||||
return_exceptions=True,
|
||||
)
|
||||
if isinstance(beatmap, Beatmap)
|
||||
@@ -140,12 +115,7 @@ async def batch_get_beatmaps(
|
||||
for beatmap in beatmaps:
|
||||
await db.refresh(beatmap)
|
||||
|
||||
return BatchGetResp(
|
||||
beatmaps=[
|
||||
await BeatmapResp.from_db(bm, session=db, user=current_user)
|
||||
for bm in beatmaps
|
||||
]
|
||||
)
|
||||
return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm, session=db, user=current_user) for bm in beatmaps])
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -163,12 +133,8 @@ async def get_beatmap_attributes(
|
||||
default_factory=list,
|
||||
description="Mods 列表;可为整型位掩码(单元素)或 JSON/简称",
|
||||
),
|
||||
ruleset: GameMode | None = Query(
|
||||
default=None, description="指定 ruleset;为空则使用谱面自身模式"
|
||||
),
|
||||
ruleset_id: int | None = Query(
|
||||
default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3
|
||||
),
|
||||
ruleset: GameMode | None = Query(default=None, description="指定 ruleset;为空则使用谱面自身模式"),
|
||||
ruleset_id: int | None = Query(default=None, description="以数字指定 ruleset (与 ruleset 二选一)", ge=0, le=3),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
@@ -187,16 +153,11 @@ async def get_beatmap_attributes(
|
||||
if ruleset is None:
|
||||
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap_id)
|
||||
ruleset = beatmap_db.mode
|
||||
key = (
|
||||
f"beatmap:{beatmap_id}:{ruleset}:"
|
||||
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
)
|
||||
key = f"beatmap:{beatmap_id}:{ruleset}:{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
if await redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
try:
|
||||
return await calculate_beatmap_attributes(
|
||||
beatmap_id, ruleset, mods_, redis, fetcher
|
||||
)
|
||||
return await calculate_beatmap_attributes(beatmap_id, ruleset, mods_, redis, fetcher)
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
@@ -35,9 +35,7 @@ from sqlmodel import exists, select
|
||||
async def _save_to_db(sets: SearchBeatmapsetsResp):
|
||||
async with with_db() as session:
|
||||
for s in sets.beatmapsets:
|
||||
if not (
|
||||
await session.exec(select(exists()).where(Beatmapset.id == s.id))
|
||||
).first():
|
||||
if not (await session.exec(select(exists()).where(Beatmapset.id == s.id))).first():
|
||||
await Beatmapset.from_resp(session, s)
|
||||
|
||||
|
||||
@@ -117,9 +115,7 @@ async def lookup_beatmapset(
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=beatmap_id)
|
||||
resp = await BeatmapsetResp.from_db(
|
||||
beatmap.beatmapset, session=db, user=current_user
|
||||
)
|
||||
resp = await BeatmapsetResp.from_db(beatmap.beatmapset, session=db, user=current_user)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -138,9 +134,7 @@ async def get_beatmapset(
|
||||
):
|
||||
try:
|
||||
beatmapset = await Beatmapset.get_or_fetch(db, fetcher, beatmapset_id)
|
||||
return await BeatmapsetResp.from_db(
|
||||
beatmapset, session=db, include=["recent_favourites"], user=current_user
|
||||
)
|
||||
return await BeatmapsetResp.from_db(beatmapset, session=db, include=["recent_favourites"], user=current_user)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
|
||||
@@ -165,9 +159,7 @@ async def download_beatmapset(
|
||||
country_code = geo_info.get("country_iso", "")
|
||||
|
||||
# 优先使用IP地理位置判断,如果获取失败则回退到用户账户的国家代码
|
||||
is_china = country_code == "CN" or (
|
||||
not country_code and current_user.country_code == "CN"
|
||||
)
|
||||
is_china = country_code == "CN" or (not country_code and current_user.country_code == "CN")
|
||||
|
||||
try:
|
||||
# 使用负载均衡服务获取下载URL
|
||||
@@ -179,13 +171,10 @@ async def download_beatmapset(
|
||||
# 如果负载均衡服务失败,回退到原有逻辑
|
||||
if is_china:
|
||||
return RedirectResponse(
|
||||
f"https://dl.sayobot.cn/beatmaps/download/"
|
||||
f"{'novideo' if no_video else 'full'}/{beatmapset_id}"
|
||||
f"https://dl.sayobot.cn/beatmaps/download/{'novideo' if no_video else 'full'}/{beatmapset_id}"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}"
|
||||
)
|
||||
return RedirectResponse(f"https://api.nerinyan.moe/d/{beatmapset_id}?noVideo={no_video}")
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -197,12 +186,9 @@ async def download_beatmapset(
|
||||
async def favourite_beatmapset(
|
||||
db: Database,
|
||||
beatmapset_id: int = Path(..., description="谱面集 ID"),
|
||||
action: Literal["favourite", "unfavourite"] = Form(
|
||||
description="操作类型:favourite 收藏 / unfavourite 取消收藏"
|
||||
),
|
||||
action: Literal["favourite", "unfavourite"] = Form(description="操作类型:favourite 收藏 / unfavourite 取消收藏"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
existing_favourite = (
|
||||
await db.exec(
|
||||
select(FavouriteBeatmapset).where(
|
||||
@@ -212,15 +198,11 @@ async def favourite_beatmapset(
|
||||
)
|
||||
).first()
|
||||
|
||||
if (action == "favourite" and existing_favourite) or (
|
||||
action == "unfavourite" and not existing_favourite
|
||||
):
|
||||
if (action == "favourite" and existing_favourite) or (action == "unfavourite" and not existing_favourite):
|
||||
return
|
||||
|
||||
if action == "favourite":
|
||||
favourite = FavouriteBeatmapset(
|
||||
user_id=current_user.id, beatmapset_id=beatmapset_id
|
||||
)
|
||||
favourite = FavouriteBeatmapset(user_id=current_user.id, beatmapset_id=beatmapset_id)
|
||||
db.add(favourite)
|
||||
else:
|
||||
await db.delete(existing_favourite)
|
||||
|
||||
@@ -4,8 +4,8 @@ from app.database import User
|
||||
from app.database.lazer_user import ALL_INCLUDED
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import Database
|
||||
from app.models.score import GameMode
|
||||
from app.models.api_me import APIMe
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .router import router
|
||||
|
||||
|
||||
@@ -33,6 +33,4 @@ class BackgroundsResp(BaseModel):
|
||||
description="获取当前季节背景图列表。",
|
||||
)
|
||||
async def get_seasonal_backgrounds():
|
||||
return BackgroundsResp(
|
||||
backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds]
|
||||
)
|
||||
return BackgroundsResp(backgrounds=[Background(url=url) for url in settings.seasonal_backgrounds])
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.service.ranking_cache_service import get_ranking_cache_service
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Path, Query, Security
|
||||
from fastapi import BackgroundTasks, Path, Query, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, select
|
||||
|
||||
@@ -38,6 +38,7 @@ class CountryResponse(BaseModel):
|
||||
)
|
||||
async def get_country_ranking(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
page: int = Query(1, ge=1, description="页码"), # TODO
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -51,9 +52,7 @@ async def get_country_ranking(
|
||||
|
||||
if cached_data:
|
||||
# 从缓存返回数据
|
||||
return CountryResponse(
|
||||
ranking=[CountryStatistics.model_validate(item) for item in cached_data]
|
||||
)
|
||||
return CountryResponse(ranking=[CountryStatistics.model_validate(item) for item in cached_data])
|
||||
|
||||
# 缓存未命中,从数据库查询
|
||||
response = CountryResponse(ranking=[])
|
||||
@@ -105,14 +104,15 @@ async def get_country_ranking(
|
||||
|
||||
# 异步缓存数据(不等待完成)
|
||||
cache_data = [item.model_dump() for item in current_page_data]
|
||||
cache_task = cache_service.cache_country_ranking(
|
||||
ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60
|
||||
)
|
||||
|
||||
# 创建后台任务来缓存数据
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(cache_task)
|
||||
background_tasks.add_task(
|
||||
cache_service.cache_country_ranking,
|
||||
ruleset,
|
||||
cache_data,
|
||||
page,
|
||||
ttl=settings.ranking_cache_expire_minutes * 60,
|
||||
)
|
||||
|
||||
# 返回当前页的结果
|
||||
response.ranking = current_page_data
|
||||
@@ -132,10 +132,9 @@ class TopUsersResponse(BaseModel):
|
||||
)
|
||||
async def get_user_ranking(
|
||||
session: Database,
|
||||
background_tasks: BackgroundTasks,
|
||||
ruleset: GameMode = Path(..., description="指定 ruleset"),
|
||||
type: Literal["performance", "score"] = Path(
|
||||
..., description="排名类型:performance 表现分 / score 计分成绩总分"
|
||||
),
|
||||
type: Literal["performance", "score"] = Path(..., description="排名类型:performance 表现分 / score 计分成绩总分"),
|
||||
country: str | None = Query(None, description="国家代码"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -149,9 +148,7 @@ async def get_user_ranking(
|
||||
|
||||
if cached_data:
|
||||
# 从缓存返回数据
|
||||
return TopUsersResponse(
|
||||
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]
|
||||
)
|
||||
return TopUsersResponse(ranking=[UserStatisticsResp.model_validate(item) for item in cached_data])
|
||||
|
||||
# 缓存未命中,从数据库查询
|
||||
wheres = [
|
||||
@@ -169,25 +166,22 @@ async def get_user_ranking(
|
||||
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
|
||||
|
||||
statistics_list = await session.exec(
|
||||
select(UserStatistics)
|
||||
.where(*wheres)
|
||||
.order_by(order_by)
|
||||
.limit(50)
|
||||
.offset(50 * (page - 1))
|
||||
select(UserStatistics).where(*wheres).order_by(order_by).limit(50).offset(50 * (page - 1))
|
||||
)
|
||||
|
||||
# 转换为响应格式
|
||||
ranking_data = []
|
||||
for statistics in statistics_list:
|
||||
user_stats_resp = await UserStatisticsResp.from_db(
|
||||
statistics, session, None, include
|
||||
)
|
||||
user_stats_resp = await UserStatisticsResp.from_db(statistics, session, None, include)
|
||||
ranking_data.append(user_stats_resp)
|
||||
|
||||
# 异步缓存数据(不等待完成)
|
||||
# 使用配置文件中的TTL设置
|
||||
cache_data = [item.model_dump() for item in ranking_data]
|
||||
cache_task = cache_service.cache_ranking(
|
||||
# 创建后台任务来缓存数据
|
||||
|
||||
background_tasks.add_task(
|
||||
cache_service.cache_ranking,
|
||||
ruleset,
|
||||
type,
|
||||
cache_data,
|
||||
@@ -196,139 +190,134 @@ async def get_user_ranking(
|
||||
ttl=settings.ranking_cache_expire_minutes * 60,
|
||||
)
|
||||
|
||||
# 创建后台任务来缓存数据
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(cache_task)
|
||||
|
||||
resp = TopUsersResponse(ranking=ranking_data)
|
||||
return resp
|
||||
|
||||
|
||||
""" @router.post(
|
||||
"/rankings/cache/refresh",
|
||||
name="刷新排行榜缓存",
|
||||
description="手动刷新排行榜缓存(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def refresh_ranking_cache(
|
||||
session: Database,
|
||||
ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
|
||||
type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
|
||||
country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
|
||||
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
if ruleset and type:
|
||||
# 刷新特定的用户排行榜
|
||||
await cache_service.refresh_ranking_cache(session, ruleset, type, country)
|
||||
message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
|
||||
|
||||
# 如果请求刷新地区排行榜
|
||||
if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
|
||||
await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
message += f" and country ranking for {ruleset}"
|
||||
|
||||
return {"message": message}
|
||||
elif ruleset:
|
||||
# 刷新特定游戏模式的所有排行榜
|
||||
ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
|
||||
for ranking_type in ranking_types:
|
||||
await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
|
||||
|
||||
if include_country_ranking:
|
||||
await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
|
||||
return {"message": f"Refreshed all ranking caches for {ruleset}"}
|
||||
else:
|
||||
# 刷新所有排行榜
|
||||
await cache_service.refresh_all_rankings(session)
|
||||
return {"message": "Refreshed all ranking caches"}
|
||||
# @router.post(
|
||||
# "/rankings/cache/refresh",
|
||||
# name="刷新排行榜缓存",
|
||||
# description="手动刷新排行榜缓存(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def refresh_ranking_cache(
|
||||
# session: Database,
|
||||
# ruleset: GameMode | None = Query(None, description="指定要刷新的游戏模式,不指定则刷新所有"),
|
||||
# type: Literal["performance", "score"] | None = Query(None, description="指定要刷新的排名类型,不指定则刷新所有"),
|
||||
# country: str | None = Query(None, description="指定要刷新的国家,不指定则刷新所有"),
|
||||
# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# if ruleset and type:
|
||||
# # 刷新特定的用户排行榜
|
||||
# await cache_service.refresh_ranking_cache(session, ruleset, type, country)
|
||||
# message = f"Refreshed ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
|
||||
|
||||
# # 如果请求刷新地区排行榜
|
||||
# if include_country_ranking and not country: # 地区排行榜不依赖于国家参数
|
||||
# await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
# message += f" and country ranking for {ruleset}"
|
||||
|
||||
# return {"message": message}
|
||||
# elif ruleset:
|
||||
# # 刷新特定游戏模式的所有排行榜
|
||||
# ranking_types: list[Literal["performance", "score"]] = ["performance", "score"]
|
||||
# for ranking_type in ranking_types:
|
||||
# await cache_service.refresh_ranking_cache(session, ruleset, ranking_type, country)
|
||||
|
||||
# if include_country_ranking:
|
||||
# await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
|
||||
# return {"message": f"Refreshed all ranking caches for {ruleset}"}
|
||||
# else:
|
||||
# # 刷新所有排行榜
|
||||
# await cache_service.refresh_all_rankings(session)
|
||||
# return {"message": "Refreshed all ranking caches"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rankings/{ruleset}/country/cache/refresh",
|
||||
name="刷新地区排行榜缓存",
|
||||
description="手动刷新地区排行榜缓存(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def refresh_country_ranking_cache(
|
||||
session: Database,
|
||||
ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
return {"message": f"Refreshed country ranking cache for {ruleset}"}
|
||||
# @router.post(
|
||||
# "/rankings/{ruleset}/country/cache/refresh",
|
||||
# name="刷新地区排行榜缓存",
|
||||
# description="手动刷新地区排行榜缓存(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def refresh_country_ranking_cache(
|
||||
# session: Database,
|
||||
# ruleset: GameMode = Path(..., description="指定要刷新的游戏模式"),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# await cache_service.refresh_country_ranking_cache(session, ruleset)
|
||||
# return {"message": f"Refreshed country ranking cache for {ruleset}"}
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/rankings/cache",
|
||||
name="清除排行榜缓存",
|
||||
description="清除排行榜缓存(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def clear_ranking_cache(
|
||||
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
|
||||
type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
|
||||
country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
|
||||
include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
|
||||
|
||||
if ruleset and type:
|
||||
message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
|
||||
if include_country_ranking:
|
||||
message += " and country ranking"
|
||||
return {"message": message}
|
||||
else:
|
||||
message = "Cleared all ranking caches"
|
||||
if include_country_ranking:
|
||||
message += " including country rankings"
|
||||
return {"message": message}
|
||||
# @router.delete(
|
||||
# "/rankings/cache",
|
||||
# name="清除排行榜缓存",
|
||||
# description="清除排行榜缓存(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def clear_ranking_cache(
|
||||
# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
|
||||
# type: Literal["performance", "score"] | None = Query(None, description="指定要清除的排名类型,不指定则清除所有"),
|
||||
# country: str | None = Query(None, description="指定要清除的国家,不指定则清除所有"),
|
||||
# include_country_ranking: bool = Query(True, description="是否包含地区排行榜"),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# await cache_service.invalidate_cache(ruleset, type, country, include_country_ranking)
|
||||
|
||||
# if ruleset and type:
|
||||
# message = f"Cleared ranking cache for {ruleset}:{type}" + (f" in {country}" if country else "")
|
||||
# if include_country_ranking:
|
||||
# message += " and country ranking"
|
||||
# return {"message": message}
|
||||
# else:
|
||||
# message = "Cleared all ranking caches"
|
||||
# if include_country_ranking:
|
||||
# message += " including country rankings"
|
||||
# return {"message": message}
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/rankings/{ruleset}/country/cache",
|
||||
name="清除地区排行榜缓存",
|
||||
description="清除地区排行榜缓存(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def clear_country_ranking_cache(
|
||||
ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
await cache_service.invalidate_country_cache(ruleset)
|
||||
|
||||
if ruleset:
|
||||
return {"message": f"Cleared country ranking cache for {ruleset}"}
|
||||
else:
|
||||
return {"message": "Cleared all country ranking caches"}
|
||||
# @router.delete(
|
||||
# "/rankings/{ruleset}/country/cache",
|
||||
# name="清除地区排行榜缓存",
|
||||
# description="清除地区排行榜缓存(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def clear_country_ranking_cache(
|
||||
# ruleset: GameMode | None = Query(None, description="指定要清除的游戏模式,不指定则清除所有"),
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# await cache_service.invalidate_country_cache(ruleset)
|
||||
|
||||
# if ruleset:
|
||||
# return {"message": f"Cleared country ranking cache for {ruleset}"}
|
||||
# else:
|
||||
# return {"message": "Cleared all country ranking caches"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rankings/cache/stats",
|
||||
name="获取排行榜缓存统计",
|
||||
description="获取排行榜缓存统计信息(管理员功能)",
|
||||
tags=["排行榜", "管理"],
|
||||
)
|
||||
async def get_ranking_cache_stats(
|
||||
current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
stats = await cache_service.get_cache_stats()
|
||||
return stats """
|
||||
# @router.get(
|
||||
# "/rankings/cache/stats",
|
||||
# name="获取排行榜缓存统计",
|
||||
# description="获取排行榜缓存统计信息(管理员功能)",
|
||||
# tags=["排行榜", "管理"],
|
||||
# )
|
||||
# async def get_ranking_cache_stats(
|
||||
# current_user: User = Security(get_current_user, scopes=["admin"]), # 需要管理员权限
|
||||
# ):
|
||||
# redis = get_redis()
|
||||
# cache_service = get_ranking_cache_service(redis)
|
||||
|
||||
# stats = await cache_service.get_cache_stats()
|
||||
# return stats
|
||||
|
||||
@@ -30,11 +30,7 @@ async def get_relationship(
|
||||
request: Request,
|
||||
current_user: User = Security(get_current_user, scopes=["friends.read"]),
|
||||
):
|
||||
relationship_type = (
|
||||
RelationshipType.FOLLOW
|
||||
if request.url.path.endswith("/friends")
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
|
||||
relationships = await db.exec(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == current_user.id,
|
||||
@@ -71,12 +67,7 @@ async def add_relationship(
|
||||
target: int = Query(description="目标用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
relationship_type = (
|
||||
RelationshipType.FOLLOW
|
||||
if request.url.path.endswith("/friends")
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
|
||||
if target == current_user.id:
|
||||
raise HTTPException(422, "Cannot add relationship to yourself")
|
||||
relationship = (
|
||||
@@ -120,11 +111,8 @@ async def add_relationship(
|
||||
Relationship.target_id == target,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
assert relationship, "Relationship should exist after commit"
|
||||
return AddFriendResp(
|
||||
user_relation=await RelationshipResp.from_db(db, relationship)
|
||||
)
|
||||
).one()
|
||||
return AddFriendResp(user_relation=await RelationshipResp.from_db(db, relationship))
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -145,11 +133,7 @@ async def delete_relationship(
|
||||
target: int = Path(..., description="目标用户 ID"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
relationship_type = (
|
||||
RelationshipType.BLOCK
|
||||
if "/blocks/" in request.url.path
|
||||
else RelationshipType.FOLLOW
|
||||
)
|
||||
relationship_type = RelationshipType.BLOCK if "/blocks/" in request.url.path else RelationshipType.FOLLOW
|
||||
relationship = (
|
||||
await db.exec(
|
||||
select(Relationship).where(
|
||||
|
||||
@@ -39,17 +39,11 @@ async def get_all_rooms(
|
||||
db: Database,
|
||||
mode: Literal["open", "ended", "participated", "owned", None] = Query(
|
||||
default="open",
|
||||
description=(
|
||||
"房间模式:open 当前开放 / ended 已经结束 / "
|
||||
"participated 参与过 / owned 自己创建的房间"
|
||||
),
|
||||
description=("房间模式:open 当前开放 / ended 已经结束 / participated 参与过 / owned 自己创建的房间"),
|
||||
),
|
||||
category: RoomCategory = Query(
|
||||
RoomCategory.NORMAL,
|
||||
description=(
|
||||
"房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
|
||||
" / DAILY_CHALLENGE 每日挑战"
|
||||
),
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战"),
|
||||
),
|
||||
status: RoomStatus | None = Query(None, description="房间状态(可选)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -60,10 +54,7 @@ async def get_all_rooms(
|
||||
if status is not None:
|
||||
where_clauses.append(col(Room.status) == status)
|
||||
if mode == "open":
|
||||
where_clauses.append(
|
||||
(col(Room.ends_at).is_(None))
|
||||
| (col(Room.ends_at) > now.replace(tzinfo=UTC))
|
||||
)
|
||||
where_clauses.append((col(Room.ends_at).is_(None)) | (col(Room.ends_at) > now.replace(tzinfo=UTC)))
|
||||
if category == RoomCategory.REALTIME:
|
||||
where_clauses.append(col(Room.id).in_(MultiplayerHubs.rooms.keys()))
|
||||
if mode == "participated":
|
||||
@@ -76,10 +67,7 @@ async def get_all_rooms(
|
||||
if mode == "owned":
|
||||
where_clauses.append(col(Room.host_id) == current_user.id)
|
||||
if mode == "ended":
|
||||
where_clauses.append(
|
||||
(col(Room.ends_at).is_not(None))
|
||||
& (col(Room.ends_at) < now.replace(tzinfo=UTC))
|
||||
)
|
||||
where_clauses.append((col(Room.ends_at).is_not(None)) & (col(Room.ends_at) < now.replace(tzinfo=UTC)))
|
||||
|
||||
db_rooms = (
|
||||
(
|
||||
@@ -97,11 +85,7 @@ async def get_all_rooms(
|
||||
resp = await RoomResp.from_db(room, db)
|
||||
if category == RoomCategory.REALTIME:
|
||||
mp_room = MultiplayerHubs.rooms.get(room.id)
|
||||
resp.has_password = (
|
||||
bool(mp_room.room.settings.password.strip())
|
||||
if mp_room is not None
|
||||
else False
|
||||
)
|
||||
resp.has_password = bool(mp_room.room.settings.password.strip()) if mp_room is not None else False
|
||||
resp.category = RoomCategory.NORMAL
|
||||
resp_list.append(resp)
|
||||
|
||||
@@ -115,9 +99,7 @@ class APICreatedRoom(RoomResp):
|
||||
error: str = ""
|
||||
|
||||
|
||||
async def _participate_room(
|
||||
room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis
|
||||
):
|
||||
async def _participate_room(room_id: int, user_id: int, db_room: Room, session: AsyncSession, redis: Redis):
|
||||
participated_user = (
|
||||
await session.exec(
|
||||
select(RoomParticipatedUser).where(
|
||||
@@ -154,7 +136,6 @@ async def create_room(
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
user_id = current_user.id
|
||||
db_room = await create_playlist_room_from_api(db, room, user_id)
|
||||
await _participate_room(db_room.id, user_id, db_room, db, redis)
|
||||
@@ -177,10 +158,7 @@ async def get_room(
|
||||
room_id: int = Path(..., description="房间 ID"),
|
||||
category: str = Query(
|
||||
default="",
|
||||
description=(
|
||||
"房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间"
|
||||
" / DAILY_CHALLENGE 每日挑战 (可选)"
|
||||
),
|
||||
description=("房间分类:NORMAL 普通歌单模式房间 / REALTIME 多人游戏房间 / DAILY_CHALLENGE 每日挑战 (可选)"),
|
||||
),
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
@@ -188,9 +166,7 @@ async def get_room(
|
||||
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
|
||||
if db_room is None:
|
||||
raise HTTPException(404, "Room not found")
|
||||
resp = await RoomResp.from_db(
|
||||
db_room, include=["current_user_score"], session=db, user=current_user
|
||||
)
|
||||
resp = await RoomResp.from_db(db_room, include=["current_user_score"], session=db, user=current_user)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -400,7 +376,6 @@ async def get_room_events(
|
||||
for score in scores:
|
||||
user_ids.add(score.user_id)
|
||||
beatmap_ids.add(score.beatmap_id)
|
||||
assert event.id is not None
|
||||
first_event_id = min(first_event_id, event.id)
|
||||
last_event_id = max(last_event_id, event.id)
|
||||
|
||||
@@ -416,16 +391,12 @@ async def get_room_events(
|
||||
users = await db.exec(select(User).where(col(User.id).in_(user_ids)))
|
||||
user_resps = [await UserResp.from_db(user, db) for user in users]
|
||||
beatmaps = await db.exec(select(Beatmap).where(col(Beatmap.id).in_(beatmap_ids)))
|
||||
beatmap_resps = [
|
||||
await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps
|
||||
]
|
||||
beatmap_resps = [await BeatmapResp.from_db(beatmap, session=db) for beatmap in beatmaps]
|
||||
beatmapset_resps = {}
|
||||
for beatmap_resp in beatmap_resps:
|
||||
beatmapset_resps[beatmap_resp.beatmapset_id] = beatmap_resp.beatmapset
|
||||
|
||||
playlist_items_resps = [
|
||||
await PlaylistResp.from_db(item) for item in playlist_items.values()
|
||||
]
|
||||
playlist_items_resps = [await PlaylistResp.from_db(item) for item in playlist_items.values()]
|
||||
|
||||
return RoomEvents(
|
||||
beatmaps=beatmap_resps,
|
||||
|
||||
@@ -104,11 +104,7 @@ async def submit_score(
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
score_token = (
|
||||
await db.exec(
|
||||
select(ScoreToken)
|
||||
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(ScoreToken.id == token)
|
||||
)
|
||||
await db.exec(select(ScoreToken).options(joinedload(ScoreToken.beatmap)).where(ScoreToken.id == token))
|
||||
).first()
|
||||
if not score_token or score_token.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Score token not found")
|
||||
@@ -138,10 +134,7 @@ async def submit_score(
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
has_pp = db_beatmap.beatmap_status.has_pp() | settings.enable_all_beatmap_pp
|
||||
has_leaderboard = (
|
||||
db_beatmap.beatmap_status.has_leaderboard()
|
||||
| settings.enable_all_beatmap_leaderboard
|
||||
)
|
||||
has_leaderboard = db_beatmap.beatmap_status.has_leaderboard() | settings.enable_all_beatmap_leaderboard
|
||||
beatmap_length = db_beatmap.total_length
|
||||
score = await process_score(
|
||||
current_user,
|
||||
@@ -167,21 +160,11 @@ async def submit_score(
|
||||
has_pp,
|
||||
has_leaderboard,
|
||||
)
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
|
||||
.where(Score.id == score_id)
|
||||
)
|
||||
).first()
|
||||
assert score is not None
|
||||
score = (await db.exec(select(Score).options(joinedload(Score.user)).where(Score.id == score_id))).one()
|
||||
|
||||
resp = await ScoreResp.from_db(db, score)
|
||||
total_users = (await db.exec(select(func.count()).select_from(User))).first()
|
||||
assert total_users is not None
|
||||
if resp.rank_global is not None and resp.rank_global <= min(
|
||||
math.ceil(float(total_users) * 0.01), 50
|
||||
):
|
||||
total_users = (await db.exec(select(func.count()).select_from(User))).one()
|
||||
if resp.rank_global is not None and resp.rank_global <= min(math.ceil(float(total_users) * 0.01), 50):
|
||||
rank_event = Event(
|
||||
created_at=datetime.now(UTC),
|
||||
type=EventType.RANK,
|
||||
@@ -207,9 +190,7 @@ async def submit_score(
|
||||
score_gamemode = score.gamemode
|
||||
|
||||
if user_id is not None:
|
||||
background_task.add_task(
|
||||
_refresh_user_cache_background, redis, user_id, score_gamemode
|
||||
)
|
||||
background_task.add_task(_refresh_user_cache_background, redis, user_id, score_gamemode)
|
||||
background_task.add_task(process_user_achievement, resp.id)
|
||||
return resp
|
||||
|
||||
@@ -225,9 +206,7 @@ async def _refresh_user_cache_background(redis: Redis, user_id: int, mode: GameM
|
||||
# 创建独立的数据库会话
|
||||
session = AsyncSession(engine)
|
||||
try:
|
||||
await user_cache_service.refresh_user_cache_on_score_submit(
|
||||
session, user_id, mode
|
||||
)
|
||||
await user_cache_service.refresh_user_cache_on_score_submit(session, user_id, mode)
|
||||
finally:
|
||||
await session.close()
|
||||
except Exception as e:
|
||||
@@ -280,22 +259,16 @@ async def get_beatmap_scores(
|
||||
beatmap_id: int = Path(description="谱面 ID"),
|
||||
mode: GameMode = Query(description="指定 auleset"),
|
||||
legacy_only: bool = Query(None, description="是否只查询 Stable 分数"),
|
||||
mods: list[str] = Query(
|
||||
default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"
|
||||
),
|
||||
mods: list[str] = Query(default_factory=set, alias="mods[]", description="筛选使用的 Mods (可选,多值)"),
|
||||
type: LeaderboardType = Query(
|
||||
LeaderboardType.GLOBAL,
|
||||
description=(
|
||||
"排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"
|
||||
),
|
||||
description=("排行榜类型:GLOBAL 全局 / COUNTRY 国家 / FRIENDS 好友 / TEAM 战队"),
|
||||
),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回条数 (1-200)"),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="this server only contains lazer scores"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="this server only contains lazer scores")
|
||||
|
||||
all_scores, user_score, count = await get_leaderboard(
|
||||
db,
|
||||
@@ -310,9 +283,7 @@ async def get_beatmap_scores(
|
||||
user_score_resp = await ScoreResp.from_db(db, user_score) if user_score else None
|
||||
resp = BeatmapScores(
|
||||
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
|
||||
user_score=BeatmapUserScore(
|
||||
score=user_score_resp, position=user_score_resp.rank_global or 0
|
||||
)
|
||||
user_score=BeatmapUserScore(score=user_score_resp, position=user_score_resp.rank_global or 0)
|
||||
if user_score_resp
|
||||
else None,
|
||||
score_count=count,
|
||||
@@ -342,9 +313,7 @@ async def get_user_beatmap_score(
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
@@ -386,9 +355,7 @@ async def get_user_all_beatmap_scores(
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="This server only contains non-legacy scores")
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
@@ -420,7 +387,6 @@ async def create_solo_score(
|
||||
ruleset_id: int = Form(..., ge=0, le=3, description="ruleset 数字 ID (0-3)"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
@@ -454,10 +420,7 @@ async def submit_solo_score(
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher=Depends(get_fetcher),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
return await submit_score(
|
||||
background_task, info, beatmap_id, token, current_user, db, redis, fetcher
|
||||
)
|
||||
return await submit_score(background_task, info, beatmap_id, token, current_user, db, redis, fetcher)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -478,7 +441,6 @@ async def create_playlist_score(
|
||||
version_hash: str = Form("", description="谱面版本哈希"),
|
||||
current_user: User = Security(get_client_user),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
@@ -488,26 +450,16 @@ async def create_playlist_score(
|
||||
db_room_time = room.ends_at.replace(tzinfo=UTC) if room.ends_at else None
|
||||
if db_room_time and db_room_time < datetime.now(UTC).replace(tzinfo=UTC):
|
||||
raise HTTPException(status_code=400, detail="Room has ended")
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Playlist not found")
|
||||
|
||||
# validate
|
||||
if not item.freestyle:
|
||||
if item.ruleset_id != ruleset_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Ruleset mismatch in playlist item"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Ruleset mismatch in playlist item")
|
||||
if item.beatmap_id != beatmap_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Beatmap ID mismatch in playlist item"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Beatmap ID mismatch in playlist item")
|
||||
agg = await session.exec(
|
||||
select(ItemAttemptsCount).where(
|
||||
ItemAttemptsCount.room_id == room_id,
|
||||
@@ -523,9 +475,7 @@ async def create_playlist_score(
|
||||
if item.expired:
|
||||
raise HTTPException(status_code=400, detail="Playlist item has expired")
|
||||
if item.played_at:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Playlist item has already been played"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Playlist item has already been played")
|
||||
# 这里应该不用验证mod了吧。。。
|
||||
background_task.add_task(_preload_beatmap_for_pp_calculation, beatmap_id)
|
||||
score_token = ScoreToken(
|
||||
@@ -557,18 +507,10 @@ async def submit_playlist_score(
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
assert current_user.id is not None
|
||||
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
item = (
|
||||
await session.exec(
|
||||
select(Playlist).where(
|
||||
Playlist.id == playlist_id, Playlist.room_id == room_id
|
||||
)
|
||||
)
|
||||
).first()
|
||||
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Playlist item not found")
|
||||
room = await session.get(Room, room_id)
|
||||
@@ -621,9 +563,7 @@ async def index_playlist_scores(
|
||||
room_id: int,
|
||||
playlist_id: int,
|
||||
limit: int = Query(50, ge=1, le=50, description="返回条数 (1-50)"),
|
||||
cursor: int = Query(
|
||||
2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"
|
||||
),
|
||||
cursor: int = Query(2000000, alias="cursor[total_score]", description="分页游标(上一页最低分)"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
@@ -693,9 +633,6 @@ async def show_playlist_score(
|
||||
current_user: User = Security(get_client_user),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
room = await session.get(Room, room_id)
|
||||
if not room:
|
||||
raise HTTPException(status_code=404, detail="Room not found")
|
||||
@@ -715,9 +652,7 @@ async def show_playlist_score(
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if completed_players := await redis.get(
|
||||
f"multiplayer:{room_id}:gameplay:players"
|
||||
):
|
||||
if completed_players := await redis.get(f"multiplayer:{room_id}:gameplay:players"):
|
||||
completed = completed_players == "0"
|
||||
if score_record and completed:
|
||||
break
|
||||
@@ -784,9 +719,7 @@ async def get_user_playlist_score(
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
resp = await ScoreResp.from_db(session, score_record.score)
|
||||
resp.position = await get_position(
|
||||
room_id, playlist_id, score_record.score_id, session
|
||||
)
|
||||
resp.position = await get_position(room_id, playlist_id, score_record.score_id, session)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -850,11 +783,7 @@ async def unpin_score(
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
score_record = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.id == score_id, Score.user_id == user_id)
|
||||
)
|
||||
).first()
|
||||
score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
@@ -878,10 +807,7 @@ async def unpin_score(
|
||||
"/score-pins/{score_id}/reorder",
|
||||
status_code=204,
|
||||
name="调整置顶成绩顺序",
|
||||
description=(
|
||||
"**客户端专属**\n调整已置顶成绩的展示顺序。"
|
||||
"仅提供 after_score_id 或 before_score_id 之一。"
|
||||
),
|
||||
description=("**客户端专属**\n调整已置顶成绩的展示顺序。仅提供 after_score_id 或 before_score_id 之一。"),
|
||||
tags=["成绩"],
|
||||
)
|
||||
async def reorder_score_pin(
|
||||
@@ -894,11 +820,7 @@ async def reorder_score_pin(
|
||||
# 立即获取用户ID,避免懒加载问题
|
||||
user_id = current_user.id
|
||||
|
||||
score_record = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.id == score_id, Score.user_id == user_id)
|
||||
)
|
||||
).first()
|
||||
score_record = (await db.exec(select(Score).where(Score.id == score_id, Score.user_id == user_id))).first()
|
||||
if not score_record:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
|
||||
@@ -908,8 +830,7 @@ async def reorder_score_pin(
|
||||
if (after_score_id is None) == (before_score_id is None):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either after_score_id or before_score_id "
|
||||
"must be provided (but not both)",
|
||||
detail="Either after_score_id or before_score_id must be provided (but not both)",
|
||||
)
|
||||
|
||||
all_pinned_scores = (
|
||||
@@ -927,9 +848,7 @@ async def reorder_score_pin(
|
||||
target_order = None
|
||||
reference_score_id = after_score_id or before_score_id
|
||||
|
||||
reference_score = next(
|
||||
(s for s in all_pinned_scores if s.id == reference_score_id), None
|
||||
)
|
||||
reference_score = next((s for s in all_pinned_scores if s.id == reference_score_id), None)
|
||||
if not reference_score:
|
||||
detail = "After score not found" if after_score_id else "Before score not found"
|
||||
raise HTTPException(status_code=404, detail=detail)
|
||||
@@ -951,9 +870,7 @@ async def reorder_score_pin(
|
||||
if current_order < s.pinned_order <= target_order and s.id != score_id:
|
||||
updates.append((s.id, s.pinned_order - 1))
|
||||
if after_score_id:
|
||||
final_target = (
|
||||
target_order - 1 if target_order > current_order else target_order
|
||||
)
|
||||
final_target = target_order - 1 if target_order > current_order else target_order
|
||||
else:
|
||||
final_target = target_order
|
||||
else:
|
||||
@@ -964,9 +881,7 @@ async def reorder_score_pin(
|
||||
|
||||
for score_id, new_order in updates:
|
||||
await db.exec(select(Score).where(Score.id == score_id))
|
||||
score_to_update = (
|
||||
await db.exec(select(Score).where(Score.id == score_id))
|
||||
).first()
|
||||
score_to_update = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||
if score_to_update:
|
||||
score_to_update.pinned_order = new_order
|
||||
|
||||
|
||||
@@ -4,34 +4,29 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, UTC
|
||||
from typing import Annotated
|
||||
|
||||
from app.auth import authenticate_user
|
||||
from app.config import settings
|
||||
from app.database import User
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.geoip import GeoIPHelper, get_geoip_helper
|
||||
from app.database.email_verification import EmailVerification, LoginSession
|
||||
from app.service.email_verification_service import (
|
||||
EmailVerificationService,
|
||||
LoginSessionService
|
||||
EmailVerificationService,
|
||||
LoginSessionService,
|
||||
)
|
||||
from app.service.login_log_service import LoginLogService
|
||||
from app.models.extended_auth import ExtendedTokenResponse
|
||||
|
||||
from fastapi import Form, Depends, Request, HTTPException, status, Security
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, Form, HTTPException, Request, Security, status
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
class SessionReissueResponse(BaseModel):
|
||||
"""重新发送验证码响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
@@ -40,39 +35,35 @@ class SessionReissueResponse(BaseModel):
|
||||
"/session/verify",
|
||||
name="验证会话",
|
||||
description="验证邮件验证码并完成会话认证",
|
||||
status_code=204
|
||||
status_code=204,
|
||||
)
|
||||
async def verify_session(
|
||||
request: Request,
|
||||
db: Database,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
verification_key: str = Form(..., description="8位邮件验证码"),
|
||||
current_user: User = Security(get_current_user)
|
||||
current_user: User = Security(get_current_user),
|
||||
) -> Response:
|
||||
"""
|
||||
验证邮件验证码并完成会话认证
|
||||
|
||||
|
||||
对应 osu! 的 session/verify 接口
|
||||
成功时返回 204 No Content,失败时返回 401 Unauthorized
|
||||
"""
|
||||
try:
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||
|
||||
|
||||
ip_address = get_client_ip(request) # noqa: F841
|
||||
user_agent = request.headers.get("User-Agent", "Unknown") # noqa: F841
|
||||
|
||||
# 从当前认证用户获取信息
|
||||
user_id = current_user.id
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户未认证"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未认证")
|
||||
|
||||
# 验证邮件验证码
|
||||
success, message = await EmailVerificationService.verify_code(
|
||||
db, redis, user_id, verification_key
|
||||
)
|
||||
|
||||
success, message = await EmailVerificationService.verify_code(db, redis, user_id, verification_key)
|
||||
|
||||
if success:
|
||||
# 记录成功的邮件验证
|
||||
await LoginLogService.record_login(
|
||||
@@ -81,9 +72,9 @@ async def verify_session(
|
||||
request=request,
|
||||
login_method="email_verification",
|
||||
login_success=True,
|
||||
notes=f"邮件验证成功"
|
||||
notes="邮件验证成功",
|
||||
)
|
||||
|
||||
|
||||
# 返回 204 No Content 表示验证成功
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
else:
|
||||
@@ -93,83 +84,69 @@ async def verify_session(
|
||||
request=request,
|
||||
attempted_username=current_user.username,
|
||||
login_method="email_verification",
|
||||
notes=f"邮件验证失败: {message}"
|
||||
notes=f"邮件验证失败: {message}",
|
||||
)
|
||||
|
||||
|
||||
# 返回 401 Unauthorized 表示验证失败
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=message
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=message)
|
||||
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的用户会话"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="验证过程中发生错误"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的用户会话")
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="验证过程中发生错误")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/session/verify/reissue",
|
||||
name="重新发送验证码",
|
||||
description="重新发送邮件验证码",
|
||||
response_model=SessionReissueResponse
|
||||
response_model=SessionReissueResponse,
|
||||
)
|
||||
async def reissue_verification_code(
|
||||
request: Request,
|
||||
db: Database,
|
||||
redis: Annotated[Redis, Depends(get_redis)],
|
||||
current_user: User = Security(get_current_user)
|
||||
current_user: User = Security(get_current_user),
|
||||
) -> SessionReissueResponse:
|
||||
"""
|
||||
重新发送邮件验证码
|
||||
|
||||
|
||||
对应 osu! 的 session/verify/reissue 接口
|
||||
"""
|
||||
try:
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
user_agent = request.headers.get("User-Agent", "Unknown")
|
||||
|
||||
|
||||
# 从当前认证用户获取信息
|
||||
user_id = current_user.id
|
||||
if not user_id:
|
||||
return SessionReissueResponse(
|
||||
success=False,
|
||||
message="用户未认证"
|
||||
)
|
||||
|
||||
return SessionReissueResponse(success=False, message="用户未认证")
|
||||
|
||||
# 重新发送验证码
|
||||
success, message = await EmailVerificationService.resend_verification_code(
|
||||
db, redis, user_id, current_user.username, current_user.email, ip_address, user_agent
|
||||
db,
|
||||
redis,
|
||||
user_id,
|
||||
current_user.username,
|
||||
current_user.email,
|
||||
ip_address,
|
||||
user_agent,
|
||||
)
|
||||
|
||||
return SessionReissueResponse(
|
||||
success=success,
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
return SessionReissueResponse(success=success, message=message)
|
||||
|
||||
except ValueError:
|
||||
return SessionReissueResponse(
|
||||
success=False,
|
||||
message="无效的用户会话"
|
||||
)
|
||||
except Exception as e:
|
||||
return SessionReissueResponse(
|
||||
success=False,
|
||||
message="重新发送过程中发生错误"
|
||||
)
|
||||
return SessionReissueResponse(success=False, message="无效的用户会话")
|
||||
except Exception:
|
||||
return SessionReissueResponse(success=False, message="重新发送过程中发生错误")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/session/check-new-location",
|
||||
name="检查新位置登录",
|
||||
description="检查登录是否来自新位置(内部接口)"
|
||||
description="检查登录是否来自新位置(内部接口)",
|
||||
)
|
||||
async def check_new_location(
|
||||
request: Request,
|
||||
@@ -183,22 +160,21 @@ async def check_new_location(
|
||||
"""
|
||||
try:
|
||||
from app.dependencies.geoip import get_client_ip
|
||||
|
||||
ip_address = get_client_ip(request)
|
||||
geo_info = geoip.lookup(ip_address)
|
||||
country_code = geo_info.get("country_iso", "XX")
|
||||
|
||||
is_new_location = await LoginSessionService.check_new_location(
|
||||
db, user_id, ip_address, country_code
|
||||
)
|
||||
|
||||
|
||||
is_new_location = await LoginSessionService.check_new_location(db, user_id, ip_address, country_code)
|
||||
|
||||
return {
|
||||
"is_new_location": is_new_location,
|
||||
"ip_address": ip_address,
|
||||
"country_code": country_code
|
||||
"country_code": country_code,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"is_new_location": True, # 出错时默认为新位置
|
||||
"error": str(e)
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
@@ -1,73 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from typing import Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from app.dependencies.database import get_redis, get_redis_message
|
||||
from app.log import logger
|
||||
from app.utils import bg_tasks
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Redis key constants
|
||||
REDIS_ONLINE_USERS_KEY = "server:online_users"
|
||||
REDIS_PLAYING_USERS_KEY = "server:playing_users"
|
||||
REDIS_PLAYING_USERS_KEY = "server:playing_users"
|
||||
REDIS_REGISTERED_USERS_KEY = "server:registered_users"
|
||||
REDIS_ONLINE_HISTORY_KEY = "server:online_history"
|
||||
|
||||
# 线程池用于同步Redis操作
|
||||
_executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
|
||||
async def _redis_exec(func, *args, **kwargs):
|
||||
"""在线程池中执行同步Redis操作"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(_executor, func, *args, **kwargs)
|
||||
|
||||
|
||||
class ServerStats(BaseModel):
|
||||
"""服务器统计信息响应模型"""
|
||||
|
||||
registered_users: int
|
||||
online_users: int
|
||||
playing_users: int
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class OnlineHistoryPoint(BaseModel):
|
||||
"""在线历史数据点"""
|
||||
|
||||
timestamp: datetime
|
||||
online_count: int
|
||||
playing_count: int
|
||||
|
||||
|
||||
class OnlineHistoryResponse(BaseModel):
|
||||
"""24小时在线历史响应模型"""
|
||||
|
||||
history: list[OnlineHistoryPoint]
|
||||
current_stats: ServerStats
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ServerStats, tags=["统计"])
|
||||
async def get_server_stats() -> ServerStats:
|
||||
"""
|
||||
获取服务器实时统计信息
|
||||
|
||||
|
||||
返回服务器注册用户数、在线用户数、正在游玩用户数等实时统计信息
|
||||
"""
|
||||
redis = get_redis()
|
||||
|
||||
|
||||
try:
|
||||
# 并行获取所有统计数据
|
||||
registered_count, online_count, playing_count = await asyncio.gather(
|
||||
_get_registered_users_count(redis),
|
||||
_get_online_users_count(redis),
|
||||
_get_playing_users_count(redis)
|
||||
_get_playing_users_count(redis),
|
||||
)
|
||||
|
||||
|
||||
return ServerStats(
|
||||
registered_users=registered_count,
|
||||
online_users=online_count,
|
||||
playing_users=playing_count,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting server stats: {e}")
|
||||
@@ -76,14 +83,15 @@ async def get_server_stats() -> ServerStats:
|
||||
registered_users=0,
|
||||
online_users=0,
|
||||
playing_users=0,
|
||||
timestamp=datetime.utcnow()
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats/history", response_model=OnlineHistoryResponse, tags=["统计"])
|
||||
async def get_online_history() -> OnlineHistoryResponse:
|
||||
"""
|
||||
获取最近24小时在线统计历史
|
||||
|
||||
|
||||
返回过去24小时内每小时的在线用户数和游玩用户数统计,
|
||||
包含当前实时数据作为最新数据点
|
||||
"""
|
||||
@@ -92,80 +100,80 @@ async def get_online_history() -> OnlineHistoryResponse:
|
||||
redis_sync = get_redis_message()
|
||||
history_data = await _redis_exec(redis_sync.lrange, REDIS_ONLINE_HISTORY_KEY, 0, -1)
|
||||
history_points = []
|
||||
|
||||
|
||||
# 处理历史数据
|
||||
for data in history_data:
|
||||
try:
|
||||
point_data = json.loads(data)
|
||||
# 只保留基本字段
|
||||
history_points.append(OnlineHistoryPoint(
|
||||
timestamp=datetime.fromisoformat(point_data["timestamp"]),
|
||||
online_count=point_data["online_count"],
|
||||
playing_count=point_data["playing_count"]
|
||||
))
|
||||
history_points.append(
|
||||
OnlineHistoryPoint(
|
||||
timestamp=datetime.fromisoformat(point_data["timestamp"]),
|
||||
online_count=point_data["online_count"],
|
||||
playing_count=point_data["playing_count"],
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning(f"Invalid history data point: {data}, error: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 获取当前实时统计信息
|
||||
current_stats = await get_server_stats()
|
||||
|
||||
|
||||
# 如果历史数据为空或者最新数据超过15分钟,添加当前数据点
|
||||
if not history_points or (
|
||||
history_points and
|
||||
(current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds() > 15 * 60
|
||||
history_points
|
||||
and (current_stats.timestamp - max(history_points, key=lambda x: x.timestamp).timestamp).total_seconds()
|
||||
> 15 * 60
|
||||
):
|
||||
history_points.append(OnlineHistoryPoint(
|
||||
timestamp=current_stats.timestamp,
|
||||
online_count=current_stats.online_users,
|
||||
playing_count=current_stats.playing_users
|
||||
))
|
||||
|
||||
history_points.append(
|
||||
OnlineHistoryPoint(
|
||||
timestamp=current_stats.timestamp,
|
||||
online_count=current_stats.online_users,
|
||||
playing_count=current_stats.playing_users,
|
||||
)
|
||||
)
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
history_points.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
|
||||
|
||||
# 限制到最多48个数据点(24小时)
|
||||
history_points = history_points[:48]
|
||||
|
||||
return OnlineHistoryResponse(
|
||||
history=history_points,
|
||||
current_stats=current_stats
|
||||
)
|
||||
|
||||
return OnlineHistoryResponse(history=history_points, current_stats=current_stats)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting online history: {e}")
|
||||
# 返回空历史和当前状态
|
||||
current_stats = await get_server_stats()
|
||||
return OnlineHistoryResponse(
|
||||
history=[],
|
||||
current_stats=current_stats
|
||||
)
|
||||
return OnlineHistoryResponse(history=[], current_stats=current_stats)
|
||||
|
||||
|
||||
@router.get("/stats/debug", tags=["统计"])
|
||||
async def get_stats_debug_info():
|
||||
"""
|
||||
获取统计系统调试信息
|
||||
|
||||
|
||||
用于调试时间对齐和区间统计问题
|
||||
"""
|
||||
try:
|
||||
from app.service.enhanced_interval_stats import EnhancedIntervalStatsManager
|
||||
|
||||
|
||||
current_time = datetime.utcnow()
|
||||
current_interval = await EnhancedIntervalStatsManager.get_current_interval_info()
|
||||
interval_stats = await EnhancedIntervalStatsManager.get_current_interval_stats()
|
||||
|
||||
|
||||
# 获取Redis中的实际数据
|
||||
redis_sync = get_redis_message()
|
||||
|
||||
|
||||
online_key = f"server:interval_online_users:{current_interval.interval_key}"
|
||||
playing_key = f"server:interval_playing_users:{current_interval.interval_key}"
|
||||
|
||||
|
||||
online_users_raw = await _redis_exec(redis_sync.smembers, online_key)
|
||||
playing_users_raw = await _redis_exec(redis_sync.smembers, playing_key)
|
||||
|
||||
|
||||
online_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in online_users_raw]
|
||||
playing_users = [int(uid.decode() if isinstance(uid, bytes) else uid) for uid in playing_users_raw]
|
||||
|
||||
|
||||
return {
|
||||
"current_time": current_time.isoformat(),
|
||||
"current_interval": {
|
||||
@@ -175,28 +183,29 @@ async def get_stats_debug_info():
|
||||
"is_current": current_interval.is_current(),
|
||||
"minutes_remaining": int((current_interval.end_time - current_time).total_seconds() / 60),
|
||||
"seconds_remaining": int((current_interval.end_time - current_time).total_seconds()),
|
||||
"progress_percentage": round((1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100, 1)
|
||||
"progress_percentage": round(
|
||||
(1 - (current_interval.end_time - current_time).total_seconds() / (30 * 60)) * 100,
|
||||
1,
|
||||
),
|
||||
},
|
||||
"interval_statistics": interval_stats.to_dict() if interval_stats else None,
|
||||
"redis_data": {
|
||||
"online_users": online_users,
|
||||
"playing_users": playing_users,
|
||||
"online_count": len(online_users),
|
||||
"playing_count": len(playing_users)
|
||||
"playing_count": len(playing_users),
|
||||
},
|
||||
"system_status": {
|
||||
"stats_system": "enhanced_interval_stats",
|
||||
"data_alignment": "30_minute_boundaries",
|
||||
"real_time_updates": True,
|
||||
"auto_24h_fill": True
|
||||
}
|
||||
"auto_24h_fill": True,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting debug info: {e}")
|
||||
return {
|
||||
"error": "Failed to retrieve debug information",
|
||||
"message": str(e)
|
||||
}
|
||||
return {"error": "Failed to retrieve debug information", "message": str(e)}
|
||||
|
||||
|
||||
async def _get_registered_users_count(redis) -> int:
|
||||
"""获取注册用户总数(从缓存)"""
|
||||
@@ -207,6 +216,7 @@ async def _get_registered_users_count(redis) -> int:
|
||||
logger.error(f"Error getting registered users count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _get_online_users_count(redis) -> int:
|
||||
"""获取当前在线用户数"""
|
||||
try:
|
||||
@@ -216,6 +226,7 @@ async def _get_online_users_count(redis) -> int:
|
||||
logger.error(f"Error getting online users count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def _get_playing_users_count(redis) -> int:
|
||||
"""获取当前游玩用户数"""
|
||||
try:
|
||||
@@ -225,27 +236,28 @@ async def _get_playing_users_count(redis) -> int:
|
||||
logger.error(f"Error getting playing users count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 统计更新功能
|
||||
async def update_registered_users_count() -> None:
|
||||
"""更新注册用户数缓存"""
|
||||
from app.dependencies.database import with_db
|
||||
from app.database import User
|
||||
from app.const import BANCHOBOT_ID
|
||||
from sqlmodel import select, func
|
||||
|
||||
from app.database import User
|
||||
from app.dependencies.database import with_db
|
||||
|
||||
from sqlmodel import func, select
|
||||
|
||||
redis = get_redis()
|
||||
try:
|
||||
async with with_db() as db:
|
||||
# 排除机器人用户(BANCHOBOT_ID)
|
||||
result = await db.exec(
|
||||
select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID)
|
||||
)
|
||||
result = await db.exec(select(func.count()).select_from(User).where(User.id != BANCHOBOT_ID))
|
||||
count = result.first()
|
||||
await redis.set(REDIS_REGISTERED_USERS_KEY, count or 0, ex=300) # 5分钟过期
|
||||
logger.debug(f"Updated registered users count: {count}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating registered users count: {e}")
|
||||
|
||||
|
||||
async def add_online_user(user_id: int) -> None:
|
||||
"""添加在线用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -257,14 +269,20 @@ async def add_online_user(user_id: int) -> None:
|
||||
if ttl <= 0: # -1表示永不过期,-2表示不存在,0表示已过期
|
||||
await redis_async.expire(REDIS_ONLINE_USERS_KEY, 3 * 3600) # 3小时过期
|
||||
logger.debug(f"Added online user {user_id}")
|
||||
|
||||
|
||||
# 立即更新当前区间统计
|
||||
from app.service.enhanced_interval_stats import update_user_activity_in_interval
|
||||
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=False))
|
||||
|
||||
|
||||
bg_tasks.add_task(
|
||||
update_user_activity_in_interval,
|
||||
user_id,
|
||||
is_playing=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding online user {user_id}: {e}")
|
||||
|
||||
|
||||
async def remove_online_user(user_id: int) -> None:
|
||||
"""移除在线用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -274,6 +292,7 @@ async def remove_online_user(user_id: int) -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing online user {user_id}: {e}")
|
||||
|
||||
|
||||
async def add_playing_user(user_id: int) -> None:
|
||||
"""添加游玩用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -285,14 +304,16 @@ async def add_playing_user(user_id: int) -> None:
|
||||
if ttl <= 0: # -1表示永不过期,-2表示不存在,0表示已过期
|
||||
await redis_async.expire(REDIS_PLAYING_USERS_KEY, 3 * 3600) # 3小时过期
|
||||
logger.debug(f"Added playing user {user_id}")
|
||||
|
||||
|
||||
# 立即更新当前区间统计
|
||||
from app.service.enhanced_interval_stats import update_user_activity_in_interval
|
||||
asyncio.create_task(update_user_activity_in_interval(user_id, is_playing=True))
|
||||
|
||||
|
||||
bg_tasks.add_task(update_user_activity_in_interval, user_id, is_playing=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding playing user {user_id}: {e}")
|
||||
|
||||
|
||||
async def remove_playing_user(user_id: int) -> None:
|
||||
"""移除游玩用户"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -301,6 +322,7 @@ async def remove_playing_user(user_id: int) -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing playing user {user_id}: {e}")
|
||||
|
||||
|
||||
async def record_hourly_stats() -> None:
|
||||
"""记录统计数据 - 简化版本,主要作为fallback使用"""
|
||||
redis_sync = get_redis_message()
|
||||
@@ -308,24 +330,27 @@ async def record_hourly_stats() -> None:
|
||||
try:
|
||||
# 先确保Redis连接正常
|
||||
await redis_async.ping()
|
||||
|
||||
|
||||
online_count = await _get_online_users_count(redis_async)
|
||||
playing_count = await _get_playing_users_count(redis_async)
|
||||
|
||||
|
||||
current_time = datetime.utcnow()
|
||||
history_point = {
|
||||
"timestamp": current_time.isoformat(),
|
||||
"online_count": online_count,
|
||||
"playing_count": playing_count
|
||||
"playing_count": playing_count,
|
||||
}
|
||||
|
||||
|
||||
# 添加到历史记录
|
||||
await _redis_exec(redis_sync.lpush, REDIS_ONLINE_HISTORY_KEY, json.dumps(history_point))
|
||||
# 只保留48个数据点(24小时,每30分钟一个点)
|
||||
await _redis_exec(redis_sync.ltrim, REDIS_ONLINE_HISTORY_KEY, 0, 47)
|
||||
# 设置过期时间为26小时,确保有足够缓冲
|
||||
await redis_async.expire(REDIS_ONLINE_HISTORY_KEY, 26 * 3600)
|
||||
|
||||
logger.info(f"Recorded fallback stats: online={online_count}, playing={playing_count} at {current_time.strftime('%H:%M:%S')}")
|
||||
|
||||
logger.info(
|
||||
f"Recorded fallback stats: online={online_count}, playing={playing_count} "
|
||||
f"at {current_time.strftime('%H:%M:%S')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error recording fallback stats: {e}")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
@@ -26,7 +25,7 @@ from app.service.user_cache_service import get_user_cache_service
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import HTTPException, Path, Query, Security
|
||||
from fastapi import BackgroundTasks, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import exists, false, select
|
||||
from sqlmodel.sql.expression import col
|
||||
@@ -47,13 +46,10 @@ class BatchUserResponse(BaseModel):
|
||||
@router.get("/users/lookup/", response_model=BatchUserResponse, include_in_schema=False)
|
||||
async def get_users(
|
||||
session: Database,
|
||||
user_ids: list[int] = Query(
|
||||
default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"
|
||||
),
|
||||
background_task: BackgroundTasks,
|
||||
user_ids: list[int] = Query(default_factory=list, alias="ids[]", description="要查询的用户 ID 列表"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
include_variant_statistics: bool = Query(
|
||||
default=False, description="是否包含各模式的统计信息"
|
||||
), # TODO: future use
|
||||
include_variant_statistics: bool = Query(default=False, description="是否包含各模式的统计信息"), # TODO: future use
|
||||
):
|
||||
redis = get_redis()
|
||||
cache_service = get_user_cache_service(redis)
|
||||
@@ -72,11 +68,7 @@ async def get_users(
|
||||
|
||||
# 查询未缓存的用户
|
||||
if uncached_user_ids:
|
||||
searched_users = (
|
||||
await session.exec(
|
||||
select(User).where(col(User.id).in_(uncached_user_ids))
|
||||
)
|
||||
).all()
|
||||
searched_users = (await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))).all()
|
||||
|
||||
# 将查询到的用户添加到缓存并返回
|
||||
for searched_user in searched_users:
|
||||
@@ -88,7 +80,7 @@ async def get_users(
|
||||
)
|
||||
cached_users.append(user_resp)
|
||||
# 异步缓存,不阻塞响应
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return BatchUserResponse(users=cached_users)
|
||||
else:
|
||||
@@ -103,7 +95,7 @@ async def get_users(
|
||||
)
|
||||
users.append(user_resp)
|
||||
# 异步缓存
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return BatchUserResponse(users=users)
|
||||
|
||||
@@ -117,6 +109,7 @@ async def get_users(
|
||||
)
|
||||
async def get_user_info_ruleset(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
ruleset: GameMode | None = Path(description="指定 ruleset"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -134,9 +127,7 @@ async def get_user_info_ruleset(
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
select(User).where(
|
||||
User.id == int(user_id)
|
||||
if user_id.isdigit()
|
||||
else User.username == user_id.removeprefix("@")
|
||||
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
|
||||
)
|
||||
)
|
||||
).first()
|
||||
@@ -151,7 +142,7 @@ async def get_user_info_ruleset(
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(cache_service.cache_user(user_resp, ruleset))
|
||||
background_task.add_task(cache_service.cache_user, user_resp, ruleset)
|
||||
|
||||
return user_resp
|
||||
|
||||
@@ -165,6 +156,7 @@ async def get_user_info_ruleset(
|
||||
tags=["用户"],
|
||||
)
|
||||
async def get_user_info(
|
||||
background_task: BackgroundTasks,
|
||||
session: Database,
|
||||
user_id: str = Path(description="用户 ID 或用户名"),
|
||||
# current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -182,9 +174,7 @@ async def get_user_info(
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
select(User).where(
|
||||
User.id == int(user_id)
|
||||
if user_id.isdigit()
|
||||
else User.username == user_id.removeprefix("@")
|
||||
User.id == int(user_id) if user_id.isdigit() else User.username == user_id.removeprefix("@")
|
||||
)
|
||||
)
|
||||
).first()
|
||||
@@ -198,7 +188,7 @@ async def get_user_info(
|
||||
)
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(cache_service.cache_user(user_resp))
|
||||
background_task.add_task(cache_service.cache_user, user_resp)
|
||||
|
||||
return user_resp
|
||||
|
||||
@@ -212,6 +202,7 @@ async def get_user_info(
|
||||
)
|
||||
async def get_user_beatmapsets(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
type: BeatmapsetType = Path(description="谱面集类型"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -222,9 +213,7 @@ async def get_user_beatmapsets(
|
||||
cache_service = get_user_cache_service(redis)
|
||||
|
||||
# 先尝试从缓存获取
|
||||
cached_result = await cache_service.get_user_beatmapsets_from_cache(
|
||||
user_id, type.value, limit, offset
|
||||
)
|
||||
cached_result = await cache_service.get_user_beatmapsets_from_cache(user_id, type.value, limit, offset)
|
||||
if cached_result is not None:
|
||||
# 根据类型恢复对象
|
||||
if type == BeatmapsetType.MOST_PLAYED:
|
||||
@@ -253,10 +242,7 @@ async def get_user_beatmapsets(
|
||||
raise HTTPException(404, detail="User not found")
|
||||
favourites = await user.awaitable_attrs.favourite_beatmapsets
|
||||
resp = [
|
||||
await BeatmapsetResp.from_db(
|
||||
favourite.beatmapset, session=session, user=user
|
||||
)
|
||||
for favourite in favourites
|
||||
await BeatmapsetResp.from_db(favourite.beatmapset, session=session, user=user) for favourite in favourites
|
||||
]
|
||||
|
||||
elif type == BeatmapsetType.MOST_PLAYED:
|
||||
@@ -267,25 +253,18 @@ async def get_user_beatmapsets(
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
resp = [
|
||||
await BeatmapPlaycountsResp.from_db(most_played_beatmap)
|
||||
for most_played_beatmap in most_played
|
||||
]
|
||||
resp = [await BeatmapPlaycountsResp.from_db(most_played_beatmap) for most_played_beatmap in most_played]
|
||||
else:
|
||||
raise HTTPException(400, detail="Invalid beatmapset type")
|
||||
|
||||
# 异步缓存结果
|
||||
async def cache_beatmapsets():
|
||||
try:
|
||||
await cache_service.cache_user_beatmapsets(
|
||||
user_id, type.value, resp, limit, offset
|
||||
)
|
||||
await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
|
||||
)
|
||||
logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}")
|
||||
|
||||
asyncio.create_task(cache_beatmapsets())
|
||||
background_task.add_task(cache_beatmapsets)
|
||||
|
||||
return resp
|
||||
|
||||
@@ -299,18 +278,14 @@ async def get_user_beatmapsets(
|
||||
)
|
||||
async def get_user_scores(
|
||||
session: Database,
|
||||
background_task: BackgroundTasks,
|
||||
user_id: int = Path(description="用户 ID"),
|
||||
type: Literal["best", "recent", "firsts", "pinned"] = Path(
|
||||
description=(
|
||||
"成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩"
|
||||
" / firsts 第一名成绩 / pinned 置顶成绩"
|
||||
)
|
||||
description=("成绩类型: best 最好成绩 / recent 最近 24h 游玩成绩 / firsts 第一名成绩 / pinned 置顶成绩")
|
||||
),
|
||||
legacy_only: bool = Query(False, description="是否只查询 Stable 成绩"),
|
||||
include_fails: bool = Query(False, description="是否包含失败的成绩"),
|
||||
mode: GameMode | None = Query(
|
||||
None, description="指定 ruleset (可选,默认为用户主模式)"
|
||||
),
|
||||
mode: GameMode | None = Query(None, description="指定 ruleset (可选,默认为用户主模式)"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回条数 (1-1000)"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
current_user: User = Security(get_current_user, scopes=["public"]),
|
||||
@@ -320,9 +295,7 @@ async def get_user_scores(
|
||||
|
||||
# 先尝试从缓存获取(对于recent类型使用较短的缓存时间)
|
||||
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
|
||||
cached_scores = await cache_service.get_user_scores_from_cache(
|
||||
user_id, type, mode, limit, offset
|
||||
)
|
||||
cached_scores = await cache_service.get_user_scores_from_cache(user_id, type, mode, limit, offset)
|
||||
if cached_scores is not None:
|
||||
return cached_scores
|
||||
|
||||
@@ -332,9 +305,7 @@ async def get_user_scores(
|
||||
|
||||
gamemode = mode or db_user.playmode
|
||||
order_by = None
|
||||
where_clause = (col(Score.user_id) == db_user.id) & (
|
||||
col(Score.gamemode) == gamemode
|
||||
)
|
||||
where_clause = (col(Score.user_id) == db_user.id) & (col(Score.gamemode) == gamemode)
|
||||
if not include_fails:
|
||||
where_clause &= col(Score.passed).is_(True)
|
||||
if type == "pinned":
|
||||
@@ -351,13 +322,7 @@ async def get_user_scores(
|
||||
where_clause &= false()
|
||||
|
||||
scores = (
|
||||
await session.exec(
|
||||
select(Score)
|
||||
.where(where_clause)
|
||||
.order_by(order_by)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
await session.exec(select(Score).where(where_clause).order_by(order_by).limit(limit).offset(offset))
|
||||
).all()
|
||||
if not scores:
|
||||
return []
|
||||
@@ -371,18 +336,14 @@ async def get_user_scores(
|
||||
]
|
||||
|
||||
# 异步缓存结果
|
||||
asyncio.create_task(
|
||||
cache_service.cache_user_scores(
|
||||
user_id, type, score_responses, mode, limit, offset, cache_expire
|
||||
)
|
||||
background_task.add_task(
|
||||
cache_service.cache_user_scores, user_id, type, score_responses, mode, limit, offset, cache_expire
|
||||
)
|
||||
|
||||
return score_responses
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp]
|
||||
)
|
||||
@router.get("/users/{user}/recent_activity", tags=["用户"], response_model=list[EventResp])
|
||||
async def get_user_events(
|
||||
session: Database,
|
||||
user: int,
|
||||
|
||||
Reference in New Issue
Block a user