From ef977d1c2d495d398bd83240d28390c43a75bf68 Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sat, 26 Jul 2025 15:31:09 +0000 Subject: [PATCH] feat(relationship): implement relationship(friends, blocks) api (close #6) --- app/database/__init__.py | 4 ++ app/database/relationship.py | 62 ++++++++++++++++++++ app/database/score.py | 2 +- app/router/__init__.py | 1 + app/router/relationship.py | 107 +++++++++++++++++++++++++++++++++++ create_sample_data.py | 36 ++++++++++-- 6 files changed, 206 insertions(+), 6 deletions(-) create mode 100644 app/database/relationship.py create mode 100644 app/router/relationship.py diff --git a/app/database/__init__.py b/app/database/__init__.py index a0cdc2a..b7df7d6 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -8,6 +8,7 @@ from .beatmapset import ( BeatmapsetResp as BeatmapsetResp, ) from .legacy import LegacyOAuthToken, LegacyUserStatistics +from .relationship import Relationship, RelationshipResp, RelationshipType from .team import Team, TeamMember from .user import ( DailyChallengeStats, @@ -53,6 +54,9 @@ __all__ = [ "LegacyUserStatistics", "OAuthToken", "RankHistory", + "Relationship", + "RelationshipResp", + "RelationshipType", "Team", "TeamMember", "User", diff --git a/app/database/relationship.py b/app/database/relationship.py new file mode 100644 index 0000000..e352b81 --- /dev/null +++ b/app/database/relationship.py @@ -0,0 +1,62 @@ +from enum import Enum + +from .user import User + +from pydantic import BaseModel +from sqlmodel import ( + Field, + Relationship as SQLRelationship, + SQLModel, + select, +) +from sqlmodel.ext.asyncio.session import AsyncSession + + +class RelationshipType(str, Enum): + FOLLOW = "Friend" + BLOCK = "Block" + + +class Relationship(SQLModel, table=True): + __tablename__ = "relationship" # pyright: ignore[reportAssignmentType] + user_id: int = Field( + default=None, foreign_key="users.id", primary_key=True, index=True + ) + target_id: int = Field( + default=None, foreign_key="users.id", primary_key=True, index=True + ) + type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) + target: "User" = SQLRelationship( + sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"} + ) + + +class RelationshipResp(BaseModel): + target_id: int + # FIXME: target: User + mutual: bool = False + type: RelationshipType + + @classmethod + async def from_db( + cls, session: AsyncSession, relationship: Relationship + ) -> "RelationshipResp": + target_relationship = ( + await session.exec( + select(Relationship).where( + Relationship.user_id == relationship.target_id, + Relationship.target_id == relationship.user_id, + ) + ) + ).first() + mutual = bool( + target_relationship is not None + and relationship.type == RelationshipType.FOLLOW + and target_relationship.type == RelationshipType.FOLLOW + ) + return cls( + target_id=relationship.target_id, + # target=relationship.target, + mutual=mutual, + type=relationship.type, + ) diff --git a/app/database/score.py b/app/database/score.py index 50cd097..37b0590 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -65,7 +65,7 @@ class Score(ScoreBase, table=True): nkatu: int = Field(exclude=True) nlarge_tick_miss: int | None = Field(default=None, exclude=True) nslider_tail_hit: int | None = Field(default=None, exclude=True) - gamemode: GameMode = Field(index=True, alias="ruleset_id") + gamemode: GameMode = Field(index=True) # optional beatmap: "Beatmap" = Relationship() diff --git a/app/router/__init__.py b/app/router/__init__.py index 50c2b4d..680bb5e 100644 --- a/app/router/__init__.py +++ b/app/router/__init__.py @@ -4,6 +4,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401 beatmap, beatmapset, me, + relationship, ) from .api_router import router as api_router from .auth import router as auth_router diff --git a/app/router/relationship.py b/app/router/relationship.py new file mode 100644 index 0000000..6f28ec2 --- /dev/null +++ b/app/router/relationship.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from typing import Literal + +from app.database import User as DBUser +from app.database.relationship import Relationship, RelationshipResp, RelationshipType +from app.dependencies.database import get_db +from app.dependencies.user import get_current_user + +from .api_router import router + +from fastapi import Depends, HTTPException, Query +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + + +@router.get("/{type}", tags=["relationship"], response_model=list[RelationshipResp]) +async def get_relationship( + type: Literal["friends", "blocks"], + current_user: DBUser = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + if type == "friends": + relationship_type = RelationshipType.FOLLOW + else: + relationship_type = RelationshipType.BLOCK + relationships = await db.exec( + select(Relationship).where( + Relationship.user_id == current_user.id, + Relationship.type == relationship_type, + ) + ) + return [await RelationshipResp.from_db(db, rel) for rel in relationships] + + +@router.post("/{type}", tags=["relationship"], response_model=RelationshipResp) +async def add_relationship( + type: Literal["friends", "blocks"], + target: int = Query(), + current_user: DBUser = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + if type == "blocks": + relationship_type = RelationshipType.BLOCK + else: + relationship_type = RelationshipType.FOLLOW + if target == current_user.id: + raise HTTPException(422, "Cannot add relationship to yourself") + relationship = ( + await db.exec( + select(Relationship).where( + Relationship.user_id == current_user.id, + Relationship.target_id == target, + ) + ) + ).first() + if relationship: + relationship.type = relationship_type + # 这里原来如何是 block 也会修改为 follow + # 与 ppy/osu-web 的行为保持一致 + else: + relationship = Relationship( + user_id=current_user.id, + target_id=target, + type=relationship_type, + ) + db.add(relationship) + if relationship.type == RelationshipType.BLOCK: + target_relationship = ( + await db.exec( + select(Relationship).where( + Relationship.user_id == target, + Relationship.target_id == current_user.id, + ) + ) + ).first() + if target_relationship and target_relationship.type == RelationshipType.FOLLOW: + await db.delete(target_relationship) + await db.commit() + await db.refresh(relationship) + return await RelationshipResp.from_db(db, relationship) + + +@router.delete("/{type}/{target}", tags=["relationship"]) +async def delete_relationship( + type: Literal["friends", "blocks"], + target: int, + current_user: DBUser = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + relationship_type = ( + RelationshipType.BLOCK if type == "blocks" else RelationshipType.FOLLOW + ) + relationship = ( + await db.exec( + select(Relationship).where( + Relationship.user_id == current_user.id, + Relationship.target_id == target, + ) + ) + ).first() + if not relationship: + raise HTTPException(404, "Relationship not found") + if relationship.type != relationship_type: + raise HTTPException(422, "Relationship type mismatch") + await db.delete(relationship) + await db.commit() diff --git a/create_sample_data.py b/create_sample_data.py index 610bff7..8c090ca 100644 --- a/create_sample_data.py +++ b/create_sample_data.py @@ -29,12 +29,14 @@ async def create_sample_user(): async with AsyncSession(engine) as session: async with session.begin(): # 检查用户是否已存在 - statement = select(User).where(User.name == "Googujiang") - result = await session.exec(statement) + result = await session.exec(select(User).where(User.name == "Googujiang")) + result2 = await session.exec( + select(User).where(User.name == "MingxuanGame") + ) existing_user = result.first() - if existing_user: + existing_user2 = result2.first() + if existing_user is not None and existing_user2 is not None: print("示例用户已存在,跳过创建") - return existing_user # 当前时间戳 # current_timestamp = int(time.time()) @@ -62,13 +64,37 @@ async def create_sample_user(): userpage_content="「世界に忘れられた」", api_key=None, ) + user2 = User( + name="MingxuanGame", + safe_name="mingxuangame", # 安全用户名(小写) + email="mingxuangame@example.com", + priv=1, # 默认权限 + pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式 + country="US", + silence_end=0, + donor_end=0, + creation_time=join_timestamp, + latest_activity=last_visit_timestamp, + clan_id=0, + clan_priv=0, + preferred_mode=0, # 0 = osu! + play_style=0, + custom_badge_name=None, + custom_badge_icon=None, + userpage_content="For love and fun!", + api_key=None, + ) session.add(user) + session.add(user2) print(f"成功创建示例用户: {user.name} (ID: {user.id})") print(f"安全用户名: {user.safe_name}") print(f"邮箱: {user.email}") print(f"国家: {user.country}") - return user + print(f"成功创建示例用户: {user2.name} (ID: {user2.id})") + print(f"安全用户名: {user2.safe_name}") + print(f"邮箱: {user2.email}") + print(f"国家: {user2.country}") async def create_sample_beatmap_data():