fix(relationship): add target in response

This commit is contained in:
MingxuanGame
2025-07-28 15:19:56 +00:00
parent e1b1d98c7a
commit 0cba7e9dd2
5 changed files with 45 additions and 15 deletions

View File

@@ -1,6 +1,8 @@
from enum import Enum from enum import Enum
from .user import User from app.models.user import User as APIUser
from .user import User as DBUser
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import ( from sqlmodel import (
@@ -41,14 +43,14 @@ class Relationship(SQLModel, table=True):
), ),
) )
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
target: "User" = SQLRelationship( target: DBUser = SQLRelationship(
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"} sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"}
) )
class RelationshipResp(BaseModel): class RelationshipResp(BaseModel):
target_id: int target_id: int
# FIXME: target: User target: APIUser
mutual: bool = False mutual: bool = False
type: RelationshipType type: RelationshipType
@@ -56,6 +58,8 @@ class RelationshipResp(BaseModel):
async def from_db( async def from_db(
cls, session: AsyncSession, relationship: Relationship cls, session: AsyncSession, relationship: Relationship
) -> "RelationshipResp": ) -> "RelationshipResp":
from app.utils import convert_db_user_to_api_user
target_relationship = ( target_relationship = (
await session.exec( await session.exec(
select(Relationship).where( select(Relationship).where(
@@ -71,7 +75,7 @@ class RelationshipResp(BaseModel):
) )
return cls( return cls(
target_id=relationship.target_id, target_id=relationship.target_id,
# target=relationship.target, target=await convert_db_user_to_api_user(relationship.target),
mutual=mutual, mutual=mutual,
type=relationship.type, type=relationship.type,
) )

View File

@@ -102,8 +102,8 @@ class User(SQLModel, table=True):
) )
@classmethod @classmethod
def all_select_clause(cls): def all_select_option(cls):
return select(cls).options( return (
joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType] joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType]
joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType] joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType]
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType] joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
@@ -121,6 +121,10 @@ class User(SQLModel, table=True):
selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType] selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType]
) )
@classmethod
def all_select_clause(cls):
return select(cls).options(*cls.all_select_option())
# ============================================ # ============================================
# Lazer API 专用表模型 # Lazer API 专用表模型

View File

@@ -2,16 +2,15 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING
from app.database import (
LazerUserAchievement,
Team as Team,
)
from .score import GameMode from .score import GameMode
from pydantic import BaseModel from pydantic import BaseModel
if TYPE_CHECKING:
from app.database import LazerUserAchievement, Team
class PlayStyle(str, Enum): class PlayStyle(str, Enum):
MOUSE = "mouse" MOUSE = "mouse"
@@ -83,7 +82,11 @@ class UserAchievement(BaseModel):
achievement_id: int achievement_id: int
# 添加数据库模型转换方法 # 添加数据库模型转换方法
def to_db_model(self, user_id: int) -> LazerUserAchievement: def to_db_model(self, user_id: int) -> "LazerUserAchievement":
from app.database import (
LazerUserAchievement,
)
return LazerUserAchievement( return LazerUserAchievement(
user_id=user_id, user_id=user_id,
achievement_id=self.achievement_id, achievement_id=self.achievement_id,
@@ -207,5 +210,5 @@ class User(BaseModel):
rank_history: RankHistory | None = None rank_history: RankHistory | None = None
rankHistory: RankHistory | None = None # 兼容性别名 rankHistory: RankHistory | None = None # 兼容性别名
replays_watched_counts: list[dict] = [] replays_watched_counts: list[dict] = []
team: Team | None = None team: "Team | None" = None
user_achievements: list[UserAchievement] = [] user_achievements: list[UserAchievement] = []

View File

@@ -8,6 +8,7 @@ from app.dependencies.user import get_current_user
from .api_router import router from .api_router import router
from fastapi import Depends, HTTPException, Query, Request from fastapi import Depends, HTTPException, Query, Request
from sqlalchemy.orm import joinedload
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -25,7 +26,9 @@ async def get_relationship(
else RelationshipType.BLOCK else RelationshipType.BLOCK
) )
relationships = await db.exec( relationships = await db.exec(
select(Relationship).where( select(Relationship)
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
.where(
Relationship.user_id == current_user.id, Relationship.user_id == current_user.id,
Relationship.type == relationship_type, Relationship.type == relationship_type,
) )
@@ -79,8 +82,20 @@ async def add_relationship(
if target_relationship and target_relationship.type == RelationshipType.FOLLOW: if target_relationship and target_relationship.type == RelationshipType.FOLLOW:
await db.delete(target_relationship) await db.delete(target_relationship)
await db.commit() await db.commit()
await db.refresh(relationship)
if relationship.type == RelationshipType.FOLLOW: if relationship.type == RelationshipType.FOLLOW:
relationship = (
await db.exec(
select(Relationship)
.where(
Relationship.user_id == current_user.id,
Relationship.target_id == target,
)
.options(
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
)
)
).first()
assert relationship, "Relationship should exist after commit"
return await RelationshipResp.from_db(db, relationship) return await RelationshipResp.from_db(db, relationship)

View File

@@ -4,12 +4,16 @@ from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from app.config import settings from app.config import settings
from app.database import Team # noqa: F401
from app.dependencies.database import create_tables, engine from app.dependencies.database import create_tables, engine
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.models.user import User
from app.router import api_router, auth_router, fetcher_router, signalr_router from app.router import api_router, auth_router, fetcher_router, signalr_router
from fastapi import FastAPI from fastapi import FastAPI
User.model_rebuild()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):