Merge branch 'main' into score-database-model
This commit is contained in:
16
.pre-commit-config.yaml
Normal file
16
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
default_install_hook_types: [pre-commit, prepare-commit-msg]
|
||||
ci:
|
||||
autofix_commit_msg: "chore(deps): auto fix by pre-commit hooks"
|
||||
autofix_prs: true
|
||||
autoupdate_branch: master
|
||||
autoupdate_schedule: monthly
|
||||
autoupdate_commit_msg: "chore(deps): auto update by pre-commit hooks"
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.2
|
||||
hooks:
|
||||
- id: ruff-check
|
||||
args: [--fix]
|
||||
stages: [pre-commit]
|
||||
- id: ruff-format
|
||||
stages: [pre-commit]
|
||||
49
app/auth.py
49
app/auth.py
@@ -14,7 +14,8 @@ from app.database import (
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
# 密码哈希上下文
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
@@ -70,7 +71,9 @@ def get_password_hash(password: str) -> str:
|
||||
return pw_bcrypt.decode()
|
||||
|
||||
|
||||
def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None:
|
||||
async def authenticate_user_legacy(
|
||||
db: AsyncSession, name: str, password: str
|
||||
) -> DBUser | None:
|
||||
"""
|
||||
验证用户身份 - 使用类似 from_login 的逻辑
|
||||
"""
|
||||
@@ -79,7 +82,7 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
|
||||
|
||||
# 2. 根据用户名查找用户
|
||||
statement = select(DBUser).where(DBUser.name == name)
|
||||
user = db.exec(statement).first()
|
||||
user = (await db.exec(statement)).first()
|
||||
if not user:
|
||||
return None
|
||||
|
||||
@@ -107,9 +110,11 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
|
||||
return None
|
||||
|
||||
|
||||
def authenticate_user(db: Session, username: str, password: str) -> DBUser | None:
|
||||
async def authenticate_user(
|
||||
db: AsyncSession, username: str, password: str
|
||||
) -> DBUser | None:
|
||||
"""验证用户身份"""
|
||||
return authenticate_user_legacy(db, username, password)
|
||||
return await authenticate_user_legacy(db, username, password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
||||
@@ -147,24 +152,28 @@ def verify_token(token: str) -> dict | None:
|
||||
return None
|
||||
|
||||
|
||||
def store_token(
|
||||
db: Session, user_id: int, access_token: str, refresh_token: str, expires_in: int
|
||||
async def store_token(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
expires_in: int,
|
||||
) -> OAuthToken:
|
||||
"""存储令牌到数据库"""
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
# 删除用户的旧令牌
|
||||
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
|
||||
old_tokens = db.exec(statement).all()
|
||||
old_tokens = (await db.exec(statement)).all()
|
||||
for token in old_tokens:
|
||||
db.delete(token)
|
||||
await db.delete(token)
|
||||
|
||||
# 检查是否有重复的 access_token
|
||||
duplicate_token = db.exec(
|
||||
select(OAuthToken).where(OAuthToken.access_token == access_token)
|
||||
duplicate_token = (
|
||||
await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))
|
||||
).first()
|
||||
if duplicate_token:
|
||||
db.delete(duplicate_token)
|
||||
await db.delete(duplicate_token)
|
||||
|
||||
# 创建新令牌记录
|
||||
token_record = OAuthToken(
|
||||
@@ -174,24 +183,28 @@ def store_token(
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(token_record)
|
||||
db.commit()
|
||||
db.refresh(token_record)
|
||||
await db.commit()
|
||||
await db.refresh(token_record)
|
||||
return token_record
|
||||
|
||||
|
||||
def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None:
|
||||
async def get_token_by_access_token(
|
||||
db: AsyncSession, access_token: str
|
||||
) -> OAuthToken | None:
|
||||
"""根据访问令牌获取令牌记录"""
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.access_token == access_token,
|
||||
OAuthToken.expires_at > datetime.utcnow(),
|
||||
)
|
||||
return db.exec(statement).first()
|
||||
return (await db.exec(statement)).first()
|
||||
|
||||
|
||||
def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None:
|
||||
async def get_token_by_refresh_token(
|
||||
db: AsyncSession, refresh_token: str
|
||||
) -> OAuthToken | None:
|
||||
"""根据刷新令牌获取令牌记录"""
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.refresh_token == refresh_token,
|
||||
OAuthToken.expires_at > datetime.utcnow(),
|
||||
)
|
||||
return db.exec(statement).first()
|
||||
return (await db.exec(statement)).first()
|
||||
|
||||
@@ -10,7 +10,7 @@ load_dotenv()
|
||||
class Settings:
|
||||
# 数据库设置
|
||||
DATABASE_URL: str = os.getenv(
|
||||
"DATABASE_URL", "mysql+pymysql://root:password@localhost:3306/osu_api"
|
||||
"DATABASE_URL", "mysql+aiomysql://root:password@localhost:3306/osu_api"
|
||||
)
|
||||
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ from sqlmodel import Field, Relationship, SQLModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
class Team(SQLModel, table=True):
|
||||
__tablename__ = "teams" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class User(SQLModel, table=True):
|
||||
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
# 主键
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
id: int = Field(default=None, primary_key=True, index=True, nullable=False)
|
||||
|
||||
# 基本信息(匹配 migrations 中的结构)
|
||||
name: str = Field(max_length=32, unique=True, index=True) # 用户名
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .database import get_db as get_db
|
||||
from .user import get_current_user as get_current_user
|
||||
from .user import get_current_user as get_current_user
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlmodel import Session, create_engine
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
try:
|
||||
import redis
|
||||
@@ -9,7 +11,7 @@ except ImportError:
|
||||
from app.config import settings
|
||||
|
||||
# 数据库引擎
|
||||
engine = create_engine(settings.DATABASE_URL)
|
||||
engine = create_async_engine(settings.DATABASE_URL)
|
||||
|
||||
# Redis 连接
|
||||
if redis:
|
||||
@@ -19,11 +21,16 @@ else:
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
def get_db():
|
||||
with Session(engine) as session:
|
||||
async def get_db():
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def create_tables():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
|
||||
# Redis 依赖
|
||||
def get_redis():
|
||||
return redis_client
|
||||
|
||||
@@ -9,14 +9,16 @@ from .database import get_db
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy.orm import joinedload, selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DBUser:
|
||||
"""获取当前认证用户"""
|
||||
token = credentials.credentials
|
||||
@@ -27,9 +29,31 @@ async def get_current_user(
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user_by_token(token: str, db: Session) -> DBUser | None:
|
||||
token_record = get_token_by_access_token(db, token)
|
||||
async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None:
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
return None
|
||||
user = db.exec(select(DBUser).where(DBUser.id == token_record.user_id)).first()
|
||||
user = (
|
||||
await db.exec(
|
||||
select(DBUser)
|
||||
.options(
|
||||
joinedload(DBUser.lazer_profile), # pyright: ignore[reportArgumentType]
|
||||
joinedload(DBUser.lazer_counts), # pyright: ignore[reportArgumentType]
|
||||
joinedload(DBUser.daily_challenge_stats), # pyright: ignore[reportArgumentType]
|
||||
joinedload(DBUser.avatar), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_statistics), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_achievements), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_profile_sections), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.statistics), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.team_membership), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.rank_history), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.active_banners), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_badges), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_replays_watched), # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
.where(DBUser.id == token_record.user_id)
|
||||
)
|
||||
).first()
|
||||
return user
|
||||
|
||||
@@ -4,14 +4,14 @@ import datetime
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from .signalr import MessagePackArrayModel
|
||||
from .score import (
|
||||
APIMod as APIModBase,
|
||||
HitResult,
|
||||
)
|
||||
from .signalr import MessagePackArrayModel
|
||||
|
||||
import msgpack
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
|
||||
class APIMod(APIModBase, MessagePackArrayModel): ...
|
||||
|
||||
@@ -87,7 +87,7 @@ class UserAchievement(BaseModel):
|
||||
return LazerUserAchievement(
|
||||
user_id=user_id,
|
||||
achievement_id=self.achievement_id,
|
||||
achieved_at=self.achieved_at
|
||||
achieved_at=self.achieved_at,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.dependencies import get_db
|
||||
from app.models.oauth import TokenResponse
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from sqlmodel import Session
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
router = APIRouter(tags=["osu! OAuth 认证"])
|
||||
|
||||
@@ -28,7 +28,7 @@ async def oauth_token(
|
||||
username: str | None = Form(None),
|
||||
password: str | None = Form(None),
|
||||
refresh_token: str | None = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""OAuth 令牌端点"""
|
||||
# 验证客户端凭据
|
||||
@@ -46,7 +46,7 @@ async def oauth_token(
|
||||
)
|
||||
|
||||
# 验证用户
|
||||
user = authenticate_user(db, username, password)
|
||||
user = await authenticate_user(db, username, password)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||
|
||||
@@ -58,9 +58,9 @@ async def oauth_token(
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
store_token(
|
||||
await store_token(
|
||||
db,
|
||||
getattr(user, "id"),
|
||||
user.id,
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
@@ -80,7 +80,7 @@ async def oauth_token(
|
||||
raise HTTPException(status_code=400, detail="Refresh token required")
|
||||
|
||||
# 验证刷新令牌
|
||||
token_record = get_token_by_refresh_token(db, refresh_token)
|
||||
token_record = await get_token_by_refresh_token(db, refresh_token)
|
||||
if not token_record:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
|
||||
@@ -92,10 +92,9 @@ async def oauth_token(
|
||||
new_refresh_token = generate_refresh_token()
|
||||
|
||||
# 更新令牌
|
||||
user_id = int(getattr(token_record, "user_id"))
|
||||
store_token(
|
||||
await store_token(
|
||||
db,
|
||||
user_id,
|
||||
token_record.user_id,
|
||||
access_token,
|
||||
new_refresh_token,
|
||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
|
||||
@@ -5,6 +5,7 @@ from app.database import (
|
||||
BeatmapResp,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user
|
||||
|
||||
@@ -12,16 +13,24 @@ from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
||||
async def get_beatmap(
|
||||
bid: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
beatmap = db.exec(select(Beatmap).where(Beatmap.id == bid)).first()
|
||||
beatmap = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||
.where(Beatmap.id == bid)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
@@ -36,16 +45,30 @@ class BatchGetResp(BaseModel):
|
||||
async def batch_get_beatmaps(
|
||||
b_ids: list[int] = Query(alias="id", default_factory=list),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not b_ids:
|
||||
# select 50 beatmaps by last_updated
|
||||
beatmaps = db.exec(
|
||||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
.order_by(col(Beatmap.last_updated).desc())
|
||||
.limit(50)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
beatmaps = db.exec(
|
||||
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
.where(col(Beatmap.id).in_(b_ids))
|
||||
.limit(50)
|
||||
)
|
||||
).all()
|
||||
|
||||
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
||||
|
||||
@@ -11,16 +11,24 @@ from app.dependencies.user import get_current_user
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
|
||||
async def get_beatmapset(
|
||||
sid: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
beatmapset = db.exec(select(Beatmapset).where(Beatmapset.id == sid)).first()
|
||||
beatmapset = (
|
||||
await db.exec(
|
||||
select(Beatmapset)
|
||||
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||
.where(Beatmapset.id == sid)
|
||||
)
|
||||
).first()
|
||||
if not beatmapset:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
return BeatmapsetResp.from_db(beatmapset)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Literal
|
||||
from app.database import (
|
||||
User as DBUser,
|
||||
)
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user import (
|
||||
User as ApiUser,
|
||||
)
|
||||
@@ -14,7 +14,6 @@ from app.utils import convert_db_user_to_api_user
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@router.get("/me/{ruleset}", response_model=ApiUser)
|
||||
@@ -22,9 +21,8 @@ from sqlalchemy.orm import Session
|
||||
async def get_user_info_default(
|
||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取当前用户信息(默认使用osu模式)"""
|
||||
# 默认使用osu模式
|
||||
api_user = convert_db_user_to_api_user(current_user, ruleset, db)
|
||||
api_user = await convert_db_user_to_api_user(current_user, ruleset)
|
||||
return api_user
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .router import router as signalr_router
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class SignalRException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvokeException(SignalRException):
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .hub import Hub
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .hub import Hub
|
||||
|
||||
|
||||
class MultiplayerHub(Hub): ...
|
||||
class MultiplayerHub(Hub): ...
|
||||
|
||||
@@ -8,10 +8,8 @@ from .hub import Client, Hub
|
||||
class SpectatorHub(Hub):
|
||||
async def BeginPlaySession(
|
||||
self, client: Client, score_token: int, state: SpectatorState
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
async def SendFrameData(
|
||||
self, client: Client, frame_data: FrameDataBundle
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@@ -4,7 +4,6 @@ from enum import IntEnum
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
SEP = b"\x1e"
|
||||
|
||||
@@ -18,6 +17,7 @@ class PacketType(IntEnum):
|
||||
PING = 6
|
||||
CLOSE = 7
|
||||
|
||||
|
||||
class ResultKind(IntEnum):
|
||||
ERROR = 1
|
||||
VOID = 2
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import info
|
||||
import time
|
||||
from typing import Literal
|
||||
import uuid
|
||||
@@ -9,15 +8,14 @@ import uuid
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user_by_token, security
|
||||
from app.dependencies.user import get_current_user_by_token
|
||||
from app.models.signalr import NegotiateResponse, Transport
|
||||
from app.router.signalr.packet import SEP
|
||||
|
||||
from .hub import Hubs
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -48,7 +46,7 @@ async def connect(
|
||||
websocket: WebSocket,
|
||||
id: str,
|
||||
authorization: str = Header(...),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
token = authorization[7:]
|
||||
user_id = id.split(":")[0]
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
from typing import Any, Callable, ForwardRef, cast
|
||||
from typing import Any, ForwardRef, cast
|
||||
|
||||
|
||||
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
|
||||
|
||||
39
app/utils.py
39
app/utils.py
@@ -11,7 +11,6 @@ from app.database import (
|
||||
from app.models.user import (
|
||||
Country,
|
||||
Cover,
|
||||
DailyChallengeStats,
|
||||
GradeCounts,
|
||||
Kudosu,
|
||||
Level,
|
||||
@@ -23,12 +22,8 @@ from app.models.user import (
|
||||
UserAchievement,
|
||||
)
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def convert_db_user_to_api_user(
|
||||
db_user: DBUser, ruleset: str = "osu", db_session: Session | None = None
|
||||
) -> User:
|
||||
async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User:
|
||||
"""将数据库用户模型转换为API用户模型(使用 Lazer 表)"""
|
||||
|
||||
# 从db_user获取基本字段值
|
||||
@@ -73,7 +68,7 @@ def convert_db_user_to_api_user(
|
||||
kudosu = Kudosu(available=0, total=0)
|
||||
|
||||
# 获取计数信息
|
||||
counts = LazerUserCounts(user_id=user_id)
|
||||
# counts = LazerUserCounts(user_id=user_id)
|
||||
|
||||
# 转换统计信息
|
||||
statistics = Statistics(
|
||||
@@ -178,21 +173,21 @@ def convert_db_user_to_api_user(
|
||||
rank_history = RankHistory(mode=ruleset, data=rank_history_data)
|
||||
|
||||
# 转换每日挑战统计
|
||||
daily_challenge_stats = None
|
||||
if db_user.daily_challenge_stats:
|
||||
dcs = db_user.daily_challenge_stats
|
||||
daily_challenge_stats = DailyChallengeStats(
|
||||
daily_streak_best=dcs.daily_streak_best,
|
||||
daily_streak_current=dcs.daily_streak_current,
|
||||
last_update=dcs.last_update,
|
||||
last_weekly_streak=dcs.last_weekly_streak,
|
||||
playcount=dcs.playcount,
|
||||
top_10p_placements=dcs.top_10p_placements,
|
||||
top_50p_placements=dcs.top_50p_placements,
|
||||
user_id=dcs.user_id,
|
||||
weekly_streak_best=dcs.weekly_streak_best,
|
||||
weekly_streak_current=dcs.weekly_streak_current,
|
||||
)
|
||||
# daily_challenge_stats = None
|
||||
# if db_user.daily_challenge_stats:
|
||||
# dcs = db_user.daily_challenge_stats
|
||||
# daily_challenge_stats = DailyChallengeStats(
|
||||
# daily_streak_best=dcs.daily_streak_best,
|
||||
# daily_streak_current=dcs.daily_streak_current,
|
||||
# last_update=dcs.last_update,
|
||||
# last_weekly_streak=dcs.last_weekly_streak,
|
||||
# playcount=dcs.playcount,
|
||||
# top_10p_placements=dcs.top_10p_placements,
|
||||
# top_50p_placements=dcs.top_50p_placements,
|
||||
# user_id=dcs.user_id,
|
||||
# weekly_streak_best=dcs.weekly_streak_best,
|
||||
# weekly_streak_current=dcs.weekly_streak_current,
|
||||
# )
|
||||
|
||||
# 转换最高排名
|
||||
rank_highest = None
|
||||
|
||||
@@ -5,419 +5,82 @@ osu! API 模拟服务器的示例数据填充脚本
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
from app.auth import get_password_hash
|
||||
from app.database import (
|
||||
DailyChallengeStats,
|
||||
LazerUserAchievement,
|
||||
LazerUserStatistics,
|
||||
RankHistory,
|
||||
User,
|
||||
)
|
||||
from app.dependencies.database import engine, get_db
|
||||
from app.dependencies.database import create_tables, engine
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
# 创建所有表
|
||||
SQLModel.metadata.create_all(bind=engine)
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
def create_sample_user():
|
||||
async def create_sample_user():
|
||||
"""创建示例用户数据"""
|
||||
with next(get_db()) as db:
|
||||
# 检查用户是否已存在
|
||||
from sqlmodel import select
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
# 检查用户是否已存在
|
||||
statement = select(User).where(User.name == "Googujiang")
|
||||
result = await session.execute(statement)
|
||||
existing_user = result.scalars().first()
|
||||
if existing_user:
|
||||
print("示例用户已存在,跳过创建")
|
||||
return existing_user
|
||||
|
||||
statement = select(User).where(User.name == "Googujiang")
|
||||
existing_user = db.exec(statement).first()
|
||||
if existing_user:
|
||||
print("示例用户已存在,跳过创建")
|
||||
return existing_user
|
||||
# 当前时间戳
|
||||
# current_timestamp = int(time.time())
|
||||
join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp())
|
||||
last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp())
|
||||
|
||||
# 当前时间戳
|
||||
current_timestamp = int(time.time())
|
||||
join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp())
|
||||
last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp())
|
||||
# 创建用户
|
||||
user = User(
|
||||
name="Googujiang",
|
||||
safe_name="googujiang", # 安全用户名(小写)
|
||||
email="googujiang@example.com",
|
||||
priv=1, # 默认权限
|
||||
pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式
|
||||
country="JP",
|
||||
silence_end=0,
|
||||
donor_end=0,
|
||||
creation_time=join_timestamp,
|
||||
latest_activity=last_visit_timestamp,
|
||||
clan_id=0,
|
||||
clan_priv=0,
|
||||
preferred_mode=0, # 0 = osu!
|
||||
play_style=0,
|
||||
custom_badge_name=None,
|
||||
custom_badge_icon=None,
|
||||
userpage_content="「世界に忘れられた」",
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
name="Googujiang",
|
||||
safe_name="googujiang", # 安全用户名(小写)
|
||||
email="googujiang@example.com",
|
||||
priv=1, # 默认权限
|
||||
pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式
|
||||
country="JP",
|
||||
silence_end=0,
|
||||
donor_end=0,
|
||||
creation_time=join_timestamp,
|
||||
latest_activity=last_visit_timestamp,
|
||||
clan_id=0,
|
||||
clan_priv=0,
|
||||
preferred_mode=0, # 0 = osu!
|
||||
play_style=0,
|
||||
custom_badge_name=None,
|
||||
custom_badge_icon=None,
|
||||
userpage_content="「世界に忘れられた」",
|
||||
api_key=None,
|
||||
# # 兼容性字段
|
||||
# avatar_url="https://a.ppy.sh/15651670?1732362658.jpeg",
|
||||
# cover_url="https://assets.ppy.sh/user-profile-covers/15651670/0fc7b77adef39765a570e7f535bc383e5a848850d41a8943f8857984330b8bc6.jpeg",
|
||||
# has_supported=True,
|
||||
# interests="「世界に忘れられた」",
|
||||
# location="咕谷国",
|
||||
# website="https://gmoe.cc",
|
||||
# playstyle=["mouse", "keyboard", "tablet"],
|
||||
# profile_order=[
|
||||
# "me",
|
||||
# "recent_activity",
|
||||
# "top_ranks",
|
||||
# "medals",
|
||||
# "historical",
|
||||
# "beatmaps",
|
||||
# "kudosu",
|
||||
# ],
|
||||
# beatmap_playcounts_count=3306,
|
||||
# favourite_beatmapset_count=15,
|
||||
# follower_count=98,
|
||||
# graveyard_beatmapset_count=7,
|
||||
# mapping_follower_count=1,
|
||||
# previous_usernames=["hehejun"],
|
||||
# monthly_playcounts=[
|
||||
# {"start_date": "2019-11-01", "count": 43},
|
||||
# {"start_date": "2020-04-01", "count": 216},
|
||||
# {"start_date": "2020-05-01", "count": 656},
|
||||
# {"start_date": "2020-07-01", "count": 158},
|
||||
# {"start_date": "2020-08-01", "count": 174},
|
||||
# {"start_date": "2020-10-01", "count": 13},
|
||||
# {"start_date": "2020-11-01", "count": 52},
|
||||
# {"start_date": "2020-12-01", "count": 140},
|
||||
# {"start_date": "2021-01-01", "count": 359},
|
||||
# {"start_date": "2021-02-01", "count": 452},
|
||||
# {"start_date": "2021-03-01", "count": 77},
|
||||
# {"start_date": "2021-04-01", "count": 114},
|
||||
# {"start_date": "2021-05-01", "count": 270},
|
||||
# {"start_date": "2021-06-01", "count": 148},
|
||||
# {"start_date": "2021-07-01", "count": 246},
|
||||
# {"start_date": "2021-08-01", "count": 56},
|
||||
# {"start_date": "2021-09-01", "count": 136},
|
||||
# {"start_date": "2021-10-01", "count": 45},
|
||||
# {"start_date": "2021-11-01", "count": 98},
|
||||
# {"start_date": "2021-12-01", "count": 54},
|
||||
# {"start_date": "2022-01-01", "count": 88},
|
||||
# {"start_date": "2022-02-01", "count": 45},
|
||||
# {"start_date": "2022-03-01", "count": 6},
|
||||
# {"start_date": "2022-04-01", "count": 54},
|
||||
# {"start_date": "2022-05-01", "count": 105},
|
||||
# {"start_date": "2022-06-01", "count": 37},
|
||||
# {"start_date": "2022-07-01", "count": 88},
|
||||
# {"start_date": "2022-08-01", "count": 7},
|
||||
# {"start_date": "2022-09-01", "count": 9},
|
||||
# {"start_date": "2022-10-01", "count": 6},
|
||||
# {"start_date": "2022-11-01", "count": 2},
|
||||
# {"start_date": "2022-12-01", "count": 16},
|
||||
# {"start_date": "2023-01-01", "count": 7},
|
||||
# {"start_date": "2023-04-01", "count": 16},
|
||||
# {"start_date": "2023-05-01", "count": 3},
|
||||
# {"start_date": "2023-06-01", "count": 8},
|
||||
# {"start_date": "2023-07-01", "count": 23},
|
||||
# {"start_date": "2023-08-01", "count": 3},
|
||||
# {"start_date": "2023-09-01", "count": 1},
|
||||
# {"start_date": "2023-10-01", "count": 25},
|
||||
# {"start_date": "2023-11-01", "count": 160},
|
||||
# {"start_date": "2023-12-01", "count": 306},
|
||||
# {"start_date": "2024-01-01", "count": 735},
|
||||
# {"start_date": "2024-02-01", "count": 420},
|
||||
# {"start_date": "2024-03-01", "count": 549},
|
||||
# {"start_date": "2024-04-01", "count": 466},
|
||||
# {"start_date": "2024-05-01", "count": 333},
|
||||
# {"start_date": "2024-06-01", "count": 1126},
|
||||
# {"start_date": "2024-07-01", "count": 534},
|
||||
# {"start_date": "2024-08-01", "count": 280},
|
||||
# {"start_date": "2024-09-01", "count": 116},
|
||||
# {"start_date": "2024-10-01", "count": 120},
|
||||
# {"start_date": "2024-11-01", "count": 332},
|
||||
# {"start_date": "2024-12-01", "count": 243},
|
||||
# {"start_date": "2025-01-01", "count": 122},
|
||||
# {"start_date": "2025-02-01", "count": 379},
|
||||
# {"start_date": "2025-03-01", "count": 278},
|
||||
# {"start_date": "2025-04-01", "count": 296},
|
||||
# {"start_date": "2025-05-01", "count": 964},
|
||||
# {"start_date": "2025-06-01", "count": 821},
|
||||
# {"start_date": "2025-07-01", "count": 230},
|
||||
# ],
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
# 确保用户ID存在
|
||||
if user.id is None:
|
||||
raise ValueError("User ID is None after saving to database")
|
||||
|
||||
# 确保用户ID存在
|
||||
if user.id is None:
|
||||
raise ValueError("User ID is None after saving to database")
|
||||
|
||||
# 创建 osu! 模式统计
|
||||
osu_stats = LazerUserStatistics(
|
||||
user_id=user.id,
|
||||
mode="osu",
|
||||
count_100=276274,
|
||||
count_300=1932068,
|
||||
count_50=32776,
|
||||
count_miss=111064,
|
||||
level_current=97,
|
||||
level_progress=96,
|
||||
global_rank=298026,
|
||||
country_rank=11221,
|
||||
pp=2826.26,
|
||||
ranked_score=4415081049,
|
||||
hit_accuracy=95.7168,
|
||||
play_count=12711,
|
||||
play_time=836529,
|
||||
total_score=12390140370,
|
||||
total_hits=2241118,
|
||||
maximum_combo=1859,
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=True,
|
||||
grade_ss=14,
|
||||
grade_ssh=3,
|
||||
grade_s=322,
|
||||
grade_sh=11,
|
||||
grade_a=757,
|
||||
rank_highest=295701,
|
||||
rank_highest_updated_at=datetime(2025, 7, 2, 17, 30, 21),
|
||||
)
|
||||
|
||||
# 创建 taiko 模式统计
|
||||
taiko_stats = LazerUserStatistics(
|
||||
user_id=user.id,
|
||||
mode="taiko",
|
||||
count_100=160,
|
||||
count_300=154,
|
||||
count_50=0,
|
||||
count_miss=480,
|
||||
level_current=2,
|
||||
level_progress=49,
|
||||
global_rank=None,
|
||||
pp=0,
|
||||
ranked_score=0,
|
||||
hit_accuracy=0,
|
||||
play_count=6,
|
||||
play_time=217,
|
||||
total_score=79301,
|
||||
total_hits=314,
|
||||
maximum_combo=0,
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=False,
|
||||
)
|
||||
|
||||
# 创建 fruits 模式统计
|
||||
fruits_stats = LazerUserStatistics(
|
||||
user_id=user.id,
|
||||
mode="fruits",
|
||||
count_100=109,
|
||||
count_300=1613,
|
||||
count_50=1861,
|
||||
count_miss=328,
|
||||
level_current=6,
|
||||
level_progress=14,
|
||||
global_rank=None,
|
||||
pp=0,
|
||||
ranked_score=343854,
|
||||
hit_accuracy=89.4779,
|
||||
play_count=19,
|
||||
play_time=669,
|
||||
total_score=1362651,
|
||||
total_hits=3583,
|
||||
maximum_combo=75,
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=False,
|
||||
grade_a=1,
|
||||
)
|
||||
|
||||
# 创建 mania 模式统计
|
||||
mania_stats = LazerUserStatistics(
|
||||
user_id=user.id,
|
||||
mode="mania",
|
||||
count_100=7867,
|
||||
count_300=12104,
|
||||
count_50=991,
|
||||
count_miss=2951,
|
||||
level_current=12,
|
||||
level_progress=89,
|
||||
global_rank=660670,
|
||||
pp=25.3784,
|
||||
ranked_score=3812295,
|
||||
hit_accuracy=77.9316,
|
||||
play_count=85,
|
||||
play_time=4834,
|
||||
total_score=13454470,
|
||||
total_hits=20962,
|
||||
maximum_combo=573,
|
||||
replays_watched_by_others=0,
|
||||
is_ranked=True,
|
||||
grade_a=1,
|
||||
)
|
||||
|
||||
db.add_all([osu_stats, taiko_stats, fruits_stats, mania_stats])
|
||||
|
||||
# 创建每日挑战统计
|
||||
daily_challenge = DailyChallengeStats(
|
||||
user_id=user.id,
|
||||
daily_streak_best=1,
|
||||
daily_streak_current=0,
|
||||
last_update=datetime(2025, 6, 21, 0, 0, 0),
|
||||
last_weekly_streak=datetime(2025, 6, 19, 0, 0, 0),
|
||||
playcount=1,
|
||||
top_10p_placements=0,
|
||||
top_50p_placements=0,
|
||||
weekly_streak_best=1,
|
||||
weekly_streak_current=0,
|
||||
)
|
||||
|
||||
db.add(daily_challenge)
|
||||
|
||||
# 创建排名历史 (最近90天的数据)
|
||||
rank_data = [
|
||||
322806,
|
||||
323092,
|
||||
323341,
|
||||
323616,
|
||||
323853,
|
||||
324106,
|
||||
324378,
|
||||
324676,
|
||||
324958,
|
||||
325254,
|
||||
325492,
|
||||
325780,
|
||||
326075,
|
||||
326356,
|
||||
326586,
|
||||
326845,
|
||||
327067,
|
||||
327286,
|
||||
327526,
|
||||
327778,
|
||||
328039,
|
||||
328347,
|
||||
328631,
|
||||
328858,
|
||||
329323,
|
||||
329557,
|
||||
329809,
|
||||
329911,
|
||||
330188,
|
||||
330425,
|
||||
330650,
|
||||
330881,
|
||||
331068,
|
||||
331325,
|
||||
331575,
|
||||
331816,
|
||||
332061,
|
||||
328959,
|
||||
315648,
|
||||
315881,
|
||||
308784,
|
||||
309023,
|
||||
309252,
|
||||
309433,
|
||||
309537,
|
||||
309364,
|
||||
309548,
|
||||
308957,
|
||||
309182,
|
||||
309426,
|
||||
309607,
|
||||
309831,
|
||||
310054,
|
||||
310269,
|
||||
310485,
|
||||
310714,
|
||||
310956,
|
||||
310924,
|
||||
311125,
|
||||
311203,
|
||||
311422,
|
||||
311640,
|
||||
303091,
|
||||
303309,
|
||||
303500,
|
||||
303691,
|
||||
303758,
|
||||
303750,
|
||||
303957,
|
||||
299867,
|
||||
300088,
|
||||
300273,
|
||||
300457,
|
||||
295799,
|
||||
295976,
|
||||
296153,
|
||||
296350,
|
||||
296566,
|
||||
296756,
|
||||
296933,
|
||||
297141,
|
||||
297314,
|
||||
297480,
|
||||
297114,
|
||||
297296,
|
||||
297480,
|
||||
297645,
|
||||
297815,
|
||||
297993,
|
||||
298026,
|
||||
]
|
||||
|
||||
rank_history = RankHistory(user_id=user.id, mode="osu", rank_data=rank_data)
|
||||
|
||||
db.add(rank_history)
|
||||
|
||||
# 创建一些成就
|
||||
achievements = [
|
||||
LazerUserAchievement(
|
||||
user_id=user.id,
|
||||
achievement_id=336,
|
||||
achieved_at=datetime(2025, 6, 21, 19, 6, 32),
|
||||
),
|
||||
LazerUserAchievement(
|
||||
user_id=user.id,
|
||||
achievement_id=319,
|
||||
achieved_at=datetime(2025, 6, 1, 0, 52, 0),
|
||||
),
|
||||
LazerUserAchievement(
|
||||
user_id=user.id,
|
||||
achievement_id=222,
|
||||
achieved_at=datetime(2025, 5, 28, 12, 24, 37),
|
||||
),
|
||||
LazerUserAchievement(
|
||||
user_id=user.id,
|
||||
achievement_id=38,
|
||||
achieved_at=datetime(2024, 7, 5, 15, 43, 23),
|
||||
),
|
||||
LazerUserAchievement(
|
||||
user_id=user.id,
|
||||
achievement_id=67,
|
||||
achieved_at=datetime(2024, 6, 24, 5, 6, 44),
|
||||
),
|
||||
]
|
||||
|
||||
db.add_all(achievements)
|
||||
|
||||
db.commit()
|
||||
print(f"成功创建示例用户: {user.name} (ID: {user.id})")
|
||||
print(f"安全用户名: {user.safe_name}")
|
||||
print(f"邮箱: {user.email}")
|
||||
print(f"国家: {user.country}")
|
||||
return user
|
||||
print(f"成功创建示例用户: {user.name} (ID: {user.id})")
|
||||
print(f"安全用户名: {user.safe_name}")
|
||||
print(f"邮箱: {user.email}")
|
||||
print(f"国家: {user.country}")
|
||||
return user
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
print("开始创建示例数据...")
|
||||
user = create_sample_user()
|
||||
await create_tables()
|
||||
user = await create_sample_user()
|
||||
print("示例数据创建完成!")
|
||||
print(f"用户名: {user.name}")
|
||||
print("密码: password123")
|
||||
print("现在您可以使用这些凭据来测试API了。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
17
main.py
17
main.py
@@ -1,24 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
|
||||
from app.config import settings
|
||||
from app.dependencies.database import engine
|
||||
from app.dependencies.database import create_tables
|
||||
from app.router import api_router, auth_router, signalr_router
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
# 注意: 表结构现在通过 migrations 管理,不再自动创建
|
||||
# 如需创建表,请运行: python quick_sync.py
|
||||
|
||||
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# on startup
|
||||
await create_tables()
|
||||
# on shutdown
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
|
||||
app.include_router(api_router, prefix="/api/v2")
|
||||
app.include_router(signalr_router, prefix="/signalr")
|
||||
app.include_router(auth_router)
|
||||
|
||||
SQLModel.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
|
||||
@@ -5,6 +5,7 @@ description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"aiomysql>=0.2.0",
|
||||
"alembic>=1.12.1",
|
||||
"bcrypt>=4.1.2",
|
||||
"cryptography>=41.0.7",
|
||||
@@ -12,7 +13,6 @@ dependencies = [
|
||||
"msgpack>=1.1.1",
|
||||
"passlib[bcrypt]>=1.7.4",
|
||||
"pydantic[email]>=2.5.0",
|
||||
"pymysql>=1.1.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"python-multipart>=0.0.6",
|
||||
@@ -85,5 +85,6 @@ reportIncompatibleVariableOverride = false
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"msgpack-types>=0.5.0",
|
||||
"pre-commit>=4.2.0",
|
||||
"ruff>=0.12.4",
|
||||
]
|
||||
|
||||
@@ -11,7 +11,7 @@ python-multipart==0.0.6
|
||||
pydantic[email]~=2.11.7
|
||||
python-dotenv~=1.1.1
|
||||
bcrypt~=4.3.0
|
||||
|
||||
msgpack~=1.1.1
|
||||
sqlmodel~=0.0.24
|
||||
starlette~=0.47.2
|
||||
starlette~=0.47.2
|
||||
aiomysql==0.2.0
|
||||
|
||||
140
test_lazer.py
140
test_lazer.py
@@ -12,108 +12,106 @@ import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from app.database import User
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.database import engine
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
def test_lazer_tables():
|
||||
async def test_lazer_tables():
|
||||
"""测试 lazer 表的基本功能"""
|
||||
print("测试 Lazer API 表支持...")
|
||||
|
||||
# 获取数据库会话
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
try:
|
||||
# 测试查询用户
|
||||
statement = select(User)
|
||||
result = await session.execute(statement)
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
print("❌ 没有找到用户,请先同步数据")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 测试查询用户
|
||||
statement = select(User)
|
||||
user = db.exec(statement).first()
|
||||
if not user:
|
||||
print("❌ 没有找到用户,请先同步数据")
|
||||
return False
|
||||
print(f"✓ 找到用户: {user.name} (ID: {user.id})")
|
||||
|
||||
print(f"✓ 找到用户: {user.name} (ID: {user.id})")
|
||||
# 测试 lazer 资料
|
||||
if user.lazer_profile:
|
||||
print(
|
||||
f"✓ 用户有 lazer 资料: 支持者={user.lazer_profile.is_supporter}"
|
||||
)
|
||||
else:
|
||||
print("⚠ 用户没有 lazer 资料,将使用默认值")
|
||||
|
||||
# 测试 lazer 资料
|
||||
if user.lazer_profile:
|
||||
print(f"✓ 用户有 lazer 资料: 支持者={user.lazer_profile.is_supporter}")
|
||||
else:
|
||||
print("⚠ 用户没有 lazer 资料,将使用默认值")
|
||||
# 测试 lazer 统计
|
||||
osu_stats = None
|
||||
for stat in user.lazer_statistics:
|
||||
if stat.mode == "osu":
|
||||
osu_stats = stat
|
||||
break
|
||||
|
||||
# 测试 lazer 统计
|
||||
osu_stats = None
|
||||
for stat in user.lazer_statistics:
|
||||
if stat.mode == "osu":
|
||||
osu_stats = stat
|
||||
break
|
||||
if osu_stats:
|
||||
print(
|
||||
f"✓ 用户有 osu! 统计: PP={osu_stats.pp}, "
|
||||
f"游戏次数={osu_stats.play_count}"
|
||||
)
|
||||
else:
|
||||
print("⚠ 用户没有 osu! 统计,将使用默认值")
|
||||
|
||||
if osu_stats:
|
||||
print(
|
||||
f"✓ 用户有 osu! 统计: PP={osu_stats.pp}, "
|
||||
f"游戏次数={osu_stats.play_count}"
|
||||
)
|
||||
else:
|
||||
print("⚠ 用户没有 osu! 统计,将使用默认值")
|
||||
# 测试转换为 API 格式
|
||||
api_user = convert_db_user_to_api_user(user, "osu")
|
||||
print("✓ 成功转换为 API 用户格式")
|
||||
print(f" - 用户名: {api_user.username}")
|
||||
print(f" - 国家: {api_user.country_code}")
|
||||
print(f" - PP: {api_user.statistics.pp}")
|
||||
print(f" - 是否支持者: {api_user.is_supporter}")
|
||||
|
||||
# 测试转换为 API 格式
|
||||
api_user = convert_db_user_to_api_user(user, "osu", db)
|
||||
print("✓ 成功转换为 API 用户格式")
|
||||
print(f" - 用户名: {api_user.username}")
|
||||
print(f" - 国家: {api_user.country_code}")
|
||||
print(f" - PP: {api_user.statistics.pp}")
|
||||
print(f" - 是否支持者: {api_user.is_supporter}")
|
||||
return True
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e}")
|
||||
import traceback
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_authentication():
|
||||
async def test_authentication():
|
||||
"""测试认证功能"""
|
||||
print("\n测试认证功能...")
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
try:
|
||||
# 尝试认证第一个用户
|
||||
statement = select(User)
|
||||
result = await session.execute(statement)
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
print("❌ 没有用户进行认证测试")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 尝试认证第一个用户
|
||||
statement = select(User)
|
||||
user = db.exec(statement).first()
|
||||
if not user:
|
||||
print("❌ 没有用户进行认证测试")
|
||||
return False
|
||||
print(f"✓ 测试用户: {user.name}")
|
||||
print("⚠ 注意: 实际密码认证需要正确的密码")
|
||||
|
||||
print(f"✓ 测试用户: {user.name}")
|
||||
print("⚠ 注意: 实际密码认证需要正确的密码")
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 认证测试失败: {e}")
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
print(f"❌ 认证测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("Lazer API 系统测试")
|
||||
print("=" * 40)
|
||||
|
||||
# 测试表连接
|
||||
success1 = test_lazer_tables()
|
||||
success1 = await test_lazer_tables()
|
||||
|
||||
# 测试认证
|
||||
success2 = test_authentication()
|
||||
success2 = await test_authentication()
|
||||
|
||||
print("\n" + "=" * 40)
|
||||
if success1 and success2:
|
||||
@@ -130,4 +128,6 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Reference in New Issue
Block a user