From 0cba7e9dd24590d4a803a8de73d1648ccb0e8e6e Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Mon, 28 Jul 2025 15:19:56 +0000 Subject: [PATCH] fix(relationship): add `target` in response --- app/database/relationship.py | 12 ++++++++---- app/database/user.py | 8 ++++++-- app/models/user.py | 17 ++++++++++------- app/router/relationship.py | 19 +++++++++++++++++-- main.py | 4 ++++ 5 files changed, 45 insertions(+), 15 deletions(-) diff --git a/app/database/relationship.py b/app/database/relationship.py index cbf7643..61dc109 100644 --- a/app/database/relationship.py +++ b/app/database/relationship.py @@ -1,6 +1,8 @@ from enum import Enum -from .user import User +from app.models.user import User as APIUser + +from .user import User as DBUser from pydantic import BaseModel from sqlmodel import ( @@ -41,14 +43,14 @@ class Relationship(SQLModel, table=True): ), ) type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) - target: "User" = SQLRelationship( + target: DBUser = SQLRelationship( sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"} ) class RelationshipResp(BaseModel): target_id: int - # FIXME: target: User + target: APIUser mutual: bool = False type: RelationshipType @@ -56,6 +58,8 @@ class RelationshipResp(BaseModel): async def from_db( cls, session: AsyncSession, relationship: Relationship ) -> "RelationshipResp": + from app.utils import convert_db_user_to_api_user + target_relationship = ( await session.exec( select(Relationship).where( @@ -71,7 +75,7 @@ class RelationshipResp(BaseModel): ) return cls( target_id=relationship.target_id, - # target=relationship.target, + target=await convert_db_user_to_api_user(relationship.target), mutual=mutual, type=relationship.type, ) diff --git a/app/database/user.py b/app/database/user.py index 09c268e..a188497 100644 --- a/app/database/user.py +++ b/app/database/user.py @@ -102,8 +102,8 @@ class User(SQLModel, table=True): ) @classmethod - def all_select_clause(cls): - return select(cls).options( + def all_select_option(cls): + return ( joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType] joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType] joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType] @@ -121,6 +121,10 @@ class User(SQLModel, table=True): selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType] ) + @classmethod + def all_select_clause(cls): + return select(cls).options(*cls.all_select_option()) + # ============================================ # Lazer API 专用表模型 diff --git a/app/models/user.py b/app/models/user.py index 42251e9..dd90e47 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -2,16 +2,15 @@ from __future__ import annotations from datetime import datetime from enum import Enum - -from app.database import ( - LazerUserAchievement, - Team as Team, -) +from typing import TYPE_CHECKING from .score import GameMode from pydantic import BaseModel +if TYPE_CHECKING: + from app.database import LazerUserAchievement, Team + class PlayStyle(str, Enum): MOUSE = "mouse" @@ -83,7 +82,11 @@ class UserAchievement(BaseModel): achievement_id: int # 添加数据库模型转换方法 - def to_db_model(self, user_id: int) -> LazerUserAchievement: + def to_db_model(self, user_id: int) -> "LazerUserAchievement": + from app.database import ( + LazerUserAchievement, + ) + return LazerUserAchievement( user_id=user_id, achievement_id=self.achievement_id, @@ -207,5 +210,5 @@ class User(BaseModel): rank_history: RankHistory | None = None rankHistory: RankHistory | None = None # 兼容性别名 replays_watched_counts: list[dict] = [] - team: Team | None = None + team: "Team | None" = None user_achievements: list[UserAchievement] = [] diff --git a/app/router/relationship.py b/app/router/relationship.py index eb8b961..4c4bb99 100644 --- a/app/router/relationship.py +++ b/app/router/relationship.py @@ -8,6 +8,7 @@ from app.dependencies.user import get_current_user from .api_router import router from fastapi import Depends, HTTPException, Query, Request +from sqlalchemy.orm import joinedload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -25,7 +26,9 @@ async def get_relationship( else RelationshipType.BLOCK ) relationships = await db.exec( - select(Relationship).where( + select(Relationship) + .options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType] + .where( Relationship.user_id == current_user.id, Relationship.type == relationship_type, ) @@ -79,8 +82,20 @@ async def add_relationship( if target_relationship and target_relationship.type == RelationshipType.FOLLOW: await db.delete(target_relationship) await db.commit() - await db.refresh(relationship) if relationship.type == RelationshipType.FOLLOW: + relationship = ( + await db.exec( + select(Relationship) + .where( + Relationship.user_id == current_user.id, + Relationship.target_id == target, + ) + .options( + joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType] + ) + ) + ).first() + assert relationship, "Relationship should exist after commit" return await RelationshipResp.from_db(db, relationship) diff --git a/main.py b/main.py index 92d4402..526d593 100644 --- a/main.py +++ b/main.py @@ -4,12 +4,16 @@ from contextlib import asynccontextmanager from datetime import datetime from app.config import settings +from app.database import Team # noqa: F401 from app.dependencies.database import create_tables, engine from app.dependencies.fetcher import get_fetcher +from app.models.user import User from app.router import api_router, auth_router, fetcher_router, signalr_router from fastapi import FastAPI +User.model_rebuild() + @asynccontextmanager async def lifespan(app: FastAPI):