refactor(database): use asyncio

This commit is contained in:
MingxuanGame
2025-07-25 20:43:50 +08:00
parent 2e1489c6d4
commit f347b680b2
21 changed files with 296 additions and 536 deletions

View File

@@ -14,7 +14,8 @@ from app.database import (
import bcrypt import bcrypt
from jose import JWTError, jwt from jose import JWTError, jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from sqlmodel import Session, select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
# 密码哈希上下文 # 密码哈希上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -70,7 +71,9 @@ def get_password_hash(password: str) -> str:
return pw_bcrypt.decode() return pw_bcrypt.decode()
def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None: async def authenticate_user_legacy(
db: AsyncSession, name: str, password: str
) -> DBUser | None:
""" """
验证用户身份 - 使用类似 from_login 的逻辑 验证用户身份 - 使用类似 from_login 的逻辑
""" """
@@ -79,7 +82,7 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
# 2. 根据用户名查找用户 # 2. 根据用户名查找用户
statement = select(DBUser).where(DBUser.name == name) statement = select(DBUser).where(DBUser.name == name)
user = db.exec(statement).first() user = (await db.exec(statement)).first()
if not user: if not user:
return None return None
@@ -107,9 +110,11 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
return None return None
def authenticate_user(db: Session, username: str, password: str) -> DBUser | None: async def authenticate_user(
db: AsyncSession, username: str, password: str
) -> DBUser | None:
"""验证用户身份""" """验证用户身份"""
return authenticate_user_legacy(db, username, password) return await authenticate_user_legacy(db, username, password)
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str: def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
@@ -147,24 +152,28 @@ def verify_token(token: str) -> dict | None:
return None return None
def store_token( async def store_token(
db: Session, user_id: int, access_token: str, refresh_token: str, expires_in: int db: AsyncSession,
user_id: int,
access_token: str,
refresh_token: str,
expires_in: int,
) -> OAuthToken: ) -> OAuthToken:
"""存储令牌到数据库""" """存储令牌到数据库"""
expires_at = datetime.utcnow() + timedelta(seconds=expires_in) expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
# 删除用户的旧令牌 # 删除用户的旧令牌
statement = select(OAuthToken).where(OAuthToken.user_id == user_id) statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
old_tokens = db.exec(statement).all() old_tokens = (await db.exec(statement)).all()
for token in old_tokens: for token in old_tokens:
db.delete(token) await db.delete(token)
# 检查是否有重复的 access_token # 检查是否有重复的 access_token
duplicate_token = db.exec( duplicate_token = (
select(OAuthToken).where(OAuthToken.access_token == access_token) await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))
).first() ).first()
if duplicate_token: if duplicate_token:
db.delete(duplicate_token) await db.delete(duplicate_token)
# 创建新令牌记录 # 创建新令牌记录
token_record = OAuthToken( token_record = OAuthToken(
@@ -174,24 +183,28 @@ def store_token(
expires_at=expires_at, expires_at=expires_at,
) )
db.add(token_record) db.add(token_record)
db.commit() await db.commit()
db.refresh(token_record) await db.refresh(token_record)
return token_record return token_record
def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None: async def get_token_by_access_token(
db: AsyncSession, access_token: str
) -> OAuthToken | None:
"""根据访问令牌获取令牌记录""" """根据访问令牌获取令牌记录"""
statement = select(OAuthToken).where( statement = select(OAuthToken).where(
OAuthToken.access_token == access_token, OAuthToken.access_token == access_token,
OAuthToken.expires_at > datetime.utcnow(), OAuthToken.expires_at > datetime.utcnow(),
) )
return db.exec(statement).first() return (await db.exec(statement)).first()
def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None: async def get_token_by_refresh_token(
db: AsyncSession, refresh_token: str
) -> OAuthToken | None:
"""根据刷新令牌获取令牌记录""" """根据刷新令牌获取令牌记录"""
statement = select(OAuthToken).where( statement = select(OAuthToken).where(
OAuthToken.refresh_token == refresh_token, OAuthToken.refresh_token == refresh_token,
OAuthToken.expires_at > datetime.utcnow(), OAuthToken.expires_at > datetime.utcnow(),
) )
return db.exec(statement).first() return (await db.exec(statement)).first()

View File

@@ -10,7 +10,7 @@ load_dotenv()
class Settings: class Settings:
# 数据库设置 # 数据库设置
DATABASE_URL: str = os.getenv( DATABASE_URL: str = os.getenv(
"DATABASE_URL", "mysql+pymysql://root:password@localhost:3306/osu_api" "DATABASE_URL", "mysql+aiomysql://root:password@localhost:3306/osu_api"
) )
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0") REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0")

View File

@@ -15,7 +15,7 @@ class User(SQLModel, table=True):
__tablename__ = "users" # pyright: ignore[reportAssignmentType] __tablename__ = "users" # pyright: ignore[reportAssignmentType]
# 主键 # 主键
id: int | None = Field(default=None, primary_key=True, index=True) id: int = Field(default=None, primary_key=True, index=True, nullable=False)
# 基本信息(匹配 migrations 中的结构) # 基本信息(匹配 migrations 中的结构)
name: str = Field(max_length=32, unique=True, index=True) # 用户名 name: str = Field(max_length=32, unique=True, index=True) # 用户名

View File

@@ -1,2 +1,4 @@
from __future__ import annotations
from .database import get_db as get_db from .database import get_db as get_db
from .user import get_current_user as get_current_user from .user import get_current_user as get_current_user

View File

@@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
from sqlmodel import Session, create_engine from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
try: try:
import redis import redis
@@ -9,7 +11,7 @@ except ImportError:
from app.config import settings from app.config import settings
# 数据库引擎 # 数据库引擎
engine = create_engine(settings.DATABASE_URL) engine = create_async_engine(settings.DATABASE_URL)
# Redis 连接 # Redis 连接
if redis: if redis:
@@ -19,11 +21,16 @@ else:
# 数据库依赖 # 数据库依赖
def get_db(): async def get_db():
with Session(engine) as session: async with AsyncSession(engine) as session:
yield session yield session
async def create_tables():
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
# Redis 依赖 # Redis 依赖
def get_redis(): def get_redis():
return redis_client return redis_client

View File

@@ -9,14 +9,16 @@ from .database import get_db
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlmodel import Session, select from sqlalchemy.orm import joinedload, selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer() security = HTTPBearer()
async def get_current_user( async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security), credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> DBUser: ) -> DBUser:
"""获取当前认证用户""" """获取当前认证用户"""
token = credentials.credentials token = credentials.credentials
@@ -27,9 +29,31 @@ async def get_current_user(
return user return user
async def get_current_user_by_token(token: str, db: Session) -> DBUser | None: async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None:
token_record = get_token_by_access_token(db, token) token_record = await get_token_by_access_token(db, token)
if not token_record: if not token_record:
return None return None
user = db.exec(select(DBUser).where(DBUser.id == token_record.user_id)).first() user = (
await db.exec(
select(DBUser)
.options(
joinedload(DBUser.lazer_profile), # pyright: ignore[reportArgumentType]
joinedload(DBUser.lazer_counts), # pyright: ignore[reportArgumentType]
joinedload(DBUser.daily_challenge_stats), # pyright: ignore[reportArgumentType]
joinedload(DBUser.avatar), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_statistics), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_achievements), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_profile_sections), # pyright: ignore[reportArgumentType]
selectinload(DBUser.statistics), # pyright: ignore[reportArgumentType]
selectinload(DBUser.team_membership), # pyright: ignore[reportArgumentType]
selectinload(DBUser.rank_history), # pyright: ignore[reportArgumentType]
selectinload(DBUser.active_banners), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_badges), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_replays_watched), # pyright: ignore[reportArgumentType]
)
.where(DBUser.id == token_record.user_id)
)
).first()
return user return user

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from fastapi import APIRouter from fastapi import APIRouter
router = APIRouter() router = APIRouter()

View File

@@ -14,7 +14,7 @@ from app.dependencies import get_db
from app.models.oauth import TokenResponse from app.models.oauth import TokenResponse
from fastapi import APIRouter, Depends, Form, HTTPException from fastapi import APIRouter, Depends, Form, HTTPException
from sqlmodel import Session from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(tags=["osu! OAuth 认证"]) router = APIRouter(tags=["osu! OAuth 认证"])
@@ -28,7 +28,7 @@ async def oauth_token(
username: str | None = Form(None), username: str | None = Form(None),
password: str | None = Form(None), password: str | None = Form(None),
refresh_token: str | None = Form(None), refresh_token: str | None = Form(None),
db: Session = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""OAuth 令牌端点""" """OAuth 令牌端点"""
# 验证客户端凭据 # 验证客户端凭据
@@ -46,7 +46,7 @@ async def oauth_token(
) )
# 验证用户 # 验证用户
user = authenticate_user(db, username, password) user = await authenticate_user(db, username, password)
if not user: if not user:
raise HTTPException(status_code=401, detail="Invalid username or password") raise HTTPException(status_code=401, detail="Invalid username or password")
@@ -58,9 +58,9 @@ async def oauth_token(
refresh_token_str = generate_refresh_token() refresh_token_str = generate_refresh_token()
# 存储令牌 # 存储令牌
store_token( await store_token(
db, db,
getattr(user, "id"), user.id,
access_token, access_token,
refresh_token_str, refresh_token_str,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
@@ -80,7 +80,7 @@ async def oauth_token(
raise HTTPException(status_code=400, detail="Refresh token required") raise HTTPException(status_code=400, detail="Refresh token required")
# 验证刷新令牌 # 验证刷新令牌
token_record = get_token_by_refresh_token(db, refresh_token) token_record =await get_token_by_refresh_token(db, refresh_token)
if not token_record: if not token_record:
raise HTTPException(status_code=401, detail="Invalid refresh token") raise HTTPException(status_code=401, detail="Invalid refresh token")
@@ -92,10 +92,9 @@ async def oauth_token(
new_refresh_token = generate_refresh_token() new_refresh_token = generate_refresh_token()
# 更新令牌 # 更新令牌
user_id = int(getattr(token_record, "user_id")) await store_token(
store_token(
db, db,
user_id, token_record.user_id,
access_token, access_token,
new_refresh_token, new_refresh_token,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,

View File

@@ -5,6 +5,7 @@ from app.database import (
BeatmapResp, BeatmapResp,
User as DBUser, User as DBUser,
) )
from app.database.beatmapset import Beatmapset
from app.dependencies.database import get_db from app.dependencies.database import get_db
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
@@ -12,16 +13,24 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query from fastapi import Depends, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import Session, col, select from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap( async def get_beatmap(
bid: int, bid: int,
current_user: DBUser = Depends(get_current_user), current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
beatmap = db.exec(select(Beatmap).where(Beatmap.id == bid)).first() beatmap = (
await db.exec(
select(Beatmap)
.options(joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmap.id == bid)
)
).first()
if not beatmap: if not beatmap:
raise HTTPException(status_code=404, detail="Beatmap not found") raise HTTPException(status_code=404, detail="Beatmap not found")
return BeatmapResp.from_db(beatmap) return BeatmapResp.from_db(beatmap)
@@ -36,16 +45,30 @@ class BatchGetResp(BaseModel):
async def batch_get_beatmaps( async def batch_get_beatmaps(
b_ids: list[int] = Query(alias="id", default_factory=list), b_ids: list[int] = Query(alias="id", default_factory=list),
current_user: DBUser = Depends(get_current_user), current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
if not b_ids: if not b_ids:
# select 50 beatmaps by last_updated # select 50 beatmaps by last_updated
beatmaps = db.exec( beatmaps = (
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50) await db.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.order_by(col(Beatmap.last_updated).desc())
.limit(50)
)
).all() ).all()
else: else:
beatmaps = db.exec( beatmaps = (
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50) await db.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.where(col(Beatmap.id).in_(b_ids))
.limit(50)
)
).all() ).all()
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps]) return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])

View File

@@ -11,16 +11,24 @@ from app.dependencies.user import get_current_user
from .api_router import router from .api_router import router
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from sqlmodel import Session, select from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp) @router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
async def get_beatmapset( async def get_beatmapset(
sid: int, sid: int,
current_user: DBUser = Depends(get_current_user), current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
beatmapset = db.exec(select(Beatmapset).where(Beatmapset.id == sid)).first() beatmapset = (
await db.exec(
select(Beatmapset)
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmapset.id == sid)
)
).first()
if not beatmapset: if not beatmapset:
raise HTTPException(status_code=404, detail="Beatmapset not found") raise HTTPException(status_code=404, detail="Beatmapset not found")
return BeatmapsetResp.from_db(beatmapset) return BeatmapsetResp.from_db(beatmapset)

View File

@@ -5,7 +5,7 @@ from typing import Literal
from app.database import ( from app.database import (
User as DBUser, User as DBUser,
) )
from app.dependencies import get_current_user, get_db from app.dependencies import get_current_user
from app.models.user import ( from app.models.user import (
User as ApiUser, User as ApiUser,
) )
@@ -14,7 +14,6 @@ from app.utils import convert_db_user_to_api_user
from .api_router import router from .api_router import router
from fastapi import Depends from fastapi import Depends
from sqlalchemy.orm import Session
@router.get("/me/{ruleset}", response_model=ApiUser) @router.get("/me/{ruleset}", response_model=ApiUser)
@@ -22,9 +21,8 @@ from sqlalchemy.orm import Session
async def get_user_info_default( async def get_user_info_default(
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu", ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
current_user: DBUser = Depends(get_current_user), current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
): ):
"""获取当前用户信息默认使用osu模式""" """获取当前用户信息默认使用osu模式"""
# 默认使用osu模式 # 默认使用osu模式
api_user = convert_db_user_to_api_user(current_user, ruleset, db) api_user = await convert_db_user_to_api_user(current_user, ruleset)
return api_user return api_user

View File

@@ -1 +1,3 @@
from __future__ import annotations
from .router import router as signalr_router from .router import router as signalr_router

View File

@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import json import json
from logging import info
import time import time
from typing import Literal from typing import Literal
import uuid import uuid
@@ -9,15 +8,14 @@ import uuid
from app.database import User as DBUser from app.database import User as DBUser
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.dependencies.database import get_db from app.dependencies.database import get_db
from app.dependencies.user import get_current_user_by_token, security from app.dependencies.user import get_current_user_by_token
from app.models.signalr import NegotiateResponse, Transport from app.models.signalr import NegotiateResponse, Transport
from app.router.signalr.packet import SEP from app.router.signalr.packet import SEP
from .hub import Hubs from .hub import Hubs
from fastapi import APIRouter, Depends, Header, Query, WebSocket from fastapi import APIRouter, Depends, Header, Query, WebSocket
from fastapi.security import HTTPAuthorizationCredentials from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import Session
router = APIRouter() router = APIRouter()
@@ -48,7 +46,7 @@ async def connect(
websocket: WebSocket, websocket: WebSocket,
id: str, id: str,
authorization: str = Header(...), authorization: str = Header(...),
db: Session = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
token = authorization[7:] token = authorization[7:]
user_id = id.split(":")[0] user_id = id.split(":")[0]

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
from collections.abc import Callable
import inspect import inspect
from typing import Any, Callable, ForwardRef, cast from typing import Any, ForwardRef, cast
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66 # https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66

View File

@@ -23,11 +23,9 @@ from app.models.user import (
UserAchievement, UserAchievement,
) )
from sqlalchemy.orm import Session
async def convert_db_user_to_api_user(
def convert_db_user_to_api_user( db_user: DBUser, ruleset: str = "osu"
db_user: DBUser, ruleset: str = "osu", db_session: Session | None = None
) -> User: ) -> User:
"""将数据库用户模型转换为API用户模型使用 Lazer 表)""" """将数据库用户模型转换为API用户模型使用 Lazer 表)"""

View File

@@ -5,419 +5,84 @@ osu! API 模拟服务器的示例数据填充脚本
from __future__ import annotations from __future__ import annotations
import asyncio
from datetime import datetime from datetime import datetime
import time import time
from app.auth import get_password_hash from app.auth import get_password_hash
from app.database import ( from app.database import (
DailyChallengeStats,
LazerUserAchievement,
LazerUserStatistics,
RankHistory,
User, User,
) )
from app.dependencies.database import engine, get_db from app.dependencies.database import create_tables, engine
from sqlmodel import SQLModel from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
# 创建所有表
SQLModel.metadata.create_all(bind=engine)
def create_sample_user(): async def create_sample_user():
"""创建示例用户数据""" """创建示例用户数据"""
with next(get_db()) as db: async with AsyncSession(engine) as session:
# 检查用户是否已存在 async with session.begin():
from sqlmodel import select
statement = select(User).where(User.name == "Googujiang") # 检查用户是否已存在
existing_user = db.exec(statement).first() statement = select(User).where(User.name == "Googujiang")
if existing_user: result = await session.execute(statement)
print("示例用户已存在,跳过创建") existing_user = result.scalars().first()
return existing_user if existing_user:
print("示例用户已存在,跳过创建")
return existing_user
# 当前时间戳 # 当前时间戳
current_timestamp = int(time.time()) current_timestamp = int(time.time())
join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp()) join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp())
last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp()) last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp())
# 创建用户 # 创建用户
user = User( user = User(
name="Googujiang", name="Googujiang",
safe_name="googujiang", # 安全用户名(小写) safe_name="googujiang", # 安全用户名(小写)
email="googujiang@example.com", email="googujiang@example.com",
priv=1, # 默认权限 priv=1, # 默认权限
pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式 pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式
country="JP", country="JP",
silence_end=0, silence_end=0,
donor_end=0, donor_end=0,
creation_time=join_timestamp, creation_time=join_timestamp,
latest_activity=last_visit_timestamp, latest_activity=last_visit_timestamp,
clan_id=0, clan_id=0,
clan_priv=0, clan_priv=0,
preferred_mode=0, # 0 = osu! preferred_mode=0, # 0 = osu!
play_style=0, play_style=0,
custom_badge_name=None, custom_badge_name=None,
custom_badge_icon=None, custom_badge_icon=None,
userpage_content="「世界に忘れられた」", userpage_content="「世界に忘れられた」",
api_key=None, api_key=None,
# # 兼容性字段 )
# avatar_url="https://a.ppy.sh/15651670?1732362658.jpeg",
# cover_url="https://assets.ppy.sh/user-profile-covers/15651670/0fc7b77adef39765a570e7f535bc383e5a848850d41a8943f8857984330b8bc6.jpeg",
# has_supported=True,
# interests="「世界に忘れられた」",
# location="咕谷国",
# website="https://gmoe.cc",
# playstyle=["mouse", "keyboard", "tablet"],
# profile_order=[
# "me",
# "recent_activity",
# "top_ranks",
# "medals",
# "historical",
# "beatmaps",
# "kudosu",
# ],
# beatmap_playcounts_count=3306,
# favourite_beatmapset_count=15,
# follower_count=98,
# graveyard_beatmapset_count=7,
# mapping_follower_count=1,
# previous_usernames=["hehejun"],
# monthly_playcounts=[
# {"start_date": "2019-11-01", "count": 43},
# {"start_date": "2020-04-01", "count": 216},
# {"start_date": "2020-05-01", "count": 656},
# {"start_date": "2020-07-01", "count": 158},
# {"start_date": "2020-08-01", "count": 174},
# {"start_date": "2020-10-01", "count": 13},
# {"start_date": "2020-11-01", "count": 52},
# {"start_date": "2020-12-01", "count": 140},
# {"start_date": "2021-01-01", "count": 359},
# {"start_date": "2021-02-01", "count": 452},
# {"start_date": "2021-03-01", "count": 77},
# {"start_date": "2021-04-01", "count": 114},
# {"start_date": "2021-05-01", "count": 270},
# {"start_date": "2021-06-01", "count": 148},
# {"start_date": "2021-07-01", "count": 246},
# {"start_date": "2021-08-01", "count": 56},
# {"start_date": "2021-09-01", "count": 136},
# {"start_date": "2021-10-01", "count": 45},
# {"start_date": "2021-11-01", "count": 98},
# {"start_date": "2021-12-01", "count": 54},
# {"start_date": "2022-01-01", "count": 88},
# {"start_date": "2022-02-01", "count": 45},
# {"start_date": "2022-03-01", "count": 6},
# {"start_date": "2022-04-01", "count": 54},
# {"start_date": "2022-05-01", "count": 105},
# {"start_date": "2022-06-01", "count": 37},
# {"start_date": "2022-07-01", "count": 88},
# {"start_date": "2022-08-01", "count": 7},
# {"start_date": "2022-09-01", "count": 9},
# {"start_date": "2022-10-01", "count": 6},
# {"start_date": "2022-11-01", "count": 2},
# {"start_date": "2022-12-01", "count": 16},
# {"start_date": "2023-01-01", "count": 7},
# {"start_date": "2023-04-01", "count": 16},
# {"start_date": "2023-05-01", "count": 3},
# {"start_date": "2023-06-01", "count": 8},
# {"start_date": "2023-07-01", "count": 23},
# {"start_date": "2023-08-01", "count": 3},
# {"start_date": "2023-09-01", "count": 1},
# {"start_date": "2023-10-01", "count": 25},
# {"start_date": "2023-11-01", "count": 160},
# {"start_date": "2023-12-01", "count": 306},
# {"start_date": "2024-01-01", "count": 735},
# {"start_date": "2024-02-01", "count": 420},
# {"start_date": "2024-03-01", "count": 549},
# {"start_date": "2024-04-01", "count": 466},
# {"start_date": "2024-05-01", "count": 333},
# {"start_date": "2024-06-01", "count": 1126},
# {"start_date": "2024-07-01", "count": 534},
# {"start_date": "2024-08-01", "count": 280},
# {"start_date": "2024-09-01", "count": 116},
# {"start_date": "2024-10-01", "count": 120},
# {"start_date": "2024-11-01", "count": 332},
# {"start_date": "2024-12-01", "count": 243},
# {"start_date": "2025-01-01", "count": 122},
# {"start_date": "2025-02-01", "count": 379},
# {"start_date": "2025-03-01", "count": 278},
# {"start_date": "2025-04-01", "count": 296},
# {"start_date": "2025-05-01", "count": 964},
# {"start_date": "2025-06-01", "count": 821},
# {"start_date": "2025-07-01", "count": 230},
# ],
)
db.add(user) session.add(user)
db.commit() await session.commit()
db.refresh(user) await session.refresh(user)
# 确保用户ID存在 # 确保用户ID存在
if user.id is None: if user.id is None:
raise ValueError("User ID is None after saving to database") raise ValueError("User ID is None after saving to database")
# 创建 osu! 模式统计 print(f"成功创建示例用户: {user.name} (ID: {user.id})")
osu_stats = LazerUserStatistics( print(f"安全用户名: {user.safe_name}")
user_id=user.id, print(f"邮箱: {user.email}")
mode="osu", print(f"国家: {user.country}")
count_100=276274, return user
count_300=1932068,
count_50=32776,
count_miss=111064,
level_current=97,
level_progress=96,
global_rank=298026,
country_rank=11221,
pp=2826.26,
ranked_score=4415081049,
hit_accuracy=95.7168,
play_count=12711,
play_time=836529,
total_score=12390140370,
total_hits=2241118,
maximum_combo=1859,
replays_watched_by_others=0,
is_ranked=True,
grade_ss=14,
grade_ssh=3,
grade_s=322,
grade_sh=11,
grade_a=757,
rank_highest=295701,
rank_highest_updated_at=datetime(2025, 7, 2, 17, 30, 21),
)
# 创建 taiko 模式统计
taiko_stats = LazerUserStatistics(
user_id=user.id,
mode="taiko",
count_100=160,
count_300=154,
count_50=0,
count_miss=480,
level_current=2,
level_progress=49,
global_rank=None,
pp=0,
ranked_score=0,
hit_accuracy=0,
play_count=6,
play_time=217,
total_score=79301,
total_hits=314,
maximum_combo=0,
replays_watched_by_others=0,
is_ranked=False,
)
# 创建 fruits 模式统计
fruits_stats = LazerUserStatistics(
user_id=user.id,
mode="fruits",
count_100=109,
count_300=1613,
count_50=1861,
count_miss=328,
level_current=6,
level_progress=14,
global_rank=None,
pp=0,
ranked_score=343854,
hit_accuracy=89.4779,
play_count=19,
play_time=669,
total_score=1362651,
total_hits=3583,
maximum_combo=75,
replays_watched_by_others=0,
is_ranked=False,
grade_a=1,
)
# 创建 mania 模式统计
mania_stats = LazerUserStatistics(
user_id=user.id,
mode="mania",
count_100=7867,
count_300=12104,
count_50=991,
count_miss=2951,
level_current=12,
level_progress=89,
global_rank=660670,
pp=25.3784,
ranked_score=3812295,
hit_accuracy=77.9316,
play_count=85,
play_time=4834,
total_score=13454470,
total_hits=20962,
maximum_combo=573,
replays_watched_by_others=0,
is_ranked=True,
grade_a=1,
)
db.add_all([osu_stats, taiko_stats, fruits_stats, mania_stats])
# 创建每日挑战统计
daily_challenge = DailyChallengeStats(
user_id=user.id,
daily_streak_best=1,
daily_streak_current=0,
last_update=datetime(2025, 6, 21, 0, 0, 0),
last_weekly_streak=datetime(2025, 6, 19, 0, 0, 0),
playcount=1,
top_10p_placements=0,
top_50p_placements=0,
weekly_streak_best=1,
weekly_streak_current=0,
)
db.add(daily_challenge)
# 创建排名历史 (最近90天的数据)
rank_data = [
322806,
323092,
323341,
323616,
323853,
324106,
324378,
324676,
324958,
325254,
325492,
325780,
326075,
326356,
326586,
326845,
327067,
327286,
327526,
327778,
328039,
328347,
328631,
328858,
329323,
329557,
329809,
329911,
330188,
330425,
330650,
330881,
331068,
331325,
331575,
331816,
332061,
328959,
315648,
315881,
308784,
309023,
309252,
309433,
309537,
309364,
309548,
308957,
309182,
309426,
309607,
309831,
310054,
310269,
310485,
310714,
310956,
310924,
311125,
311203,
311422,
311640,
303091,
303309,
303500,
303691,
303758,
303750,
303957,
299867,
300088,
300273,
300457,
295799,
295976,
296153,
296350,
296566,
296756,
296933,
297141,
297314,
297480,
297114,
297296,
297480,
297645,
297815,
297993,
298026,
]
rank_history = RankHistory(user_id=user.id, mode="osu", rank_data=rank_data)
db.add(rank_history)
# 创建一些成就
achievements = [
LazerUserAchievement(
user_id=user.id,
achievement_id=336,
achieved_at=datetime(2025, 6, 21, 19, 6, 32),
),
LazerUserAchievement(
user_id=user.id,
achievement_id=319,
achieved_at=datetime(2025, 6, 1, 0, 52, 0),
),
LazerUserAchievement(
user_id=user.id,
achievement_id=222,
achieved_at=datetime(2025, 5, 28, 12, 24, 37),
),
LazerUserAchievement(
user_id=user.id,
achievement_id=38,
achieved_at=datetime(2024, 7, 5, 15, 43, 23),
),
LazerUserAchievement(
user_id=user.id,
achievement_id=67,
achieved_at=datetime(2024, 6, 24, 5, 6, 44),
),
]
db.add_all(achievements)
db.commit()
print(f"成功创建示例用户: {user.name} (ID: {user.id})")
print(f"安全用户名: {user.safe_name}")
print(f"邮箱: {user.email}")
print(f"国家: {user.country}")
return user
if __name__ == "__main__": async def main():
print("开始创建示例数据...") print("开始创建示例数据...")
user = create_sample_user() await create_tables()
user = await create_sample_user()
print("示例数据创建完成!") print("示例数据创建完成!")
print(f"用户名: {user.name}") print(f"用户名: {user.name}")
print("密码: password123") print("密码: password123")
print("现在您可以使用这些凭据来测试API了。") print("现在您可以使用这些凭据来测试API了。")
if __name__ == "__main__":
asyncio.run(main())

17
main.py
View File

@@ -1,24 +1,31 @@
from __future__ import annotations from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from app.config import settings from app.config import settings
from app.dependencies.database import engine from app.dependencies.database import create_tables
from app.router import api_router, auth_router, signalr_router from app.router import api_router, auth_router, signalr_router
from fastapi import FastAPI from fastapi import FastAPI
from sqlmodel import SQLModel
# 注意: 表结构现在通过 migrations 管理,不再自动创建 # 注意: 表结构现在通过 migrations 管理,不再自动创建
# 如需创建表,请运行: python quick_sync.py # 如需创建表,请运行: python quick_sync.py
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0")
@asynccontextmanager
async def lifespan(app: FastAPI):
# on startup
await create_tables()
# on shutdown
yield
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
app.include_router(api_router, prefix="/api/v2") app.include_router(api_router, prefix="/api/v2")
app.include_router(signalr_router, prefix="/signalr") app.include_router(signalr_router, prefix="/signalr")
app.include_router(auth_router) app.include_router(auth_router)
SQLModel.metadata.create_all(bind=engine)
@app.get("/") @app.get("/")
async def root(): async def root():

View File

@@ -5,6 +5,7 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.11" requires-python = ">=3.11"
dependencies = [ dependencies = [
"aiomysql>=0.2.0",
"alembic>=1.12.1", "alembic>=1.12.1",
"bcrypt>=4.1.2", "bcrypt>=4.1.2",
"cryptography>=41.0.7", "cryptography>=41.0.7",
@@ -12,7 +13,6 @@ dependencies = [
"msgpack>=1.1.1", "msgpack>=1.1.1",
"passlib[bcrypt]>=1.7.4", "passlib[bcrypt]>=1.7.4",
"pydantic[email]>=2.5.0", "pydantic[email]>=2.5.0",
"pymysql>=1.1.0",
"python-dotenv>=1.0.0", "python-dotenv>=1.0.0",
"python-jose[cryptography]>=3.3.0", "python-jose[cryptography]>=3.3.0",
"python-multipart>=0.0.6", "python-multipart>=0.0.6",

View File

@@ -2,7 +2,6 @@ fastapi==0.104.1
uvicorn[standard]==0.24.0 uvicorn[standard]==0.24.0
sqlalchemy==2.0.23 sqlalchemy==2.0.23
alembic==1.12.1 alembic==1.12.1
pymysql==1.1.0
cryptography==41.0.7 cryptography==41.0.7
redis==5.0.1 redis==5.0.1
python-jose[cryptography]==3.3.0 python-jose[cryptography]==3.3.0
@@ -11,3 +10,4 @@ python-multipart==0.0.6
pydantic[email]==2.5.0 pydantic[email]==2.5.0
python-dotenv==1.0.0 python-dotenv==1.0.0
bcrypt==4.1.2 bcrypt==4.1.2
aiomysql==0.2.0

View File

@@ -12,108 +12,106 @@ import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from app.database import User from app.database import User
from app.dependencies.database import get_db from app.dependencies.database import engine
from app.utils import convert_db_user_to_api_user from app.utils import convert_db_user_to_api_user
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
def test_lazer_tables(): async def test_lazer_tables():
"""测试 lazer 表的基本功能""" """测试 lazer 表的基本功能"""
print("测试 Lazer API 表支持...") print("测试 Lazer API 表支持...")
# 获取数据库会话 async with AsyncSession(engine) as session:
db_gen = get_db() async with session.begin():
db = next(db_gen) try:
# 测试查询用户
statement = select(User)
result = await session.execute(statement)
user = result.scalars().first()
if not user:
print("❌ 没有找到用户,请先同步数据")
return False
try: print(f"✓ 找到用户: {user.name} (ID: {user.id})")
# 测试查询用户
statement = select(User)
user = db.exec(statement).first()
if not user:
print("❌ 没有找到用户,请先同步数据")
return False
print(f"✓ 找到用户: {user.name} (ID: {user.id})") # 测试 lazer 资料
if user.lazer_profile:
print(
f"✓ 用户有 lazer 资料: 支持者={user.lazer_profile.is_supporter}"
)
else:
print("⚠ 用户没有 lazer 资料,将使用默认值")
# 测试 lazer 资料 # 测试 lazer 统计
if user.lazer_profile: osu_stats = None
print(f"✓ 用户有 lazer 资料: 支持者={user.lazer_profile.is_supporter}") for stat in user.lazer_statistics:
else: if stat.mode == "osu":
print("⚠ 用户没有 lazer 资料,将使用默认值") osu_stats = stat
break
# 测试 lazer 统计 if osu_stats:
osu_stats = None print(
for stat in user.lazer_statistics: f"✓ 用户有 osu! 统计: PP={osu_stats.pp}, "
if stat.mode == "osu": f"游戏次数={osu_stats.play_count}"
osu_stats = stat )
break else:
print("⚠ 用户没有 osu! 统计,将使用默认值")
if osu_stats: # 测试转换为 API 格式
print( api_user = convert_db_user_to_api_user(user, "osu")
f"✓ 用户有 osu! 统计: PP={osu_stats.pp}, " print("✓ 成功转换为 API 用户格式")
f"游戏次数={osu_stats.play_count}" print(f" - 用户名: {api_user.username}")
) print(f" - 国家: {api_user.country_code}")
else: print(f" - PP: {api_user.statistics.pp}")
print("⚠ 用户没有 osu! 统计,将使用默认值") print(f" - 是否支持者: {api_user.is_supporter}")
# 测试转换为 API 格式 return True
api_user = convert_db_user_to_api_user(user, "osu", db)
print("✓ 成功转换为 API 用户格式")
print(f" - 用户名: {api_user.username}")
print(f" - 国家: {api_user.country_code}")
print(f" - PP: {api_user.statistics.pp}")
print(f" - 是否支持者: {api_user.is_supporter}")
return True except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
except Exception as e: traceback.print_exc()
print(f"❌ 测试失败: {e}") return False
import traceback
traceback.print_exc()
return False
finally:
db.close()
def test_authentication(): async def test_authentication():
"""测试认证功能""" """测试认证功能"""
print("\n测试认证功能...") print("\n测试认证功能...")
db_gen = get_db() async with AsyncSession(engine) as session:
db = next(db_gen) async with session.begin():
try:
# 尝试认证第一个用户
statement = select(User)
result = await session.execute(statement)
user = result.scalars().first()
if not user:
print("❌ 没有用户进行认证测试")
return False
try: print(f"✓ 测试用户: {user.name}")
# 尝试认证第一个用户 print("⚠ 注意: 实际密码认证需要正确的密码")
statement = select(User)
user = db.exec(statement).first()
if not user:
print("❌ 没有用户进行认证测试")
return False
print(f"✓ 测试用户: {user.name}") return True
print("⚠ 注意: 实际密码认证需要正确的密码")
return True except Exception as e:
print(f"❌ 认证测试失败: {e}")
except Exception as e: return False
print(f"❌ 认证测试失败: {e}")
return False
finally:
db.close()
def main(): async def main():
"""主测试函数""" """主测试函数"""
print("Lazer API 系统测试") print("Lazer API 系统测试")
print("=" * 40) print("=" * 40)
# 测试表连接 # 测试表连接
success1 = test_lazer_tables() success1 = await test_lazer_tables()
# 测试认证 # 测试认证
success2 = test_authentication() success2 = await test_authentication()
print("\n" + "=" * 40) print("\n" + "=" * 40)
if success1 and success2: if success1 and success2:
@@ -130,4 +128,6 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() import asyncio
asyncio.run(main())

16
uv.lock generated
View File

@@ -2,6 +2,18 @@ version = 1
revision = 2 revision = 2
requires-python = ">=3.11" requires-python = ">=3.11"
[[package]]
name = "aiomysql"
version = "0.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pymysql" },
]
sdist = { url = "https://files.pythonhosted.org/packages/67/76/2c5b55e4406a1957ffdfd933a94c2517455291c97d2b81cec6813754791a/aiomysql-0.2.0.tar.gz", hash = "sha256:558b9c26d580d08b8c5fd1be23c5231ce3aeff2dadad989540fee740253deb67", size = 114706, upload-time = "2023-06-11T19:57:53.608Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/42/87/c982ee8b333c85b8ae16306387d703a1fcdfc81a2f3f15a24820ab1a512d/aiomysql-0.2.0-py3-none-any.whl", hash = "sha256:b7c26da0daf23a5ec5e0b133c03d20657276e4eae9b73e040b72787f6f6ade0a", size = 44215, upload-time = "2023-06-11T19:57:51.09Z" },
]
[[package]] [[package]]
name = "alembic" name = "alembic"
version = "1.16.4" version = "1.16.4"
@@ -462,6 +474,7 @@ name = "osu-lazer-api"
version = "0.1.0" version = "0.1.0"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "aiomysql" },
{ name = "alembic" }, { name = "alembic" },
{ name = "bcrypt" }, { name = "bcrypt" },
{ name = "cryptography" }, { name = "cryptography" },
@@ -469,7 +482,6 @@ dependencies = [
{ name = "msgpack" }, { name = "msgpack" },
{ name = "passlib", extra = ["bcrypt"] }, { name = "passlib", extra = ["bcrypt"] },
{ name = "pydantic", extra = ["email"] }, { name = "pydantic", extra = ["email"] },
{ name = "pymysql" },
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "python-jose", extra = ["cryptography"] }, { name = "python-jose", extra = ["cryptography"] },
{ name = "python-multipart" }, { name = "python-multipart" },
@@ -487,6 +499,7 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "aiomysql", specifier = ">=0.2.0" },
{ name = "alembic", specifier = ">=1.12.1" }, { name = "alembic", specifier = ">=1.12.1" },
{ name = "bcrypt", specifier = ">=4.1.2" }, { name = "bcrypt", specifier = ">=4.1.2" },
{ name = "cryptography", specifier = ">=41.0.7" }, { name = "cryptography", specifier = ">=41.0.7" },
@@ -494,7 +507,6 @@ requires-dist = [
{ name = "msgpack", specifier = ">=1.1.1" }, { name = "msgpack", specifier = ">=1.1.1" },
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
{ name = "pydantic", extras = ["email"], specifier = ">=2.5.0" }, { name = "pydantic", extras = ["email"], specifier = ">=2.5.0" },
{ name = "pymysql", specifier = ">=1.1.0" },
{ name = "python-dotenv", specifier = ">=1.0.0" }, { name = "python-dotenv", specifier = ">=1.0.0" },
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
{ name = "python-multipart", specifier = ">=0.0.6" }, { name = "python-multipart", specifier = ">=0.0.6" },