feat(developer): support custom OAuth 2.0 client
This commit is contained in:
@@ -1,34 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from app.auth import get_token_by_access_token
|
||||
from app.config import settings
|
||||
from app.database import User
|
||||
|
||||
from .database import get_db
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi.security import (
|
||||
HTTPBearer,
|
||||
OAuth2AuthorizationCodeBearer,
|
||||
OAuth2PasswordBearer,
|
||||
SecurityScopes,
|
||||
)
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
oauth2_password = OAuth2PasswordBearer(
|
||||
tokenUrl="oauth/token",
|
||||
scopes={"*": "Allows access to all scopes."},
|
||||
)
|
||||
|
||||
oauth2_code = OAuth2AuthorizationCodeBearer(
|
||||
authorizationUrl="oauth/authorize",
|
||||
tokenUrl="oauth/token",
|
||||
scopes={
|
||||
"chat.read": "Allows read chat messages on a user's behalf.",
|
||||
"chat.write": "Allows sending chat messages on a user's behalf.",
|
||||
"chat.write_manage": (
|
||||
"Allows joining and leaving chat channels on a user's behalf."
|
||||
),
|
||||
"delegate": (
|
||||
"Allows acting as the owner of a client; "
|
||||
"only available for Client Credentials Grant."
|
||||
),
|
||||
"forum.write": "Allows creating and editing forum posts on a user's behalf.",
|
||||
"friends.read": "Allows reading of the user's friend list.",
|
||||
"identify": "Allows reading of the public profile of the user (/me).",
|
||||
"public": "Allows reading of publicly available data on behalf of the user.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
security_scopes: SecurityScopes,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
|
||||
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
|
||||
) -> User:
|
||||
"""获取当前认证用户"""
|
||||
token = credentials.credentials
|
||||
token = token_pw or token_code
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
user = await get_current_user_by_token(token, db)
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
is_client = token_record.client_id in (
|
||||
settings.osu_client_id,
|
||||
settings.osu_web_client_id,
|
||||
)
|
||||
|
||||
if security_scopes.scopes == ["*"]:
|
||||
# client/web only
|
||||
if not token_pw or not is_client:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
elif not is_client:
|
||||
for scope in security_scopes.scopes:
|
||||
if scope not in token_record.scope.split(","):
|
||||
raise HTTPException(
|
||||
status_code=403, detail=f"Insufficient scope: {scope}"
|
||||
)
|
||||
|
||||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None:
|
||||
token_record = await get_token_by_access_token(db, token)
|
||||
if not token_record:
|
||||
return None
|
||||
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
|
||||
return user
|
||||
|
||||
Reference in New Issue
Block a user