refactor(database): use asyncio
This commit is contained in:
49
app/auth.py
49
app/auth.py
@@ -14,7 +14,8 @@ from app.database import (
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
# 密码哈希上下文
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
@@ -70,7 +71,9 @@ def get_password_hash(password: str) -> str:
|
||||
return pw_bcrypt.decode()
|
||||
|
||||
|
||||
def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None:
|
||||
async def authenticate_user_legacy(
|
||||
db: AsyncSession, name: str, password: str
|
||||
) -> DBUser | None:
|
||||
"""
|
||||
验证用户身份 - 使用类似 from_login 的逻辑
|
||||
"""
|
||||
@@ -79,7 +82,7 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
|
||||
|
||||
# 2. 根据用户名查找用户
|
||||
statement = select(DBUser).where(DBUser.name == name)
|
||||
user = db.exec(statement).first()
|
||||
user = (await db.exec(statement)).first()
|
||||
if not user:
|
||||
return None
|
||||
|
||||
@@ -107,9 +110,11 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
|
||||
return None
|
||||
|
||||
|
||||
def authenticate_user(db: Session, username: str, password: str) -> DBUser | None:
|
||||
async def authenticate_user(
|
||||
db: AsyncSession, username: str, password: str
|
||||
) -> DBUser | None:
|
||||
"""验证用户身份"""
|
||||
return authenticate_user_legacy(db, username, password)
|
||||
return await authenticate_user_legacy(db, username, password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
||||
@@ -147,24 +152,28 @@ def verify_token(token: str) -> dict | None:
|
||||
return None
|
||||
|
||||
|
||||
def store_token(
|
||||
db: Session, user_id: int, access_token: str, refresh_token: str, expires_in: int
|
||||
async def store_token(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
expires_in: int,
|
||||
) -> OAuthToken:
|
||||
"""存储令牌到数据库"""
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
# 删除用户的旧令牌
|
||||
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
|
||||
old_tokens = db.exec(statement).all()
|
||||
old_tokens = (await db.exec(statement)).all()
|
||||
for token in old_tokens:
|
||||
db.delete(token)
|
||||
await db.delete(token)
|
||||
|
||||
# 检查是否有重复的 access_token
|
||||
duplicate_token = db.exec(
|
||||
select(OAuthToken).where(OAuthToken.access_token == access_token)
|
||||
duplicate_token = (
|
||||
await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))
|
||||
).first()
|
||||
if duplicate_token:
|
||||
db.delete(duplicate_token)
|
||||
await db.delete(duplicate_token)
|
||||
|
||||
# 创建新令牌记录
|
||||
token_record = OAuthToken(
|
||||
@@ -174,24 +183,28 @@ def store_token(
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(token_record)
|
||||
db.commit()
|
||||
db.refresh(token_record)
|
||||
await db.commit()
|
||||
await db.refresh(token_record)
|
||||
return token_record
|
||||
|
||||
|
||||
def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None:
|
||||
async def get_token_by_access_token(
|
||||
db: AsyncSession, access_token: str
|
||||
) -> OAuthToken | None:
|
||||
"""根据访问令牌获取令牌记录"""
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.access_token == access_token,
|
||||
OAuthToken.expires_at > datetime.utcnow(),
|
||||
)
|
||||
return db.exec(statement).first()
|
||||
return (await db.exec(statement)).first()
|
||||
|
||||
|
||||
def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None:
|
||||
async def get_token_by_refresh_token(
|
||||
db: AsyncSession, refresh_token: str
|
||||
) -> OAuthToken | None:
|
||||
"""根据刷新令牌获取令牌记录"""
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.refresh_token == refresh_token,
|
||||
OAuthToken.expires_at > datetime.utcnow(),
|
||||
)
|
||||
return db.exec(statement).first()
|
||||
return (await db.exec(statement)).first()
|
||||
|
||||
Reference in New Issue
Block a user