Files
g0v0-server/app/auth.py
2025-10-07 13:07:14 +00:00

450 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from datetime import timedelta
import hashlib
import re
import secrets
import string
from app.config import settings
from app.const import BACKUP_CODE_LENGTH
from app.database import (
OAuthToken,
User,
)
from app.database.auth import TotpKeys
from app.log import log
from app.models.totp import FinishStatus, StartCreateTotpKeyResp
from app.utils import utcnow
import bcrypt
from jose import JWTError, jwt
from passlib.context import CryptContext
import pyotp
from redis.asyncio import Redis
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
# 密码哈希上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# bcrypt 缓存(模拟应用状态缓存)
bcrypt_cache = {}
logger = log("Auth")
def validate_username(username: str) -> list[str]:
"""验证用户名"""
errors = []
if not username:
errors.append("Username is required")
return errors
if len(username) < 3:
errors.append("Username must be at least 3 characters long")
if len(username) > 15:
errors.append("Username must be at most 15 characters long")
# 检查用户名格式(只允许字母、数字、下划线、连字符)
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
errors.append("Username can only contain letters, numbers, underscores, and hyphens")
# 检查是否以数字开头
if username[0].isdigit():
errors.append("Username cannot start with a number")
if username.lower() in settings.banned_name:
errors.append("This username is not allowed")
return errors
def validate_password(password: str) -> list[str]:
"""验证密码"""
errors = []
if not password:
errors.append("Password is required")
return errors
if len(password) < 8:
errors.append("Password must be at least 8 characters long")
return errors
def verify_password_legacy(plain_password: str, bcrypt_hash: str) -> bool:
"""
验证密码 - 使用 osu! 的验证方式
1. 明文密码 -> MD5哈希
2. MD5哈希 -> bcrypt验证
"""
# 1. 明文密码转 MD5
pw_md5 = hashlib.md5(plain_password.encode()).hexdigest().encode() # noqa: S324
# 2. 检查缓存
if bcrypt_hash in bcrypt_cache:
return bcrypt_cache[bcrypt_hash] == pw_md5
# 3. 如果缓存中没有,进行 bcrypt 验证
try:
# 验证 MD5 哈希与 bcrypt 哈希
is_valid = bcrypt.checkpw(pw_md5, bcrypt_hash.encode())
# 如果验证成功,将结果缓存
if is_valid:
bcrypt_cache[bcrypt_hash] = pw_md5
return is_valid
except Exception:
logger.exception("Password verification error")
return False
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码(向后兼容)"""
# 首先尝试新的验证方式
if verify_password_legacy(plain_password, hashed_password):
return True
# 如果失败,尝试标准 bcrypt 验证
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""生成密码哈希 - 使用 osu! 的方式"""
# 1. 明文密码 -> MD5
pw_md5 = hashlib.md5(password.encode()).hexdigest().encode() # noqa: S324
# 2. MD5 -> bcrypt
pw_bcrypt = bcrypt.hashpw(pw_md5, bcrypt.gensalt())
return pw_bcrypt.decode()
async def authenticate_user_legacy(db: AsyncSession, name: str, password: str) -> User | None:
"""
验证用户身份 - 使用类似 from_login 的逻辑
"""
# 1. 明文密码转 MD5
pw_md5 = hashlib.md5(password.encode()).hexdigest() # noqa: S324
# 2. 根据用户名查找用户
user = None
user = (await db.exec(select(User).where(User.username == name))).first()
if user is None:
user = (await db.exec(select(User).where(User.email == name))).first()
if user is None and name.isdigit():
user = (await db.exec(select(User).where(User.id == int(name)))).first()
if user is None:
return None
# 3. 验证密码
if user.pw_bcrypt is None or user.pw_bcrypt == "":
return None
# 4. 检查缓存
if user.pw_bcrypt in bcrypt_cache:
if bcrypt_cache[user.pw_bcrypt] == pw_md5.encode():
return user
else:
return None
# 5. 验证 bcrypt
try:
is_valid = bcrypt.checkpw(pw_md5.encode(), user.pw_bcrypt.encode())
if is_valid:
# 缓存验证结果
bcrypt_cache[user.pw_bcrypt] = pw_md5.encode()
return user
except Exception:
logger.exception(f"Authentication error for user {name}")
return None
async def authenticate_user(db: AsyncSession, username: str, password: str) -> User | None:
"""验证用户身份"""
return await authenticate_user_legacy(db, username, password)
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = utcnow() + expires_delta
else:
expire = utcnow() + timedelta(minutes=settings.access_token_expire_minutes)
# 添加标准JWT声明
to_encode.update({"exp": expire, "jti": secrets.token_hex(16)})
if settings.jwt_audience:
to_encode["aud"] = settings.jwt_audience
to_encode["iss"] = str(settings.server_url)
# 编码JWT
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
return encoded_jwt
def generate_refresh_token() -> str:
"""生成刷新令牌"""
length = 64
characters = string.ascii_letters + string.digits
return "".join(secrets.choice(characters) for _ in range(length))
async def invalidate_user_tokens(db: AsyncSession, user_id: int) -> int:
"""使指定用户的所有令牌失效
返回删除的令牌数量
"""
# 使用 select 先获取所有令牌
stmt = select(OAuthToken).where(OAuthToken.user_id == user_id)
result = await db.exec(stmt)
tokens = result.all()
# 逐个删除令牌
count = 0
for token in tokens:
await db.delete(token)
count += 1
# 提交更改
await db.commit()
return count
def verify_token(token: str) -> dict | None:
"""验证访问令牌"""
try:
payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
return payload
except JWTError:
return None
async def store_token(
db: AsyncSession,
user_id: int,
client_id: int,
scopes: list[str],
access_token: str,
refresh_token: str,
expires_in: int,
refresh_token_expires_in: int,
allow_multiple_devices: bool = True,
) -> OAuthToken:
"""存储令牌到数据库(支持多设备)"""
expires_at = utcnow() + timedelta(seconds=expires_in)
refresh_token_expires_at = utcnow() + timedelta(seconds=refresh_token_expires_in)
if not allow_multiple_devices:
# 旧的行为:删除用户的旧令牌(单设备模式)
statement = select(OAuthToken).where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id)
old_tokens = (await db.exec(statement)).all()
for token in old_tokens:
await db.delete(token)
else:
# 新的行为:只删除过期的令牌,保留有效的令牌(多设备模式)
statement = select(OAuthToken).where(
OAuthToken.user_id == user_id, OAuthToken.client_id == client_id, OAuthToken.expires_at <= utcnow()
)
expired_tokens = (await db.exec(statement)).all()
for token in expired_tokens:
await db.delete(token)
# 限制每个用户每个客户端的最大令牌数量(防止无限增长)
max_tokens_per_client = settings.max_tokens_per_client
statement = (
select(OAuthToken)
.where(OAuthToken.user_id == user_id, OAuthToken.client_id == client_id, OAuthToken.expires_at > utcnow())
.order_by(col(OAuthToken.created_at).desc())
)
active_tokens = (await db.exec(statement)).all()
if len(active_tokens) >= max_tokens_per_client:
# 删除最旧的令牌
tokens_to_delete = active_tokens[max_tokens_per_client - 1 :]
for token in tokens_to_delete:
await db.delete(token)
logger.info(f"Cleaned up {len(tokens_to_delete)} old tokens for user {user_id}")
# 检查是否有重复的 access_token
duplicate_token = (await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))).first()
if duplicate_token:
await db.delete(duplicate_token)
# 创建新令牌记录
token_record = OAuthToken(
user_id=user_id,
client_id=client_id,
access_token=access_token,
scope=",".join(scopes),
refresh_token=refresh_token,
expires_at=expires_at,
refresh_token_expires_at=refresh_token_expires_at,
)
db.add(token_record)
await db.commit()
await db.refresh(token_record)
logger.info(f"Created new token for user {user_id}, client {client_id} (multi-device: {allow_multiple_devices})")
return token_record
async def get_token_by_access_token(db: AsyncSession, access_token: str) -> OAuthToken | None:
"""根据访问令牌获取令牌记录"""
statement = select(OAuthToken).where(
OAuthToken.access_token == access_token,
OAuthToken.expires_at > utcnow(),
)
return (await db.exec(statement)).first()
async def get_token_by_refresh_token(db: AsyncSession, refresh_token: str) -> OAuthToken | None:
"""根据刷新令牌获取令牌记录"""
statement = select(OAuthToken).where(
OAuthToken.refresh_token == refresh_token,
OAuthToken.refresh_token_expires_at > utcnow(),
)
return (await db.exec(statement)).first()
async def get_user_by_authorization_code(
db: AsyncSession, redis: Redis, client_id: int, code: str
) -> tuple[User, list[str]] | None:
user_id = await redis.hget(f"oauth:code:{client_id}:{code}", "user_id") # pyright: ignore[reportGeneralTypeIssues]
scopes = await redis.hget(f"oauth:code:{client_id}:{code}", "scopes") # pyright: ignore[reportGeneralTypeIssues]
if not user_id or not scopes:
return None
await redis.hdel(f"oauth:code:{client_id}:{code}", "user_id", "scopes") # pyright: ignore[reportGeneralTypeIssues]
statement = select(User).where(User.id == int(user_id))
user = (await db.exec(statement)).first()
if user:
await db.refresh(user)
return (user, scopes.split(","))
return None
def totp_redis_key(user: User) -> str:
return f"totp:setup:{user.email}"
def _generate_totp_account_label(user: User) -> str:
"""生成TOTP账户标签
根据配置选择使用用户名或邮箱,并添加服务器信息使标签更具描述性
"""
primary_identifier = user.username if settings.totp_use_username_in_label else user.email
# 如果配置了服务名称,添加到标签中以便在认证器中区分
if settings.totp_service_name:
return f"{primary_identifier} ({settings.totp_service_name})"
else:
return primary_identifier
def _generate_totp_issuer_name() -> str:
"""生成TOTP发行者名称
优先使用自定义的totp_issuer否则使用服务名称
"""
if settings.totp_issuer:
return settings.totp_issuer
elif settings.totp_service_name:
return settings.totp_service_name
else:
# 回退到默认值
return "osu! Private Server"
async def start_create_totp_key(user: User, redis: Redis) -> StartCreateTotpKeyResp:
secret = pyotp.random_base32()
await redis.hset(totp_redis_key(user), mapping={"secret": secret, "fails": 0}) # pyright: ignore[reportGeneralTypeIssues]
await redis.expire(totp_redis_key(user), 300)
# 生成更完整的账户标签和issuer信息
account_label = _generate_totp_account_label(user)
issuer_name = _generate_totp_issuer_name()
return StartCreateTotpKeyResp(
secret=secret,
uri=pyotp.totp.TOTP(secret).provisioning_uri(name=account_label, issuer_name=issuer_name),
)
def verify_totp_key(secret: str, code: str) -> bool:
return pyotp.TOTP(secret).verify(code, valid_window=1)
async def verify_totp_key_with_replay_protection(user_id: int, secret: str, code: str, redis: Redis) -> bool:
"""验证TOTP密钥并防止密钥重放攻击"""
if not pyotp.TOTP(secret).verify(code, valid_window=1):
return False
# 防止120秒内重复使用同一密钥参考osu-web实现
cache_key = f"totp:{user_id}:{code}"
if await redis.exists(cache_key):
return False
# 设置120秒过期时间
await redis.setex(cache_key, 120, "1")
return True
def _generate_backup_codes(count=10, length=BACKUP_CODE_LENGTH) -> list[str]:
alphabet = string.ascii_uppercase + string.digits
return ["".join(secrets.choice(alphabet) for _ in range(length)) for _ in range(count)]
async def _store_totp_key(user: User, secret: str, db: AsyncSession) -> list[str]:
backup_codes = _generate_backup_codes()
hashed_codes = [bcrypt.hashpw(code.encode(), bcrypt.gensalt()) for code in backup_codes]
totp_secret = TotpKeys(user_id=user.id, secret=secret, backup_keys=[code.decode() for code in hashed_codes])
db.add(totp_secret)
await db.commit()
return backup_codes
async def finish_create_totp_key(
user: User, code: str, redis: Redis, db: AsyncSession
) -> tuple[FinishStatus, list[str]]:
data = await redis.hgetall(totp_redis_key(user)) # pyright: ignore[reportGeneralTypeIssues]
if not data or "secret" not in data or "fails" not in data:
return FinishStatus.INVALID, []
secret = data["secret"]
fails = int(data["fails"])
if fails >= 3:
await redis.delete(totp_redis_key(user)) # pyright: ignore[reportGeneralTypeIssues]
return FinishStatus.TOO_MANY_ATTEMPTS, []
if verify_totp_key(secret, code):
await redis.delete(totp_redis_key(user)) # pyright: ignore[reportGeneralTypeIssues]
backup_codes = await _store_totp_key(user, secret, db)
return FinishStatus.SUCCESS, backup_codes
else:
fails += 1
await redis.hset(totp_redis_key(user), "fails", str(fails)) # pyright: ignore[reportGeneralTypeIssues]
return FinishStatus.FAILED, []
async def disable_totp(user: User, db: AsyncSession) -> None:
totp = await db.get(TotpKeys, user.id)
if totp:
await db.delete(totp)
await db.commit()
def check_totp_backup_code(totp: TotpKeys, code: str) -> bool:
for hashed_code in totp.backup_keys:
if bcrypt.checkpw(code.encode(), hashed_code.encode()):
copy = totp.backup_keys[:]
copy.remove(hashed_code)
totp.backup_keys = copy
return True
return False