From 7817b7c59ae88962fb19ac0eb8c2bf13851a68bc Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Wed, 13 Aug 2025 14:12:29 +0000 Subject: [PATCH] feat(oauth): support client credentials grant --- app/dependencies/user.py | 16 ++++- app/models/oauth.py | 113 +++++++++++++++++++++++++++++++++++- app/router/auth.py | 52 ++++++++++++++--- app/router/private/oauth.py | 1 + 4 files changed, 171 insertions(+), 11 deletions(-) diff --git a/app/dependencies/user.py b/app/dependencies/user.py index c85ab57..284e5d2 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -5,6 +5,7 @@ from typing import Annotated from app.auth import get_token_by_access_token from app.config import settings from app.database import User +from app.models.oauth import OAuth2ClientCredentialsBearer from .database import get_db @@ -47,6 +48,16 @@ oauth2_code = OAuth2AuthorizationCodeBearer( scheme_name="Authorization Code Grant", ) +oauth2_client_credentials = OAuth2ClientCredentialsBearer( + tokenUrl="oauth/token", + refreshUrl="oauth/token", + scopes={ + "public": "允许读取公开数据。", + }, + description="osu! OAuth 认证 (客户端凭证流)", + scheme_name="Client Credentials Grant", +) + async def get_client_user( token: Annotated[str, Depends(oauth2_password)], @@ -67,9 +78,12 @@ async def get_current_user( db: Annotated[AsyncSession, Depends(get_db)], token_pw: Annotated[str | None, Depends(oauth2_password)] = None, token_code: Annotated[str | None, Depends(oauth2_code)] = None, + token_client_credentials: Annotated[ + str | None, Depends(oauth2_client_credentials) + ] = None, ) -> User: """获取当前认证用户""" - token = token_pw or token_code + token = token_pw or token_code or token_client_credentials if not token: raise HTTPException(status_code=401, detail="Not authenticated") diff --git a/app/models/oauth.py b/app/models/oauth.py index 6665965..f3db41f 100644 --- a/app/models/oauth.py +++ b/app/models/oauth.py @@ -1,7 +1,13 @@ -# OAuth 相关模型 -from __future__ import annotations +# OAuth 相关模型 # noqa: I002 +from typing import Annotated, Any, cast +from typing_extensions import Doc +from fastapi import HTTPException, Request +from fastapi.openapi.models import OAuthFlows +from fastapi.security import OAuth2 +from fastapi.security.utils import get_authorization_scheme_param from pydantic import BaseModel +from starlette.status import HTTP_401_UNAUTHORIZED class TokenRequest(BaseModel): @@ -56,3 +62,106 @@ class RegistrationRequestErrors(BaseModel): message: str | None = None redirect: str | None = None user: UserRegistrationErrors | None = None + + +class OAuth2ClientCredentialsBearer(OAuth2): + def __init__( + self, + tokenUrl: Annotated[ + str, + Doc( + """ + The URL to obtain the OAuth2 token. + """ + ), + ], + refreshUrl: Annotated[ + str | None, + Doc( + """ + The URL to refresh the token and obtain a new one. + """ + ), + ] = None, + scheme_name: Annotated[ + str | None, + Doc( + """ + Security scheme name. + + It will be included in the generated OpenAPI (e.g. visible at `/docs`). + """ + ), + ] = None, + scopes: Annotated[ + dict[str, str] | None, + Doc( + """ + The OAuth2 scopes that would be required by the *path operations* that + use this dependency. + """ + ), + ] = None, + description: Annotated[ + str | None, + Doc( + """ + Security scheme description. + + It will be included in the generated OpenAPI (e.g. visible at `/docs`). + """ + ), + ] = None, + auto_error: Annotated[ + bool, + Doc( + """ + By default, if no HTTP Authorization header is provided, required for + OAuth2 authentication, it will automatically cancel the request and + send the client an error. + + If `auto_error` is set to `False`, when the HTTP Authorization header + is not available, instead of erroring out, the dependency result will + be `None`. + + This is useful when you want to have optional authentication. + + It is also useful when you want to have authentication that can be + provided in one of multiple optional ways (for example, with OAuth2 + or in a cookie). + """ + ), + ] = True, + ): + if not scopes: + scopes = {} + flows = OAuthFlows( + clientCredentials=cast( + Any, + { + "tokenUrl": tokenUrl, + "refreshUrl": refreshUrl, + "scopes": scopes, + }, + ) + ) + super().__init__( + flows=flows, + scheme_name=scheme_name, + description=description, + auto_error=auto_error, + ) + + async def __call__(self, request: Request) -> str | None: + authorization = request.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != "bearer": + if self.auto_error: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + else: + return None # pragma: nocover + return param diff --git a/app/router/auth.py b/app/router/auth.py index 7251232..2466662 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -428,12 +428,48 @@ async def oauth_token( refresh_token=refresh_token_str, scope=" ".join(scopes), ) - else: - return create_oauth_error_response( - error="unsupported_grant_type", - description=( - "The authorization grant type is not supported " - "by the authorization server." - ), - hint="Unsupported grant type", + elif grant_type == "client_credentials": + if client is None: + return create_oauth_error_response( + error="invalid_client", + description=( + "Client authentication failed (e.g., unknown client, " + "no client authentication included, " + "or unsupported authentication method)." + ), + hint="Invalid client credentials", + status_code=401, + ) + elif scopes != ["public"]: + return create_oauth_error_response( + error="invalid_scope", + description="The requested scope is invalid, unknown, or malformed.", + hint="Scope must be 'public'", + status_code=400, + ) + + # 生成令牌 + access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) + access_token = create_access_token( + data={"sub": "3"}, expires_delta=access_token_expires + ) + refresh_token_str = generate_refresh_token() + + # 存储令牌 + await store_token( + db, + 3, + client_id, + scopes, + access_token, + refresh_token_str, + settings.access_token_expire_minutes * 60, + ) + + return TokenResponse( + access_token=access_token, + token_type="Bearer", + expires_in=settings.access_token_expire_minutes * 60, + refresh_token=refresh_token_str, + scope=" ".join(scopes), ) diff --git a/app/router/private/oauth.py b/app/router/private/oauth.py index 88f8587..d3950e3 100644 --- a/app/router/private/oauth.py +++ b/app/router/private/oauth.py @@ -37,6 +37,7 @@ async def create_oauth_app( if next_id < 10: await session.execute(text("ALTER TABLE oauth_clients AUTO_INCREMENT = 10")) await session.commit() + await session.refresh(current_user) oauth_client = OAuthClient( name=name,