feat(beatmap): 添加谱面用户标签功能 (#25)
* feat(tags): 添加 beatmap tags 相关功能 - 新增 BeatmapTags 模型类,用于表示 beatmap 的标签信息 - 实现加载标签数据、根据 ID 获取标签、获取所有标签等功能 * feat(database): 新增 BeatmapTagVote 数据库模型和迁移脚本 * fix(database): 修改 BeatmapTagVote 模型并创建新表 - 将 BeatmapTagVote 模型的表名从 "beatmap_tag_votes" 改为 "beatmap_tags" - 创建新的数据库迁移文件以替换错误的原迁移文件 - 删除错误的迁移文件 "4a827ddba235_add_table_beatmap_tags.py" * feat(tags): 添加用户标签功能 - 在 BeatmapResp 类中添加了 top_tag_ids 和 current_user_tag_ids 字段 - 新增了 /tags 相关的路由,包括获取所有标签和投票/取消投票功能 - 实现了标签投票和取消投票的数据库操作 * fix(tags): 修复标签投票查询和返回过程中的逻辑问题 - 修复 BeatmapResp 类中 current_user_tag_ids 字段的查询逻辑 - 优化 vote_beatmap_tags 函数中的标签验证过程 * fix(tags): add suggested changes from reviews - 在 BeatmapResp 中添加 top_tag_ids 和 current_user_tag_ids 字段 - 实现用户标签投票功能,包括检查用户是否有资格投票 - 优化标签数据的加载方式 - 调整标签相关路由,增加路径参数描述 * fix(tags): apply changes from review * fix(tag): apply changes from review suggests - 更新标签接口文档,统一参数描述 - 修改标签投票接口状态码为 204 - 优化标签投票接口的用户认证方式 - 改进标签相关错误处理,使用更友好的错误信息 * fix(tag): use client authorization * chore(linter): auto fix by pre-commit hooks --------- Co-authored-by: MingxuanGame <MingxuanGame@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from .beatmap import (
|
||||
BeatmapResp,
|
||||
)
|
||||
from .beatmap_playcounts import BeatmapPlaycounts, BeatmapPlaycountsResp
|
||||
from .beatmap_tags import BeatmapTagVote
|
||||
from .beatmapset import (
|
||||
Beatmapset,
|
||||
BeatmapsetResp,
|
||||
@@ -74,6 +75,7 @@ __all__ = [
|
||||
"BeatmapPlaycountsResp",
|
||||
"BeatmapRating",
|
||||
"BeatmapResp",
|
||||
"BeatmapTagVote",
|
||||
"Beatmapset",
|
||||
"BeatmapsetResp",
|
||||
"BestScore",
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from app.calculator import calculate_beatmap_attribute
|
||||
from app.config import settings
|
||||
from app.database.beatmap_tags import BeatmapTagVote
|
||||
from app.database.failtime import FailTime, FailTimeResp
|
||||
from app.models.beatmap import BeatmapAttributes, BeatmapRankStatus
|
||||
from app.models.mods import APIMod
|
||||
@@ -13,6 +14,7 @@ from app.models.score import GameMode
|
||||
from .beatmap_playcounts import BeatmapPlaycounts
|
||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, col, exists, func, select
|
||||
@@ -130,6 +132,11 @@ class Beatmap(BeatmapBase, table=True):
|
||||
return beatmap
|
||||
|
||||
|
||||
class APIBeatmapTag(BaseModel):
|
||||
tag_id: int
|
||||
count: int
|
||||
|
||||
|
||||
class BeatmapResp(BeatmapBase):
|
||||
id: int
|
||||
beatmapset_id: int
|
||||
@@ -143,6 +150,8 @@ class BeatmapResp(BeatmapBase):
|
||||
playcount: int = 0
|
||||
passcount: int = 0
|
||||
failtimes: FailTimeResp | None = None
|
||||
top_tag_ids: list[APIBeatmapTag] | None = None
|
||||
current_user_tag_ids: list[int] | None = None
|
||||
|
||||
@classmethod
|
||||
async def from_db(
|
||||
@@ -191,6 +200,29 @@ class BeatmapResp(BeatmapBase):
|
||||
)
|
||||
)
|
||||
).one()
|
||||
|
||||
all_votes = (
|
||||
await session.exec(
|
||||
select(BeatmapTagVote.tag_id, func.count().label("vote_count"))
|
||||
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
||||
.group_by(col(BeatmapTagVote.tag_id))
|
||||
)
|
||||
).all()
|
||||
top_tag_ids: list[dict[str, int]] = []
|
||||
for id, votes in all_votes:
|
||||
top_tag_ids.append({"tag_id": id, "count": votes})
|
||||
beatmap_["top_tag_ids"] = top_tag_ids
|
||||
|
||||
if user is not None:
|
||||
beatmap_["current_user_tag_ids"] = (
|
||||
await session.exec(
|
||||
select(BeatmapTagVote.tag_id)
|
||||
.where(BeatmapTagVote.beatmap_id == beatmap.id)
|
||||
.where(BeatmapTagVote.user_id == user.id)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
beatmap_["current_user_tag_ids"] = []
|
||||
return cls.model_validate(beatmap_)
|
||||
|
||||
|
||||
|
||||
10
app/database/beatmap_tags.py
Normal file
10
app/database/beatmap_tags.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class BeatmapTagVote(SQLModel, table=True):
|
||||
__tablename__: str = "beatmap_tags"
|
||||
tag_id: int = Field(primary_key=True, index=True, default=None)
|
||||
beatmap_id: int = Field(primary_key=True, index=True, default=None)
|
||||
user_id: int = Field(primary_key=True, index=True, default=None)
|
||||
47
app/models/tags.py
Normal file
47
app/models/tags.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from app.log import logger
|
||||
from app.path import STATIC_DIR
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BeatmapTags(BaseModel):
|
||||
id: int
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
ruleset_id: int | None = None
|
||||
|
||||
|
||||
ALL_TAGS: dict[int, BeatmapTags] = {}
|
||||
|
||||
|
||||
def load_tags() -> None:
|
||||
if len(ALL_TAGS) > 0:
|
||||
return
|
||||
if not (STATIC_DIR / "beatmap_tags.json").exists():
|
||||
logger.warning("beatmap tags description file does not exist, using no tags")
|
||||
return
|
||||
tags_list = json.loads((STATIC_DIR / "beatmap_tags.json").read_text())
|
||||
for tag in tags_list:
|
||||
if tag["id"] in ALL_TAGS:
|
||||
logger.error("find duplicated beatmap tag id")
|
||||
logger.info(f"tag {ALL_TAGS[tag['id']].name} and tag {tag['name']} have the same tag id")
|
||||
raise ValueError("duplicated tag id found")
|
||||
ALL_TAGS[tag["id"]] = BeatmapTags.model_validate(tag)
|
||||
|
||||
|
||||
def get_tag_by_id(id: int) -> BeatmapTags:
|
||||
load_tags()
|
||||
tag = ALL_TAGS.get(id)
|
||||
if tag is None:
|
||||
logger.error(f"tag id {id} not found")
|
||||
raise ValueError("tag id not found")
|
||||
return tag
|
||||
|
||||
|
||||
def get_all_tags() -> list[BeatmapTags]:
|
||||
load_tags()
|
||||
return list(ALL_TAGS.values())
|
||||
@@ -10,6 +10,7 @@ from . import ( # noqa: F401
|
||||
room,
|
||||
score,
|
||||
session_verify,
|
||||
tags,
|
||||
user,
|
||||
)
|
||||
from .router import router as api_v2_router
|
||||
|
||||
122
app/router/v2/tags.py
Normal file
122
app/router/v2/tags.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.beatmap import Beatmap
|
||||
from app.database.beatmap_tags import BeatmapTagVote
|
||||
from app.database.lazer_user import User
|
||||
from app.database.score import Score
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_client_user
|
||||
from app.models.score import Rank
|
||||
from app.models.tags import BeatmapTags, get_all_tags, get_tag_by_id
|
||||
|
||||
from .router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Path
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import col, exists, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class APITagCollection(BaseModel):
|
||||
tags: list[BeatmapTags]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tags",
|
||||
tags=["用户标签"],
|
||||
response_model=APITagCollection,
|
||||
name="获取所有标签",
|
||||
description="获取所有可用的谱面标签。",
|
||||
)
|
||||
async def router_get_all_tags():
|
||||
return APITagCollection(tags=get_all_tags())
|
||||
|
||||
|
||||
async def check_user_can_vote(user: User, beatmap_id: int, session: AsyncSession):
|
||||
user_beatmap_score = (
|
||||
await session.exec(
|
||||
select(exists())
|
||||
.where(Score.beatmap_id == beatmap_id)
|
||||
.where(Score.user_id == user.id)
|
||||
.where(col(Score.rank).not_in([Rank.F, Rank.D]))
|
||||
.where(col(Score.beatmap).has(col(Beatmap.mode) == Score.gamemode))
|
||||
)
|
||||
).first()
|
||||
if user_beatmap_score is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@router.put(
|
||||
"/beatmaps/{beatmap_id}/tags/{tag_id}",
|
||||
tags=["用户标签"],
|
||||
status_code=204,
|
||||
name="为谱面投票标签",
|
||||
description="为指定谱面添加标签投票。",
|
||||
)
|
||||
async def vote_beatmap_tags(
|
||||
beatmap_id: int = Path(..., description="谱面 ID"),
|
||||
tag_id: int = Path(..., description="标签 ID"),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_client_user),
|
||||
):
|
||||
try:
|
||||
get_tag_by_id(tag_id)
|
||||
beatmap = (await session.exec(select(exists()).where(Beatmap.id == beatmap_id))).first()
|
||||
if beatmap is None or (not beatmap):
|
||||
raise HTTPException(404, "beatmap not found")
|
||||
previous_votes = (
|
||||
await session.exec(
|
||||
select(BeatmapTagVote)
|
||||
.where(BeatmapTagVote.beatmap_id == beatmap_id)
|
||||
.where(BeatmapTagVote.tag_id == tag_id)
|
||||
.where(BeatmapTagVote.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
if previous_votes is None:
|
||||
if check_user_can_vote(current_user, beatmap_id, session):
|
||||
new_vote = BeatmapTagVote(tag_id=tag_id, beatmap_id=beatmap_id, user_id=current_user.id)
|
||||
session.add(new_vote)
|
||||
await session.commit()
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Tag is not found")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/beatmaps/{beatmap_id}/tags/{tag_id}",
|
||||
tags=["用户标签", "谱面"],
|
||||
status_code=204,
|
||||
name="取消谱面标签投票",
|
||||
description="取消对指定谱面标签的投票。",
|
||||
)
|
||||
async def devote_beatmap_tags(
|
||||
beatmap_id: int = Path(..., description="谱面 ID"),
|
||||
tag_id: int = Path(..., description="标签 ID"),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_client_user),
|
||||
):
|
||||
"""
|
||||
取消对谱面指定标签的投票。
|
||||
|
||||
- **beatmap_id**: 谱面ID
|
||||
- **tag_id**: 标签ID
|
||||
"""
|
||||
try:
|
||||
tag = get_tag_by_id(tag_id)
|
||||
assert tag is not None
|
||||
beatmap = await session.get(Beatmap, beatmap_id)
|
||||
if beatmap is None:
|
||||
raise HTTPException(404, "beatmap not found")
|
||||
previous_votes = (
|
||||
await session.exec(
|
||||
select(BeatmapTagVote)
|
||||
.where(BeatmapTagVote.beatmap_id == beatmap_id)
|
||||
.where(BeatmapTagVote.tag_id == tag_id)
|
||||
.where(BeatmapTagVote.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
if previous_votes is not None:
|
||||
await session.delete(previous_votes)
|
||||
await session.commit()
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Tag is not found")
|
||||
50
migrations/versions/ebaa317ad928_add_beatmap_tag.py
Normal file
50
migrations/versions/ebaa317ad928_add_beatmap_tag.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""add beatmap_tag
|
||||
|
||||
|
||||
Revision ID: ebaa317ad928
|
||||
Revises: 24a32515292d
|
||||
Create Date: 2025-08-29 12:29:23.267557
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "ebaa317ad928"
|
||||
down_revision: str | Sequence[str] | None = "24a32515292d"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"beatmap_tags",
|
||||
sa.Column("tag_id", sa.Integer(), nullable=False),
|
||||
sa.Column("beatmap_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("tag_id", "beatmap_id", "user_id"),
|
||||
)
|
||||
op.create_index(op.f("ix_beatmap_tags_beatmap_id"), "beatmap_tags", ["beatmap_id"], unique=False)
|
||||
op.create_index(op.f("ix_beatmap_tags_tag_id"), "beatmap_tags", ["tag_id"], unique=False)
|
||||
op.create_index(op.f("ix_beatmap_tags_user_id"), "beatmap_tags", ["user_id"], unique=False)
|
||||
op.drop_column("beatmapsets", "ratings")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("beatmapsets", sa.Column("ratings", mysql.JSON(), nullable=True))
|
||||
op.drop_index(op.f("ix_beatmap_tags_user_id"), table_name="beatmap_tags")
|
||||
op.drop_index(op.f("ix_beatmap_tags_tag_id"), table_name="beatmap_tags")
|
||||
op.drop_index(op.f("ix_beatmap_tags_beatmap_id"), table_name="beatmap_tags")
|
||||
op.drop_table("beatmap_tags")
|
||||
# ### end Alembic commands ###
|
||||
Reference in New Issue
Block a user