refactor(database): use asyncio

This commit is contained in:
MingxuanGame
2025-07-25 20:43:50 +08:00
parent 2e1489c6d4
commit f347b680b2
21 changed files with 296 additions and 536 deletions

View File

@@ -14,7 +14,8 @@ from app.database import (
import bcrypt
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlmodel import Session, select
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
# 密码哈希上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -70,7 +71,9 @@ def get_password_hash(password: str) -> str:
return pw_bcrypt.decode()
def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None:
async def authenticate_user_legacy(
db: AsyncSession, name: str, password: str
) -> DBUser | None:
"""
验证用户身份 - 使用类似 from_login 的逻辑
"""
@@ -79,7 +82,7 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
# 2. 根据用户名查找用户
statement = select(DBUser).where(DBUser.name == name)
user = db.exec(statement).first()
user = (await db.exec(statement)).first()
if not user:
return None
@@ -107,9 +110,11 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
return None
def authenticate_user(db: Session, username: str, password: str) -> DBUser | None:
async def authenticate_user(
db: AsyncSession, username: str, password: str
) -> DBUser | None:
"""验证用户身份"""
return authenticate_user_legacy(db, username, password)
return await authenticate_user_legacy(db, username, password)
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
@@ -147,24 +152,28 @@ def verify_token(token: str) -> dict | None:
return None
def store_token(
db: Session, user_id: int, access_token: str, refresh_token: str, expires_in: int
async def store_token(
db: AsyncSession,
user_id: int,
access_token: str,
refresh_token: str,
expires_in: int,
) -> OAuthToken:
"""存储令牌到数据库"""
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
# 删除用户的旧令牌
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
old_tokens = db.exec(statement).all()
old_tokens = (await db.exec(statement)).all()
for token in old_tokens:
db.delete(token)
await db.delete(token)
# 检查是否有重复的 access_token
duplicate_token = db.exec(
select(OAuthToken).where(OAuthToken.access_token == access_token)
duplicate_token = (
await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))
).first()
if duplicate_token:
db.delete(duplicate_token)
await db.delete(duplicate_token)
# 创建新令牌记录
token_record = OAuthToken(
@@ -174,24 +183,28 @@ def store_token(
expires_at=expires_at,
)
db.add(token_record)
db.commit()
db.refresh(token_record)
await db.commit()
await db.refresh(token_record)
return token_record
def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None:
async def get_token_by_access_token(
db: AsyncSession, access_token: str
) -> OAuthToken | None:
"""根据访问令牌获取令牌记录"""
statement = select(OAuthToken).where(
OAuthToken.access_token == access_token,
OAuthToken.expires_at > datetime.utcnow(),
)
return db.exec(statement).first()
return (await db.exec(statement)).first()
def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None:
async def get_token_by_refresh_token(
db: AsyncSession, refresh_token: str
) -> OAuthToken | None:
"""根据刷新令牌获取令牌记录"""
statement = select(OAuthToken).where(
OAuthToken.refresh_token == refresh_token,
OAuthToken.expires_at > datetime.utcnow(),
)
return db.exec(statement).first()
return (await db.exec(statement)).first()