refactor(database): migrate to sqlmodel

This commit is contained in:
MingxuanGame
2025-07-24 20:49:07 +08:00
parent 1655bb9f53
commit c43ca883a5
11 changed files with 582 additions and 743 deletions

View File

@@ -4,7 +4,6 @@ from datetime import datetime, timedelta
import hashlib
import secrets
import string
from typing import Optional
from app.config import settings
from app.database import (
@@ -15,7 +14,7 @@ from app.database import (
import bcrypt
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from sqlmodel import Session, select
# 密码哈希上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -71,7 +70,7 @@ def get_password_hash(password: str) -> str:
return pw_bcrypt.decode()
def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[DBUser]:
def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None:
"""
验证用户身份 - 使用类似 from_login 的逻辑
"""
@@ -79,12 +78,13 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[
pw_md5 = hashlib.md5(password.encode()).hexdigest()
# 2. 根据用户名查找用户
user = db.query(DBUser).filter(DBUser.name == name).first()
statement = select(DBUser).where(DBUser.name == name)
user = db.exec(statement).first()
if not user:
return None
# 3. 验证密码
if not (user.pw_bcrypt is None and user.pw_bcrypt != ""):
# 3. 验证密码 - 修复逻辑错误
if user.pw_bcrypt is None or user.pw_bcrypt == "":
return None
# 4. 检查缓存
@@ -107,12 +107,12 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[
return None
def authenticate_user(db: Session, username: str, password: str) -> Optional[DBUser]:
def authenticate_user(db: Session, username: str, password: str) -> DBUser | None:
"""验证用户身份"""
return authenticate_user_legacy(db, username, password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
@@ -136,7 +136,7 @@ def generate_refresh_token() -> str:
return "".join(secrets.choice(characters) for _ in range(length))
def verify_token(token: str) -> Optional[dict]:
def verify_token(token: str) -> dict | None:
"""验证访问令牌"""
try:
payload = jwt.decode(
@@ -154,7 +154,10 @@ def store_token(
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
# 删除用户的旧令牌
db.query(OAuthToken).filter(OAuthToken.user_id == user_id).delete()
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
old_tokens = db.exec(statement).all()
for token in old_tokens:
db.delete(token)
# 创建新令牌记录
token_record = OAuthToken(
@@ -169,25 +172,19 @@ def store_token(
return token_record
def get_token_by_access_token(db: Session, access_token: str) -> Optional[OAuthToken]:
def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None:
"""根据访问令牌获取令牌记录"""
return (
db.query(OAuthToken)
.filter(
OAuthToken.access_token == access_token,
OAuthToken.expires_at > datetime.utcnow(),
)
.first()
statement = select(OAuthToken).where(
OAuthToken.access_token == access_token,
OAuthToken.expires_at > datetime.utcnow(),
)
return db.exec(statement).first()
def get_token_by_refresh_token(db: Session, refresh_token: str) -> Optional[OAuthToken]:
def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None:
"""根据刷新令牌获取令牌记录"""
return (
db.query(OAuthToken)
.filter(
OAuthToken.refresh_token == refresh_token,
OAuthToken.expires_at > datetime.utcnow(),
)
.first()
statement = select(OAuthToken).where(
OAuthToken.refresh_token == refresh_token,
OAuthToken.expires_at > datetime.utcnow(),
)
return db.exec(statement).first()