refactor(database): use asyncio
This commit is contained in:
49
app/auth.py
49
app/auth.py
@@ -14,7 +14,8 @@ from app.database import (
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
# 密码哈希上下文
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
@@ -70,7 +71,9 @@ def get_password_hash(password: str) -> str:
|
||||
return pw_bcrypt.decode()
|
||||
|
||||
|
||||
def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None:
|
||||
async def authenticate_user_legacy(
|
||||
db: AsyncSession, name: str, password: str
|
||||
) -> DBUser | None:
|
||||
"""
|
||||
验证用户身份 - 使用类似 from_login 的逻辑
|
||||
"""
|
||||
@@ -79,7 +82,7 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
|
||||
|
||||
# 2. 根据用户名查找用户
|
||||
statement = select(DBUser).where(DBUser.name == name)
|
||||
user = db.exec(statement).first()
|
||||
user = (await db.exec(statement)).first()
|
||||
if not user:
|
||||
return None
|
||||
|
||||
@@ -107,9 +110,11 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
|
||||
return None
|
||||
|
||||
|
||||
def authenticate_user(db: Session, username: str, password: str) -> DBUser | None:
|
||||
async def authenticate_user(
|
||||
db: AsyncSession, username: str, password: str
|
||||
) -> DBUser | None:
|
||||
"""验证用户身份"""
|
||||
return authenticate_user_legacy(db, username, password)
|
||||
return await authenticate_user_legacy(db, username, password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
||||
@@ -147,24 +152,28 @@ def verify_token(token: str) -> dict | None:
|
||||
return None
|
||||
|
||||
|
||||
def store_token(
|
||||
db: Session, user_id: int, access_token: str, refresh_token: str, expires_in: int
|
||||
async def store_token(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
expires_in: int,
|
||||
) -> OAuthToken:
|
||||
"""存储令牌到数据库"""
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
# 删除用户的旧令牌
|
||||
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
|
||||
old_tokens = db.exec(statement).all()
|
||||
old_tokens = (await db.exec(statement)).all()
|
||||
for token in old_tokens:
|
||||
db.delete(token)
|
||||
await db.delete(token)
|
||||
|
||||
# 检查是否有重复的 access_token
|
||||
duplicate_token = db.exec(
|
||||
select(OAuthToken).where(OAuthToken.access_token == access_token)
|
||||
duplicate_token = (
|
||||
await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))
|
||||
).first()
|
||||
if duplicate_token:
|
||||
db.delete(duplicate_token)
|
||||
await db.delete(duplicate_token)
|
||||
|
||||
# 创建新令牌记录
|
||||
token_record = OAuthToken(
|
||||
@@ -174,24 +183,28 @@ def store_token(
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(token_record)
|
||||
db.commit()
|
||||
db.refresh(token_record)
|
||||
await db.commit()
|
||||
await db.refresh(token_record)
|
||||
return token_record
|
||||
|
||||
|
||||
def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None:
|
||||
async def get_token_by_access_token(
|
||||
db: AsyncSession, access_token: str
|
||||
) -> OAuthToken | None:
|
||||
"""根据访问令牌获取令牌记录"""
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.access_token == access_token,
|
||||
OAuthToken.expires_at > datetime.utcnow(),
|
||||
)
|
||||
return db.exec(statement).first()
|
||||
return (await db.exec(statement)).first()
|
||||
|
||||
|
||||
def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None:
|
||||
async def get_token_by_refresh_token(
|
||||
db: AsyncSession, refresh_token: str
|
||||
) -> OAuthToken | None:
|
||||
"""根据刷新令牌获取令牌记录"""
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.refresh_token == refresh_token,
|
||||
OAuthToken.expires_at > datetime.utcnow(),
|
||||
)
|
||||
return db.exec(statement).first()
|
||||
return (await db.exec(statement)).first()
|
||||
|
||||
@@ -10,7 +10,7 @@ load_dotenv()
|
||||
class Settings:
|
||||
# 数据库设置
|
||||
DATABASE_URL: str = os.getenv(
|
||||
"DATABASE_URL", "mysql+pymysql://root:password@localhost:3306/osu_api"
|
||||
"DATABASE_URL", "mysql+aiomysql://root:password@localhost:3306/osu_api"
|
||||
)
|
||||
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class User(SQLModel, table=True):
|
||||
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
# 主键
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
id: int = Field(default=None, primary_key=True, index=True, nullable=False)
|
||||
|
||||
# 基本信息(匹配 migrations 中的结构)
|
||||
name: str = Field(max_length=32, unique=True, index=True) # 用户名
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .database import get_db as get_db
|
||||
from .user import get_current_user as get_current_user
|
||||
from .user import get_current_user as get_current_user
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlmodel import Session, create_engine
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
try:
|
||||
import redis
|
||||
@@ -9,7 +11,7 @@ except ImportError:
|
||||
from app.config import settings
|
||||
|
||||
# 数据库引擎
|
||||
engine = create_engine(settings.DATABASE_URL)
|
||||
engine = create_async_engine(settings.DATABASE_URL)
|
||||
|
||||
# Redis 连接
|
||||
if redis:
|
||||
@@ -19,11 +21,16 @@ else:
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
def get_db():
|
||||
with Session(engine) as session:
|
||||
async def get_db():
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def create_tables():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
|
||||
# Redis 依赖
|
||||
def get_redis():
|
||||
return redis_client
|
||||
|
||||
@@ -9,14 +9,16 @@ from .database import get_db
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy.orm import joinedload, selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> DBUser:
|
||||
"""获取当前认证用户"""
|
||||
token = credentials.credentials
|
||||
@@ -27,9 +29,31 @@ async def get_current_user(
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user_by_token(token: str, db: Session) -> DBUser | None:
|
||||
token_record = get_token_by_access_token(db, token)
|
||||
async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None:
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
return None
|
||||
user = db.exec(select(DBUser).where(DBUser.id == token_record.user_id)).first()
|
||||
user = (
|
||||
await db.exec(
|
||||
select(DBUser)
|
||||
.options(
|
||||
joinedload(DBUser.lazer_profile), # pyright: ignore[reportArgumentType]
|
||||
joinedload(DBUser.lazer_counts), # pyright: ignore[reportArgumentType]
|
||||
joinedload(DBUser.daily_challenge_stats), # pyright: ignore[reportArgumentType]
|
||||
joinedload(DBUser.avatar), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_statistics), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_achievements), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_profile_sections), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.statistics), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.team_membership), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.rank_history), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.active_banners), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_badges), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_replays_watched), # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
.where(DBUser.id == token_record.user_id)
|
||||
)
|
||||
).first()
|
||||
return user
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.dependencies import get_db
|
||||
from app.models.oauth import TokenResponse
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from sqlmodel import Session
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
router = APIRouter(tags=["osu! OAuth 认证"])
|
||||
|
||||
@@ -28,7 +28,7 @@ async def oauth_token(
|
||||
username: str | None = Form(None),
|
||||
password: str | None = Form(None),
|
||||
refresh_token: str | None = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""OAuth 令牌端点"""
|
||||
# 验证客户端凭据
|
||||
@@ -46,7 +46,7 @@ async def oauth_token(
|
||||
)
|
||||
|
||||
# 验证用户
|
||||
user = authenticate_user(db, username, password)
|
||||
user = await authenticate_user(db, username, password)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||
|
||||
@@ -58,9 +58,9 @@ async def oauth_token(
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
store_token(
|
||||
await store_token(
|
||||
db,
|
||||
getattr(user, "id"),
|
||||
user.id,
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
@@ -80,7 +80,7 @@ async def oauth_token(
|
||||
raise HTTPException(status_code=400, detail="Refresh token required")
|
||||
|
||||
# 验证刷新令牌
|
||||
token_record = get_token_by_refresh_token(db, refresh_token)
|
||||
token_record =await get_token_by_refresh_token(db, refresh_token)
|
||||
if not token_record:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
|
||||
@@ -92,10 +92,9 @@ async def oauth_token(
|
||||
new_refresh_token = generate_refresh_token()
|
||||
|
||||
# 更新令牌
|
||||
user_id = int(getattr(token_record, "user_id"))
|
||||
store_token(
|
||||
await store_token(
|
||||
db,
|
||||
user_id,
|
||||
token_record.user_id,
|
||||
access_token,
|
||||
new_refresh_token,
|
||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
|
||||
@@ -5,6 +5,7 @@ from app.database import (
|
||||
BeatmapResp,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user
|
||||
|
||||
@@ -12,16 +13,24 @@ from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
||||
async def get_beatmap(
|
||||
bid: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
beatmap = db.exec(select(Beatmap).where(Beatmap.id == bid)).first()
|
||||
beatmap = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||
.where(Beatmap.id == bid)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
@@ -36,16 +45,30 @@ class BatchGetResp(BaseModel):
|
||||
async def batch_get_beatmaps(
|
||||
b_ids: list[int] = Query(alias="id", default_factory=list),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not b_ids:
|
||||
# select 50 beatmaps by last_updated
|
||||
beatmaps = db.exec(
|
||||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
.order_by(col(Beatmap.last_updated).desc())
|
||||
.limit(50)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
beatmaps = db.exec(
|
||||
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
.where(col(Beatmap.id).in_(b_ids))
|
||||
.limit(50)
|
||||
)
|
||||
).all()
|
||||
|
||||
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
||||
|
||||
@@ -11,16 +11,24 @@ from app.dependencies.user import get_current_user
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
|
||||
async def get_beatmapset(
|
||||
sid: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
beatmapset = db.exec(select(Beatmapset).where(Beatmapset.id == sid)).first()
|
||||
beatmapset = (
|
||||
await db.exec(
|
||||
select(Beatmapset)
|
||||
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||
.where(Beatmapset.id == sid)
|
||||
)
|
||||
).first()
|
||||
if not beatmapset:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
return BeatmapsetResp.from_db(beatmapset)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Literal
|
||||
from app.database import (
|
||||
User as DBUser,
|
||||
)
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user import (
|
||||
User as ApiUser,
|
||||
)
|
||||
@@ -14,7 +14,6 @@ from app.utils import convert_db_user_to_api_user
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@router.get("/me/{ruleset}", response_model=ApiUser)
|
||||
@@ -22,9 +21,8 @@ from sqlalchemy.orm import Session
|
||||
async def get_user_info_default(
|
||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取当前用户信息(默认使用osu模式)"""
|
||||
# 默认使用osu模式
|
||||
api_user = convert_db_user_to_api_user(current_user, ruleset, db)
|
||||
api_user = await convert_db_user_to_api_user(current_user, ruleset)
|
||||
return api_user
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .router import router as signalr_router
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import info
|
||||
import time
|
||||
from typing import Literal
|
||||
import uuid
|
||||
@@ -9,15 +8,14 @@ import uuid
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user_by_token, security
|
||||
from app.dependencies.user import get_current_user_by_token
|
||||
from app.models.signalr import NegotiateResponse, Transport
|
||||
from app.router.signalr.packet import SEP
|
||||
|
||||
from .hub import Hubs
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -48,7 +46,7 @@ async def connect(
|
||||
websocket: WebSocket,
|
||||
id: str,
|
||||
authorization: str = Header(...),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
token = authorization[7:]
|
||||
user_id = id.split(":")[0]
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
from typing import Any, Callable, ForwardRef, cast
|
||||
from typing import Any, ForwardRef, cast
|
||||
|
||||
|
||||
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
|
||||
|
||||
@@ -23,11 +23,9 @@ from app.models.user import (
|
||||
UserAchievement,
|
||||
)
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def convert_db_user_to_api_user(
|
||||
db_user: DBUser, ruleset: str = "osu", db_session: Session | None = None
|
||||
async def convert_db_user_to_api_user(
|
||||
db_user: DBUser, ruleset: str = "osu"
|
||||
) -> User:
|
||||
"""将数据库用户模型转换为API用户模型(使用 Lazer 表)"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user