diff --git a/app/database/auth.py b/app/database/auth.py index 8e9032b..ae49676 100644 --- a/app/database/auth.py +++ b/app/database/auth.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import TYPE_CHECKING from sqlalchemy import Column, DateTime -from sqlmodel import Field, Relationship, SQLModel +from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel if TYPE_CHECKING: from .user import User @@ -12,7 +12,9 @@ class OAuthToken(SQLModel, table=True): __tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("users.id"), index=True) + ) access_token: str = Field(max_length=500, unique=True) refresh_token: str = Field(max_length=500, unique=True) token_type: str = Field(default="Bearer", max_length=20) diff --git a/app/database/beatmap.py b/app/database/beatmap.py index ea15799..46fdd96 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -1,6 +1,6 @@ from datetime import datetime +from typing import TYPE_CHECKING -from app.fetcher import Fetcher from app.models.beatmap import BeatmapRankStatus from app.models.score import MODE_TO_INT, GameMode @@ -11,6 +11,9 @@ from sqlalchemy.orm import joinedload from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession +if TYPE_CHECKING: + from app.fetcher import Fetcher + class BeatmapOwner(SQLModel): id: int @@ -111,7 +114,7 @@ class Beatmap(BeatmapBase, table=True): @classmethod async def get_or_fetch( - cls, session: AsyncSession, bid: int, fetcher: Fetcher + cls, session: AsyncSession, bid: int, fetcher: "Fetcher" ) -> "Beatmap": beatmap = ( await session.exec( diff --git a/app/database/legacy.py b/app/database/legacy.py index c0db405..ff1e957 100644 --- a/app/database/legacy.py +++ b/app/database/legacy.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from sqlalchemy import JSON, Column, DateTime from sqlalchemy.orm import Mapped -from sqlmodel import Field, Relationship, SQLModel +from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel if TYPE_CHECKING: from .user import User @@ -16,7 +16,7 @@ class LegacyUserStatistics(SQLModel, table=True): __tablename__ = "user_statistics" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) mode: str = Field(max_length=10) # osu, taiko, fruits, mania # 基本统计 @@ -77,7 +77,7 @@ class LegacyOAuthToken(SQLModel, table=True): __tablename__ = "legacy_oauth_tokens" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) access_token: str = Field(max_length=255, index=True) refresh_token: str = Field(max_length=255, index=True) expires_at: datetime = Field(sa_column=Column(DateTime)) diff --git a/app/database/relationship.py b/app/database/relationship.py index e352b81..cbf7643 100644 --- a/app/database/relationship.py +++ b/app/database/relationship.py @@ -4,7 +4,10 @@ from .user import User from pydantic import BaseModel from sqlmodel import ( + BigInteger, + Column, Field, + ForeignKey, Relationship as SQLRelationship, SQLModel, select, @@ -20,10 +23,22 @@ class RelationshipType(str, Enum): class Relationship(SQLModel, table=True): __tablename__ = "relationship" # pyright: ignore[reportAssignmentType] user_id: int = Field( - default=None, foreign_key="users.id", primary_key=True, index=True + default=None, + sa_column=Column( + BigInteger, + ForeignKey("users.id"), + primary_key=True, + index=True, + ), ) target_id: int = Field( - default=None, foreign_key="users.id", primary_key=True, index=True + default=None, + sa_column=Column( + BigInteger, + ForeignKey("users.id"), + primary_key=True, + index=True, + ), ) type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False) target: "User" = SQLRelationship( diff --git a/app/database/score.py b/app/database/score.py index f82a813..1bc2e58 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -22,6 +22,7 @@ from sqlmodel import ( JSON, BigInteger, Field, + ForeignKey, Relationship, SQLModel, col, @@ -69,7 +70,14 @@ class Score(ScoreBase, table=True): default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True) ) beatmap_id: int = Field(index=True, foreign_key="beatmaps.id") - user_id: int = Field(foreign_key="users.id", index=True) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("users.id"), + index=True, + ), + ) # ScoreStatistics n300: int = Field(exclude=True) n100: int = Field(exclude=True) @@ -92,6 +100,7 @@ class Score(ScoreBase, table=True): class ScoreResp(ScoreBase): id: int + user_id: int is_perfect_combo: bool = False legacy_perfect: bool = False legacy_total_score: int = 0 # FIXME diff --git a/app/database/score_token.py b/app/database/score_token.py index 195a174..6a6edb3 100644 --- a/app/database/score_token.py +++ b/app/database/score_token.py @@ -27,12 +27,15 @@ class ScoreToken(ScoreTokenBase, table=True): id: int | None = Field( default=None, - primary_key=True, - index=True, - sa_column_kwargs={"autoincrement": True}, + sa_column=Column( + BigInteger, + primary_key=True, + index=True, + autoincrement=True, + ), ) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) - beatmap_id: int = Field(sa_column=Column(BigInteger, ForeignKey("beatmaps.id"))) + beatmap_id: int = Field(foreign_key="beatmaps.id") user: "User" = Relationship() beatmap: "Beatmap" = Relationship() diff --git a/app/database/team.py b/app/database/team.py index 2722319..360e805 100644 --- a/app/database/team.py +++ b/app/database/team.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import TYPE_CHECKING from sqlalchemy import Column, DateTime -from sqlmodel import Field, Relationship, SQLModel +from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel if TYPE_CHECKING: from .user import User @@ -26,7 +26,7 @@ class TeamMember(SQLModel, table=True): __tablename__ = "team_members" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) team_id: int = Field(foreign_key="teams.id") joined_at: datetime = Field( default_factory=datetime.utcnow, sa_column=Column(DateTime) diff --git a/app/database/user.py b/app/database/user.py index 71a6eda..6c70ce0 100644 --- a/app/database/user.py +++ b/app/database/user.py @@ -7,7 +7,7 @@ from .team import TeamMember from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text from sqlalchemy.dialects.mysql import VARCHAR -from sqlmodel import BigInteger, Field, Relationship, SQLModel +from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel class User(SQLModel, table=True): @@ -109,7 +109,14 @@ class User(SQLModel, table=True): class LazerUserProfile(SQLModel, table=True): __tablename__ = "lazer_user_profiles" # pyright: ignore[reportAssignmentType] - user_id: int = Field(foreign_key="users.id", primary_key=True) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("users.id"), + primary_key=True, + ), + ) # 基本状态字段 is_active: bool = Field(default=True) @@ -165,7 +172,7 @@ class LazerUserProfileSections(SQLModel, table=True): __tablename__ = "lazer_user_profile_sections" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) section_name: str = Field(sa_column=Column(VARCHAR(50))) display_order: int | None = Field(default=None) @@ -182,7 +189,14 @@ class LazerUserProfileSections(SQLModel, table=True): class LazerUserCountry(SQLModel, table=True): __tablename__ = "lazer_user_countries" # pyright: ignore[reportAssignmentType] - user_id: int = Field(foreign_key="users.id", primary_key=True) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("users.id"), + primary_key=True, + ), + ) code: str = Field(max_length=2) name: str = Field(max_length=100) @@ -197,7 +211,14 @@ class LazerUserCountry(SQLModel, table=True): class LazerUserKudosu(SQLModel, table=True): __tablename__ = "lazer_user_kudosu" # pyright: ignore[reportAssignmentType] - user_id: int = Field(foreign_key="users.id", primary_key=True) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("users.id"), + primary_key=True, + ), + ) available: int = Field(default=0) total: int = Field(default=0) @@ -212,7 +233,14 @@ class LazerUserKudosu(SQLModel, table=True): class LazerUserCounts(SQLModel, table=True): __tablename__ = "lazer_user_counts" # pyright: ignore[reportAssignmentType] - user_id: int = Field(foreign_key="users.id", primary_key=True) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("users.id"), + primary_key=True, + ), + ) # 统计计数字段 beatmap_playcounts_count: int = Field(default=0) @@ -247,7 +275,14 @@ class LazerUserCounts(SQLModel, table=True): class LazerUserStatistics(SQLModel, table=True): __tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType] - user_id: int = Field(foreign_key="users.id", primary_key=True) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("users.id"), + primary_key=True, + ), + ) mode: str = Field(default="osu", max_length=10, primary_key=True) # 基本命中统计 @@ -308,7 +343,7 @@ class LazerUserBanners(SQLModel, table=True): __tablename__ = "lazer_user_tournament_banners" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) tournament_id: int image_url: str = Field(sa_column=Column(VARCHAR(500))) is_active: bool | None = Field(default=None) @@ -321,7 +356,7 @@ class LazerUserAchievement(SQLModel, table=True): __tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) achievement_id: int achieved_at: datetime = Field( default_factory=datetime.utcnow, sa_column=Column(DateTime) @@ -334,7 +369,7 @@ class LazerUserBadge(SQLModel, table=True): __tablename__ = "lazer_user_badges" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) badge_id: int awarded_at: datetime | None = Field(default=None, sa_column=Column(DateTime)) description: str | None = Field(default=None, sa_column=Column(Text)) @@ -355,7 +390,7 @@ class LazerUserMonthlyPlaycounts(SQLModel, table=True): __tablename__ = "lazer_user_monthly_playcounts" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) start_date: datetime = Field(sa_column=Column(Date)) play_count: int = Field(default=0) @@ -373,7 +408,7 @@ class LazerUserPreviousUsername(SQLModel, table=True): __tablename__ = "lazer_user_previous_usernames" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) username: str = Field(max_length=32) changed_at: datetime = Field(sa_column=Column(DateTime)) @@ -391,7 +426,7 @@ class LazerUserReplaysWatched(SQLModel, table=True): __tablename__ = "lazer_user_replays_watched" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) start_date: datetime = Field(sa_column=Column(Date)) count: int = Field(default=0) @@ -416,7 +451,9 @@ class DailyChallengeStats(SQLModel, table=True): __tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(foreign_key="users.id", unique=True) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("users.id"), unique=True) + ) daily_streak_best: int = Field(default=0) daily_streak_current: int = Field(default=0) @@ -437,7 +474,7 @@ class RankHistory(SQLModel, table=True): __tablename__ = "rank_history" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) mode: str = Field(max_length=10) rank_data: list = Field(sa_column=Column(JSON)) # Array of ranks date_recorded: datetime = Field( @@ -451,7 +488,7 @@ class UserAvatar(SQLModel, table=True): __tablename__ = "user_avatars" # pyright: ignore[reportAssignmentType] id: int | None = Field(default=None, primary_key=True, index=True) - user_id: int = Field(foreign_key="users.id") + user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id"))) filename: str = Field(max_length=255) original_filename: str = Field(max_length=255) file_size: int diff --git a/app/fetcher/beatmap.py b/app/fetcher/beatmap.py index d9da207..8e770f1 100644 --- a/app/fetcher/beatmap.py +++ b/app/fetcher/beatmap.py @@ -1,19 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from app.database.beatmap import BeatmapResp from ._base import BaseFetcher from httpx import AsyncClient -if TYPE_CHECKING: - from app.database.beatmap import BeatmapResp - class BeatmapFetcher(BaseFetcher): - async def get_beatmap(self, beatmap_id: int) -> "BeatmapResp": - from app.database.beatmap import BeatmapResp - + async def get_beatmap(self, beatmap_id: int) -> BeatmapResp: async with AsyncClient() as client: response = await client.get( f"https://osu.ppy.sh/api/v2/beatmaps/{beatmap_id}", diff --git a/app/models/score.py b/app/models/score.py index f038988..d1e391a 100644 --- a/app/models/score.py +++ b/app/models/score.py @@ -3,7 +3,7 @@ from __future__ import annotations from enum import Enum from typing import Literal, TypedDict -from .mods import API_MOD_TO_LEGACY, API_MODS, APIMod, init_mods +from .mods import API_MODS, APIMod, init_mods from pydantic import BaseModel, Field, ValidationInfo, field_validator import rosu_pp_py as rosu @@ -109,7 +109,7 @@ class SoloScoreSubmissionInfo(BaseModel): @field_validator("mods", mode="after") @classmethod def validate_mods(cls, mods: list[APIMod], info: ValidationInfo): - if not API_MOD_TO_LEGACY: + if not API_MODS: init_mods() incompatible_mods = set() # check incompatible mods @@ -122,6 +122,7 @@ class SoloScoreSubmissionInfo(BaseModel): if not setting_mods: raise ValueError(f"Invalid mod: {mod['acronym']}") incompatible_mods.update(setting_mods["IncompatibleMods"]) + return mods class LegacyReplaySoloScoreInfo(TypedDict): diff --git a/app/router/score.py b/app/router/score.py index 9cc57f1..2bf9519 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -1,17 +1,22 @@ from __future__ import annotations +import datetime + from app.database import ( Beatmap, User as DBUser, ) from app.database.beatmapset import Beatmapset from app.database.score import Score, ScoreResp +from app.database.score_token import ScoreToken, ScoreTokenResp +from app.database.user import User from app.dependencies.database import get_db from app.dependencies.user import get_current_user +from app.models.score import INT_TO_MODE, HitResult, Rank, SoloScoreSubmissionInfo from .api_router import router -from fastapi import Depends, HTTPException, Query +from fastapi import Depends, Form, HTTPException, Query from pydantic import BaseModel from sqlalchemy.orm import joinedload from sqlmodel import col, select @@ -63,8 +68,8 @@ async def get_beatmap_scores( ).first() return BeatmapScores( - scores=[ScoreResp.from_db(score) for score in all_scores], - userScore=ScoreResp.from_db(user_score) if user_score else None, + scores=[await ScoreResp.from_db(db, score) for score in all_scores], + userScore=await ScoreResp.from_db(db, user_score) if user_score else None, ) @@ -115,7 +120,7 @@ async def get_user_beatmap_score( else: return BeatmapUserScore( position=user_score.position if user_score.position is not None else 0, - score=ScoreResp.from_db(user_score), + score=await ScoreResp.from_db(db, user_score), ) @@ -153,4 +158,113 @@ async def get_user_all_beatmap_scores( ) ).all() - return [ScoreResp.from_db(score) for score in all_user_scores] + return [await ScoreResp.from_db(db, score) for score in all_user_scores] + + +@router.post( + "/beatmaps/{beatmap}/solo/scores", tags=["beatmap"], response_model=ScoreTokenResp +) +async def create_solo_score( + beatmap: int, + version_hash: str = Form(""), + beatmap_hash: str = Form(), + ruleset_id: int = Form(..., ge=0, le=3), + current_user: DBUser = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + async with db: + score_token = ScoreToken( + user_id=current_user.id, + beatmap_id=beatmap, + ruleset_id=INT_TO_MODE[ruleset_id], + ) + db.add(score_token) + await db.commit() + await db.refresh(score_token) + return ScoreTokenResp.from_db(score_token) + + +@router.put( + "/beatmaps/{beatmap}/solo/scores/{token}", + tags=["beatmap"], + response_model=ScoreResp, +) +async def submit_solo_score( + beatmap: int, + token: int, + info: SoloScoreSubmissionInfo, + current_user: DBUser = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + if not info.passed: + info.rank = Rank.F + async with db: + score_token = ( + await db.exec( + select(ScoreToken) + .options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType] + .where(ScoreToken.id == token, ScoreToken.user_id == current_user.id) + ) + ).first() + if not score_token or score_token.user_id != current_user.id: + raise HTTPException(status_code=404, detail="Score token not found") + if score_token.score_id: + score = ( + await db.exec( + select(Score) + .options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType] + .where( + Score.id == score_token.score_id, + Score.user_id == current_user.id, + ) + ) + ).first() + if not score: + raise HTTPException(status_code=404, detail="Score not found") + else: + score = Score( + accuracy=info.accuracy, + max_combo=info.max_combo, + # maximum_statistics=info.maximum_statistics, + mods=info.mods, + passed=info.passed, + rank=info.rank, + total_score=info.total_score, + total_score_without_mods=info.total_score_without_mods, + beatmap_id=beatmap, + ended_at=datetime.datetime.now(datetime.UTC), + gamemode=INT_TO_MODE[info.ruleset_id], + started_at=score_token.created_at, + user_id=current_user.id, + preserve=info.passed, + map_md5=score_token.beatmap.checksum, + has_replay=False, + pp=info.pp, + type="solo", + n300=info.statistics.get(HitResult.GREAT, 0), + n100=info.statistics.get(HitResult.OK, 0), + n50=info.statistics.get(HitResult.MEH, 0), + nmiss=info.statistics.get(HitResult.MISS, 0), + ngeki=info.statistics.get(HitResult.PERFECT, 0), + nkatu=info.statistics.get(HitResult.GOOD, 0), + ) + db.add(score) + await db.commit() + await db.refresh(score) + score_id = score.id + score_token.score_id = score_id + await db.commit() + score = ( + await db.exec( + select(Score) + .options( + joinedload(Score.beatmap) # pyright: ignore[reportArgumentType] + .joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType] + .selectinload(Beatmapset.beatmaps), # pyright: ignore[reportArgumentType] + joinedload(Score.user).joinedload(User.lazer_profile), # pyright: ignore[reportArgumentType] + ) + .where(Score.id == score_id) + ) + ).first() + assert score is not None + return await ScoreResp.from_db(db, score)