refactor(database): migrate to sqlmodel
This commit is contained in:
49
app/auth.py
49
app/auth.py
@@ -4,7 +4,6 @@ from datetime import datetime, timedelta
|
||||
import hashlib
|
||||
import secrets
|
||||
import string
|
||||
from typing import Optional
|
||||
|
||||
from app.config import settings
|
||||
from app.database import (
|
||||
@@ -15,7 +14,7 @@ from app.database import (
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlmodel import Session, select
|
||||
|
||||
# 密码哈希上下文
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
@@ -71,7 +70,7 @@ def get_password_hash(password: str) -> str:
|
||||
return pw_bcrypt.decode()
|
||||
|
||||
|
||||
def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[DBUser]:
|
||||
def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None:
|
||||
"""
|
||||
验证用户身份 - 使用类似 from_login 的逻辑
|
||||
"""
|
||||
@@ -79,12 +78,13 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[
|
||||
pw_md5 = hashlib.md5(password.encode()).hexdigest()
|
||||
|
||||
# 2. 根据用户名查找用户
|
||||
user = db.query(DBUser).filter(DBUser.name == name).first()
|
||||
statement = select(DBUser).where(DBUser.name == name)
|
||||
user = db.exec(statement).first()
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# 3. 验证密码
|
||||
if not (user.pw_bcrypt is None and user.pw_bcrypt != ""):
|
||||
# 3. 验证密码 - 修复逻辑错误
|
||||
if user.pw_bcrypt is None or user.pw_bcrypt == "":
|
||||
return None
|
||||
|
||||
# 4. 检查缓存
|
||||
@@ -107,12 +107,12 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[
|
||||
return None
|
||||
|
||||
|
||||
def authenticate_user(db: Session, username: str, password: str) -> Optional[DBUser]:
|
||||
def authenticate_user(db: Session, username: str, password: str) -> DBUser | None:
|
||||
"""验证用户身份"""
|
||||
return authenticate_user_legacy(db, username, password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
||||
"""创建访问令牌"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
@@ -136,7 +136,7 @@ def generate_refresh_token() -> str:
|
||||
return "".join(secrets.choice(characters) for _ in range(length))
|
||||
|
||||
|
||||
def verify_token(token: str) -> Optional[dict]:
|
||||
def verify_token(token: str) -> dict | None:
|
||||
"""验证访问令牌"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
@@ -154,7 +154,10 @@ def store_token(
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
# 删除用户的旧令牌
|
||||
db.query(OAuthToken).filter(OAuthToken.user_id == user_id).delete()
|
||||
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
|
||||
old_tokens = db.exec(statement).all()
|
||||
for token in old_tokens:
|
||||
db.delete(token)
|
||||
|
||||
# 创建新令牌记录
|
||||
token_record = OAuthToken(
|
||||
@@ -169,25 +172,19 @@ def store_token(
|
||||
return token_record
|
||||
|
||||
|
||||
def get_token_by_access_token(db: Session, access_token: str) -> Optional[OAuthToken]:
|
||||
def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None:
|
||||
"""根据访问令牌获取令牌记录"""
|
||||
return (
|
||||
db.query(OAuthToken)
|
||||
.filter(
|
||||
OAuthToken.access_token == access_token,
|
||||
OAuthToken.expires_at > datetime.utcnow(),
|
||||
)
|
||||
.first()
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.access_token == access_token,
|
||||
OAuthToken.expires_at > datetime.utcnow(),
|
||||
)
|
||||
return db.exec(statement).first()
|
||||
|
||||
|
||||
def get_token_by_refresh_token(db: Session, refresh_token: str) -> Optional[OAuthToken]:
|
||||
def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None:
|
||||
"""根据刷新令牌获取令牌记录"""
|
||||
return (
|
||||
db.query(OAuthToken)
|
||||
.filter(
|
||||
OAuthToken.refresh_token == refresh_token,
|
||||
OAuthToken.expires_at > datetime.utcnow(),
|
||||
)
|
||||
.first()
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.refresh_token == refresh_token,
|
||||
OAuthToken.expires_at > datetime.utcnow(),
|
||||
)
|
||||
return db.exec(statement).first()
|
||||
|
||||
Reference in New Issue
Block a user