diff --git a/app/database/__init__.py b/app/database/__init__.py index 12fa867..6e2e8c5 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -10,6 +10,7 @@ from .beatmapset import ( ) from .best_score import BestScore from .daily_challenge import DailyChallengeStats, DailyChallengeStatsResp +from .favourite_beatmapset import FavouriteBeatmapset from .lazer_user import ( User, UserResp, @@ -41,6 +42,7 @@ __all__ = [ "BestScore", "DailyChallengeStats", "DailyChallengeStatsResp", + "FavouriteBeatmapset", "OAuthToken", "PPBestScore", "Relationship", diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 2ab5ad0..c55643a 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -14,6 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from app.fetcher import Fetcher + from .lazer_user import User + class BeatmapOwner(SQLModel): id: int @@ -161,6 +163,8 @@ class BeatmapResp(BeatmapBase): beatmap: Beatmap, query_mode: GameMode | None = None, from_set: bool = False, + session: AsyncSession | None = None, + user: "User | None" = None, ) -> "BeatmapResp": beatmap_ = beatmap.model_dump() if query_mode is not None and beatmap.mode != query_mode: @@ -170,5 +174,7 @@ class BeatmapResp(BeatmapBase): beatmap_["ranked"] = beatmap.beatmap_status.value beatmap_["mode_int"] = MODE_TO_INT[beatmap.mode] if not from_set: - beatmap_["beatmapset"] = await BeatmapsetResp.from_db(beatmap.beatmapset) + beatmap_["beatmapset"] = await BeatmapsetResp.from_db( + beatmap.beatmapset, session=session, user=user + ) return cls.model_validate(beatmap_) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 5a618b7..49313b2 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -5,14 +5,17 @@ from app.models.beatmap import BeatmapRankStatus, Genre, Language from app.models.model import UTCBaseModel from app.models.score import GameMode +from .lazer_user import BASE_INCLUDES, User, UserResp + from pydantic import BaseModel, model_serializer from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text from sqlalchemy.ext.asyncio import AsyncAttrs -from sqlmodel import Field, Relationship, SQLModel +from sqlmodel import Field, Relationship, SQLModel, col, func, select from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: from .beatmap import Beatmap, BeatmapResp + from .favourite_beatmapset import FavouriteBeatmapset class BeatmapCovers(SQLModel): @@ -90,7 +93,6 @@ class BeatmapsetBase(SQLModel, UTCBaseModel): artist_unicode: str = Field(index=True) covers: BeatmapCovers | None = Field(sa_column=Column(JSON)) creator: str - favourite_count: int nsfw: bool = Field(default=False) play_count: int preview_url: str @@ -114,11 +116,9 @@ class BeatmapsetBase(SQLModel, UTCBaseModel): pack_tags: list[str] = Field(default=[], sa_column=Column(JSON)) ratings: list[int] = Field(default=None, sa_column=Column(JSON)) - # TODO: recent_favourites: Optional[list[User]] = None # TODO: related_users: Optional[list[User]] = None # TODO: user: Optional[User] = Field(default=None) track_id: int | None = Field(default=None) # feature artist? - # TODO: has_favourited # BeatmapsetExtended bpm: float = Field(default=0.0, sa_column=Column(DECIMAL(10, 2))) @@ -152,6 +152,7 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): hype_required: int = Field(default=0) availability_info: str | None = Field(default=None) download_disabled: bool = Field(default=False) + favourites: list["FavouriteBeatmapset"] = Relationship(back_populates="beatmapset") @classmethod async def from_resp( @@ -199,40 +200,88 @@ class BeatmapsetResp(BeatmapsetBase): genre: BeatmapTranslationText | None = None language: BeatmapTranslationText | None = None nominations: BeatmapNominations | None = None + has_favourited: bool = False + favourite_count: int = 0 + recent_favourites: list[UserResp] = Field(default_factory=list) @classmethod - async def from_db(cls, beatmapset: Beatmapset) -> "BeatmapsetResp": + async def from_db( + cls, + beatmapset: Beatmapset, + include: list[str] = [], + session: AsyncSession | None = None, + user: User | None = None, + ) -> "BeatmapsetResp": from .beatmap import BeatmapResp + from .favourite_beatmapset import FavouriteBeatmapset - beatmaps = [ - await BeatmapResp.from_db(beatmap, from_set=True) - for beatmap in await beatmapset.awaitable_attrs.beatmaps - ] + update = { + "beatmaps": [ + await BeatmapResp.from_db(beatmap, from_set=True) + for beatmap in await beatmapset.awaitable_attrs.beatmaps + ], + "hype": BeatmapHype( + current=beatmapset.hype_current, required=beatmapset.hype_required + ), + "availability": BeatmapAvailability( + more_information=beatmapset.availability_info, + download_disabled=beatmapset.download_disabled, + ), + "genre": BeatmapTranslationText( + name=beatmapset.beatmap_genre.name, + id=beatmapset.beatmap_genre.value, + ), + "language": BeatmapTranslationText( + name=beatmapset.beatmap_language.name, + id=beatmapset.beatmap_language.value, + ), + "nominations": BeatmapNominations( + required=beatmapset.nominations_required, + current=beatmapset.nominations_current, + ), + "status": beatmapset.beatmap_status.name.lower(), + "ranked": beatmapset.beatmap_status.value, + "is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING, + **beatmapset.model_dump(), + } + if session and user: + existing_favourite = ( + await session.exec( + select(FavouriteBeatmapset).where( + FavouriteBeatmapset.beatmapset_id == beatmapset.id + ) + ) + ).first() + update["has_favourited"] = existing_favourite is not None + + if session and "recent_favourites" in include: + recent_favourites = ( + await session.exec( + select(FavouriteBeatmapset) + .where( + FavouriteBeatmapset.beatmapset_id == beatmapset.id, + ) + .order_by(col(FavouriteBeatmapset.date).desc()) + .limit(50) + ) + ).all() + update["recent_favourites"] = [ + await UserResp.from_db( + await favourite.awaitable_attrs.user, + session=session, + include=BASE_INCLUDES, + ) + for favourite in recent_favourites + ] + + if session: + update["favourite_count"] = ( + await session.exec( + select(func.count()) + .select_from(FavouriteBeatmapset) + .where(FavouriteBeatmapset.beatmapset_id == beatmapset.id) + ) + ).one() return cls.model_validate( - { - "beatmaps": beatmaps, - "hype": BeatmapHype( - current=beatmapset.hype_current, required=beatmapset.hype_required - ), - "availability": BeatmapAvailability( - more_information=beatmapset.availability_info, - download_disabled=beatmapset.download_disabled, - ), - "genre": BeatmapTranslationText( - name=beatmapset.beatmap_genre.name, - id=beatmapset.beatmap_genre.value, - ), - "language": BeatmapTranslationText( - name=beatmapset.beatmap_language.name, - id=beatmapset.beatmap_language.value, - ), - "nominations": BeatmapNominations( - required=beatmapset.nominations_required, - current=beatmapset.nominations_current, - ), - "status": beatmapset.beatmap_status.name.lower(), - "ranked": beatmapset.beatmap_status.value, - "is_scoreable": beatmapset.beatmap_status > BeatmapRankStatus.PENDING, - **beatmapset.model_dump(), - } + update, ) diff --git a/app/database/favourite_beatmapset.py b/app/database/favourite_beatmapset.py new file mode 100644 index 0000000..51bd578 --- /dev/null +++ b/app/database/favourite_beatmapset.py @@ -0,0 +1,53 @@ +import datetime + +from app.database.beatmapset import Beatmapset +from app.database.lazer_user import User + +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlmodel import ( + BigInteger, + Column, + DateTime, + Field, + ForeignKey, + Relationship, + SQLModel, +) + + +class FavouriteBeatmapset(AsyncAttrs, SQLModel, table=True): + __tablename__ = "favourite_beatmapset" # pyright: ignore[reportAssignmentType] + id: int | None = Field( + default=None, + sa_column=Column(BigInteger, autoincrement=True, primary_key=True), + exclude=True, + ) + user_id: int = Field( + default=None, + sa_column=Column( + BigInteger, + ForeignKey("lazer_users.id"), + index=True, + ), + ) + beatmapset_id: int = Field( + default=None, + sa_column=Column( + ForeignKey("beatmapsets.id"), + index=True, + ), + ) + date: datetime.datetime = Field( + default=datetime.datetime.now(datetime.UTC), + sa_column=Column( + DateTime, + ), + ) + + user: User = Relationship(back_populates="favourite_beatmapsets") + beatmapset: Beatmapset = Relationship( + sa_relationship_kwargs={ + "lazy": "selectin", + }, + back_populates="favourites", + ) diff --git a/app/database/lazer_user.py b/app/database/lazer_user.py index 1337cc2..3bd751b 100644 --- a/app/database/lazer_user.py +++ b/app/database/lazer_user.py @@ -1,7 +1,6 @@ from datetime import UTC, datetime from typing import TYPE_CHECKING, NotRequired, TypedDict -from app.dependencies.database import get_redis from app.models.model import UTCBaseModel from app.models.score import GameMode from app.models.user import Country, Page, RankHistory @@ -28,7 +27,8 @@ from sqlmodel import ( from sqlmodel.ext.asyncio.session import AsyncSession if TYPE_CHECKING: - from app.database.relationship import RelationshipResp + from .favourite_beatmapset import FavouriteBeatmapset + from .relationship import RelationshipResp class Kudosu(TypedDict): @@ -144,6 +144,9 @@ class User(AsyncAttrs, UserBase, table=True): back_populates="user" ) monthly_playcounts: list[MonthlyPlaycounts] = Relationship(back_populates="user") + favourite_beatmapsets: list["FavouriteBeatmapset"] = Relationship( + back_populates="user" + ) email: str = Field(max_length=254, unique=True, index=True, exclude=True) priv: int = Field(default=1, exclude=True) @@ -201,6 +204,8 @@ class UserResp(UserBase): include: list[str] = [], ruleset: GameMode | None = None, ) -> "UserResp": + from app.dependencies.database import get_redis + from .best_score import BestScore from .relationship import Relationship, RelationshipResp, RelationshipType @@ -320,3 +325,9 @@ SEARCH_INCLUDED = [ "achievements", "monthly_playcounts", ] + +BASE_INCLUDES = [ + "team", + "daily_challenge_user_stats", + "statistics", +] diff --git a/app/database/score.py b/app/database/score.py index 32ddb6c..79cb005 100644 --- a/app/database/score.py +++ b/app/database/score.py @@ -169,7 +169,9 @@ class ScoreResp(ScoreBase): assert score.id await score.awaitable_attrs.beatmap s.beatmap = await BeatmapResp.from_db(score.beatmap) - s.beatmapset = await BeatmapsetResp.from_db(score.beatmap.beatmapset) + s.beatmapset = await BeatmapsetResp.from_db( + score.beatmap.beatmapset, session=session, user=score.user + ) s.is_perfect_combo = s.max_combo == s.beatmap.max_combo s.legacy_perfect = s.max_combo == s.beatmap.max_combo s.ruleset_id = MODE_TO_INT[score.gamemode] @@ -669,7 +671,7 @@ async def process_score( acc=score.accuracy, ) session.add(best_score) - session.delete(previous_pp_best) if previous_pp_best else None + await session.delete(previous_pp_best) if previous_pp_best else None await session.commit() await session.refresh(score) await session.refresh(score_token) diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 0a25562..9574bdb 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -50,7 +50,7 @@ async def lookup_beatmap( if beatmap is None: raise HTTPException(status_code=404, detail="Beatmap not found") - return await BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap, session=db, user=current_user) @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) @@ -62,7 +62,7 @@ async def get_beatmap( ): try: beatmap = await Beatmap.get_or_fetch(db, fetcher, bid) - return await BeatmapResp.from_db(beatmap) + return await BeatmapResp.from_db(beatmap, session=db, user=current_user) except HTTPError: raise HTTPException(status_code=404, detail="Beatmap not found") @@ -90,7 +90,12 @@ async def batch_get_beatmaps( await db.exec(select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)) ).all() - return BatchGetResp(beatmaps=[await BeatmapResp.from_db(bm) for bm in beatmaps]) + return BatchGetResp( + beatmaps=[ + await BeatmapResp.from_db(bm, session=db, user=current_user) + for bm in beatmaps + ] + ) @router.post( diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index b82678d..b4d2e4c 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -1,10 +1,8 @@ from __future__ import annotations -from app.database import ( - Beatmapset, - BeatmapsetResp, - User, -) +from typing import Literal + +from app.database import Beatmapset, BeatmapsetResp, FavouriteBeatmapset, User from app.dependencies.database import get_db from app.dependencies.fetcher import get_fetcher from app.dependencies.user import get_current_user @@ -12,7 +10,7 @@ from app.fetcher import Fetcher from .api_router import router -from fastapi import Depends, HTTPException, Query +from fastapi import Depends, Form, HTTPException, Query from fastapi.responses import RedirectResponse from httpx import HTTPStatusError from sqlmodel import select @@ -34,7 +32,9 @@ async def get_beatmapset( except HTTPStatusError: raise HTTPException(status_code=404, detail="Beatmapset not found") else: - resp = await BeatmapsetResp.from_db(beatmapset) + resp = await BeatmapsetResp.from_db( + beatmapset, session=db, include=["recent_favourites"], user=current_user + ) return resp @@ -53,3 +53,34 @@ async def download_beatmapset( return RedirectResponse( f"https://api.nerinyan.moe/d/{beatmapset}?noVideo={no_video}" ) + + +@router.post("/beatmapsets/{beatmapset}/favourites", tags=["beatmapset"]) +async def favourite_beatmapset( + beatmapset: int, + action: Literal["favourite", "unfavourite"] = Form(), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + existing_favourite = ( + await db.exec( + select(FavouriteBeatmapset).where( + FavouriteBeatmapset.user_id == current_user.id, + FavouriteBeatmapset.beatmapset_id == beatmapset, + ) + ) + ).first() + + if action == "favourite" and existing_favourite: + raise HTTPException(status_code=400, detail="Already favourited") + elif action == "unfavourite" and not existing_favourite: + raise HTTPException(status_code=400, detail="Not favourited") + + if action == "favourite": + favourite = FavouriteBeatmapset( + user_id=current_user.id, beatmapset_id=beatmapset + ) + db.add(favourite) + else: + await db.delete(existing_favourite) + await db.commit() diff --git a/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py b/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py new file mode 100644 index 0000000..84bae15 --- /dev/null +++ b/migrations/versions/1178d0758ebf_beatmapset_support_favourite_count.py @@ -0,0 +1,40 @@ +"""beatmapset: support favourite count + +Revision ID: 1178d0758ebf +Revises: +Create Date: 2025-08-01 04:05:09.882800 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "1178d0758ebf" +down_revision: str | Sequence[str] | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("beatmapsets", "favourite_count") + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "beatmapsets", + sa.Column( + "favourite_count", mysql.INTEGER(), autoincrement=False, nullable=False + ), + ) + # ### end Alembic commands ###