feat(developer): support custom OAuth 2.0 client
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from app.auth import (
|
||||
authenticate_user,
|
||||
@@ -9,12 +10,14 @@ from app.auth import (
|
||||
generate_refresh_token,
|
||||
get_password_hash,
|
||||
get_token_by_refresh_token,
|
||||
get_user_by_authorization_code,
|
||||
store_token,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.database import DailyChallengeStats, User
|
||||
from app.database import DailyChallengeStats, OAuthClient, User
|
||||
from app.database.statistics import UserStatistics
|
||||
from app.dependencies import get_db
|
||||
from app.dependencies.database import get_redis
|
||||
from app.log import logger
|
||||
from app.models.oauth import (
|
||||
OAuthErrorResponse,
|
||||
@@ -26,6 +29,7 @@ from app.models.score import GameMode
|
||||
|
||||
from fastapi import APIRouter, Depends, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import text
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -195,21 +199,36 @@ async def register_user(
|
||||
|
||||
@router.post("/oauth/token", response_model=TokenResponse)
|
||||
async def oauth_token(
|
||||
grant_type: str = Form(...),
|
||||
client_id: str = Form(...),
|
||||
grant_type: Literal[
|
||||
"authorization_code", "refresh_token", "password", "client_credentials"
|
||||
] = Form(...),
|
||||
client_id: int = Form(...),
|
||||
client_secret: str = Form(...),
|
||||
code: str | None = Form(None),
|
||||
scope: str = Form("*"),
|
||||
username: str | None = Form(None),
|
||||
password: str | None = Form(None),
|
||||
refresh_token: str | None = Form(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
"""OAuth 令牌端点"""
|
||||
# 验证客户端凭据
|
||||
if (
|
||||
client_id != settings.osu_client_id
|
||||
or client_secret != settings.osu_client_secret
|
||||
):
|
||||
scopes = scope.split(" ")
|
||||
|
||||
client = (
|
||||
await db.exec(
|
||||
select(OAuthClient).where(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.client_secret == client_secret,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
is_game_client = (client_id, client_secret) in [
|
||||
(settings.osu_client_id, settings.osu_client_secret),
|
||||
(settings.osu_web_client_id, settings.osu_web_client_secret),
|
||||
]
|
||||
|
||||
if client is None and not is_game_client:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_client",
|
||||
description=(
|
||||
@@ -222,7 +241,6 @@ async def oauth_token(
|
||||
)
|
||||
|
||||
if grant_type == "password":
|
||||
# 密码授权流程
|
||||
if not username or not password:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_request",
|
||||
@@ -233,6 +251,16 @@ async def oauth_token(
|
||||
),
|
||||
hint="Username and password required",
|
||||
)
|
||||
if scopes != ["*"]:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_scope",
|
||||
description=(
|
||||
"The requested scope is invalid, unknown, "
|
||||
"or malformed. The client may not request "
|
||||
"more than one scope at a time."
|
||||
),
|
||||
hint="Only '*' scope is allowed for password grant type",
|
||||
)
|
||||
|
||||
# 验证用户
|
||||
user = await authenticate_user(db, username, password)
|
||||
@@ -261,6 +289,8 @@ async def oauth_token(
|
||||
await store_token(
|
||||
db,
|
||||
user.id,
|
||||
client_id,
|
||||
scopes,
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
@@ -313,6 +343,8 @@ async def oauth_token(
|
||||
await store_token(
|
||||
db,
|
||||
token_record.user_id,
|
||||
client_id,
|
||||
scopes,
|
||||
access_token,
|
||||
new_refresh_token,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
@@ -325,7 +357,69 @@ async def oauth_token(
|
||||
refresh_token=new_refresh_token,
|
||||
scope=scope,
|
||||
)
|
||||
elif grant_type == "authorization_code":
|
||||
if client is None:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_client",
|
||||
description=(
|
||||
"Client authentication failed (e.g., unknown client, "
|
||||
"no client authentication included, "
|
||||
"or unsupported authentication method)."
|
||||
),
|
||||
hint="Invalid client credentials",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
if not code:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_request",
|
||||
description=(
|
||||
"The request is missing a required parameter, "
|
||||
"includes an invalid parameter value, "
|
||||
"includes a parameter more than once, or is otherwise malformed."
|
||||
),
|
||||
hint="Authorization code required",
|
||||
)
|
||||
|
||||
code_result = await get_user_by_authorization_code(db, redis, client_id, code)
|
||||
if not code_result:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_grant",
|
||||
description=(
|
||||
"The provided authorization grant (e.g., authorization code, "
|
||||
"resource owner credentials) or refresh token is invalid, "
|
||||
"expired, revoked, does not match the redirection URI used in "
|
||||
"the authorization request, or was issued to another client."
|
||||
),
|
||||
hint="Invalid authorization code",
|
||||
)
|
||||
user, scopes = code_result
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user.id)}, expires_delta=access_token_expires
|
||||
)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
assert user.id
|
||||
await store_token(
|
||||
db,
|
||||
user.id,
|
||||
client_id,
|
||||
scopes,
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=" ".join(scopes),
|
||||
)
|
||||
else:
|
||||
return create_oauth_error_response(
|
||||
error="unsupported_grant_type",
|
||||
|
||||
Reference in New Issue
Block a user