refactor(database): use asyncio
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.dependencies import get_db
|
||||
from app.models.oauth import TokenResponse
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from sqlmodel import Session
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
router = APIRouter(tags=["osu! OAuth 认证"])
|
||||
|
||||
@@ -28,7 +28,7 @@ async def oauth_token(
|
||||
username: str | None = Form(None),
|
||||
password: str | None = Form(None),
|
||||
refresh_token: str | None = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""OAuth 令牌端点"""
|
||||
# 验证客户端凭据
|
||||
@@ -46,7 +46,7 @@ async def oauth_token(
|
||||
)
|
||||
|
||||
# 验证用户
|
||||
user = authenticate_user(db, username, password)
|
||||
user = await authenticate_user(db, username, password)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||
|
||||
@@ -58,9 +58,9 @@ async def oauth_token(
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
store_token(
|
||||
await store_token(
|
||||
db,
|
||||
getattr(user, "id"),
|
||||
user.id,
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
@@ -80,7 +80,7 @@ async def oauth_token(
|
||||
raise HTTPException(status_code=400, detail="Refresh token required")
|
||||
|
||||
# 验证刷新令牌
|
||||
token_record = get_token_by_refresh_token(db, refresh_token)
|
||||
token_record =await get_token_by_refresh_token(db, refresh_token)
|
||||
if not token_record:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
|
||||
@@ -92,10 +92,9 @@ async def oauth_token(
|
||||
new_refresh_token = generate_refresh_token()
|
||||
|
||||
# 更新令牌
|
||||
user_id = int(getattr(token_record, "user_id"))
|
||||
store_token(
|
||||
await store_token(
|
||||
db,
|
||||
user_id,
|
||||
token_record.user_id,
|
||||
access_token,
|
||||
new_refresh_token,
|
||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
|
||||
@@ -5,6 +5,7 @@ from app.database import (
|
||||
BeatmapResp,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user
|
||||
|
||||
@@ -12,16 +13,24 @@ from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
||||
async def get_beatmap(
|
||||
bid: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
beatmap = db.exec(select(Beatmap).where(Beatmap.id == bid)).first()
|
||||
beatmap = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||
.where(Beatmap.id == bid)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
@@ -36,16 +45,30 @@ class BatchGetResp(BaseModel):
|
||||
async def batch_get_beatmaps(
|
||||
b_ids: list[int] = Query(alias="id", default_factory=list),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not b_ids:
|
||||
# select 50 beatmaps by last_updated
|
||||
beatmaps = db.exec(
|
||||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
.order_by(col(Beatmap.last_updated).desc())
|
||||
.limit(50)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
beatmaps = db.exec(
|
||||
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)
|
||||
beatmaps = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
.where(col(Beatmap.id).in_(b_ids))
|
||||
.limit(50)
|
||||
)
|
||||
).all()
|
||||
|
||||
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
||||
|
||||
@@ -11,16 +11,24 @@ from app.dependencies.user import get_current_user
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
|
||||
async def get_beatmapset(
|
||||
sid: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
beatmapset = db.exec(select(Beatmapset).where(Beatmapset.id == sid)).first()
|
||||
beatmapset = (
|
||||
await db.exec(
|
||||
select(Beatmapset)
|
||||
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||
.where(Beatmapset.id == sid)
|
||||
)
|
||||
).first()
|
||||
if not beatmapset:
|
||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||
return BeatmapsetResp.from_db(beatmapset)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Literal
|
||||
from app.database import (
|
||||
User as DBUser,
|
||||
)
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user import (
|
||||
User as ApiUser,
|
||||
)
|
||||
@@ -14,7 +14,6 @@ from app.utils import convert_db_user_to_api_user
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
@router.get("/me/{ruleset}", response_model=ApiUser)
|
||||
@@ -22,9 +21,8 @@ from sqlalchemy.orm import Session
|
||||
async def get_user_info_default(
|
||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取当前用户信息(默认使用osu模式)"""
|
||||
# 默认使用osu模式
|
||||
api_user = convert_db_user_to_api_user(current_user, ruleset, db)
|
||||
api_user = await convert_db_user_to_api_user(current_user, ruleset)
|
||||
return api_user
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .router import router as signalr_router
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import info
|
||||
import time
|
||||
from typing import Literal
|
||||
import uuid
|
||||
@@ -9,15 +8,14 @@ import uuid
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user_by_token, security
|
||||
from app.dependencies.user import get_current_user_by_token
|
||||
from app.models.signalr import NegotiateResponse, Transport
|
||||
from app.router.signalr.packet import SEP
|
||||
|
||||
from .hub import Hubs
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -48,7 +46,7 @@ async def connect(
|
||||
websocket: WebSocket,
|
||||
id: str,
|
||||
authorization: str = Header(...),
|
||||
db: Session = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
token = authorization[7:]
|
||||
user_id = id.split(":")[0]
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
from typing import Any, Callable, ForwardRef, cast
|
||||
from typing import Any, ForwardRef, cast
|
||||
|
||||
|
||||
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
|
||||
|
||||
Reference in New Issue
Block a user