diff --git a/app/database/__init__.py b/app/database/__init__.py index 2d5f155..b3a0e3d 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -5,6 +5,7 @@ from .beatmap import ( BeatmapResp, ) from .beatmap_playcounts import BeatmapPlaycounts, BeatmapPlaycountsResp +from .beatmap_tags import BeatmapTagVote from .beatmapset import ( Beatmapset, BeatmapsetResp, @@ -74,6 +75,7 @@ __all__ = [ "BeatmapPlaycountsResp", "BeatmapRating", "BeatmapResp", + "BeatmapTagVote", "Beatmapset", "BeatmapsetResp", "BestScore", diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 1f83aa2..6ca0d30 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from app.calculator import calculate_beatmap_attribute from app.config import settings +from app.database.beatmap_tags import BeatmapTagVote from app.database.failtime import FailTime, FailTimeResp from app.models.beatmap import BeatmapAttributes, BeatmapRankStatus from app.models.mods import APIMod @@ -13,6 +14,7 @@ from app.models.score import GameMode from .beatmap_playcounts import BeatmapPlaycounts from .beatmapset import Beatmapset, BeatmapsetResp +from pydantic import BaseModel from redis.asyncio import Redis from sqlalchemy import Column, DateTime from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select @@ -130,6 +132,11 @@ class Beatmap(BeatmapBase, table=True): return beatmap +class APIBeatmapTag(BaseModel): + tag_id: int + count: int + + class BeatmapResp(BeatmapBase): id: int beatmapset_id: int @@ -143,6 +150,8 @@ class BeatmapResp(BeatmapBase): playcount: int = 0 passcount: int = 0 failtimes: FailTimeResp | None = None + top_tag_ids: list[APIBeatmapTag] | None = None + current_user_tag_ids: list[int] | None = None @classmethod async def from_db( @@ -191,6 +200,29 @@ class BeatmapResp(BeatmapBase): ) ) ).one() + + all_votes = ( + await session.exec( + select(BeatmapTagVote.tag_id, func.count().label("vote_count")) + .where(BeatmapTagVote.beatmap_id == beatmap.id) + .group_by(col(BeatmapTagVote.tag_id)) + ) + ).all() + top_tag_ids: list[dict[str, int]] = [] + for id, votes in all_votes: + top_tag_ids.append({"tag_id": id, "count": votes}) + beatmap_["top_tag_ids"] = top_tag_ids + + if user is not None: + beatmap_["current_user_tag_ids"] = ( + await session.exec( + select(BeatmapTagVote.tag_id) + .where(BeatmapTagVote.beatmap_id == beatmap.id) + .where(BeatmapTagVote.user_id == user.id) + ) + ).all() + else: + beatmap_["current_user_tag_ids"] = [] return cls.model_validate(beatmap_) diff --git a/app/database/beatmap_tags.py b/app/database/beatmap_tags.py new file mode 100644 index 0000000..69d5a31 --- /dev/null +++ b/app/database/beatmap_tags.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from sqlmodel import Field, SQLModel + + +class BeatmapTagVote(SQLModel, table=True): + __tablename__: str = "beatmap_tags" + tag_id: int = Field(primary_key=True, index=True, default=None) + beatmap_id: int = Field(primary_key=True, index=True, default=None) + user_id: int = Field(primary_key=True, index=True, default=None) diff --git a/app/models/tags.py b/app/models/tags.py new file mode 100644 index 0000000..ddcc24e --- /dev/null +++ b/app/models/tags.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import json + +from app.log import logger +from app.path import STATIC_DIR + +from pydantic import BaseModel + + +class BeatmapTags(BaseModel): + id: int + name: str = "" + description: str = "" + ruleset_id: int | None = None + + +ALL_TAGS: dict[int, BeatmapTags] = {} + + +def load_tags() -> None: + if len(ALL_TAGS) > 0: + return + if not (STATIC_DIR / "beatmap_tags.json").exists(): + logger.warning("beatmap tags description file does not exist, using no tags") + return + tags_list = json.loads((STATIC_DIR / "beatmap_tags.json").read_text()) + for tag in tags_list: + if tag["id"] in ALL_TAGS: + logger.error("find duplicated beatmap tag id") + logger.info(f"tag {ALL_TAGS[tag['id']].name} and tag {tag['name']} have the same tag id") + raise ValueError("duplicated tag id found") + ALL_TAGS[tag["id"]] = BeatmapTags.model_validate(tag) + + +def get_tag_by_id(id: int) -> BeatmapTags: + load_tags() + tag = ALL_TAGS.get(id) + if tag is None: + logger.error(f"tag id {id} not found") + raise ValueError("tag id not found") + return tag + + +def get_all_tags() -> list[BeatmapTags]: + load_tags() + return list(ALL_TAGS.values()) diff --git a/app/router/v2/__init__.py b/app/router/v2/__init__.py index 968b030..0bd06cd 100644 --- a/app/router/v2/__init__.py +++ b/app/router/v2/__init__.py @@ -10,6 +10,7 @@ from . import ( # noqa: F401 room, score, session_verify, + tags, user, ) from .router import router as api_v2_router diff --git a/app/router/v2/tags.py b/app/router/v2/tags.py new file mode 100644 index 0000000..af124ef --- /dev/null +++ b/app/router/v2/tags.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from app.database.beatmap import Beatmap +from app.database.beatmap_tags import BeatmapTagVote +from app.database.lazer_user import User +from app.database.score import Score +from app.dependencies.database import get_db +from app.dependencies.user import get_client_user +from app.models.score import Rank +from app.models.tags import BeatmapTags, get_all_tags, get_tag_by_id + +from .router import router + +from fastapi import Depends, HTTPException, Path +from pydantic import BaseModel +from sqlmodel import col, exists, select +from sqlmodel.ext.asyncio.session import AsyncSession + + +class APITagCollection(BaseModel): + tags: list[BeatmapTags] + + +@router.get( + "/tags", + tags=["用户标签"], + response_model=APITagCollection, + name="获取所有标签", + description="获取所有可用的谱面标签。", +) +async def router_get_all_tags(): + return APITagCollection(tags=get_all_tags()) + + +async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession): + user_beatmap_score = ( + await session.exec( + select(exists()) + .where(Score.beatmap_id == beatmap_id) + .where(Score.user_id == user.id) + .where(col(Score.rank).not_in([Rank.F, Rank.D])) + .where(col(Score.beatmap).has(col(Beatmap.mode) == Score.gamemode)) + ) + ).first() + if user_beatmap_score is None: + return False + return True + + +@router.put( + "/beatmaps/{beatmap_id}/tags/{tag_id}", + tags=["用户标签"], + status_code=204, + name="为谱面投票标签", + description="为指定谱面添加标签投票。", +) +async def vote_beatmap_tags( + beatmap_id: int = Path(..., description="谱面 ID"), + tag_id: int = Path(..., description="标签 ID"), + session: AsyncSession = Depends(get_db), + current_user: User = Depends(get_client_user), +): + try: + get_tag_by_id(tag_id) + beatmap = (await session.exec(select(exists()).where(Beatmap.id == beatmap_id))).first() + if beatmap is None or (not beatmap): + raise HTTPException(404, "beatmap not found") + previous_votes = ( + await session.exec( + select(BeatmapTagVote) + .where(BeatmapTagVote.beatmap_id == beatmap_id) + .where(BeatmapTagVote.tag_id == tag_id) + .where(BeatmapTagVote.user_id == current_user.id) + ) + ).first() + if previous_votes is None: + if check_user_can_vote(current_user, beatmap_id, session): + new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id) + session.add(new_vote) + await session.commit() + except ValueError: + raise HTTPException(400, "Tag is not found") + + +@router.delete( + "/beatmaps/{beatmap_id}/tags/{tag_id}", + tags=["用户标签", "谱面"], + status_code=204, + name="取消谱面标签投票", + description="取消对指定谱面标签的投票。", +) +async def devote_beatmap_tags( + beatmap_id: int = Path(..., description="谱面 ID"), + tag_id: int = Path(..., description="标签 ID"), + session: AsyncSession = Depends(get_db), + current_user: User = Depends(get_client_user), +): + """ + 取消对谱面指定标签的投票。 + + - **beatmap_id**: 谱面ID + - **tag_id**: 标签ID + """ + try: + tag = get_tag_by_id(tag_id) + assert tag is not None + beatmap = await session.get(Beatmap, beatmap_id) + if beatmap is None: + raise HTTPException(404, "beatmap not found") + previous_votes = ( + await session.exec( + select(BeatmapTagVote) + .where(BeatmapTagVote.beatmap_id == beatmap_id) + .where(BeatmapTagVote.tag_id == tag_id) + .where(BeatmapTagVote.user_id == current_user.id) + ) + ).first() + if previous_votes is not None: + await session.delete(previous_votes) + await session.commit() + except ValueError: + raise HTTPException(400, "Tag is not found") diff --git a/migrations/versions/ebaa317ad928_add_beatmap_tag.py b/migrations/versions/ebaa317ad928_add_beatmap_tag.py new file mode 100644 index 0000000..4d0d5bb --- /dev/null +++ b/migrations/versions/ebaa317ad928_add_beatmap_tag.py @@ -0,0 +1,50 @@ +"""add beatmap_tag + + +Revision ID: ebaa317ad928 +Revises: 24a32515292d +Create Date: 2025-08-29 12:29:23.267557 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "ebaa317ad928" +down_revision: str | Sequence[str] | None = "24a32515292d" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "beatmap_tags", + sa.Column("tag_id", sa.Integer(), nullable=False), + sa.Column("beatmap_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("tag_id", "beatmap_id", "user_id"), + ) + op.create_index(op.f("ix_beatmap_tags_beatmap_id"), "beatmap_tags", ["beatmap_id"], unique=False) + op.create_index(op.f("ix_beatmap_tags_tag_id"), "beatmap_tags", ["tag_id"], unique=False) + op.create_index(op.f("ix_beatmap_tags_user_id"), "beatmap_tags", ["user_id"], unique=False) + op.drop_column("beatmapsets", "ratings") + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("beatmapsets", sa.Column("ratings", mysql.JSON(), nullable=True)) + op.drop_index(op.f("ix_beatmap_tags_user_id"), table_name="beatmap_tags") + op.drop_index(op.f("ix_beatmap_tags_tag_id"), table_name="beatmap_tags") + op.drop_index(op.f("ix_beatmap_tags_beatmap_id"), table_name="beatmap_tags") + op.drop_table("beatmap_tags") + # ### end Alembic commands ###