diff --git a/app/auth.py b/app/auth.py index dbd5f27..b250844 100644 --- a/app/auth.py +++ b/app/auth.py @@ -14,7 +14,8 @@ from app.database import ( import bcrypt from jose import JWTError, jwt from passlib.context import CryptContext -from sqlmodel import Session, select +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession # 密码哈希上下文 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -70,7 +71,9 @@ def get_password_hash(password: str) -> str: return pw_bcrypt.decode() -def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None: +async def authenticate_user_legacy( + db: AsyncSession, name: str, password: str +) -> DBUser | None: """ 验证用户身份 - 使用类似 from_login 的逻辑 """ @@ -79,7 +82,7 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | # 2. 根据用户名查找用户 statement = select(DBUser).where(DBUser.name == name) - user = db.exec(statement).first() + user = (await db.exec(statement)).first() if not user: return None @@ -107,9 +110,11 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | return None -def authenticate_user(db: Session, username: str, password: str) -> DBUser | None: +async def authenticate_user( + db: AsyncSession, username: str, password: str +) -> DBUser | None: """验证用户身份""" - return authenticate_user_legacy(db, username, password) + return await authenticate_user_legacy(db, username, password) def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str: @@ -147,24 +152,28 @@ def verify_token(token: str) -> dict | None: return None -def store_token( - db: Session, user_id: int, access_token: str, refresh_token: str, expires_in: int +async def store_token( + db: AsyncSession, + user_id: int, + access_token: str, + refresh_token: str, + expires_in: int, ) -> OAuthToken: """存储令牌到数据库""" expires_at = datetime.utcnow() + timedelta(seconds=expires_in) # 删除用户的旧令牌 statement = select(OAuthToken).where(OAuthToken.user_id == user_id) - old_tokens = db.exec(statement).all() + old_tokens = (await db.exec(statement)).all() for token in old_tokens: - db.delete(token) + await db.delete(token) # 检查是否有重复的 access_token - duplicate_token = db.exec( - select(OAuthToken).where(OAuthToken.access_token == access_token) + duplicate_token = ( + await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token)) ).first() if duplicate_token: - db.delete(duplicate_token) + await db.delete(duplicate_token) # 创建新令牌记录 token_record = OAuthToken( @@ -174,24 +183,28 @@ def store_token( expires_at=expires_at, ) db.add(token_record) - db.commit() - db.refresh(token_record) + await db.commit() + await db.refresh(token_record) return token_record -def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None: +async def get_token_by_access_token( + db: AsyncSession, access_token: str +) -> OAuthToken | None: """根据访问令牌获取令牌记录""" statement = select(OAuthToken).where( OAuthToken.access_token == access_token, OAuthToken.expires_at > datetime.utcnow(), ) - return db.exec(statement).first() + return (await db.exec(statement)).first() -def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None: +async def get_token_by_refresh_token( + db: AsyncSession, refresh_token: str +) -> OAuthToken | None: """根据刷新令牌获取令牌记录""" statement = select(OAuthToken).where( OAuthToken.refresh_token == refresh_token, OAuthToken.expires_at > datetime.utcnow(), ) - return db.exec(statement).first() + return (await db.exec(statement)).first() diff --git a/app/config.py b/app/config.py index f902fe6..6f3aa25 100644 --- a/app/config.py +++ b/app/config.py @@ -10,7 +10,7 @@ load_dotenv() class Settings: # 数据库设置 DATABASE_URL: str = os.getenv( - "DATABASE_URL", "mysql+pymysql://root:password@localhost:3306/osu_api" + "DATABASE_URL", "mysql+aiomysql://root:password@localhost:3306/osu_api" ) REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0") diff --git a/app/database/user.py b/app/database/user.py index facab63..d1a24fb 100644 --- a/app/database/user.py +++ b/app/database/user.py @@ -15,7 +15,7 @@ class User(SQLModel, table=True): __tablename__ = "users" # pyright: ignore[reportAssignmentType] # 主键 - id: int | None = Field(default=None, primary_key=True, index=True) + id: int = Field(default=None, primary_key=True, index=True, nullable=False) # 基本信息(匹配 migrations 中的结构) name: str = Field(max_length=32, unique=True, index=True) # 用户名 diff --git a/app/dependencies/__init__.py b/app/dependencies/__init__.py index f36fe81..cdcce5a 100644 --- a/app/dependencies/__init__.py +++ b/app/dependencies/__init__.py @@ -1,2 +1,4 @@ +from __future__ import annotations + from .database import get_db as get_db -from .user import get_current_user as get_current_user \ No newline at end of file +from .user import get_current_user as get_current_user diff --git a/app/dependencies/database.py b/app/dependencies/database.py index c297c1c..288660c 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -1,6 +1,8 @@ from __future__ import annotations -from sqlmodel import Session, create_engine +from sqlalchemy.ext.asyncio import create_async_engine +from sqlmodel import SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession try: import redis @@ -9,7 +11,7 @@ except ImportError: from app.config import settings # 数据库引擎 -engine = create_engine(settings.DATABASE_URL) +engine = create_async_engine(settings.DATABASE_URL) # Redis 连接 if redis: @@ -19,11 +21,16 @@ else: # 数据库依赖 -def get_db(): - with Session(engine) as session: +async def get_db(): + async with AsyncSession(engine) as session: yield session +async def create_tables(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + + # Redis 依赖 def get_redis(): return redis_client diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 580fbe0..5c3b396 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -9,14 +9,16 @@ from .database import get_db from fastapi import Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from sqlmodel import Session, select +from sqlalchemy.orm import joinedload, selectinload +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession security = HTTPBearer() async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ) -> DBUser: """获取当前认证用户""" token = credentials.credentials @@ -27,9 +29,31 @@ async def get_current_user( return user -async def get_current_user_by_token(token: str, db: Session) -> DBUser | None: - token_record = get_token_by_access_token(db, token) +async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None: + token_record = await get_token_by_access_token(db, token) if not token_record: return None - user = db.exec(select(DBUser).where(DBUser.id == token_record.user_id)).first() + user = ( + await db.exec( + select(DBUser) + .options( + joinedload(DBUser.lazer_profile), # pyright: ignore[reportArgumentType] + joinedload(DBUser.lazer_counts), # pyright: ignore[reportArgumentType] + joinedload(DBUser.daily_challenge_stats), # pyright: ignore[reportArgumentType] + joinedload(DBUser.avatar), # pyright: ignore[reportArgumentType] + selectinload(DBUser.lazer_statistics), # pyright: ignore[reportArgumentType] + selectinload(DBUser.lazer_achievements), # pyright: ignore[reportArgumentType] + selectinload(DBUser.lazer_profile_sections), # pyright: ignore[reportArgumentType] + selectinload(DBUser.statistics), # pyright: ignore[reportArgumentType] + selectinload(DBUser.team_membership), # pyright: ignore[reportArgumentType] + selectinload(DBUser.rank_history), # pyright: ignore[reportArgumentType] + selectinload(DBUser.active_banners), # pyright: ignore[reportArgumentType] + selectinload(DBUser.lazer_badges), # pyright: ignore[reportArgumentType] + selectinload(DBUser.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType] + selectinload(DBUser.lazer_previous_usernames), # pyright: ignore[reportArgumentType] + selectinload(DBUser.lazer_replays_watched), # pyright: ignore[reportArgumentType] + ) + .where(DBUser.id == token_record.user_id) + ) + ).first() return user diff --git a/app/router/api_router.py b/app/router/api_router.py index e6f2f82..6a3e356 100644 --- a/app/router/api_router.py +++ b/app/router/api_router.py @@ -1,4 +1,5 @@ +from __future__ import annotations + from fastapi import APIRouter - router = APIRouter() diff --git a/app/router/auth.py b/app/router/auth.py index 838a75e..a1277e3 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -14,7 +14,7 @@ from app.dependencies import get_db from app.models.oauth import TokenResponse from fastapi import APIRouter, Depends, Form, HTTPException -from sqlmodel import Session +from sqlmodel.ext.asyncio.session import AsyncSession router = APIRouter(tags=["osu! OAuth 认证"]) @@ -28,7 +28,7 @@ async def oauth_token( username: str | None = Form(None), password: str | None = Form(None), refresh_token: str | None = Form(None), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): """OAuth 令牌端点""" # 验证客户端凭据 @@ -46,7 +46,7 @@ async def oauth_token( ) # 验证用户 - user = authenticate_user(db, username, password) + user = await authenticate_user(db, username, password) if not user: raise HTTPException(status_code=401, detail="Invalid username or password") @@ -58,9 +58,9 @@ async def oauth_token( refresh_token_str = generate_refresh_token() # 存储令牌 - store_token( + await store_token( db, - getattr(user, "id"), + user.id, access_token, refresh_token_str, settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, @@ -80,7 +80,7 @@ async def oauth_token( raise HTTPException(status_code=400, detail="Refresh token required") # 验证刷新令牌 - token_record = get_token_by_refresh_token(db, refresh_token) + token_record =await get_token_by_refresh_token(db, refresh_token) if not token_record: raise HTTPException(status_code=401, detail="Invalid refresh token") @@ -92,10 +92,9 @@ async def oauth_token( new_refresh_token = generate_refresh_token() # 更新令牌 - user_id = int(getattr(token_record, "user_id")) - store_token( + await store_token( db, - user_id, + token_record.user_id, access_token, new_refresh_token, settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, diff --git a/app/router/beatmap.py b/app/router/beatmap.py index d232d4d..c02f79f 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -5,6 +5,7 @@ from app.database import ( BeatmapResp, User as DBUser, ) +from app.database.beatmapset import Beatmapset from app.dependencies.database import get_db from app.dependencies.user import get_current_user @@ -12,16 +13,24 @@ from .api_router import router from fastapi import Depends, HTTPException, Query from pydantic import BaseModel -from sqlmodel import Session, col, select +from sqlalchemy.orm import joinedload +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) async def get_beatmap( bid: int, current_user: DBUser = Depends(get_current_user), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): - beatmap = db.exec(select(Beatmap).where(Beatmap.id == bid)).first() + beatmap = ( + await db.exec( + select(Beatmap) + .options(joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType] + .where(Beatmap.id == bid) + ) + ).first() if not beatmap: raise HTTPException(status_code=404, detail="Beatmap not found") return BeatmapResp.from_db(beatmap) @@ -36,16 +45,30 @@ class BatchGetResp(BaseModel): async def batch_get_beatmaps( b_ids: list[int] = Query(alias="id", default_factory=list), current_user: DBUser = Depends(get_current_user), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): if not b_ids: # select 50 beatmaps by last_updated - beatmaps = db.exec( - select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50) + beatmaps = ( + await db.exec( + select(Beatmap) + .options( + joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] + ) + .order_by(col(Beatmap.last_updated).desc()) + .limit(50) + ) ).all() else: - beatmaps = db.exec( - select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50) + beatmaps = ( + await db.exec( + select(Beatmap) + .options( + joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType] + ) + .where(col(Beatmap.id).in_(b_ids)) + .limit(50) + ) ).all() return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps]) diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index 68569cc..eceb19b 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -11,16 +11,24 @@ from app.dependencies.user import get_current_user from .api_router import router from fastapi import Depends, HTTPException -from sqlmodel import Session, select +from sqlalchemy.orm import selectinload +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession @router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp) async def get_beatmapset( sid: int, current_user: DBUser = Depends(get_current_user), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): - beatmapset = db.exec(select(Beatmapset).where(Beatmapset.id == sid)).first() + beatmapset = ( + await db.exec( + select(Beatmapset) + .options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType] + .where(Beatmapset.id == sid) + ) + ).first() if not beatmapset: raise HTTPException(status_code=404, detail="Beatmapset not found") return BeatmapsetResp.from_db(beatmapset) diff --git a/app/router/me.py b/app/router/me.py index 82e43a1..93dcbdc 100644 --- a/app/router/me.py +++ b/app/router/me.py @@ -5,7 +5,7 @@ from typing import Literal from app.database import ( User as DBUser, ) -from app.dependencies import get_current_user, get_db +from app.dependencies import get_current_user from app.models.user import ( User as ApiUser, ) @@ -14,7 +14,6 @@ from app.utils import convert_db_user_to_api_user from .api_router import router from fastapi import Depends -from sqlalchemy.orm import Session @router.get("/me/{ruleset}", response_model=ApiUser) @@ -22,9 +21,8 @@ from sqlalchemy.orm import Session async def get_user_info_default( ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu", current_user: DBUser = Depends(get_current_user), - db: Session = Depends(get_db), ): """获取当前用户信息(默认使用osu模式)""" # 默认使用osu模式 - api_user = convert_db_user_to_api_user(current_user, ruleset, db) + api_user = await convert_db_user_to_api_user(current_user, ruleset) return api_user diff --git a/app/router/signalr/__init__.py b/app/router/signalr/__init__.py index 4094d6a..881b00e 100644 --- a/app/router/signalr/__init__.py +++ b/app/router/signalr/__init__.py @@ -1 +1,3 @@ +from __future__ import annotations + from .router import router as signalr_router diff --git a/app/router/signalr/router.py b/app/router/signalr/router.py index 0894341..bce1611 100644 --- a/app/router/signalr/router.py +++ b/app/router/signalr/router.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -from logging import info import time from typing import Literal import uuid @@ -9,15 +8,14 @@ import uuid from app.database import User as DBUser from app.dependencies import get_current_user from app.dependencies.database import get_db -from app.dependencies.user import get_current_user_by_token, security +from app.dependencies.user import get_current_user_by_token from app.models.signalr import NegotiateResponse, Transport from app.router.signalr.packet import SEP from .hub import Hubs from fastapi import APIRouter, Depends, Header, Query, WebSocket -from fastapi.security import HTTPAuthorizationCredentials -from sqlalchemy.orm import Session +from sqlmodel.ext.asyncio.session import AsyncSession router = APIRouter() @@ -48,7 +46,7 @@ async def connect( websocket: WebSocket, id: str, authorization: str = Header(...), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): token = authorization[7:] user_id = id.split(":")[0] diff --git a/app/router/signalr/utils.py b/app/router/signalr/utils.py index bce3ae9..02d08c2 100644 --- a/app/router/signalr/utils.py +++ b/app/router/signalr/utils.py @@ -1,5 +1,8 @@ +from __future__ import annotations + +from collections.abc import Callable import inspect -from typing import Any, Callable, ForwardRef, cast +from typing import Any, ForwardRef, cast # https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66 diff --git a/app/utils.py b/app/utils.py index 89c5d86..103e695 100644 --- a/app/utils.py +++ b/app/utils.py @@ -23,11 +23,9 @@ from app.models.user import ( UserAchievement, ) -from sqlalchemy.orm import Session - -def convert_db_user_to_api_user( - db_user: DBUser, ruleset: str = "osu", db_session: Session | None = None +async def convert_db_user_to_api_user( + db_user: DBUser, ruleset: str = "osu" ) -> User: """将数据库用户模型转换为API用户模型(使用 Lazer 表)""" diff --git a/create_sample_data.py b/create_sample_data.py index 5758cd8..c65c01d 100644 --- a/create_sample_data.py +++ b/create_sample_data.py @@ -5,419 +5,84 @@ osu! API 模拟服务器的示例数据填充脚本 from __future__ import annotations +import asyncio from datetime import datetime import time from app.auth import get_password_hash from app.database import ( - DailyChallengeStats, - LazerUserAchievement, - LazerUserStatistics, - RankHistory, User, ) -from app.dependencies.database import engine, get_db +from app.dependencies.database import create_tables, engine -from sqlmodel import SQLModel - -# 创建所有表 -SQLModel.metadata.create_all(bind=engine) +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession -def create_sample_user(): +async def create_sample_user(): """创建示例用户数据""" - with next(get_db()) as db: - # 检查用户是否已存在 - from sqlmodel import select + async with AsyncSession(engine) as session: + async with session.begin(): - statement = select(User).where(User.name == "Googujiang") - existing_user = db.exec(statement).first() - if existing_user: - print("示例用户已存在,跳过创建") - return existing_user + # 检查用户是否已存在 + statement = select(User).where(User.name == "Googujiang") + result = await session.execute(statement) + existing_user = result.scalars().first() + if existing_user: + print("示例用户已存在,跳过创建") + return existing_user - # 当前时间戳 - current_timestamp = int(time.time()) - join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp()) - last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp()) + # 当前时间戳 + current_timestamp = int(time.time()) + join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp()) + last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp()) - # 创建用户 - user = User( - name="Googujiang", - safe_name="googujiang", # 安全用户名(小写) - email="googujiang@example.com", - priv=1, # 默认权限 - pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式 - country="JP", - silence_end=0, - donor_end=0, - creation_time=join_timestamp, - latest_activity=last_visit_timestamp, - clan_id=0, - clan_priv=0, - preferred_mode=0, # 0 = osu! - play_style=0, - custom_badge_name=None, - custom_badge_icon=None, - userpage_content="「世界に忘れられた」", - api_key=None, - # # 兼容性字段 - # avatar_url="https://a.ppy.sh/15651670?1732362658.jpeg", - # cover_url="https://assets.ppy.sh/user-profile-covers/15651670/0fc7b77adef39765a570e7f535bc383e5a848850d41a8943f8857984330b8bc6.jpeg", - # has_supported=True, - # interests="「世界に忘れられた」", - # location="咕谷国", - # website="https://gmoe.cc", - # playstyle=["mouse", "keyboard", "tablet"], - # profile_order=[ - # "me", - # "recent_activity", - # "top_ranks", - # "medals", - # "historical", - # "beatmaps", - # "kudosu", - # ], - # beatmap_playcounts_count=3306, - # favourite_beatmapset_count=15, - # follower_count=98, - # graveyard_beatmapset_count=7, - # mapping_follower_count=1, - # previous_usernames=["hehejun"], - # monthly_playcounts=[ - # {"start_date": "2019-11-01", "count": 43}, - # {"start_date": "2020-04-01", "count": 216}, - # {"start_date": "2020-05-01", "count": 656}, - # {"start_date": "2020-07-01", "count": 158}, - # {"start_date": "2020-08-01", "count": 174}, - # {"start_date": "2020-10-01", "count": 13}, - # {"start_date": "2020-11-01", "count": 52}, - # {"start_date": "2020-12-01", "count": 140}, - # {"start_date": "2021-01-01", "count": 359}, - # {"start_date": "2021-02-01", "count": 452}, - # {"start_date": "2021-03-01", "count": 77}, - # {"start_date": "2021-04-01", "count": 114}, - # {"start_date": "2021-05-01", "count": 270}, - # {"start_date": "2021-06-01", "count": 148}, - # {"start_date": "2021-07-01", "count": 246}, - # {"start_date": "2021-08-01", "count": 56}, - # {"start_date": "2021-09-01", "count": 136}, - # {"start_date": "2021-10-01", "count": 45}, - # {"start_date": "2021-11-01", "count": 98}, - # {"start_date": "2021-12-01", "count": 54}, - # {"start_date": "2022-01-01", "count": 88}, - # {"start_date": "2022-02-01", "count": 45}, - # {"start_date": "2022-03-01", "count": 6}, - # {"start_date": "2022-04-01", "count": 54}, - # {"start_date": "2022-05-01", "count": 105}, - # {"start_date": "2022-06-01", "count": 37}, - # {"start_date": "2022-07-01", "count": 88}, - # {"start_date": "2022-08-01", "count": 7}, - # {"start_date": "2022-09-01", "count": 9}, - # {"start_date": "2022-10-01", "count": 6}, - # {"start_date": "2022-11-01", "count": 2}, - # {"start_date": "2022-12-01", "count": 16}, - # {"start_date": "2023-01-01", "count": 7}, - # {"start_date": "2023-04-01", "count": 16}, - # {"start_date": "2023-05-01", "count": 3}, - # {"start_date": "2023-06-01", "count": 8}, - # {"start_date": "2023-07-01", "count": 23}, - # {"start_date": "2023-08-01", "count": 3}, - # {"start_date": "2023-09-01", "count": 1}, - # {"start_date": "2023-10-01", "count": 25}, - # {"start_date": "2023-11-01", "count": 160}, - # {"start_date": "2023-12-01", "count": 306}, - # {"start_date": "2024-01-01", "count": 735}, - # {"start_date": "2024-02-01", "count": 420}, - # {"start_date": "2024-03-01", "count": 549}, - # {"start_date": "2024-04-01", "count": 466}, - # {"start_date": "2024-05-01", "count": 333}, - # {"start_date": "2024-06-01", "count": 1126}, - # {"start_date": "2024-07-01", "count": 534}, - # {"start_date": "2024-08-01", "count": 280}, - # {"start_date": "2024-09-01", "count": 116}, - # {"start_date": "2024-10-01", "count": 120}, - # {"start_date": "2024-11-01", "count": 332}, - # {"start_date": "2024-12-01", "count": 243}, - # {"start_date": "2025-01-01", "count": 122}, - # {"start_date": "2025-02-01", "count": 379}, - # {"start_date": "2025-03-01", "count": 278}, - # {"start_date": "2025-04-01", "count": 296}, - # {"start_date": "2025-05-01", "count": 964}, - # {"start_date": "2025-06-01", "count": 821}, - # {"start_date": "2025-07-01", "count": 230}, - # ], - ) + # 创建用户 + user = User( + name="Googujiang", + safe_name="googujiang", # 安全用户名(小写) + email="googujiang@example.com", + priv=1, # 默认权限 + pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式 + country="JP", + silence_end=0, + donor_end=0, + creation_time=join_timestamp, + latest_activity=last_visit_timestamp, + clan_id=0, + clan_priv=0, + preferred_mode=0, # 0 = osu! + play_style=0, + custom_badge_name=None, + custom_badge_icon=None, + userpage_content="「世界に忘れられた」", + api_key=None, + ) - db.add(user) - db.commit() - db.refresh(user) + session.add(user) + await session.commit() + await session.refresh(user) - # 确保用户ID存在 - if user.id is None: - raise ValueError("User ID is None after saving to database") + # 确保用户ID存在 + if user.id is None: + raise ValueError("User ID is None after saving to database") - # 创建 osu! 模式统计 - osu_stats = LazerUserStatistics( - user_id=user.id, - mode="osu", - count_100=276274, - count_300=1932068, - count_50=32776, - count_miss=111064, - level_current=97, - level_progress=96, - global_rank=298026, - country_rank=11221, - pp=2826.26, - ranked_score=4415081049, - hit_accuracy=95.7168, - play_count=12711, - play_time=836529, - total_score=12390140370, - total_hits=2241118, - maximum_combo=1859, - replays_watched_by_others=0, - is_ranked=True, - grade_ss=14, - grade_ssh=3, - grade_s=322, - grade_sh=11, - grade_a=757, - rank_highest=295701, - rank_highest_updated_at=datetime(2025, 7, 2, 17, 30, 21), - ) - - # 创建 taiko 模式统计 - taiko_stats = LazerUserStatistics( - user_id=user.id, - mode="taiko", - count_100=160, - count_300=154, - count_50=0, - count_miss=480, - level_current=2, - level_progress=49, - global_rank=None, - pp=0, - ranked_score=0, - hit_accuracy=0, - play_count=6, - play_time=217, - total_score=79301, - total_hits=314, - maximum_combo=0, - replays_watched_by_others=0, - is_ranked=False, - ) - - # 创建 fruits 模式统计 - fruits_stats = LazerUserStatistics( - user_id=user.id, - mode="fruits", - count_100=109, - count_300=1613, - count_50=1861, - count_miss=328, - level_current=6, - level_progress=14, - global_rank=None, - pp=0, - ranked_score=343854, - hit_accuracy=89.4779, - play_count=19, - play_time=669, - total_score=1362651, - total_hits=3583, - maximum_combo=75, - replays_watched_by_others=0, - is_ranked=False, - grade_a=1, - ) - - # 创建 mania 模式统计 - mania_stats = LazerUserStatistics( - user_id=user.id, - mode="mania", - count_100=7867, - count_300=12104, - count_50=991, - count_miss=2951, - level_current=12, - level_progress=89, - global_rank=660670, - pp=25.3784, - ranked_score=3812295, - hit_accuracy=77.9316, - play_count=85, - play_time=4834, - total_score=13454470, - total_hits=20962, - maximum_combo=573, - replays_watched_by_others=0, - is_ranked=True, - grade_a=1, - ) - - db.add_all([osu_stats, taiko_stats, fruits_stats, mania_stats]) - - # 创建每日挑战统计 - daily_challenge = DailyChallengeStats( - user_id=user.id, - daily_streak_best=1, - daily_streak_current=0, - last_update=datetime(2025, 6, 21, 0, 0, 0), - last_weekly_streak=datetime(2025, 6, 19, 0, 0, 0), - playcount=1, - top_10p_placements=0, - top_50p_placements=0, - weekly_streak_best=1, - weekly_streak_current=0, - ) - - db.add(daily_challenge) - - # 创建排名历史 (最近90天的数据) - rank_data = [ - 322806, - 323092, - 323341, - 323616, - 323853, - 324106, - 324378, - 324676, - 324958, - 325254, - 325492, - 325780, - 326075, - 326356, - 326586, - 326845, - 327067, - 327286, - 327526, - 327778, - 328039, - 328347, - 328631, - 328858, - 329323, - 329557, - 329809, - 329911, - 330188, - 330425, - 330650, - 330881, - 331068, - 331325, - 331575, - 331816, - 332061, - 328959, - 315648, - 315881, - 308784, - 309023, - 309252, - 309433, - 309537, - 309364, - 309548, - 308957, - 309182, - 309426, - 309607, - 309831, - 310054, - 310269, - 310485, - 310714, - 310956, - 310924, - 311125, - 311203, - 311422, - 311640, - 303091, - 303309, - 303500, - 303691, - 303758, - 303750, - 303957, - 299867, - 300088, - 300273, - 300457, - 295799, - 295976, - 296153, - 296350, - 296566, - 296756, - 296933, - 297141, - 297314, - 297480, - 297114, - 297296, - 297480, - 297645, - 297815, - 297993, - 298026, - ] - - rank_history = RankHistory(user_id=user.id, mode="osu", rank_data=rank_data) - - db.add(rank_history) - - # 创建一些成就 - achievements = [ - LazerUserAchievement( - user_id=user.id, - achievement_id=336, - achieved_at=datetime(2025, 6, 21, 19, 6, 32), - ), - LazerUserAchievement( - user_id=user.id, - achievement_id=319, - achieved_at=datetime(2025, 6, 1, 0, 52, 0), - ), - LazerUserAchievement( - user_id=user.id, - achievement_id=222, - achieved_at=datetime(2025, 5, 28, 12, 24, 37), - ), - LazerUserAchievement( - user_id=user.id, - achievement_id=38, - achieved_at=datetime(2024, 7, 5, 15, 43, 23), - ), - LazerUserAchievement( - user_id=user.id, - achievement_id=67, - achieved_at=datetime(2024, 6, 24, 5, 6, 44), - ), - ] - - db.add_all(achievements) - - db.commit() - print(f"成功创建示例用户: {user.name} (ID: {user.id})") - print(f"安全用户名: {user.safe_name}") - print(f"邮箱: {user.email}") - print(f"国家: {user.country}") - return user + print(f"成功创建示例用户: {user.name} (ID: {user.id})") + print(f"安全用户名: {user.safe_name}") + print(f"邮箱: {user.email}") + print(f"国家: {user.country}") + return user -if __name__ == "__main__": +async def main(): print("开始创建示例数据...") - user = create_sample_user() + await create_tables() + user = await create_sample_user() print("示例数据创建完成!") print(f"用户名: {user.name}") print("密码: password123") print("现在您可以使用这些凭据来测试API了。") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/main.py b/main.py index 6e5cb9d..ce4d222 100644 --- a/main.py +++ b/main.py @@ -1,24 +1,31 @@ from __future__ import annotations +from contextlib import asynccontextmanager from datetime import datetime from app.config import settings -from app.dependencies.database import engine +from app.dependencies.database import create_tables from app.router import api_router, auth_router, signalr_router from fastapi import FastAPI -from sqlmodel import SQLModel # 注意: 表结构现在通过 migrations 管理,不再自动创建 # 如需创建表,请运行: python quick_sync.py -app = FastAPI(title="osu! API 模拟服务器", version="1.0.0") + +@asynccontextmanager +async def lifespan(app: FastAPI): + # on startup + await create_tables() + # on shutdown + yield + + +app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan) app.include_router(api_router, prefix="/api/v2") app.include_router(signalr_router, prefix="/signalr") app.include_router(auth_router) -SQLModel.metadata.create_all(bind=engine) - @app.get("/") async def root(): diff --git a/pyproject.toml b/pyproject.toml index 050918b..54c6850 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.11" dependencies = [ + "aiomysql>=0.2.0", "alembic>=1.12.1", "bcrypt>=4.1.2", "cryptography>=41.0.7", @@ -12,7 +13,6 @@ dependencies = [ "msgpack>=1.1.1", "passlib[bcrypt]>=1.7.4", "pydantic[email]>=2.5.0", - "pymysql>=1.1.0", "python-dotenv>=1.0.0", "python-jose[cryptography]>=3.3.0", "python-multipart>=0.0.6", diff --git a/requirements.txt b/requirements.txt index 2aa8b01..de329af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ fastapi==0.104.1 uvicorn[standard]==0.24.0 sqlalchemy==2.0.23 alembic==1.12.1 -pymysql==1.1.0 cryptography==41.0.7 redis==5.0.1 python-jose[cryptography]==3.3.0 @@ -11,3 +10,4 @@ python-multipart==0.0.6 pydantic[email]==2.5.0 python-dotenv==1.0.0 bcrypt==4.1.2 +aiomysql==0.2.0 diff --git a/test_lazer.py b/test_lazer.py index 31ebb3c..627325d 100644 --- a/test_lazer.py +++ b/test_lazer.py @@ -12,108 +12,106 @@ import sys sys.path.append(os.path.dirname(os.path.dirname(__file__))) from app.database import User -from app.dependencies.database import get_db +from app.dependencies.database import engine from app.utils import convert_db_user_to_api_user from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession -def test_lazer_tables(): +async def test_lazer_tables(): """测试 lazer 表的基本功能""" print("测试 Lazer API 表支持...") - # 获取数据库会话 - db_gen = get_db() - db = next(db_gen) + async with AsyncSession(engine) as session: + async with session.begin(): + try: + # 测试查询用户 + statement = select(User) + result = await session.execute(statement) + user = result.scalars().first() + if not user: + print("❌ 没有找到用户,请先同步数据") + return False - try: - # 测试查询用户 - statement = select(User) - user = db.exec(statement).first() - if not user: - print("❌ 没有找到用户,请先同步数据") - return False + print(f"✓ 找到用户: {user.name} (ID: {user.id})") - print(f"✓ 找到用户: {user.name} (ID: {user.id})") + # 测试 lazer 资料 + if user.lazer_profile: + print( + f"✓ 用户有 lazer 资料: 支持者={user.lazer_profile.is_supporter}" + ) + else: + print("⚠ 用户没有 lazer 资料,将使用默认值") - # 测试 lazer 资料 - if user.lazer_profile: - print(f"✓ 用户有 lazer 资料: 支持者={user.lazer_profile.is_supporter}") - else: - print("⚠ 用户没有 lazer 资料,将使用默认值") + # 测试 lazer 统计 + osu_stats = None + for stat in user.lazer_statistics: + if stat.mode == "osu": + osu_stats = stat + break - # 测试 lazer 统计 - osu_stats = None - for stat in user.lazer_statistics: - if stat.mode == "osu": - osu_stats = stat - break + if osu_stats: + print( + f"✓ 用户有 osu! 统计: PP={osu_stats.pp}, " + f"游戏次数={osu_stats.play_count}" + ) + else: + print("⚠ 用户没有 osu! 统计,将使用默认值") - if osu_stats: - print( - f"✓ 用户有 osu! 统计: PP={osu_stats.pp}, " - f"游戏次数={osu_stats.play_count}" - ) - else: - print("⚠ 用户没有 osu! 统计,将使用默认值") + # 测试转换为 API 格式 + api_user = convert_db_user_to_api_user(user, "osu") + print("✓ 成功转换为 API 用户格式") + print(f" - 用户名: {api_user.username}") + print(f" - 国家: {api_user.country_code}") + print(f" - PP: {api_user.statistics.pp}") + print(f" - 是否支持者: {api_user.is_supporter}") - # 测试转换为 API 格式 - api_user = convert_db_user_to_api_user(user, "osu", db) - print("✓ 成功转换为 API 用户格式") - print(f" - 用户名: {api_user.username}") - print(f" - 国家: {api_user.country_code}") - print(f" - PP: {api_user.statistics.pp}") - print(f" - 是否支持者: {api_user.is_supporter}") + return True - return True + except Exception as e: + print(f"❌ 测试失败: {e}") + import traceback - except Exception as e: - print(f"❌ 测试失败: {e}") - import traceback - - traceback.print_exc() - return False - finally: - db.close() + traceback.print_exc() + return False -def test_authentication(): +async def test_authentication(): """测试认证功能""" print("\n测试认证功能...") - db_gen = get_db() - db = next(db_gen) + async with AsyncSession(engine) as session: + async with session.begin(): + try: + # 尝试认证第一个用户 + statement = select(User) + result = await session.execute(statement) + user = result.scalars().first() + if not user: + print("❌ 没有用户进行认证测试") + return False - try: - # 尝试认证第一个用户 - statement = select(User) - user = db.exec(statement).first() - if not user: - print("❌ 没有用户进行认证测试") - return False + print(f"✓ 测试用户: {user.name}") + print("⚠ 注意: 实际密码认证需要正确的密码") - print(f"✓ 测试用户: {user.name}") - print("⚠ 注意: 实际密码认证需要正确的密码") + return True - return True - - except Exception as e: - print(f"❌ 认证测试失败: {e}") - return False - finally: - db.close() + except Exception as e: + print(f"❌ 认证测试失败: {e}") + return False -def main(): +async def main(): """主测试函数""" print("Lazer API 系统测试") print("=" * 40) # 测试表连接 - success1 = test_lazer_tables() + success1 = await test_lazer_tables() # 测试认证 - success2 = test_authentication() + success2 = await test_authentication() print("\n" + "=" * 40) if success1 and success2: @@ -130,4 +128,6 @@ def main(): if __name__ == "__main__": - main() + import asyncio + + asyncio.run(main()) diff --git a/uv.lock b/uv.lock index 580869c..f3155e1 100644 --- a/uv.lock +++ b/uv.lock @@ -2,6 +2,18 @@ version = 1 revision = 2 requires-python = ">=3.11" +[[package]] +name = "aiomysql" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pymysql" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/67/76/2c5b55e4406a1957ffdfd933a94c2517455291c97d2b81cec6813754791a/aiomysql-0.2.0.tar.gz", hash = "sha256:558b9c26d580d08b8c5fd1be23c5231ce3aeff2dadad989540fee740253deb67", size = 114706, upload-time = "2023-06-11T19:57:53.608Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/87/c982ee8b333c85b8ae16306387d703a1fcdfc81a2f3f15a24820ab1a512d/aiomysql-0.2.0-py3-none-any.whl", hash = "sha256:b7c26da0daf23a5ec5e0b133c03d20657276e4eae9b73e040b72787f6f6ade0a", size = 44215, upload-time = "2023-06-11T19:57:51.09Z" }, +] + [[package]] name = "alembic" version = "1.16.4" @@ -462,6 +474,7 @@ name = "osu-lazer-api" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "aiomysql" }, { name = "alembic" }, { name = "bcrypt" }, { name = "cryptography" }, @@ -469,7 +482,6 @@ dependencies = [ { name = "msgpack" }, { name = "passlib", extra = ["bcrypt"] }, { name = "pydantic", extra = ["email"] }, - { name = "pymysql" }, { name = "python-dotenv" }, { name = "python-jose", extra = ["cryptography"] }, { name = "python-multipart" }, @@ -487,6 +499,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiomysql", specifier = ">=0.2.0" }, { name = "alembic", specifier = ">=1.12.1" }, { name = "bcrypt", specifier = ">=4.1.2" }, { name = "cryptography", specifier = ">=41.0.7" }, @@ -494,7 +507,6 @@ requires-dist = [ { name = "msgpack", specifier = ">=1.1.1" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "pydantic", extras = ["email"], specifier = ">=2.5.0" }, - { name = "pymysql", specifier = ">=1.1.0" }, { name = "python-dotenv", specifier = ">=1.0.0" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" }, { name = "python-multipart", specifier = ">=0.0.6" },