feat(beatmap): implement get beatmap arrtibutes
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
from datetime import datetime
|
||||
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.score import MODE_TO_INT, GameMode
|
||||
|
||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
||||
|
||||
from sqlalchemy import DECIMAL, Column, DateTime
|
||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -77,6 +79,7 @@ class Beatmap(BeatmapBase, table=True):
|
||||
)
|
||||
session.add(beatmap)
|
||||
await session.commit()
|
||||
await session.refresh(beatmap)
|
||||
return beatmap
|
||||
|
||||
@classmethod
|
||||
@@ -102,6 +105,30 @@ class Beatmap(BeatmapBase, table=True):
|
||||
await session.commit()
|
||||
return beatmaps
|
||||
|
||||
@classmethod
|
||||
async def get_or_fetch(
|
||||
cls, session: AsyncSession, bid: int, fetcher: Fetcher
|
||||
) -> "Beatmap":
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap)
|
||||
.where(Beatmap.id == bid)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
resp = await fetcher.get_beatmap(bid)
|
||||
r = await session.exec(
|
||||
select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)
|
||||
)
|
||||
if not r.first():
|
||||
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
|
||||
await Beatmapset.from_resp(session, set_resp, from_=resp.id)
|
||||
return await Beatmap.from_resp(session, resp)
|
||||
return beatmap
|
||||
|
||||
|
||||
class BeatmapResp(BeatmapBase):
|
||||
id: int
|
||||
|
||||
@@ -2,7 +2,6 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
||||
from app.models.score import GameMode
|
||||
|
||||
from pydantic import BaseModel, model_serializer
|
||||
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
|
||||
@@ -69,7 +68,7 @@ class BeatmapNomination(TypedDict):
|
||||
beatmapset_id: int
|
||||
reset: bool
|
||||
user_id: int
|
||||
rulesets: dict[str, GameMode] | None
|
||||
rulesets: list[str] | None
|
||||
|
||||
|
||||
class BeatmapDescription(SQLModel):
|
||||
|
||||
@@ -2,7 +2,8 @@ from datetime import datetime
|
||||
import math
|
||||
|
||||
from app.database.user import User
|
||||
from app.models.score import MODE_TO_INT, APIMod, GameMode, Rank
|
||||
from app.models.mods import APIMod
|
||||
from app.models.score import MODE_TO_INT, GameMode, Rank
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .beatmapset import BeatmapsetResp
|
||||
|
||||
@@ -2,9 +2,10 @@ from __future__ import annotations
|
||||
|
||||
from .beatmap import BeatmapFetcher
|
||||
from .beatmapset import BeatmapsetFetcher
|
||||
from .osu_dot_direct import OsuDotDirectFetcher
|
||||
|
||||
|
||||
class Fetcher(BeatmapFetcher, BeatmapsetFetcher):
|
||||
class Fetcher(BeatmapFetcher, BeatmapsetFetcher, OsuDotDirectFetcher):
|
||||
"""A class that combines all fetchers for easy access."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.beatmap import BeatmapResp
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database.beatmap import BeatmapResp
|
||||
|
||||
|
||||
class BeatmapFetcher(BaseFetcher):
|
||||
async def get_beatmap(self, beatmap_id: int) -> BeatmapResp:
|
||||
async def get_beatmap(self, beatmap_id: int) -> "BeatmapResp":
|
||||
from app.database.beatmap import BeatmapResp
|
||||
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"https://osu.ppy.sh/api/v2/beatmaps/{beatmap_id}",
|
||||
|
||||
15
app/fetcher/osu_dot_direct.py
Normal file
15
app/fetcher/osu_dot_direct.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class OsuDotDirectFetcher(BaseFetcher):
|
||||
async def get_beatmap_raw(self, beatmap_id: int) -> str:
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"https://osu.direct/api/osu/{beatmap_id}/raw",
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
@@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BeatmapRankStatus(IntEnum):
|
||||
GRAVEYARD = -2
|
||||
@@ -45,3 +47,20 @@ class Language(IntEnum):
|
||||
RUSSIAN = 11
|
||||
POLISH = 12
|
||||
OTHER = 13
|
||||
|
||||
|
||||
class BeatmapAttributes(BaseModel):
|
||||
star_rating: float
|
||||
max_combo: int
|
||||
|
||||
# osu
|
||||
aim_difficulty: float | None = None
|
||||
aim_difficult_slider_count: float | None = None
|
||||
speed_difficulty: float | None = None
|
||||
speed_note_count: float | None = None
|
||||
slider_factor: float | None = None
|
||||
aim_difficult_strain_count: float | None = None
|
||||
speed_difficult_strain_count: float | None = None
|
||||
|
||||
# taiko
|
||||
mono_stamina_factor: float | None = None
|
||||
|
||||
56
app/models/mods.py
Normal file
56
app/models/mods.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class APIMod(TypedDict):
|
||||
acronym: str
|
||||
settings: dict[str, bool | float | str]
|
||||
|
||||
|
||||
# https://github.com/ppy/osu-api/wiki#mods
|
||||
LEGACY_MOD_TO_API_MOD = {
|
||||
(1 << 0): APIMod(acronym="NF", settings={}), # No Fail
|
||||
(1 << 1): APIMod(acronym="EZ", settings={}),
|
||||
(1 << 2): APIMod(acronym="TD", settings={}), # Touch Device
|
||||
(1 << 3): APIMod(acronym="HD", settings={}), # Hidden
|
||||
(1 << 4): APIMod(acronym="HR", settings={}), # Hard Rock
|
||||
(1 << 5): APIMod(acronym="SD", settings={}), # Sudden Death
|
||||
(1 << 6): APIMod(acronym="DT", settings={}), # Double Time
|
||||
(1 << 7): APIMod(acronym="RX", settings={}), # Relax
|
||||
(1 << 8): APIMod(acronym="HT", settings={}), # Half Time
|
||||
(1 << 9): APIMod(acronym="NC", settings={}), # Nightcore
|
||||
(1 << 10): APIMod(acronym="FL", settings={}), # Flashlight
|
||||
(1 << 11): APIMod(acronym="AT", settings={}), # Auto Play
|
||||
(1 << 12): APIMod(acronym="SO", settings={}), # Spun Out
|
||||
(1 << 13): APIMod(acronym="AP", settings={}), # Autopilot
|
||||
(1 << 14): APIMod(acronym="PF", settings={}), # Perfect
|
||||
(1 << 15): APIMod(acronym="4K", settings={}), # 4K
|
||||
(1 << 16): APIMod(acronym="5K", settings={}), # 5K
|
||||
(1 << 17): APIMod(acronym="6K", settings={}), # 6K
|
||||
(1 << 18): APIMod(acronym="7K", settings={}), # 7K
|
||||
(1 << 19): APIMod(acronym="8K", settings={}), # 8K
|
||||
(1 << 20): APIMod(acronym="FI", settings={}), # Fade In
|
||||
(1 << 21): APIMod(acronym="RD", settings={}), # Random
|
||||
(1 << 22): APIMod(acronym="CN", settings={}), # Cinema
|
||||
(1 << 23): APIMod(acronym="TP", settings={}), # Target Practice
|
||||
(1 << 24): APIMod(acronym="9K", settings={}), # 9K
|
||||
(1 << 25): APIMod(acronym="CO", settings={}), # Key Co-op
|
||||
(1 << 26): APIMod(acronym="1K", settings={}), # 1K
|
||||
(1 << 27): APIMod(acronym="2K", settings={}), # 2K
|
||||
(1 << 28): APIMod(acronym="3K", settings={}), # 3K
|
||||
(1 << 29): APIMod(acronym="SV2", settings={}), # Score V2
|
||||
(1 << 30): APIMod(acronym="MR", settings={}), # Mirror
|
||||
}
|
||||
|
||||
|
||||
def int_to_mods(mods: int) -> list[APIMod]:
|
||||
mod_list = []
|
||||
for mod in range(31):
|
||||
if mods & (1 << mod):
|
||||
mod_list.append(LEGACY_MOD_TO_API_MOD[(1 << mod)])
|
||||
if mods & (1 << 14):
|
||||
mod_list.remove(LEGACY_MOD_TO_API_MOD[(1 << 5)])
|
||||
if mods & (1 << 9):
|
||||
mod_list.remove(LEGACY_MOD_TO_API_MOD[(1 << 6)])
|
||||
return mod_list
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import rosu_pp_py as rosu
|
||||
|
||||
|
||||
class GameMode(str, Enum):
|
||||
@@ -10,6 +11,14 @@ class GameMode(str, Enum):
|
||||
FRUITS = "fruits"
|
||||
MANIA = "mania"
|
||||
|
||||
def to_rosu(self) -> rosu.GameMode:
|
||||
return {
|
||||
GameMode.OSU: rosu.GameMode.Osu,
|
||||
GameMode.TAIKO: rosu.GameMode.Taiko,
|
||||
GameMode.FRUITS: rosu.GameMode.Catch,
|
||||
GameMode.MANIA: rosu.GameMode.Mania,
|
||||
}[self]
|
||||
|
||||
|
||||
MODE_TO_INT = {
|
||||
GameMode.OSU: 0,
|
||||
@@ -32,11 +41,6 @@ class Rank(str, Enum):
|
||||
F = "f"
|
||||
|
||||
|
||||
class APIMod(TypedDict):
|
||||
acronym: str
|
||||
settings: dict[str, Any]
|
||||
|
||||
|
||||
# https://github.com/ppy/osu/blob/master/osu.Game/Rulesets/Scoring/HitResult.cs
|
||||
class HitResult(IntEnum):
|
||||
PERFECT = 0 # [Order(0)]
|
||||
|
||||
@@ -5,6 +5,7 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
|
||||
beatmapset,
|
||||
me,
|
||||
relationship,
|
||||
score,
|
||||
)
|
||||
from .api_router import router as api_router
|
||||
from .auth import router as auth_router
|
||||
|
||||
@@ -1,22 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
BeatmapResp,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.score import Score, ScoreResp
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.fetcher import get_fetcher
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod, int_to_mods
|
||||
from app.models.score import INT_TO_MODE, GameMode
|
||||
from app.utils import calculate_beatmap_attribute
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from httpx import HTTPStatusError
|
||||
from httpx import HTTPError, HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
import rosu_pp_py as rosu
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -29,32 +38,11 @@ async def get_beatmap(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
beatmap = (
|
||||
await db.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload( # pyright: ignore[reportArgumentType]
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(Beatmap.id == bid)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
try:
|
||||
resp = await fetcher.get_beatmap(bid)
|
||||
r = await db.exec(
|
||||
select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)
|
||||
)
|
||||
if not r.first():
|
||||
set_resp = await fetcher.get_beatmapset(resp.beatmapset_id)
|
||||
await Beatmapset.from_resp(db, set_resp, from_=resp.id)
|
||||
await Beatmap.from_resp(db, resp)
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
else:
|
||||
resp = BeatmapResp.from_db(beatmap)
|
||||
return resp
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, bid, fetcher)
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
|
||||
class BatchGetResp(BaseModel):
|
||||
@@ -75,7 +63,7 @@ async def batch_get_beatmaps(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
).selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
@@ -90,8 +78,8 @@ async def batch_get_beatmaps(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
).selectinload(
|
||||
Beatmap.beatmapset # pyright: ignore[reportArgumentType]
|
||||
).selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
@@ -103,137 +91,52 @@ async def batch_get_beatmaps(
|
||||
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
||||
|
||||
|
||||
class BeatmapScores(BaseModel):
|
||||
scores: list[ScoreResp]
|
||||
userScore: ScoreResp | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores", tags=["beatmap"], response_model=BeatmapScores
|
||||
)
|
||||
async def get_beatmap_scores(
|
||||
beatmap: int,
|
||||
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
|
||||
mode: str = Query(None),
|
||||
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
|
||||
type: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="this server only contains lazer scores"
|
||||
)
|
||||
|
||||
all_scores = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.beatmap_id == beatmap)
|
||||
# .where(Score.mods == mods if mods else True)
|
||||
)
|
||||
).all()
|
||||
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
|
||||
return BeatmapScores(
|
||||
scores=[ScoreResp.from_db(score) for score in all_scores],
|
||||
userScore=ScoreResp.from_db(user_score) if user_score else None,
|
||||
)
|
||||
|
||||
|
||||
class BeatmapUserScore(BaseModel):
|
||||
position: int
|
||||
score: ScoreResp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores/users/{user}",
|
||||
@router.post(
|
||||
"/beatmaps/{beatmap}/attributes",
|
||||
tags=["beatmap"],
|
||||
response_model=BeatmapUserScore,
|
||||
response_model=BeatmapAttributes,
|
||||
)
|
||||
async def get_user_beatmap_score(
|
||||
async def get_beatmap_attributes(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
mode: str = Query(None),
|
||||
mods: str = Query(None), # TODO:添加mods筛选
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
mods: list[str] = Query(default_factory=list),
|
||||
ruleset: GameMode | None = Query(default=None),
|
||||
ruleset_id: int | None = Query(default=None),
|
||||
redis: Redis = Depends(get_redis),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(Score.gamemode==mode if mode is not None else True)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == user)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
if not user_score:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Cannot find user %s's score on this beatmap" % user
|
||||
)
|
||||
mods_ = []
|
||||
if mods and mods[0].isdigit():
|
||||
mods_ = int_to_mods(int(mods[0]))
|
||||
else:
|
||||
return BeatmapUserScore(
|
||||
position=user_score.position if user_score.position is not None else 0,
|
||||
score=ScoreResp.from_db(user_score),
|
||||
)
|
||||
for i in mods:
|
||||
try:
|
||||
mods_.append(json.loads(i))
|
||||
except json.JSONDecodeError:
|
||||
mods_.append(APIMod(acronym=i, settings={}))
|
||||
mods_.sort(key=lambda x: x["acronym"])
|
||||
if ruleset_id is not None and ruleset is None:
|
||||
ruleset = INT_TO_MODE[ruleset_id]
|
||||
if ruleset is None:
|
||||
beatmap_db = await Beatmap.get_or_fetch(db, beatmap, fetcher)
|
||||
ruleset = beatmap_db.mode
|
||||
key = (
|
||||
f"beatmap:{beatmap}:{ruleset}:"
|
||||
f"{hashlib.md5(str(mods_).encode()).hexdigest()}:attributes"
|
||||
)
|
||||
if redis.exists(key):
|
||||
return BeatmapAttributes.model_validate_json(redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores/users/{user}/all",
|
||||
tags=["beatmap"],
|
||||
response_model=list[ScoreResp],
|
||||
)
|
||||
async def get_user_all_beatmap_scores(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
ruleset: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(status_code=404,detail="This server only contains non-legacy scores")
|
||||
all_user_scores=(
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
try:
|
||||
resp = await fetcher.get_beatmap_raw(beatmap)
|
||||
try:
|
||||
attr = await asyncio.get_event_loop().run_in_executor(
|
||||
None, calculate_beatmap_attribute, resp, ruleset, mods_
|
||||
)
|
||||
.where(Score.gamemode==ruleset if ruleset is not None else True)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == user)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
)
|
||||
).all()
|
||||
|
||||
return [ScoreResp.from_db(score) for score in all_user_scores]
|
||||
except rosu.ConvertError as e: # pyright: ignore[reportAttributeAccessIssue]
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
redis.set(key, attr.model_dump_json())
|
||||
return attr
|
||||
except HTTPStatusError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
156
app/router/score.py
Normal file
156
app/router/score.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.score import Score, ScoreResp
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class BeatmapScores(BaseModel):
|
||||
scores: list[ScoreResp]
|
||||
userScore: ScoreResp | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores", tags=["beatmap"], response_model=BeatmapScores
|
||||
)
|
||||
async def get_beatmap_scores(
|
||||
beatmap: int,
|
||||
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
|
||||
mode: str = Query(None),
|
||||
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
|
||||
type: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="this server only contains lazer scores"
|
||||
)
|
||||
|
||||
all_scores = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.beatmap_id == beatmap)
|
||||
# .where(Score.mods == mods if mods else True)
|
||||
)
|
||||
).all()
|
||||
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
|
||||
return BeatmapScores(
|
||||
scores=[ScoreResp.from_db(score) for score in all_scores],
|
||||
userScore=ScoreResp.from_db(user_score) if user_score else None,
|
||||
)
|
||||
|
||||
|
||||
class BeatmapUserScore(BaseModel):
|
||||
position: int
|
||||
score: ScoreResp
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores/users/{user}",
|
||||
tags=["beatmap"],
|
||||
response_model=BeatmapUserScore,
|
||||
)
|
||||
async def get_user_beatmap_score(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
mode: str = Query(None),
|
||||
mods: str = Query(None), # TODO:添加mods筛选
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(Score.gamemode == mode if mode is not None else True)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == user)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
if not user_score:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Cannot find user {user}'s score on this beatmap"
|
||||
)
|
||||
else:
|
||||
return BeatmapUserScore(
|
||||
position=user_score.position if user_score.position is not None else 0,
|
||||
score=ScoreResp.from_db(user_score),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/beatmaps/{beatmap}/scores/users/{user}/all",
|
||||
tags=["beatmap"],
|
||||
response_model=list[ScoreResp],
|
||||
)
|
||||
async def get_user_all_beatmap_scores(
|
||||
beatmap: int,
|
||||
user: int,
|
||||
legacy_only: bool = Query(None),
|
||||
ruleset: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if legacy_only:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="This server only contains non-legacy scores"
|
||||
)
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
.where(Score.gamemode == ruleset if ruleset is not None else True)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == user)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
)
|
||||
).all()
|
||||
|
||||
return [ScoreResp.from_db(score) for score in all_user_scores]
|
||||
28
app/utils.py
28
app/utils.py
@@ -8,6 +8,9 @@ from app.database import (
|
||||
LazerUserStatistics,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod
|
||||
from app.models.score import GameMode
|
||||
from app.models.user import (
|
||||
Country,
|
||||
Cover,
|
||||
@@ -22,6 +25,8 @@ from app.models.user import (
|
||||
UserAchievement,
|
||||
)
|
||||
|
||||
import rosu_pp_py as rosu
|
||||
|
||||
|
||||
async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User:
|
||||
"""将数据库用户模型转换为API用户模型(使用 Lazer 表)"""
|
||||
@@ -425,3 +430,26 @@ def get_country_name(country_code: str) -> str:
|
||||
# 可以添加更多国家
|
||||
}
|
||||
return country_names.get(country_code, "Unknown")
|
||||
|
||||
|
||||
def calculate_beatmap_attribute(
|
||||
beatmap: str,
|
||||
gamemode: GameMode | None = None,
|
||||
mods: int | list[APIMod] | list[str] = 0,
|
||||
) -> BeatmapAttributes:
|
||||
map = rosu.Beatmap(content=beatmap)
|
||||
if gamemode is not None:
|
||||
map.convert(gamemode.to_rosu(), mods)
|
||||
diff = rosu.Difficulty(mods=mods).calculate(map)
|
||||
return BeatmapAttributes(
|
||||
star_rating=diff.stars,
|
||||
max_combo=diff.max_combo,
|
||||
aim_difficulty=diff.aim,
|
||||
aim_difficult_slider_count=diff.aim_difficult_slider_count,
|
||||
speed_difficulty=diff.speed,
|
||||
speed_note_count=diff.speed_note_count,
|
||||
slider_factor=diff.slider_factor,
|
||||
aim_difficult_strain_count=diff.aim_difficult_strain_count,
|
||||
speed_difficult_strain_count=diff.speed_difficult_strain_count,
|
||||
mono_stamina_factor=diff.stamina,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user