diff --git a/app/database/auth.py b/app/database/auth.py index 443ba8c..962ef06 100644 --- a/app/database/auth.py +++ b/app/database/auth.py @@ -16,6 +16,7 @@ from sqlmodel import ( Relationship, SQLModel, Text, + text, ) if TYPE_CHECKING: @@ -40,14 +41,20 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True): login_session: LoginSession | None = Relationship(back_populates="token", passive_deletes=True) -class OAuthClient(SQLModel, table=True): +class OAuthClient(UTCBaseModel, SQLModel, table=True): __tablename__: str = "oauth_clients" name: str = Field(max_length=100, index=True) description: str = Field(sa_column=Column(Text), default="") client_id: int | None = Field(default=None, primary_key=True, index=True) - client_secret: str = Field(default_factory=secrets.token_hex, index=True) + client_secret: str = Field(default_factory=secrets.token_hex, index=True, exclude=True) redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON)) - owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) + owner_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True), exclude=True) + + created_at: datetime = Field(default_factory=utcnow, sa_column=Column(DateTime)) + updated_at: datetime = Field( + default_factory=utcnow, + sa_column=Column(DateTime, onupdate=text("CURRENT_TIMESTAMP")), + ) class V1APIKeys(SQLModel, table=True): diff --git a/app/router/private/oauth.py b/app/router/private/oauth.py index 8c4c664..9e0af89 100644 --- a/app/router/private/oauth.py +++ b/app/router/private/oauth.py @@ -46,9 +46,8 @@ async def create_oauth_app( await session.commit() await session.refresh(oauth_client) return { - "client_id": oauth_client.client_id, "client_secret": oauth_client.client_secret, - "redirect_uris": oauth_client.redirect_uris, + **oauth_client.model_dump(exclude={"client_secret"}), } @@ -57,6 +56,7 @@ async def create_oauth_app( name="获取 OAuth 应用信息", description="通过客户端 ID 获取 OAuth 应用的详细信息", tags=["osu! OAuth 认证", "g0v0 API"], + response_model=OAuthClient, ) async def get_oauth_app( session: Database, @@ -66,12 +66,7 @@ async def get_oauth_app( oauth_app = await session.get(OAuthClient, client_id) if not oauth_app: raise HTTPException(status_code=404, detail="OAuth app not found") - return { - "name": oauth_app.name, - "description": oauth_app.description, - "redirect_uris": oauth_app.redirect_uris, - "client_id": oauth_app.client_id, - } + return oauth_app @router.get( @@ -79,21 +74,14 @@ async def get_oauth_app( name="获取用户的 OAuth 应用列表", description="获取当前用户创建的所有 OAuth 应用程序", tags=["osu! OAuth 认证", "g0v0 API"], + response_model=list[OAuthClient], ) async def get_user_oauth_apps( session: Database, current_user: ClientUser, ): oauth_apps = await session.exec(select(OAuthClient).where(OAuthClient.owner_id == current_user.id)) - return [ - { - "name": app.name, - "description": app.description, - "redirect_uris": app.redirect_uris, - "client_id": app.client_id, - } - for app in oauth_apps - ] + return oauth_apps.all() @router.delete( @@ -150,9 +138,8 @@ async def update_oauth_app( await session.refresh(oauth_client) return { - "client_id": oauth_client.client_id, "client_secret": oauth_client.client_secret, - "redirect_uris": oauth_client.redirect_uris, + **oauth_client.model_dump(exclude={"client_secret"}), } @@ -182,9 +169,8 @@ async def refresh_secret( await session.refresh(oauth_client) return { - "client_id": oauth_client.client_id, "client_secret": oauth_client.client_secret, - "redirect_uris": oauth_client.redirect_uris, + **oauth_client.model_dump(exclude={"client_secret"}), } diff --git a/migrations/versions/2025-11-23_57641cb601f4_oauth_client_add_date.py b/migrations/versions/2025-11-23_57641cb601f4_oauth_client_add_date.py new file mode 100644 index 0000000..d214745 --- /dev/null +++ b/migrations/versions/2025-11-23_57641cb601f4_oauth_client_add_date.py @@ -0,0 +1,41 @@ +"""oauth-client: add date + +Revision ID: 57641cb601f4 +Revises: 23707640303c +Create Date: 2025-11-23 13:46:55.654967 + +""" + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "57641cb601f4" +down_revision: str | Sequence[str] | None = "23707640303c" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("oauth_clients", sa.Column("created_at", sa.DateTime(), nullable=True)) + op.add_column( + "oauth_clients", sa.Column("updated_at", sa.DateTime(), nullable=True, onupdate=sa.text("CURRENT_TIMESTAMP")) + ) + op.execute( + "UPDATE oauth_clients SET created_at = NOW(), updated_at = NOW() WHERE created_at IS NULL OR updated_at IS NULL" + ) + op.alter_column("oauth_clients", "created_at", existing_type=sa.DateTime(), nullable=False) + op.alter_column("oauth_clients", "updated_at", existing_type=sa.DateTime(), nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("oauth_clients", "updated_at") + op.drop_column("oauth_clients", "created_at") + # ### end Alembic commands ###