chore(merge): merge pull request #7 from GooGuTeam/feat/solo-play

feat: 单人游戏
This commit is contained in:
MingxuanGame
2025-07-28 16:53:20 +08:00
committed by GitHub
43 changed files with 6155 additions and 751 deletions

421
.gitignore vendored
View File

@@ -1,209 +1,212 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
#poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
#pdm.lock
#pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
#pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Cursor
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
bancho.py-master/*
.vscode/settings.json
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
#poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
#pdm.lock
#pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
#pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Cursor
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
bancho.py-master/*
.vscode/settings.json
# runtime file
replays/

View File

@@ -34,7 +34,7 @@ class Settings:
# SignalR 设置
SIGNALR_NEGOTIATE_TIMEOUT: int = int(os.getenv("SIGNALR_NEGOTIATE_TIMEOUT", "30"))
SIGNALR_PING_INTERVAL: int = int(os.getenv("SIGNALR_PING_INTERVAL", "120"))
SIGNALR_PING_INTERVAL: int = int(os.getenv("SIGNALR_PING_INTERVAL", "15"))
# Fetcher 设置
FETCHER_CLIENT_ID: str = os.getenv("FETCHER_CLIENT_ID", "")

View File

@@ -9,6 +9,13 @@ from .beatmapset import (
)
from .legacy import LegacyOAuthToken, LegacyUserStatistics
from .relationship import Relationship, RelationshipResp, RelationshipType
from .score import (
Score,
ScoreBase,
ScoreResp,
ScoreStatistics,
)
from .score_token import ScoreToken, ScoreTokenResp
from .team import Team, TeamMember
from .user import (
DailyChallengeStats,
@@ -57,6 +64,12 @@ __all__ = [
"Relationship",
"RelationshipResp",
"RelationshipType",
"Score",
"ScoreBase",
"ScoreResp",
"ScoreStatistics",
"ScoreToken",
"ScoreTokenResp",
"Team",
"TeamMember",
"User",

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import Column, DateTime
from sqlmodel import Field, Relationship, SQLModel
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
@@ -12,7 +12,9 @@ class OAuthToken(SQLModel, table=True):
__tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("users.id"), index=True)
)
access_token: str = Field(max_length=500, unique=True)
refresh_token: str = Field(max_length=500, unique=True)
token_type: str = Field(default="Bearer", max_length=20)

View File

@@ -1,6 +1,6 @@
from datetime import datetime
from typing import TYPE_CHECKING
from app.fetcher import Fetcher
from app.models.beatmap import BeatmapRankStatus
from app.models.score import MODE_TO_INT, GameMode
@@ -11,6 +11,9 @@ from sqlalchemy.orm import joinedload
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
if TYPE_CHECKING:
from app.fetcher import Fetcher
class BeatmapOwner(SQLModel):
id: int
@@ -65,6 +68,10 @@ class Beatmap(BeatmapBase, table=True):
# optional
beatmapset: Beatmapset = Relationship(back_populates="beatmaps")
@property
def can_ranked(self) -> bool:
return self.beatmap_status > BeatmapRankStatus.PENDING
@classmethod
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
d = resp.model_dump()
@@ -107,19 +114,25 @@ class Beatmap(BeatmapBase, table=True):
@classmethod
async def get_or_fetch(
cls, session: AsyncSession, bid: int, fetcher: Fetcher
cls,
session: AsyncSession,
fetcher: "Fetcher",
bid: int | None = None,
md5: str | None = None,
) -> "Beatmap":
beatmap = (
await session.exec(
select(Beatmap)
.where(Beatmap.id == bid)
.where(
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
)
.options(
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
)
)
).first()
if not beatmap:
resp = await fetcher.get_beatmap(bid)
resp = await fetcher.get_beatmap(bid, md5)
r = await session.exec(
select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)
)

View File

@@ -2,6 +2,7 @@ from datetime import datetime
from typing import TYPE_CHECKING, TypedDict, cast
from app.models.beatmap import BeatmapRankStatus, Genre, Language
from app.models.score import GameMode
from pydantic import BaseModel, model_serializer
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
@@ -68,7 +69,7 @@ class BeatmapNomination(TypedDict):
beatmapset_id: int
reset: bool
user_id: int
rulesets: list[str] | None
rulesets: list[GameMode] | None
class BeatmapDescription(SQLModel):

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
from sqlalchemy import JSON, Column, DateTime
from sqlalchemy.orm import Mapped
from sqlmodel import Field, Relationship, SQLModel
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
@@ -16,7 +16,7 @@ class LegacyUserStatistics(SQLModel, table=True):
__tablename__ = "user_statistics" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
mode: str = Field(max_length=10) # osu, taiko, fruits, mania
# 基本统计
@@ -77,7 +77,7 @@ class LegacyOAuthToken(SQLModel, table=True):
__tablename__ = "legacy_oauth_tokens" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
access_token: str = Field(max_length=255, index=True)
refresh_token: str = Field(max_length=255, index=True)
expires_at: datetime = Field(sa_column=Column(DateTime))

View File

@@ -4,7 +4,10 @@ from .user import User
from pydantic import BaseModel
from sqlmodel import (
BigInteger,
Column,
Field,
ForeignKey,
Relationship as SQLRelationship,
SQLModel,
select,
@@ -20,10 +23,22 @@ class RelationshipType(str, Enum):
class Relationship(SQLModel, table=True):
__tablename__ = "relationship" # pyright: ignore[reportAssignmentType]
user_id: int = Field(
default=None, foreign_key="users.id", primary_key=True, index=True
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
index=True,
),
)
target_id: int = Field(
default=None, foreign_key="users.id", primary_key=True, index=True
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
index=True,
),
)
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
target: "User" = SQLRelationship(

View File

@@ -2,15 +2,36 @@ from datetime import datetime
import math
from app.database.user import User
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import APIMod
from app.models.score import MODE_TO_INT, GameMode, Rank
from app.models.score import (
MODE_TO_INT,
GameMode,
HitResult,
LeaderboardType,
Rank,
ScoreStatistics,
)
from .beatmap import Beatmap, BeatmapResp
from .beatmapset import BeatmapsetResp
from .beatmapset import Beatmapset, BeatmapsetResp
from pydantic import BaseModel
from sqlalchemy import Column, DateTime
from sqlmodel import JSON, BigInteger, Field, Relationship, SQLModel
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
from sqlalchemy.orm import aliased, joinedload
from sqlmodel import (
JSON,
BigInteger,
Field,
ForeignKey,
Relationship,
SQLModel,
col,
false,
func,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql._expression_select_cls import SelectOfScalar
class ScoreBase(SQLModel):
@@ -34,6 +55,9 @@ class ScoreBase(SQLModel):
room_id: int | None = Field(default=None) # multiplayer
started_at: datetime = Field(sa_column=Column(DateTime))
total_score: int = Field(default=0, sa_column=Column(BigInteger))
total_score_without_mods: int = Field(
default=0, sa_column=Column(BigInteger), exclude=True
)
type: str
# optional
@@ -41,22 +65,20 @@ class ScoreBase(SQLModel):
position: int | None = Field(default=None) # multiplayer
class ScoreStatistics(BaseModel):
count_miss: int
count_50: int
count_100: int
count_300: int
count_geki: int
count_katu: int
count_large_tick_miss: int | None = None
count_slider_tail_hit: int | None = None
class Score(ScoreBase, table=True):
__tablename__ = "scores" # pyright: ignore[reportAssignmentType]
id: int = Field(primary_key=True)
id: int | None = Field(
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
)
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
user_id: int = Field(foreign_key="users.id", index=True)
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
index=True,
),
)
# ScoreStatistics
n300: int = Field(exclude=True)
n100: int = Field(exclude=True)
@@ -72,9 +94,51 @@ class Score(ScoreBase, table=True):
beatmap: "Beatmap" = Relationship()
user: "User" = Relationship()
@property
def is_perfect_combo(self) -> bool:
return self.max_combo == self.beatmap.max_combo
@staticmethod
def select_clause() -> SelectOfScalar["Score"]:
return select(Score).options(
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
joinedload(Score.user).joinedload(User.lazer_profile), # pyright: ignore[reportArgumentType]
)
@staticmethod
def select_clause_unique(
*where_clauses: ColumnExpressionArgument[bool] | bool,
) -> SelectOfScalar["Score"]:
rownum = (
func.row_number()
.over(
partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()
)
.label("rn")
)
subq = select(Score, rownum).where(*where_clauses).subquery()
best = aliased(Score, subq, adapt_on_names=True)
return (
select(best)
.where(subq.c.rn == 1)
.options(
joinedload(best.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
),
joinedload(best.user).joinedload(User.lazer_profile), # pyright: ignore[reportArgumentType]
)
)
class ScoreResp(ScoreBase):
id: int
user_id: int
is_perfect_combo: bool = False
legacy_perfect: bool = False
legacy_total_score: int = 0 # FIXME
@@ -85,10 +149,13 @@ class ScoreResp(ScoreBase):
beatmapset: BeatmapsetResp | None = None
# FIXME: user: APIUser | None = None
statistics: ScoreStatistics | None = None
rank_global: int | None = None
rank_country: int | None = None
@classmethod
def from_db(cls, score: Score) -> "ScoreResp":
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
s = cls.model_validate(score.model_dump())
assert score.id
s.beatmap = BeatmapResp.from_db(score.beatmap)
s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset)
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
@@ -97,14 +164,220 @@ class ScoreResp(ScoreBase):
if score.best_id:
# https://osu.ppy.sh/wiki/Performance_points/Weighting_system
s.weight = math.pow(0.95, score.best_id)
s.statistics = ScoreStatistics(
count_miss=score.nmiss,
count_50=score.n50,
count_100=score.n100,
count_300=score.n300,
count_geki=score.ngeki,
count_katu=score.nkatu,
count_large_tick_miss=score.nlarge_tick_miss,
count_slider_tail_hit=score.nslider_tail_hit,
s.statistics = {
HitResult.MISS: score.nmiss,
HitResult.MEH: score.n50,
HitResult.OK: score.n100,
HitResult.GREAT: score.n300,
HitResult.PERFECT: score.ngeki,
HitResult.GOOD: score.nkatu,
}
if score.nlarge_tick_miss is not None:
s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
if score.nslider_tail_hit is not None:
s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
# s.user = await convert_db_user_to_api_user(score.user)
s.rank_global = (
await get_score_position_by_id(
session,
score.map_md5,
score.id,
mode=score.gamemode,
user=score.user,
)
or None
)
s.rank_country = (
await get_score_position_by_id(
session,
score.map_md5,
score.id,
score.gamemode,
score.user,
)
or None
)
return s
async def get_leaderboard(
session: AsyncSession,
beatmap_md5: str,
mode: GameMode,
type: LeaderboardType = LeaderboardType.GLOBAL,
mods: list[APIMod] | None = None,
user: User | None = None,
limit: int = 50,
) -> list[Score]:
scores = []
if type == LeaderboardType.GLOBAL:
query = (
select(Score)
.where(
col(Beatmap.beatmap_status).in_(
[
BeatmapRankStatus.RANKED,
BeatmapRankStatus.LOVED,
BeatmapRankStatus.QUALIFIED,
BeatmapRankStatus.APPROVED,
]
),
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
col(Score.passed).is_(True),
Score.mods == mods if user and user.is_supporter else false(),
)
.limit(limit)
.order_by(
col(Score.total_score).desc(),
)
)
result = await session.exec(query)
scores = list[Score](result.all())
elif type == LeaderboardType.FRIENDS and user and user.is_supporter:
# TODO
...
elif type == LeaderboardType.TEAM and user and user.team_membership:
team_id = user.team_membership.team_id
query = (
select(Score)
.join(Beatmap)
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
.where(
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
col(Score.passed).is_(True),
col(Score.user.team_membership).is_not(None),
Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
Score.mods == mods if user and user.is_supporter else false(),
)
.limit(limit)
.order_by(
col(Score.total_score).desc(),
)
)
result = await session.exec(query)
scores = list[Score](result.all())
if user:
user_score = (
await session.exec(
select(Score).where(
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
Score.user_id == user.id,
col(Score.passed).is_(True),
)
)
).first()
if user_score and user_score not in scores:
scores.append(user_score)
return scores
async def get_score_position_by_user(
session: AsyncSession,
beatmap_md5: str,
user: User,
mode: GameMode,
type: LeaderboardType = LeaderboardType.GLOBAL,
mods: list[APIMod] | None = None,
) -> int:
where_clause = [
Score.map_md5 == beatmap_md5,
Score.gamemode == mode,
col(Score.passed).is_(True),
col(Beatmap.beatmap_status).in_(
[
BeatmapRankStatus.RANKED,
BeatmapRankStatus.LOVED,
BeatmapRankStatus.QUALIFIED,
BeatmapRankStatus.APPROVED,
]
),
]
if mods and user.is_supporter:
where_clause.append(Score.mods == mods)
else:
where_clause.append(false())
if type == LeaderboardType.FRIENDS and user.is_supporter:
# TODO
...
elif type == LeaderboardType.TEAM and user.team_membership:
team_id = user.team_membership.team_id
where_clause.append(
col(Score.user.team_membership).is_not(None),
)
where_clause.append(
Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
)
rownum = (
func.row_number()
.over(
partition_by=Score.map_md5,
order_by=col(Score.total_score).desc(),
)
.label("row_number")
)
subq = select(Score, rownum).join(Beatmap).where(*where_clause).subquery()
stmt = select(subq.c.row_number).where(subq.c.user == user)
result = await session.exec(stmt)
s = result.one_or_none()
return s if s else 0
async def get_score_position_by_id(
session: AsyncSession,
beatmap_md5: str,
score_id: int,
mode: GameMode,
user: User | None = None,
type: LeaderboardType = LeaderboardType.GLOBAL,
mods: list[APIMod] | None = None,
) -> int:
where_clause = [
Score.map_md5 == beatmap_md5,
Score.id == score_id,
Score.gamemode == mode,
col(Score.passed).is_(True),
col(Beatmap.beatmap_status).in_(
[
BeatmapRankStatus.RANKED,
BeatmapRankStatus.LOVED,
BeatmapRankStatus.QUALIFIED,
BeatmapRankStatus.APPROVED,
]
),
]
if mods and user and user.is_supporter:
where_clause.append(Score.mods == mods)
elif mods:
where_clause.append(false())
rownum = (
func.row_number()
.over(
partition_by=[col(Score.user_id), col(Score.map_md5)],
order_by=col(Score.total_score).desc(),
)
.label("rownum")
)
subq = (
select(Score.user_id, Score.id, Score.total_score, rownum)
.join(Beatmap)
.where(*where_clause)
.subquery()
)
best_scores = aliased(subq)
overall_rank = (
func.rank().over(order_by=best_scores.c.total_score.desc()).label("global_rank")
)
final_q = (
select(best_scores.c.id, overall_rank)
.select_from(best_scores)
.where(best_scores.c.rownum == 1)
.subquery()
)
stmt = select(final_q.c.global_rank).where(final_q.c.id == score_id)
result = await session.exec(stmt)
s = result.one_or_none()
return s if s else 0

View File

@@ -0,0 +1,50 @@
from datetime import datetime
from app.models.score import GameMode
from .beatmap import Beatmap
from .user import User
from sqlalchemy import Column, DateTime, Index
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
class ScoreTokenBase(SQLModel):
score_id: int | None = Field(sa_column=Column(BigInteger), default=None)
ruleset_id: GameMode
playlist_item_id: int | None = Field(default=None) # playlist
created_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
updated_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
class ScoreToken(ScoreTokenBase, table=True):
__tablename__ = "score_tokens" # pyright: ignore[reportAssignmentType]
__table_args__ = (Index("idx_user_playlist", "user_id", "playlist_item_id"),)
id: int | None = Field(
default=None,
sa_column=Column(
BigInteger,
primary_key=True,
index=True,
autoincrement=True,
),
)
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
beatmap_id: int = Field(foreign_key="beatmaps.id")
user: "User" = Relationship()
beatmap: "Beatmap" = Relationship()
class ScoreTokenResp(ScoreTokenBase):
id: int
user_id: int
beatmap_id: int
@classmethod
def from_db(cls, obj: ScoreToken) -> "ScoreTokenResp":
return cls.model_validate(obj)

View File

@@ -2,8 +2,7 @@ from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import Column, DateTime
from sqlalchemy.orm import Mapped
from sqlmodel import Field, Relationship, SQLModel
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
@@ -20,18 +19,18 @@ class Team(SQLModel, table=True):
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
members: Mapped[list["TeamMember"]] = Relationship(back_populates="team")
members: list["TeamMember"] = Relationship(back_populates="team")
class TeamMember(SQLModel, table=True):
__tablename__ = "team_members" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
team_id: int = Field(foreign_key="teams.id")
joined_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
)
user: Mapped["User"] = Relationship(back_populates="team_membership")
team: Mapped["Team"] = Relationship(back_populates="members")
user: "User" = Relationship(back_populates="team_membership")
team: "Team" = Relationship(back_populates="members")

View File

@@ -7,14 +7,16 @@ from .team import TeamMember
from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text
from sqlalchemy.dialects.mysql import VARCHAR
from sqlmodel import BigInteger, Field, Relationship, SQLModel
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
class User(SQLModel, table=True):
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
# 主键
id: int = Field(default=None, primary_key=True, index=True, nullable=False)
id: int = Field(
default=None, sa_column=Column(BigInteger, primary_key=True, index=True)
)
# 基本信息(匹配 migrations 中的结构)
name: str = Field(max_length=32, unique=True, index=True) # 用户名
@@ -65,6 +67,10 @@ class User(SQLModel, table=True):
latest_activity = getattr(self, "latest_activity", 0)
return datetime.fromtimestamp(latest_activity) if latest_activity > 0 else None
@property
def is_supporter(self):
return self.lazer_profile.is_supporter if self.lazer_profile else False
# 关联关系
lazer_profile: Optional["LazerUserProfile"] = Relationship(back_populates="user")
lazer_statistics: list["LazerUserStatistics"] = Relationship(back_populates="user")
@@ -76,7 +82,7 @@ class User(SQLModel, table=True):
back_populates="user"
)
statistics: list["LegacyUserStatistics"] = Relationship(back_populates="user")
team_membership: list["TeamMember"] = Relationship(back_populates="user")
team_membership: Optional["TeamMember"] = Relationship(back_populates="user")
daily_challenge_stats: Optional["DailyChallengeStats"] = Relationship(
back_populates="user"
)
@@ -103,7 +109,14 @@ class User(SQLModel, table=True):
class LazerUserProfile(SQLModel, table=True):
__tablename__ = "lazer_user_profiles" # pyright: ignore[reportAssignmentType]
user_id: int = Field(foreign_key="users.id", primary_key=True)
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
# 基本状态字段
is_active: bool = Field(default=True)
@@ -159,7 +172,7 @@ class LazerUserProfileSections(SQLModel, table=True):
__tablename__ = "lazer_user_profile_sections" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
section_name: str = Field(sa_column=Column(VARCHAR(50)))
display_order: int | None = Field(default=None)
@@ -176,7 +189,14 @@ class LazerUserProfileSections(SQLModel, table=True):
class LazerUserCountry(SQLModel, table=True):
__tablename__ = "lazer_user_countries" # pyright: ignore[reportAssignmentType]
user_id: int = Field(foreign_key="users.id", primary_key=True)
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
code: str = Field(max_length=2)
name: str = Field(max_length=100)
@@ -191,7 +211,14 @@ class LazerUserCountry(SQLModel, table=True):
class LazerUserKudosu(SQLModel, table=True):
__tablename__ = "lazer_user_kudosu" # pyright: ignore[reportAssignmentType]
user_id: int = Field(foreign_key="users.id", primary_key=True)
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
available: int = Field(default=0)
total: int = Field(default=0)
@@ -206,7 +233,14 @@ class LazerUserKudosu(SQLModel, table=True):
class LazerUserCounts(SQLModel, table=True):
__tablename__ = "lazer_user_counts" # pyright: ignore[reportAssignmentType]
user_id: int = Field(foreign_key="users.id", primary_key=True)
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
# 统计计数字段
beatmap_playcounts_count: int = Field(default=0)
@@ -241,7 +275,14 @@ class LazerUserCounts(SQLModel, table=True):
class LazerUserStatistics(SQLModel, table=True):
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
user_id: int = Field(foreign_key="users.id", primary_key=True)
user_id: int = Field(
default=None,
sa_column=Column(
BigInteger,
ForeignKey("users.id"),
primary_key=True,
),
)
mode: str = Field(default="osu", max_length=10, primary_key=True)
# 基本命中统计
@@ -302,7 +343,7 @@ class LazerUserBanners(SQLModel, table=True):
__tablename__ = "lazer_user_tournament_banners" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
tournament_id: int
image_url: str = Field(sa_column=Column(VARCHAR(500)))
is_active: bool | None = Field(default=None)
@@ -315,7 +356,7 @@ class LazerUserAchievement(SQLModel, table=True):
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
achievement_id: int
achieved_at: datetime = Field(
default_factory=datetime.utcnow, sa_column=Column(DateTime)
@@ -328,7 +369,7 @@ class LazerUserBadge(SQLModel, table=True):
__tablename__ = "lazer_user_badges" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
badge_id: int
awarded_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
description: str | None = Field(default=None, sa_column=Column(Text))
@@ -349,7 +390,7 @@ class LazerUserMonthlyPlaycounts(SQLModel, table=True):
__tablename__ = "lazer_user_monthly_playcounts" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
start_date: datetime = Field(sa_column=Column(Date))
play_count: int = Field(default=0)
@@ -367,7 +408,7 @@ class LazerUserPreviousUsername(SQLModel, table=True):
__tablename__ = "lazer_user_previous_usernames" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
username: str = Field(max_length=32)
changed_at: datetime = Field(sa_column=Column(DateTime))
@@ -385,7 +426,7 @@ class LazerUserReplaysWatched(SQLModel, table=True):
__tablename__ = "lazer_user_replays_watched" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
start_date: datetime = Field(sa_column=Column(Date))
count: int = Field(default=0)
@@ -410,7 +451,9 @@ class DailyChallengeStats(SQLModel, table=True):
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(foreign_key="users.id", unique=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("users.id"), unique=True)
)
daily_streak_best: int = Field(default=0)
daily_streak_current: int = Field(default=0)
@@ -431,7 +474,7 @@ class RankHistory(SQLModel, table=True):
__tablename__ = "rank_history" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
mode: str = Field(max_length=10)
rank_data: list = Field(sa_column=Column(JSON)) # Array of ranks
date_recorded: datetime = Field(
@@ -445,7 +488,7 @@ class UserAvatar(SQLModel, table=True):
__tablename__ = "user_avatars" # pyright: ignore[reportAssignmentType]
id: int | None = Field(default=None, primary_key=True, index=True)
user_id: int = Field(foreign_key="users.id")
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
filename: str = Field(max_length=255)
original_filename: str = Field(max_length=255)
file_size: int

View File

@@ -45,7 +45,7 @@ async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | No
selectinload(DBUser.lazer_achievements), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_profile_sections), # pyright: ignore[reportArgumentType]
selectinload(DBUser.statistics), # pyright: ignore[reportArgumentType]
selectinload(DBUser.team_membership), # pyright: ignore[reportArgumentType]
joinedload(DBUser.team_membership), # pyright: ignore[reportArgumentType]
selectinload(DBUser.rank_history), # pyright: ignore[reportArgumentType]
selectinload(DBUser.active_banners), # pyright: ignore[reportArgumentType]
selectinload(DBUser.lazer_badges), # pyright: ignore[reportArgumentType]

View File

@@ -1,23 +1,27 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from app.database.beatmap import BeatmapResp
from ._base import BaseFetcher
from httpx import AsyncClient
if TYPE_CHECKING:
from app.database.beatmap import BeatmapResp
class BeatmapFetcher(BaseFetcher):
async def get_beatmap(self, beatmap_id: int) -> "BeatmapResp":
from app.database.beatmap import BeatmapResp
async def get_beatmap(
self, beatmap_id: int | None = None, beatmap_checksum: str | None = None
) -> BeatmapResp:
if beatmap_id:
params = {"id": beatmap_id}
elif beatmap_checksum:
params = {"checksum": beatmap_checksum}
else:
raise ValueError("Either beatmap_id or beatmap_checksum must be provided.")
async with AsyncClient() as client:
response = await client.get(
f"https://osu.ppy.sh/api/v2/beatmaps/{beatmap_id}",
"https://osu.ppy.sh/api/v2/beatmaps/lookup",
headers=self.header,
params=params,
)
response.raise_for_status()
return BeatmapResp.model_validate(response.json())

152
app/models/metadata_hub.py Normal file
View File

@@ -0,0 +1,152 @@
from __future__ import annotations
from enum import IntEnum
from typing import Any, Literal
from app.models.signalr import UserState
from pydantic import BaseModel, ConfigDict, Field
class _UserActivity(BaseModel):
model_config = ConfigDict(serialize_by_alias=True)
type: Literal[
"ChoosingBeatmap",
"InSoloGame",
"WatchingReplay",
"SpectatingUser",
"SearchingForLobby",
"InLobby",
"InMultiplayerGame",
"SpectatingMultiplayerGame",
"InPlaylistGame",
"EditingBeatmap",
"ModdingBeatmap",
"TestingBeatmap",
"InDailyChallengeLobby",
"PlayingDailyChallenge",
] = Field(alias="$dtype")
value: Any | None = Field(alias="$value")
class ChoosingBeatmap(_UserActivity):
type: Literal["ChoosingBeatmap"] = Field(alias="$dtype")
class InGameValue(BaseModel):
beatmap_id: int = Field(alias="BeatmapID")
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
ruleset_id: int = Field(alias="RulesetID")
ruleset_playing_verb: str = Field(alias="RulesetPlayingVerb")
class _InGame(_UserActivity):
value: InGameValue = Field(alias="$value")
class InSoloGame(_InGame):
type: Literal["InSoloGame"] = Field(alias="$dtype")
class InMultiplayerGame(_InGame):
type: Literal["InMultiplayerGame"] = Field(alias="$dtype")
class SpectatingMultiplayerGame(_InGame):
type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype")
class InPlaylistGame(_InGame):
type: Literal["InPlaylistGame"] = Field(alias="$dtype")
class EditingBeatmapValue(BaseModel):
beatmap_id: int = Field(alias="BeatmapID")
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
class EditingBeatmap(_UserActivity):
type: Literal["EditingBeatmap"] = Field(alias="$dtype")
value: EditingBeatmapValue = Field(alias="$value")
class TestingBeatmap(_UserActivity):
type: Literal["TestingBeatmap"] = Field(alias="$dtype")
class ModdingBeatmap(_UserActivity):
type: Literal["ModdingBeatmap"] = Field(alias="$dtype")
class WatchingReplayValue(BaseModel):
score_id: int = Field(alias="ScoreID")
player_name: str = Field(alias="PlayerName")
beatmap_id: int = Field(alias="BeatmapID")
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
class WatchingReplay(_UserActivity):
type: Literal["WatchingReplay"] = Field(alias="$dtype")
value: int | None = Field(alias="$value") # Replay ID
class SpectatingUser(WatchingReplay):
type: Literal["SpectatingUser"] = Field(alias="$dtype")
class SearchingForLobby(_UserActivity):
type: Literal["SearchingForLobby"] = Field(alias="$dtype")
class InLobbyValue(BaseModel):
room_id: int = Field(alias="RoomID")
room_name: str = Field(alias="RoomName")
class InLobby(_UserActivity):
type: Literal["InLobby"] = "InLobby"
class InDailyChallengeLobby(_UserActivity):
type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype")
UserActivity = (
ChoosingBeatmap
| InSoloGame
| WatchingReplay
| SpectatingUser
| SearchingForLobby
| InLobby
| InMultiplayerGame
| SpectatingMultiplayerGame
| InPlaylistGame
| EditingBeatmap
| ModdingBeatmap
| TestingBeatmap
| InDailyChallengeLobby
)
class MetadataClientState(UserState):
user_activity: UserActivity | None = None
status: OnlineStatus | None = None
def to_dict(self) -> dict[str, Any] | None:
if self.status is None or self.status == OnlineStatus.OFFLINE:
return None
dumped = self.model_dump(by_alias=True, exclude_none=True)
return {
"Activity": dumped.get("user_activity"),
"Status": dumped.get("status"),
}
@property
def pushable(self) -> bool:
return self.status is not None and self.status != OnlineStatus.OFFLINE
class OnlineStatus(IntEnum):
OFFLINE = 0 # 隐身
DO_NOT_DISTURB = 1
ONLINE = 2

View File

@@ -1,47 +1,91 @@
from __future__ import annotations
from typing import TypedDict
import json
from typing import Literal, NotRequired, TypedDict
from app.path import STATIC_DIR
class APIMod(TypedDict):
acronym: str
settings: dict[str, bool | float | str]
settings: NotRequired[dict[str, bool | float | str]]
# https://github.com/ppy/osu-api/wiki#mods
LEGACY_MOD_TO_API_MOD = {
(1 << 0): APIMod(acronym="NF", settings={}), # No Fail
(1 << 1): APIMod(acronym="EZ", settings={}),
(1 << 2): APIMod(acronym="TD", settings={}), # Touch Device
(1 << 3): APIMod(acronym="HD", settings={}), # Hidden
(1 << 4): APIMod(acronym="HR", settings={}), # Hard Rock
(1 << 5): APIMod(acronym="SD", settings={}), # Sudden Death
(1 << 6): APIMod(acronym="DT", settings={}), # Double Time
(1 << 7): APIMod(acronym="RX", settings={}), # Relax
(1 << 8): APIMod(acronym="HT", settings={}), # Half Time
(1 << 9): APIMod(acronym="NC", settings={}), # Nightcore
(1 << 10): APIMod(acronym="FL", settings={}), # Flashlight
(1 << 11): APIMod(acronym="AT", settings={}), # Auto Play
(1 << 12): APIMod(acronym="SO", settings={}), # Spun Out
(1 << 13): APIMod(acronym="AP", settings={}), # Autopilot
(1 << 14): APIMod(acronym="PF", settings={}), # Perfect
(1 << 15): APIMod(acronym="4K", settings={}), # 4K
(1 << 16): APIMod(acronym="5K", settings={}), # 5K
(1 << 17): APIMod(acronym="6K", settings={}), # 6K
(1 << 18): APIMod(acronym="7K", settings={}), # 7K
(1 << 19): APIMod(acronym="8K", settings={}), # 8K
(1 << 20): APIMod(acronym="FI", settings={}), # Fade In
(1 << 21): APIMod(acronym="RD", settings={}), # Random
(1 << 22): APIMod(acronym="CN", settings={}), # Cinema
(1 << 23): APIMod(acronym="TP", settings={}), # Target Practice
(1 << 24): APIMod(acronym="9K", settings={}), # 9K
(1 << 25): APIMod(acronym="CO", settings={}), # Key Co-op
(1 << 26): APIMod(acronym="1K", settings={}), # 1K
(1 << 27): APIMod(acronym="2K", settings={}), # 2K
(1 << 28): APIMod(acronym="3K", settings={}), # 3K
(1 << 29): APIMod(acronym="SV2", settings={}), # Score V2
(1 << 30): APIMod(acronym="MR", settings={}), # Mirror
API_MOD_TO_LEGACY: dict[str, int] = {
"NF": 1 << 0, # No Fail
"EZ": 1 << 1, # Easy
"TD": 1 << 2, # Touch Device
"HD": 1 << 3, # Hidden
"HR": 1 << 4, # Hard Rock
"SD": 1 << 5, # Sudden Death
"DT": 1 << 6, # Double Time
"RX": 1 << 7, # Relax
"HT": 1 << 8, # Half Time
"NC": 1 << 9, # Nightcore
"FL": 1 << 10, # Flashlight
"AT": 1 << 11, # Autoplay
"SO": 1 << 12, # Spun Out
"AP": 1 << 13, # Auto Pilot
"PF": 1 << 14, # Perfect
"4K": 1 << 15, # 4K
"5K": 1 << 16, # 5K
"6K": 1 << 17, # 6K
"7K": 1 << 18, # 7K
"8K": 1 << 19, # 8K
"FI": 1 << 20, # Fade In
"RD": 1 << 21, # Random
"CN": 1 << 22, # Cinema
"TP": 1 << 23, # Target Practice
"9K": 1 << 24, # 9K
"CO": 1 << 25, # Key Co-op
"1K": 1 << 26, # 1K
"3K": 1 << 27, # 3K
"2K": 1 << 28, # 2K
"SV2": 1 << 29, # ScoreV2
"MR": 1 << 30, # Mirror
}
LEGACY_MOD_TO_API_MOD = {}
for k, v in API_MOD_TO_LEGACY.items():
LEGACY_MOD_TO_API_MOD[v] = APIMod(acronym=k, settings={})
API_MOD_TO_LEGACY["NC"] |= API_MOD_TO_LEGACY["DT"]
API_MOD_TO_LEGACY["PF"] |= API_MOD_TO_LEGACY["SD"]
# see static/mods.json
class Settings(TypedDict):
Name: str
Type: str
Label: str
Description: str
class Mod(TypedDict):
Acronym: str
Name: str
Description: str
Type: str
Settings: list[Settings]
IncompatibleMods: list[str]
RequiresConfiguration: bool
UserPlayable: bool
ValidForMultiplayer: bool
ValidForFreestyleAsRequiredMod: bool
ValidForMultiplayerAsFreeMod: bool
AlwaysValidForSubmission: bool
API_MODS: dict[Literal[0, 1, 2, 3], dict[str, Mod]] = {}
def init_mods():
mods_file = STATIC_DIR / "mods.json"
raw_mods = json.loads(mods_file.read_text())
for ruleset in raw_mods:
ruleset_mods = {}
for mod in ruleset["Mods"]:
ruleset_mods[mod["Acronym"]] = mod
API_MODS[ruleset["RulesetID"]] = ruleset_mods
def int_to_mods(mods: int) -> list[APIMod]:
@@ -54,3 +98,10 @@ def int_to_mods(mods: int) -> list[APIMod]:
if mods & (1 << 9):
mod_list.remove(LEGACY_MOD_TO_API_MOD[(1 << 6)])
return mod_list
def mods_to_int(mods: list[APIMod]) -> int:
sum_ = 0
for mod in mods:
sum_ |= API_MOD_TO_LEGACY.get(mod["acronym"], 0)
return sum_

View File

@@ -1,7 +1,11 @@
from __future__ import annotations
from enum import Enum, IntEnum
from typing import Literal, TypedDict
from .mods import API_MODS, APIMod, init_mods
from pydantic import BaseModel, Field, ValidationInfo, field_validator
import rosu_pp_py as rosu
@@ -30,40 +34,141 @@ INT_TO_MODE = {v: k for k, v in MODE_TO_INT.items()}
class Rank(str, Enum):
X = "ss"
XH = "ssh"
S = "s"
SH = "sh"
A = "a"
B = "b"
C = "c"
D = "d"
F = "f"
X = "X"
XH = "XH"
S = "S"
SH = "SH"
A = "A"
B = "B"
C = "C"
D = "D"
F = "F"
# https://github.com/ppy/osu/blob/master/osu.Game/Rulesets/Scoring/HitResult.cs
class HitResult(IntEnum):
PERFECT = 0 # [Order(0)]
GREAT = 1 # [Order(1)]
GOOD = 2 # [Order(2)]
OK = 3 # [Order(3)]
MEH = 4 # [Order(4)]
MISS = 5 # [Order(5)]
class HitResult(str, Enum):
PERFECT = "perfect" # [Order(0)]
GREAT = "great" # [Order(1)]
GOOD = "good" # [Order(2)]
OK = "ok" # [Order(3)]
MEH = "meh" # [Order(4)]
MISS = "miss" # [Order(5)]
LARGE_TICK_HIT = 6 # [Order(6)]
SMALL_TICK_HIT = 7 # [Order(7)]
SLIDER_TAIL_HIT = 8 # [Order(8)]
LARGE_TICK_HIT = "large_tick_hit" # [Order(6)]
SMALL_TICK_HIT = "small_tick_hit" # [Order(7)]
SLIDER_TAIL_HIT = "slider_tail_hit" # [Order(8)]
LARGE_BONUS = 9 # [Order(9)]
SMALL_BONUS = 10 # [Order(10)]
LARGE_BONUS = "large_bonus" # [Order(9)]
SMALL_BONUS = "small_bonus" # [Order(10)]
LARGE_TICK_MISS = 11 # [Order(11)]
SMALL_TICK_MISS = 12 # [Order(12)]
LARGE_TICK_MISS = "large_tick_miss" # [Order(11)]
SMALL_TICK_MISS = "small_tick_miss" # [Order(12)]
IGNORE_HIT = 13 # [Order(13)]
IGNORE_MISS = 14 # [Order(14)]
IGNORE_HIT = "ignore_hit" # [Order(13)]
IGNORE_MISS = "ignore_miss" # [Order(14)]
NONE = 15 # [Order(15)]
COMBO_BREAK = 16 # [Order(16)]
NONE = "none" # [Order(15)]
COMBO_BREAK = "combo_break" # [Order(16)]
LEGACY_COMBO_INCREASE = 99 # [Order(99)] @deprecated
LEGACY_COMBO_INCREASE = "legacy_combo_increase" # [Order(99)] @deprecated
def is_hit(self) -> bool:
return self not in (
HitResult.NONE,
HitResult.IGNORE_MISS,
HitResult.COMBO_BREAK,
HitResult.LARGE_TICK_MISS,
HitResult.SMALL_TICK_MISS,
HitResult.MISS,
)
class HitResultInt(IntEnum):
PERFECT = 0
GREAT = 1
GOOD = 2
OK = 3
MEH = 4
MISS = 5
LARGE_TICK_HIT = 6
SMALL_TICK_HIT = 7
SLIDER_TAIL_HIT = 8
LARGE_BONUS = 9
SMALL_BONUS = 10
LARGE_TICK_MISS = 11
SMALL_TICK_MISS = 12
IGNORE_HIT = 13
IGNORE_MISS = 14
NONE = 15
COMBO_BREAK = 16
LEGACY_COMBO_INCREASE = 99
def is_hit(self) -> bool:
return self not in (
HitResultInt.NONE,
HitResultInt.IGNORE_MISS,
HitResultInt.COMBO_BREAK,
HitResultInt.LARGE_TICK_MISS,
HitResultInt.SMALL_TICK_MISS,
HitResultInt.MISS,
)
class LeaderboardType(Enum):
GLOBAL = "global"
FRIENDS = "friends"
COUNTRY = "country"
TEAM = "team"
ScoreStatistics = dict[HitResult, int]
ScoreStatisticsInt = dict[HitResultInt, int]
class SoloScoreSubmissionInfo(BaseModel):
rank: Rank
total_score: int = Field(ge=0, le=2**31 - 1)
total_score_without_mods: int = Field(ge=0, le=2**31 - 1)
accuracy: float = Field(ge=0, le=1)
pp: float = Field(default=0, ge=0, le=2**31 - 1)
max_combo: int = 0
ruleset_id: Literal[0, 1, 2, 3]
passed: bool = False
mods: list[APIMod] = Field(default_factory=list)
statistics: ScoreStatistics = Field(default_factory=dict)
maximum_statistics: ScoreStatistics = Field(default_factory=dict)
@field_validator("mods", mode="after")
@classmethod
def validate_mods(cls, mods: list[APIMod], info: ValidationInfo):
if not API_MODS:
init_mods()
incompatible_mods = set()
# check incompatible mods
for mod in mods:
if mod["acronym"] in incompatible_mods:
raise ValueError(
f"Mod {mod['acronym']} is incompatible with other mods"
)
setting_mods = API_MODS[info.data["ruleset_id"]].get(mod["acronym"])
if not setting_mods:
raise ValueError(f"Invalid mod: {mod['acronym']}")
incompatible_mods.update(setting_mods["IncompatibleMods"])
return mods
class LegacyReplaySoloScoreInfo(TypedDict):
online_id: int
mods: list[APIMod]
statistics: ScoreStatisticsInt
maximum_statistics: ScoreStatisticsInt
client_version: str
rank: Rank
user_id: int
total_score_without_mods: int

View File

@@ -1,11 +1,42 @@
from __future__ import annotations
from typing import Any
import datetime
from typing import Any, get_origin
from pydantic import BaseModel, Field, model_validator
import msgpack
from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
model_serializer,
model_validator,
)
def serialize_to_list(value: BaseModel) -> list[Any]:
data = []
for field, info in value.__class__.model_fields.items():
v = getattr(value, field)
anno = get_origin(info.annotation)
if anno and issubclass(anno, BaseModel):
data.append(serialize_to_list(v))
elif anno and issubclass(anno, list):
data.append(
TypeAdapter(
info.annotation,
).dump_python(v)
)
elif isinstance(v, datetime.datetime):
data.append([msgpack.ext.Timestamp.from_datetime(v), 0])
else:
data.append(v)
return data
class MessagePackArrayModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="before")
@classmethod
def unpack(cls, v: Any) -> Any:
@@ -16,11 +47,15 @@ class MessagePackArrayModel(BaseModel):
return dict(zip(fields, v))
return v
@model_serializer
def serialize(self) -> list[Any]:
return serialize_to_list(self)
class Transport(BaseModel):
transport: str
transfer_formats: list[str] = Field(
default_factory=lambda: ["Binary"], alias="transferFormats"
default_factory=lambda: ["Binary", "Text"], alias="transferFormats"
)
@@ -29,3 +64,8 @@ class NegotiateResponse(BaseModel):
connectionToken: str
negotiateVersion: int = 1
availableTransports: list[Transport]
class UserState(BaseModel):
connection_id: str
connection_token: str

View File

@@ -4,18 +4,22 @@ import datetime
from enum import IntEnum
from typing import Any
from app.models.beatmap import BeatmapRankStatus
from .score import (
HitResult,
ScoreStatisticsInt,
)
from .signalr import MessagePackArrayModel
from .signalr import MessagePackArrayModel, UserState
import msgpack
from pydantic import Field, field_validator
from pydantic import BaseModel, Field, field_validator
class APIMod(MessagePackArrayModel):
acronym: str
settings: dict[str, Any] = Field(default_factory=dict)
settings: dict[str, Any] | list = Field(
default_factory=dict
) # FIXME: with settings
class SpectatedUserState(IntEnum):
@@ -32,7 +36,7 @@ class SpectatorState(MessagePackArrayModel):
ruleset_id: int | None = None # 0,1,2,3
mods: list[APIMod] = Field(default_factory=list)
state: SpectatedUserState
maximum_statistics: dict[HitResult, int] = Field(default_factory=dict)
maximum_statistics: ScoreStatisticsInt = Field(default_factory=dict)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SpectatorState):
@@ -58,7 +62,7 @@ class FrameHeader(MessagePackArrayModel):
acc: float
combo: int
max_combo: int
statistics: dict[HitResult, int] = Field(default_factory=dict)
statistics: ScoreStatisticsInt = Field(default_factory=dict)
score_processor_statistics: ScoreProcessorStatistics
received_time: datetime.datetime
mods: list[APIMod] = Field(default_factory=list)
@@ -79,22 +83,56 @@ class FrameHeader(MessagePackArrayModel):
raise ValueError(f"Cannot convert {type(v)} to datetime")
class ReplayButtonState(IntEnum):
NONE = 0
LEFT1 = 1
RIGHT1 = 2
LEFT2 = 4
RIGHT2 = 8
SMOKE = 16
# class ReplayButtonState(IntEnum):
# NONE = 0
# LEFT1 = 1
# RIGHT1 = 2
# LEFT2 = 4
# RIGHT2 = 8
# SMOKE = 16
class LegacyReplayFrame(MessagePackArrayModel):
time: int # from ReplayFrame,the parent of LegacyReplayFrame
time: float # from ReplayFrame,the parent of LegacyReplayFrame
x: float | None = None
y: float | None = None
button_state: ReplayButtonState
button_state: int
class FrameDataBundle(MessagePackArrayModel):
header: FrameHeader
frames: list[LegacyReplayFrame]
# Use for server
class APIUser(BaseModel):
id: int
name: str
class ScoreInfo(BaseModel):
mods: list[APIMod]
user: APIUser
ruleset: int
maximum_statistics: ScoreStatisticsInt
id: int | None = None
total_score: int | None = None
acc: float | None = None
max_combo: int | None = None
combo: int | None = None
statistics: ScoreStatisticsInt = Field(default_factory=dict)
class StoreScore(BaseModel):
score_info: ScoreInfo
replay_frames: list[LegacyReplayFrame] = Field(default_factory=list)
class StoreClientState(UserState):
state: SpectatorState | None = None
beatmap_status: BeatmapRankStatus | None = None
checksum: str | None = None
ruleset_id: int | None = None
score_token: int | None = None
watched_user: set[int] = Field(default_factory=set)
score: StoreScore | None = None

8
app/path.py Normal file
View File

@@ -0,0 +1,8 @@
from __future__ import annotations
from pathlib import Path
STATIC_DIR = Path(__file__).parent.parent / "static"
REPLAY_DIR = Path(__file__).parent.parent / "replays"
REPLAY_DIR.mkdir(exist_ok=True)

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
from app.signalr import signalr_router as signalr_router
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
beatmap,
beatmapset,
@@ -10,6 +12,5 @@ from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
from .api_router import router as api_router
from .auth import router as auth_router
from .fetcher import fetcher_router as fetcher_router
from .signalr import signalr_router as signalr_router
__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"]

View File

@@ -16,7 +16,10 @@ from app.dependencies.user import get_current_user
from app.fetcher import Fetcher
from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod, int_to_mods
from app.models.score import INT_TO_MODE, GameMode
from app.models.score import (
INT_TO_MODE,
GameMode,
)
from app.utils import calculate_beatmap_attribute
from .api_router import router
@@ -31,6 +34,31 @@ from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmaps/lookup", tags=["beatmap"], response_model=BeatmapResp)
async def lookup_beatmap(
id: int | None = Query(default=None, alias="id"),
md5: str | None = Query(default=None, alias="checksum"),
filename: str | None = Query(default=None, alias="filename"),
current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
if id is None and md5 is None and filename is None:
raise HTTPException(
status_code=400,
detail="At least one of 'id', 'checksum', or 'filename' must be provided.",
)
try:
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=id, md5=md5)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
if beatmap is None:
raise HTTPException(status_code=404, detail="Beatmap not found")
return BeatmapResp.from_db(beatmap)
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap(
bid: int,
@@ -39,7 +67,7 @@ async def get_beatmap(
fetcher: Fetcher = Depends(get_fetcher),
):
try:
beatmap = await Beatmap.get_or_fetch(db, bid, fetcher)
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
return BeatmapResp.from_db(beatmap)
except HTTPError:
raise HTTPException(status_code=404, detail="Beatmap not found")
@@ -119,7 +147,7 @@ async def get_beatmap_attributes(
if ruleset_id is not None and ruleset is None:
ruleset = INT_TO_MODE[ruleset_id]
if ruleset is None:
beatmap_db = await Beatmap.get_or_fetch(db, beatmap, fetcher)
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap)
ruleset = beatmap_db.mode
key = (
f"beatmap:{beatmap}:{ruleset}:"

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
from typing import Literal
from app.database import User as DBUser
from app.database.relationship import Relationship, RelationshipResp, RelationshipType
from app.dependencies.database import get_db
@@ -9,21 +7,23 @@ from app.dependencies.user import get_current_user
from .api_router import router
from fastapi import Depends, HTTPException, Query
from fastapi import Depends, HTTPException, Query, Request
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/{type}", tags=["relationship"], response_model=list[RelationshipResp])
@router.get("/friends", tags=["relationship"], response_model=list[RelationshipResp])
@router.get("/blocks", tags=["relationship"], response_model=list[RelationshipResp])
async def get_relationship(
type: Literal["friends", "blocks"],
request: Request,
current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if type == "friends":
relationship_type = RelationshipType.FOLLOW
else:
relationship_type = RelationshipType.BLOCK
relationship_type = (
RelationshipType.FOLLOW
if request.url.path.endswith("/friends")
else RelationshipType.BLOCK
)
relationships = await db.exec(
select(Relationship).where(
Relationship.user_id == current_user.id,
@@ -33,17 +33,19 @@ async def get_relationship(
return [await RelationshipResp.from_db(db, rel) for rel in relationships]
@router.post("/{type}", tags=["relationship"], response_model=RelationshipResp)
@router.post("/friends", tags=["relationship"], response_model=RelationshipResp)
@router.post("/blocks", tags=["relationship"])
async def add_relationship(
type: Literal["friends", "blocks"],
request: Request,
target: int = Query(),
current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if type == "blocks":
relationship_type = RelationshipType.BLOCK
else:
relationship_type = RelationshipType.FOLLOW
relationship_type = (
RelationshipType.FOLLOW
if request.url.path.endswith("/friends")
else RelationshipType.BLOCK
)
if target == current_user.id:
raise HTTPException(422, "Cannot add relationship to yourself")
relationship = (
@@ -78,18 +80,22 @@ async def add_relationship(
await db.delete(target_relationship)
await db.commit()
await db.refresh(relationship)
return await RelationshipResp.from_db(db, relationship)
if relationship.type == RelationshipType.FOLLOW:
return await RelationshipResp.from_db(db, relationship)
@router.delete("/{type}/{target}", tags=["relationship"])
@router.delete("/friends/{target}", tags=["relationship"])
@router.delete("/blocks/{target}", tags=["relationship"])
async def delete_relationship(
type: Literal["friends", "blocks"],
request: Request,
target: int,
current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
relationship_type = (
RelationshipType.BLOCK if type == "blocks" else RelationshipType.FOLLOW
RelationshipType.BLOCK
if "/blocks/" in request.url.path
else RelationshipType.FOLLOW
)
relationship = (
await db.exec(

View File

@@ -1,20 +1,28 @@
from __future__ import annotations
import datetime
from app.database import (
Beatmap,
User as DBUser,
)
from app.database.beatmapset import Beatmapset
from app.database.score import Score, ScoreResp
from app.database.score_token import ScoreToken, ScoreTokenResp
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user
from app.models.score import (
INT_TO_MODE,
GameMode,
HitResult,
Rank,
SoloScoreSubmissionInfo,
)
from .api_router import router
from fastapi import Depends, HTTPException, Query
from fastapi import Depends, Form, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import joinedload
from sqlmodel import col, select
from sqlmodel import col, select, true
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -29,7 +37,7 @@ class BeatmapScores(BaseModel):
async def get_beatmap_scores(
beatmap: int,
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
mode: str = Query(None),
mode: GameMode | None = Query(None),
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
type: str = Query(None),
current_user: DBUser = Depends(get_current_user),
@@ -42,29 +50,28 @@ async def get_beatmap_scores(
all_scores = (
await db.exec(
select(Score).where(Score.beatmap_id == beatmap)
# .where(Score.mods == mods if mods else True)
Score.select_clause_unique(
Score.beatmap_id == beatmap,
col(Score.passed).is_(True),
Score.gamemode == mode if mode is not None else true(),
)
)
).all()
user_score = (
await db.exec(
select(Score)
.options(
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
Score.select_clause_unique(
Score.beatmap_id == beatmap,
Score.user_id == current_user.id,
col(Score.passed).is_(True),
Score.gamemode == mode if mode is not None else true(),
)
.where(Score.beatmap_id == beatmap)
.where(Score.user_id == current_user.id)
)
).first()
return BeatmapScores(
scores=[ScoreResp.from_db(score) for score in all_scores],
userScore=ScoreResp.from_db(user_score) if user_score else None,
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
)
@@ -93,18 +100,13 @@ async def get_user_beatmap_score(
)
user_score = (
await db.exec(
select(Score)
.options(
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
Score.select_clause()
.where(
Score.gamemode == mode if mode is not None else True,
Score.beatmap_id == beatmap,
Score.user_id == user,
)
.where(Score.gamemode == mode if mode is not None else True)
.where(Score.beatmap_id == beatmap)
.where(Score.user_id == user)
.order_by(col(Score.classic_total_score).desc())
.order_by(col(Score.total_score).desc())
)
).first()
@@ -115,7 +117,7 @@ async def get_user_beatmap_score(
else:
return BeatmapUserScore(
position=user_score.position if user_score.position is not None else 0,
score=ScoreResp.from_db(user_score),
score=await ScoreResp.from_db(db, user_score),
)
@@ -138,19 +140,114 @@ async def get_user_all_beatmap_scores(
)
all_user_scores = (
await db.exec(
select(Score)
.options(
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
.selectinload(
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
)
Score.select_clause()
.where(
Score.gamemode == ruleset if ruleset is not None else True,
Score.beatmap_id == beatmap,
Score.user_id == user,
)
.where(Score.gamemode == ruleset if ruleset is not None else True)
.where(Score.beatmap_id == beatmap)
.where(Score.user_id == user)
.order_by(col(Score.classic_total_score).desc())
)
).all()
return [ScoreResp.from_db(score) for score in all_user_scores]
return [await ScoreResp.from_db(db, score) for score in all_user_scores]
@router.post(
"/beatmaps/{beatmap}/solo/scores", tags=["beatmap"], response_model=ScoreTokenResp
)
async def create_solo_score(
beatmap: int,
version_hash: str = Form(""),
beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3),
current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
async with db:
score_token = ScoreToken(
user_id=current_user.id,
beatmap_id=beatmap,
ruleset_id=INT_TO_MODE[ruleset_id],
)
db.add(score_token)
await db.commit()
await db.refresh(score_token)
return ScoreTokenResp.from_db(score_token)
@router.put(
"/beatmaps/{beatmap}/solo/scores/{token}",
tags=["beatmap"],
response_model=ScoreResp,
)
async def submit_solo_score(
beatmap: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: DBUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
if not info.passed:
info.rank = Rank.F
async with db:
score_token = (
await db.exec(
select(ScoreToken)
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
.where(ScoreToken.id == token, ScoreToken.user_id == current_user.id)
)
).first()
if not score_token or score_token.user_id != current_user.id:
raise HTTPException(status_code=404, detail="Score token not found")
if score_token.score_id:
score = (
await db.exec(
select(Score)
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
.where(
Score.id == score_token.score_id,
Score.user_id == current_user.id,
)
)
).first()
if not score:
raise HTTPException(status_code=404, detail="Score not found")
else:
score = Score(
accuracy=info.accuracy,
max_combo=info.max_combo,
# maximum_statistics=info.maximum_statistics,
mods=info.mods,
passed=info.passed,
rank=info.rank,
total_score=info.total_score,
total_score_without_mods=info.total_score_without_mods,
beatmap_id=beatmap,
ended_at=datetime.datetime.now(datetime.UTC),
gamemode=INT_TO_MODE[info.ruleset_id],
started_at=score_token.created_at,
user_id=current_user.id,
preserve=info.passed,
map_md5=score_token.beatmap.checksum,
has_replay=False,
pp=info.pp,
type="solo",
n300=info.statistics.get(HitResult.GREAT, 0),
n100=info.statistics.get(HitResult.OK, 0),
n50=info.statistics.get(HitResult.MEH, 0),
nmiss=info.statistics.get(HitResult.MISS, 0),
ngeki=info.statistics.get(HitResult.PERFECT, 0),
nkatu=info.statistics.get(HitResult.GOOD, 0),
)
db.add(score)
await db.commit()
await db.refresh(score)
score_id = score.id
score_token.score_id = score_id
await db.commit()
score = (
await db.exec(Score.select_clause().where(Score.id == score_id))
).first()
assert score is not None
return await ScoreResp.from_db(db, score)

View File

@@ -1,211 +0,0 @@
from __future__ import annotations
import asyncio
import time
from typing import Any
from app.config import settings
from app.router.signalr.exception import InvokeException
from app.router.signalr.packet import (
PacketType,
ResultKind,
encode_varint,
parse_packet,
)
from app.router.signalr.store import ResultStore
from app.router.signalr.utils import get_signature
from fastapi import WebSocket
import msgpack
from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
class Client:
def __init__(
self, connection_id: str, connection_token: str, connection: WebSocket
) -> None:
self.connection_id = connection_id
self.connection_token = connection_token
self.connection = connection
self._listen_task: asyncio.Task | None = None
self._ping_task: asyncio.Task | None = None
self._store = ResultStore()
async def send_packet(self, type: PacketType, packet: list[Any]):
packet.insert(0, type.value)
payload = msgpack.packb(packet)
length = encode_varint(len(payload))
await self.connection.send_bytes(length + payload)
async def _ping(self):
while True:
try:
await self.send_packet(PacketType.PING, [])
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
except WebSocketDisconnect:
break
except Exception as e:
print(f"Error in ping task for {self.connection_id}: {e}")
break
class Hub:
def __init__(self) -> None:
self.clients: dict[str, Client] = {}
self.waited_clients: dict[str, int] = {}
self.tasks: set[asyncio.Task] = set()
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
self.waited_clients[connection_token] = timestamp
def add_client(
self, connection_id: str, connection_token: str, connection: WebSocket
) -> Client:
if connection_token in self.clients:
raise ValueError(
f"Client with connection token {connection_token} already exists."
)
if connection_token in self.waited_clients:
if (
self.waited_clients[connection_token]
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT
):
raise TimeoutError(f"Connection {connection_id} has waited too long.")
del self.waited_clients[connection_token]
client = Client(connection_id, connection_token, connection)
self.clients[connection_token] = client
task = asyncio.create_task(client._ping())
self.tasks.add(task)
client._ping_task = task
return client
async def remove_client(self, connection_id: str) -> None:
if client := self.clients.get(connection_id):
del self.clients[connection_id]
if client._listen_task:
client._listen_task.cancel()
if client._ping_task:
client._ping_task.cancel()
await client.connection.close()
async def send_packet(self, client: Client, type: PacketType, packet: list[Any]):
await client.send_packet(type, packet)
async def _listen_client(self, client: Client) -> None:
jump = False
while not jump:
try:
message = await client.connection.receive_bytes()
packet_type, packet_data = parse_packet(message)
task = asyncio.create_task(
self._handle_packet(client, packet_type, packet_data)
)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
except WebSocketDisconnect as e:
if e.code == 1005:
continue
print(
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
)
jump = True
except Exception as e:
print(f"Error in client {client.connection_id}: {e}")
jump = True
await self.remove_client(client.connection_id)
async def _handle_packet(
self, client: Client, type: PacketType, packet: list[Any]
) -> None:
match type:
case PacketType.PING:
...
case PacketType.INVOCATION:
invocation_id: str | None = packet[1] # pyright: ignore[reportRedeclaration]
target: str = packet[2]
args: list[Any] | None = packet[3]
if args is None:
args = []
# streams: list[str] | None = packet[4] # TODO: stream support
code = ResultKind.VOID
result = None
try:
result = await self.invoke_method(client, target, args)
if result is not None:
code = ResultKind.HAS_VALUE
except InvokeException as e:
code = ResultKind.ERROR
result = e.message
except Exception as e:
code = ResultKind.ERROR
result = str(e)
packet = [
{}, # header
invocation_id,
code.value,
]
if result is not None:
packet.append(result)
if invocation_id is not None:
await client.send_packet(
PacketType.COMPLETION,
packet,
)
case PacketType.COMPLETION:
invocation_id: str = packet[1]
code: ResultKind = ResultKind(packet[2])
result: Any = packet[3] if len(packet) > 3 else None
client._store.add_result(invocation_id, code, result)
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
method_ = getattr(self, method, None)
call_params = []
if not method_:
raise InvokeException(f"Method '{method}' not found in hub.")
signature = get_signature(method_)
for name, param in signature.parameters.items():
if name == "self" or param.annotation is Client:
continue
if issubclass(param.annotation, BaseModel):
call_params.append(param.annotation.model_validate(args.pop(0)))
else:
call_params.append(args.pop(0))
return await method_(client, *call_params)
async def call(self, client: Client, method: str, *args: Any) -> Any:
invocation_id = client._store.get_invocation_id()
await client.send_packet(
PacketType.INVOCATION,
[
{}, # header
invocation_id,
method,
list(args),
None, # streams
],
)
r = await client._store.fetch(invocation_id, None)
if r[0] == ResultKind.HAS_VALUE:
return r[1]
if r[0] == ResultKind.ERROR:
raise InvokeException(r[1])
return None
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
await client.send_packet(
PacketType.INVOCATION,
[
{}, # header
None, # invocation_id
method,
list(args),
None, # streams
],
)
return None
def __contains__(self, item: str) -> bool:
return item in self.clients or item in self.waited_clients

View File

@@ -1,6 +0,0 @@
from __future__ import annotations
from .hub import Hub
class MetadataHub(Hub): ...

View File

@@ -1,15 +0,0 @@
from __future__ import annotations
from app.models.spectator_hub import FrameDataBundle, SpectatorState
from .hub import Client, Hub
class SpectatorHub(Hub):
async def BeginPlaySession(
self, client: Client, score_token: int, state: SpectatorState
) -> None: ...
async def SendFrameData(
self, client: Client, frame_data: FrameDataBundle
) -> None: ...

View File

@@ -1,56 +0,0 @@
from __future__ import annotations
from enum import IntEnum
from typing import Any
import msgpack
SEP = b"\x1e"
class PacketType(IntEnum):
INVOCATION = 1
STREAM_ITEM = 2
COMPLETION = 3
STREAM_INVOCATION = 4
CANCEL_INVOCATION = 5
PING = 6
CLOSE = 7
class ResultKind(IntEnum):
ERROR = 1
VOID = 2
HAS_VALUE = 3
def parse_packet(data: bytes) -> tuple[PacketType, list[Any]]:
length, offset = decode_varint(data)
message_data = data[offset : offset + length]
unpacked = msgpack.unpackb(message_data, raw=False)
return PacketType(unpacked[0]), unpacked[1:]
def encode_varint(value: int) -> bytes:
result = []
while value >= 0x80:
result.append((value & 0x7F) | 0x80)
value >>= 7
result.append(value & 0x7F)
return bytes(result)
def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
result = 0
shift = 0
pos = offset
while pos < len(data):
byte = data[pos]
result |= (byte & 0x7F) << shift
pos += 1
if (byte & 0x80) == 0:
break
shift += 7
return result, pos

296
app/signalr/hub/hub.py Normal file
View File

@@ -0,0 +1,296 @@
from __future__ import annotations
from abc import abstractmethod
import asyncio
import time
import traceback
from typing import Any
from app.config import settings
from app.models.signalr import UserState
from app.signalr.exception import InvokeException
from app.signalr.packet import (
ClosePacket,
CompletionPacket,
InvocationPacket,
Packet,
PingPacket,
Protocol,
)
from app.signalr.store import ResultStore
from app.signalr.utils import get_signature
from fastapi import WebSocket
from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
class CloseConnection(Exception):
def __init__(
self,
message: str = "Connection closed",
allow_reconnect: bool = False,
from_client: bool = False,
) -> None:
super().__init__(message)
self.message = message
self.allow_reconnect = allow_reconnect
self.from_client = from_client
class Client:
def __init__(
self,
connection_id: str,
connection_token: str,
connection: WebSocket,
protocol: Protocol,
) -> None:
self.connection_id = connection_id
self.connection_token = connection_token
self.connection = connection
self.procotol = protocol
self._listen_task: asyncio.Task | None = None
self._ping_task: asyncio.Task | None = None
self._store = ResultStore()
def __hash__(self) -> int:
return hash(self.connection_token)
@property
def user_id(self) -> int:
return int(self.connection_id)
async def send_packet(self, packet: Packet):
await self.connection.send_bytes(self.procotol.encode(packet))
async def receive_packets(self) -> list[Packet]:
message = await self.connection.receive()
d = message.get("bytes") or message.get("text", "").encode()
if not d:
return []
return self.procotol.decode(d)
async def _ping(self):
while True:
try:
await self.send_packet(PingPacket())
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
except WebSocketDisconnect:
break
except Exception as e:
print(f"Error in ping task for {self.connection_id}: {e}")
break
class Hub[TState: UserState]:
def __init__(self) -> None:
self.clients: dict[str, Client] = {}
self.waited_clients: dict[str, int] = {}
self.tasks: set[asyncio.Task] = set()
self.groups: dict[str, set[Client]] = {}
self.state: dict[int, TState] = {}
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
self.waited_clients[connection_token] = timestamp
def get_client_by_id(self, id: str, default: Any = None) -> Client:
for client in self.clients.values():
if client.connection_id == id:
return client
return default
@abstractmethod
def create_state(self, client: Client) -> TState:
raise NotImplementedError
def get_or_create_state(self, client: Client) -> TState:
if (state := self.state.get(client.user_id)) is not None:
return state
state = self.create_state(client)
self.state[client.user_id] = state
return state
def add_to_group(self, client: Client, group_id: str) -> None:
self.groups.setdefault(group_id, set()).add(client)
def remove_from_group(self, client: Client, group_id: str) -> None:
if group_id in self.groups:
self.groups[group_id].discard(client)
async def add_client(
self,
connection_id: str,
connection_token: str,
protocol: Protocol,
connection: WebSocket,
) -> Client:
if connection_token in self.clients:
raise ValueError(
f"Client with connection token {connection_token} already exists."
)
if connection_token in self.waited_clients:
if (
self.waited_clients[connection_token]
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT
):
raise TimeoutError(f"Connection {connection_id} has waited too long.")
del self.waited_clients[connection_token]
client = Client(connection_id, connection_token, connection, protocol)
self.clients[connection_token] = client
task = asyncio.create_task(client._ping())
self.tasks.add(task)
client._ping_task = task
return client
async def remove_client(self, client: Client) -> None:
del self.clients[client.connection_token]
if client._listen_task:
client._listen_task.cancel()
if client._ping_task:
client._ping_task.cancel()
for group in self.groups.values():
group.discard(client)
await self.clean_state(client, False)
@abstractmethod
async def _clean_state(self, state: TState) -> None:
return
async def clean_state(self, client: Client, disconnected: bool) -> None:
if (state := self.state.get(client.user_id)) is None:
return
if disconnected and client.connection_token != state.connection_token:
return
try:
await self._clean_state(state)
except Exception:
...
async def on_connect(self, client: Client) -> None:
if method := getattr(self, "on_client_connect", None):
await method(client)
async def send_packet(self, client: Client, packet: Packet) -> None:
await client.send_packet(packet)
async def broadcast_call(self, method: str, *args: Any) -> None:
tasks = []
for client in self.clients.values():
tasks.append(self.call_noblock(client, method, *args))
await asyncio.gather(*tasks)
async def broadcast_group_call(
self, group_id: str, method: str, *args: Any
) -> None:
tasks = []
for client in self.groups.get(group_id, []):
tasks.append(self.call_noblock(client, method, *args))
await asyncio.gather(*tasks)
async def _listen_client(self, client: Client) -> None:
try:
while True:
packets = await client.receive_packets()
for packet in packets:
if isinstance(packet, PingPacket):
continue
elif isinstance(packet, ClosePacket):
raise CloseConnection(
packet.error or "Connection closed by client",
packet.allow_reconnect,
True,
)
task = asyncio.create_task(self._handle_packet(client, packet))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
except WebSocketDisconnect as e:
print(f"Client {client.connection_id} disconnected: {e.code}, {e.reason}")
except RuntimeError as e:
if "disconnect message" in str(e):
print(f"Client {client.connection_id} closed the connection.")
else:
traceback.print_exc()
print(f"RuntimeError in client {client.connection_id}: {e}")
except CloseConnection as e:
if not e.from_client:
await client.send_packet(
ClosePacket(error=e.message, allow_reconnect=e.allow_reconnect)
)
print(f"Client {client.connection_id} closed the connection: {e.message}")
except Exception as e:
traceback.print_exc()
print(f"Error in client {client.connection_id}: {e}")
await self.remove_client(client)
async def _handle_packet(self, client: Client, packet: Packet) -> None:
if isinstance(packet, PingPacket):
return
elif isinstance(packet, InvocationPacket):
args = packet.arguments or []
error = None
result = None
try:
result = await self.invoke_method(client, packet.target, args)
except InvokeException as e:
error = e.message
except Exception as e:
traceback.print_exc()
error = str(e)
if packet.invocation_id is not None:
await client.send_packet(
CompletionPacket(
invocation_id=packet.invocation_id,
error=error,
result=result,
)
)
elif isinstance(packet, CompletionPacket):
client._store.add_result(packet.invocation_id, packet.result, packet.error)
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
method_ = getattr(self, method, None)
call_params = []
if not method_:
raise InvokeException(f"Method '{method}' not found in hub.")
signature = get_signature(method_)
for name, param in signature.parameters.items():
if name == "self" or param.annotation is Client:
continue
if issubclass(param.annotation, BaseModel):
call_params.append(param.annotation.model_validate(args.pop(0)))
else:
call_params.append(args.pop(0))
return await method_(client, *call_params)
async def call(self, client: Client, method: str, *args: Any) -> Any:
invocation_id = client._store.get_invocation_id()
await client.send_packet(
InvocationPacket(
header={},
invocation_id=invocation_id,
target=method,
arguments=list(args),
stream_ids=None,
)
)
r = await client._store.fetch(invocation_id, None)
if r[1]:
raise InvokeException(r[1])
return r[0]
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
await client.send_packet(
InvocationPacket(
header={},
invocation_id=None,
target=method,
arguments=list(args),
stream_ids=None,
)
)
return None
def __contains__(self, item: str) -> bool:
return item in self.clients or item in self.waited_clients

152
app/signalr/hub/metadata.py Normal file
View File

@@ -0,0 +1,152 @@
from __future__ import annotations
import asyncio
from collections.abc import Coroutine
from typing import override
from app.database.relationship import Relationship, RelationshipType
from app.dependencies.database import engine
from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity
from .hub import Client, Hub
from pydantic import TypeAdapter
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
class MetadataHub(Hub[MetadataClientState]):
def __init__(self) -> None:
super().__init__()
@staticmethod
def online_presence_watchers_group() -> str:
return ONLINE_PRESENCE_WATCHERS_GROUP
def broadcast_tasks(
self, user_id: int, store: MetadataClientState | None
) -> set[Coroutine]:
if store is not None and not store.pushable:
return set()
data = store.to_dict() if store else None
return {
self.broadcast_group_call(
self.online_presence_watchers_group(),
"UserPresenceUpdated",
user_id,
data,
),
self.broadcast_group_call(
self.friend_presence_watchers_group(user_id),
"FriendPresenceUpdated",
user_id,
data,
),
}
@staticmethod
def friend_presence_watchers_group(user_id: int):
return f"metadata:friend-presence-watchers:{user_id}"
@override
async def _clean_state(self, state: MetadataClientState) -> None:
if state.pushable:
await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None))
@override
def create_state(self, client: Client) -> MetadataClientState:
return MetadataClientState(
connection_id=client.connection_id,
connection_token=client.connection_token,
)
async def on_client_connect(self, client: Client) -> None:
user_id = int(client.connection_id)
self.get_or_create_state(client)
async with AsyncSession(engine) as session:
async with session.begin():
friends = (
await session.exec(
select(Relationship.target_id).where(
Relationship.user_id == user_id,
Relationship.type == RelationshipType.FOLLOW,
)
)
).all()
tasks = []
for friend_id in friends:
self.groups.setdefault(
self.friend_presence_watchers_group(friend_id), set()
).add(client)
if (
friend_state := self.state.get(friend_id)
) and friend_state.pushable:
print("Pushed")
tasks.append(
self.broadcast_group_call(
self.friend_presence_watchers_group(friend_id),
"FriendPresenceUpdated",
friend_id,
friend_state.to_dict(),
)
)
await asyncio.gather(*tasks)
async def UpdateStatus(self, client: Client, status: int) -> None:
status_ = OnlineStatus(status)
user_id = int(client.connection_id)
store = self.get_or_create_state(client)
if store.status is not None and store.status == status_:
return
store.status = OnlineStatus(status_)
tasks = self.broadcast_tasks(user_id, store)
tasks.add(
self.call_noblock(
client,
"UserPresenceUpdated",
user_id,
store.to_dict(),
)
)
await asyncio.gather(*tasks)
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
user_id = int(client.connection_id)
activity = (
TypeAdapter(UserActivity).validate_python(activity_dict)
if activity_dict
else None
)
store = self.get_or_create_state(client)
store.user_activity = activity
tasks = self.broadcast_tasks(user_id, store)
tasks.add(
self.call_noblock(
client,
"UserPresenceUpdated",
user_id,
store.to_dict(),
)
)
await asyncio.gather(*tasks)
async def BeginWatchingUserPresence(self, client: Client) -> None:
await asyncio.gather(
*[
self.call_noblock(
client,
"UserPresenceUpdated",
user_id,
store.to_dict(),
)
for user_id, store in self.state.items()
if store.pushable
]
)
self.add_to_group(client, self.online_presence_watchers_group())
async def EndWatchingUserPresence(self, client: Client) -> None:
self.remove_from_group(client, self.online_presence_watchers_group())

View File

@@ -0,0 +1,357 @@
from __future__ import annotations
import asyncio
import json
import lzma
import struct
import time
from typing import override
from app.database import Beatmap
from app.database.score import Score
from app.database.score_token import ScoreToken
from app.database.user import User
from app.dependencies.database import engine
from app.models.beatmap import BeatmapRankStatus
from app.models.mods import mods_to_int
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
from app.models.signalr import serialize_to_list
from app.models.spectator_hub import (
APIUser,
FrameDataBundle,
LegacyReplayFrame,
ScoreInfo,
SpectatedUserState,
SpectatorState,
StoreClientState,
StoreScore,
)
from app.path import REPLAY_DIR
from app.utils import unix_timestamp_to_windows
from .hub import Client, Hub
from sqlalchemy.orm import joinedload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
READ_SCORE_TIMEOUT = 30
REPLAY_LATEST_VER = 30000016
def encode_uleb128(num: int) -> bytes | bytearray:
if num == 0:
return b"\x00"
ret = bytearray()
while num != 0:
ret.append(num & 0x7F)
num >>= 7
if num != 0:
ret[-1] |= 0x80
return ret
def encode_string(s: str) -> bytes:
"""Write `s` into bytes (ULEB128 & string)."""
if s:
encoded = s.encode()
ret = b"\x0b" + encode_uleb128(len(encoded)) + encoded
else:
ret = b"\x00"
return ret
def save_replay(
ruleset_id: int,
md5: str,
username: str,
score: Score,
statistics: ScoreStatisticsInt,
maximum_statistics: ScoreStatisticsInt,
frames: list[LegacyReplayFrame],
) -> None:
data = bytearray()
data.extend(struct.pack("<bi", ruleset_id, REPLAY_LATEST_VER))
data.extend(encode_string(md5))
data.extend(encode_string(username))
data.extend(encode_string(f"lazer-{username}-{score.started_at.isoformat()}"))
data.extend(
struct.pack(
"<hhhhhhihbi",
score.n300,
score.n100,
score.n50,
score.ngeki,
score.nkatu,
score.nmiss,
score.total_score,
score.max_combo,
score.is_perfect_combo,
mods_to_int(score.mods),
)
)
data.extend(encode_string("")) # hp graph
data.extend(
struct.pack(
"<q",
unix_timestamp_to_windows(round(score.started_at.timestamp())),
)
)
# write frames
# FIXME: cannot play in stable
frame_strs = []
last_time = 0
for frame in frames:
frame_strs.append(
f"{frame.time - last_time}|{frame.x or 0.0}"
f"|{frame.y or 0.0}|{frame.button_state}"
)
last_time = frame.time
frame_strs.append("-12345|0|0|0")
compressed = lzma.compress(
",".join(frame_strs).encode("ascii"), format=lzma.FORMAT_ALONE
)
data.extend(struct.pack("<i", len(compressed)))
data.extend(compressed)
data.extend(struct.pack("<q", score.id))
assert score.id
score_info = LegacyReplaySoloScoreInfo(
online_id=score.id,
mods=score.mods,
statistics=statistics,
maximum_statistics=maximum_statistics,
client_version="",
rank=score.rank,
user_id=score.user_id,
total_score_without_mods=score.total_score_without_mods,
)
compressed = lzma.compress(
json.dumps(score_info).encode(), format=lzma.FORMAT_ALONE
)
data.extend(struct.pack("<i", len(compressed)))
data.extend(compressed)
replay_path = REPLAY_DIR / f"lazer-{score.type}-{username}-{score.id}.osr"
replay_path.write_bytes(data)
class SpectatorHub(Hub[StoreClientState]):
@staticmethod
def group_id(user_id: int) -> str:
return f"watch:{user_id}"
@override
def create_state(self, client: Client) -> StoreClientState:
return StoreClientState(
connection_id=client.connection_id,
connection_token=client.connection_token,
)
@override
async def _clean_state(self, state: StoreClientState) -> None:
if state.state:
await self._end_session(int(state.connection_id), state.state)
for target in self.waited_clients:
target_client = self.get_client_by_id(target)
if target_client:
await self.call_noblock(
target_client, "UserEndedWatching", int(state.connection_id)
)
async def on_client_connect(self, client: Client) -> None:
tasks = [
self.call_noblock(
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
)
for user_id, store in self.state.items()
if store.state is not None
]
await asyncio.gather(*tasks)
async def BeginPlaySession(
self, client: Client, score_token: int, state: SpectatorState
) -> None:
user_id = int(client.connection_id)
store = self.get_or_create_state(client)
if store.state is not None:
return
if state.beatmap_id is None or state.ruleset_id is None:
return
async with AsyncSession(engine) as session:
async with session.begin():
beatmap = (
await session.exec(
select(Beatmap).where(Beatmap.id == state.beatmap_id)
)
).first()
if not beatmap:
return
user = (
await session.exec(select(User).where(User.id == user_id))
).first()
if not user:
return
name = user.name
store.state = state
store.beatmap_status = beatmap.beatmap_status
store.checksum = beatmap.checksum
store.ruleset_id = state.ruleset_id
store.score_token = score_token
store.score = StoreScore(
score_info=ScoreInfo(
mods=state.mods,
user=APIUser(id=user_id, name=name),
ruleset=state.ruleset_id,
maximum_statistics=state.maximum_statistics,
)
)
await self.broadcast_group_call(
self.group_id(user_id),
"UserBeganPlaying",
user_id,
serialize_to_list(state),
)
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
user_id = int(client.connection_id)
state = self.get_or_create_state(client)
if not state.score:
return
state.score.score_info.acc = frame_data.header.acc
state.score.score_info.combo = frame_data.header.combo
state.score.score_info.max_combo = frame_data.header.max_combo
state.score.score_info.statistics = frame_data.header.statistics
state.score.score_info.total_score = frame_data.header.total_score
state.score.score_info.mods = frame_data.header.mods
state.score.replay_frames.extend(frame_data.frames)
await self.broadcast_group_call(
self.group_id(user_id),
"UserSentFrames",
user_id,
frame_data.model_dump(),
)
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
user_id = int(client.connection_id)
store = self.get_or_create_state(client)
score = store.score
if not score or not store.score_token:
return
assert store.beatmap_status is not None
async def _save_replay():
assert store.checksum is not None
assert store.ruleset_id is not None
assert store.state is not None
assert store.score is not None
async with AsyncSession(engine) as session:
async with session:
start_time = time.time()
score_record = None
while time.time() - start_time < READ_SCORE_TIMEOUT:
sub_query = select(ScoreToken.score_id).where(
ScoreToken.id == store.score_token,
)
result = await session.exec(
select(Score)
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
.where(
Score.id == sub_query,
Score.user_id == user_id,
)
)
score_record = result.first()
if score_record:
break
if not score_record:
return
if not score_record.passed:
return
score_record.has_replay = True
await session.commit()
await session.refresh(score_record)
save_replay(
ruleset_id=store.ruleset_id,
md5=store.checksum,
username=store.score.score_info.user.name,
score=score_record,
statistics=score.score_info.statistics,
maximum_statistics=score.score_info.maximum_statistics,
frames=score.replay_frames,
)
if (
(
BeatmapRankStatus.PENDING
< store.beatmap_status
<= BeatmapRankStatus.LOVED
)
and any(
k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()
)
and state.state != SpectatedUserState.Failed
):
# save replay
await _save_replay()
store.state = None
store.beatmap_status = None
store.checksum = None
store.ruleset_id = None
store.score_token = None
store.score = None
await self._end_session(user_id, state)
async def _end_session(self, user_id: int, state: SpectatorState) -> None:
if state.state == SpectatedUserState.Playing:
state.state = SpectatedUserState.Quit
await self.broadcast_group_call(
self.group_id(user_id),
"UserFinishedPlaying",
user_id,
serialize_to_list(state) if state else None,
)
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
print(f"StartWatchingUser -> {client.connection_id} {target_id}")
user_id = int(client.connection_id)
target_store = self.get_or_create_state(client)
if target_store.state:
await self.call_noblock(
client,
"UserBeganPlaying",
target_id,
serialize_to_list(target_store.state),
)
store = self.get_or_create_state(client)
store.watched_user.add(target_id)
self.add_to_group(client, self.group_id(target_id))
async with AsyncSession(engine) as session:
async with session.begin():
username = (
await session.exec(select(User.name).where(User.id == user_id))
).first()
if not username:
return
if (target_client := self.get_client_by_id(str(target_id))) is not None:
await self.call_noblock(
target_client, "UserStartedWatching", [[user_id, username]]
)
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
print(f"EndWatchingUser -> {client.connection_id} {target_id}")
user_id = int(client.connection_id)
self.remove_from_group(client, self.group_id(target_id))
store = self.state.get(user_id)
if store:
store.watched_user.discard(target_id)
if (target_client := self.get_client_by_id(str(target_id))) is not None:
await self.call_noblock(target_client, "UserEndedWatching", user_id)

277
app/signalr/packet.py Normal file
View File

@@ -0,0 +1,277 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
import json
from typing import (
Any,
Protocol as TypingProtocol,
)
import msgpack
SEP = b"\x1e"
class PacketType(IntEnum):
INVOCATION = 1
STREAM_ITEM = 2
COMPLETION = 3
STREAM_INVOCATION = 4
CANCEL_INVOCATION = 5
PING = 6
CLOSE = 7
@dataclass(kw_only=True)
class Packet:
type: PacketType
header: dict[str, Any] | None = None
@dataclass(kw_only=True)
class InvocationPacket(Packet):
type: PacketType = PacketType.INVOCATION
invocation_id: str | None
target: str
arguments: list[Any] | None = None
stream_ids: list[str] | None = None
@dataclass(kw_only=True)
class CompletionPacket(Packet):
type: PacketType = PacketType.COMPLETION
invocation_id: str
result: Any
error: str | None = None
@dataclass(kw_only=True)
class PingPacket(Packet):
type: PacketType = PacketType.PING
@dataclass(kw_only=True)
class ClosePacket(Packet):
type: PacketType = PacketType.CLOSE
error: str | None = None
allow_reconnect: bool = False
PACKETS = {
PacketType.INVOCATION: InvocationPacket,
PacketType.COMPLETION: CompletionPacket,
PacketType.PING: PingPacket,
PacketType.CLOSE: ClosePacket,
}
class Protocol(TypingProtocol):
@staticmethod
def decode(input: bytes) -> list[Packet]: ...
@staticmethod
def encode(packet: Packet) -> bytes: ...
class MsgpackProtocol:
@staticmethod
def _encode_varint(value: int) -> bytes:
result = []
while value >= 0x80:
result.append((value & 0x7F) | 0x80)
value >>= 7
result.append(value & 0x7F)
return bytes(result)
@staticmethod
def _decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
result = 0
shift = 0
pos = offset
while pos < len(data):
byte = data[pos]
result |= (byte & 0x7F) << shift
pos += 1
if (byte & 0x80) == 0:
break
shift += 7
return result, pos
@staticmethod
def decode(input: bytes) -> list[Packet]:
length, offset = MsgpackProtocol._decode_varint(input)
message_data = input[offset : offset + length]
# FIXME: custom deserializer for APIMod
# https://github.com/ppy/osu/blob/master/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
unpacked = msgpack.unpackb(
message_data, raw=False, strict_map_key=False, use_list=True
)
packet_type = PacketType(unpacked[0])
if packet_type not in PACKETS:
raise ValueError(f"Unknown packet type: {packet_type}")
match packet_type:
case PacketType.INVOCATION:
return [
InvocationPacket(
header=unpacked[1],
invocation_id=unpacked[2],
target=unpacked[3],
arguments=unpacked[4] if len(unpacked) > 4 else None,
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
)
]
case PacketType.COMPLETION:
result_kind = unpacked[3]
return [
CompletionPacket(
header=unpacked[1],
invocation_id=unpacked[2],
error=unpacked[4] if result_kind == 1 else None,
result=unpacked[5] if result_kind == 3 else None,
)
]
case PacketType.PING:
return [PingPacket()]
case PacketType.CLOSE:
return [
ClosePacket(
error=unpacked[1],
allow_reconnect=unpacked[2] if len(unpacked) > 2 else False,
)
]
raise ValueError(f"Unsupported packet type: {packet_type}")
@staticmethod
def encode(packet: Packet) -> bytes:
payload = [packet.type.value, packet.header or {}]
if isinstance(packet, InvocationPacket):
payload.extend(
[
packet.invocation_id,
packet.target,
]
)
if packet.arguments is not None:
payload.append(packet.arguments)
if packet.stream_ids is not None:
payload.append(packet.stream_ids)
elif isinstance(packet, CompletionPacket):
result_kind = 2
if packet.error:
result_kind = 1
elif packet.result is None:
result_kind = 3
payload.extend(
[
packet.invocation_id,
result_kind,
packet.error or packet.result or None,
]
)
elif isinstance(packet, ClosePacket):
payload.extend(
[
packet.error or "",
packet.allow_reconnect,
]
)
elif isinstance(packet, PingPacket):
payload.pop(-1)
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
return MsgpackProtocol._encode_varint(len(data)) + data
class JSONProtocol:
@staticmethod
def decode(input: bytes) -> list[Packet]:
packets_raw = input.removesuffix(SEP).split(SEP)
packets = []
if len(packets_raw) > 1:
for packet_raw in packets_raw:
packets.extend(JSONProtocol.decode(packet_raw))
return packets
else:
data = json.loads(packets_raw[0])
packet_type = PacketType(data["type"])
if packet_type not in PACKETS:
raise ValueError(f"Unknown packet type: {packet_type}")
match packet_type:
case PacketType.INVOCATION:
return [
InvocationPacket(
header=data.get("header"),
invocation_id=data.get("invocationId"),
target=data["target"],
arguments=data.get("arguments"),
stream_ids=data.get("streamIds"),
)
]
case PacketType.COMPLETION:
return [
CompletionPacket(
header=data.get("header"),
invocation_id=data["invocationId"],
error=data.get("error"),
result=data.get("result"),
)
]
case PacketType.PING:
return [PingPacket()]
case PacketType.CLOSE:
return [
ClosePacket(
error=data.get("error"),
allow_reconnect=data.get("allowReconnect", False),
)
]
raise ValueError(f"Unsupported packet type: {packet_type}")
@staticmethod
def encode(packet: Packet) -> bytes:
payload: dict[str, Any] = {
"type": packet.type.value,
}
if packet.header:
payload["header"] = packet.header
if isinstance(packet, InvocationPacket):
payload.update(
{
"target": packet.target,
}
)
if packet.invocation_id is not None:
payload["invocationId"] = packet.invocation_id
if packet.arguments is not None:
payload["arguments"] = packet.arguments
if packet.stream_ids is not None:
payload["streamIds"] = packet.stream_ids
elif isinstance(packet, CompletionPacket):
payload.update(
{
"invocationId": packet.invocation_id,
}
)
if packet.error is not None:
payload["error"] = packet.error
if packet.result is not None:
payload["result"] = packet.result
elif isinstance(packet, PingPacket):
pass
elif isinstance(packet, ClosePacket):
payload.update(
{
"allowReconnect": packet.allow_reconnect,
}
)
if packet.error is not None:
payload["error"] = packet.error
return json.dumps(payload).encode("utf-8") + SEP
PROTOCOLS: dict[str, Protocol] = {
"json": JSONProtocol,
"messagepack": MsgpackProtocol,
}

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import json
import time
from typing import Literal
@@ -10,9 +11,9 @@ from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user_by_token
from app.models.signalr import NegotiateResponse, Transport
from app.router.signalr.packet import SEP
from .hub import Hubs
from .packet import PROTOCOLS, SEP
from fastapi import APIRouter, Depends, Header, Query, WebSocket
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -62,30 +63,41 @@ async def connect(
await websocket.accept()
# handshake
handshake = await websocket.receive_bytes()
handshake_payload = json.loads(handshake[:-1])
handshake = await websocket.receive()
message = handshake.get("bytes") or handshake.get("text")
if not message:
await websocket.close(code=1008)
return
handshake_payload = json.loads(message[:-1])
error = ""
if (protocol := handshake_payload.get("protocol")) != "messagepack" or (
handshake_payload.get("version")
) != 1:
error = f"Requested protocol '{protocol}' is not available."
protocol = handshake_payload.get("protocol", "json")
client = None
try:
client = hub_.add_client(
client = await hub_.add_client(
connection_id=user_id,
connection_token=id,
connection=websocket,
protocol=PROTOCOLS[protocol],
)
except KeyError:
error = f"Protocol '{protocol}' is not supported."
except TimeoutError:
error = f"Connection {id} has waited too long."
except ValueError as e:
error = str(e)
payload = {"error": error} if error else {}
# finish handshake
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
if error or not client:
await websocket.close(code=1008)
return
await hub_.clean_state(client, False)
task = asyncio.create_task(hub_.on_connect(client))
hub_.tasks.add(task)
task.add_done_callback(hub_.tasks.discard)
await hub_._listen_client(client)
try:
await websocket.close()
except Exception:
...

View File

@@ -2,9 +2,7 @@ from __future__ import annotations
import asyncio
import sys
from typing import Any, Literal
from app.router.signalr.packet import ResultKind
from typing import Any
class ResultStore:
@@ -22,21 +20,17 @@ class ResultStore:
return str(s)
def add_result(
self, invocation_id: str, type: ResultKind, result: dict[str, Any] | None
self, invocation_id: str, result: Any, error: str | None = None
) -> None:
if isinstance(invocation_id, str) and invocation_id.isdecimal():
if future := self._futures.get(invocation_id):
future.set_result((type, result))
future.set_result((result, error))
async def fetch(
self,
invocation_id: str,
timeout: float | None, # noqa: ASYNC109
) -> (
tuple[Literal[ResultKind.ERROR], str]
| tuple[Literal[ResultKind.VOID], None]
| tuple[Literal[ResultKind.HAS_VALUE], Any]
):
) -> tuple[Any, str | None]:
future = asyncio.get_event_loop().create_future()
self._futures[invocation_id] = future
try:

View File

@@ -2,24 +2,20 @@ from __future__ import annotations
from collections.abc import Callable
import inspect
import sys
from typing import Any, ForwardRef, cast
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L61-L75
if sys.version_info < (3, 12, 4):
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
def evaluate_forwardref(
type_: ForwardRef,
globalns: Any,
localns: Any,
) -> Any:
# Even though it is the right signature for python 3.9,
# mypy complains with
# `error: Too many arguments for "_evaluate" of
# "ForwardRef"` hence the cast...
return cast(Any, type_)._evaluate(
globalns,
localns,
set(),
)
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
else:
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
return cast(Any, type_)._evaluate(
globalns, localns, type_params=(), recursive_guard=set()
)
def get_annotation(param: inspect.Parameter, globalns: dict[str, Any]) -> Any:

View File

@@ -28,6 +28,11 @@ from app.models.user import (
import rosu_pp_py as rosu
def unix_timestamp_to_windows(timestamp: int) -> int:
"""Convert a Unix timestamp to a Windows timestamp."""
return (timestamp + 62135596800) * 10_000_000
async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User:
"""将数据库用户模型转换为API用户模型使用 Lazer 表)"""
@@ -205,7 +210,7 @@ async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") ->
# 转换团队信息
team = None
if db_user.team_membership:
team_member = db_user.team_membership[0] # 假设用户只属于一个团队
team_member = db_user.team_membership # 假设用户只属于一个团队
team = team_member.team
# 创建用户对象

View File

@@ -77,7 +77,7 @@ mark-parentheses = false
keep-runtime-typing = true
[tool.pyright]
pythonVersion = "3.11"
pythonVersion = "3.12"
pythonPlatform = "All"
typeCheckingMode = "standard"

5
static/README.md Normal file
View File

@@ -0,0 +1,5 @@
# 静态文件
- `mods.json`: 包含了游戏中的所有可用mod的详细信息。
- Origin: https://github.com/ppy/osu-web/blob/master/database/mods.json
- Version: 2025/6/10 `b68c920b1db3d443b9302fdc3f86010c875fe380`

3656
static/mods.json Normal file

File diff suppressed because it is too large Load Diff