fix(api): avoid relationship api handling all requests

This commit is contained in:
MingxuanGame
2025-07-27 09:04:27 +00:00
parent 3ee95b0e7c
commit 9e44121427

View File

@@ -1,7 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import Literal
from app.database import User as DBUser from app.database import User as DBUser
from app.database.relationship import Relationship, RelationshipResp, RelationshipType from app.database.relationship import Relationship, RelationshipResp, RelationshipType
from app.dependencies.database import get_db from app.dependencies.database import get_db
@@ -9,21 +7,23 @@ from app.dependencies.user import get_current_user
from .api_router import router from .api_router import router
from fastapi import Depends, HTTPException, Query from fastapi import Depends, HTTPException, Query, Request
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/{type}", tags=["relationship"], response_model=list[RelationshipResp]) @router.get("/friends", tags=["relationship"], response_model=list[RelationshipResp])
@router.get("/blocks", tags=["relationship"], response_model=list[RelationshipResp])
async def get_relationship( async def get_relationship(
type: Literal["friends", "blocks"], request: Request,
current_user: DBUser = Depends(get_current_user), current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
if type == "friends": relationship_type = (
relationship_type = RelationshipType.FOLLOW RelationshipType.FOLLOW
else: if request.url.path.endswith("/friends")
relationship_type = RelationshipType.BLOCK else RelationshipType.BLOCK
)
relationships = await db.exec( relationships = await db.exec(
select(Relationship).where( select(Relationship).where(
Relationship.user_id == current_user.id, Relationship.user_id == current_user.id,
@@ -33,17 +33,19 @@ async def get_relationship(
return [await RelationshipResp.from_db(db, rel) for rel in relationships] return [await RelationshipResp.from_db(db, rel) for rel in relationships]
@router.post("/{type}", tags=["relationship"], response_model=RelationshipResp) @router.post("/friends", tags=["relationship"], response_model=RelationshipResp)
@router.post("/blocks", tags=["relationship"])
async def add_relationship( async def add_relationship(
type: Literal["friends", "blocks"], request: Request,
target: int = Query(), target: int = Query(),
current_user: DBUser = Depends(get_current_user), current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
if type == "blocks": relationship_type = (
relationship_type = RelationshipType.BLOCK RelationshipType.FOLLOW
else: if request.url.path.endswith("/friends")
relationship_type = RelationshipType.FOLLOW else RelationshipType.BLOCK
)
if target == current_user.id: if target == current_user.id:
raise HTTPException(422, "Cannot add relationship to yourself") raise HTTPException(422, "Cannot add relationship to yourself")
relationship = ( relationship = (
@@ -78,18 +80,22 @@ async def add_relationship(
await db.delete(target_relationship) await db.delete(target_relationship)
await db.commit() await db.commit()
await db.refresh(relationship) await db.refresh(relationship)
return await RelationshipResp.from_db(db, relationship) if relationship.type == RelationshipType.FOLLOW:
return await RelationshipResp.from_db(db, relationship)
@router.delete("/{type}/{target}", tags=["relationship"]) @router.delete("/friends/{target}", tags=["relationship"])
@router.delete("/blocks/{target}", tags=["relationship"])
async def delete_relationship( async def delete_relationship(
type: Literal["friends", "blocks"], request: Request,
target: int, target: int,
current_user: DBUser = Depends(get_current_user), current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
relationship_type = ( relationship_type = (
RelationshipType.BLOCK if type == "blocks" else RelationshipType.FOLLOW RelationshipType.BLOCK
if "/blocks/" in request.url.path
else RelationshipType.FOLLOW
) )
relationship = ( relationship = (
await db.exec( await db.exec(