fix(relationship): add target in response
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
from .user import User
|
||||
from app.models.user import User as APIUser
|
||||
|
||||
from .user import User as DBUser
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import (
|
||||
@@ -41,14 +43,14 @@ class Relationship(SQLModel, table=True):
|
||||
),
|
||||
)
|
||||
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
|
||||
target: "User" = SQLRelationship(
|
||||
target: DBUser = SQLRelationship(
|
||||
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"}
|
||||
)
|
||||
|
||||
|
||||
class RelationshipResp(BaseModel):
|
||||
target_id: int
|
||||
# FIXME: target: User
|
||||
target: APIUser
|
||||
mutual: bool = False
|
||||
type: RelationshipType
|
||||
|
||||
@@ -56,6 +58,8 @@ class RelationshipResp(BaseModel):
|
||||
async def from_db(
|
||||
cls, session: AsyncSession, relationship: Relationship
|
||||
) -> "RelationshipResp":
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
|
||||
target_relationship = (
|
||||
await session.exec(
|
||||
select(Relationship).where(
|
||||
@@ -71,7 +75,7 @@ class RelationshipResp(BaseModel):
|
||||
)
|
||||
return cls(
|
||||
target_id=relationship.target_id,
|
||||
# target=relationship.target,
|
||||
target=await convert_db_user_to_api_user(relationship.target),
|
||||
mutual=mutual,
|
||||
type=relationship.type,
|
||||
)
|
||||
|
||||
@@ -102,8 +102,8 @@ class User(SQLModel, table=True):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def all_select_clause(cls):
|
||||
return select(cls).options(
|
||||
def all_select_option(cls):
|
||||
return (
|
||||
joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType]
|
||||
joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType]
|
||||
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
|
||||
@@ -121,6 +121,10 @@ class User(SQLModel, table=True):
|
||||
selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def all_select_clause(cls):
|
||||
return select(cls).options(*cls.all_select_option())
|
||||
|
||||
|
||||
# ============================================
|
||||
# Lazer API 专用表模型
|
||||
|
||||
@@ -2,16 +2,15 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from app.database import (
|
||||
LazerUserAchievement,
|
||||
Team as Team,
|
||||
)
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .score import GameMode
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database import LazerUserAchievement, Team
|
||||
|
||||
|
||||
class PlayStyle(str, Enum):
|
||||
MOUSE = "mouse"
|
||||
@@ -83,7 +82,11 @@ class UserAchievement(BaseModel):
|
||||
achievement_id: int
|
||||
|
||||
# 添加数据库模型转换方法
|
||||
def to_db_model(self, user_id: int) -> LazerUserAchievement:
|
||||
def to_db_model(self, user_id: int) -> "LazerUserAchievement":
|
||||
from app.database import (
|
||||
LazerUserAchievement,
|
||||
)
|
||||
|
||||
return LazerUserAchievement(
|
||||
user_id=user_id,
|
||||
achievement_id=self.achievement_id,
|
||||
@@ -207,5 +210,5 @@ class User(BaseModel):
|
||||
rank_history: RankHistory | None = None
|
||||
rankHistory: RankHistory | None = None # 兼容性别名
|
||||
replays_watched_counts: list[dict] = []
|
||||
team: Team | None = None
|
||||
team: "Team | None" = None
|
||||
user_achievements: list[UserAchievement] = []
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.dependencies.user import get_current_user
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, Request
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -25,7 +26,9 @@ async def get_relationship(
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
relationships = await db.exec(
|
||||
select(Relationship).where(
|
||||
select(Relationship)
|
||||
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
Relationship.user_id == current_user.id,
|
||||
Relationship.type == relationship_type,
|
||||
)
|
||||
@@ -79,8 +82,20 @@ async def add_relationship(
|
||||
if target_relationship and target_relationship.type == RelationshipType.FOLLOW:
|
||||
await db.delete(target_relationship)
|
||||
await db.commit()
|
||||
await db.refresh(relationship)
|
||||
if relationship.type == RelationshipType.FOLLOW:
|
||||
relationship = (
|
||||
await db.exec(
|
||||
select(Relationship)
|
||||
.where(
|
||||
Relationship.user_id == current_user.id,
|
||||
Relationship.target_id == target,
|
||||
)
|
||||
.options(
|
||||
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
).first()
|
||||
assert relationship, "Relationship should exist after commit"
|
||||
return await RelationshipResp.from_db(db, relationship)
|
||||
|
||||
|
||||
|
||||
4
main.py
4
main.py
@@ -4,12 +4,16 @@ from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
|
||||
from app.config import settings
|
||||
from app.database import Team # noqa: F401
|
||||
from app.dependencies.database import create_tables, engine
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.models.user import User
|
||||
from app.router import api_router, auth_router, fetcher_router, signalr_router
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
User.model_rebuild()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
||||
Reference in New Issue
Block a user