feat(oauth): support client credentials grant

This commit is contained in:
MingxuanGame
2025-08-13 14:12:29 +00:00
parent 7a6a548a65
commit 7817b7c59a
4 changed files with 171 additions and 11 deletions

View File

@@ -5,6 +5,7 @@ from typing import Annotated
from app.auth import get_token_by_access_token
from app.config import settings
from app.database import User
from app.models.oauth import OAuth2ClientCredentialsBearer
from .database import get_db
@@ -47,6 +48,16 @@ oauth2_code = OAuth2AuthorizationCodeBearer(
scheme_name="Authorization Code Grant",
)
oauth2_client_credentials = OAuth2ClientCredentialsBearer(
tokenUrl="oauth/token",
refreshUrl="oauth/token",
scopes={
"public": "允许读取公开数据。",
},
description="osu! OAuth 认证 (客户端凭证流)",
scheme_name="Client Credentials Grant",
)
async def get_client_user(
token: Annotated[str, Depends(oauth2_password)],
@@ -67,9 +78,12 @@ async def get_current_user(
db: Annotated[AsyncSession, Depends(get_db)],
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
token_client_credentials: Annotated[
str | None, Depends(oauth2_client_credentials)
] = None,
) -> User:
"""获取当前认证用户"""
token = token_pw or token_code
token = token_pw or token_code or token_client_credentials
if not token:
raise HTTPException(status_code=401, detail="Not authenticated")

View File

@@ -1,7 +1,13 @@
# OAuth 相关模型
from __future__ import annotations
# OAuth 相关模型 # noqa: I002
from typing import Annotated, Any, cast
from typing_extensions import Doc
from fastapi import HTTPException, Request
from fastapi.openapi.models import OAuthFlows
from fastapi.security import OAuth2
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel
from starlette.status import HTTP_401_UNAUTHORIZED
class TokenRequest(BaseModel):
@@ -56,3 +62,106 @@ class RegistrationRequestErrors(BaseModel):
message: str | None = None
redirect: str | None = None
user: UserRegistrationErrors | None = None
class OAuth2ClientCredentialsBearer(OAuth2):
def __init__(
self,
tokenUrl: Annotated[
str,
Doc(
"""
The URL to obtain the OAuth2 token.
"""
),
],
refreshUrl: Annotated[
str | None,
Doc(
"""
The URL to refresh the token and obtain a new one.
"""
),
] = None,
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
scopes: Annotated[
dict[str, str] | None,
Doc(
"""
The OAuth2 scopes that would be required by the *path operations* that
use this dependency.
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
By default, if no HTTP Authorization header is provided, required for
OAuth2 authentication, it will automatically cancel the request and
send the client an error.
If `auto_error` is set to `False`, when the HTTP Authorization header
is not available, instead of erroring out, the dependency result will
be `None`.
This is useful when you want to have optional authentication.
It is also useful when you want to have authentication that can be
provided in one of multiple optional ways (for example, with OAuth2
or in a cookie).
"""
),
] = True,
):
if not scopes:
scopes = {}
flows = OAuthFlows(
clientCredentials=cast(
Any,
{
"tokenUrl": tokenUrl,
"refreshUrl": refreshUrl,
"scopes": scopes,
},
)
)
super().__init__(
flows=flows,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
async def __call__(self, request: Request) -> str | None:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
else:
return None # pragma: nocover
return param

View File

@@ -428,12 +428,48 @@ async def oauth_token(
refresh_token=refresh_token_str,
scope=" ".join(scopes),
)
else:
return create_oauth_error_response(
error="unsupported_grant_type",
description=(
"The authorization grant type is not supported "
"by the authorization server."
),
hint="Unsupported grant type",
elif grant_type == "client_credentials":
if client is None:
return create_oauth_error_response(
error="invalid_client",
description=(
"Client authentication failed (e.g., unknown client, "
"no client authentication included, "
"or unsupported authentication method)."
),
hint="Invalid client credentials",
status_code=401,
)
elif scopes != ["public"]:
return create_oauth_error_response(
error="invalid_scope",
description="The requested scope is invalid, unknown, or malformed.",
hint="Scope must be 'public'",
status_code=400,
)
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": "3"}, expires_delta=access_token_expires
)
refresh_token_str = generate_refresh_token()
# 存储令牌
await store_token(
db,
3,
client_id,
scopes,
access_token,
refresh_token_str,
settings.access_token_expire_minutes * 60,
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),
)

View File

@@ -37,6 +37,7 @@ async def create_oauth_app(
if next_id < 10:
await session.execute(text("ALTER TABLE oauth_clients AUTO_INCREMENT = 10"))
await session.commit()
await session.refresh(current_user)
oauth_client = OAuthClient(
name=name,