feat(oauth): support client credentials grant
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Annotated
|
|||||||
from app.auth import get_token_by_access_token
|
from app.auth import get_token_by_access_token
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database import User
|
from app.database import User
|
||||||
|
from app.models.oauth import OAuth2ClientCredentialsBearer
|
||||||
|
|
||||||
from .database import get_db
|
from .database import get_db
|
||||||
|
|
||||||
@@ -47,6 +48,16 @@ oauth2_code = OAuth2AuthorizationCodeBearer(
|
|||||||
scheme_name="Authorization Code Grant",
|
scheme_name="Authorization Code Grant",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
oauth2_client_credentials = OAuth2ClientCredentialsBearer(
|
||||||
|
tokenUrl="oauth/token",
|
||||||
|
refreshUrl="oauth/token",
|
||||||
|
scopes={
|
||||||
|
"public": "允许读取公开数据。",
|
||||||
|
},
|
||||||
|
description="osu! OAuth 认证 (客户端凭证流)",
|
||||||
|
scheme_name="Client Credentials Grant",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_client_user(
|
async def get_client_user(
|
||||||
token: Annotated[str, Depends(oauth2_password)],
|
token: Annotated[str, Depends(oauth2_password)],
|
||||||
@@ -67,9 +78,12 @@ async def get_current_user(
|
|||||||
db: Annotated[AsyncSession, Depends(get_db)],
|
db: Annotated[AsyncSession, Depends(get_db)],
|
||||||
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
|
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
|
||||||
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
|
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
|
||||||
|
token_client_credentials: Annotated[
|
||||||
|
str | None, Depends(oauth2_client_credentials)
|
||||||
|
] = None,
|
||||||
) -> User:
|
) -> User:
|
||||||
"""获取当前认证用户"""
|
"""获取当前认证用户"""
|
||||||
token = token_pw or token_code
|
token = token_pw or token_code or token_client_credentials
|
||||||
if not token:
|
if not token:
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,13 @@
|
|||||||
# OAuth 相关模型
|
# OAuth 相关模型 # noqa: I002
|
||||||
from __future__ import annotations
|
from typing import Annotated, Any, cast
|
||||||
|
from typing_extensions import Doc
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
from fastapi.openapi.models import OAuthFlows
|
||||||
|
from fastapi.security import OAuth2
|
||||||
|
from fastapi.security.utils import get_authorization_scheme_param
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
class TokenRequest(BaseModel):
|
class TokenRequest(BaseModel):
|
||||||
@@ -56,3 +62,106 @@ class RegistrationRequestErrors(BaseModel):
|
|||||||
message: str | None = None
|
message: str | None = None
|
||||||
redirect: str | None = None
|
redirect: str | None = None
|
||||||
user: UserRegistrationErrors | None = None
|
user: UserRegistrationErrors | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2ClientCredentialsBearer(OAuth2):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenUrl: Annotated[
|
||||||
|
str,
|
||||||
|
Doc(
|
||||||
|
"""
|
||||||
|
The URL to obtain the OAuth2 token.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
],
|
||||||
|
refreshUrl: Annotated[
|
||||||
|
str | None,
|
||||||
|
Doc(
|
||||||
|
"""
|
||||||
|
The URL to refresh the token and obtain a new one.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
scheme_name: Annotated[
|
||||||
|
str | None,
|
||||||
|
Doc(
|
||||||
|
"""
|
||||||
|
Security scheme name.
|
||||||
|
|
||||||
|
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
scopes: Annotated[
|
||||||
|
dict[str, str] | None,
|
||||||
|
Doc(
|
||||||
|
"""
|
||||||
|
The OAuth2 scopes that would be required by the *path operations* that
|
||||||
|
use this dependency.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
description: Annotated[
|
||||||
|
str | None,
|
||||||
|
Doc(
|
||||||
|
"""
|
||||||
|
Security scheme description.
|
||||||
|
|
||||||
|
It will be included in the generated OpenAPI (e.g. visible at `/docs`).
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
auto_error: Annotated[
|
||||||
|
bool,
|
||||||
|
Doc(
|
||||||
|
"""
|
||||||
|
By default, if no HTTP Authorization header is provided, required for
|
||||||
|
OAuth2 authentication, it will automatically cancel the request and
|
||||||
|
send the client an error.
|
||||||
|
|
||||||
|
If `auto_error` is set to `False`, when the HTTP Authorization header
|
||||||
|
is not available, instead of erroring out, the dependency result will
|
||||||
|
be `None`.
|
||||||
|
|
||||||
|
This is useful when you want to have optional authentication.
|
||||||
|
|
||||||
|
It is also useful when you want to have authentication that can be
|
||||||
|
provided in one of multiple optional ways (for example, with OAuth2
|
||||||
|
or in a cookie).
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
] = True,
|
||||||
|
):
|
||||||
|
if not scopes:
|
||||||
|
scopes = {}
|
||||||
|
flows = OAuthFlows(
|
||||||
|
clientCredentials=cast(
|
||||||
|
Any,
|
||||||
|
{
|
||||||
|
"tokenUrl": tokenUrl,
|
||||||
|
"refreshUrl": refreshUrl,
|
||||||
|
"scopes": scopes,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
flows=flows,
|
||||||
|
scheme_name=scheme_name,
|
||||||
|
description=description,
|
||||||
|
auto_error=auto_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def __call__(self, request: Request) -> str | None:
|
||||||
|
authorization = request.headers.get("Authorization")
|
||||||
|
scheme, param = get_authorization_scheme_param(authorization)
|
||||||
|
if not authorization or scheme.lower() != "bearer":
|
||||||
|
if self.auto_error:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Not authenticated",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return None # pragma: nocover
|
||||||
|
return param
|
||||||
|
|||||||
@@ -428,12 +428,48 @@ async def oauth_token(
|
|||||||
refresh_token=refresh_token_str,
|
refresh_token=refresh_token_str,
|
||||||
scope=" ".join(scopes),
|
scope=" ".join(scopes),
|
||||||
)
|
)
|
||||||
else:
|
elif grant_type == "client_credentials":
|
||||||
return create_oauth_error_response(
|
if client is None:
|
||||||
error="unsupported_grant_type",
|
return create_oauth_error_response(
|
||||||
description=(
|
error="invalid_client",
|
||||||
"The authorization grant type is not supported "
|
description=(
|
||||||
"by the authorization server."
|
"Client authentication failed (e.g., unknown client, "
|
||||||
),
|
"no client authentication included, "
|
||||||
hint="Unsupported grant type",
|
"or unsupported authentication method)."
|
||||||
|
),
|
||||||
|
hint="Invalid client credentials",
|
||||||
|
status_code=401,
|
||||||
|
)
|
||||||
|
elif scopes != ["public"]:
|
||||||
|
return create_oauth_error_response(
|
||||||
|
error="invalid_scope",
|
||||||
|
description="The requested scope is invalid, unknown, or malformed.",
|
||||||
|
hint="Scope must be 'public'",
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 生成令牌
|
||||||
|
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||||
|
access_token = create_access_token(
|
||||||
|
data={"sub": "3"}, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
refresh_token_str = generate_refresh_token()
|
||||||
|
|
||||||
|
# 存储令牌
|
||||||
|
await store_token(
|
||||||
|
db,
|
||||||
|
3,
|
||||||
|
client_id,
|
||||||
|
scopes,
|
||||||
|
access_token,
|
||||||
|
refresh_token_str,
|
||||||
|
settings.access_token_expire_minutes * 60,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=access_token,
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_in=settings.access_token_expire_minutes * 60,
|
||||||
|
refresh_token=refresh_token_str,
|
||||||
|
scope=" ".join(scopes),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ async def create_oauth_app(
|
|||||||
if next_id < 10:
|
if next_id < 10:
|
||||||
await session.execute(text("ALTER TABLE oauth_clients AUTO_INCREMENT = 10"))
|
await session.execute(text("ALTER TABLE oauth_clients AUTO_INCREMENT = 10"))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
await session.refresh(current_user)
|
||||||
|
|
||||||
oauth_client = OAuthClient(
|
oauth_client = OAuthClient(
|
||||||
name=name,
|
name=name,
|
||||||
|
|||||||
Reference in New Issue
Block a user