171 lines
5.9 KiB
Python
171 lines
5.9 KiB
Python
from typing import Annotated
|
||
|
||
from app.auth import get_token_by_access_token
|
||
from app.config import settings
|
||
from app.const import SUPPORT_TOTP_VERIFICATION_VER
|
||
from app.database import User
|
||
from app.database.auth import OAuthToken, V1APIKeys
|
||
from app.models.oauth import OAuth2ClientCredentialsBearer
|
||
|
||
from .api_version import APIVersion
|
||
from .database import Database, get_redis
|
||
|
||
from fastapi import Depends, HTTPException, Security
|
||
from fastapi.security import (
|
||
APIKeyQuery,
|
||
HTTPBearer,
|
||
OAuth2AuthorizationCodeBearer,
|
||
OAuth2PasswordBearer,
|
||
SecurityScopes,
|
||
)
|
||
from redis.asyncio import Redis
|
||
from sqlmodel import select
|
||
|
||
security = HTTPBearer()
|
||
|
||
|
||
oauth2_password = OAuth2PasswordBearer(
|
||
tokenUrl="oauth/token",
|
||
refreshUrl="oauth/token",
|
||
scopes={"*": "允许访问全部 API。"},
|
||
description="osu!lazer 或网页客户端密码登录认证,具有全部权限",
|
||
scheme_name="Password Grant",
|
||
)
|
||
|
||
oauth2_code = OAuth2AuthorizationCodeBearer(
|
||
authorizationUrl="oauth/authorize",
|
||
tokenUrl="oauth/token",
|
||
refreshUrl="oauth/token",
|
||
scopes={
|
||
"chat.read": "允许代表用户读取聊天消息。",
|
||
"chat.write": "允许代表用户发送聊天消息。",
|
||
"chat.write_manage": ("允许代表用户加入和离开聊天频道。"),
|
||
"delegate": ("允许作为客户端的所有者进行操作;仅适用于客户端凭证授权。"),
|
||
"forum.write": "允许代表用户创建和编辑论坛帖子。",
|
||
"friends.read": "允许读取用户的好友列表。",
|
||
"identify": "允许读取用户的公开资料 (/me)。",
|
||
"public": "允许代表用户读取公开数据。",
|
||
},
|
||
description="osu! OAuth 认证 (授权码认证)",
|
||
scheme_name="Authorization Code Grant",
|
||
)
|
||
|
||
oauth2_client_credentials = OAuth2ClientCredentialsBearer(
|
||
tokenUrl="oauth/token",
|
||
refreshUrl="oauth/token",
|
||
scopes={
|
||
"public": "允许读取公开数据。",
|
||
},
|
||
description="osu! OAuth 认证 (客户端凭证流)",
|
||
scheme_name="Client Credentials Grant",
|
||
)
|
||
|
||
v1_api_key = APIKeyQuery(name="k", scheme_name="V1 API Key", description="v1 API 密钥")
|
||
|
||
|
||
async def v1_authorize(
|
||
db: Database,
|
||
api_key: Annotated[str, Depends(v1_api_key)],
|
||
):
|
||
"""V1 API Key 授权"""
|
||
if not api_key:
|
||
raise HTTPException(status_code=401, detail="Missing API key")
|
||
|
||
api_key_record = (await db.exec(select(V1APIKeys).where(V1APIKeys.key == api_key))).first()
|
||
if not api_key_record:
|
||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||
|
||
|
||
async def get_client_user_and_token(
|
||
db: Database,
|
||
token: Annotated[str, Depends(oauth2_password)],
|
||
) -> tuple[User, OAuthToken]:
|
||
token_record = await get_token_by_access_token(db, token)
|
||
if not token_record:
|
||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||
|
||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||
if not user:
|
||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||
|
||
return user, token_record
|
||
|
||
|
||
UserAndToken = tuple[User, OAuthToken]
|
||
|
||
|
||
async def get_client_user_no_verified(user_and_token: UserAndToken = Depends(get_client_user_and_token)):
|
||
return user_and_token[0]
|
||
|
||
|
||
async def get_client_user(
|
||
db: Database,
|
||
redis: Annotated[Redis, Depends(get_redis)],
|
||
api_version: APIVersion,
|
||
user_and_token: UserAndToken = Depends(get_client_user_and_token),
|
||
):
|
||
from app.service.verification_service import LoginSessionService
|
||
|
||
user, token = user_and_token
|
||
|
||
if await LoginSessionService.check_is_need_verification(db, user.id, token.id):
|
||
# 获取当前验证方式
|
||
verify_method = None
|
||
if api_version >= SUPPORT_TOTP_VERIFICATION_VER:
|
||
verify_method = await LoginSessionService.get_login_method(user.id, token.id, redis)
|
||
|
||
if verify_method is None:
|
||
# 智能选择验证方式(有TOTP优先TOTP)
|
||
totp_key = await user.awaitable_attrs.totp_key
|
||
verify_method = "totp" if totp_key is not None and api_version >= SUPPORT_TOTP_VERIFICATION_VER else "mail"
|
||
|
||
# 设置选择的验证方法到Redis中,避免重复选择
|
||
if api_version >= 20250913:
|
||
await LoginSessionService.set_login_method(user.id, token.id, verify_method, redis)
|
||
|
||
# 返回符合 osu! API 标准的错误响应
|
||
error_response = {"error": "User not verified", "method": verify_method}
|
||
raise HTTPException(status_code=401, detail=error_response)
|
||
return user
|
||
|
||
|
||
async def get_current_user_and_token(
|
||
db: Database,
|
||
security_scopes: SecurityScopes,
|
||
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
|
||
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
|
||
token_client_credentials: Annotated[str | None, Depends(oauth2_client_credentials)] = None,
|
||
) -> UserAndToken:
|
||
"""获取当前认证用户"""
|
||
token = token_pw or token_code or token_client_credentials
|
||
if not token:
|
||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||
|
||
token_record = await get_token_by_access_token(db, token)
|
||
if not token_record:
|
||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||
|
||
is_client = token_record.client_id in (
|
||
settings.osu_client_id,
|
||
settings.osu_web_client_id,
|
||
)
|
||
|
||
if not is_client:
|
||
for scope in security_scopes.scopes:
|
||
if scope not in token_record.scope.split(","):
|
||
raise HTTPException(status_code=403, detail=f"Insufficient scope: {scope}")
|
||
|
||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||
if not user:
|
||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||
return user, token_record
|
||
|
||
|
||
async def get_current_user(
|
||
user_and_token: UserAndToken = Depends(get_current_user_and_token),
|
||
) -> User:
|
||
return user_and_token[0]
|
||
|
||
|
||
ClientUser = Annotated[User, Security(get_client_user, scopes=["*"])]
|