refactor(database): use asyncio

This commit is contained in:
MingxuanGame
2025-07-25 20:43:50 +08:00
parent 2e1489c6d4
commit f347b680b2
21 changed files with 296 additions and 536 deletions

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from fastapi import APIRouter
router = APIRouter()

View File

@@ -14,7 +14,7 @@ from app.dependencies import get_db
from app.models.oauth import TokenResponse
from fastapi import APIRouter, Depends, Form, HTTPException
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter(tags=["osu! OAuth 认证"])
@@ -28,7 +28,7 @@ async def oauth_token(
username: str | None = Form(None),
password: str | None = Form(None),
refresh_token: str | None = Form(None),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
"""OAuth 令牌端点"""
# 验证客户端凭据
@@ -46,7 +46,7 @@ async def oauth_token(
)
# 验证用户
user = authenticate_user(db, username, password)
user = await authenticate_user(db, username, password)
if not user:
raise HTTPException(status_code=401, detail="Invalid username or password")
@@ -58,9 +58,9 @@ async def oauth_token(
refresh_token_str = generate_refresh_token()
# 存储令牌
store_token(
await store_token(
db,
getattr(user, "id"),
user.id,
access_token,
refresh_token_str,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
@@ -80,7 +80,7 @@ async def oauth_token(
raise HTTPException(status_code=400, detail="Refresh token required")
# 验证刷新令牌
token_record = get_token_by_refresh_token(db, refresh_token)
token_record =await get_token_by_refresh_token(db, refresh_token)
if not token_record:
raise HTTPException(status_code=401, detail="Invalid refresh token")
@@ -92,10 +92,9 @@ async def oauth_token(
new_refresh_token = generate_refresh_token()
# 更新令牌
user_id = int(getattr(token_record, "user_id"))
store_token(
await store_token(
db,
user_id,
token_record.user_id,
access_token,
new_refresh_token,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,

View File

@@ -5,6 +5,7 @@ from app.database import (
BeatmapResp,
User as DBUser,
)
from app.database.beatmapset import Beatmapset
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user
@@ -12,16 +13,24 @@ from .api_router import router
from fastapi import Depends, HTTPException, Query
from pydantic import BaseModel
from sqlmodel import Session, col, select
from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap(
bid: int,
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
beatmap = db.exec(select(Beatmap).where(Beatmap.id == bid)).first()
beatmap = (
await db.exec(
select(Beatmap)
.options(joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmap.id == bid)
)
).first()
if not beatmap:
raise HTTPException(status_code=404, detail="Beatmap not found")
return BeatmapResp.from_db(beatmap)
@@ -36,16 +45,30 @@ class BatchGetResp(BaseModel):
async def batch_get_beatmaps(
b_ids: list[int] = Query(alias="id", default_factory=list),
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
if not b_ids:
# select 50 beatmaps by last_updated
beatmaps = db.exec(
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.order_by(col(Beatmap.last_updated).desc())
.limit(50)
)
).all()
else:
beatmaps = db.exec(
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)
beatmaps = (
await db.exec(
select(Beatmap)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
.where(col(Beatmap.id).in_(b_ids))
.limit(50)
)
).all()
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])

View File

@@ -11,16 +11,24 @@ from app.dependencies.user import get_current_user
from .api_router import router
from fastapi import Depends, HTTPException
from sqlmodel import Session, select
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
async def get_beatmapset(
sid: int,
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
beatmapset = db.exec(select(Beatmapset).where(Beatmapset.id == sid)).first()
beatmapset = (
await db.exec(
select(Beatmapset)
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
.where(Beatmapset.id == sid)
)
).first()
if not beatmapset:
raise HTTPException(status_code=404, detail="Beatmapset not found")
return BeatmapsetResp.from_db(beatmapset)

View File

@@ -5,7 +5,7 @@ from typing import Literal
from app.database import (
User as DBUser,
)
from app.dependencies import get_current_user, get_db
from app.dependencies import get_current_user
from app.models.user import (
User as ApiUser,
)
@@ -14,7 +14,6 @@ from app.utils import convert_db_user_to_api_user
from .api_router import router
from fastapi import Depends
from sqlalchemy.orm import Session
@router.get("/me/{ruleset}", response_model=ApiUser)
@@ -22,9 +21,8 @@ from sqlalchemy.orm import Session
async def get_user_info_default(
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取当前用户信息默认使用osu模式"""
# 默认使用osu模式
api_user = convert_db_user_to_api_user(current_user, ruleset, db)
api_user = await convert_db_user_to_api_user(current_user, ruleset)
return api_user

View File

@@ -1 +1,3 @@
from __future__ import annotations
from .router import router as signalr_router

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import json
from logging import info
import time
from typing import Literal
import uuid
@@ -9,15 +8,14 @@ import uuid
from app.database import User as DBUser
from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user_by_token, security
from app.dependencies.user import get_current_user_by_token
from app.models.signalr import NegotiateResponse, Transport
from app.router.signalr.packet import SEP
from .hub import Hubs
from fastapi import APIRouter, Depends, Header, Query, WebSocket
from fastapi.security import HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter()
@@ -48,7 +46,7 @@ async def connect(
websocket: WebSocket,
id: str,
authorization: str = Header(...),
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
):
token = authorization[7:]
user_id = id.split(":")[0]

View File

@@ -1,5 +1,8 @@
from __future__ import annotations
from collections.abc import Callable
import inspect
from typing import Any, Callable, ForwardRef, cast
from typing import Any, ForwardRef, cast
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66