Merge branch 'main' into feat/multiplayer-api
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.signalr import signalr_router as signalr_router
|
||||
|
||||
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
|
||||
beatmap,
|
||||
beatmapset,
|
||||
me,
|
||||
relationship,
|
||||
score,
|
||||
user,
|
||||
)
|
||||
from .api_router import router as api_router
|
||||
from .auth import router as auth_router
|
||||
from .fetcher import fetcher_router as fetcher_router
|
||||
from .signalr import signalr_router as signalr_router
|
||||
|
||||
__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"]
|
||||
|
||||
@@ -1,39 +1,264 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
import re
|
||||
|
||||
from app.auth import (
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
generate_refresh_token,
|
||||
get_password_hash,
|
||||
get_token_by_refresh_token,
|
||||
store_token,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_db
|
||||
from app.models.oauth import TokenResponse, OAuthErrorResponse
|
||||
from app.models.oauth import (
|
||||
OAuthErrorResponse,
|
||||
RegistrationRequestErrors,
|
||||
TokenResponse,
|
||||
UserRegistrationErrors,
|
||||
)
|
||||
|
||||
from fastapi import APIRouter, Depends, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400):
|
||||
def create_oauth_error_response(
|
||||
error: str, description: str, hint: str, status_code: int = 400
|
||||
):
|
||||
"""创建标准的 OAuth 错误响应"""
|
||||
error_data = OAuthErrorResponse(
|
||||
error=error,
|
||||
error_description=description,
|
||||
hint=hint,
|
||||
message=description
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=error_data.model_dump()
|
||||
error=error, error_description=description, hint=hint, message=description
|
||||
)
|
||||
return JSONResponse(status_code=status_code, content=error_data.model_dump())
|
||||
|
||||
|
||||
def validate_username(username: str) -> list[str]:
|
||||
"""验证用户名"""
|
||||
errors = []
|
||||
|
||||
if not username:
|
||||
errors.append("Username is required")
|
||||
return errors
|
||||
|
||||
if len(username) < 3:
|
||||
errors.append("Username must be at least 3 characters long")
|
||||
|
||||
if len(username) > 15:
|
||||
errors.append("Username must be at most 15 characters long")
|
||||
|
||||
# 检查用户名格式(只允许字母、数字、下划线、连字符)
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
|
||||
errors.append(
|
||||
"Username can only contain letters, numbers, underscores, and hyphens"
|
||||
)
|
||||
|
||||
# 检查是否以数字开头
|
||||
if username[0].isdigit():
|
||||
errors.append("Username cannot start with a number")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_email(email: str) -> list[str]:
|
||||
"""验证邮箱"""
|
||||
errors = []
|
||||
|
||||
if not email:
|
||||
errors.append("Email is required")
|
||||
return errors
|
||||
|
||||
# 基本的邮箱格式验证
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if not re.match(email_pattern, email):
|
||||
errors.append("Please enter a valid email address")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_password(password: str) -> list[str]:
|
||||
"""验证密码"""
|
||||
errors = []
|
||||
|
||||
if not password:
|
||||
errors.append("Password is required")
|
||||
return errors
|
||||
|
||||
if len(password) < 8:
|
||||
errors.append("Password must be at least 8 characters long")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
router = APIRouter(tags=["osu! OAuth 认证"])
|
||||
|
||||
|
||||
@router.post("/users")
|
||||
async def register_user(
|
||||
user_username: str = Form(..., alias="user[username]"),
|
||||
user_email: str = Form(..., alias="user[user_email]"),
|
||||
user_password: str = Form(..., alias="user[password]"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""用户注册接口 - 匹配 osu! 客户端的注册请求"""
|
||||
|
||||
username_errors = validate_username(user_username)
|
||||
email_errors = validate_email(user_email)
|
||||
password_errors = validate_password(user_password)
|
||||
|
||||
result = await db.exec(select(DBUser).where(DBUser.name == user_username))
|
||||
existing_user = result.first()
|
||||
if existing_user:
|
||||
username_errors.append("Username is already taken")
|
||||
|
||||
result = await db.exec(select(DBUser).where(DBUser.email == user_email))
|
||||
existing_email = result.first()
|
||||
if existing_email:
|
||||
email_errors.append("Email is already taken")
|
||||
|
||||
if username_errors or email_errors or password_errors:
|
||||
errors = RegistrationRequestErrors(
|
||||
user=UserRegistrationErrors(
|
||||
username=username_errors,
|
||||
user_email=email_errors,
|
||||
password=password_errors,
|
||||
)
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
|
||||
try:
|
||||
# 创建新用户
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
new_user = DBUser(
|
||||
name=user_username,
|
||||
safe_name=user_username.lower(), # 安全用户名(小写)
|
||||
email=user_email,
|
||||
pw_bcrypt=get_password_hash(user_password),
|
||||
priv=1, # 普通用户权限
|
||||
country="CN", # 默认国家
|
||||
creation_time=int(time.time()),
|
||||
latest_activity=int(time.time()),
|
||||
preferred_mode=0, # 默认模式
|
||||
play_style=0, # 默认游戏风格
|
||||
)
|
||||
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
|
||||
# 保存用户ID,因为会话可能会关闭
|
||||
user_id = new_user.id
|
||||
|
||||
if user_id <= 2:
|
||||
await db.rollback()
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3"))
|
||||
await db.commit()
|
||||
|
||||
# 重新创建用户
|
||||
new_user = DBUser(
|
||||
name=user_username,
|
||||
safe_name=user_username.lower(),
|
||||
email=user_email,
|
||||
pw_bcrypt=get_password_hash(user_password),
|
||||
priv=1,
|
||||
country="CN",
|
||||
creation_time=int(time.time()),
|
||||
latest_activity=int(time.time()),
|
||||
preferred_mode=0,
|
||||
play_style=0,
|
||||
)
|
||||
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
user_id = new_user.id
|
||||
|
||||
# 最终检查ID是否有效
|
||||
if user_id <= 2:
|
||||
await db.rollback()
|
||||
errors = RegistrationRequestErrors(
|
||||
message=(
|
||||
"Failed to create account with valid ID. "
|
||||
"Please contact support."
|
||||
)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
|
||||
except Exception as fix_error:
|
||||
await db.rollback()
|
||||
print(f"Failed to fix AUTO_INCREMENT: {fix_error}")
|
||||
errors = RegistrationRequestErrors(
|
||||
message="Failed to create account with valid ID. Please try again."
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
|
||||
# 创建默认的 lazer_profile
|
||||
from app.database.user import LazerUserProfile
|
||||
|
||||
lazer_profile = LazerUserProfile(
|
||||
user_id=user_id,
|
||||
is_active=True,
|
||||
is_bot=False,
|
||||
is_deleted=False,
|
||||
is_online=True,
|
||||
is_supporter=False,
|
||||
is_restricted=False,
|
||||
session_verified=False,
|
||||
has_supported=False,
|
||||
pm_friends_only=False,
|
||||
default_group="default",
|
||||
join_date=datetime.utcnow(),
|
||||
playmode="osu",
|
||||
support_level=0,
|
||||
max_blocks=50,
|
||||
max_friends=250,
|
||||
post_count=0,
|
||||
)
|
||||
|
||||
db.add(lazer_profile)
|
||||
await db.commit()
|
||||
|
||||
# 返回成功响应
|
||||
return JSONResponse(
|
||||
status_code=201,
|
||||
content={"message": "Account created successfully", "user_id": user_id},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
# 打印详细错误信息用于调试
|
||||
print(f"Registration error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# 返回通用错误
|
||||
errors = RegistrationRequestErrors(
|
||||
message="An error occurred while creating your account. Please try again."
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/oauth/token", response_model=TokenResponse)
|
||||
async def oauth_token(
|
||||
grant_type: str = Form(...),
|
||||
@@ -53,9 +278,13 @@ async def oauth_token(
|
||||
):
|
||||
return create_oauth_error_response(
|
||||
error="invalid_client",
|
||||
description="Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method).",
|
||||
description=(
|
||||
"Client authentication failed (e.g., unknown client, "
|
||||
"no client authentication included, "
|
||||
"or unsupported authentication method)."
|
||||
),
|
||||
hint="Invalid client credentials",
|
||||
status_code=401
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
if grant_type == "password":
|
||||
@@ -63,8 +292,12 @@ async def oauth_token(
|
||||
if not username or not password:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_request",
|
||||
description="The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed.",
|
||||
hint="Username and password required"
|
||||
description=(
|
||||
"The request is missing a required parameter, includes an "
|
||||
"invalid parameter value, "
|
||||
"includes a parameter more than once, or is otherwise malformed."
|
||||
),
|
||||
hint="Username and password required",
|
||||
)
|
||||
|
||||
# 验证用户
|
||||
@@ -72,8 +305,14 @@ async def oauth_token(
|
||||
if not user:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_grant",
|
||||
description="The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client.",
|
||||
hint="Incorrect sign in"
|
||||
description=(
|
||||
"The provided authorization grant (e.g., authorization code, "
|
||||
"resource owner credentials) "
|
||||
"or refresh token is invalid, expired, revoked, "
|
||||
"does not match the redirection URI used in "
|
||||
"the authorization request, or was issued to another client."
|
||||
),
|
||||
hint="Incorrect sign in",
|
||||
)
|
||||
|
||||
# 生成令牌
|
||||
@@ -105,8 +344,12 @@ async def oauth_token(
|
||||
if not refresh_token:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_request",
|
||||
description="The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed.",
|
||||
hint="Refresh token required"
|
||||
description=(
|
||||
"The request is missing a required parameter, "
|
||||
"includes an invalid parameter value, "
|
||||
"includes a parameter more than once, or is otherwise malformed."
|
||||
),
|
||||
hint="Refresh token required",
|
||||
)
|
||||
|
||||
# 验证刷新令牌
|
||||
@@ -114,8 +357,14 @@ async def oauth_token(
|
||||
if not token_record:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_grant",
|
||||
description="The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client.",
|
||||
hint="Invalid refresh token"
|
||||
description=(
|
||||
"The provided authorization grant (e.g., authorization code, "
|
||||
"resource owner credentials) or refresh token is "
|
||||
"invalid, expired, revoked, "
|
||||
"does not match the redirection URI used "
|
||||
"in the authorization request, or was issued to another client."
|
||||
),
|
||||
hint="Invalid refresh token",
|
||||
)
|
||||
|
||||
# 生成新的访问令牌
|
||||
@@ -145,6 +394,9 @@ async def oauth_token(
|
||||
else:
|
||||
return create_oauth_error_response(
|
||||
error="unsupported_grant_type",
|
||||
description="The authorization grant type is not supported by the authorization server.",
|
||||
hint="Unsupported grant type"
|
||||
description=(
|
||||
"The authorization grant type is not supported "
|
||||
"by the authorization server."
|
||||
),
|
||||
hint="Unsupported grant type",
|
||||
)
|
||||
|
||||
@@ -16,7 +16,10 @@ from app.dependencies.user import get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod, int_to_mods
|
||||
from app.models.score import INT_TO_MODE, GameMode
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
GameMode,
|
||||
)
|
||||
from app.utils import calculate_beatmap_attribute
|
||||
|
||||
from .api_router import router
|
||||
@@ -31,6 +34,31 @@ from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/beatmaps/lookup", tags=["beatmap"], response_model=BeatmapResp)
|
||||
async def lookup_beatmap(
|
||||
id: int | None = Query(default=None, alias="id"),
|
||||
md5: str | None = Query(default=None, alias="checksum"),
|
||||
filename: str | None = Query(default=None, alias="filename"),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
if id is None and md5 is None and filename is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="At least one of 'id', 'checksum', or 'filename' must be provided.",
|
||||
)
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=id, md5=md5)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
if beatmap is None:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
|
||||
|
||||
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
||||
async def get_beatmap(
|
||||
bid: int,
|
||||
@@ -39,7 +67,7 @@ async def get_beatmap(
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, bid, fetcher)
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
@@ -119,7 +147,7 @@ async def get_beatmap_attributes(
|
||||
if ruleset_id is not None and ruleset is None:
|
||||
ruleset = INT_TO_MODE[ruleset_id]
|
||||
if ruleset is None:
|
||||
beatmap_db = await Beatmap.get_or_fetch(db, beatmap, fetcher)
|
||||
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap)
|
||||
ruleset = beatmap_db.mode
|
||||
key = (
|
||||
f"beatmap:{beatmap}:{ruleset}:"
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.database.relationship import Relationship, RelationshipResp, RelationshipType
|
||||
from app.dependencies.database import get_db
|
||||
@@ -9,21 +7,23 @@ from app.dependencies.user import get_current_user
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from fastapi import Depends, HTTPException, Query, Request
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/{type}", tags=["relationship"], response_model=list[RelationshipResp])
|
||||
@router.get("/friends", tags=["relationship"], response_model=list[RelationshipResp])
|
||||
@router.get("/blocks", tags=["relationship"], response_model=list[RelationshipResp])
|
||||
async def get_relationship(
|
||||
type: Literal["friends", "blocks"],
|
||||
request: Request,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if type == "friends":
|
||||
relationship_type = RelationshipType.FOLLOW
|
||||
else:
|
||||
relationship_type = RelationshipType.BLOCK
|
||||
relationship_type = (
|
||||
RelationshipType.FOLLOW
|
||||
if request.url.path.endswith("/friends")
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
relationships = await db.exec(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == current_user.id,
|
||||
@@ -33,17 +33,19 @@ async def get_relationship(
|
||||
return [await RelationshipResp.from_db(db, rel) for rel in relationships]
|
||||
|
||||
|
||||
@router.post("/{type}", tags=["relationship"], response_model=RelationshipResp)
|
||||
@router.post("/friends", tags=["relationship"], response_model=RelationshipResp)
|
||||
@router.post("/blocks", tags=["relationship"])
|
||||
async def add_relationship(
|
||||
type: Literal["friends", "blocks"],
|
||||
request: Request,
|
||||
target: int = Query(),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if type == "blocks":
|
||||
relationship_type = RelationshipType.BLOCK
|
||||
else:
|
||||
relationship_type = RelationshipType.FOLLOW
|
||||
relationship_type = (
|
||||
RelationshipType.FOLLOW
|
||||
if request.url.path.endswith("/friends")
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
if target == current_user.id:
|
||||
raise HTTPException(422, "Cannot add relationship to yourself")
|
||||
relationship = (
|
||||
@@ -78,18 +80,22 @@ async def add_relationship(
|
||||
await db.delete(target_relationship)
|
||||
await db.commit()
|
||||
await db.refresh(relationship)
|
||||
return await RelationshipResp.from_db(db, relationship)
|
||||
if relationship.type == RelationshipType.FOLLOW:
|
||||
return await RelationshipResp.from_db(db, relationship)
|
||||
|
||||
|
||||
@router.delete("/{type}/{target}", tags=["relationship"])
|
||||
@router.delete("/friends/{target}", tags=["relationship"])
|
||||
@router.delete("/blocks/{target}", tags=["relationship"])
|
||||
async def delete_relationship(
|
||||
type: Literal["friends", "blocks"],
|
||||
request: Request,
|
||||
target: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
relationship_type = (
|
||||
RelationshipType.BLOCK if type == "blocks" else RelationshipType.FOLLOW
|
||||
RelationshipType.BLOCK
|
||||
if "/blocks/" in request.url.path
|
||||
else RelationshipType.FOLLOW
|
||||
)
|
||||
relationship = (
|
||||
await db.exec(
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.score import Score, ScoreResp
|
||||
from app.database.score_token import ScoreToken, ScoreTokenResp
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
GameMode,
|
||||
HitResult,
|
||||
Rank,
|
||||
SoloScoreSubmissionInfo,
|
||||
)
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from fastapi import Depends, Form, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel import col, select, true
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -29,7 +37,7 @@ class BeatmapScores(BaseModel):
|
||||
async def get_beatmap_scores(
|
||||
beatmap: int,
|
||||
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
|
||||
mode: str = Query(None),
|
||||
mode: GameMode | None = Query(None),
|
||||
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
|
||||
type: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
@@ -42,29 +50,28 @@ async def get_beatmap_scores(
|
||||
|
||||
all_scores = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.beatmap_id == beatmap)
|
||||
# .where(Score.mods == mods if mods else True)
|
||||
Score.select_clause_unique(
|
||||
Score.beatmap_id == beatmap,
|
||||
col(Score.passed).is_(True),
|
||||
Score.gamemode == mode if mode is not None else true(),
|
||||
)
|
||||
)
|
||||
).all()
|
||||
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
Score.select_clause_unique(
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == current_user.id,
|
||||
col(Score.passed).is_(True),
|
||||
Score.gamemode == mode if mode is not None else true(),
|
||||
)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
|
||||
return BeatmapScores(
|
||||
scores=[ScoreResp.from_db(score) for score in all_scores],
|
||||
userScore=ScoreResp.from_db(user_score) if user_score else None,
|
||||
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
|
||||
userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -93,18 +100,13 @@ async def get_user_beatmap_score(
|
||||
)
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
Score.select_clause()
|
||||
.where(
|
||||
Score.gamemode == mode if mode is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == user,
|
||||
)
|
||||
.where(Score.gamemode == mode if mode is not None else True)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == user)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
.order_by(col(Score.total_score).desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
@@ -115,7 +117,7 @@ async def get_user_beatmap_score(
|
||||
else:
|
||||
return BeatmapUserScore(
|
||||
position=user_score.position if user_score.position is not None else 0,
|
||||
score=ScoreResp.from_db(user_score),
|
||||
score=await ScoreResp.from_db(db, user_score),
|
||||
)
|
||||
|
||||
|
||||
@@ -138,19 +140,114 @@ async def get_user_all_beatmap_scores(
|
||||
)
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
Score.select_clause()
|
||||
.where(
|
||||
Score.gamemode == ruleset if ruleset is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == user,
|
||||
)
|
||||
.where(Score.gamemode == ruleset if ruleset is not None else True)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == user)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
)
|
||||
).all()
|
||||
|
||||
return [ScoreResp.from_db(score) for score in all_user_scores]
|
||||
return [await ScoreResp.from_db(db, score) for score in all_user_scores]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/beatmaps/{beatmap}/solo/scores", tags=["beatmap"], response_model=ScoreTokenResp
|
||||
)
|
||||
async def create_solo_score(
|
||||
beatmap: int,
|
||||
version_hash: str = Form(""),
|
||||
beatmap_hash: str = Form(),
|
||||
ruleset_id: int = Form(..., ge=0, le=3),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
async with db:
|
||||
score_token = ScoreToken(
|
||||
user_id=current_user.id,
|
||||
beatmap_id=beatmap,
|
||||
ruleset_id=INT_TO_MODE[ruleset_id],
|
||||
)
|
||||
db.add(score_token)
|
||||
await db.commit()
|
||||
await db.refresh(score_token)
|
||||
return ScoreTokenResp.from_db(score_token)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/beatmaps/{beatmap}/solo/scores/{token}",
|
||||
tags=["beatmap"],
|
||||
response_model=ScoreResp,
|
||||
)
|
||||
async def submit_solo_score(
|
||||
beatmap: int,
|
||||
token: int,
|
||||
info: SoloScoreSubmissionInfo,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
async with db:
|
||||
score_token = (
|
||||
await db.exec(
|
||||
select(ScoreToken)
|
||||
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(ScoreToken.id == token, ScoreToken.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
if not score_token or score_token.user_id != current_user.id:
|
||||
raise HTTPException(status_code=404, detail="Score token not found")
|
||||
if score_token.score_id:
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
Score.id == score_token.score_id,
|
||||
Score.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not score:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
else:
|
||||
score = Score(
|
||||
accuracy=info.accuracy,
|
||||
max_combo=info.max_combo,
|
||||
# maximum_statistics=info.maximum_statistics,
|
||||
mods=info.mods,
|
||||
passed=info.passed,
|
||||
rank=info.rank,
|
||||
total_score=info.total_score,
|
||||
total_score_without_mods=info.total_score_without_mods,
|
||||
beatmap_id=beatmap,
|
||||
ended_at=datetime.datetime.now(datetime.UTC),
|
||||
gamemode=INT_TO_MODE[info.ruleset_id],
|
||||
started_at=score_token.created_at,
|
||||
user_id=current_user.id,
|
||||
preserve=info.passed,
|
||||
map_md5=score_token.beatmap.checksum,
|
||||
has_replay=False,
|
||||
pp=info.pp,
|
||||
type="solo",
|
||||
n300=info.statistics.get(HitResult.GREAT, 0),
|
||||
n100=info.statistics.get(HitResult.OK, 0),
|
||||
n50=info.statistics.get(HitResult.MEH, 0),
|
||||
nmiss=info.statistics.get(HitResult.MISS, 0),
|
||||
ngeki=info.statistics.get(HitResult.PERFECT, 0),
|
||||
nkatu=info.statistics.get(HitResult.GOOD, 0),
|
||||
)
|
||||
db.add(score)
|
||||
await db.commit()
|
||||
await db.refresh(score)
|
||||
score_id = score.id
|
||||
score_token.score_id = score_id
|
||||
await db.commit()
|
||||
score = (
|
||||
await db.exec(Score.select_clause().where(Score.id == score_id))
|
||||
).first()
|
||||
assert score is not None
|
||||
return await ScoreResp.from_db(db, score)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .router import router as signalr_router
|
||||
|
||||
__all__ = ["signalr_router"]
|
||||
@@ -1,10 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class SignalRException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvokeException(SignalRException):
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
@@ -1,15 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .hub import Hub
|
||||
from .metadata import MetadataHub
|
||||
from .multiplayer import MultiplayerHub
|
||||
from .spectator import SpectatorHub
|
||||
|
||||
SpectatorHubs = SpectatorHub()
|
||||
MultiplayerHubs = MultiplayerHub()
|
||||
MetadataHubs = MetadataHub()
|
||||
Hubs: dict[str, Hub] = {
|
||||
"spectator": SpectatorHubs,
|
||||
"multiplayer": MultiplayerHubs,
|
||||
"metadata": MetadataHubs,
|
||||
}
|
||||
@@ -1,211 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.router.signalr.exception import InvokeException
|
||||
from app.router.signalr.packet import (
|
||||
PacketType,
|
||||
ResultKind,
|
||||
encode_varint,
|
||||
parse_packet,
|
||||
)
|
||||
from app.router.signalr.store import ResultStore
|
||||
from app.router.signalr.utils import get_signature
|
||||
|
||||
from fastapi import WebSocket
|
||||
import msgpack
|
||||
from pydantic import BaseModel
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self, connection_id: str, connection_token: str, connection: WebSocket
|
||||
) -> None:
|
||||
self.connection_id = connection_id
|
||||
self.connection_token = connection_token
|
||||
self.connection = connection
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
self._ping_task: asyncio.Task | None = None
|
||||
self._store = ResultStore()
|
||||
|
||||
async def send_packet(self, type: PacketType, packet: list[Any]):
|
||||
packet.insert(0, type.value)
|
||||
payload = msgpack.packb(packet)
|
||||
length = encode_varint(len(payload))
|
||||
await self.connection.send_bytes(length + payload)
|
||||
|
||||
async def _ping(self):
|
||||
while True:
|
||||
try:
|
||||
await self.send_packet(PacketType.PING, [])
|
||||
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error in ping task for {self.connection_id}: {e}")
|
||||
break
|
||||
|
||||
|
||||
class Hub:
|
||||
def __init__(self) -> None:
|
||||
self.clients: dict[str, Client] = {}
|
||||
self.waited_clients: dict[str, int] = {}
|
||||
self.tasks: set[asyncio.Task] = set()
|
||||
|
||||
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
|
||||
self.waited_clients[connection_token] = timestamp
|
||||
|
||||
def add_client(
|
||||
self, connection_id: str, connection_token: str, connection: WebSocket
|
||||
) -> Client:
|
||||
if connection_token in self.clients:
|
||||
raise ValueError(
|
||||
f"Client with connection token {connection_token} already exists."
|
||||
)
|
||||
if connection_token in self.waited_clients:
|
||||
if (
|
||||
self.waited_clients[connection_token]
|
||||
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT
|
||||
):
|
||||
raise TimeoutError(f"Connection {connection_id} has waited too long.")
|
||||
del self.waited_clients[connection_token]
|
||||
client = Client(connection_id, connection_token, connection)
|
||||
self.clients[connection_token] = client
|
||||
task = asyncio.create_task(client._ping())
|
||||
self.tasks.add(task)
|
||||
client._ping_task = task
|
||||
return client
|
||||
|
||||
async def remove_client(self, connection_id: str) -> None:
|
||||
if client := self.clients.get(connection_id):
|
||||
del self.clients[connection_id]
|
||||
if client._listen_task:
|
||||
client._listen_task.cancel()
|
||||
if client._ping_task:
|
||||
client._ping_task.cancel()
|
||||
await client.connection.close()
|
||||
|
||||
async def send_packet(self, client: Client, type: PacketType, packet: list[Any]):
|
||||
await client.send_packet(type, packet)
|
||||
|
||||
async def _listen_client(self, client: Client) -> None:
|
||||
jump = False
|
||||
while not jump:
|
||||
try:
|
||||
message = await client.connection.receive_bytes()
|
||||
packet_type, packet_data = parse_packet(message)
|
||||
task = asyncio.create_task(
|
||||
self._handle_packet(client, packet_type, packet_data)
|
||||
)
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
except WebSocketDisconnect as e:
|
||||
if e.code == 1005:
|
||||
continue
|
||||
print(
|
||||
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
|
||||
)
|
||||
jump = True
|
||||
except Exception as e:
|
||||
print(f"Error in client {client.connection_id}: {e}")
|
||||
jump = True
|
||||
await self.remove_client(client.connection_id)
|
||||
|
||||
async def _handle_packet(
|
||||
self, client: Client, type: PacketType, packet: list[Any]
|
||||
) -> None:
|
||||
match type:
|
||||
case PacketType.PING:
|
||||
...
|
||||
case PacketType.INVOCATION:
|
||||
invocation_id: str | None = packet[1] # pyright: ignore[reportRedeclaration]
|
||||
target: str = packet[2]
|
||||
args: list[Any] | None = packet[3]
|
||||
if args is None:
|
||||
args = []
|
||||
# streams: list[str] | None = packet[4] # TODO: stream support
|
||||
code = ResultKind.VOID
|
||||
result = None
|
||||
try:
|
||||
result = await self.invoke_method(client, target, args)
|
||||
if result is not None:
|
||||
code = ResultKind.HAS_VALUE
|
||||
except InvokeException as e:
|
||||
code = ResultKind.ERROR
|
||||
result = e.message
|
||||
|
||||
except Exception as e:
|
||||
code = ResultKind.ERROR
|
||||
result = str(e)
|
||||
|
||||
packet = [
|
||||
{}, # header
|
||||
invocation_id,
|
||||
code.value,
|
||||
]
|
||||
if result is not None:
|
||||
packet.append(result)
|
||||
if invocation_id is not None:
|
||||
await client.send_packet(
|
||||
PacketType.COMPLETION,
|
||||
packet,
|
||||
)
|
||||
case PacketType.COMPLETION:
|
||||
invocation_id: str = packet[1]
|
||||
code: ResultKind = ResultKind(packet[2])
|
||||
result: Any = packet[3] if len(packet) > 3 else None
|
||||
client._store.add_result(invocation_id, code, result)
|
||||
|
||||
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
|
||||
method_ = getattr(self, method, None)
|
||||
call_params = []
|
||||
if not method_:
|
||||
raise InvokeException(f"Method '{method}' not found in hub.")
|
||||
signature = get_signature(method_)
|
||||
for name, param in signature.parameters.items():
|
||||
if name == "self" or param.annotation is Client:
|
||||
continue
|
||||
if issubclass(param.annotation, BaseModel):
|
||||
call_params.append(param.annotation.model_validate(args.pop(0)))
|
||||
else:
|
||||
call_params.append(args.pop(0))
|
||||
return await method_(client, *call_params)
|
||||
|
||||
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
||||
invocation_id = client._store.get_invocation_id()
|
||||
await client.send_packet(
|
||||
PacketType.INVOCATION,
|
||||
[
|
||||
{}, # header
|
||||
invocation_id,
|
||||
method,
|
||||
list(args),
|
||||
None, # streams
|
||||
],
|
||||
)
|
||||
r = await client._store.fetch(invocation_id, None)
|
||||
if r[0] == ResultKind.HAS_VALUE:
|
||||
return r[1]
|
||||
if r[0] == ResultKind.ERROR:
|
||||
raise InvokeException(r[1])
|
||||
return None
|
||||
|
||||
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
|
||||
await client.send_packet(
|
||||
PacketType.INVOCATION,
|
||||
[
|
||||
{}, # header
|
||||
None, # invocation_id
|
||||
method,
|
||||
list(args),
|
||||
None, # streams
|
||||
],
|
||||
)
|
||||
return None
|
||||
|
||||
def __contains__(self, item: str) -> bool:
|
||||
return item in self.clients or item in self.waited_clients
|
||||
@@ -1,6 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .hub import Hub
|
||||
|
||||
|
||||
class MetadataHub(Hub): ...
|
||||
@@ -1,6 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .hub import Hub
|
||||
|
||||
|
||||
class MultiplayerHub(Hub): ...
|
||||
@@ -1,15 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.spectator_hub import FrameDataBundle, SpectatorState
|
||||
|
||||
from .hub import Client, Hub
|
||||
|
||||
|
||||
class SpectatorHub(Hub):
|
||||
async def BeginPlaySession(
|
||||
self, client: Client, score_token: int, state: SpectatorState
|
||||
) -> None: ...
|
||||
|
||||
async def SendFrameData(
|
||||
self, client: Client, frame_data: FrameDataBundle
|
||||
) -> None: ...
|
||||
@@ -1,56 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
|
||||
SEP = b"\x1e"
|
||||
|
||||
|
||||
class PacketType(IntEnum):
|
||||
INVOCATION = 1
|
||||
STREAM_ITEM = 2
|
||||
COMPLETION = 3
|
||||
STREAM_INVOCATION = 4
|
||||
CANCEL_INVOCATION = 5
|
||||
PING = 6
|
||||
CLOSE = 7
|
||||
|
||||
|
||||
class ResultKind(IntEnum):
|
||||
ERROR = 1
|
||||
VOID = 2
|
||||
HAS_VALUE = 3
|
||||
|
||||
|
||||
def parse_packet(data: bytes) -> tuple[PacketType, list[Any]]:
|
||||
length, offset = decode_varint(data)
|
||||
message_data = data[offset : offset + length]
|
||||
unpacked = msgpack.unpackb(message_data, raw=False)
|
||||
return PacketType(unpacked[0]), unpacked[1:]
|
||||
|
||||
|
||||
def encode_varint(value: int) -> bytes:
|
||||
result = []
|
||||
while value >= 0x80:
|
||||
result.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
result.append(value & 0x7F)
|
||||
return bytes(result)
|
||||
|
||||
|
||||
def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
|
||||
result = 0
|
||||
shift = 0
|
||||
pos = offset
|
||||
|
||||
while pos < len(data):
|
||||
byte = data[pos]
|
||||
result |= (byte & 0x7F) << shift
|
||||
pos += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
shift += 7
|
||||
|
||||
return result, pos
|
||||
@@ -1,91 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Literal
|
||||
import uuid
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user_by_token
|
||||
from app.models.signalr import NegotiateResponse, Transport
|
||||
from app.router.signalr.packet import SEP
|
||||
|
||||
from .hub import Hubs
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/{hub}/negotiate", response_model=NegotiateResponse)
|
||||
async def negotiate(
|
||||
hub: Literal["spectator", "multiplayer", "metadata"],
|
||||
negotiate_version: int = Query(1, alias="negotiateVersion"),
|
||||
user: DBUser = Depends(get_current_user),
|
||||
):
|
||||
connectionId = str(user.id)
|
||||
connectionToken = f"{connectionId}:{uuid.uuid4()}"
|
||||
Hubs[hub].add_waited_client(
|
||||
connection_token=connectionToken,
|
||||
timestamp=int(time.time()),
|
||||
)
|
||||
return NegotiateResponse(
|
||||
connectionId=connectionId,
|
||||
connectionToken=connectionToken,
|
||||
negotiateVersion=negotiate_version,
|
||||
availableTransports=[Transport(transport="WebSockets")],
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/{hub}")
|
||||
async def connect(
|
||||
hub: Literal["spectator", "multiplayer", "metadata"],
|
||||
websocket: WebSocket,
|
||||
id: str,
|
||||
authorization: str = Header(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
token = authorization[7:]
|
||||
user_id = id.split(":")[0]
|
||||
hub_ = Hubs[hub]
|
||||
if id not in hub_:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
if (user := await get_current_user_by_token(token, db)) is None or str(
|
||||
user.id
|
||||
) != user_id:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
await websocket.accept()
|
||||
|
||||
# handshake
|
||||
handshake = await websocket.receive_bytes()
|
||||
handshake_payload = json.loads(handshake[:-1])
|
||||
error = ""
|
||||
if (protocol := handshake_payload.get("protocol")) != "messagepack" or (
|
||||
handshake_payload.get("version")
|
||||
) != 1:
|
||||
error = f"Requested protocol '{protocol}' is not available."
|
||||
|
||||
client = None
|
||||
try:
|
||||
client = hub_.add_client(
|
||||
connection_id=user_id,
|
||||
connection_token=id,
|
||||
connection=websocket,
|
||||
)
|
||||
except TimeoutError:
|
||||
error = f"Connection {id} has waited too long."
|
||||
except ValueError as e:
|
||||
error = str(e)
|
||||
payload = {"error": error} if error else {}
|
||||
|
||||
# finish handshake
|
||||
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
|
||||
if error or not client:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
await hub_._listen_client(client)
|
||||
@@ -1,45 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.router.signalr.packet import ResultKind
|
||||
|
||||
|
||||
class ResultStore:
|
||||
def __init__(self) -> None:
|
||||
self._seq: int = 1
|
||||
self._futures: dict[str, asyncio.Future] = {}
|
||||
|
||||
@property
|
||||
def current_invocation_id(self) -> int:
|
||||
return self._seq
|
||||
|
||||
def get_invocation_id(self) -> str:
|
||||
s = self._seq
|
||||
self._seq = (self._seq + 1) % sys.maxsize
|
||||
return str(s)
|
||||
|
||||
def add_result(
|
||||
self, invocation_id: str, type: ResultKind, result: dict[str, Any] | None
|
||||
) -> None:
|
||||
if isinstance(invocation_id, str) and invocation_id.isdecimal():
|
||||
if future := self._futures.get(invocation_id):
|
||||
future.set_result((type, result))
|
||||
|
||||
async def fetch(
|
||||
self,
|
||||
invocation_id: str,
|
||||
timeout: float | None, # noqa: ASYNC109
|
||||
) -> (
|
||||
tuple[Literal[ResultKind.ERROR], str]
|
||||
| tuple[Literal[ResultKind.VOID], None]
|
||||
| tuple[Literal[ResultKind.HAS_VALUE], Any]
|
||||
):
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._futures[invocation_id] = future
|
||||
try:
|
||||
return await asyncio.wait_for(future, timeout)
|
||||
finally:
|
||||
del self._futures[invocation_id]
|
||||
@@ -1,48 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
from typing import Any, ForwardRef, cast
|
||||
|
||||
|
||||
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
|
||||
def evaluate_forwardref(
|
||||
type_: ForwardRef,
|
||||
globalns: Any,
|
||||
localns: Any,
|
||||
) -> Any:
|
||||
# Even though it is the right signature for python 3.9,
|
||||
# mypy complains with
|
||||
# `error: Too many arguments for "_evaluate" of
|
||||
# "ForwardRef"` hence the cast...
|
||||
return cast(Any, type_)._evaluate(
|
||||
globalns,
|
||||
localns,
|
||||
set(),
|
||||
)
|
||||
|
||||
|
||||
def get_annotation(param: inspect.Parameter, globalns: dict[str, Any]) -> Any:
|
||||
annotation = param.annotation
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
try:
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
except Exception:
|
||||
return inspect.Parameter.empty
|
||||
return annotation
|
||||
|
||||
|
||||
def get_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
kind=param.kind,
|
||||
default=param.default,
|
||||
annotation=get_annotation(param, globalns),
|
||||
)
|
||||
for param in signature.parameters.values()
|
||||
]
|
||||
return inspect.Signature(typed_params)
|
||||
75
app/router/user.py
Normal file
75
app/router/user.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.database import (
|
||||
User as DBUser,
|
||||
)
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.models.score import INT_TO_MODE
|
||||
from app.models.user import (
|
||||
User as ApiUser,
|
||||
)
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import col
|
||||
|
||||
|
||||
@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
|
||||
@router.get("/users/{user}", response_model=ApiUser)
|
||||
async def get_user_info_default(
|
||||
user: str,
|
||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
DBUser.all_select_clause().where(
|
||||
DBUser.id == int(user)
|
||||
if user.isdigit()
|
||||
else DBUser.name == user.removeprefix("@")
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not searched_user:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
return await convert_db_user_to_api_user(searched_user, ruleset=ruleset)
|
||||
|
||||
|
||||
class BatchUserResponse(BaseModel):
|
||||
users: list[ApiUser]
|
||||
|
||||
|
||||
@router.get("/users", response_model=BatchUserResponse)
|
||||
@router.get("/users/lookup", response_model=BatchUserResponse)
|
||||
async def get_users(
|
||||
user_ids: list[int] = Query(default_factory=list, alias="ids[]"),
|
||||
include_variant_statistics: bool = Query(default=False), # TODO
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if user_ids:
|
||||
searched_users = (
|
||||
await session.exec(
|
||||
DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids))
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
searched_users = (
|
||||
await session.exec(DBUser.all_select_clause().limit(50))
|
||||
).all()
|
||||
return BatchUserResponse(
|
||||
users=[
|
||||
await convert_db_user_to_api_user(
|
||||
searched_user, ruleset=INT_TO_MODE[current_user.preferred_mode].value
|
||||
)
|
||||
for searched_user in searched_users
|
||||
]
|
||||
)
|
||||
Reference in New Issue
Block a user