feat(beatmap): implement get beatmap arrtibutes
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
from datetime import datetime
|
||||
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.score import MODE_TO_INT, GameMode
|
||||
|
||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
||||
|
||||
from sqlalchemy import DECIMAL, Column, DateTime
|
||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -77,6 +79,7 @@ class Beatmap(BeatmapBase, table=True):
|
||||
)
|
||||
session.add(beatmap)
|
||||
await session.commit()
|
||||
await session.refresh(beatmap)
|
||||
return beatmap
|
||||
|
||||
@classmethod
|
||||
@@ -102,6 +105,30 @@ class Beatmap(BeatmapBase, table=True):
|
||||
await session.commit()
|
||||
return beatmaps
|
||||
|
||||
@classmethod
|
||||
async def get_or_fetch(
|
||||
cls, session: AsyncSession, bid: int, fetcher: Fetcher
|
||||
) -> "Beatmap":
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap)
|
||||
.where(Beatmap.id == bid)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
resp = await fetcher.get_beatmap(bid)
|
||||
r = await session.exec(
|
||||
select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)
|
||||
)
|
||||
if not r.first():
|
||||
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
|
||||
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
|
||||
return await Beatmap.from_resp(session, resp)
|
||||
return beatmap
|
||||
|
||||
|
||||
class BeatmapResp(BeatmapBase):
|
||||
id: int
|
||||
|
||||
@@ -2,7 +2,6 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
||||
from app.models.score import GameMode
|
||||
|
||||
from pydantic import BaseModel, model_serializer
|
||||
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
|
||||
@@ -69,7 +68,7 @@ class BeatmapNomination(TypedDict):
|
||||
beatmapset_id: int
|
||||
reset: bool
|
||||
user_id: int
|
||||
rulesets: dict[str, GameMode] | None
|
||||
rulesets: list[str] | None
|
||||
|
||||
|
||||
class BeatmapDescription(SQLModel):
|
||||
|
||||
@@ -2,7 +2,8 @@ from datetime import datetime
|
||||
import math
|
||||
|
||||
from app.database.user import User
|
||||
from app.models.score import MODE_TO_INT, APIMod, GameMode, Rank
|
||||
from app.models.mods import APIMod
|
||||
from app.models.score import MODE_TO_INT, GameMode, Rank
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .beatmapset import BeatmapsetResp
|
||||
|
||||
@@ -2,9 +2,10 @@ from __future__ import annotations
|
||||
|
||||
from .beatmap import BeatmapFetcher
|
||||
from .beatmapset import BeatmapsetFetcher
|
||||
from .osu_dot_direct import OsuDotDirectFetcher
|
||||
|
||||
|
||||
class Fetcher(BeatmapFetcher, BeatmapsetFetcher):
|
||||
class Fetcher(BeatmapFetcher, BeatmapsetFetcher, OsuDotDirectFetcher):
|
||||
"""A class that combines all fetchers for easy access."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.beatmap import BeatmapResp
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database.beatmap import BeatmapResp
|
||||
|
||||
|
||||
class BeatmapFetcher(BaseFetcher):
|
||||
async def get_beatmap(self, beatmap_id: int) -> BeatmapResp:
|
||||
async def get_beatmap(self, beatmap_id: int) -> "BeatmapResp":
|
||||
from app.database.beatmap import BeatmapResp
|
||||
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"https://osu.ppy.sh/api/v2/beatmaps/{beatmap_id}",
|
||||
|
||||
15
app/fetcher/osu_dot_direct.py
Normal file
15
app/fetcher/osu_dot_direct.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class OsuDotDirectFetcher(BaseFetcher):
|
||||
async def get_beatmap_raw(self, beatmap_id: int) -> str:
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"https://osu.direct/api/osu/{beatmap_id}/raw",
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
@@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BeatmapRankStatus(IntEnum):
|
||||
GRAVEYARD = -2
|
||||
@@ -45,3 +47,20 @@ class Language(IntEnum):
|
||||
RUSSIAN = 11
|
||||
POLISH = 12
|
||||
OTHER = 13
|
||||
|
||||
|
||||
class BeatmapAttributes(BaseModel):
|
||||
star_rating: float
|
||||
max_combo: int
|
||||
|
||||
# osu
|
||||
aim_difficulty: float | None = None
|
||||
aim_difficult_slider_count: float | None = None
|
||||
speed_difficulty: float | None = None
|
||||
speed_note_count: float | None = None
|
||||
slider_factor: float | None = None
|
||||
aim_difficult_strain_count: float | None = None
|
||||
speed_difficult_strain_count: float | None = None
|
||||
|
||||
# taiko
|
||||
mono_stamina_factor: float | None = None
|
||||
|
||||
56
app/models/mods.py
Normal file
56
app/models/mods.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class APIMod(TypedDict):
|
||||
acronym: str
|
||||
settings: dict[str, bool | float | str]
|
||||
|
||||
|
||||
# https://github.com/ppy/osu-api/wiki#mods
|
||||
LEGACY_MOD_TO_API_MOD = {
|
||||
(1 << 0): APIMod(acronym="NF", settings={}), # No Fail
|
||||
(1 << 1): APIMod(acronym="EZ", settings={}),
|
||||
(1 << 2): APIMod(acronym="TD", settings={}), # Touch Device
|
||||
(1 << 3): APIMod(acronym="HD", settings={}), # Hidden
|
||||
(1 << 4): APIMod(acronym="HR", settings={}), # Hard Rock
|
||||
(1 << 5): APIMod(acronym="SD", settings={}), # Sudden Death
|
||||
(1 << 6): APIMod(acronym="DT", settings={}), # Double Time
|
||||
(1 << 7): APIMod(acronym="RX", settings={}), # Relax
|
||||
(1 << 8): APIMod(acronym="HT", settings={}), # Half Time
|
||||
(1 << 9): APIMod(acronym="NC", settings={}), # Nightcore
|
||||
(1 << 10): APIMod(acronym="FL", settings={}), # Flashlight
|
||||
(1 << 11): APIMod(acronym="AT", settings={}), # Auto Play
|
||||
(1 << 12): APIMod(acronym="SO", settings={}), # Spun Out
|
||||
(1 << 13): APIMod(acronym="AP", settings={}), # Autopilot
|
||||
(1 << 14): APIMod(acronym="PF", settings={}), # Perfect
|
||||
(1 << 15): APIMod(acronym="4K", settings={}), # 4K
|
||||
(1 << 16): APIMod(acronym="5K", settings={}), # 5K
|
||||
(1 << 17): APIMod(acronym="6K", settings={}), # 6K
|
||||
(1 << 18): APIMod(acronym="7K", settings={}), # 7K
|
||||
(1 << 19): APIMod(acronym="8K", settings={}), # 8K
|
||||
(1 << 20): APIMod(acronym="FI", settings={}), # Fade In
|
||||
(1 << 21): APIMod(acronym="RD", settings={}), # Random
|
||||
(1 << 22): APIMod(acronym="CN", settings={}), # Cinema
|
||||
(1 << 23): APIMod(acronym="TP", settings={}), # Target Practice
|
||||
(1 << 24): APIMod(acronym="9K", settings={}), # 9K
|
||||
(1 << 25): APIMod(acronym="CO", settings={}), # Key Co-op
|
||||
(1 << 26): APIMod(acronym="1K", settings={}), # 1K
|
||||
(1 << 27): APIMod(acronym="2K", settings={}), # 2K
|
||||
(1 << 28): APIMod(acronym="3K", settings={}), # 3K
|
||||
(1 << 29): APIMod(acronym="SV2", settings={}), # Score V2
|
||||
(1 << 30): APIMod(acronym="MR", settings={}), # Mirror
|
||||
}
|
||||
|
||||
|
||||
def int_to_mods(mods: int) -> list[APIMod]:
|
||||
mod_list = []
|
||||
for mod in range(31):
|
||||
if mods & (1 << mod):
|
||||
mod_list.append(LEGACY_MOD_TO_API_MOD[(1 << mod)])
|
||||
if mods & (1 << 14):
|
||||
mod_list.remove(LEGACY_MOD_TO_API_MOD[(1 << 5)])
|
||||
if mods & (1 << 9):
|
||||
mod_list.remove(LEGACY_MOD_TO_API_MOD[(1 << 6)])
|
||||
return mod_list
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import rosu_pp_py as rosu
|
||||
|
||||
|
||||
class GameMode(str, Enum):
|
||||
@@ -10,6 +11,14 @@ class GameMode(str, Enum):
|
||||
FRUITS = "fruits"
|
||||
MANIA = "mania"
|
||||
|
||||
def to_rosu(self) -> rosu.GameMode:
|
||||
return {
|
||||
GameMode.OSU: rosu.GameMode.Osu,
|
||||
GameMode.TAIKO: rosu.GameMode.Taiko,
|
||||
GameMode.FRUITS: rosu.GameMode.Catch,
|
||||
GameMode.MANIA: rosu.GameMode.Mania,
|
||||
}[self]
|
||||
|
||||
|
||||
MODE_TO_INT = {
|
||||
GameMode.OSU: 0,
|
||||
@@ -32,11 +41,6 @@ class Rank(str, Enum):
|
||||
F = "f"
|
||||
|
||||
|
||||
class APIMod(TypedDict):
|
||||
acronym: str
|
||||
settings: dict[str, Any]
|
||||
|
||||
|
||||
# https://github.com/ppy/osu/blob/master/osu.Game/Rulesets/Scoring/HitResult.cs
|
||||
class HitResult(IntEnum):
|
||||
PERFECT = 0 # [Order(0)]
|
||||
|
||||
@@ -5,6 +5,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
|
||||
beatmapset,
|
||||
me,
|
||||
relationship,
|
||||
score,
|
||||
)
|
||||
from .api_router import router as api_router
|
||||
from .auth import router as auth_router
|
||||
|
||||
@@ -1,22 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
BeatmapResp,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.score import Score, ScoreResp
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod, int_to_mods
|
||||
from app.models.score import INT_TO_MODE, GameMode
|
||||
from app.utils import calculate_beatmap_attribute
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from httpx import HTTPStatusError
|
||||
from httpx import HTTPError, HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
import rosu_pp_py as rosu
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -29,32 +38,11 @@ async def get_beatmap(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmap = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload( # pyright: ignore[reportArgumentType]
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(Beatmap.id == bid)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
try:
|
||||
resp = await fetcher.get_beatmap(bid)
|
||||
r = await db.exec(
|
||||
select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)
|
||||
)
|
||||
if not r.first():
|
||||
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
|
||||
await Beatmapset.from_resp(db, set_resp, from_=resp.id)
|
||||
await Beatmap.from_resp(db, resp)
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
else:
|
||||
resp = BeatmapResp.from_db(beatmap)
|
||||
return resp
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, bid, fetcher)
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
|
||||
class BatchGetResp(BaseModel):
|
||||
@@ -75,7 +63,7 @@ async def batch_get_beatmaps(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
).selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
@@ -90,8 +78,8 @@ async def batch_get_beatmaps(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
).selectinload(
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
).selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
@@ -103,137 +91,52 @@ async def batch_get_beatmaps(
|
||||
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
||||
|
||||
|
||||
class BeatmapScores(BaseModel):
|
||||
scores: list[ScoreResp]
|
||||
userScore: ScoreResp | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores", tags=["beatmap"], response_model=BeatmapScores
|
||||
)
|
||||
async def get_beatmap_scores(
|
||||
beatmap: int,
|
||||
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
|
||||
mode: str = Query(None),
|
||||
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
|
||||
type: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="this server only contains lazer scores"
|
||||
)
|
||||
|
||||
all_scores = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.beatmap_id == beatmap)
|
||||
# .where(Score.mods == mods if mods else True)
|
||||
)
|
||||
).all()
|
||||
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
|
||||
return BeatmapScores(
|
||||
scores=[ScoreResp.from_db(score) for score in all_scores],
|
||||
userScore=ScoreResp.from_db(user_score) if user_score else None,
|
||||
)
|
||||
|
||||
|
||||
class BeatmapUserScore(BaseModel):
|
||||
position: int
|
||||
score: ScoreResp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores/users/{user}",
|
||||
@router.post(
|
||||
"/beatmaps/{beatmap}/attributes",
|
||||
tags=["beatmap"],
|
||||
response_model=BeatmapUserScore,
|
||||
response_model=BeatmapAttributes,
|
||||
)
|
||||
async def get_user_beatmap_score(
|
||||
async def get_beatmap_attributes(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
mode: str = Query(None),
|
||||
mods: str = Query(None), # TODO:添加mods筛选
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
mods: list[str] = Query(default_factory=list),
|
||||
ruleset: GameMode | None = Query(default=None),
|
||||
ruleset_id: int | None = Query(default=None),
|
||||
redis: Redis = Depends(get_redis),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(Score.gamemode==mode if mode is not None else True)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == user)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
if not user_score:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find user %s's score on this beatmap" % user
|
||||
)
|
||||
mods_ = []
|
||||
if mods and mods[0].isdigit():
|
||||
mods_ = int_to_mods(int(mods[0]))
|
||||
else:
|
||||
return BeatmapUserScore(
|
||||
position=user_score.position if user_score.position is not None else 0,
|
||||
score=ScoreResp.from_db(user_score),
|
||||
)
|
||||
for i in mods:
|
||||
try:
|
||||
mods_.append(json.loads(i))
|
||||
except json.JSONDecodeError:
|
||||
mods_.append(APIMod(acronym=i, settings={}))
|
||||
mods_.sort(key=lambda x: x["acronym"])
|
||||
if ruleset_id is not None and ruleset is None:
|
||||
ruleset = INT_TO_MODE[ruleset_id]
|
||||
if ruleset is None:
|
||||
beatmap_db = await Beatmap.get_or_fetch(db, beatmap, fetcher)
|
||||
ruleset = beatmap_db.mode
|
||||
key = (
|
||||
f"beatmap:{beatmap}:{ruleset}:"
|
||||
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
)
|
||||
if redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores/users/{user}/all",
|
||||
tags=["beatmap"],
|
||||
response_model=list[ScoreResp],
|
||||
)
|
||||
async def get_user_all_beatmap_scores(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
ruleset: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(status_code=404,detail="This server only contains non-legacy scores")
|
||||
all_user_scores=(
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
try:
|
||||
resp = await fetcher.get_beatmap_raw(beatmap)
|
||||
try:
|
||||
attr = await asyncio.get_event_loop().run_in_executor(
|
||||
None, calculate_beatmap_attribute, resp, ruleset, mods_
|
||||
)
|
||||
.where(Score.gamemode==ruleset if ruleset is not None else True)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == user)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
)
|
||||
).all()
|
||||
|
||||
return [ScoreResp.from_db(score) for score in all_user_scores]
|
||||
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
redis.set(key, attr.model_dump_json())
|
||||
return attr
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
156
app/router/score.py
Normal file
156
app/router/score.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.score import Score, ScoreResp
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class BeatmapScores(BaseModel):
|
||||
scores: list[ScoreResp]
|
||||
userScore: ScoreResp | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores", tags=["beatmap"], response_model=BeatmapScores
|
||||
)
|
||||
async def get_beatmap_scores(
|
||||
beatmap: int,
|
||||
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
|
||||
mode: str = Query(None),
|
||||
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
|
||||
type: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="this server only contains lazer scores"
|
||||
)
|
||||
|
||||
all_scores = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.beatmap_id == beatmap)
|
||||
# .where(Score.mods == mods if mods else True)
|
||||
)
|
||||
).all()
|
||||
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
|
||||
return BeatmapScores(
|
||||
scores=[ScoreResp.from_db(score) for score in all_scores],
|
||||
userScore=ScoreResp.from_db(user_score) if user_score else None,
|
||||
)
|
||||
|
||||
|
||||
class BeatmapUserScore(BaseModel):
|
||||
position: int
|
||||
score: ScoreResp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores/users/{user}",
|
||||
tags=["beatmap"],
|
||||
response_model=BeatmapUserScore,
|
||||
)
|
||||
async def get_user_beatmap_score(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
mode: str = Query(None),
|
||||
mods: str = Query(None), # TODO:添加mods筛选
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(Score.gamemode == mode if mode is not None else True)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == user)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
if not user_score:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
|
||||
)
|
||||
else:
|
||||
return BeatmapUserScore(
|
||||
position=user_score.position if user_score.position is not None else 0,
|
||||
score=ScoreResp.from_db(user_score),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores/users/{user}/all",
|
||||
tags=["beatmap"],
|
||||
response_model=list[ScoreResp],
|
||||
)
|
||||
async def get_user_all_beatmap_scores(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
ruleset: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(Score.gamemode == ruleset if ruleset is not None else True)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == user)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
)
|
||||
).all()
|
||||
|
||||
return [ScoreResp.from_db(score) for score in all_user_scores]
|
||||
28
app/utils.py
28
app/utils.py
@@ -8,6 +8,9 @@ from app.database import (
|
||||
LazerUserStatistics,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod
|
||||
from app.models.score import GameMode
|
||||
from app.models.user import (
|
||||
Country,
|
||||
Cover,
|
||||
@@ -22,6 +25,8 @@ from app.models.user import (
|
||||
UserAchievement,
|
||||
)
|
||||
|
||||
import rosu_pp_py as rosu
|
||||
|
||||
|
||||
async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User:
|
||||
"""将数据库用户模型转换为API用户模型(使用 Lazer 表)"""
|
||||
@@ -425,3 +430,26 @@ def get_country_name(country_code: str) -> str:
|
||||
# 可以添加更多国家
|
||||
}
|
||||
return country_names.get(country_code, "Unknown")
|
||||
|
||||
|
||||
def calculate_beatmap_attribute(
|
||||
beatmap: str,
|
||||
gamemode: GameMode | None = None,
|
||||
mods: int | list[APIMod] | list[str] = 0,
|
||||
) -> BeatmapAttributes:
|
||||
map = rosu.Beatmap(content=beatmap)
|
||||
if gamemode is not None:
|
||||
map.convert(gamemode.to_rosu(), mods)
|
||||
diff = rosu.Difficulty(mods=mods).calculate(map)
|
||||
return BeatmapAttributes(
|
||||
star_rating=diff.stars,
|
||||
max_combo=diff.max_combo,
|
||||
aim_difficulty=diff.aim,
|
||||
aim_difficult_slider_count=diff.aim_difficult_slider_count,
|
||||
speed_difficulty=diff.speed,
|
||||
speed_note_count=diff.speed_note_count,
|
||||
slider_factor=diff.slider_factor,
|
||||
aim_difficult_strain_count=diff.aim_difficult_strain_count,
|
||||
speed_difficult_strain_count=diff.speed_difficult_strain_count,
|
||||
mono_stamina_factor=diff.stamina,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,8 @@ from app.database.beatmapset import Beatmapset
|
||||
from app.database.score import Score
|
||||
from app.dependencies.database import create_tables, engine
|
||||
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
||||
from app.models.score import APIMod, GameMode, Rank
|
||||
from app.models.mods import APIMod
|
||||
from app.models.score import GameMode, Rank
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -18,6 +18,7 @@ dependencies = [
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"python-multipart>=0.0.6",
|
||||
"redis>=5.0.1",
|
||||
"rosu-pp-py>=3.1.0",
|
||||
"sqlalchemy>=2.0.23",
|
||||
"sqlmodel>=0.0.24",
|
||||
"uvicorn[standard]>=0.24.0",
|
||||
|
||||
37
uv.lock
generated
37
uv.lock
generated
@@ -569,6 +569,7 @@ dependencies = [
|
||||
{ name = "python-jose", extra = ["cryptography"] },
|
||||
{ name = "python-multipart" },
|
||||
{ name = "redis" },
|
||||
{ name = "rosu-pp-py" },
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "sqlmodel" },
|
||||
{ name = "uvicorn", extra = ["standard"] },
|
||||
@@ -596,6 +597,7 @@ requires-dist = [
|
||||
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.6" },
|
||||
{ name = "redis", specifier = ">=5.0.1" },
|
||||
{ name = "rosu-pp-py", specifier = ">=3.1.0" },
|
||||
{ name = "sqlalchemy", specifier = ">=2.0.23" },
|
||||
{ name = "sqlmodel", specifier = ">=0.0.24" },
|
||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.24.0" },
|
||||
@@ -843,6 +845,41 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/13/67/e60968d3b0e077495a8fee89cf3f2373db98e528288a48f1ee44967f6e8c/redis-6.2.0-py3-none-any.whl", hash = "sha256:c8ddf316ee0aab65f04a11229e94a64b2618451dab7a67cb2f77eb799d872d5e", size = 278659, upload-time = "2025-05-28T05:01:16.955Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rosu-pp-py"
|
||||
version = "3.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6c/19/b44c30066c6e85cd6a4fd8a8983be91d2336a4e7f0ef04e576bc9b1d7c63/rosu_pp_py-3.1.0.tar.gz", hash = "sha256:4aa64eb5e68b8957357f9b304047db285423b207ad913e28829ccfcd5348d41a", size = 31144, upload-time = "2025-06-03T17:14:27.461Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/e8/a4a899997304049801c27e1affa4ce7ea60d2ba16caa7c6739a6387f1790/rosu_pp_py-3.1.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d584ffabb96958d2c90a696a2634fa7336966b429ee0f0d03397763fc73d3237", size = 556133, upload-time = "2025-06-03T17:13:21.925Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/12/be/fc90d17277335a0225b88dd06790f6056bc0e4385e610df4aed471f692d8/rosu_pp_py-3.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ccf125864d0483281ada86e913b8133b53cb62455842bd418a5a4966abb47a67", size = 513148, upload-time = "2025-06-03T17:13:23.071Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/9d/26893b6182bd83694974ae6931647801c060b55844696089c463645290d3/rosu_pp_py-3.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:587a16e928c02f1b9439d8140d53ed8ff7ccab8f663b813c44cab9c3a89a1d46", size = 526976, upload-time = "2025-06-03T17:13:24.127Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/12/8fd68740f722ffbb792f37577a515c727de52ef14fcf46f22d5c2cdde03e/rosu_pp_py-3.1.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:721b4f9e0c1f17402d23915f8cb8695e476c8841319c71097c5f71dab6c91f1c", size = 550737, upload-time = "2025-06-03T17:13:25.264Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/86/ce7b0587800ce1da69b672e7b11ea5ec8469c17bedc3d51efcd400261830/rosu_pp_py-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f5fbb9ae415d1f71ca8e3a153b47e584415ae081816d0b60b70a1410c7ed562", size = 566922, upload-time = "2025-06-03T17:13:26.356Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/63/8b752c2116777fa03f46d3793fc6e87e262a21a71460a49d503d59690cec/rosu_pp_py-3.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:06977b5211da327c27a921e284f5cb678e4a89f00ce76520fee2c33f09b28ab8", size = 705614, upload-time = "2025-06-03T17:13:27.882Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/71/88d4051beaad89a29813038c4e391952f017ffe2199efee4469955257167/rosu_pp_py-3.1.0-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:9dbd319039d5803a85e7263a22c808f93970b8bc0ed9e846d66050995d19fdb5", size = 814233, upload-time = "2025-06-03T17:13:29.544Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8d/13/a5b55a928edd2b70fb6d3268f7f344356cc781fd2194076a75af86faedb5/rosu_pp_py-3.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:32d039b60c80bc4c6d4d6ee50918a44ebd95ab36d154da0dcc24af38858d0807", size = 738492, upload-time = "2025-06-03T17:13:31.157Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/18/67fa30cab0ff4179533fd2c89e4d8141d01968278ea095a42a06e1350b39/rosu_pp_py-3.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:f21edd037b6e30c019a721d374dc1e72e62c10f1a9a5b22773f1b5e321cf2a1a", size = 460036, upload-time = "2025-06-03T17:13:32.141Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/04/d752d7cfb71afcbecd0513ffcc716abcf5c3b2b4b9a4e44a3c7e7fc43fba/rosu_pp_py-3.1.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:61275ddfedd7f67bcb5c42a136fb30a66aeb7e07323c59a67db590de687bd78d", size = 552307, upload-time = "2025-06-03T17:13:33.203Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/76/e7d3415cdd384b8ea0a2f461c87d9b451108cbded46e2e88676611a99875/rosu_pp_py-3.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:04aacaa6faba9d0892ba5584884cfaf42eb1a7678dc0dff453fc6988e8be8809", size = 508787, upload-time = "2025-06-03T17:13:34.507Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/a0/c59168f75b32b6cf3e41d5d44dc478b113eebe38166e6b87af193ebb8d4f/rosu_pp_py-3.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9eecd7a78aeb82abf39ac7db670350a42b6eb8a54eb4a8a13610def02c56d005", size = 525740, upload-time = "2025-06-03T17:13:35.631Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/c0/7b498f8ecd6650d718291994c5e6d3931e5572e408d8d7bc9000f2441575/rosu_pp_py-3.1.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3dd5118614335e9084f076f9fa88fb139e64a9e1750c0d8020c8e8abe9e42dce", size = 550091, upload-time = "2025-06-03T17:13:36.733Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/21/85f67440c93bc22135e6e43f6fc1d35d184b9c1523416acfae4b8721d9e5/rosu_pp_py-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:edbd67da486af4fbf5d53cd310fddc280a67d06274aea5eb3e322ffc66e82479", size = 566542, upload-time = "2025-06-03T17:13:38.308Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/ed/1d3727d327097edf2ecf8a39a267d5f2ba7a82ce2f7c71e1be5b6c278870/rosu_pp_py-3.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:af295819cda6df49324179e5c3986eb4215d6c456a055620ec30716ed22ec97c", size = 704380, upload-time = "2025-06-03T17:13:39.839Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a3/4d/db4fb9bcd1cdebbc761728a8684d700559a5b44e5d2baec262e07907917a/rosu_pp_py-3.1.0-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:b0367959b9ef74f51f1cc414d587b6dabab00390496a855a89073b55e08330b0", size = 813664, upload-time = "2025-06-03T17:13:41.052Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/a9/3ec4502f4f44c0e22b7658308def31c96320e339b89cdf474c2612b40351/rosu_pp_py-3.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:adf103385961c01859ae99ded0c289e03f5ab33d86ecabdd4e8f3139c84c6240", size = 738024, upload-time = "2025-06-03T17:13:42.132Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/f6/d33cde2f911ff2fdedbbc2be6b249e29f3a65e11acd1b645df77ece0747a/rosu_pp_py-3.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:8dc48f45aff62fc2798e3a4adf4596d9e810079f16650a98c8ed6cf1a37e506b", size = 458391, upload-time = "2025-06-03T17:13:43.706Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/53/3f68a24d75c65b789200241f490c2379d86a3760f48dc9e22348f0a619c9/rosu_pp_py-3.1.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:5cda7206c2e8c96fdaccf0b531d0614df5e30ad6cd1bf217ec5556406294ed6c", size = 552011, upload-time = "2025-06-03T17:13:44.889Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/95/6251e0d7f615c148d17e5151b89e3da7da89ef5363de921b5957b5407510/rosu_pp_py-3.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d54606719ac93ccadbcb40acd3dda41f6e319e075303b6bbfdebf784ed451281", size = 508659, upload-time = "2025-06-03T17:13:45.968Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7f/2b/23d449a97fb6d34ced7c421a13669d98a5522ce79fabd8151a873d3d152a/rosu_pp_py-3.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec88b95845851018e95e49f3f8610dc989a2cfc74273a8c40fe7ef94e4f37a6a", size = 525367, upload-time = "2025-06-03T17:13:47.56Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/52/9a/c8879dd4f62632d8928cc147bca705eb7e2a21dc0ad43307d6f68e0a3b41/rosu_pp_py-3.1.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f39332ec3c479c68396d0f6ea09ab3ee77ca595ab14f4739581ca8a631dc33d8", size = 549600, upload-time = "2025-06-03T17:13:48.717Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e8/86/a0154a1b3149bd25884ea8009c70b9792a960dbfd4172b65ace0e55394b4/rosu_pp_py-3.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4a290f7920b0015e0a9d829428cce7948ae98043985b237b0d68e2b28c8dba3", size = 566082, upload-time = "2025-06-03T17:13:49.761Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/ee/897f5cb48dfe067549dee39cb265581782d1daebc4dd27b1c1bc58551755/rosu_pp_py-3.1.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:11ab7db7903a2752b7c53458e689b2f1f724bee1e99d627d447dee69e7668299", size = 704157, upload-time = "2025-06-03T17:13:51.175Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/7d/67ec98bed784807d543106bb517879149bed3544d1987bdf59eab6ced79e/rosu_pp_py-3.1.0-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:bc5350a00a37dc273f7e734364a27820f2c274a5a1715fe3b0ef62bd071fae54", size = 813310, upload-time = "2025-06-03T17:13:52.421Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/02/fbbb54b21cec66fbe8e2884a73837e0c4e97ca5c625587d90b378c5354f0/rosu_pp_py-3.1.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:28f171e6042d68df379be0536173626b2ae51ddc4a7b1881209ff384c468918a", size = 737638, upload-time = "2025-06-03T17:13:53.709Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/9e/f951ef3508cbfbaf36dcee3bd828eb8f922a21b2791bc852074adc1835a1/rosu_pp_py-3.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:a327e627bc56e55bc8dd3fcc26abcfe60af1497f310dad7aea3ef798434f2e9b", size = 457855, upload-time = "2025-06-03T17:13:55.317Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rsa"
|
||||
version = "4.9.1"
|
||||
|
||||
Reference in New Issue
Block a user