feat(oauth): support client credentials grant

This commit is contained in:
MingxuanGame
2025-08-13 14:12:29 +00:00
parent 7a6a548a65
commit 7817b7c59a
4 changed files with 171 additions and 11 deletions

View File

@@ -5,6 +5,7 @@ from typing import Annotated
from app.auth import get_token_by_access_token from app.auth import get_token_by_access_token
from app.config import settings from app.config import settings
from app.database import User from app.database import User
from app.models.oauth import OAuth2ClientCredentialsBearer
from .database import get_db from .database import get_db
@@ -47,6 +48,16 @@ oauth2_code = OAuth2AuthorizationCodeBearer(
scheme_name="Authorization Code Grant", scheme_name="Authorization Code Grant",
) )
oauth2_client_credentials = OAuth2ClientCredentialsBearer(
tokenUrl="oauth/token",
refreshUrl="oauth/token",
scopes={
"public": "允许读取公开数据。",
},
description="osu! OAuth 认证 (客户端凭证流)",
scheme_name="Client Credentials Grant",
)
async def get_client_user( async def get_client_user(
token: Annotated[str, Depends(oauth2_password)], token: Annotated[str, Depends(oauth2_password)],
@@ -67,9 +78,12 @@ async def get_current_user(
db: Annotated[AsyncSession, Depends(get_db)], db: Annotated[AsyncSession, Depends(get_db)],
token_pw: Annotated[str | None, Depends(oauth2_password)] = None, token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
token_code: Annotated[str | None, Depends(oauth2_code)] = None, token_code: Annotated[str | None, Depends(oauth2_code)] = None,
token_client_credentials: Annotated[
str | None, Depends(oauth2_client_credentials)
] = None,
) -> User: ) -> User:
"""获取当前认证用户""" """获取当前认证用户"""
token = token_pw or token_code token = token_pw or token_code or token_client_credentials
if not token: if not token:
raise HTTPException(status_code=401, detail="Not authenticated") raise HTTPException(status_code=401, detail="Not authenticated")

View File

@@ -1,7 +1,13 @@
# OAuth 相关模型 # OAuth 相关模型 # noqa: I002
from __future__ import annotations from typing import Annotated, Any, cast
from typing_extensions import Doc
from fastapi import HTTPException, Request
from fastapi.openapi.models import OAuthFlows
from fastapi.security import OAuth2
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel from pydantic import BaseModel
from starlette.status import HTTP_401_UNAUTHORIZED
class TokenRequest(BaseModel): class TokenRequest(BaseModel):
@@ -56,3 +62,106 @@ class RegistrationRequestErrors(BaseModel):
message: str | None = None message: str | None = None
redirect: str | None = None redirect: str | None = None
user: UserRegistrationErrors | None = None user: UserRegistrationErrors | None = None
class OAuth2ClientCredentialsBearer(OAuth2):
def __init__(
self,
tokenUrl: Annotated[
str,
Doc(
"""
The URL to obtain the OAuth2 token.
"""
),
],
refreshUrl: Annotated[
str | None,
Doc(
"""
The URL to refresh the token and obtain a new one.
"""
),
] = None,
scheme_name: Annotated[
str | None,
Doc(
"""
Security scheme name.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
scopes: Annotated[
dict[str, str] | None,
Doc(
"""
The OAuth2 scopes that would be required by the *path operations* that
use this dependency.
"""
),
] = None,
description: Annotated[
str | None,
Doc(
"""
Security scheme description.
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
"""
),
] = None,
auto_error: Annotated[
bool,
Doc(
"""
By default, if no HTTP Authorization header is provided, required for
OAuth2 authentication, it will automatically cancel the request and
send the client an error.
If `auto_error` is set to `False`, when the HTTP Authorization header
is not available, instead of erroring out, the dependency result will
be `None`.
This is useful when you want to have optional authentication.
It is also useful when you want to have authentication that can be
provided in one of multiple optional ways (for example, with OAuth2
or in a cookie).
"""
),
] = True,
):
if not scopes:
scopes = {}
flows = OAuthFlows(
clientCredentials=cast(
Any,
{
"tokenUrl": tokenUrl,
"refreshUrl": refreshUrl,
"scopes": scopes,
},
)
)
super().__init__(
flows=flows,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
async def __call__(self, request: Request) -> str | None:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
else:
return None # pragma: nocover
return param

View File

@@ -428,12 +428,48 @@ async def oauth_token(
refresh_token=refresh_token_str, refresh_token=refresh_token_str,
scope=" ".join(scopes), scope=" ".join(scopes),
) )
else: elif grant_type == "client_credentials":
return create_oauth_error_response( if client is None:
error="unsupported_grant_type", return create_oauth_error_response(
description=( error="invalid_client",
"The authorization grant type is not supported " description=(
"by the authorization server." "Client authentication failed (e.g., unknown client, "
), "no client authentication included, "
hint="Unsupported grant type", "or unsupported authentication method)."
),
hint="Invalid client credentials",
status_code=401,
)
elif scopes != ["public"]:
return create_oauth_error_response(
error="invalid_scope",
description="The requested scope is invalid, unknown, or malformed.",
hint="Scope must be 'public'",
status_code=400,
)
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": "3"}, expires_delta=access_token_expires
)
refresh_token_str = generate_refresh_token()
# 存储令牌
await store_token(
db,
3,
client_id,
scopes,
access_token,
refresh_token_str,
settings.access_token_expire_minutes * 60,
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),
) )

View File

@@ -37,6 +37,7 @@ async def create_oauth_app(
if next_id < 10: if next_id < 10:
await session.execute(text("ALTER TABLE oauth_clients AUTO_INCREMENT = 10")) await session.execute(text("ALTER TABLE oauth_clients AUTO_INCREMENT = 10"))
await session.commit() await session.commit()
await session.refresh(current_user)
oauth_client = OAuthClient( oauth_client = OAuthClient(
name=name, name=name,