feat(oauth): support client credentials grant
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Annotated
|
||||
from app.auth import get_token_by_access_token
|
||||
from app.config import settings
|
||||
from app.database import User
|
||||
from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||
|
||||
from .database import get_db
|
||||
|
||||
@@ -47,6 +48,16 @@ oauth2_code = OAuth2AuthorizationCodeBearer(
|
||||
scheme_name="Authorization Code Grant",
|
||||
)
|
||||
|
||||
oauth2_client_credentials = OAuth2ClientCredentialsBearer(
|
||||
tokenUrl="oauth/token",
|
||||
refreshUrl="oauth/token",
|
||||
scopes={
|
||||
"public": "允许读取公开数据。",
|
||||
},
|
||||
description="osu! OAuth 认证 (客户端凭证流)",
|
||||
scheme_name="Client Credentials Grant",
|
||||
)
|
||||
|
||||
|
||||
async def get_client_user(
|
||||
token: Annotated[str, Depends(oauth2_password)],
|
||||
@@ -67,9 +78,12 @@ async def get_current_user(
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
|
||||
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
|
||||
token_client_credentials: Annotated[
|
||||
str | None, Depends(oauth2_client_credentials)
|
||||
] = None,
|
||||
) -> User:
|
||||
"""获取当前认证用户"""
|
||||
token = token_pw or token_code
|
||||
token = token_pw or token_code or token_client_credentials
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
# OAuth 相关模型
|
||||
from __future__ import annotations
|
||||
# OAuth 相关模型 # noqa: I002
|
||||
from typing import Annotated, Any, cast
|
||||
from typing_extensions import Doc
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.openapi.models import OAuthFlows
|
||||
from fastapi.security import OAuth2
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from pydantic import BaseModel
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TokenRequest(BaseModel):
|
||||
@@ -56,3 +62,106 @@ class RegistrationRequestErrors(BaseModel):
|
||||
message: str | None = None
|
||||
redirect: str | None = None
|
||||
user: UserRegistrationErrors | None = None
|
||||
|
||||
|
||||
class OAuth2ClientCredentialsBearer(OAuth2):
|
||||
def __init__(
|
||||
self,
|
||||
tokenUrl: Annotated[
|
||||
str,
|
||||
Doc(
|
||||
"""
|
||||
The URL to obtain the OAuth2 token.
|
||||
"""
|
||||
),
|
||||
],
|
||||
refreshUrl: Annotated[
|
||||
str | None,
|
||||
Doc(
|
||||
"""
|
||||
The URL to refresh the token and obtain a new one.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
scheme_name: Annotated[
|
||||
str | None,
|
||||
Doc(
|
||||
"""
|
||||
Security scheme name.
|
||||
|
||||
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
scopes: Annotated[
|
||||
dict[str, str] | None,
|
||||
Doc(
|
||||
"""
|
||||
The OAuth2 scopes that would be required by the *path operations* that
|
||||
use this dependency.
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
description: Annotated[
|
||||
str | None,
|
||||
Doc(
|
||||
"""
|
||||
Security scheme description.
|
||||
|
||||
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
|
||||
"""
|
||||
),
|
||||
] = None,
|
||||
auto_error: Annotated[
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if no HTTP Authorization header is provided, required for
|
||||
OAuth2 authentication, it will automatically cancel the request and
|
||||
send the client an error.
|
||||
|
||||
If `auto_error` is set to `False`, when the HTTP Authorization header
|
||||
is not available, instead of erroring out, the dependency result will
|
||||
be `None`.
|
||||
|
||||
This is useful when you want to have optional authentication.
|
||||
|
||||
It is also useful when you want to have authentication that can be
|
||||
provided in one of multiple optional ways (for example, with OAuth2
|
||||
or in a cookie).
|
||||
"""
|
||||
),
|
||||
] = True,
|
||||
):
|
||||
if not scopes:
|
||||
scopes = {}
|
||||
flows = OAuthFlows(
|
||||
clientCredentials=cast(
|
||||
Any,
|
||||
{
|
||||
"tokenUrl": tokenUrl,
|
||||
"refreshUrl": refreshUrl,
|
||||
"scopes": scopes,
|
||||
},
|
||||
)
|
||||
)
|
||||
super().__init__(
|
||||
flows=flows,
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
|
||||
async def __call__(self, request: Request) -> str | None:
|
||||
authorization = request.headers.get("Authorization")
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if not authorization or scheme.lower() != "bearer":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
else:
|
||||
return None # pragma: nocover
|
||||
return param
|
||||
|
||||
@@ -428,12 +428,48 @@ async def oauth_token(
|
||||
refresh_token=refresh_token_str,
|
||||
scope=" ".join(scopes),
|
||||
)
|
||||
else:
|
||||
elif grant_type == "client_credentials":
|
||||
if client is None:
|
||||
return create_oauth_error_response(
|
||||
error="unsupported_grant_type",
|
||||
error="invalid_client",
|
||||
description=(
|
||||
"The authorization grant type is not supported "
|
||||
"by the authorization server."
|
||||
"Client authentication failed (e.g., unknown client, "
|
||||
"no client authentication included, "
|
||||
"or unsupported authentication method)."
|
||||
),
|
||||
hint="Unsupported grant type",
|
||||
hint="Invalid client credentials",
|
||||
status_code=401,
|
||||
)
|
||||
elif scopes != ["public"]:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_scope",
|
||||
description="The requested scope is invalid, unknown, or malformed.",
|
||||
hint="Scope must be 'public'",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": "3"}, expires_delta=access_token_expires
|
||||
)
|
||||
refresh_token_str = generate_refresh_token()
|
||||
|
||||
# 存储令牌
|
||||
await store_token(
|
||||
db,
|
||||
3,
|
||||
client_id,
|
||||
scopes,
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=" ".join(scopes),
|
||||
)
|
||||
|
||||
@@ -37,6 +37,7 @@ async def create_oauth_app(
|
||||
if next_id < 10:
|
||||
await session.execute(text("ALTER TABLE oauth_clients AUTO_INCREMENT = 10"))
|
||||
await session.commit()
|
||||
await session.refresh(current_user)
|
||||
|
||||
oauth_client = OAuthClient(
|
||||
name=name,
|
||||
|
||||
Reference in New Issue
Block a user