feat(developer): support custom OAuth 2.0 client

This commit is contained in:
MingxuanGame
2025-08-11 12:33:31 +00:00
parent ee9381d1f0
commit 6e71141146
21 changed files with 380 additions and 82 deletions

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta
import re
from typing import Literal
from app.auth import (
authenticate_user,
@@ -9,12 +10,14 @@ from app.auth import (
generate_refresh_token,
get_password_hash,
get_token_by_refresh_token,
get_user_by_authorization_code,
store_token,
)
from app.config import settings
from app.database import DailyChallengeStats, User
from app.database import DailyChallengeStats, OAuthClient, User
from app.database.statistics import UserStatistics
from app.dependencies import get_db
from app.dependencies.database import get_redis
from app.log import logger
from app.models.oauth import (
OAuthErrorResponse,
@@ -26,6 +29,7 @@ from app.models.score import GameMode
from fastapi import APIRouter, Depends, Form
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlalchemy import text
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -195,21 +199,36 @@ async def register_user(
@router.post("/oauth/token", response_model=TokenResponse)
async def oauth_token(
grant_type: str = Form(...),
client_id: str = Form(...),
grant_type: Literal[
"authorization_code", "refresh_token", "password", "client_credentials"
] = Form(...),
client_id: int = Form(...),
client_secret: str = Form(...),
code: str | None = Form(None),
scope: str = Form("*"),
username: str | None = Form(None),
password: str | None = Form(None),
refresh_token: str | None = Form(None),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
"""OAuth 令牌端点"""
# 验证客户端凭据
if (
client_id != settings.osu_client_id
or client_secret != settings.osu_client_secret
):
scopes = scope.split(" ")
client = (
await db.exec(
select(OAuthClient).where(
OAuthClient.client_id == client_id,
OAuthClient.client_secret == client_secret,
)
)
).first()
is_game_client = (client_id, client_secret) in [
(settings.osu_client_id, settings.osu_client_secret),
(settings.osu_web_client_id, settings.osu_web_client_secret),
]
if client is None and not is_game_client:
return create_oauth_error_response(
error="invalid_client",
description=(
@@ -222,7 +241,6 @@ async def oauth_token(
)
if grant_type == "password":
# 密码授权流程
if not username or not password:
return create_oauth_error_response(
error="invalid_request",
@@ -233,6 +251,16 @@ async def oauth_token(
),
hint="Username and password required",
)
if scopes != ["*"]:
return create_oauth_error_response(
error="invalid_scope",
description=(
"The requested scope is invalid, unknown, "
"or malformed. The client may not request "
"more than one scope at a time."
),
hint="Only '*' scope is allowed for password grant type",
)
# 验证用户
user = await authenticate_user(db, username, password)
@@ -261,6 +289,8 @@ async def oauth_token(
await store_token(
db,
user.id,
client_id,
scopes,
access_token,
refresh_token_str,
settings.access_token_expire_minutes * 60,
@@ -313,6 +343,8 @@ async def oauth_token(
await store_token(
db,
token_record.user_id,
client_id,
scopes,
access_token,
new_refresh_token,
settings.access_token_expire_minutes * 60,
@@ -325,7 +357,69 @@ async def oauth_token(
refresh_token=new_refresh_token,
scope=scope,
)
elif grant_type == "authorization_code":
if client is None:
return create_oauth_error_response(
error="invalid_client",
description=(
"Client authentication failed (e.g., unknown client, "
"no client authentication included, "
"or unsupported authentication method)."
),
hint="Invalid client credentials",
status_code=401,
)
if not code:
return create_oauth_error_response(
error="invalid_request",
description=(
"The request is missing a required parameter, "
"includes an invalid parameter value, "
"includes a parameter more than once, or is otherwise malformed."
),
hint="Authorization code required",
)
code_result = await get_user_by_authorization_code(db, redis, client_id, code)
if not code_result:
return create_oauth_error_response(
error="invalid_grant",
description=(
"The provided authorization grant (e.g., authorization code, "
"resource owner credentials) or refresh token is invalid, "
"expired, revoked, does not match the redirection URI used in "
"the authorization request, or was issued to another client."
),
hint="Invalid authorization code",
)
user, scopes = code_result
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": str(user.id)}, expires_delta=access_token_expires
)
refresh_token_str = generate_refresh_token()
# 存储令牌
assert user.id
await store_token(
db,
user.id,
client_id,
scopes,
access_token,
refresh_token_str,
settings.access_token_expire_minutes * 60,
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),
)
else:
return create_oauth_error_response(
error="unsupported_grant_type",