feat(developer): support custom OAuth 2.0 client
This commit is contained in:
24
app/auth.py
24
app/auth.py
@@ -15,6 +15,7 @@ from app.log import logger
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -156,6 +157,8 @@ def verify_token(token: str) -> dict | None:
|
||||
async def store_token(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
client_id: int,
|
||||
scopes: list[str],
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
expires_in: int,
|
||||
@@ -164,7 +167,9 @@ async def store_token(
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
# 删除用户的旧令牌
|
||||
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
|
||||
statement = select(OAuthToken).where(
|
||||
OAuthToken.user_id == user_id, OAuthToken.client_id == client_id
|
||||
)
|
||||
old_tokens = (await db.exec(statement)).all()
|
||||
for token in old_tokens:
|
||||
await db.delete(token)
|
||||
@@ -179,7 +184,9 @@ async def store_token(
|
||||
# 创建新令牌记录
|
||||
token_record = OAuthToken(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
access_token=access_token,
|
||||
scope=",".join(scopes),
|
||||
refresh_token=refresh_token,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
@@ -209,3 +216,18 @@ async def get_token_by_refresh_token(
|
||||
OAuthToken.expires_at > datetime.utcnow(),
|
||||
)
|
||||
return (await db.exec(statement)).first()
|
||||
|
||||
|
||||
async def get_user_by_authorization_code(
|
||||
db: AsyncSession, redis: Redis, client_id: int, code: str
|
||||
) -> tuple[User, list[str]] | None:
|
||||
user_id = await redis.hget(f"oauth:code:{client_id}:{code}", "user_id") # pyright: ignore[reportGeneralTypeIssues]
|
||||
scopes = await redis.hget(f"oauth:code:{client_id}:{code}", "scopes") # pyright: ignore[reportGeneralTypeIssues]
|
||||
if not user_id or not scopes:
|
||||
return None
|
||||
|
||||
await redis.hdel(f"oauth:code:{client_id}:{code}", "user_id", "scopes") # pyright: ignore[reportGeneralTypeIssues]
|
||||
|
||||
statement = select(User).where(User.id == int(user_id))
|
||||
user = (await db.exec(statement)).first()
|
||||
return (user, scopes.split(",")) if user else None
|
||||
|
||||
Reference in New Issue
Block a user