@@ -16,6 +16,7 @@ from sqlmodel import (
|
||||
Relationship,
|
||||
SQLModel,
|
||||
Text,
|
||||
text,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -40,14 +41,20 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
|
||||
login_session: LoginSession | None = Relationship(back_populates="token", passive_deletes=True)
|
||||
|
||||
|
||||
class OAuthClient(SQLModel, table=True):
|
||||
class OAuthClient(UTCBaseModel, SQLModel, table=True):
|
||||
__tablename__: str = "oauth_clients"
|
||||
name: str = Field(max_length=100, index=True)
|
||||
description: str = Field(sa_column=Column(Text), default="")
|
||||
client_id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
client_secret: str = Field(default_factory=secrets.token_hex, index=True)
|
||||
client_secret: str = Field(default_factory=secrets.token_hex, index=True, exclude=True)
|
||||
redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
|
||||
owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True), exclude=True)
|
||||
|
||||
created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=utcnow,
|
||||
sa_column=Column(DateTime, onupdate=text("CURRENT_TIMESTAMP")),
|
||||
)
|
||||
|
||||
|
||||
class V1APIKeys(SQLModel, table=True):
|
||||
|
||||
@@ -46,9 +46,8 @@ async def create_oauth_app(
|
||||
await session.commit()
|
||||
await session.refresh(oauth_client)
|
||||
return {
|
||||
"client_id": oauth_client.client_id,
|
||||
"client_secret": oauth_client.client_secret,
|
||||
"redirect_uris": oauth_client.redirect_uris,
|
||||
**oauth_client.model_dump(exclude={"client_secret"}),
|
||||
}
|
||||
|
||||
|
||||
@@ -57,6 +56,7 @@ async def create_oauth_app(
|
||||
name="获取 OAuth 应用信息",
|
||||
description="通过客户端 ID 获取 OAuth 应用的详细信息",
|
||||
tags=["osu! OAuth 认证", "g0v0 API"],
|
||||
response_model=OAuthClient,
|
||||
)
|
||||
async def get_oauth_app(
|
||||
session: Database,
|
||||
@@ -66,12 +66,7 @@ async def get_oauth_app(
|
||||
oauth_app = await session.get(OAuthClient, client_id)
|
||||
if not oauth_app:
|
||||
raise HTTPException(status_code=404, detail="OAuth app not found")
|
||||
return {
|
||||
"name": oauth_app.name,
|
||||
"description": oauth_app.description,
|
||||
"redirect_uris": oauth_app.redirect_uris,
|
||||
"client_id": oauth_app.client_id,
|
||||
}
|
||||
return oauth_app
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -79,21 +74,14 @@ async def get_oauth_app(
|
||||
name="获取用户的 OAuth 应用列表",
|
||||
description="获取当前用户创建的所有 OAuth 应用程序",
|
||||
tags=["osu! OAuth 认证", "g0v0 API"],
|
||||
response_model=list[OAuthClient],
|
||||
)
|
||||
async def get_user_oauth_apps(
|
||||
session: Database,
|
||||
current_user: ClientUser,
|
||||
):
|
||||
oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id))
|
||||
return [
|
||||
{
|
||||
"name": app.name,
|
||||
"description": app.description,
|
||||
"redirect_uris": app.redirect_uris,
|
||||
"client_id": app.client_id,
|
||||
}
|
||||
for app in oauth_apps
|
||||
]
|
||||
return oauth_apps.all()
|
||||
|
||||
|
||||
@router.delete(
|
||||
@@ -150,9 +138,8 @@ async def update_oauth_app(
|
||||
await session.refresh(oauth_client)
|
||||
|
||||
return {
|
||||
"client_id": oauth_client.client_id,
|
||||
"client_secret": oauth_client.client_secret,
|
||||
"redirect_uris": oauth_client.redirect_uris,
|
||||
**oauth_client.model_dump(exclude={"client_secret"}),
|
||||
}
|
||||
|
||||
|
||||
@@ -182,9 +169,8 @@ async def refresh_secret(
|
||||
await session.refresh(oauth_client)
|
||||
|
||||
return {
|
||||
"client_id": oauth_client.client_id,
|
||||
"client_secret": oauth_client.client_secret,
|
||||
"redirect_uris": oauth_client.redirect_uris,
|
||||
**oauth_client.model_dump(exclude={"client_secret"}),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""oauth-client: add date
|
||||
|
||||
Revision ID: 57641cb601f4
|
||||
Revises: 23707640303c
|
||||
Create Date: 2025-11-23 13:46:55.654967
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "57641cb601f4"
|
||||
down_revision: str | Sequence[str] | None = "23707640303c"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("oauth_clients", sa.Column("created_at", sa.DateTime(), nullable=True))
|
||||
op.add_column(
|
||||
"oauth_clients", sa.Column("updated_at", sa.DateTime(), nullable=True, onupdate=sa.text("CURRENT_TIMESTAMP"))
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE oauth_clients SET created_at = NOW(), updated_at = NOW() WHERE created_at IS NULL OR updated_at IS NULL"
|
||||
)
|
||||
op.alter_column("oauth_clients", "created_at", existing_type=sa.DateTime(), nullable=False)
|
||||
op.alter_column("oauth_clients", "updated_at", existing_type=sa.DateTime(), nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("oauth_clients", "updated_at")
|
||||
op.drop_column("oauth_clients", "created_at")
|
||||
# ### end Alembic commands ###
|
||||
Reference in New Issue
Block a user