chore(merge): merge branch 'main' into feat/multiplayer-api
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import re
|
||||
|
||||
from app.auth import (
|
||||
@@ -12,17 +12,21 @@ from app.auth import (
|
||||
store_token,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.database import User as DBUser
|
||||
from app.database import DailyChallengeStats, User
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies import get_db
|
||||
from app.log import logger
|
||||
from app.models.oauth import (
|
||||
OAuthErrorResponse,
|
||||
RegistrationRequestErrors,
|
||||
TokenResponse,
|
||||
UserRegistrationErrors,
|
||||
)
|
||||
from app.models.score import GameMode
|
||||
|
||||
from fastapi import APIRouter, Depends, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import text
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -110,12 +114,12 @@ async def register_user(
|
||||
email_errors = validate_email(user_email)
|
||||
password_errors = validate_password(user_password)
|
||||
|
||||
result = await db.exec(select(DBUser).where(DBUser.name == user_username))
|
||||
result = await db.exec(select(User).where(User.username == user_username))
|
||||
existing_user = result.first()
|
||||
if existing_user:
|
||||
username_errors.append("Username is already taken")
|
||||
|
||||
result = await db.exec(select(DBUser).where(DBUser.email == user_email))
|
||||
result = await db.exec(select(User).where(User.email == user_email))
|
||||
existing_email = result.first()
|
||||
if existing_email:
|
||||
email_errors.append("Email is already taken")
|
||||
@@ -135,119 +139,41 @@ async def register_user(
|
||||
|
||||
try:
|
||||
# 创建新用户
|
||||
from datetime import datetime
|
||||
import time
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
result = await db.execute( # pyright: ignore[reportDeprecated]
|
||||
text(
|
||||
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
||||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'lazer_users'"
|
||||
)
|
||||
)
|
||||
next_id = result.one()[0]
|
||||
if next_id <= 2:
|
||||
await db.execute(text("ALTER TABLE lazer_users AUTO_INCREMENT = 3"))
|
||||
await db.commit()
|
||||
|
||||
new_user = DBUser(
|
||||
name=user_username,
|
||||
safe_name=user_username.lower(), # 安全用户名(小写)
|
||||
new_user = User(
|
||||
username=user_username,
|
||||
email=user_email,
|
||||
pw_bcrypt=get_password_hash(user_password),
|
||||
priv=1, # 普通用户权限
|
||||
country="CN", # 默认国家
|
||||
creation_time=int(time.time()),
|
||||
latest_activity=int(time.time()),
|
||||
preferred_mode=0, # 默认模式
|
||||
play_style=0, # 默认游戏风格
|
||||
country_code="CN", # 默认国家
|
||||
join_date=datetime.now(UTC),
|
||||
last_visit=datetime.now(UTC),
|
||||
)
|
||||
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
|
||||
# 保存用户ID,因为会话可能会关闭
|
||||
user_id = new_user.id
|
||||
|
||||
if user_id <= 2:
|
||||
await db.rollback()
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3"))
|
||||
await db.commit()
|
||||
|
||||
# 重新创建用户
|
||||
new_user = DBUser(
|
||||
name=user_username,
|
||||
safe_name=user_username.lower(),
|
||||
email=user_email,
|
||||
pw_bcrypt=get_password_hash(user_password),
|
||||
priv=1,
|
||||
country="CN",
|
||||
creation_time=int(time.time()),
|
||||
latest_activity=int(time.time()),
|
||||
preferred_mode=0,
|
||||
play_style=0,
|
||||
)
|
||||
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
user_id = new_user.id
|
||||
|
||||
# 最终检查ID是否有效
|
||||
if user_id <= 2:
|
||||
await db.rollback()
|
||||
errors = RegistrationRequestErrors(
|
||||
message=(
|
||||
"Failed to create account with valid ID. "
|
||||
"Please contact support."
|
||||
)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
|
||||
except Exception as fix_error:
|
||||
await db.rollback()
|
||||
print(f"Failed to fix AUTO_INCREMENT: {fix_error}")
|
||||
errors = RegistrationRequestErrors(
|
||||
message="Failed to create account with valid ID. Please try again."
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
|
||||
# 创建默认的 lazer_profile
|
||||
from app.database.user import LazerUserProfile
|
||||
|
||||
lazer_profile = LazerUserProfile(
|
||||
user_id=user_id,
|
||||
is_active=True,
|
||||
is_bot=False,
|
||||
is_deleted=False,
|
||||
is_online=True,
|
||||
is_supporter=False,
|
||||
is_restricted=False,
|
||||
session_verified=False,
|
||||
has_supported=False,
|
||||
pm_friends_only=False,
|
||||
default_group="default",
|
||||
join_date=datetime.utcnow(),
|
||||
playmode="osu",
|
||||
support_level=0,
|
||||
max_blocks=50,
|
||||
max_friends=250,
|
||||
post_count=0,
|
||||
)
|
||||
|
||||
db.add(lazer_profile)
|
||||
assert new_user.id is not None, "New user ID should not be None"
|
||||
for i in GameMode:
|
||||
statistics = UserStatistics(mode=i, user_id=new_user.id)
|
||||
db.add(statistics)
|
||||
daily_challenge_user_stats = DailyChallengeStats(user_id=new_user.id)
|
||||
db.add(daily_challenge_user_stats)
|
||||
await db.commit()
|
||||
|
||||
# 返回成功响应
|
||||
return JSONResponse(
|
||||
status_code=201,
|
||||
content={"message": "Account created successfully", "user_id": user_id},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
await db.rollback()
|
||||
# 打印详细错误信息用于调试
|
||||
print(f"Registration error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.exception(f"Registration error for user {user_username}")
|
||||
|
||||
# 返回通用错误
|
||||
errors = RegistrationRequestErrors(
|
||||
@@ -323,6 +249,7 @@ async def oauth_token(
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
assert user.id
|
||||
await store_token(
|
||||
db,
|
||||
user.id,
|
||||
|
||||
@@ -5,12 +5,7 @@ import hashlib
|
||||
import json
|
||||
|
||||
from app.calculator import calculate_beatmap_attribute
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
BeatmapResp,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database import Beatmap, BeatmapResp, User
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
@@ -27,9 +22,8 @@ from .api_router import router
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from httpx import HTTPError, HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis
|
||||
import rosu_pp_py as rosu
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -39,7 +33,7 @@ async def lookup_beatmap(
|
||||
id: int | None = Query(default=None, alias="id"),
|
||||
md5: str | None = Query(default=None, alias="checksum"),
|
||||
filename: str | None = Query(default=None, alias="filename"),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
@@ -56,19 +50,19 @@ async def lookup_beatmap(
|
||||
if beatmap is None:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
|
||||
|
||||
|
||||
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
||||
async def get_beatmap(
|
||||
bid: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
return await BeatmapResp.from_db(beatmap, session=db, user=current_user)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
@@ -81,42 +75,27 @@ class BatchGetResp(BaseModel):
|
||||
@router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp)
|
||||
async def batch_get_beatmaps(
|
||||
b_ids: list[int] = Query(alias="id", default_factory=list),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not b_ids:
|
||||
# select 50 beatmaps by last_updated
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
).selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.order_by(col(Beatmap.last_updated).desc())
|
||||
.limit(50)
|
||||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
).selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(col(Beatmap.id).in_(b_ids))
|
||||
.limit(50)
|
||||
)
|
||||
await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50))
|
||||
).all()
|
||||
|
||||
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
||||
return BatchGetResp(
|
||||
beatmaps=[
|
||||
await BeatmapResp.from_db(bm, session=db, user=current_user)
|
||||
for bm in beatmaps
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -126,7 +105,7 @@ async def batch_get_beatmaps(
|
||||
)
|
||||
async def get_beatmap_attributes(
|
||||
beatmap: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
mods: list[str] = Query(default_factory=list),
|
||||
ruleset: GameMode | None = Query(default=None),
|
||||
ruleset_id: int | None = Query(default=None),
|
||||
@@ -153,8 +132,8 @@ async def get_beatmap_attributes(
|
||||
f"beatmap:{beatmap}:{ruleset}:"
|
||||
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
)
|
||||
if redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
if await redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
|
||||
try:
|
||||
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap)
|
||||
@@ -164,7 +143,7 @@ async def get_beatmap_attributes(
|
||||
)
|
||||
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
redis.set(key, attr.model_dump_json())
|
||||
await redis.set(key, attr.model_dump_json())
|
||||
return attr
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import (
|
||||
Beatmapset,
|
||||
BeatmapsetResp,
|
||||
User as DBUser,
|
||||
)
|
||||
from typing import Literal
|
||||
|
||||
from app.database import Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
@@ -12,9 +10,9 @@ from app.fetcher import Fetcher
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import Depends, Form, HTTPException, Query
|
||||
from fastapi.responses import RedirectResponse
|
||||
from httpx import HTTPStatusError
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -22,17 +20,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
|
||||
async def get_beatmapset(
|
||||
sid: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmapset = (
|
||||
await db.exec(
|
||||
select(Beatmapset)
|
||||
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||
.where(Beatmapset.id == sid)
|
||||
)
|
||||
).first()
|
||||
beatmapset = (await db.exec(select(Beatmapset).where(Beatmapset.id == sid))).first()
|
||||
if not beatmapset:
|
||||
try:
|
||||
resp = await fetcher.get_beatmapset(sid)
|
||||
@@ -40,5 +32,55 @@ async def get_beatmapset(
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
else:
|
||||
resp = BeatmapsetResp.from_db(beatmapset)
|
||||
resp = await BeatmapsetResp.from_db(
|
||||
beatmapset, session=db, include=["recent_favourites"], user=current_user
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@router.get("/beatmapsets/{beatmapset}/download", tags=["beatmapset"])
|
||||
async def download_beatmapset(
|
||||
beatmapset: int,
|
||||
no_video: bool = Query(True, alias="noVideo"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if current_user.country_code == "CN":
|
||||
return RedirectResponse(
|
||||
f"https://txy1.sayobot.cn/beatmaps/download/"
|
||||
f"{'novideo' if no_video else 'full'}/{beatmapset}?server=auto"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
f"https://api.nerinyan.moe/d/{beatmapset}?noVideo={no_video}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/beatmapsets/{beatmapset}/favourites", tags=["beatmapset"])
|
||||
async def favourite_beatmapset(
|
||||
beatmapset: int,
|
||||
action: Literal["favourite", "unfavourite"] = Form(),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
existing_favourite = (
|
||||
await db.exec(
|
||||
select(FavouriteBeatmapset).where(
|
||||
FavouriteBeatmapset.user_id == current_user.id,
|
||||
FavouriteBeatmapset.beatmapset_id == beatmapset,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
if action == "favourite" and existing_favourite:
|
||||
raise HTTPException(status_code=400, detail="Already favourited")
|
||||
elif action == "unfavourite" and not existing_favourite:
|
||||
raise HTTPException(status_code=400, detail="Not favourited")
|
||||
|
||||
if action == "favourite":
|
||||
favourite = FavouriteBeatmapset(
|
||||
user_id=current_user.id, beatmapset_id=beatmapset
|
||||
)
|
||||
db.add(favourite)
|
||||
else:
|
||||
await db.delete(existing_favourite)
|
||||
await db.commit()
|
||||
|
||||
@@ -1,28 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.database import (
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database import User, UserResp
|
||||
from app.database.lazer_user import ALL_INCLUDED
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user import (
|
||||
User as ApiUser,
|
||||
)
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/me/{ruleset}", response_model=ApiUser)
|
||||
@router.get("/me/", response_model=ApiUser)
|
||||
@router.get("/me/{ruleset}", response_model=UserResp)
|
||||
@router.get("/me/", response_model=UserResp)
|
||||
async def get_user_info_default(
|
||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
ruleset: GameMode | None = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取当前用户信息(默认使用osu模式)"""
|
||||
# 默认使用osu模式
|
||||
api_user = await convert_db_user_to_api_user(current_user, ruleset)
|
||||
return api_user
|
||||
return await UserResp.from_db(
|
||||
current_user,
|
||||
session,
|
||||
ALL_INCLUDED,
|
||||
ruleset,
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from app.dependencies.user import get_current_user
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, Request
|
||||
from sqlalchemy.orm import joinedload
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -26,17 +26,19 @@ async def get_relationship(
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
relationships = await db.exec(
|
||||
select(Relationship)
|
||||
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == current_user.id,
|
||||
Relationship.type == relationship_type,
|
||||
)
|
||||
)
|
||||
return [await RelationshipResp.from_db(db, rel) for rel in relationships]
|
||||
return [await RelationshipResp.from_db(db, rel) for rel in relationships.unique()]
|
||||
|
||||
|
||||
@router.post("/friends", tags=["relationship"], response_model=RelationshipResp)
|
||||
class AddFriendResp(BaseModel):
|
||||
user_relation: RelationshipResp
|
||||
|
||||
|
||||
@router.post("/friends", tags=["relationship"], response_model=AddFriendResp)
|
||||
@router.post("/blocks", tags=["relationship"])
|
||||
async def add_relationship(
|
||||
request: Request,
|
||||
@@ -87,14 +89,10 @@ async def add_relationship(
|
||||
if origin_type == RelationshipType.FOLLOW:
|
||||
relationship = (
|
||||
await db.exec(
|
||||
select(Relationship)
|
||||
.where(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == current_user_id,
|
||||
Relationship.target_id == target,
|
||||
)
|
||||
.options(
|
||||
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
).first()
|
||||
assert relationship, "Relationship should exist after commit"
|
||||
|
||||
@@ -6,8 +6,10 @@ from app.dependencies.fetcher import get_fetcher
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.room import MultiplayerRoom, MultiplayerRoomState, Room
|
||||
|
||||
from api_router import router
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -21,6 +23,7 @@ async def get_all_rooms(
|
||||
), # TODO: 对房间根据分类进行筛选(真的有人用这功能吗)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
all_roomID = (await db.exec(select(RoomIndex))).all()
|
||||
redis = get_redis()
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import (
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.score import Score, ScoreResp, process_score, process_user
|
||||
from app.database.score_token import ScoreToken, ScoreTokenResp
|
||||
from app.database import Beatmap, Score, ScoreResp, ScoreToken, ScoreTokenResp, User
|
||||
from app.database.score import get_leaderboard, process_score, process_user
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
@@ -13,6 +9,7 @@ from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
GameMode,
|
||||
LeaderboardType,
|
||||
Rank,
|
||||
SoloScoreSubmissionInfo,
|
||||
)
|
||||
@@ -21,9 +18,9 @@ from .api_router import router
|
||||
|
||||
from fastapi import Depends, Form, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select, true
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -37,44 +34,26 @@ class BeatmapScores(BaseModel):
|
||||
)
|
||||
async def get_beatmap_scores(
|
||||
beatmap: int,
|
||||
mode: GameMode,
|
||||
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
|
||||
mode: GameMode | None = Query(None),
|
||||
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
|
||||
type: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
mods: list[str] = Query(default_factory=set, alias="mods[]"),
|
||||
type: LeaderboardType = Query(LeaderboardType.GLOBAL),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="this server only contains lazer scores"
|
||||
)
|
||||
|
||||
all_scores = (
|
||||
await db.exec(
|
||||
Score.select_clause_unique(
|
||||
Score.beatmap_id == beatmap,
|
||||
col(Score.passed).is_(True),
|
||||
Score.gamemode == mode if mode is not None else true(),
|
||||
)
|
||||
)
|
||||
).all()
|
||||
|
||||
user_score = (
|
||||
await db.exec(
|
||||
Score.select_clause_unique(
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == current_user.id,
|
||||
col(Score.passed).is_(True),
|
||||
Score.gamemode == mode if mode is not None else true(),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
all_scores, user_score = await get_leaderboard(
|
||||
db, beatmap, mode, type=type, user=current_user, limit=limit, mods=mods
|
||||
)
|
||||
|
||||
return BeatmapScores(
|
||||
scores=[await ScoreResp.from_db(db, score, score.user) for score in all_scores],
|
||||
userScore=await ScoreResp.from_db(db, user_score, user_score.user)
|
||||
if user_score
|
||||
else None,
|
||||
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
|
||||
userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -94,7 +73,7 @@ async def get_user_beatmap_score(
|
||||
legacy_only: bool = Query(None),
|
||||
mode: str = Query(None),
|
||||
mods: str = Query(None), # TODO:添加mods筛选
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
@@ -103,7 +82,7 @@ async def get_user_beatmap_score(
|
||||
)
|
||||
user_score = (
|
||||
await db.exec(
|
||||
Score.select_clause(True)
|
||||
select(Score)
|
||||
.where(
|
||||
Score.gamemode == mode if mode is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
@@ -120,7 +99,7 @@ async def get_user_beatmap_score(
|
||||
else:
|
||||
return BeatmapUserScore(
|
||||
position=user_score.position if user_score.position is not None else 0,
|
||||
score=await ScoreResp.from_db(db, user_score, user_score.user),
|
||||
score=await ScoreResp.from_db(db, user_score),
|
||||
)
|
||||
|
||||
|
||||
@@ -134,7 +113,7 @@ async def get_user_all_beatmap_scores(
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
ruleset: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
@@ -143,7 +122,7 @@ async def get_user_all_beatmap_scores(
|
||||
)
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
Score.select_clause()
|
||||
select(Score)
|
||||
.where(
|
||||
Score.gamemode == ruleset if ruleset is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
@@ -153,9 +132,7 @@ async def get_user_all_beatmap_scores(
|
||||
)
|
||||
).all()
|
||||
|
||||
return [
|
||||
await ScoreResp.from_db(db, score, current_user) for score in all_user_scores
|
||||
]
|
||||
return [await ScoreResp.from_db(db, score) for score in all_user_scores]
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -166,9 +143,10 @@ async def create_solo_score(
|
||||
version_hash: str = Form(""),
|
||||
beatmap_hash: str = Form(),
|
||||
ruleset_id: int = Form(..., ge=0, le=3),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
assert current_user.id
|
||||
async with db:
|
||||
score_token = ScoreToken(
|
||||
user_id=current_user.id,
|
||||
@@ -190,7 +168,7 @@ async def submit_solo_score(
|
||||
beatmap: int,
|
||||
token: int,
|
||||
info: SoloScoreSubmissionInfo,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
fetcher=Depends(get_fetcher),
|
||||
@@ -210,9 +188,7 @@ async def submit_solo_score(
|
||||
if score_token.score_id:
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
select(Score).where(
|
||||
Score.id == score_token.score_id,
|
||||
Score.user_id == current_user.id,
|
||||
)
|
||||
@@ -246,8 +222,6 @@ async def submit_solo_score(
|
||||
score_id = score.id
|
||||
score_token.score_id = score_id
|
||||
await process_user(db, current_user, score, ranked)
|
||||
score = (
|
||||
await db.exec(Score.select_clause().where(Score.id == score_id))
|
||||
).first()
|
||||
score = (await db.exec(select(Score).where(Score.id == score_id))).first()
|
||||
assert score is not None
|
||||
return await ScoreResp.from_db(db, score, current_user)
|
||||
return await ScoreResp.from_db(db, score)
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.database import User, UserResp
|
||||
from app.database.lazer_user import SEARCH_INCLUDED
|
||||
from app.dependencies.database import get_db
|
||||
from app.models.score import INT_TO_MODE
|
||||
from app.models.user import User as ApiUser
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .api_router import router
|
||||
|
||||
@@ -17,28 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import col
|
||||
|
||||
|
||||
# ---------- Shared Utility ----------
|
||||
async def get_user_by_lookup(
|
||||
db: AsyncSession, lookup: str, key: str = "id"
|
||||
) -> DBUser | None:
|
||||
"""根据查找方式获取用户"""
|
||||
if key == "id":
|
||||
try:
|
||||
user_id = int(lookup)
|
||||
result = await db.exec(select(DBUser).where(DBUser.id == user_id))
|
||||
return result.first()
|
||||
except ValueError:
|
||||
return None
|
||||
elif key == "username":
|
||||
result = await db.exec(select(DBUser).where(DBUser.name == lookup))
|
||||
return result.first()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# ---------- Batch Users ----------
|
||||
class BatchUserResponse(BaseModel):
|
||||
users: list[ApiUser]
|
||||
users: list[UserResp]
|
||||
|
||||
|
||||
@router.get("/users", response_model=BatchUserResponse)
|
||||
@@ -51,75 +28,44 @@ async def get_users(
|
||||
):
|
||||
if user_ids:
|
||||
searched_users = (
|
||||
await session.exec(
|
||||
DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids))
|
||||
)
|
||||
await session.exec(select(User).limit(50).where(col(User.id).in_(user_ids)))
|
||||
).all()
|
||||
else:
|
||||
searched_users = (
|
||||
await session.exec(DBUser.all_select_clause().limit(50))
|
||||
).all()
|
||||
searched_users = (await session.exec(select(User).limit(50))).all()
|
||||
return BatchUserResponse(
|
||||
users=[
|
||||
await convert_db_user_to_api_user(
|
||||
searched_user, ruleset=INT_TO_MODE[searched_user.preferred_mode].value
|
||||
await UserResp.from_db(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
)
|
||||
for searched_user in searched_users
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# # ---------- Individual User ----------
|
||||
# @router.get("/users/{user_lookup}/{mode}", response_model=ApiUser)
|
||||
# @router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser)
|
||||
# async def get_user_with_mode(
|
||||
# user_lookup: str,
|
||||
# mode: Literal["osu", "taiko", "fruits", "mania"],
|
||||
# key: Literal["id", "username"] = Query("id"),
|
||||
# current_user: DBUser = Depends(get_current_user),
|
||||
# db: AsyncSession = Depends(get_db),
|
||||
# ):
|
||||
# """获取指定游戏模式的用户信息"""
|
||||
# user = await get_user_by_lookup(db, user_lookup, key)
|
||||
# if not user:
|
||||
# raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# return await convert_db_user_to_api_user(user, mode)
|
||||
|
||||
|
||||
# @router.get("/users/{user_lookup}", response_model=ApiUser)
|
||||
# @router.get("/users/{user_lookup}/", response_model=ApiUser)
|
||||
# async def get_user_default(
|
||||
# user_lookup: str,
|
||||
# key: Literal["id", "username"] = Query("id"),
|
||||
# current_user: DBUser = Depends(get_current_user),
|
||||
# db: AsyncSession = Depends(get_db),
|
||||
# ):
|
||||
# """获取用户信息(默认使用osu模式,但包含所有模式的统计信息)"""
|
||||
# user = await get_user_by_lookup(db, user_lookup, key)
|
||||
# if not user:
|
||||
# raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# return await convert_db_user_to_api_user(user, "osu")
|
||||
|
||||
|
||||
@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
|
||||
@router.get("/users/{user}/", response_model=ApiUser)
|
||||
@router.get("/users/{user}", response_model=ApiUser)
|
||||
@router.get("/users/{user}/{ruleset}", response_model=UserResp)
|
||||
@router.get("/users/{user}/", response_model=UserResp)
|
||||
@router.get("/users/{user}", response_model=UserResp)
|
||||
async def get_user_info(
|
||||
user: str,
|
||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||
ruleset: GameMode | None = None,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
DBUser.all_select_clause().where(
|
||||
DBUser.id == int(user)
|
||||
select(User).where(
|
||||
User.id == int(user)
|
||||
if user.isdigit()
|
||||
else DBUser.name == user.removeprefix("@")
|
||||
else User.username == user.removeprefix("@")
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not searched_user:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
return await convert_db_user_to_api_user(searched_user, ruleset=ruleset)
|
||||
return await UserResp.from_db(
|
||||
searched_user,
|
||||
session,
|
||||
include=SEARCH_INCLUDED,
|
||||
ruleset=ruleset,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user