refactor(database): use asyncio

This commit is contained in:
MingxuanGame
2025-07-25 20:43:50 +08:00
parent 2e1489c6d4
commit f347b680b2
21 changed files with 296 additions and 536 deletions

View File

@@ -14,7 +14,8 @@ from app.database import (
import bcrypt
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlmodel import Session, select
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
# 密码哈希上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -70,7 +71,9 @@ def get_password_hash(password: str) -> str:
return pw_bcrypt.decode()
def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None:
async def authenticate_user_legacy(
db: AsyncSession, name: str, password: str
) -> DBUser | None:
"""
验证用户身份 - 使用类似 from_login 的逻辑
"""
@@ -79,7 +82,7 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
# 2. 根据用户名查找用户
statement = select(DBUser).where(DBUser.name == name)
user = db.exec(statement).first()
user = (await db.exec(statement)).first()
if not user:
return None
@@ -107,9 +110,11 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
return None
def authenticate_user(db: Session, username: str, password: str) -> DBUser | None:
async def authenticate_user(
db: AsyncSession, username: str, password: str
) -> DBUser | None:
"""验证用户身份"""
return authenticate_user_legacy(db, username, password)
return await authenticate_user_legacy(db, username, password)
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
@@ -147,24 +152,28 @@ def verify_token(token: str) -> dict | None:
return None
def store_token(
db: Session, user_id: int, access_token: str, refresh_token: str, expires_in: int
async def store_token(
db: AsyncSession,
user_id: int,
access_token: str,
refresh_token: str,
expires_in: int,
) -> OAuthToken:
"""存储令牌到数据库"""
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
# 删除用户的旧令牌
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
old_tokens = db.exec(statement).all()
old_tokens = (await db.exec(statement)).all()
for token in old_tokens:
db.delete(token)
await db.delete(token)
# 检查是否有重复的 access_token
duplicate_token = db.exec(
select(OAuthToken).where(OAuthToken.access_token == access_token)
duplicate_token = (
await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))
).first()
if duplicate_token:
db.delete(duplicate_token)
await db.delete(duplicate_token)
# 创建新令牌记录
token_record = OAuthToken(
@@ -174,24 +183,28 @@ def store_token(
expires_at=expires_at,
)
db.add(token_record)
db.commit()
db.refresh(token_record)
await db.commit()
await db.refresh(token_record)
return token_record
def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None:
async def get_token_by_access_token(
db: AsyncSession, access_token: str
) -> OAuthToken | None:
"""根据访问令牌获取令牌记录"""
statement = select(OAuthToken).where(
OAuthToken.access_token == access_token,
OAuthToken.expires_at > datetime.utcnow(),
)
return db.exec(statement).first()
return (await db.exec(statement)).first()
def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None:
async def get_token_by_refresh_token(
db: AsyncSession, refresh_token: str
) -> OAuthToken | None:
"""根据刷新令牌获取令牌记录"""
statement = select(OAuthToken).where(
OAuthToken.refresh_token == refresh_token,
OAuthToken.expires_at > datetime.utcnow(),
)
return db.exec(statement).first()
return (await db.exec(statement)).first()

View File

@@ -10,7 +10,7 @@ load_dotenv()
class Settings:
# 数据库设置
DATABASE_URL: str = os.getenv(
"DATABASE_URL", "mysql+pymysql://root:password@localhost:3306/osu_api"
"DATABASE_URL", "mysql+aiomysql://root:password@localhost:3306/osu_api"
)
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0")

View File

@@ -15,7 +15,7 @@ class User(SQLModel, table=True):
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
# 主键
id: int | None = Field(default=None, primary_key=True, index=True)
id: int = Field(default=None, primary_key=True, index=True, nullable=False)
# 基本信息(匹配 migrations 中的结构)
name: str = Field(max_length=32, unique=True, index=True) # 用户名

View File

@@ -1,2 +1,4 @@
from __future__ import annotations
from .database import get_db as get_db
from .user import get_current_user as get_current_user
from .user import get_current_user as get_current_user

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from sqlmodel import Session, create_engine
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
try:
import redis
@@ -9,7 +11,7 @@ except ImportError:
from app.config import settings
# 数据库引擎
engine = create_engine(settings.DATABASE_URL)
engine = create_async_engine(settings.DATABASE_URL)
# Redis 连接
if redis:
@@ -19,11 +21,16 @@ else:
# 数据库依赖
def get_db():
with Session(engine) as session:
async def get_db():
async with AsyncSession(engine) as session:
yield session
async def create_tables():
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
# Redis 依赖
def get_redis():
return redis_client

View File

@@ -9,14 +9,16 @@ from .database import get_db
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlmodel import Session, select
from sqlalchemy.orm import joinedload, selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
) -> DBUser:
"""获取当前认证用户"""
token = credentials.credentials
@@ -27,9 +29,31 @@ async def get_current_user(
return user
async def get_current_user_by_token(token: str, db: Session) -> DBUser | None:
token_record = get_token_by_access_token(db, token)
async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None:
token_record = await get_token_by_access_token(db, token)
if not token_record:
return None
user = db.exec(select(DBUser).where(DBUser.id == token_record.user_id)).first()
user = (
await db.exec(
select(DBUser)
.options(
joinedload(DBUser.lazer_profile), # pyright: ignore[reportArgumentType]
joinedload(DBUser.lazer_counts), # pyright: ignore[reportArgumentType]
joinedload(DBUser.daily_challenge_stats), # pyright: ignore[reportArgumentType]
joinedload(DBUser.avatar), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_statistics), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_achievements), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_profile_sections), # pyright: ignore[reportArgumentType]
selectinload(DBUser.statistics), # pyright: ignore[reportArgumentType]
selectinload(DBUser.team_membership), # pyright: ignore[reportArgumentType]
selectinload(DBUser.rank_history), # pyright: ignore[reportArgumentType]
selectinload(DBUser.active_banners), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_badges), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_replays_watched), # pyright: ignore[reportArgumentType]
)
.where(DBUser.id == token_record.user_id)
)
).first()
return user

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from fastapi import APIRouter
router = APIRouter()

View File

@@ -14,7 +14,7 @@ from app.dependencies import get_db
from app.models.oauth import TokenResponse
from fastapi import APIRouter, Depends, Form, HTTPException
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(tags=["osu! OAuth 认证"])
@@ -28,7 +28,7 @@ async def oauth_token(
username: str | None = Form(None),
password: str | None = Form(None),
refresh_token: str | None = Form(None),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
"""OAuth 令牌端点"""
# 验证客户端凭据
@@ -46,7 +46,7 @@ async def oauth_token(
)
# 验证用户
user = authenticate_user(db, username, password)
user = await authenticate_user(db, username, password)
if not user:
raise HTTPException(status_code=401, detail="Invalid username or password")
@@ -58,9 +58,9 @@ async def oauth_token(
refresh_token_str = generate_refresh_token()
# 存储令牌
store_token(
await store_token(
db,
getattr(user, "id"),
user.id,
access_token,
refresh_token_str,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
@@ -80,7 +80,7 @@ async def oauth_token(
raise HTTPException(status_code=400, detail="Refresh token required")
# 验证刷新令牌
token_record = get_token_by_refresh_token(db, refresh_token)
token_record =await get_token_by_refresh_token(db, refresh_token)
if not token_record:
raise HTTPException(status_code=401, detail="Invalid refresh token")
@@ -92,10 +92,9 @@ async def oauth_token(
new_refresh_token = generate_refresh_token()
# 更新令牌
user_id = int(getattr(token_record, "user_id"))
store_token(
await store_token(
db,
user_id,
token_record.user_id,
access_token,
new_refresh_token,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,

View File

@@ -5,6 +5,7 @@ from app.database import (
BeatmapResp,
User as DBUser,
)
from app.database.beatmapset import Beatmapset
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user
@@ -12,16 +13,24 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query
from pydantic import BaseModel
from sqlmodel import Session, col, select
from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap(
bid: int,
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
beatmap = db.exec(select(Beatmap).where(Beatmap.id == bid)).first()
beatmap = (
await db.exec(
select(Beatmap)
.options(joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmap.id == bid)
)
).first()
if not beatmap:
raise HTTPException(status_code=404, detail="Beatmap not found")
return BeatmapResp.from_db(beatmap)
@@ -36,16 +45,30 @@ class BatchGetResp(BaseModel):
async def batch_get_beatmaps(
b_ids: list[int] = Query(alias="id", default_factory=list),
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
if not b_ids:
# select 50 beatmaps by last_updated
beatmaps = db.exec(
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.order_by(col(Beatmap.last_updated).desc())
.limit(50)
)
).all()
else:
beatmaps = db.exec(
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.where(col(Beatmap.id).in_(b_ids))
.limit(50)
)
).all()
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])

View File

@@ -11,16 +11,24 @@ from app.dependencies.user import get_current_user
from .api_router import router
from fastapi import Depends, HTTPException
from sqlmodel import Session, select
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
async def get_beatmapset(
sid: int,
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
beatmapset = db.exec(select(Beatmapset).where(Beatmapset.id == sid)).first()
beatmapset = (
await db.exec(
select(Beatmapset)
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmapset.id == sid)
)
).first()
if not beatmapset:
raise HTTPException(status_code=404, detail="Beatmapset not found")
return BeatmapsetResp.from_db(beatmapset)

View File

@@ -5,7 +5,7 @@ from typing import Literal
from app.database import (
User as DBUser,
)
from app.dependencies import get_current_user, get_db
from app.dependencies import get_current_user
from app.models.user import (
User as ApiUser,
)
@@ -14,7 +14,6 @@ from app.utils import convert_db_user_to_api_user
from .api_router import router
from fastapi import Depends
from sqlalchemy.orm import Session
@router.get("/me/{ruleset}", response_model=ApiUser)
@@ -22,9 +21,8 @@ from sqlalchemy.orm import Session
async def get_user_info_default(
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取当前用户信息默认使用osu模式"""
# 默认使用osu模式
api_user = convert_db_user_to_api_user(current_user, ruleset, db)
api_user = await convert_db_user_to_api_user(current_user, ruleset)
return api_user

View File

@@ -1 +1,3 @@
from __future__ import annotations
from .router import router as signalr_router

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import json
from logging import info
import time
from typing import Literal
import uuid
@@ -9,15 +8,14 @@ import uuid
from app.database import User as DBUser
from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user_by_token, security
from app.dependencies.user import get_current_user_by_token
from app.models.signalr import NegotiateResponse, Transport
from app.router.signalr.packet import SEP
from .hub import Hubs
from fastapi import APIRouter, Depends, Header, Query, WebSocket
from fastapi.security import HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter()
@@ -48,7 +46,7 @@ async def connect(
websocket: WebSocket,
id: str,
authorization: str = Header(...),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
token = authorization[7:]
user_id = id.split(":")[0]

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
from collections.abc import Callable
import inspect
from typing import Any, Callable, ForwardRef, cast
from typing import Any, ForwardRef, cast
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66

View File

@@ -23,11 +23,9 @@ from app.models.user import (
UserAchievement,
)
from sqlalchemy.orm import Session
def convert_db_user_to_api_user(
db_user: DBUser, ruleset: str = "osu", db_session: Session | None = None
async def convert_db_user_to_api_user(
db_user: DBUser, ruleset: str = "osu"
) -> User:
"""将数据库用户模型转换为API用户模型使用 Lazer 表)"""

View File

@@ -5,419 +5,84 @@ osu! API 模拟服务器的示例数据填充脚本
from __future__ import annotations
import asyncio
from datetime import datetime
import time
from app.auth import get_password_hash
from app.database import (
DailyChallengeStats,
LazerUserAchievement,
LazerUserStatistics,
RankHistory,
User,
)
from app.dependencies.database import engine, get_db
from app.dependencies.database import create_tables, engine
from sqlmodel import SQLModel
# 创建所有表
SQLModel.metadata.create_all(bind=engine)
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
def create_sample_user():
async def create_sample_user():
"""创建示例用户数据"""
with next(get_db()) as db:
# 检查用户是否已存在
from sqlmodel import select
async with AsyncSession(engine) as session:
async with session.begin():
statement = select(User).where(User.name == "Googujiang")
existing_user = db.exec(statement).first()
if existing_user:
print("示例用户已存在,跳过创建")
return existing_user
# 检查用户是否已存在
statement = select(User).where(User.name == "Googujiang")
result = await session.execute(statement)
existing_user = result.scalars().first()
if existing_user:
print("示例用户已存在,跳过创建")
return existing_user
# 当前时间戳
current_timestamp = int(time.time())
join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp())
last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp())
# 当前时间戳
current_timestamp = int(time.time())
join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp())
last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp())
# 创建用户
user = User(
name="Googujiang",
safe_name="googujiang", # 安全用户名(小写)
email="googujiang@example.com",
priv=1, # 默认权限
pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式
country="JP",
silence_end=0,
donor_end=0,
creation_time=join_timestamp,
latest_activity=last_visit_timestamp,
clan_id=0,
clan_priv=0,
preferred_mode=0, # 0 = osu!
play_style=0,
custom_badge_name=None,
custom_badge_icon=None,
userpage_content="「世界に忘れられた」",
api_key=None,
# # 兼容性字段
# avatar_url="https://a.ppy.sh/15651670?1732362658.jpeg",
# cover_url="https://assets.ppy.sh/user-profile-covers/15651670/0fc7b77adef39765a570e7f535bc383e5a848850d41a8943f8857984330b8bc6.jpeg",
# has_supported=True,
# interests="「世界に忘れられた」",
# location="咕谷国",
# website="https://gmoe.cc",
# playstyle=["mouse", "keyboard", "tablet"],
# profile_order=[
# "me",
# "recent_activity",
# "top_ranks",
# "medals",
# "historical",
# "beatmaps",
# "kudosu",
# ],
# beatmap_playcounts_count=3306,
# favourite_beatmapset_count=15,
# follower_count=98,
# graveyard_beatmapset_count=7,
# mapping_follower_count=1,
# previous_usernames=["hehejun"],
# monthly_playcounts=[
# {"start_date": "2019-11-01", "count": 43},
# {"start_date": "2020-04-01", "count": 216},
# {"start_date": "2020-05-01", "count": 656},
# {"start_date": "2020-07-01", "count": 158},
# {"start_date": "2020-08-01", "count": 174},
# {"start_date": "2020-10-01", "count": 13},
# {"start_date": "2020-11-01", "count": 52},
# {"start_date": "2020-12-01", "count": 140},
# {"start_date": "2021-01-01", "count": 359},
# {"start_date": "2021-02-01", "count": 452},
# {"start_date": "2021-03-01", "count": 77},
# {"start_date": "2021-04-01", "count": 114},
# {"start_date": "2021-05-01", "count": 270},
# {"start_date": "2021-06-01", "count": 148},
# {"start_date": "2021-07-01", "count": 246},
# {"start_date": "2021-08-01", "count": 56},
# {"start_date": "2021-09-01", "count": 136},
# {"start_date": "2021-10-01", "count": 45},
# {"start_date": "2021-11-01", "count": 98},
# {"start_date": "2021-12-01", "count": 54},
# {"start_date": "2022-01-01", "count": 88},
# {"start_date": "2022-02-01", "count": 45},
# {"start_date": "2022-03-01", "count": 6},
# {"start_date": "2022-04-01", "count": 54},
# {"start_date": "2022-05-01", "count": 105},
# {"start_date": "2022-06-01", "count": 37},
# {"start_date": "2022-07-01", "count": 88},
# {"start_date": "2022-08-01", "count": 7},
# {"start_date": "2022-09-01", "count": 9},
# {"start_date": "2022-10-01", "count": 6},
# {"start_date": "2022-11-01", "count": 2},
# {"start_date": "2022-12-01", "count": 16},
# {"start_date": "2023-01-01", "count": 7},
# {"start_date": "2023-04-01", "count": 16},
# {"start_date": "2023-05-01", "count": 3},
# {"start_date": "2023-06-01", "count": 8},
# {"start_date": "2023-07-01", "count": 23},
# {"start_date": "2023-08-01", "count": 3},
# {"start_date": "2023-09-01", "count": 1},
# {"start_date": "2023-10-01", "count": 25},
# {"start_date": "2023-11-01", "count": 160},
# {"start_date": "2023-12-01", "count": 306},
# {"start_date": "2024-01-01", "count": 735},
# {"start_date": "2024-02-01", "count": 420},
# {"start_date": "2024-03-01", "count": 549},
# {"start_date": "2024-04-01", "count": 466},
# {"start_date": "2024-05-01", "count": 333},
# {"start_date": "2024-06-01", "count": 1126},
# {"start_date": "2024-07-01", "count": 534},
# {"start_date": "2024-08-01", "count": 280},
# {"start_date": "2024-09-01", "count": 116},
# {"start_date": "2024-10-01", "count": 120},
# {"start_date": "2024-11-01", "count": 332},
# {"start_date": "2024-12-01", "count": 243},
# {"start_date": "2025-01-01", "count": 122},
# {"start_date": "2025-02-01", "count": 379},
# {"start_date": "2025-03-01", "count": 278},
# {"start_date": "2025-04-01", "count": 296},
# {"start_date": "2025-05-01", "count": 964},
# {"start_date": "2025-06-01", "count": 821},
# {"start_date": "2025-07-01", "count": 230},
# ],
)
# 创建用户
user = User(
name="Googujiang",
safe_name="googujiang", # 安全用户名(小写)
email="googujiang@example.com",
priv=1, # 默认权限
pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式
country="JP",
silence_end=0,
donor_end=0,
creation_time=join_timestamp,
latest_activity=last_visit_timestamp,
clan_id=0,
clan_priv=0,
preferred_mode=0, # 0 = osu!
play_style=0,
custom_badge_name=None,
custom_badge_icon=None,
userpage_content="「世界に忘れられた」",
api_key=None,
)
db.add(user)
db.commit()
db.refresh(user)
session.add(user)
await session.commit()
await session.refresh(user)
# 确保用户ID存在
if user.id is None:
raise ValueError("User ID is None after saving to database")
# 确保用户ID存在
if user.id is None:
raise ValueError("User ID is None after saving to database")
# 创建 osu! 模式统计
osu_stats = LazerUserStatistics(
user_id=user.id,
mode="osu",
count_100=276274,
count_300=1932068,
count_50=32776,
count_miss=111064,
level_current=97,
level_progress=96,
global_rank=298026,
country_rank=11221,
pp=2826.26,
ranked_score=4415081049,
hit_accuracy=95.7168,
play_count=12711,
play_time=836529,
total_score=12390140370,
total_hits=2241118,
maximum_combo=1859,
replays_watched_by_others=0,
is_ranked=True,
grade_ss=14,
grade_ssh=3,
grade_s=322,
grade_sh=11,
grade_a=757,
rank_highest=295701,
rank_highest_updated_at=datetime(2025, 7, 2, 17, 30, 21),
)
# 创建 taiko 模式统计
taiko_stats = LazerUserStatistics(
user_id=user.id,
mode="taiko",
count_100=160,
count_300=154,
count_50=0,
count_miss=480,
level_current=2,
level_progress=49,
global_rank=None,
pp=0,
ranked_score=0,
hit_accuracy=0,
play_count=6,
play_time=217,
total_score=79301,
total_hits=314,
maximum_combo=0,
replays_watched_by_others=0,
is_ranked=False,
)
# 创建 fruits 模式统计
fruits_stats = LazerUserStatistics(
user_id=user.id,
mode="fruits",
count_100=109,
count_300=1613,
count_50=1861,
count_miss=328,
level_current=6,
level_progress=14,
global_rank=None,
pp=0,
ranked_score=343854,
hit_accuracy=89.4779,
play_count=19,
play_time=669,
total_score=1362651,
total_hits=3583,
maximum_combo=75,
replays_watched_by_others=0,
is_ranked=False,
grade_a=1,
)
# 创建 mania 模式统计
mania_stats = LazerUserStatistics(
user_id=user.id,
mode="mania",
count_100=7867,
count_300=12104,
count_50=991,
count_miss=2951,
level_current=12,
level_progress=89,
global_rank=660670,
pp=25.3784,
ranked_score=3812295,
hit_accuracy=77.9316,
play_count=85,
play_time=4834,
total_score=13454470,
total_hits=20962,
maximum_combo=573,
replays_watched_by_others=0,
is_ranked=True,
grade_a=1,
)
db.add_all([osu_stats, taiko_stats, fruits_stats, mania_stats])
# 创建每日挑战统计
daily_challenge = DailyChallengeStats(
user_id=user.id,
daily_streak_best=1,
daily_streak_current=0,
last_update=datetime(2025, 6, 21, 0, 0, 0),
last_weekly_streak=datetime(2025, 6, 19, 0, 0, 0),
playcount=1,
top_10p_placements=0,
top_50p_placements=0,
weekly_streak_best=1,
weekly_streak_current=0,
)
db.add(daily_challenge)
# 创建排名历史 (最近90天的数据)
rank_data = [
322806,
323092,
323341,
323616,
323853,
324106,
324378,
324676,
324958,
325254,
325492,
325780,
326075,
326356,
326586,
326845,
327067,
327286,
327526,
327778,
328039,
328347,
328631,
328858,
329323,
329557,
329809,
329911,
330188,
330425,
330650,
330881,
331068,
331325,
331575,
331816,
332061,
328959,
315648,
315881,
308784,
309023,
309252,
309433,
309537,
309364,
309548,
308957,
309182,
309426,
309607,
309831,
310054,
310269,
310485,
310714,
310956,
310924,
311125,
311203,
311422,
311640,
303091,
303309,
303500,
303691,
303758,
303750,
303957,
299867,
300088,
300273,
300457,
295799,
295976,
296153,
296350,
296566,
296756,
296933,
297141,
297314,
297480,
297114,
297296,
297480,
297645,
297815,
297993,
298026,
]
rank_history = RankHistory(user_id=user.id, mode="osu", rank_data=rank_data)
db.add(rank_history)
# 创建一些成就
achievements = [
LazerUserAchievement(
user_id=user.id,
achievement_id=336,
achieved_at=datetime(2025, 6, 21, 19, 6, 32),
),
LazerUserAchievement(
user_id=user.id,
achievement_id=319,
achieved_at=datetime(2025, 6, 1, 0, 52, 0),
),
LazerUserAchievement(
user_id=user.id,
achievement_id=222,
achieved_at=datetime(2025, 5, 28, 12, 24, 37),
),
LazerUserAchievement(
user_id=user.id,
achievement_id=38,
achieved_at=datetime(2024, 7, 5, 15, 43, 23),
),
LazerUserAchievement(
user_id=user.id,
achievement_id=67,
achieved_at=datetime(2024, 6, 24, 5, 6, 44),
),
]
db.add_all(achievements)
db.commit()
print(f"成功创建示例用户: {user.name} (ID: {user.id})")
print(f"安全用户名: {user.safe_name}")
print(f"邮箱: {user.email}")
print(f"国家: {user.country}")
return user
print(f"成功创建示例用户: {user.name} (ID: {user.id})")
print(f"安全用户名: {user.safe_name}")
print(f"邮箱: {user.email}")
print(f"国家: {user.country}")
return user
if __name__ == "__main__":
async def main():
print("开始创建示例数据...")
user = create_sample_user()
await create_tables()
user = await create_sample_user()
print("示例数据创建完成!")
print(f"用户名: {user.name}")
print("密码: password123")
print("现在您可以使用这些凭据来测试API了。")
if __name__ == "__main__":
asyncio.run(main())

17
main.py
View File

@@ -1,24 +1,31 @@
from __future__ import annotations
from contextlib import asynccontextmanager
from datetime import datetime
from app.config import settings
from app.dependencies.database import engine
from app.dependencies.database import create_tables
from app.router import api_router, auth_router, signalr_router
from fastapi import FastAPI
from sqlmodel import SQLModel
# 注意: 表结构现在通过 migrations 管理,不再自动创建
# 如需创建表,请运行: python quick_sync.py
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0")
@asynccontextmanager
async def lifespan(app: FastAPI):
# on startup
await create_tables()
# on shutdown
yield
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
app.include_router(api_router, prefix="/api/v2")
app.include_router(signalr_router, prefix="/signalr")
app.include_router(auth_router)
SQLModel.metadata.create_all(bind=engine)
@app.get("/")
async def root():

View File

@@ -5,6 +5,7 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"aiomysql>=0.2.0",
"alembic>=1.12.1",
"bcrypt>=4.1.2",
"cryptography>=41.0.7",
@@ -12,7 +13,6 @@ dependencies = [
"msgpack>=1.1.1",
"passlib[bcrypt]>=1.7.4",
"pydantic[email]>=2.5.0",
"pymysql>=1.1.0",
"python-dotenv>=1.0.0",
"python-jose[cryptography]>=3.3.0",
"python-multipart>=0.0.6",

View File

@@ -2,7 +2,6 @@ fastapi==0.104.1
uvicorn[standard]==0.24.0
sqlalchemy==2.0.23
alembic==1.12.1
pymysql==1.1.0
cryptography==41.0.7
redis==5.0.1
python-jose[cryptography]==3.3.0
@@ -11,3 +10,4 @@ python-multipart==0.0.6
pydantic[email]==2.5.0
python-dotenv==1.0.0
bcrypt==4.1.2
aiomysql==0.2.0

View File

@@ -12,108 +12,106 @@ import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from app.database import User
from app.dependencies.database import get_db
from app.dependencies.database import engine
from app.utils import convert_db_user_to_api_user
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
def test_lazer_tables():
async def test_lazer_tables():
"""测试 lazer 表的基本功能"""
print("测试 Lazer API 表支持...")
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
async with AsyncSession(engine) as session:
async with session.begin():
try:
# 测试查询用户
statement = select(User)
result = await session.execute(statement)
user = result.scalars().first()
if not user:
print("❌ 没有找到用户,请先同步数据")
return False
try:
# 测试查询用户
statement = select(User)
user = db.exec(statement).first()
if not user:
print("❌ 没有找到用户,请先同步数据")
return False
print(f"✓ 找到用户: {user.name} (ID: {user.id})")
print(f"✓ 找到用户: {user.name} (ID: {user.id})")
# 测试 lazer 资料
if user.lazer_profile:
print(
f"✓ 用户有 lazer 资料: 支持者={user.lazer_profile.is_supporter}"
)
else:
print("⚠ 用户没有 lazer 资料,将使用默认值")
# 测试 lazer 资料
if user.lazer_profile:
print(f"✓ 用户有 lazer 资料: 支持者={user.lazer_profile.is_supporter}")
else:
print("⚠ 用户没有 lazer 资料,将使用默认值")
# 测试 lazer 统计
osu_stats = None
for stat in user.lazer_statistics:
if stat.mode == "osu":
osu_stats = stat
break
# 测试 lazer 统计
osu_stats = None
for stat in user.lazer_statistics:
if stat.mode == "osu":
osu_stats = stat
break
if osu_stats:
print(
f"✓ 用户有 osu! 统计: PP={osu_stats.pp}, "
f"游戏次数={osu_stats.play_count}"
)
else:
print("⚠ 用户没有 osu! 统计,将使用默认值")
if osu_stats:
print(
f"✓ 用户有 osu! 统计: PP={osu_stats.pp}, "
f"游戏次数={osu_stats.play_count}"
)
else:
print("⚠ 用户没有 osu! 统计,将使用默认值")
# 测试转换为 API 格式
api_user = convert_db_user_to_api_user(user, "osu")
print("✓ 成功转换为 API 用户格式")
print(f" - 用户名: {api_user.username}")
print(f" - 国家: {api_user.country_code}")
print(f" - PP: {api_user.statistics.pp}")
print(f" - 是否支持者: {api_user.is_supporter}")
# 测试转换为 API 格式
api_user = convert_db_user_to_api_user(user, "osu", db)
print("✓ 成功转换为 API 用户格式")
print(f" - 用户名: {api_user.username}")
print(f" - 国家: {api_user.country_code}")
print(f" - PP: {api_user.statistics.pp}")
print(f" - 是否支持者: {api_user.is_supporter}")
return True
return True
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
finally:
db.close()
traceback.print_exc()
return False
def test_authentication():
async def test_authentication():
"""测试认证功能"""
print("\n测试认证功能...")
db_gen = get_db()
db = next(db_gen)
async with AsyncSession(engine) as session:
async with session.begin():
try:
# 尝试认证第一个用户
statement = select(User)
result = await session.execute(statement)
user = result.scalars().first()
if not user:
print("❌ 没有用户进行认证测试")
return False
try:
# 尝试认证第一个用户
statement = select(User)
user = db.exec(statement).first()
if not user:
print("❌ 没有用户进行认证测试")
return False
print(f"✓ 测试用户: {user.name}")
print("⚠ 注意: 实际密码认证需要正确的密码")
print(f"✓ 测试用户: {user.name}")
print("⚠ 注意: 实际密码认证需要正确的密码")
return True
return True
except Exception as e:
print(f"❌ 认证测试失败: {e}")
return False
finally:
db.close()
except Exception as e:
print(f"❌ 认证测试失败: {e}")
return False
def main():
async def main():
"""主测试函数"""
print("Lazer API 系统测试")
print("=" * 40)
# 测试表连接
success1 = test_lazer_tables()
success1 = await test_lazer_tables()
# 测试认证
success2 = test_authentication()
success2 = await test_authentication()
print("\n" + "=" * 40)
if success1 and success2:
@@ -130,4 +128,6 @@ def main():
if __name__ == "__main__":
main()
import asyncio
asyncio.run(main())

16
uv.lock generated
View File

@@ -2,6 +2,18 @@ version = 1
revision = 2
requires-python = ">=3.11"
[[package]]
name = "aiomysql"
version = "0.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pymysql" },
]
sdist = { url = "https://files.pythonhosted.org/packages/67/76/2c5b55e4406a1957ffdfd933a94c2517455291c97d2b81cec6813754791a/aiomysql-0.2.0.tar.gz", hash = "sha256:558b9c26d580d08b8c5fd1be23c5231ce3aeff2dadad989540fee740253deb67", size = 114706, upload-time = "2023-06-11T19:57:53.608Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/42/87/c982ee8b333c85b8ae16306387d703a1fcdfc81a2f3f15a24820ab1a512d/aiomysql-0.2.0-py3-none-any.whl", hash = "sha256:b7c26da0daf23a5ec5e0b133c03d20657276e4eae9b73e040b72787f6f6ade0a", size = 44215, upload-time = "2023-06-11T19:57:51.09Z" },
]
[[package]]
name = "alembic"
version = "1.16.4"
@@ -462,6 +474,7 @@ name = "osu-lazer-api"
version = "0.1.0"
source = { virtual = "." }
dependencies = [
{ name = "aiomysql" },
{ name = "alembic" },
{ name = "bcrypt" },
{ name = "cryptography" },
@@ -469,7 +482,6 @@ dependencies = [
{ name = "msgpack" },
{ name = "passlib", extra = ["bcrypt"] },
{ name = "pydantic", extra = ["email"] },
{ name = "pymysql" },
{ name = "python-dotenv" },
{ name = "python-jose", extra = ["cryptography"] },
{ name = "python-multipart" },
@@ -487,6 +499,7 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "aiomysql", specifier = ">=0.2.0" },
{ name = "alembic", specifier = ">=1.12.1" },
{ name = "bcrypt", specifier = ">=4.1.2" },
{ name = "cryptography", specifier = ">=41.0.7" },
@@ -494,7 +507,6 @@ requires-dist = [
{ name = "msgpack", specifier = ">=1.1.1" },
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
{ name = "pydantic", extras = ["email"], specifier = ">=2.5.0" },
{ name = "pymysql", specifier = ">=1.1.0" },
{ name = "python-dotenv", specifier = ">=1.0.0" },
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
{ name = "python-multipart", specifier = ">=0.0.6" },