feat(private-api): support OAuth client operations
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
||||
from app.models.model import UTCBaseModel
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlmodel import JSON, BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
from sqlmodel import JSON, BigInteger, Field, ForeignKey, Relationship, SQLModel, Text
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lazer_user import User
|
||||
@@ -33,6 +33,8 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
|
||||
|
||||
class OAuthClient(SQLModel, table=True):
|
||||
__tablename__ = "oauth_clients" # pyright: ignore[reportAssignmentType]
|
||||
name: str = Field(max_length=100, index=True)
|
||||
description: str = Field(sa_column=Column(Text), default="")
|
||||
client_id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
client_secret: str = Field(default_factory=secrets.token_hex, index=True)
|
||||
redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
|
||||
174
app/router/private/oauth.py
Normal file
174
app/router/private/oauth.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
|
||||
from app.database.auth import OAuthClient, OAuthToken
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Body, Depends, HTTPException
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import select, text
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.post("/oauth-app/create", tags=["OAuth"])
|
||||
async def create_oauth_app(
|
||||
name: str = Body(..., max_length=100),
|
||||
description: str = Body(""),
|
||||
redirect_uris: list[str] = Body(...),
|
||||
owner_id: int = Body(...),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await session.execute( # pyright: ignore[reportDeprecated]
|
||||
text(
|
||||
"SELECT AUTO_INCREMENT FROM information_schema.TABLES "
|
||||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'oauth_clients'"
|
||||
)
|
||||
)
|
||||
next_id = result.one()[0]
|
||||
if next_id < 10:
|
||||
await session.execute(text("ALTER TABLE oauth_clients AUTO_INCREMENT = 10"))
|
||||
await session.commit()
|
||||
|
||||
oauth_client = OAuthClient(
|
||||
name=name,
|
||||
description=description,
|
||||
redirect_uris=redirect_uris,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
session.add(oauth_client)
|
||||
await session.commit()
|
||||
await session.refresh(oauth_client)
|
||||
return {
|
||||
"client_id": oauth_client.client_id,
|
||||
"client_secret": oauth_client.client_secret,
|
||||
"redirect_uris": oauth_client.redirect_uris,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/oauth-apps/{client_id}", tags=["OAuth"])
|
||||
async def get_oauth_app(client_id: int, session: AsyncSession = Depends(get_db)):
|
||||
oauth_app = await session.get(OAuthClient, client_id)
|
||||
if not oauth_app:
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
return {
|
||||
"name": oauth_app.name,
|
||||
"description": oauth_app.description,
|
||||
"redirect_uris": oauth_app.redirect_uris,
|
||||
"client_id": oauth_app.client_id,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/oauth-apps/user/{owner_id}", tags=["OAuth"])
|
||||
async def get_user_oauth_apps(owner_id: int, session: AsyncSession = Depends(get_db)):
|
||||
oauth_apps = await session.exec(
|
||||
select(OAuthClient).where(OAuthClient.owner_id == owner_id)
|
||||
)
|
||||
return [
|
||||
{
|
||||
"name": app.name,
|
||||
"description": app.description,
|
||||
"redirect_uris": app.redirect_uris,
|
||||
"client_id": app.client_id,
|
||||
}
|
||||
for app in oauth_apps
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/oauth-app/{client_id}", tags=["OAuth"], status_code=204)
|
||||
async def delete_oauth_app(
|
||||
client_id: int,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
oauth_client = await session.get(OAuthClient, client_id)
|
||||
if not oauth_client:
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
|
||||
tokens = await session.exec(
|
||||
select(OAuthToken).where(OAuthToken.client_id == client_id)
|
||||
)
|
||||
for token in tokens:
|
||||
await session.delete(token)
|
||||
|
||||
await session.delete(oauth_client)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@router.patch("/oauth-app/{client_id}", tags=["OAuth"])
|
||||
async def update_oauth_app(
|
||||
client_id: int,
|
||||
name: str = Body(..., max_length=100),
|
||||
description: str = Body(""),
|
||||
redirect_uris: list[str] = Body(...),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
oauth_client = await session.get(OAuthClient, client_id)
|
||||
if not oauth_client:
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
|
||||
oauth_client.name = name
|
||||
oauth_client.description = description
|
||||
oauth_client.redirect_uris = redirect_uris
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(oauth_client)
|
||||
|
||||
return {
|
||||
"client_id": oauth_client.client_id,
|
||||
"client_secret": oauth_client.client_secret,
|
||||
"redirect_uris": oauth_client.redirect_uris,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/oauth-app/{client_id}/refresh", tags=["OAuth"])
|
||||
async def refresh_secret(
|
||||
client_id: int,
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
oauth_client = await session.get(OAuthClient, client_id)
|
||||
if not oauth_client:
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
|
||||
oauth_client.client_secret = secrets.token_hex()
|
||||
tokens = await session.exec(
|
||||
select(OAuthToken).where(OAuthToken.client_id == client_id)
|
||||
)
|
||||
for token in tokens:
|
||||
await session.delete(token)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(oauth_client)
|
||||
|
||||
return {
|
||||
"client_id": oauth_client.client_id,
|
||||
"client_secret": oauth_client.client_secret,
|
||||
"redirect_uris": oauth_client.redirect_uris,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/oauth-app/{client_id}/code")
|
||||
async def generate_oauth_code(
|
||||
client_id: int,
|
||||
user_id: int = Body(...),
|
||||
redirect_uri: str = Body(...),
|
||||
scopes: list[str] = Body(...),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
client = await session.get(OAuthClient, client_id)
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
|
||||
if redirect_uri not in client.redirect_uris:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Redirect URI not allowed for this client"
|
||||
)
|
||||
|
||||
code = secrets.token_urlsafe(80)
|
||||
await redis.hset( # pyright: ignore[reportGeneralTypeIssues]
|
||||
f"oauth:code:{client_id}:{code}",
|
||||
mapping={"user_id": user_id, "scopes": ",".join(scopes)},
|
||||
)
|
||||
await redis.expire(f"oauth:code:{client_id}:{code}", 300)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""auth: add `name` & `description` for OAuth client
|
||||
|
||||
Revision ID: 749bb2c2c33a
|
||||
Revises: a8669ba11e96
|
||||
Create Date: 2025-08-12 09:29:12.085060
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "749bb2c2c33a"
|
||||
down_revision: str | Sequence[str] | None = "a8669ba11e96"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"oauth_clients",
|
||||
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=False),
|
||||
)
|
||||
op.add_column("oauth_clients", sa.Column("description", sa.Text(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_oauth_clients_name"), table_name="oauth_clients")
|
||||
op.drop_column("oauth_clients", "description")
|
||||
op.drop_column("oauth_clients", "name")
|
||||
# ### end Alembic commands ###
|
||||
Reference in New Issue
Block a user