refactor(database): use asyncio

This commit is contained in:
MingxuanGame
2025-07-25 20:43:50 +08:00
parent 2e1489c6d4
commit f347b680b2
21 changed files with 296 additions and 536 deletions

View File

@@ -14,7 +14,8 @@ from app.database import (
import bcrypt
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlmodel import Session, select
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
# 密码哈希上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -70,7 +71,9 @@ def get_password_hash(password: str) -> str:
return pw_bcrypt.decode()
def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None:
async def authenticate_user_legacy(
db: AsyncSession, name: str, password: str
) -> DBUser | None:
"""
验证用户身份 - 使用类似 from_login 的逻辑
"""
@@ -79,7 +82,7 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
# 2. 根据用户名查找用户
statement = select(DBUser).where(DBUser.name == name)
user = db.exec(statement).first()
user = (await db.exec(statement)).first()
if not user:
return None
@@ -107,9 +110,11 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
return None
def authenticate_user(db: Session, username: str, password: str) -> DBUser | None:
async def authenticate_user(
db: AsyncSession, username: str, password: str
) -> DBUser | None:
"""验证用户身份"""
return authenticate_user_legacy(db, username, password)
return await authenticate_user_legacy(db, username, password)
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
@@ -147,24 +152,28 @@ def verify_token(token: str) -> dict | None:
return None
def store_token(
db: Session, user_id: int, access_token: str, refresh_token: str, expires_in: int
async def store_token(
db: AsyncSession,
user_id: int,
access_token: str,
refresh_token: str,
expires_in: int,
) -> OAuthToken:
"""存储令牌到数据库"""
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
# 删除用户的旧令牌
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
old_tokens = db.exec(statement).all()
old_tokens = (await db.exec(statement)).all()
for token in old_tokens:
db.delete(token)
await db.delete(token)
# 检查是否有重复的 access_token
duplicate_token = db.exec(
select(OAuthToken).where(OAuthToken.access_token == access_token)
duplicate_token = (
await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))
).first()
if duplicate_token:
db.delete(duplicate_token)
await db.delete(duplicate_token)
# 创建新令牌记录
token_record = OAuthToken(
@@ -174,24 +183,28 @@ def store_token(
expires_at=expires_at,
)
db.add(token_record)
db.commit()
db.refresh(token_record)
await db.commit()
await db.refresh(token_record)
return token_record
def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None:
async def get_token_by_access_token(
db: AsyncSession, access_token: str
) -> OAuthToken | None:
"""根据访问令牌获取令牌记录"""
statement = select(OAuthToken).where(
OAuthToken.access_token == access_token,
OAuthToken.expires_at > datetime.utcnow(),
)
return db.exec(statement).first()
return (await db.exec(statement)).first()
def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None:
async def get_token_by_refresh_token(
db: AsyncSession, refresh_token: str
) -> OAuthToken | None:
"""根据刷新令牌获取令牌记录"""
statement = select(OAuthToken).where(
OAuthToken.refresh_token == refresh_token,
OAuthToken.expires_at > datetime.utcnow(),
)
return db.exec(statement).first()
return (await db.exec(statement)).first()

View File

@@ -10,7 +10,7 @@ load_dotenv()
class Settings:
# 数据库设置
DATABASE_URL: str = os.getenv(
"DATABASE_URL", "mysql+pymysql://root:password@localhost:3306/osu_api"
"DATABASE_URL", "mysql+aiomysql://root:password@localhost:3306/osu_api"
)
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0")

View File

@@ -15,7 +15,7 @@ class User(SQLModel, table=True):
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
# 主键
id: int | None = Field(default=None, primary_key=True, index=True)
id: int = Field(default=None, primary_key=True, index=True, nullable=False)
# 基本信息(匹配 migrations 中的结构)
name: str = Field(max_length=32, unique=True, index=True) # 用户名

View File

@@ -1,2 +1,4 @@
from __future__ import annotations
from .database import get_db as get_db
from .user import get_current_user as get_current_user
from .user import get_current_user as get_current_user

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from sqlmodel import Session, create_engine
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
try:
import redis
@@ -9,7 +11,7 @@ except ImportError:
from app.config import settings
# 数据库引擎
engine = create_engine(settings.DATABASE_URL)
engine = create_async_engine(settings.DATABASE_URL)
# Redis 连接
if redis:
@@ -19,11 +21,16 @@ else:
# 数据库依赖
def get_db():
with Session(engine) as session:
async def get_db():
async with AsyncSession(engine) as session:
yield session
async def create_tables():
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
# Redis 依赖
def get_redis():
return redis_client

View File

@@ -9,14 +9,16 @@ from .database import get_db
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlmodel import Session, select
from sqlalchemy.orm import joinedload, selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
) -> DBUser:
"""获取当前认证用户"""
token = credentials.credentials
@@ -27,9 +29,31 @@ async def get_current_user(
return user
async def get_current_user_by_token(token: str, db: Session) -> DBUser | None:
token_record = get_token_by_access_token(db, token)
async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None:
token_record = await get_token_by_access_token(db, token)
if not token_record:
return None
user = db.exec(select(DBUser).where(DBUser.id == token_record.user_id)).first()
user = (
await db.exec(
select(DBUser)
.options(
joinedload(DBUser.lazer_profile), # pyright: ignore[reportArgumentType]
joinedload(DBUser.lazer_counts), # pyright: ignore[reportArgumentType]
joinedload(DBUser.daily_challenge_stats), # pyright: ignore[reportArgumentType]
joinedload(DBUser.avatar), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_statistics), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_achievements), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_profile_sections), # pyright: ignore[reportArgumentType]
selectinload(DBUser.statistics), # pyright: ignore[reportArgumentType]
selectinload(DBUser.team_membership), # pyright: ignore[reportArgumentType]
selectinload(DBUser.rank_history), # pyright: ignore[reportArgumentType]
selectinload(DBUser.active_banners), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_badges), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_replays_watched), # pyright: ignore[reportArgumentType]
)
.where(DBUser.id == token_record.user_id)
)
).first()
return user

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from fastapi import APIRouter
router = APIRouter()

View File

@@ -14,7 +14,7 @@ from app.dependencies import get_db
from app.models.oauth import TokenResponse
from fastapi import APIRouter, Depends, Form, HTTPException
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(tags=["osu! OAuth 认证"])
@@ -28,7 +28,7 @@ async def oauth_token(
username: str | None = Form(None),
password: str | None = Form(None),
refresh_token: str | None = Form(None),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
"""OAuth 令牌端点"""
# 验证客户端凭据
@@ -46,7 +46,7 @@ async def oauth_token(
)
# 验证用户
user = authenticate_user(db, username, password)
user = await authenticate_user(db, username, password)
if not user:
raise HTTPException(status_code=401, detail="Invalid username or password")
@@ -58,9 +58,9 @@ async def oauth_token(
refresh_token_str = generate_refresh_token()
# 存储令牌
store_token(
await store_token(
db,
getattr(user, "id"),
user.id,
access_token,
refresh_token_str,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
@@ -80,7 +80,7 @@ async def oauth_token(
raise HTTPException(status_code=400, detail="Refresh token required")
# 验证刷新令牌
token_record = get_token_by_refresh_token(db, refresh_token)
token_record =await get_token_by_refresh_token(db, refresh_token)
if not token_record:
raise HTTPException(status_code=401, detail="Invalid refresh token")
@@ -92,10 +92,9 @@ async def oauth_token(
new_refresh_token = generate_refresh_token()
# 更新令牌
user_id = int(getattr(token_record, "user_id"))
store_token(
await store_token(
db,
user_id,
token_record.user_id,
access_token,
new_refresh_token,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,

View File

@@ -5,6 +5,7 @@ from app.database import (
BeatmapResp,
User as DBUser,
)
from app.database.beatmapset import Beatmapset
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user
@@ -12,16 +13,24 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query
from pydantic import BaseModel
from sqlmodel import Session, col, select
from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap(
bid: int,
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
beatmap = db.exec(select(Beatmap).where(Beatmap.id == bid)).first()
beatmap = (
await db.exec(
select(Beatmap)
.options(joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmap.id == bid)
)
).first()
if not beatmap:
raise HTTPException(status_code=404, detail="Beatmap not found")
return BeatmapResp.from_db(beatmap)
@@ -36,16 +45,30 @@ class BatchGetResp(BaseModel):
async def batch_get_beatmaps(
b_ids: list[int] = Query(alias="id", default_factory=list),
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
if not b_ids:
# select 50 beatmaps by last_updated
beatmaps = db.exec(
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.order_by(col(Beatmap.last_updated).desc())
.limit(50)
)
).all()
else:
beatmaps = db.exec(
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.where(col(Beatmap.id).in_(b_ids))
.limit(50)
)
).all()
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])

View File

@@ -11,16 +11,24 @@ from app.dependencies.user import get_current_user
from .api_router import router
from fastapi import Depends, HTTPException
from sqlmodel import Session, select
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
async def get_beatmapset(
sid: int,
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
beatmapset = db.exec(select(Beatmapset).where(Beatmapset.id == sid)).first()
beatmapset = (
await db.exec(
select(Beatmapset)
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmapset.id == sid)
)
).first()
if not beatmapset:
raise HTTPException(status_code=404, detail="Beatmapset not found")
return BeatmapsetResp.from_db(beatmapset)

View File

@@ -5,7 +5,7 @@ from typing import Literal
from app.database import (
User as DBUser,
)
from app.dependencies import get_current_user, get_db
from app.dependencies import get_current_user
from app.models.user import (
User as ApiUser,
)
@@ -14,7 +14,6 @@ from app.utils import convert_db_user_to_api_user
from .api_router import router
from fastapi import Depends
from sqlalchemy.orm import Session
@router.get("/me/{ruleset}", response_model=ApiUser)
@@ -22,9 +21,8 @@ from sqlalchemy.orm import Session
async def get_user_info_default(
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取当前用户信息默认使用osu模式"""
# 默认使用osu模式
api_user = convert_db_user_to_api_user(current_user, ruleset, db)
api_user = await convert_db_user_to_api_user(current_user, ruleset)
return api_user

View File

@@ -1 +1,3 @@
from __future__ import annotations
from .router import router as signalr_router

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import json
from logging import info
import time
from typing import Literal
import uuid
@@ -9,15 +8,14 @@ import uuid
from app.database import User as DBUser
from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user_by_token, security
from app.dependencies.user import get_current_user_by_token
from app.models.signalr import NegotiateResponse, Transport
from app.router.signalr.packet import SEP
from .hub import Hubs
from fastapi import APIRouter, Depends, Header, Query, WebSocket
from fastapi.security import HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter()
@@ -48,7 +46,7 @@ async def connect(
websocket: WebSocket,
id: str,
authorization: str = Header(...),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
token = authorization[7:]
user_id = id.split(":")[0]

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
from collections.abc import Callable
import inspect
from typing import Any, Callable, ForwardRef, cast
from typing import Any, ForwardRef, cast
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66

View File

@@ -23,11 +23,9 @@ from app.models.user import (
UserAchievement,
)
from sqlalchemy.orm import Session
def convert_db_user_to_api_user(
db_user: DBUser, ruleset: str = "osu", db_session: Session | None = None
async def convert_db_user_to_api_user(
db_user: DBUser, ruleset: str = "osu"
) -> User:
"""将数据库用户模型转换为API用户模型使用 Lazer 表)"""