Merge branch 'main' into feat/multiplayer-api
This commit is contained in:
@@ -10,6 +10,7 @@ from app.database import (
|
||||
OAuthToken,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
@@ -47,8 +48,8 @@ def verify_password_legacy(plain_password: str, bcrypt_hash: str) -> bool:
|
||||
bcrypt_cache[bcrypt_hash] = pw_md5
|
||||
|
||||
return is_valid
|
||||
except Exception as e:
|
||||
print(f"Password verification error: {e}")
|
||||
except Exception:
|
||||
logger.exception("Password verification error")
|
||||
return False
|
||||
|
||||
|
||||
@@ -104,8 +105,8 @@ async def authenticate_user_legacy(
|
||||
# 缓存验证结果
|
||||
bcrypt_cache[user.pw_bcrypt] = pw_md5.encode()
|
||||
return user
|
||||
except Exception as e:
|
||||
print(f"Authentication error for user {name}: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"Authentication error for user {name}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class Settings:
|
||||
|
||||
# SignalR 设置
|
||||
SIGNALR_NEGOTIATE_TIMEOUT: int = int(os.getenv("SIGNALR_NEGOTIATE_TIMEOUT", "30"))
|
||||
SIGNALR_PING_INTERVAL: int = int(os.getenv("SIGNALR_PING_INTERVAL", "120"))
|
||||
SIGNALR_PING_INTERVAL: int = int(os.getenv("SIGNALR_PING_INTERVAL", "15"))
|
||||
|
||||
# Fetcher 设置
|
||||
FETCHER_CLIENT_ID: str = os.getenv("FETCHER_CLIENT_ID", "")
|
||||
@@ -44,5 +44,8 @@ class Settings:
|
||||
"FETCHER_CALLBACK_URL", "http://localhost:8000/fetcher/callback"
|
||||
)
|
||||
|
||||
# 日志设置
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -9,6 +9,13 @@ from .beatmapset import (
|
||||
)
|
||||
from .legacy import LegacyOAuthToken, LegacyUserStatistics
|
||||
from .relationship import Relationship, RelationshipResp, RelationshipType
|
||||
from .score import (
|
||||
Score,
|
||||
ScoreBase,
|
||||
ScoreResp,
|
||||
ScoreStatistics,
|
||||
)
|
||||
from .score_token import ScoreToken, ScoreTokenResp
|
||||
from .team import Team, TeamMember
|
||||
from .user import (
|
||||
DailyChallengeStats,
|
||||
@@ -57,6 +64,12 @@ __all__ = [
|
||||
"Relationship",
|
||||
"RelationshipResp",
|
||||
"RelationshipType",
|
||||
"Score",
|
||||
"ScoreBase",
|
||||
"ScoreResp",
|
||||
"ScoreStatistics",
|
||||
"ScoreToken",
|
||||
"ScoreTokenResp",
|
||||
"Team",
|
||||
"TeamMember",
|
||||
"User",
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -12,7 +12,9 @@ class OAuthToken(SQLModel, table=True):
|
||||
__tablename__ = "oauth_tokens" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("users.id"), index=True)
|
||||
)
|
||||
access_token: str = Field(max_length=500, unique=True)
|
||||
refresh_token: str = Field(max_length=500, unique=True)
|
||||
token_type: str = Field(default="Bearer", max_length=20)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.score import MODE_TO_INT, GameMode
|
||||
|
||||
@@ -11,6 +11,9 @@ from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import VARCHAR, Field, Relationship, SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.fetcher import Fetcher
|
||||
|
||||
|
||||
class BeatmapOwner(SQLModel):
|
||||
id: int
|
||||
@@ -65,6 +68,10 @@ class Beatmap(BeatmapBase, table=True):
|
||||
# optional
|
||||
beatmapset: Beatmapset = Relationship(back_populates="beatmaps")
|
||||
|
||||
@property
|
||||
def can_ranked(self) -> bool:
|
||||
return self.beatmap_status > BeatmapRankStatus.PENDING
|
||||
|
||||
@classmethod
|
||||
async def from_resp(cls, session: AsyncSession, resp: "BeatmapResp") -> "Beatmap":
|
||||
d = resp.model_dump()
|
||||
@@ -79,7 +86,16 @@ class Beatmap(BeatmapBase, table=True):
|
||||
)
|
||||
session.add(beatmap)
|
||||
await session.commit()
|
||||
await session.refresh(beatmap)
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
.where(Beatmap.id == resp.id)
|
||||
)
|
||||
).first()
|
||||
assert beatmap is not None, "Beatmap should not be None after commit"
|
||||
return beatmap
|
||||
|
||||
@classmethod
|
||||
@@ -107,19 +123,25 @@ class Beatmap(BeatmapBase, table=True):
|
||||
|
||||
@classmethod
|
||||
async def get_or_fetch(
|
||||
cls, session: AsyncSession, bid: int, fetcher: Fetcher
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
fetcher: "Fetcher",
|
||||
bid: int | None = None,
|
||||
md5: str | None = None,
|
||||
) -> "Beatmap":
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap)
|
||||
.where(Beatmap.id == bid)
|
||||
.where(
|
||||
Beatmap.id == bid if bid is not None else Beatmap.checksum == md5
|
||||
)
|
||||
.options(
|
||||
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
resp = await fetcher.get_beatmap(bid)
|
||||
resp = await fetcher.get_beatmap(bid, md5)
|
||||
r = await session.exec(
|
||||
select(Beatmapset.id).where(Beatmapset.id == resp.beatmapset_id)
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
from app.models.beatmap import BeatmapRankStatus, Genre, Language
|
||||
from app.models.score import GameMode
|
||||
|
||||
from pydantic import BaseModel, model_serializer
|
||||
from sqlalchemy import DECIMAL, JSON, Column, DateTime, Text
|
||||
@@ -68,7 +69,7 @@ class BeatmapNomination(TypedDict):
|
||||
beatmapset_id: int
|
||||
reset: bool
|
||||
user_id: int
|
||||
rulesets: list[str] | None
|
||||
rulesets: list[GameMode] | None
|
||||
|
||||
|
||||
class BeatmapDescription(SQLModel):
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -16,7 +16,7 @@ class LegacyUserStatistics(SQLModel, table=True):
|
||||
__tablename__ = "user_statistics" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
mode: str = Field(max_length=10) # osu, taiko, fruits, mania
|
||||
|
||||
# 基本统计
|
||||
@@ -77,7 +77,7 @@ class LegacyOAuthToken(SQLModel, table=True):
|
||||
__tablename__ = "legacy_oauth_tokens" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
access_token: str = Field(max_length=255, index=True)
|
||||
refresh_token: str = Field(max_length=255, index=True)
|
||||
expires_at: datetime = Field(sa_column=Column(DateTime))
|
||||
|
||||
@@ -4,7 +4,10 @@ from .user import User
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import (
|
||||
BigInteger,
|
||||
Column,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship as SQLRelationship,
|
||||
SQLModel,
|
||||
select,
|
||||
@@ -20,10 +23,22 @@ class RelationshipType(str, Enum):
|
||||
class Relationship(SQLModel, table=True):
|
||||
__tablename__ = "relationship" # pyright: ignore[reportAssignmentType]
|
||||
user_id: int = Field(
|
||||
default=None, foreign_key="users.id", primary_key=True, index=True
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
target_id: int = Field(
|
||||
default=None, foreign_key="users.id", primary_key=True, index=True
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
|
||||
target: "User" = SQLRelationship(
|
||||
|
||||
@@ -2,15 +2,36 @@ from datetime import datetime
|
||||
import math
|
||||
|
||||
from app.database.user import User
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import APIMod
|
||||
from app.models.score import MODE_TO_INT, GameMode, Rank
|
||||
from app.models.score import (
|
||||
MODE_TO_INT,
|
||||
GameMode,
|
||||
HitResult,
|
||||
LeaderboardType,
|
||||
Rank,
|
||||
ScoreStatistics,
|
||||
)
|
||||
|
||||
from .beatmap import Beatmap, BeatmapResp
|
||||
from .beatmapset import BeatmapsetResp
|
||||
from .beatmapset import Beatmapset, BeatmapsetResp
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlmodel import JSON, BigInteger, Field, Relationship, SQLModel
|
||||
from sqlalchemy import Column, ColumnExpressionArgument, DateTime
|
||||
from sqlalchemy.orm import aliased, joinedload
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Field,
|
||||
ForeignKey,
|
||||
Relationship,
|
||||
SQLModel,
|
||||
col,
|
||||
false,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql._expression_select_cls import SelectOfScalar
|
||||
|
||||
|
||||
class ScoreBase(SQLModel):
|
||||
@@ -34,6 +55,9 @@ class ScoreBase(SQLModel):
|
||||
room_id: int | None = Field(default=None) # multiplayer
|
||||
started_at: datetime = Field(sa_column=Column(DateTime))
|
||||
total_score: int = Field(default=0, sa_column=Column(BigInteger))
|
||||
total_score_without_mods: int = Field(
|
||||
default=0, sa_column=Column(BigInteger), exclude=True
|
||||
)
|
||||
type: str
|
||||
|
||||
# optional
|
||||
@@ -41,22 +65,20 @@ class ScoreBase(SQLModel):
|
||||
position: int | None = Field(default=None) # multiplayer
|
||||
|
||||
|
||||
class ScoreStatistics(BaseModel):
|
||||
count_miss: int
|
||||
count_50: int
|
||||
count_100: int
|
||||
count_300: int
|
||||
count_geki: int
|
||||
count_katu: int
|
||||
count_large_tick_miss: int | None = None
|
||||
count_slider_tail_hit: int | None = None
|
||||
|
||||
|
||||
class Score(ScoreBase, table=True):
|
||||
__tablename__ = "scores" # pyright: ignore[reportAssignmentType]
|
||||
id: int = Field(primary_key=True)
|
||||
id: int | None = Field(
|
||||
default=None, sa_column=Column(BigInteger, autoincrement=True, primary_key=True)
|
||||
)
|
||||
beatmap_id: int = Field(index=True, foreign_key="beatmaps.id")
|
||||
user_id: int = Field(foreign_key="users.id", index=True)
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
# ScoreStatistics
|
||||
n300: int = Field(exclude=True)
|
||||
n100: int = Field(exclude=True)
|
||||
@@ -72,9 +94,51 @@ class Score(ScoreBase, table=True):
|
||||
beatmap: "Beatmap" = Relationship()
|
||||
user: "User" = Relationship()
|
||||
|
||||
@property
|
||||
def is_perfect_combo(self) -> bool:
|
||||
return self.max_combo == self.beatmap.max_combo
|
||||
|
||||
@staticmethod
|
||||
def select_clause() -> SelectOfScalar["Score"]:
|
||||
return select(Score).options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
joinedload(Score.user).joinedload(User.lazer_profile), # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def select_clause_unique(
|
||||
*where_clauses: ColumnExpressionArgument[bool] | bool,
|
||||
) -> SelectOfScalar["Score"]:
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=col(Score.user_id), order_by=col(Score.total_score).desc()
|
||||
)
|
||||
.label("rn")
|
||||
)
|
||||
subq = select(Score, rownum).where(*where_clauses).subquery()
|
||||
best = aliased(Score, subq, adapt_on_names=True)
|
||||
return (
|
||||
select(best)
|
||||
.where(subq.c.rn == 1)
|
||||
.options(
|
||||
joinedload(best.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
joinedload(best.user).joinedload(User.lazer_profile), # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ScoreResp(ScoreBase):
|
||||
id: int
|
||||
user_id: int
|
||||
is_perfect_combo: bool = False
|
||||
legacy_perfect: bool = False
|
||||
legacy_total_score: int = 0 # FIXME
|
||||
@@ -85,10 +149,13 @@ class ScoreResp(ScoreBase):
|
||||
beatmapset: BeatmapsetResp | None = None
|
||||
# FIXME: user: APIUser | None = None
|
||||
statistics: ScoreStatistics | None = None
|
||||
rank_global: int | None = None
|
||||
rank_country: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, score: Score) -> "ScoreResp":
|
||||
async def from_db(cls, session: AsyncSession, score: Score) -> "ScoreResp":
|
||||
s = cls.model_validate(score.model_dump())
|
||||
assert score.id
|
||||
s.beatmap = BeatmapResp.from_db(score.beatmap)
|
||||
s.beatmapset = BeatmapsetResp.from_db(score.beatmap.beatmapset)
|
||||
s.is_perfect_combo = s.max_combo == s.beatmap.max_combo
|
||||
@@ -97,14 +164,220 @@ class ScoreResp(ScoreBase):
|
||||
if score.best_id:
|
||||
# https://osu.ppy.sh/wiki/Performance_points/Weighting_system
|
||||
s.weight = math.pow(0.95, score.best_id)
|
||||
s.statistics = ScoreStatistics(
|
||||
count_miss=score.nmiss,
|
||||
count_50=score.n50,
|
||||
count_100=score.n100,
|
||||
count_300=score.n300,
|
||||
count_geki=score.ngeki,
|
||||
count_katu=score.nkatu,
|
||||
count_large_tick_miss=score.nlarge_tick_miss,
|
||||
count_slider_tail_hit=score.nslider_tail_hit,
|
||||
s.statistics = {
|
||||
HitResult.MISS: score.nmiss,
|
||||
HitResult.MEH: score.n50,
|
||||
HitResult.OK: score.n100,
|
||||
HitResult.GREAT: score.n300,
|
||||
HitResult.PERFECT: score.ngeki,
|
||||
HitResult.GOOD: score.nkatu,
|
||||
}
|
||||
if score.nlarge_tick_miss is not None:
|
||||
s.statistics[HitResult.LARGE_TICK_MISS] = score.nlarge_tick_miss
|
||||
if score.nslider_tail_hit is not None:
|
||||
s.statistics[HitResult.SLIDER_TAIL_HIT] = score.nslider_tail_hit
|
||||
# s.user = await convert_db_user_to_api_user(score.user)
|
||||
s.rank_global = (
|
||||
await get_score_position_by_id(
|
||||
session,
|
||||
score.map_md5,
|
||||
score.id,
|
||||
mode=score.gamemode,
|
||||
user=score.user,
|
||||
)
|
||||
or None
|
||||
)
|
||||
s.rank_country = (
|
||||
await get_score_position_by_id(
|
||||
session,
|
||||
score.map_md5,
|
||||
score.id,
|
||||
score.gamemode,
|
||||
score.user,
|
||||
)
|
||||
or None
|
||||
)
|
||||
return s
|
||||
|
||||
|
||||
async def get_leaderboard(
|
||||
session: AsyncSession,
|
||||
beatmap_md5: str,
|
||||
mode: GameMode,
|
||||
type: LeaderboardType = LeaderboardType.GLOBAL,
|
||||
mods: list[APIMod] | None = None,
|
||||
user: User | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[Score]:
|
||||
scores = []
|
||||
if type == LeaderboardType.GLOBAL:
|
||||
query = (
|
||||
select(Score)
|
||||
.where(
|
||||
col(Beatmap.beatmap_status).in_(
|
||||
[
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.LOVED,
|
||||
BeatmapRankStatus.QUALIFIED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
]
|
||||
),
|
||||
Score.map_md5 == beatmap_md5,
|
||||
Score.gamemode == mode,
|
||||
col(Score.passed).is_(True),
|
||||
Score.mods == mods if user and user.is_supporter else false(),
|
||||
)
|
||||
.limit(limit)
|
||||
.order_by(
|
||||
col(Score.total_score).desc(),
|
||||
)
|
||||
)
|
||||
result = await session.exec(query)
|
||||
scores = list[Score](result.all())
|
||||
elif type == LeaderboardType.FRIENDS and user and user.is_supporter:
|
||||
# TODO
|
||||
...
|
||||
elif type == LeaderboardType.TEAM and user and user.team_membership:
|
||||
team_id = user.team_membership.team_id
|
||||
query = (
|
||||
select(Score)
|
||||
.join(Beatmap)
|
||||
.options(joinedload(Score.user)) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
Score.map_md5 == beatmap_md5,
|
||||
Score.gamemode == mode,
|
||||
col(Score.passed).is_(True),
|
||||
col(Score.user.team_membership).is_not(None),
|
||||
Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
|
||||
Score.mods == mods if user and user.is_supporter else false(),
|
||||
)
|
||||
.limit(limit)
|
||||
.order_by(
|
||||
col(Score.total_score).desc(),
|
||||
)
|
||||
)
|
||||
result = await session.exec(query)
|
||||
scores = list[Score](result.all())
|
||||
if user:
|
||||
user_score = (
|
||||
await session.exec(
|
||||
select(Score).where(
|
||||
Score.map_md5 == beatmap_md5,
|
||||
Score.gamemode == mode,
|
||||
Score.user_id == user.id,
|
||||
col(Score.passed).is_(True),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if user_score and user_score not in scores:
|
||||
scores.append(user_score)
|
||||
return scores
|
||||
|
||||
|
||||
async def get_score_position_by_user(
|
||||
session: AsyncSession,
|
||||
beatmap_md5: str,
|
||||
user: User,
|
||||
mode: GameMode,
|
||||
type: LeaderboardType = LeaderboardType.GLOBAL,
|
||||
mods: list[APIMod] | None = None,
|
||||
) -> int:
|
||||
where_clause = [
|
||||
Score.map_md5 == beatmap_md5,
|
||||
Score.gamemode == mode,
|
||||
col(Score.passed).is_(True),
|
||||
col(Beatmap.beatmap_status).in_(
|
||||
[
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.LOVED,
|
||||
BeatmapRankStatus.QUALIFIED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
]
|
||||
),
|
||||
]
|
||||
if mods and user.is_supporter:
|
||||
where_clause.append(Score.mods == mods)
|
||||
else:
|
||||
where_clause.append(false())
|
||||
if type == LeaderboardType.FRIENDS and user.is_supporter:
|
||||
# TODO
|
||||
...
|
||||
elif type == LeaderboardType.TEAM and user.team_membership:
|
||||
team_id = user.team_membership.team_id
|
||||
where_clause.append(
|
||||
col(Score.user.team_membership).is_not(None),
|
||||
)
|
||||
where_clause.append(
|
||||
Score.user.team_membership.team_id == team_id, # pyright: ignore[reportOptionalMemberAccess]
|
||||
)
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=Score.map_md5,
|
||||
order_by=col(Score.total_score).desc(),
|
||||
)
|
||||
.label("row_number")
|
||||
)
|
||||
subq = select(Score, rownum).join(Beatmap).where(*where_clause).subquery()
|
||||
stmt = select(subq.c.row_number).where(subq.c.user == user)
|
||||
result = await session.exec(stmt)
|
||||
s = result.one_or_none()
|
||||
return s if s else 0
|
||||
|
||||
|
||||
async def get_score_position_by_id(
|
||||
session: AsyncSession,
|
||||
beatmap_md5: str,
|
||||
score_id: int,
|
||||
mode: GameMode,
|
||||
user: User | None = None,
|
||||
type: LeaderboardType = LeaderboardType.GLOBAL,
|
||||
mods: list[APIMod] | None = None,
|
||||
) -> int:
|
||||
where_clause = [
|
||||
Score.map_md5 == beatmap_md5,
|
||||
Score.id == score_id,
|
||||
Score.gamemode == mode,
|
||||
col(Score.passed).is_(True),
|
||||
col(Beatmap.beatmap_status).in_(
|
||||
[
|
||||
BeatmapRankStatus.RANKED,
|
||||
BeatmapRankStatus.LOVED,
|
||||
BeatmapRankStatus.QUALIFIED,
|
||||
BeatmapRankStatus.APPROVED,
|
||||
]
|
||||
),
|
||||
]
|
||||
if mods and user and user.is_supporter:
|
||||
where_clause.append(Score.mods == mods)
|
||||
elif mods:
|
||||
where_clause.append(false())
|
||||
rownum = (
|
||||
func.row_number()
|
||||
.over(
|
||||
partition_by=[col(Score.user_id), col(Score.map_md5)],
|
||||
order_by=col(Score.total_score).desc(),
|
||||
)
|
||||
.label("rownum")
|
||||
)
|
||||
subq = (
|
||||
select(Score.user_id, Score.id, Score.total_score, rownum)
|
||||
.join(Beatmap)
|
||||
.where(*where_clause)
|
||||
.subquery()
|
||||
)
|
||||
best_scores = aliased(subq)
|
||||
overall_rank = (
|
||||
func.rank().over(order_by=best_scores.c.total_score.desc()).label("global_rank")
|
||||
)
|
||||
final_q = (
|
||||
select(best_scores.c.id, overall_rank)
|
||||
.select_from(best_scores)
|
||||
.where(best_scores.c.rownum == 1)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
stmt = select(final_q.c.global_rank).where(final_q.c.id == score_id)
|
||||
result = await session.exec(stmt)
|
||||
s = result.one_or_none()
|
||||
return s if s else 0
|
||||
|
||||
50
app/database/score_token.py
Normal file
50
app/database/score_token.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.score import GameMode
|
||||
|
||||
from .beatmap import Beatmap
|
||||
from .user import User
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
|
||||
class ScoreTokenBase(SQLModel):
|
||||
score_id: int | None = Field(sa_column=Column(BigInteger), default=None)
|
||||
ruleset_id: GameMode
|
||||
playlist_item_id: int | None = Field(default=None) # playlist
|
||||
created_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
|
||||
class ScoreToken(ScoreTokenBase, table=True):
|
||||
__tablename__ = "score_tokens" # pyright: ignore[reportAssignmentType]
|
||||
__table_args__ = (Index("idx_user_playlist", "user_id", "playlist_item_id"),)
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
primary_key=True,
|
||||
index=True,
|
||||
autoincrement=True,
|
||||
),
|
||||
)
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
beatmap_id: int = Field(foreign_key="beatmaps.id")
|
||||
user: "User" = Relationship()
|
||||
beatmap: "Beatmap" = Relationship()
|
||||
|
||||
|
||||
class ScoreTokenResp(ScoreTokenBase):
|
||||
id: int
|
||||
user_id: int
|
||||
beatmap_id: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, obj: ScoreToken) -> "ScoreTokenResp":
|
||||
return cls.model_validate(obj)
|
||||
@@ -2,8 +2,7 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -20,18 +19,18 @@ class Team(SQLModel, table=True):
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
members: Mapped[list["TeamMember"]] = Relationship(back_populates="team")
|
||||
members: list["TeamMember"] = Relationship(back_populates="team")
|
||||
|
||||
|
||||
class TeamMember(SQLModel, table=True):
|
||||
__tablename__ = "team_members" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
team_id: int = Field(foreign_key="teams.id")
|
||||
joined_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
)
|
||||
|
||||
user: Mapped["User"] = Relationship(back_populates="team_membership")
|
||||
team: Mapped["Team"] = Relationship(back_populates="members")
|
||||
user: "User" = Relationship(back_populates="team_membership")
|
||||
team: "Team" = Relationship(back_populates="members")
|
||||
|
||||
@@ -7,16 +7,19 @@ from .team import TeamMember
|
||||
|
||||
from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text
|
||||
from sqlalchemy.dialects.mysql import VARCHAR
|
||||
from sqlmodel import BigInteger, Field, Relationship, SQLModel
|
||||
from sqlalchemy.orm import joinedload, selectinload
|
||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel, select
|
||||
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
# 主键
|
||||
id: int = Field(default=None, primary_key=True, index=True, nullable=False)
|
||||
id: int = Field(
|
||||
default=None, sa_column=Column(BigInteger, primary_key=True, index=True)
|
||||
)
|
||||
|
||||
# 基本信息(匹配 migrations 中的结构)
|
||||
# 基本信息(匹配 migrations_old 中的结构)
|
||||
name: str = Field(max_length=32, unique=True, index=True) # 用户名
|
||||
safe_name: str = Field(max_length=32, unique=True, index=True) # 安全用户名
|
||||
email: str = Field(max_length=254, unique=True, index=True)
|
||||
@@ -65,6 +68,10 @@ class User(SQLModel, table=True):
|
||||
latest_activity = getattr(self, "latest_activity", 0)
|
||||
return datetime.fromtimestamp(latest_activity) if latest_activity > 0 else None
|
||||
|
||||
@property
|
||||
def is_supporter(self):
|
||||
return self.lazer_profile.is_supporter if self.lazer_profile else False
|
||||
|
||||
# 关联关系
|
||||
lazer_profile: Optional["LazerUserProfile"] = Relationship(back_populates="user")
|
||||
lazer_statistics: list["LazerUserStatistics"] = Relationship(back_populates="user")
|
||||
@@ -76,7 +83,7 @@ class User(SQLModel, table=True):
|
||||
back_populates="user"
|
||||
)
|
||||
statistics: list["LegacyUserStatistics"] = Relationship(back_populates="user")
|
||||
team_membership: list["TeamMember"] = Relationship(back_populates="user")
|
||||
team_membership: Optional["TeamMember"] = Relationship(back_populates="user")
|
||||
daily_challenge_stats: Optional["DailyChallengeStats"] = Relationship(
|
||||
back_populates="user"
|
||||
)
|
||||
@@ -94,6 +101,26 @@ class User(SQLModel, table=True):
|
||||
back_populates="user"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def all_select_clause(cls):
|
||||
return select(cls).options(
|
||||
joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType]
|
||||
joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType]
|
||||
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
|
||||
joinedload(cls.avatar), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_statistics), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_achievements), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_profile_sections), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
|
||||
joinedload(cls.team_membership), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.rank_history), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.active_banners), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_badges), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
|
||||
selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
|
||||
|
||||
# ============================================
|
||||
# Lazer API 专用表模型
|
||||
@@ -103,7 +130,14 @@ class User(SQLModel, table=True):
|
||||
class LazerUserProfile(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_profiles" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(foreign_key="users.id", primary_key=True)
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
# 基本状态字段
|
||||
is_active: bool = Field(default=True)
|
||||
@@ -159,7 +193,7 @@ class LazerUserProfileSections(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_profile_sections" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
section_name: str = Field(sa_column=Column(VARCHAR(50)))
|
||||
display_order: int | None = Field(default=None)
|
||||
|
||||
@@ -176,7 +210,14 @@ class LazerUserProfileSections(SQLModel, table=True):
|
||||
class LazerUserCountry(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_countries" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(foreign_key="users.id", primary_key=True)
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
code: str = Field(max_length=2)
|
||||
name: str = Field(max_length=100)
|
||||
|
||||
@@ -191,7 +232,14 @@ class LazerUserCountry(SQLModel, table=True):
|
||||
class LazerUserKudosu(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_kudosu" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(foreign_key="users.id", primary_key=True)
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
available: int = Field(default=0)
|
||||
total: int = Field(default=0)
|
||||
|
||||
@@ -206,7 +254,14 @@ class LazerUserKudosu(SQLModel, table=True):
|
||||
class LazerUserCounts(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_counts" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(foreign_key="users.id", primary_key=True)
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
# 统计计数字段
|
||||
beatmap_playcounts_count: int = Field(default=0)
|
||||
@@ -241,7 +296,14 @@ class LazerUserCounts(SQLModel, table=True):
|
||||
class LazerUserStatistics(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_statistics" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(foreign_key="users.id", primary_key=True)
|
||||
user_id: int = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
mode: str = Field(default="osu", max_length=10, primary_key=True)
|
||||
|
||||
# 基本命中统计
|
||||
@@ -302,7 +364,7 @@ class LazerUserBanners(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_tournament_banners" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
tournament_id: int
|
||||
image_url: str = Field(sa_column=Column(VARCHAR(500)))
|
||||
is_active: bool | None = Field(default=None)
|
||||
@@ -315,7 +377,7 @@ class LazerUserAchievement(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_achievements" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
achievement_id: int
|
||||
achieved_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, sa_column=Column(DateTime)
|
||||
@@ -328,7 +390,7 @@ class LazerUserBadge(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_badges" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
badge_id: int
|
||||
awarded_at: datetime | None = Field(default=None, sa_column=Column(DateTime))
|
||||
description: str | None = Field(default=None, sa_column=Column(Text))
|
||||
@@ -349,7 +411,7 @@ class LazerUserMonthlyPlaycounts(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_monthly_playcounts" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
start_date: datetime = Field(sa_column=Column(Date))
|
||||
play_count: int = Field(default=0)
|
||||
|
||||
@@ -367,7 +429,7 @@ class LazerUserPreviousUsername(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_previous_usernames" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
username: str = Field(max_length=32)
|
||||
changed_at: datetime = Field(sa_column=Column(DateTime))
|
||||
|
||||
@@ -385,7 +447,7 @@ class LazerUserReplaysWatched(SQLModel, table=True):
|
||||
__tablename__ = "lazer_user_replays_watched" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
start_date: datetime = Field(sa_column=Column(Date))
|
||||
count: int = Field(default=0)
|
||||
|
||||
@@ -410,7 +472,9 @@ class DailyChallengeStats(SQLModel, table=True):
|
||||
__tablename__ = "daily_challenge_stats" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(foreign_key="users.id", unique=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("users.id"), unique=True)
|
||||
)
|
||||
|
||||
daily_streak_best: int = Field(default=0)
|
||||
daily_streak_current: int = Field(default=0)
|
||||
@@ -431,7 +495,7 @@ class RankHistory(SQLModel, table=True):
|
||||
__tablename__ = "rank_history" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
mode: str = Field(max_length=10)
|
||||
rank_data: list = Field(sa_column=Column(JSON)) # Array of ranks
|
||||
date_recorded: datetime = Field(
|
||||
@@ -445,7 +509,7 @@ class UserAvatar(SQLModel, table=True):
|
||||
__tablename__ = "user_avatars" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
||||
user_id: int = Field(foreign_key="users.id")
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("users.id")))
|
||||
filename: str = Field(max_length=255)
|
||||
original_filename: str = Field(max_length=255)
|
||||
file_size: int
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from app.config import settings
|
||||
from app.dependencies.database import get_redis
|
||||
from app.fetcher import Fetcher
|
||||
from app.log import logger
|
||||
|
||||
fetcher: Fetcher | None = None
|
||||
|
||||
@@ -25,5 +26,7 @@ def get_fetcher() -> Fetcher:
|
||||
if refresh_token:
|
||||
fetcher.refresh_token = str(refresh_token)
|
||||
if not fetcher.access_token or not fetcher.refresh_token:
|
||||
print("Login to initialize fetcher:", fetcher.authorize_url)
|
||||
logger.opt(colors=True).info(
|
||||
f"Login to initialize fetcher: <y>{fetcher.authorize_url}</y>"
|
||||
)
|
||||
return fetcher
|
||||
|
||||
@@ -9,8 +9,6 @@ from .database import get_db
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.orm import joinedload, selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
security = HTTPBearer()
|
||||
@@ -35,25 +33,7 @@ async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | No
|
||||
return None
|
||||
user = (
|
||||
await db.exec(
|
||||
select(DBUser)
|
||||
.options(
|
||||
joinedload(DBUser.lazer_profile), # pyright: ignore[reportArgumentType]
|
||||
joinedload(DBUser.lazer_counts), # pyright: ignore[reportArgumentType]
|
||||
joinedload(DBUser.daily_challenge_stats), # pyright: ignore[reportArgumentType]
|
||||
joinedload(DBUser.avatar), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_statistics), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_achievements), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_profile_sections), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.statistics), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.team_membership), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.rank_history), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.active_banners), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_badges), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
|
||||
selectinload(DBUser.lazer_replays_watched), # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
.where(DBUser.id == token_record.user_id)
|
||||
DBUser.all_select_clause().where(DBUser.id == token_record.user_id)
|
||||
)
|
||||
).first()
|
||||
return user
|
||||
|
||||
@@ -1,23 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from app.database.beatmap import BeatmapResp
|
||||
from app.log import logger
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.database.beatmap import BeatmapResp
|
||||
|
||||
|
||||
class BeatmapFetcher(BaseFetcher):
|
||||
async def get_beatmap(self, beatmap_id: int) -> "BeatmapResp":
|
||||
from app.database.beatmap import BeatmapResp
|
||||
|
||||
async def get_beatmap(
|
||||
self, beatmap_id: int | None = None, beatmap_checksum: str | None = None
|
||||
) -> BeatmapResp:
|
||||
if beatmap_id:
|
||||
params = {"id": beatmap_id}
|
||||
elif beatmap_checksum:
|
||||
params = {"checksum": beatmap_checksum}
|
||||
else:
|
||||
raise ValueError("Either beatmap_id or beatmap_checksum must be provided.")
|
||||
logger.opt(colors=True).debug(
|
||||
f"<blue>[BeatmapFetcher]</blue> get_beatmap: <y>{params}</y>"
|
||||
)
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"https://osu.ppy.sh/api/v2/beatmaps/{beatmap_id}",
|
||||
"https://osu.ppy.sh/api/v2/beatmaps/lookup",
|
||||
headers=self.header,
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return BeatmapResp.model_validate(response.json())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.database.beatmapset import BeatmapsetResp
|
||||
from app.log import logger
|
||||
|
||||
from ._base import BaseFetcher
|
||||
|
||||
@@ -9,6 +10,9 @@ from httpx import AsyncClient
|
||||
|
||||
class BeatmapsetFetcher(BaseFetcher):
|
||||
async def get_beatmapset(self, beatmap_set_id: int) -> BeatmapsetResp:
|
||||
logger.opt(colors=True).debug(
|
||||
f"<blue>[BeatmapsetFetcher]</blue> get_beatmapset: <y>{beatmap_set_id}</y>"
|
||||
)
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"https://osu.ppy.sh/api/v2/beatmapsets/{beatmap_set_id}",
|
||||
|
||||
@@ -3,10 +3,14 @@ from __future__ import annotations
|
||||
from ._base import BaseFetcher
|
||||
|
||||
from httpx import AsyncClient
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class OsuDotDirectFetcher(BaseFetcher):
|
||||
async def get_beatmap_raw(self, beatmap_id: int) -> str:
|
||||
logger.opt(colors=True).debug(
|
||||
f"<blue>[OsuDotDirectFetcher]</blue> get_beatmap_raw: <y>{beatmap_id}</y>"
|
||||
)
|
||||
async with AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"https://osu.direct/api/osu/{beatmap_id}/raw",
|
||||
|
||||
138
app/log.py
Normal file
138
app/log.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import http
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
from sys import stdout
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.config import settings
|
||||
|
||||
import loguru
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from loguru import Logger
|
||||
|
||||
logger: "Logger" = loguru.logger
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# Get corresponding Loguru level if it exists.
|
||||
try:
|
||||
level: str | int = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
# Find caller from where originated the logged message.
|
||||
frame, depth = inspect.currentframe(), 0
|
||||
while frame:
|
||||
filename = frame.f_code.co_filename
|
||||
is_logging = filename == logging.__file__
|
||||
is_frozen = "importlib" in filename and "_bootstrap" in filename
|
||||
if depth > 0 and not (is_logging or is_frozen):
|
||||
break
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
message = record.getMessage()
|
||||
|
||||
if record.name == "uvicorn.access":
|
||||
message = self._format_uvicorn_access_log(message)
|
||||
elif record.name == "uvicorn.error":
|
||||
message = self._format_uvicorn_error_log(message)
|
||||
logger.opt(depth=depth, exception=record.exc_info, colors=True).log(
|
||||
level, message
|
||||
)
|
||||
|
||||
def _format_uvicorn_error_log(self, message: str) -> str:
|
||||
websocket_pattern = (
|
||||
r'(\d+\.\d+\.\d+\.\d+:\d+)\s*-\s*"WebSocket\s+([^"]+)"\s+([\w\[\]]+)'
|
||||
)
|
||||
websocket_match = re.search(websocket_pattern, message)
|
||||
|
||||
if websocket_match:
|
||||
ip, path, status = websocket_match.groups()
|
||||
|
||||
colored_ip = f"<cyan>{ip}</cyan>"
|
||||
status_colors = {
|
||||
"[accepted]": "<green>[accepted]</green>",
|
||||
"403": "<red>403 [rejected]</red>",
|
||||
}
|
||||
colored_status = status_colors.get(
|
||||
status.lower(), f"<white>{status}</white>"
|
||||
)
|
||||
return (
|
||||
f'{colored_ip} - "<bold><magenta>WebSocket</magenta> '
|
||||
f'{path}</bold>" '
|
||||
f"{colored_status}"
|
||||
)
|
||||
else:
|
||||
return message
|
||||
|
||||
def _format_uvicorn_access_log(self, message: str) -> str:
|
||||
http_pattern = r'(\d+\.\d+\.\d+\.\d+:\d+)\s*-\s*"(\w+)\s+([^"]+)"\s+(\d+)'
|
||||
|
||||
http_match = re.search(http_pattern, message)
|
||||
if http_match:
|
||||
ip, method, path, status_code = http_match.groups()
|
||||
try:
|
||||
status_phrase = http.HTTPStatus(int(status_code)).phrase
|
||||
except ValueError:
|
||||
status_phrase = ""
|
||||
|
||||
colored_ip = f"<cyan>{ip}</cyan>"
|
||||
method_colors = {
|
||||
"GET": "<green>GET</green>",
|
||||
"POST": "<blue>POST</blue>",
|
||||
"PUT": "<yellow>PUT</yellow>",
|
||||
"DELETE": "<red>DELETE</red>",
|
||||
"PATCH": "<magenta>PATCH</magenta>",
|
||||
"OPTIONS": "<white>OPTIONS</white>",
|
||||
"HEAD": "<white>HEAD</white>",
|
||||
}
|
||||
colored_method = method_colors.get(method, f"<white>{method}</white>")
|
||||
status = int(status_code)
|
||||
status_color = "white"
|
||||
if 200 <= status < 300:
|
||||
status_color = "green"
|
||||
elif 300 <= status < 400:
|
||||
status_color = "yellow"
|
||||
elif 400 <= status < 500:
|
||||
status_color = "red"
|
||||
elif 500 <= status < 600:
|
||||
status_color = "red"
|
||||
|
||||
return (
|
||||
f'{colored_ip} - "<bold>{colored_method} '
|
||||
f'{path}</bold>" '
|
||||
f"<{status_color}>{status_code} {status_phrase}</{status_color}>"
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
|
||||
logger.remove()
|
||||
logger.add(
|
||||
stdout,
|
||||
colorize=True,
|
||||
format=(
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"
|
||||
),
|
||||
level=settings.LOG_LEVEL,
|
||||
diagnose=settings.DEBUG,
|
||||
)
|
||||
logging.basicConfig(handlers=[InterceptHandler()], level=settings.LOG_LEVEL, force=True)
|
||||
|
||||
uvicorn_loggers = [
|
||||
"uvicorn",
|
||||
"uvicorn.error",
|
||||
"uvicorn.access",
|
||||
"fastapi",
|
||||
]
|
||||
|
||||
for logger_name in uvicorn_loggers:
|
||||
uvicorn_logger = logging.getLogger(logger_name)
|
||||
uvicorn_logger.handlers = [InterceptHandler()]
|
||||
uvicorn_logger.propagate = False
|
||||
152
app/models/metadata_hub.py
Normal file
152
app/models/metadata_hub.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.models.signalr import UserState
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class _UserActivity(BaseModel):
|
||||
model_config = ConfigDict(serialize_by_alias=True)
|
||||
type: Literal[
|
||||
"ChoosingBeatmap",
|
||||
"InSoloGame",
|
||||
"WatchingReplay",
|
||||
"SpectatingUser",
|
||||
"SearchingForLobby",
|
||||
"InLobby",
|
||||
"InMultiplayerGame",
|
||||
"SpectatingMultiplayerGame",
|
||||
"InPlaylistGame",
|
||||
"EditingBeatmap",
|
||||
"ModdingBeatmap",
|
||||
"TestingBeatmap",
|
||||
"InDailyChallengeLobby",
|
||||
"PlayingDailyChallenge",
|
||||
] = Field(alias="$dtype")
|
||||
value: Any | None = Field(alias="$value")
|
||||
|
||||
|
||||
class ChoosingBeatmap(_UserActivity):
|
||||
type: Literal["ChoosingBeatmap"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class InGameValue(BaseModel):
|
||||
beatmap_id: int = Field(alias="BeatmapID")
|
||||
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
|
||||
ruleset_id: int = Field(alias="RulesetID")
|
||||
ruleset_playing_verb: str = Field(alias="RulesetPlayingVerb")
|
||||
|
||||
|
||||
class _InGame(_UserActivity):
|
||||
value: InGameValue = Field(alias="$value")
|
||||
|
||||
|
||||
class InSoloGame(_InGame):
|
||||
type: Literal["InSoloGame"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class InMultiplayerGame(_InGame):
|
||||
type: Literal["InMultiplayerGame"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class SpectatingMultiplayerGame(_InGame):
|
||||
type: Literal["SpectatingMultiplayerGame"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class InPlaylistGame(_InGame):
|
||||
type: Literal["InPlaylistGame"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class EditingBeatmapValue(BaseModel):
|
||||
beatmap_id: int = Field(alias="BeatmapID")
|
||||
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
|
||||
|
||||
|
||||
class EditingBeatmap(_UserActivity):
|
||||
type: Literal["EditingBeatmap"] = Field(alias="$dtype")
|
||||
value: EditingBeatmapValue = Field(alias="$value")
|
||||
|
||||
|
||||
class TestingBeatmap(_UserActivity):
|
||||
type: Literal["TestingBeatmap"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class ModdingBeatmap(_UserActivity):
|
||||
type: Literal["ModdingBeatmap"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class WatchingReplayValue(BaseModel):
|
||||
score_id: int = Field(alias="ScoreID")
|
||||
player_name: str = Field(alias="PlayerName")
|
||||
beatmap_id: int = Field(alias="BeatmapID")
|
||||
beatmap_display_title: str = Field(alias="BeatmapDisplayTitle")
|
||||
|
||||
|
||||
class WatchingReplay(_UserActivity):
|
||||
type: Literal["WatchingReplay"] = Field(alias="$dtype")
|
||||
value: int | None = Field(alias="$value") # Replay ID
|
||||
|
||||
|
||||
class SpectatingUser(WatchingReplay):
|
||||
type: Literal["SpectatingUser"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class SearchingForLobby(_UserActivity):
|
||||
type: Literal["SearchingForLobby"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
class InLobbyValue(BaseModel):
|
||||
room_id: int = Field(alias="RoomID")
|
||||
room_name: str = Field(alias="RoomName")
|
||||
|
||||
|
||||
class InLobby(_UserActivity):
|
||||
type: Literal["InLobby"] = "InLobby"
|
||||
|
||||
|
||||
class InDailyChallengeLobby(_UserActivity):
|
||||
type: Literal["InDailyChallengeLobby"] = Field(alias="$dtype")
|
||||
|
||||
|
||||
UserActivity = (
|
||||
ChoosingBeatmap
|
||||
| InSoloGame
|
||||
| WatchingReplay
|
||||
| SpectatingUser
|
||||
| SearchingForLobby
|
||||
| InLobby
|
||||
| InMultiplayerGame
|
||||
| SpectatingMultiplayerGame
|
||||
| InPlaylistGame
|
||||
| EditingBeatmap
|
||||
| ModdingBeatmap
|
||||
| TestingBeatmap
|
||||
| InDailyChallengeLobby
|
||||
)
|
||||
|
||||
|
||||
class MetadataClientState(UserState):
|
||||
user_activity: UserActivity | None = None
|
||||
status: OnlineStatus | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any] | None:
|
||||
if self.status is None or self.status == OnlineStatus.OFFLINE:
|
||||
return None
|
||||
dumped = self.model_dump(by_alias=True, exclude_none=True)
|
||||
return {
|
||||
"Activity": dumped.get("user_activity"),
|
||||
"Status": dumped.get("status"),
|
||||
}
|
||||
|
||||
@property
|
||||
def pushable(self) -> bool:
|
||||
return self.status is not None and self.status != OnlineStatus.OFFLINE
|
||||
|
||||
|
||||
class OnlineStatus(IntEnum):
|
||||
OFFLINE = 0 # 隐身
|
||||
DO_NOT_DISTURB = 1
|
||||
ONLINE = 2
|
||||
@@ -1,47 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
import json
|
||||
from typing import Literal, NotRequired, TypedDict
|
||||
|
||||
from app.path import STATIC_DIR
|
||||
|
||||
|
||||
class APIMod(TypedDict):
|
||||
acronym: str
|
||||
settings: dict[str, bool | float | str]
|
||||
settings: NotRequired[dict[str, bool | float | str]]
|
||||
|
||||
|
||||
# https://github.com/ppy/osu-api/wiki#mods
|
||||
LEGACY_MOD_TO_API_MOD = {
|
||||
(1 << 0): APIMod(acronym="NF", settings={}), # No Fail
|
||||
(1 << 1): APIMod(acronym="EZ", settings={}),
|
||||
(1 << 2): APIMod(acronym="TD", settings={}), # Touch Device
|
||||
(1 << 3): APIMod(acronym="HD", settings={}), # Hidden
|
||||
(1 << 4): APIMod(acronym="HR", settings={}), # Hard Rock
|
||||
(1 << 5): APIMod(acronym="SD", settings={}), # Sudden Death
|
||||
(1 << 6): APIMod(acronym="DT", settings={}), # Double Time
|
||||
(1 << 7): APIMod(acronym="RX", settings={}), # Relax
|
||||
(1 << 8): APIMod(acronym="HT", settings={}), # Half Time
|
||||
(1 << 9): APIMod(acronym="NC", settings={}), # Nightcore
|
||||
(1 << 10): APIMod(acronym="FL", settings={}), # Flashlight
|
||||
(1 << 11): APIMod(acronym="AT", settings={}), # Auto Play
|
||||
(1 << 12): APIMod(acronym="SO", settings={}), # Spun Out
|
||||
(1 << 13): APIMod(acronym="AP", settings={}), # Autopilot
|
||||
(1 << 14): APIMod(acronym="PF", settings={}), # Perfect
|
||||
(1 << 15): APIMod(acronym="4K", settings={}), # 4K
|
||||
(1 << 16): APIMod(acronym="5K", settings={}), # 5K
|
||||
(1 << 17): APIMod(acronym="6K", settings={}), # 6K
|
||||
(1 << 18): APIMod(acronym="7K", settings={}), # 7K
|
||||
(1 << 19): APIMod(acronym="8K", settings={}), # 8K
|
||||
(1 << 20): APIMod(acronym="FI", settings={}), # Fade In
|
||||
(1 << 21): APIMod(acronym="RD", settings={}), # Random
|
||||
(1 << 22): APIMod(acronym="CN", settings={}), # Cinema
|
||||
(1 << 23): APIMod(acronym="TP", settings={}), # Target Practice
|
||||
(1 << 24): APIMod(acronym="9K", settings={}), # 9K
|
||||
(1 << 25): APIMod(acronym="CO", settings={}), # Key Co-op
|
||||
(1 << 26): APIMod(acronym="1K", settings={}), # 1K
|
||||
(1 << 27): APIMod(acronym="2K", settings={}), # 2K
|
||||
(1 << 28): APIMod(acronym="3K", settings={}), # 3K
|
||||
(1 << 29): APIMod(acronym="SV2", settings={}), # Score V2
|
||||
(1 << 30): APIMod(acronym="MR", settings={}), # Mirror
|
||||
API_MOD_TO_LEGACY: dict[str, int] = {
|
||||
"NF": 1 << 0, # No Fail
|
||||
"EZ": 1 << 1, # Easy
|
||||
"TD": 1 << 2, # Touch Device
|
||||
"HD": 1 << 3, # Hidden
|
||||
"HR": 1 << 4, # Hard Rock
|
||||
"SD": 1 << 5, # Sudden Death
|
||||
"DT": 1 << 6, # Double Time
|
||||
"RX": 1 << 7, # Relax
|
||||
"HT": 1 << 8, # Half Time
|
||||
"NC": 1 << 9, # Nightcore
|
||||
"FL": 1 << 10, # Flashlight
|
||||
"AT": 1 << 11, # Autoplay
|
||||
"SO": 1 << 12, # Spun Out
|
||||
"AP": 1 << 13, # Auto Pilot
|
||||
"PF": 1 << 14, # Perfect
|
||||
"4K": 1 << 15, # 4K
|
||||
"5K": 1 << 16, # 5K
|
||||
"6K": 1 << 17, # 6K
|
||||
"7K": 1 << 18, # 7K
|
||||
"8K": 1 << 19, # 8K
|
||||
"FI": 1 << 20, # Fade In
|
||||
"RD": 1 << 21, # Random
|
||||
"CN": 1 << 22, # Cinema
|
||||
"TP": 1 << 23, # Target Practice
|
||||
"9K": 1 << 24, # 9K
|
||||
"CO": 1 << 25, # Key Co-op
|
||||
"1K": 1 << 26, # 1K
|
||||
"3K": 1 << 27, # 3K
|
||||
"2K": 1 << 28, # 2K
|
||||
"SV2": 1 << 29, # ScoreV2
|
||||
"MR": 1 << 30, # Mirror
|
||||
}
|
||||
LEGACY_MOD_TO_API_MOD = {}
|
||||
for k, v in API_MOD_TO_LEGACY.items():
|
||||
LEGACY_MOD_TO_API_MOD[v] = APIMod(acronym=k, settings={})
|
||||
API_MOD_TO_LEGACY["NC"] |= API_MOD_TO_LEGACY["DT"]
|
||||
API_MOD_TO_LEGACY["PF"] |= API_MOD_TO_LEGACY["SD"]
|
||||
|
||||
|
||||
# see static/mods.json
|
||||
class Settings(TypedDict):
|
||||
Name: str
|
||||
Type: str
|
||||
Label: str
|
||||
Description: str
|
||||
|
||||
|
||||
class Mod(TypedDict):
|
||||
Acronym: str
|
||||
Name: str
|
||||
Description: str
|
||||
Type: str
|
||||
Settings: list[Settings]
|
||||
IncompatibleMods: list[str]
|
||||
RequiresConfiguration: bool
|
||||
UserPlayable: bool
|
||||
ValidForMultiplayer: bool
|
||||
ValidForFreestyleAsRequiredMod: bool
|
||||
ValidForMultiplayerAsFreeMod: bool
|
||||
AlwaysValidForSubmission: bool
|
||||
|
||||
|
||||
API_MODS: dict[Literal[0, 1, 2, 3], dict[str, Mod]] = {}
|
||||
|
||||
|
||||
def init_mods():
|
||||
mods_file = STATIC_DIR / "mods.json"
|
||||
raw_mods = json.loads(mods_file.read_text())
|
||||
for ruleset in raw_mods:
|
||||
ruleset_mods = {}
|
||||
for mod in ruleset["Mods"]:
|
||||
ruleset_mods[mod["Acronym"]] = mod
|
||||
API_MODS[ruleset["RulesetID"]] = ruleset_mods
|
||||
|
||||
|
||||
def int_to_mods(mods: int) -> list[APIMod]:
|
||||
@@ -54,3 +98,10 @@ def int_to_mods(mods: int) -> list[APIMod]:
|
||||
if mods & (1 << 9):
|
||||
mod_list.remove(LEGACY_MOD_TO_API_MOD[(1 << 6)])
|
||||
return mod_list
|
||||
|
||||
|
||||
def mods_to_int(mods: list[APIMod]) -> int:
|
||||
sum_ = 0
|
||||
for mod in mods:
|
||||
sum_ |= API_MOD_TO_LEGACY.get(mod["acronym"], 0)
|
||||
return sum_
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# OAuth 相关模型
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -34,3 +35,22 @@ class OAuthErrorResponse(BaseModel):
|
||||
error_description: str
|
||||
hint: str
|
||||
message: str
|
||||
|
||||
|
||||
class RegistrationErrorResponse(BaseModel):
|
||||
"""注册错误响应模型"""
|
||||
form_error: dict
|
||||
|
||||
|
||||
class UserRegistrationErrors(BaseModel):
|
||||
"""用户注册错误模型"""
|
||||
username: List[str] = []
|
||||
user_email: List[str] = []
|
||||
password: List[str] = []
|
||||
|
||||
|
||||
class RegistrationRequestErrors(BaseModel):
|
||||
"""注册请求错误模型"""
|
||||
message: str | None = None
|
||||
redirect: str | None = None
|
||||
user: UserRegistrationErrors | None = None
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from .mods import API_MODS, APIMod, init_mods
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
import rosu_pp_py as rosu
|
||||
|
||||
|
||||
@@ -30,40 +34,141 @@ INT_TO_MODE = {v: k for k, v in MODE_TO_INT.items()}
|
||||
|
||||
|
||||
class Rank(str, Enum):
|
||||
X = "ss"
|
||||
XH = "ssh"
|
||||
S = "s"
|
||||
SH = "sh"
|
||||
A = "a"
|
||||
B = "b"
|
||||
C = "c"
|
||||
D = "d"
|
||||
F = "f"
|
||||
X = "X"
|
||||
XH = "XH"
|
||||
S = "S"
|
||||
SH = "SH"
|
||||
A = "A"
|
||||
B = "B"
|
||||
C = "C"
|
||||
D = "D"
|
||||
F = "F"
|
||||
|
||||
|
||||
# https://github.com/ppy/osu/blob/master/osu.Game/Rulesets/Scoring/HitResult.cs
|
||||
class HitResult(IntEnum):
|
||||
PERFECT = 0 # [Order(0)]
|
||||
GREAT = 1 # [Order(1)]
|
||||
GOOD = 2 # [Order(2)]
|
||||
OK = 3 # [Order(3)]
|
||||
MEH = 4 # [Order(4)]
|
||||
MISS = 5 # [Order(5)]
|
||||
class HitResult(str, Enum):
|
||||
PERFECT = "perfect" # [Order(0)]
|
||||
GREAT = "great" # [Order(1)]
|
||||
GOOD = "good" # [Order(2)]
|
||||
OK = "ok" # [Order(3)]
|
||||
MEH = "meh" # [Order(4)]
|
||||
MISS = "miss" # [Order(5)]
|
||||
|
||||
LARGE_TICK_HIT = 6 # [Order(6)]
|
||||
SMALL_TICK_HIT = 7 # [Order(7)]
|
||||
SLIDER_TAIL_HIT = 8 # [Order(8)]
|
||||
LARGE_TICK_HIT = "large_tick_hit" # [Order(6)]
|
||||
SMALL_TICK_HIT = "small_tick_hit" # [Order(7)]
|
||||
SLIDER_TAIL_HIT = "slider_tail_hit" # [Order(8)]
|
||||
|
||||
LARGE_BONUS = 9 # [Order(9)]
|
||||
SMALL_BONUS = 10 # [Order(10)]
|
||||
LARGE_BONUS = "large_bonus" # [Order(9)]
|
||||
SMALL_BONUS = "small_bonus" # [Order(10)]
|
||||
|
||||
LARGE_TICK_MISS = 11 # [Order(11)]
|
||||
SMALL_TICK_MISS = 12 # [Order(12)]
|
||||
LARGE_TICK_MISS = "large_tick_miss" # [Order(11)]
|
||||
SMALL_TICK_MISS = "small_tick_miss" # [Order(12)]
|
||||
|
||||
IGNORE_HIT = 13 # [Order(13)]
|
||||
IGNORE_MISS = 14 # [Order(14)]
|
||||
IGNORE_HIT = "ignore_hit" # [Order(13)]
|
||||
IGNORE_MISS = "ignore_miss" # [Order(14)]
|
||||
|
||||
NONE = 15 # [Order(15)]
|
||||
COMBO_BREAK = 16 # [Order(16)]
|
||||
NONE = "none" # [Order(15)]
|
||||
COMBO_BREAK = "combo_break" # [Order(16)]
|
||||
|
||||
LEGACY_COMBO_INCREASE = 99 # [Order(99)] @deprecated
|
||||
LEGACY_COMBO_INCREASE = "legacy_combo_increase" # [Order(99)] @deprecated
|
||||
|
||||
def is_hit(self) -> bool:
|
||||
return self not in (
|
||||
HitResult.NONE,
|
||||
HitResult.IGNORE_MISS,
|
||||
HitResult.COMBO_BREAK,
|
||||
HitResult.LARGE_TICK_MISS,
|
||||
HitResult.SMALL_TICK_MISS,
|
||||
HitResult.MISS,
|
||||
)
|
||||
|
||||
|
||||
class HitResultInt(IntEnum):
|
||||
PERFECT = 0
|
||||
GREAT = 1
|
||||
GOOD = 2
|
||||
OK = 3
|
||||
MEH = 4
|
||||
MISS = 5
|
||||
|
||||
LARGE_TICK_HIT = 6
|
||||
SMALL_TICK_HIT = 7
|
||||
SLIDER_TAIL_HIT = 8
|
||||
|
||||
LARGE_BONUS = 9
|
||||
SMALL_BONUS = 10
|
||||
|
||||
LARGE_TICK_MISS = 11
|
||||
SMALL_TICK_MISS = 12
|
||||
|
||||
IGNORE_HIT = 13
|
||||
IGNORE_MISS = 14
|
||||
|
||||
NONE = 15
|
||||
COMBO_BREAK = 16
|
||||
|
||||
LEGACY_COMBO_INCREASE = 99
|
||||
|
||||
def is_hit(self) -> bool:
|
||||
return self not in (
|
||||
HitResultInt.NONE,
|
||||
HitResultInt.IGNORE_MISS,
|
||||
HitResultInt.COMBO_BREAK,
|
||||
HitResultInt.LARGE_TICK_MISS,
|
||||
HitResultInt.SMALL_TICK_MISS,
|
||||
HitResultInt.MISS,
|
||||
)
|
||||
|
||||
|
||||
class LeaderboardType(Enum):
|
||||
GLOBAL = "global"
|
||||
FRIENDS = "friends"
|
||||
COUNTRY = "country"
|
||||
TEAM = "team"
|
||||
|
||||
|
||||
ScoreStatistics = dict[HitResult, int]
|
||||
ScoreStatisticsInt = dict[HitResultInt, int]
|
||||
|
||||
|
||||
class SoloScoreSubmissionInfo(BaseModel):
|
||||
rank: Rank
|
||||
total_score: int = Field(ge=0, le=2**31 - 1)
|
||||
total_score_without_mods: int = Field(ge=0, le=2**31 - 1)
|
||||
accuracy: float = Field(ge=0, le=1)
|
||||
pp: float = Field(default=0, ge=0, le=2**31 - 1)
|
||||
max_combo: int = 0
|
||||
ruleset_id: Literal[0, 1, 2, 3]
|
||||
passed: bool = False
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
statistics: ScoreStatistics = Field(default_factory=dict)
|
||||
maximum_statistics: ScoreStatistics = Field(default_factory=dict)
|
||||
|
||||
@field_validator("mods", mode="after")
|
||||
@classmethod
|
||||
def validate_mods(cls, mods: list[APIMod], info: ValidationInfo):
|
||||
if not API_MODS:
|
||||
init_mods()
|
||||
incompatible_mods = set()
|
||||
# check incompatible mods
|
||||
for mod in mods:
|
||||
if mod["acronym"] in incompatible_mods:
|
||||
raise ValueError(
|
||||
f"Mod {mod['acronym']} is incompatible with other mods"
|
||||
)
|
||||
setting_mods = API_MODS[info.data["ruleset_id"]].get(mod["acronym"])
|
||||
if not setting_mods:
|
||||
raise ValueError(f"Invalid mod: {mod['acronym']}")
|
||||
incompatible_mods.update(setting_mods["IncompatibleMods"])
|
||||
return mods
|
||||
|
||||
|
||||
class LegacyReplaySoloScoreInfo(TypedDict):
|
||||
online_id: int
|
||||
mods: list[APIMod]
|
||||
statistics: ScoreStatisticsInt
|
||||
maximum_statistics: ScoreStatisticsInt
|
||||
client_version: str
|
||||
rank: Rank
|
||||
user_id: int
|
||||
total_score_without_mods: int
|
||||
|
||||
@@ -1,11 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import datetime
|
||||
from typing import Any, get_origin
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
import msgpack
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
TypeAdapter,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
def serialize_to_list(value: BaseModel) -> list[Any]:
|
||||
data = []
|
||||
for field, info in value.__class__.model_fields.items():
|
||||
v = getattr(value, field)
|
||||
anno = get_origin(info.annotation)
|
||||
if anno and issubclass(anno, BaseModel):
|
||||
data.append(serialize_to_list(v))
|
||||
elif anno and issubclass(anno, list):
|
||||
data.append(
|
||||
TypeAdapter(
|
||||
info.annotation,
|
||||
).dump_python(v)
|
||||
)
|
||||
elif isinstance(v, datetime.datetime):
|
||||
data.append([msgpack.ext.Timestamp.from_datetime(v), 0])
|
||||
else:
|
||||
data.append(v)
|
||||
return data
|
||||
|
||||
|
||||
class MessagePackArrayModel(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def unpack(cls, v: Any) -> Any:
|
||||
@@ -16,11 +47,15 @@ class MessagePackArrayModel(BaseModel):
|
||||
return dict(zip(fields, v))
|
||||
return v
|
||||
|
||||
@model_serializer
|
||||
def serialize(self) -> list[Any]:
|
||||
return serialize_to_list(self)
|
||||
|
||||
|
||||
class Transport(BaseModel):
|
||||
transport: str
|
||||
transfer_formats: list[str] = Field(
|
||||
default_factory=lambda: ["Binary"], alias="transferFormats"
|
||||
default_factory=lambda: ["Binary", "Text"], alias="transferFormats"
|
||||
)
|
||||
|
||||
|
||||
@@ -29,3 +64,8 @@ class NegotiateResponse(BaseModel):
|
||||
connectionToken: str
|
||||
negotiateVersion: int = 1
|
||||
availableTransports: list[Transport]
|
||||
|
||||
|
||||
class UserState(BaseModel):
|
||||
connection_id: str
|
||||
connection_token: str
|
||||
|
||||
@@ -4,18 +4,22 @@ import datetime
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
|
||||
from .score import (
|
||||
HitResult,
|
||||
ScoreStatisticsInt,
|
||||
)
|
||||
from .signalr import MessagePackArrayModel
|
||||
from .signalr import MessagePackArrayModel, UserState
|
||||
|
||||
import msgpack
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class APIMod(MessagePackArrayModel):
|
||||
acronym: str
|
||||
settings: dict[str, Any] = Field(default_factory=dict)
|
||||
settings: dict[str, Any] | list = Field(
|
||||
default_factory=dict
|
||||
) # FIXME: with settings
|
||||
|
||||
|
||||
class SpectatedUserState(IntEnum):
|
||||
@@ -32,7 +36,7 @@ class SpectatorState(MessagePackArrayModel):
|
||||
ruleset_id: int | None = None # 0,1,2,3
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
state: SpectatedUserState
|
||||
maximum_statistics: dict[HitResult, int] = Field(default_factory=dict)
|
||||
maximum_statistics: ScoreStatisticsInt = Field(default_factory=dict)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SpectatorState):
|
||||
@@ -58,7 +62,7 @@ class FrameHeader(MessagePackArrayModel):
|
||||
acc: float
|
||||
combo: int
|
||||
max_combo: int
|
||||
statistics: dict[HitResult, int] = Field(default_factory=dict)
|
||||
statistics: ScoreStatisticsInt = Field(default_factory=dict)
|
||||
score_processor_statistics: ScoreProcessorStatistics
|
||||
received_time: datetime.datetime
|
||||
mods: list[APIMod] = Field(default_factory=list)
|
||||
@@ -79,22 +83,56 @@ class FrameHeader(MessagePackArrayModel):
|
||||
raise ValueError(f"Cannot convert {type(v)} to datetime")
|
||||
|
||||
|
||||
class ReplayButtonState(IntEnum):
|
||||
NONE = 0
|
||||
LEFT1 = 1
|
||||
RIGHT1 = 2
|
||||
LEFT2 = 4
|
||||
RIGHT2 = 8
|
||||
SMOKE = 16
|
||||
# class ReplayButtonState(IntEnum):
|
||||
# NONE = 0
|
||||
# LEFT1 = 1
|
||||
# RIGHT1 = 2
|
||||
# LEFT2 = 4
|
||||
# RIGHT2 = 8
|
||||
# SMOKE = 16
|
||||
|
||||
|
||||
class LegacyReplayFrame(MessagePackArrayModel):
|
||||
time: int # from ReplayFrame,the parent of LegacyReplayFrame
|
||||
time: float # from ReplayFrame,the parent of LegacyReplayFrame
|
||||
x: float | None = None
|
||||
y: float | None = None
|
||||
button_state: ReplayButtonState
|
||||
button_state: int
|
||||
|
||||
|
||||
class FrameDataBundle(MessagePackArrayModel):
|
||||
header: FrameHeader
|
||||
frames: list[LegacyReplayFrame]
|
||||
|
||||
|
||||
# Use for server
|
||||
class APIUser(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class ScoreInfo(BaseModel):
|
||||
mods: list[APIMod]
|
||||
user: APIUser
|
||||
ruleset: int
|
||||
maximum_statistics: ScoreStatisticsInt
|
||||
id: int | None = None
|
||||
total_score: int | None = None
|
||||
acc: float | None = None
|
||||
max_combo: int | None = None
|
||||
combo: int | None = None
|
||||
statistics: ScoreStatisticsInt = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StoreScore(BaseModel):
|
||||
score_info: ScoreInfo
|
||||
replay_frames: list[LegacyReplayFrame] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StoreClientState(UserState):
|
||||
state: SpectatorState | None = None
|
||||
beatmap_status: BeatmapRankStatus | None = None
|
||||
checksum: str | None = None
|
||||
ruleset_id: int | None = None
|
||||
score_token: int | None = None
|
||||
watched_user: set[int] = Field(default_factory=set)
|
||||
score: StoreScore | None = None
|
||||
|
||||
8
app/path.py
Normal file
8
app/path.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
STATIC_DIR = Path(__file__).parent.parent / "static"
|
||||
|
||||
REPLAY_DIR = Path(__file__).parent.parent / "replays"
|
||||
REPLAY_DIR.mkdir(exist_ok=True)
|
||||
@@ -1,15 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.signalr import signalr_router as signalr_router
|
||||
|
||||
from . import ( # pyright: ignore[reportUnusedImport] # noqa: F401
|
||||
beatmap,
|
||||
beatmapset,
|
||||
me,
|
||||
relationship,
|
||||
score,
|
||||
user,
|
||||
)
|
||||
from .api_router import router as api_router
|
||||
from .auth import router as auth_router
|
||||
from .fetcher import fetcher_router as fetcher_router
|
||||
from .signalr import signalr_router as signalr_router
|
||||
|
||||
__all__ = ["api_router", "auth_router", "fetcher_router", "signalr_router"]
|
||||
|
||||
@@ -1,39 +1,264 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
import re
|
||||
|
||||
from app.auth import (
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
generate_refresh_token,
|
||||
get_password_hash,
|
||||
get_token_by_refresh_token,
|
||||
store_token,
|
||||
)
|
||||
from app.config import settings
|
||||
from app.database import User as DBUser
|
||||
from app.dependencies import get_db
|
||||
from app.models.oauth import TokenResponse, OAuthErrorResponse
|
||||
from app.models.oauth import (
|
||||
OAuthErrorResponse,
|
||||
RegistrationRequestErrors,
|
||||
TokenResponse,
|
||||
UserRegistrationErrors,
|
||||
)
|
||||
|
||||
from fastapi import APIRouter, Depends, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400):
|
||||
def create_oauth_error_response(
|
||||
error: str, description: str, hint: str, status_code: int = 400
|
||||
):
|
||||
"""创建标准的 OAuth 错误响应"""
|
||||
error_data = OAuthErrorResponse(
|
||||
error=error,
|
||||
error_description=description,
|
||||
hint=hint,
|
||||
message=description
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=error_data.model_dump()
|
||||
error=error, error_description=description, hint=hint, message=description
|
||||
)
|
||||
return JSONResponse(status_code=status_code, content=error_data.model_dump())
|
||||
|
||||
|
||||
def validate_username(username: str) -> list[str]:
|
||||
"""验证用户名"""
|
||||
errors = []
|
||||
|
||||
if not username:
|
||||
errors.append("Username is required")
|
||||
return errors
|
||||
|
||||
if len(username) < 3:
|
||||
errors.append("Username must be at least 3 characters long")
|
||||
|
||||
if len(username) > 15:
|
||||
errors.append("Username must be at most 15 characters long")
|
||||
|
||||
# 检查用户名格式(只允许字母、数字、下划线、连字符)
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
|
||||
errors.append(
|
||||
"Username can only contain letters, numbers, underscores, and hyphens"
|
||||
)
|
||||
|
||||
# 检查是否以数字开头
|
||||
if username[0].isdigit():
|
||||
errors.append("Username cannot start with a number")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_email(email: str) -> list[str]:
|
||||
"""验证邮箱"""
|
||||
errors = []
|
||||
|
||||
if not email:
|
||||
errors.append("Email is required")
|
||||
return errors
|
||||
|
||||
# 基本的邮箱格式验证
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if not re.match(email_pattern, email):
|
||||
errors.append("Please enter a valid email address")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_password(password: str) -> list[str]:
|
||||
"""验证密码"""
|
||||
errors = []
|
||||
|
||||
if not password:
|
||||
errors.append("Password is required")
|
||||
return errors
|
||||
|
||||
if len(password) < 8:
|
||||
errors.append("Password must be at least 8 characters long")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
router = APIRouter(tags=["osu! OAuth 认证"])
|
||||
|
||||
|
||||
@router.post("/users")
|
||||
async def register_user(
|
||||
user_username: str = Form(..., alias="user[username]"),
|
||||
user_email: str = Form(..., alias="user[user_email]"),
|
||||
user_password: str = Form(..., alias="user[password]"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""用户注册接口 - 匹配 osu! 客户端的注册请求"""
|
||||
|
||||
username_errors = validate_username(user_username)
|
||||
email_errors = validate_email(user_email)
|
||||
password_errors = validate_password(user_password)
|
||||
|
||||
result = await db.exec(select(DBUser).where(DBUser.name == user_username))
|
||||
existing_user = result.first()
|
||||
if existing_user:
|
||||
username_errors.append("Username is already taken")
|
||||
|
||||
result = await db.exec(select(DBUser).where(DBUser.email == user_email))
|
||||
existing_email = result.first()
|
||||
if existing_email:
|
||||
email_errors.append("Email is already taken")
|
||||
|
||||
if username_errors or email_errors or password_errors:
|
||||
errors = RegistrationRequestErrors(
|
||||
user=UserRegistrationErrors(
|
||||
username=username_errors,
|
||||
user_email=email_errors,
|
||||
password=password_errors,
|
||||
)
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
|
||||
try:
|
||||
# 创建新用户
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
new_user = DBUser(
|
||||
name=user_username,
|
||||
safe_name=user_username.lower(), # 安全用户名(小写)
|
||||
email=user_email,
|
||||
pw_bcrypt=get_password_hash(user_password),
|
||||
priv=1, # 普通用户权限
|
||||
country="CN", # 默认国家
|
||||
creation_time=int(time.time()),
|
||||
latest_activity=int(time.time()),
|
||||
preferred_mode=0, # 默认模式
|
||||
play_style=0, # 默认游戏风格
|
||||
)
|
||||
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
|
||||
# 保存用户ID,因为会话可能会关闭
|
||||
user_id = new_user.id
|
||||
|
||||
if user_id <= 2:
|
||||
await db.rollback()
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
|
||||
# 确保 AUTO_INCREMENT 值从3开始(ID=1是BanchoBot,ID=2预留给ppy)
|
||||
await db.execute(text("ALTER TABLE users AUTO_INCREMENT = 3"))
|
||||
await db.commit()
|
||||
|
||||
# 重新创建用户
|
||||
new_user = DBUser(
|
||||
name=user_username,
|
||||
safe_name=user_username.lower(),
|
||||
email=user_email,
|
||||
pw_bcrypt=get_password_hash(user_password),
|
||||
priv=1,
|
||||
country="CN",
|
||||
creation_time=int(time.time()),
|
||||
latest_activity=int(time.time()),
|
||||
preferred_mode=0,
|
||||
play_style=0,
|
||||
)
|
||||
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
user_id = new_user.id
|
||||
|
||||
# 最终检查ID是否有效
|
||||
if user_id <= 2:
|
||||
await db.rollback()
|
||||
errors = RegistrationRequestErrors(
|
||||
message=(
|
||||
"Failed to create account with valid ID. "
|
||||
"Please contact support."
|
||||
)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
|
||||
except Exception as fix_error:
|
||||
await db.rollback()
|
||||
print(f"Failed to fix AUTO_INCREMENT: {fix_error}")
|
||||
errors = RegistrationRequestErrors(
|
||||
message="Failed to create account with valid ID. Please try again."
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
|
||||
# 创建默认的 lazer_profile
|
||||
from app.database.user import LazerUserProfile
|
||||
|
||||
lazer_profile = LazerUserProfile(
|
||||
user_id=user_id,
|
||||
is_active=True,
|
||||
is_bot=False,
|
||||
is_deleted=False,
|
||||
is_online=True,
|
||||
is_supporter=False,
|
||||
is_restricted=False,
|
||||
session_verified=False,
|
||||
has_supported=False,
|
||||
pm_friends_only=False,
|
||||
default_group="default",
|
||||
join_date=datetime.utcnow(),
|
||||
playmode="osu",
|
||||
support_level=0,
|
||||
max_blocks=50,
|
||||
max_friends=250,
|
||||
post_count=0,
|
||||
)
|
||||
|
||||
db.add(lazer_profile)
|
||||
await db.commit()
|
||||
|
||||
# 返回成功响应
|
||||
return JSONResponse(
|
||||
status_code=201,
|
||||
content={"message": "Account created successfully", "user_id": user_id},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
# 打印详细错误信息用于调试
|
||||
print(f"Registration error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# 返回通用错误
|
||||
errors = RegistrationRequestErrors(
|
||||
message="An error occurred while creating your account. Please try again."
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500, content={"form_error": errors.model_dump()}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/oauth/token", response_model=TokenResponse)
|
||||
async def oauth_token(
|
||||
grant_type: str = Form(...),
|
||||
@@ -53,9 +278,13 @@ async def oauth_token(
|
||||
):
|
||||
return create_oauth_error_response(
|
||||
error="invalid_client",
|
||||
description="Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method).",
|
||||
description=(
|
||||
"Client authentication failed (e.g., unknown client, "
|
||||
"no client authentication included, "
|
||||
"or unsupported authentication method)."
|
||||
),
|
||||
hint="Invalid client credentials",
|
||||
status_code=401
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
if grant_type == "password":
|
||||
@@ -63,8 +292,12 @@ async def oauth_token(
|
||||
if not username or not password:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_request",
|
||||
description="The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed.",
|
||||
hint="Username and password required"
|
||||
description=(
|
||||
"The request is missing a required parameter, includes an "
|
||||
"invalid parameter value, "
|
||||
"includes a parameter more than once, or is otherwise malformed."
|
||||
),
|
||||
hint="Username and password required",
|
||||
)
|
||||
|
||||
# 验证用户
|
||||
@@ -72,8 +305,14 @@ async def oauth_token(
|
||||
if not user:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_grant",
|
||||
description="The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client.",
|
||||
hint="Incorrect sign in"
|
||||
description=(
|
||||
"The provided authorization grant (e.g., authorization code, "
|
||||
"resource owner credentials) "
|
||||
"or refresh token is invalid, expired, revoked, "
|
||||
"does not match the redirection URI used in "
|
||||
"the authorization request, or was issued to another client."
|
||||
),
|
||||
hint="Incorrect sign in",
|
||||
)
|
||||
|
||||
# 生成令牌
|
||||
@@ -105,8 +344,12 @@ async def oauth_token(
|
||||
if not refresh_token:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_request",
|
||||
description="The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed.",
|
||||
hint="Refresh token required"
|
||||
description=(
|
||||
"The request is missing a required parameter, "
|
||||
"includes an invalid parameter value, "
|
||||
"includes a parameter more than once, or is otherwise malformed."
|
||||
),
|
||||
hint="Refresh token required",
|
||||
)
|
||||
|
||||
# 验证刷新令牌
|
||||
@@ -114,8 +357,14 @@ async def oauth_token(
|
||||
if not token_record:
|
||||
return create_oauth_error_response(
|
||||
error="invalid_grant",
|
||||
description="The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client.",
|
||||
hint="Invalid refresh token"
|
||||
description=(
|
||||
"The provided authorization grant (e.g., authorization code, "
|
||||
"resource owner credentials) or refresh token is "
|
||||
"invalid, expired, revoked, "
|
||||
"does not match the redirection URI used "
|
||||
"in the authorization request, or was issued to another client."
|
||||
),
|
||||
hint="Invalid refresh token",
|
||||
)
|
||||
|
||||
# 生成新的访问令牌
|
||||
@@ -145,6 +394,9 @@ async def oauth_token(
|
||||
else:
|
||||
return create_oauth_error_response(
|
||||
error="unsupported_grant_type",
|
||||
description="The authorization grant type is not supported by the authorization server.",
|
||||
hint="Unsupported grant type"
|
||||
description=(
|
||||
"The authorization grant type is not supported "
|
||||
"by the authorization server."
|
||||
),
|
||||
hint="Unsupported grant type",
|
||||
)
|
||||
|
||||
@@ -16,7 +16,10 @@ from app.dependencies.user import get_current_user
|
||||
from app.fetcher import Fetcher
|
||||
from app.models.beatmap import BeatmapAttributes
|
||||
from app.models.mods import APIMod, int_to_mods
|
||||
from app.models.score import INT_TO_MODE, GameMode
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
GameMode,
|
||||
)
|
||||
from app.utils import calculate_beatmap_attribute
|
||||
|
||||
from .api_router import router
|
||||
@@ -31,6 +34,31 @@ from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/beatmaps/lookup", tags=["beatmap"], response_model=BeatmapResp)
|
||||
async def lookup_beatmap(
|
||||
id: int | None = Query(default=None, alias="id"),
|
||||
md5: str | None = Query(default=None, alias="checksum"),
|
||||
filename: str | None = Query(default=None, alias="filename"),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
if id is None and md5 is None and filename is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="At least one of 'id', 'checksum', or 'filename' must be provided.",
|
||||
)
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid=id, md5=md5)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
if beatmap is None:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
|
||||
|
||||
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
||||
async def get_beatmap(
|
||||
bid: int,
|
||||
@@ -39,7 +67,7 @@ async def get_beatmap(
|
||||
fetcher: Fetcher = Depends(get_fetcher),
|
||||
):
|
||||
try:
|
||||
beatmap = await Beatmap.get_or_fetch(db, bid, fetcher)
|
||||
beatmap = await Beatmap.get_or_fetch(db, fetcher, bid)
|
||||
return BeatmapResp.from_db(beatmap)
|
||||
except HTTPError:
|
||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||
@@ -119,7 +147,7 @@ async def get_beatmap_attributes(
|
||||
if ruleset_id is not None and ruleset is None:
|
||||
ruleset = INT_TO_MODE[ruleset_id]
|
||||
if ruleset is None:
|
||||
beatmap_db = await Beatmap.get_or_fetch(db, beatmap, fetcher)
|
||||
beatmap_db = await Beatmap.get_or_fetch(db, fetcher, beatmap)
|
||||
ruleset = beatmap_db.mode
|
||||
key = (
|
||||
f"beatmap:{beatmap}:{ruleset}:"
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.database.relationship import Relationship, RelationshipResp, RelationshipType
|
||||
from app.dependencies.database import get_db
|
||||
@@ -9,21 +7,23 @@ from app.dependencies.user import get_current_user
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from fastapi import Depends, HTTPException, Query, Request
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@router.get("/{type}", tags=["relationship"], response_model=list[RelationshipResp])
|
||||
@router.get("/friends", tags=["relationship"], response_model=list[RelationshipResp])
|
||||
@router.get("/blocks", tags=["relationship"], response_model=list[RelationshipResp])
|
||||
async def get_relationship(
|
||||
type: Literal["friends", "blocks"],
|
||||
request: Request,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if type == "friends":
|
||||
relationship_type = RelationshipType.FOLLOW
|
||||
else:
|
||||
relationship_type = RelationshipType.BLOCK
|
||||
relationship_type = (
|
||||
RelationshipType.FOLLOW
|
||||
if request.url.path.endswith("/friends")
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
relationships = await db.exec(
|
||||
select(Relationship).where(
|
||||
Relationship.user_id == current_user.id,
|
||||
@@ -33,17 +33,19 @@ async def get_relationship(
|
||||
return [await RelationshipResp.from_db(db, rel) for rel in relationships]
|
||||
|
||||
|
||||
@router.post("/{type}", tags=["relationship"], response_model=RelationshipResp)
|
||||
@router.post("/friends", tags=["relationship"], response_model=RelationshipResp)
|
||||
@router.post("/blocks", tags=["relationship"])
|
||||
async def add_relationship(
|
||||
type: Literal["friends", "blocks"],
|
||||
request: Request,
|
||||
target: int = Query(),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if type == "blocks":
|
||||
relationship_type = RelationshipType.BLOCK
|
||||
else:
|
||||
relationship_type = RelationshipType.FOLLOW
|
||||
relationship_type = (
|
||||
RelationshipType.FOLLOW
|
||||
if request.url.path.endswith("/friends")
|
||||
else RelationshipType.BLOCK
|
||||
)
|
||||
if target == current_user.id:
|
||||
raise HTTPException(422, "Cannot add relationship to yourself")
|
||||
relationship = (
|
||||
@@ -78,18 +80,22 @@ async def add_relationship(
|
||||
await db.delete(target_relationship)
|
||||
await db.commit()
|
||||
await db.refresh(relationship)
|
||||
return await RelationshipResp.from_db(db, relationship)
|
||||
if relationship.type == RelationshipType.FOLLOW:
|
||||
return await RelationshipResp.from_db(db, relationship)
|
||||
|
||||
|
||||
@router.delete("/{type}/{target}", tags=["relationship"])
|
||||
@router.delete("/friends/{target}", tags=["relationship"])
|
||||
@router.delete("/blocks/{target}", tags=["relationship"])
|
||||
async def delete_relationship(
|
||||
type: Literal["friends", "blocks"],
|
||||
request: Request,
|
||||
target: int,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
relationship_type = (
|
||||
RelationshipType.BLOCK if type == "blocks" else RelationshipType.FOLLOW
|
||||
RelationshipType.BLOCK
|
||||
if "/blocks/" in request.url.path
|
||||
else RelationshipType.FOLLOW
|
||||
)
|
||||
relationship = (
|
||||
await db.exec(
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
from app.database import (
|
||||
Beatmap,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.database.beatmapset import Beatmapset
|
||||
from app.database.score import Score, ScoreResp
|
||||
from app.database.score_token import ScoreToken, ScoreTokenResp
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.models.score import (
|
||||
INT_TO_MODE,
|
||||
GameMode,
|
||||
HitResult,
|
||||
Rank,
|
||||
SoloScoreSubmissionInfo,
|
||||
)
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from fastapi import Depends, Form, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel import col, select, true
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@@ -29,7 +37,7 @@ class BeatmapScores(BaseModel):
|
||||
async def get_beatmap_scores(
|
||||
beatmap: int,
|
||||
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
|
||||
mode: str = Query(None),
|
||||
mode: GameMode | None = Query(None),
|
||||
# mods: List[APIMod] = Query(None), # TODO:加入指定MOD的查询
|
||||
type: str = Query(None),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
@@ -42,29 +50,28 @@ async def get_beatmap_scores(
|
||||
|
||||
all_scores = (
|
||||
await db.exec(
|
||||
select(Score).where(Score.beatmap_id == beatmap)
|
||||
# .where(Score.mods == mods if mods else True)
|
||||
Score.select_clause_unique(
|
||||
Score.beatmap_id == beatmap,
|
||||
col(Score.passed).is_(True),
|
||||
Score.gamemode == mode if mode is not None else true(),
|
||||
)
|
||||
)
|
||||
).all()
|
||||
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
Score.select_clause_unique(
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == current_user.id,
|
||||
col(Score.passed).is_(True),
|
||||
Score.gamemode == mode if mode is not None else true(),
|
||||
)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
|
||||
return BeatmapScores(
|
||||
scores=[ScoreResp.from_db(score) for score in all_scores],
|
||||
userScore=ScoreResp.from_db(user_score) if user_score else None,
|
||||
scores=[await ScoreResp.from_db(db, score) for score in all_scores],
|
||||
userScore=await ScoreResp.from_db(db, user_score) if user_score else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -93,18 +100,13 @@ async def get_user_beatmap_score(
|
||||
)
|
||||
user_score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
Score.select_clause()
|
||||
.where(
|
||||
Score.gamemode == mode if mode is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == user,
|
||||
)
|
||||
.where(Score.gamemode == mode if mode is not None else True)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == user)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
.order_by(col(Score.total_score).desc())
|
||||
)
|
||||
).first()
|
||||
|
||||
@@ -115,7 +117,7 @@ async def get_user_beatmap_score(
|
||||
else:
|
||||
return BeatmapUserScore(
|
||||
position=user_score.position if user_score.position is not None else 0,
|
||||
score=ScoreResp.from_db(user_score),
|
||||
score=await ScoreResp.from_db(db, user_score),
|
||||
)
|
||||
|
||||
|
||||
@@ -138,19 +140,114 @@ async def get_user_all_beatmap_scores(
|
||||
)
|
||||
all_user_scores = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(
|
||||
joinedload(Score.beatmap) # pyright: ignore[reportArgumentType]
|
||||
.joinedload(Beatmap.beatmapset) # pyright: ignore[reportArgumentType]
|
||||
.selectinload(
|
||||
Beatmapset.beatmaps # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
Score.select_clause()
|
||||
.where(
|
||||
Score.gamemode == ruleset if ruleset is not None else True,
|
||||
Score.beatmap_id == beatmap,
|
||||
Score.user_id == user,
|
||||
)
|
||||
.where(Score.gamemode == ruleset if ruleset is not None else True)
|
||||
.where(Score.beatmap_id == beatmap)
|
||||
.where(Score.user_id == user)
|
||||
.order_by(col(Score.classic_total_score).desc())
|
||||
)
|
||||
).all()
|
||||
|
||||
return [ScoreResp.from_db(score) for score in all_user_scores]
|
||||
return [await ScoreResp.from_db(db, score) for score in all_user_scores]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/beatmaps/{beatmap}/solo/scores", tags=["beatmap"], response_model=ScoreTokenResp
|
||||
)
|
||||
async def create_solo_score(
|
||||
beatmap: int,
|
||||
version_hash: str = Form(""),
|
||||
beatmap_hash: str = Form(),
|
||||
ruleset_id: int = Form(..., ge=0, le=3),
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
async with db:
|
||||
score_token = ScoreToken(
|
||||
user_id=current_user.id,
|
||||
beatmap_id=beatmap,
|
||||
ruleset_id=INT_TO_MODE[ruleset_id],
|
||||
)
|
||||
db.add(score_token)
|
||||
await db.commit()
|
||||
await db.refresh(score_token)
|
||||
return ScoreTokenResp.from_db(score_token)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/beatmaps/{beatmap}/solo/scores/{token}",
|
||||
tags=["beatmap"],
|
||||
response_model=ScoreResp,
|
||||
)
|
||||
async def submit_solo_score(
|
||||
beatmap: int,
|
||||
token: int,
|
||||
info: SoloScoreSubmissionInfo,
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not info.passed:
|
||||
info.rank = Rank.F
|
||||
async with db:
|
||||
score_token = (
|
||||
await db.exec(
|
||||
select(ScoreToken)
|
||||
.options(joinedload(ScoreToken.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(ScoreToken.id == token, ScoreToken.user_id == current_user.id)
|
||||
)
|
||||
).first()
|
||||
if not score_token or score_token.user_id != current_user.id:
|
||||
raise HTTPException(status_code=404, detail="Score token not found")
|
||||
if score_token.score_id:
|
||||
score = (
|
||||
await db.exec(
|
||||
select(Score)
|
||||
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
Score.id == score_token.score_id,
|
||||
Score.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not score:
|
||||
raise HTTPException(status_code=404, detail="Score not found")
|
||||
else:
|
||||
score = Score(
|
||||
accuracy=info.accuracy,
|
||||
max_combo=info.max_combo,
|
||||
# maximum_statistics=info.maximum_statistics,
|
||||
mods=info.mods,
|
||||
passed=info.passed,
|
||||
rank=info.rank,
|
||||
total_score=info.total_score,
|
||||
total_score_without_mods=info.total_score_without_mods,
|
||||
beatmap_id=beatmap,
|
||||
ended_at=datetime.datetime.now(datetime.UTC),
|
||||
gamemode=INT_TO_MODE[info.ruleset_id],
|
||||
started_at=score_token.created_at,
|
||||
user_id=current_user.id,
|
||||
preserve=info.passed,
|
||||
map_md5=score_token.beatmap.checksum,
|
||||
has_replay=False,
|
||||
pp=info.pp,
|
||||
type="solo",
|
||||
n300=info.statistics.get(HitResult.GREAT, 0),
|
||||
n100=info.statistics.get(HitResult.OK, 0),
|
||||
n50=info.statistics.get(HitResult.MEH, 0),
|
||||
nmiss=info.statistics.get(HitResult.MISS, 0),
|
||||
ngeki=info.statistics.get(HitResult.PERFECT, 0),
|
||||
nkatu=info.statistics.get(HitResult.GOOD, 0),
|
||||
)
|
||||
db.add(score)
|
||||
await db.commit()
|
||||
await db.refresh(score)
|
||||
score_id = score.id
|
||||
score_token.score_id = score_id
|
||||
await db.commit()
|
||||
score = (
|
||||
await db.exec(Score.select_clause().where(Score.id == score_id))
|
||||
).first()
|
||||
assert score is not None
|
||||
return await ScoreResp.from_db(db, score)
|
||||
|
||||
@@ -1,211 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.router.signalr.exception import InvokeException
|
||||
from app.router.signalr.packet import (
|
||||
PacketType,
|
||||
ResultKind,
|
||||
encode_varint,
|
||||
parse_packet,
|
||||
)
|
||||
from app.router.signalr.store import ResultStore
|
||||
from app.router.signalr.utils import get_signature
|
||||
|
||||
from fastapi import WebSocket
|
||||
import msgpack
|
||||
from pydantic import BaseModel
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self, connection_id: str, connection_token: str, connection: WebSocket
|
||||
) -> None:
|
||||
self.connection_id = connection_id
|
||||
self.connection_token = connection_token
|
||||
self.connection = connection
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
self._ping_task: asyncio.Task | None = None
|
||||
self._store = ResultStore()
|
||||
|
||||
async def send_packet(self, type: PacketType, packet: list[Any]):
|
||||
packet.insert(0, type.value)
|
||||
payload = msgpack.packb(packet)
|
||||
length = encode_varint(len(payload))
|
||||
await self.connection.send_bytes(length + payload)
|
||||
|
||||
async def _ping(self):
|
||||
while True:
|
||||
try:
|
||||
await self.send_packet(PacketType.PING, [])
|
||||
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error in ping task for {self.connection_id}: {e}")
|
||||
break
|
||||
|
||||
|
||||
class Hub:
|
||||
def __init__(self) -> None:
|
||||
self.clients: dict[str, Client] = {}
|
||||
self.waited_clients: dict[str, int] = {}
|
||||
self.tasks: set[asyncio.Task] = set()
|
||||
|
||||
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
|
||||
self.waited_clients[connection_token] = timestamp
|
||||
|
||||
def add_client(
|
||||
self, connection_id: str, connection_token: str, connection: WebSocket
|
||||
) -> Client:
|
||||
if connection_token in self.clients:
|
||||
raise ValueError(
|
||||
f"Client with connection token {connection_token} already exists."
|
||||
)
|
||||
if connection_token in self.waited_clients:
|
||||
if (
|
||||
self.waited_clients[connection_token]
|
||||
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT
|
||||
):
|
||||
raise TimeoutError(f"Connection {connection_id} has waited too long.")
|
||||
del self.waited_clients[connection_token]
|
||||
client = Client(connection_id, connection_token, connection)
|
||||
self.clients[connection_token] = client
|
||||
task = asyncio.create_task(client._ping())
|
||||
self.tasks.add(task)
|
||||
client._ping_task = task
|
||||
return client
|
||||
|
||||
async def remove_client(self, connection_id: str) -> None:
|
||||
if client := self.clients.get(connection_id):
|
||||
del self.clients[connection_id]
|
||||
if client._listen_task:
|
||||
client._listen_task.cancel()
|
||||
if client._ping_task:
|
||||
client._ping_task.cancel()
|
||||
await client.connection.close()
|
||||
|
||||
async def send_packet(self, client: Client, type: PacketType, packet: list[Any]):
|
||||
await client.send_packet(type, packet)
|
||||
|
||||
async def _listen_client(self, client: Client) -> None:
|
||||
jump = False
|
||||
while not jump:
|
||||
try:
|
||||
message = await client.connection.receive_bytes()
|
||||
packet_type, packet_data = parse_packet(message)
|
||||
task = asyncio.create_task(
|
||||
self._handle_packet(client, packet_type, packet_data)
|
||||
)
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
except WebSocketDisconnect as e:
|
||||
if e.code == 1005:
|
||||
continue
|
||||
print(
|
||||
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
|
||||
)
|
||||
jump = True
|
||||
except Exception as e:
|
||||
print(f"Error in client {client.connection_id}: {e}")
|
||||
jump = True
|
||||
await self.remove_client(client.connection_id)
|
||||
|
||||
async def _handle_packet(
|
||||
self, client: Client, type: PacketType, packet: list[Any]
|
||||
) -> None:
|
||||
match type:
|
||||
case PacketType.PING:
|
||||
...
|
||||
case PacketType.INVOCATION:
|
||||
invocation_id: str | None = packet[1] # pyright: ignore[reportRedeclaration]
|
||||
target: str = packet[2]
|
||||
args: list[Any] | None = packet[3]
|
||||
if args is None:
|
||||
args = []
|
||||
# streams: list[str] | None = packet[4] # TODO: stream support
|
||||
code = ResultKind.VOID
|
||||
result = None
|
||||
try:
|
||||
result = await self.invoke_method(client, target, args)
|
||||
if result is not None:
|
||||
code = ResultKind.HAS_VALUE
|
||||
except InvokeException as e:
|
||||
code = ResultKind.ERROR
|
||||
result = e.message
|
||||
|
||||
except Exception as e:
|
||||
code = ResultKind.ERROR
|
||||
result = str(e)
|
||||
|
||||
packet = [
|
||||
{}, # header
|
||||
invocation_id,
|
||||
code.value,
|
||||
]
|
||||
if result is not None:
|
||||
packet.append(result)
|
||||
if invocation_id is not None:
|
||||
await client.send_packet(
|
||||
PacketType.COMPLETION,
|
||||
packet,
|
||||
)
|
||||
case PacketType.COMPLETION:
|
||||
invocation_id: str = packet[1]
|
||||
code: ResultKind = ResultKind(packet[2])
|
||||
result: Any = packet[3] if len(packet) > 3 else None
|
||||
client._store.add_result(invocation_id, code, result)
|
||||
|
||||
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
|
||||
method_ = getattr(self, method, None)
|
||||
call_params = []
|
||||
if not method_:
|
||||
raise InvokeException(f"Method '{method}' not found in hub.")
|
||||
signature = get_signature(method_)
|
||||
for name, param in signature.parameters.items():
|
||||
if name == "self" or param.annotation is Client:
|
||||
continue
|
||||
if issubclass(param.annotation, BaseModel):
|
||||
call_params.append(param.annotation.model_validate(args.pop(0)))
|
||||
else:
|
||||
call_params.append(args.pop(0))
|
||||
return await method_(client, *call_params)
|
||||
|
||||
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
||||
invocation_id = client._store.get_invocation_id()
|
||||
await client.send_packet(
|
||||
PacketType.INVOCATION,
|
||||
[
|
||||
{}, # header
|
||||
invocation_id,
|
||||
method,
|
||||
list(args),
|
||||
None, # streams
|
||||
],
|
||||
)
|
||||
r = await client._store.fetch(invocation_id, None)
|
||||
if r[0] == ResultKind.HAS_VALUE:
|
||||
return r[1]
|
||||
if r[0] == ResultKind.ERROR:
|
||||
raise InvokeException(r[1])
|
||||
return None
|
||||
|
||||
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
|
||||
await client.send_packet(
|
||||
PacketType.INVOCATION,
|
||||
[
|
||||
{}, # header
|
||||
None, # invocation_id
|
||||
method,
|
||||
list(args),
|
||||
None, # streams
|
||||
],
|
||||
)
|
||||
return None
|
||||
|
||||
def __contains__(self, item: str) -> bool:
|
||||
return item in self.clients or item in self.waited_clients
|
||||
@@ -1,6 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .hub import Hub
|
||||
|
||||
|
||||
class MetadataHub(Hub): ...
|
||||
@@ -1,15 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.spectator_hub import FrameDataBundle, SpectatorState
|
||||
|
||||
from .hub import Client, Hub
|
||||
|
||||
|
||||
class SpectatorHub(Hub):
|
||||
async def BeginPlaySession(
|
||||
self, client: Client, score_token: int, state: SpectatorState
|
||||
) -> None: ...
|
||||
|
||||
async def SendFrameData(
|
||||
self, client: Client, frame_data: FrameDataBundle
|
||||
) -> None: ...
|
||||
@@ -1,56 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
|
||||
SEP = b"\x1e"
|
||||
|
||||
|
||||
class PacketType(IntEnum):
|
||||
INVOCATION = 1
|
||||
STREAM_ITEM = 2
|
||||
COMPLETION = 3
|
||||
STREAM_INVOCATION = 4
|
||||
CANCEL_INVOCATION = 5
|
||||
PING = 6
|
||||
CLOSE = 7
|
||||
|
||||
|
||||
class ResultKind(IntEnum):
|
||||
ERROR = 1
|
||||
VOID = 2
|
||||
HAS_VALUE = 3
|
||||
|
||||
|
||||
def parse_packet(data: bytes) -> tuple[PacketType, list[Any]]:
|
||||
length, offset = decode_varint(data)
|
||||
message_data = data[offset : offset + length]
|
||||
unpacked = msgpack.unpackb(message_data, raw=False)
|
||||
return PacketType(unpacked[0]), unpacked[1:]
|
||||
|
||||
|
||||
def encode_varint(value: int) -> bytes:
|
||||
result = []
|
||||
while value >= 0x80:
|
||||
result.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
result.append(value & 0x7F)
|
||||
return bytes(result)
|
||||
|
||||
|
||||
def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
|
||||
result = 0
|
||||
shift = 0
|
||||
pos = offset
|
||||
|
||||
while pos < len(data):
|
||||
byte = data[pos]
|
||||
result |= (byte & 0x7F) << shift
|
||||
pos += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
shift += 7
|
||||
|
||||
return result, pos
|
||||
75
app/router/user.py
Normal file
75
app/router/user.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.database import (
|
||||
User as DBUser,
|
||||
)
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.models.score import INT_TO_MODE
|
||||
from app.models.user import (
|
||||
User as ApiUser,
|
||||
)
|
||||
from app.utils import convert_db_user_to_api_user
|
||||
|
||||
from .api_router import router
|
||||
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import col
|
||||
|
||||
|
||||
@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
|
||||
@router.get("/users/{user}", response_model=ApiUser)
|
||||
async def get_user_info_default(
|
||||
user: str,
|
||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
searched_user = (
|
||||
await session.exec(
|
||||
DBUser.all_select_clause().where(
|
||||
DBUser.id == int(user)
|
||||
if user.isdigit()
|
||||
else DBUser.name == user.removeprefix("@")
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if not searched_user:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
return await convert_db_user_to_api_user(searched_user, ruleset=ruleset)
|
||||
|
||||
|
||||
class BatchUserResponse(BaseModel):
|
||||
users: list[ApiUser]
|
||||
|
||||
|
||||
@router.get("/users", response_model=BatchUserResponse)
|
||||
@router.get("/users/lookup", response_model=BatchUserResponse)
|
||||
async def get_users(
|
||||
user_ids: list[int] = Query(default_factory=list, alias="ids[]"),
|
||||
include_variant_statistics: bool = Query(default=False), # TODO
|
||||
current_user: DBUser = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if user_ids:
|
||||
searched_users = (
|
||||
await session.exec(
|
||||
DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids))
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
searched_users = (
|
||||
await session.exec(DBUser.all_select_clause().limit(50))
|
||||
).all()
|
||||
return BatchUserResponse(
|
||||
users=[
|
||||
await convert_db_user_to_api_user(
|
||||
searched_user, ruleset=INT_TO_MODE[current_user.preferred_mode].value
|
||||
)
|
||||
for searched_user in searched_users
|
||||
]
|
||||
)
|
||||
301
app/signalr/hub/hub.py
Normal file
301
app/signalr/hub/hub.py
Normal file
@@ -0,0 +1,301 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.log import logger
|
||||
from app.models.signalr import UserState
|
||||
from app.signalr.exception import InvokeException
|
||||
from app.signalr.packet import (
|
||||
ClosePacket,
|
||||
CompletionPacket,
|
||||
InvocationPacket,
|
||||
Packet,
|
||||
PingPacket,
|
||||
Protocol,
|
||||
)
|
||||
from app.signalr.store import ResultStore
|
||||
from app.signalr.utils import get_signature
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
|
||||
class CloseConnection(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Connection closed",
|
||||
allow_reconnect: bool = False,
|
||||
from_client: bool = False,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.allow_reconnect = allow_reconnect
|
||||
self.from_client = from_client
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self,
|
||||
connection_id: str,
|
||||
connection_token: str,
|
||||
connection: WebSocket,
|
||||
protocol: Protocol,
|
||||
) -> None:
|
||||
self.connection_id = connection_id
|
||||
self.connection_token = connection_token
|
||||
self.connection = connection
|
||||
self.procotol = protocol
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
self._ping_task: asyncio.Task | None = None
|
||||
self._store = ResultStore()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.connection_token)
|
||||
|
||||
@property
|
||||
def user_id(self) -> int:
|
||||
return int(self.connection_id)
|
||||
|
||||
async def send_packet(self, packet: Packet):
|
||||
await self.connection.send_bytes(self.procotol.encode(packet))
|
||||
|
||||
async def receive_packets(self) -> list[Packet]:
|
||||
message = await self.connection.receive()
|
||||
d = message.get("bytes") or message.get("text", "").encode()
|
||||
if not d:
|
||||
return []
|
||||
return self.procotol.decode(d)
|
||||
|
||||
async def _ping(self):
|
||||
while True:
|
||||
try:
|
||||
await self.send_packet(PingPacket())
|
||||
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ping task for {self.connection_id}: {e}")
|
||||
break
|
||||
|
||||
|
||||
class Hub[TState: UserState]:
|
||||
def __init__(self) -> None:
|
||||
self.clients: dict[str, Client] = {}
|
||||
self.waited_clients: dict[str, int] = {}
|
||||
self.tasks: set[asyncio.Task] = set()
|
||||
self.groups: dict[str, set[Client]] = {}
|
||||
self.state: dict[int, TState] = {}
|
||||
|
||||
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
|
||||
self.waited_clients[connection_token] = timestamp
|
||||
|
||||
def get_client_by_id(self, id: str, default: Any = None) -> Client:
|
||||
for client in self.clients.values():
|
||||
if client.connection_id == id:
|
||||
return client
|
||||
return default
|
||||
|
||||
@abstractmethod
|
||||
def create_state(self, client: Client) -> TState:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_or_create_state(self, client: Client) -> TState:
|
||||
if (state := self.state.get(client.user_id)) is not None:
|
||||
return state
|
||||
state = self.create_state(client)
|
||||
self.state[client.user_id] = state
|
||||
return state
|
||||
|
||||
def add_to_group(self, client: Client, group_id: str) -> None:
|
||||
self.groups.setdefault(group_id, set()).add(client)
|
||||
|
||||
def remove_from_group(self, client: Client, group_id: str) -> None:
|
||||
if group_id in self.groups:
|
||||
self.groups[group_id].discard(client)
|
||||
|
||||
async def add_client(
|
||||
self,
|
||||
connection_id: str,
|
||||
connection_token: str,
|
||||
protocol: Protocol,
|
||||
connection: WebSocket,
|
||||
) -> Client:
|
||||
if connection_token in self.clients:
|
||||
raise ValueError(
|
||||
f"Client with connection token {connection_token} already exists."
|
||||
)
|
||||
if connection_token in self.waited_clients:
|
||||
if (
|
||||
self.waited_clients[connection_token]
|
||||
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT
|
||||
):
|
||||
raise TimeoutError(f"Connection {connection_id} has waited too long.")
|
||||
del self.waited_clients[connection_token]
|
||||
client = Client(connection_id, connection_token, connection, protocol)
|
||||
self.clients[connection_token] = client
|
||||
task = asyncio.create_task(client._ping())
|
||||
self.tasks.add(task)
|
||||
client._ping_task = task
|
||||
return client
|
||||
|
||||
async def remove_client(self, client: Client) -> None:
|
||||
del self.clients[client.connection_token]
|
||||
if client._listen_task:
|
||||
client._listen_task.cancel()
|
||||
if client._ping_task:
|
||||
client._ping_task.cancel()
|
||||
for group in self.groups.values():
|
||||
group.discard(client)
|
||||
await self.clean_state(client, False)
|
||||
|
||||
@abstractmethod
|
||||
async def _clean_state(self, state: TState) -> None:
|
||||
return
|
||||
|
||||
async def clean_state(self, client: Client, disconnected: bool) -> None:
|
||||
if (state := self.state.get(client.user_id)) is None:
|
||||
return
|
||||
if disconnected and client.connection_token != state.connection_token:
|
||||
return
|
||||
try:
|
||||
await self._clean_state(state)
|
||||
except Exception:
|
||||
...
|
||||
|
||||
async def on_connect(self, client: Client) -> None:
|
||||
if method := getattr(self, "on_client_connect", None):
|
||||
await method(client)
|
||||
|
||||
async def send_packet(self, client: Client, packet: Packet) -> None:
|
||||
await client.send_packet(packet)
|
||||
|
||||
async def broadcast_call(self, method: str, *args: Any) -> None:
|
||||
tasks = []
|
||||
for client in self.clients.values():
|
||||
tasks.append(self.call_noblock(client, method, *args))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def broadcast_group_call(
|
||||
self, group_id: str, method: str, *args: Any
|
||||
) -> None:
|
||||
tasks = []
|
||||
for client in self.groups.get(group_id, []):
|
||||
tasks.append(self.call_noblock(client, method, *args))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _listen_client(self, client: Client) -> None:
|
||||
try:
|
||||
while True:
|
||||
packets = await client.receive_packets()
|
||||
for packet in packets:
|
||||
if isinstance(packet, PingPacket):
|
||||
continue
|
||||
elif isinstance(packet, ClosePacket):
|
||||
raise CloseConnection(
|
||||
packet.error or "Connection closed by client",
|
||||
packet.allow_reconnect,
|
||||
True,
|
||||
)
|
||||
task = asyncio.create_task(self._handle_packet(client, packet))
|
||||
self.tasks.add(task)
|
||||
task.add_done_callback(self.tasks.discard)
|
||||
except WebSocketDisconnect as e:
|
||||
logger.info(
|
||||
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "disconnect message" in str(e):
|
||||
logger.info(f"Client {client.connection_id} closed the connection.")
|
||||
else:
|
||||
logger.exception(f"RuntimeError in client {client.connection_id}: {e}")
|
||||
except CloseConnection as e:
|
||||
if not e.from_client:
|
||||
await client.send_packet(
|
||||
ClosePacket(error=e.message, allow_reconnect=e.allow_reconnect)
|
||||
)
|
||||
logger.info(
|
||||
f"Client {client.connection_id} closed the connection: {e.message}"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Error in client {client.connection_id}")
|
||||
|
||||
await self.remove_client(client)
|
||||
|
||||
async def _handle_packet(self, client: Client, packet: Packet) -> None:
|
||||
if isinstance(packet, PingPacket):
|
||||
return
|
||||
elif isinstance(packet, InvocationPacket):
|
||||
args = packet.arguments or []
|
||||
error = None
|
||||
result = None
|
||||
try:
|
||||
result = await self.invoke_method(client, packet.target, args)
|
||||
except InvokeException as e:
|
||||
error = e.message
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error invoking method {packet.target} for "
|
||||
f"client {client.connection_id}"
|
||||
)
|
||||
error = str(e)
|
||||
if packet.invocation_id is not None:
|
||||
await client.send_packet(
|
||||
CompletionPacket(
|
||||
invocation_id=packet.invocation_id,
|
||||
error=error,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
elif isinstance(packet, CompletionPacket):
|
||||
client._store.add_result(packet.invocation_id, packet.result, packet.error)
|
||||
|
||||
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
|
||||
method_ = getattr(self, method, None)
|
||||
call_params = []
|
||||
if not method_:
|
||||
raise InvokeException(f"Method '{method}' not found in hub.")
|
||||
signature = get_signature(method_)
|
||||
for name, param in signature.parameters.items():
|
||||
if name == "self" or param.annotation is Client:
|
||||
continue
|
||||
if issubclass(param.annotation, BaseModel):
|
||||
call_params.append(param.annotation.model_validate(args.pop(0)))
|
||||
else:
|
||||
call_params.append(args.pop(0))
|
||||
return await method_(client, *call_params)
|
||||
|
||||
async def call(self, client: Client, method: str, *args: Any) -> Any:
|
||||
invocation_id = client._store.get_invocation_id()
|
||||
await client.send_packet(
|
||||
InvocationPacket(
|
||||
header={},
|
||||
invocation_id=invocation_id,
|
||||
target=method,
|
||||
arguments=list(args),
|
||||
stream_ids=None,
|
||||
)
|
||||
)
|
||||
r = await client._store.fetch(invocation_id, None)
|
||||
if r[1]:
|
||||
raise InvokeException(r[1])
|
||||
return r[0]
|
||||
|
||||
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
|
||||
await client.send_packet(
|
||||
InvocationPacket(
|
||||
header={},
|
||||
invocation_id=None,
|
||||
target=method,
|
||||
arguments=list(args),
|
||||
stream_ids=None,
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
def __contains__(self, item: str) -> bool:
|
||||
return item in self.clients or item in self.waited_clients
|
||||
151
app/signalr/hub/metadata.py
Normal file
151
app/signalr/hub/metadata.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Coroutine
|
||||
from typing import override
|
||||
|
||||
from app.database.relationship import Relationship, RelationshipType
|
||||
from app.dependencies.database import engine
|
||||
from app.models.metadata_hub import MetadataClientState, OnlineStatus, UserActivity
|
||||
|
||||
from .hub import Client, Hub
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
ONLINE_PRESENCE_WATCHERS_GROUP = "metadata:online-presence-watchers"
|
||||
|
||||
|
||||
class MetadataHub(Hub[MetadataClientState]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def online_presence_watchers_group() -> str:
|
||||
return ONLINE_PRESENCE_WATCHERS_GROUP
|
||||
|
||||
def broadcast_tasks(
|
||||
self, user_id: int, store: MetadataClientState | None
|
||||
) -> set[Coroutine]:
|
||||
if store is not None and not store.pushable:
|
||||
return set()
|
||||
data = store.to_dict() if store else None
|
||||
return {
|
||||
self.broadcast_group_call(
|
||||
self.online_presence_watchers_group(),
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
data,
|
||||
),
|
||||
self.broadcast_group_call(
|
||||
self.friend_presence_watchers_group(user_id),
|
||||
"FriendPresenceUpdated",
|
||||
user_id,
|
||||
data,
|
||||
),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def friend_presence_watchers_group(user_id: int):
|
||||
return f"metadata:friend-presence-watchers:{user_id}"
|
||||
|
||||
@override
|
||||
async def _clean_state(self, state: MetadataClientState) -> None:
|
||||
if state.pushable:
|
||||
await asyncio.gather(*self.broadcast_tasks(int(state.connection_id), None))
|
||||
|
||||
@override
|
||||
def create_state(self, client: Client) -> MetadataClientState:
|
||||
return MetadataClientState(
|
||||
connection_id=client.connection_id,
|
||||
connection_token=client.connection_token,
|
||||
)
|
||||
|
||||
async def on_client_connect(self, client: Client) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
self.get_or_create_state(client)
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
friends = (
|
||||
await session.exec(
|
||||
select(Relationship.target_id).where(
|
||||
Relationship.user_id == user_id,
|
||||
Relationship.type == RelationshipType.FOLLOW,
|
||||
)
|
||||
)
|
||||
).all()
|
||||
tasks = []
|
||||
for friend_id in friends:
|
||||
self.groups.setdefault(
|
||||
self.friend_presence_watchers_group(friend_id), set()
|
||||
).add(client)
|
||||
if (
|
||||
friend_state := self.state.get(friend_id)
|
||||
) and friend_state.pushable:
|
||||
tasks.append(
|
||||
self.broadcast_group_call(
|
||||
self.friend_presence_watchers_group(friend_id),
|
||||
"FriendPresenceUpdated",
|
||||
friend_id,
|
||||
friend_state.to_dict(),
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def UpdateStatus(self, client: Client, status: int) -> None:
|
||||
status_ = OnlineStatus(status)
|
||||
user_id = int(client.connection_id)
|
||||
store = self.get_or_create_state(client)
|
||||
if store.status is not None and store.status == status_:
|
||||
return
|
||||
store.status = OnlineStatus(status_)
|
||||
tasks = self.broadcast_tasks(user_id, store)
|
||||
tasks.add(
|
||||
self.call_noblock(
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.to_dict(),
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def UpdateActivity(self, client: Client, activity_dict: dict | None) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
activity = (
|
||||
TypeAdapter(UserActivity).validate_python(activity_dict)
|
||||
if activity_dict
|
||||
else None
|
||||
)
|
||||
store = self.get_or_create_state(client)
|
||||
store.user_activity = activity
|
||||
tasks = self.broadcast_tasks(user_id, store)
|
||||
tasks.add(
|
||||
self.call_noblock(
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.to_dict(),
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def BeginWatchingUserPresence(self, client: Client) -> None:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.call_noblock(
|
||||
client,
|
||||
"UserPresenceUpdated",
|
||||
user_id,
|
||||
store.to_dict(),
|
||||
)
|
||||
for user_id, store in self.state.items()
|
||||
if store.pushable
|
||||
]
|
||||
)
|
||||
self.add_to_group(client, self.online_presence_watchers_group())
|
||||
|
||||
async def EndWatchingUserPresence(self, client: Client) -> None:
|
||||
self.remove_from_group(client, self.online_presence_watchers_group())
|
||||
355
app/signalr/hub/spectator.py
Normal file
355
app/signalr/hub/spectator.py
Normal file
@@ -0,0 +1,355 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import lzma
|
||||
import struct
|
||||
import time
|
||||
from typing import override
|
||||
|
||||
from app.database import Beatmap
|
||||
from app.database.score import Score
|
||||
from app.database.score_token import ScoreToken
|
||||
from app.database.user import User
|
||||
from app.dependencies.database import engine
|
||||
from app.models.beatmap import BeatmapRankStatus
|
||||
from app.models.mods import mods_to_int
|
||||
from app.models.score import LegacyReplaySoloScoreInfo, ScoreStatisticsInt
|
||||
from app.models.signalr import serialize_to_list
|
||||
from app.models.spectator_hub import (
|
||||
APIUser,
|
||||
FrameDataBundle,
|
||||
LegacyReplayFrame,
|
||||
ScoreInfo,
|
||||
SpectatedUserState,
|
||||
SpectatorState,
|
||||
StoreClientState,
|
||||
StoreScore,
|
||||
)
|
||||
from app.path import REPLAY_DIR
|
||||
from app.utils import unix_timestamp_to_windows
|
||||
|
||||
from .hub import Client, Hub
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
READ_SCORE_TIMEOUT = 30
|
||||
REPLAY_LATEST_VER = 30000016
|
||||
|
||||
|
||||
def encode_uleb128(num: int) -> bytes | bytearray:
|
||||
if num == 0:
|
||||
return b"\x00"
|
||||
|
||||
ret = bytearray()
|
||||
|
||||
while num != 0:
|
||||
ret.append(num & 0x7F)
|
||||
num >>= 7
|
||||
if num != 0:
|
||||
ret[-1] |= 0x80
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def encode_string(s: str) -> bytes:
|
||||
"""Write `s` into bytes (ULEB128 & string)."""
|
||||
if s:
|
||||
encoded = s.encode()
|
||||
ret = b"\x0b" + encode_uleb128(len(encoded)) + encoded
|
||||
else:
|
||||
ret = b"\x00"
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def save_replay(
|
||||
ruleset_id: int,
|
||||
md5: str,
|
||||
username: str,
|
||||
score: Score,
|
||||
statistics: ScoreStatisticsInt,
|
||||
maximum_statistics: ScoreStatisticsInt,
|
||||
frames: list[LegacyReplayFrame],
|
||||
) -> None:
|
||||
data = bytearray()
|
||||
data.extend(struct.pack("<bi", ruleset_id, REPLAY_LATEST_VER))
|
||||
data.extend(encode_string(md5))
|
||||
data.extend(encode_string(username))
|
||||
data.extend(encode_string(f"lazer-{username}-{score.started_at.isoformat()}"))
|
||||
data.extend(
|
||||
struct.pack(
|
||||
"<hhhhhhihbi",
|
||||
score.n300,
|
||||
score.n100,
|
||||
score.n50,
|
||||
score.ngeki,
|
||||
score.nkatu,
|
||||
score.nmiss,
|
||||
score.total_score,
|
||||
score.max_combo,
|
||||
score.is_perfect_combo,
|
||||
mods_to_int(score.mods),
|
||||
)
|
||||
)
|
||||
data.extend(encode_string("")) # hp graph
|
||||
data.extend(
|
||||
struct.pack(
|
||||
"<q",
|
||||
unix_timestamp_to_windows(round(score.started_at.timestamp())),
|
||||
)
|
||||
)
|
||||
|
||||
# write frames
|
||||
# FIXME: cannot play in stable
|
||||
frame_strs = []
|
||||
last_time = 0
|
||||
for frame in frames:
|
||||
frame_strs.append(
|
||||
f"{frame.time - last_time}|{frame.x or 0.0}"
|
||||
f"|{frame.y or 0.0}|{frame.button_state}"
|
||||
)
|
||||
last_time = frame.time
|
||||
frame_strs.append("-12345|0|0|0")
|
||||
|
||||
compressed = lzma.compress(
|
||||
",".join(frame_strs).encode("ascii"), format=lzma.FORMAT_ALONE
|
||||
)
|
||||
data.extend(struct.pack("<i", len(compressed)))
|
||||
data.extend(compressed)
|
||||
data.extend(struct.pack("<q", score.id))
|
||||
assert score.id
|
||||
score_info = LegacyReplaySoloScoreInfo(
|
||||
online_id=score.id,
|
||||
mods=score.mods,
|
||||
statistics=statistics,
|
||||
maximum_statistics=maximum_statistics,
|
||||
client_version="",
|
||||
rank=score.rank,
|
||||
user_id=score.user_id,
|
||||
total_score_without_mods=score.total_score_without_mods,
|
||||
)
|
||||
compressed = lzma.compress(
|
||||
json.dumps(score_info).encode(), format=lzma.FORMAT_ALONE
|
||||
)
|
||||
data.extend(struct.pack("<i", len(compressed)))
|
||||
data.extend(compressed)
|
||||
|
||||
replay_path = REPLAY_DIR / f"lazer-{score.type}-{username}-{score.id}.osr"
|
||||
replay_path.write_bytes(data)
|
||||
|
||||
|
||||
class SpectatorHub(Hub[StoreClientState]):
|
||||
@staticmethod
|
||||
def group_id(user_id: int) -> str:
|
||||
return f"watch:{user_id}"
|
||||
|
||||
@override
|
||||
def create_state(self, client: Client) -> StoreClientState:
|
||||
return StoreClientState(
|
||||
connection_id=client.connection_id,
|
||||
connection_token=client.connection_token,
|
||||
)
|
||||
|
||||
@override
|
||||
async def _clean_state(self, state: StoreClientState) -> None:
|
||||
if state.state:
|
||||
await self._end_session(int(state.connection_id), state.state)
|
||||
for target in self.waited_clients:
|
||||
target_client = self.get_client_by_id(target)
|
||||
if target_client:
|
||||
await self.call_noblock(
|
||||
target_client, "UserEndedWatching", int(state.connection_id)
|
||||
)
|
||||
|
||||
async def on_client_connect(self, client: Client) -> None:
|
||||
tasks = [
|
||||
self.call_noblock(
|
||||
client, "UserBeganPlaying", user_id, serialize_to_list(store.state)
|
||||
)
|
||||
for user_id, store in self.state.items()
|
||||
if store.state is not None
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def BeginPlaySession(
|
||||
self, client: Client, score_token: int, state: SpectatorState
|
||||
) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
store = self.get_or_create_state(client)
|
||||
if store.state is not None:
|
||||
return
|
||||
if state.beatmap_id is None or state.ruleset_id is None:
|
||||
return
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
beatmap = (
|
||||
await session.exec(
|
||||
select(Beatmap).where(Beatmap.id == state.beatmap_id)
|
||||
)
|
||||
).first()
|
||||
if not beatmap:
|
||||
return
|
||||
user = (
|
||||
await session.exec(select(User).where(User.id == user_id))
|
||||
).first()
|
||||
if not user:
|
||||
return
|
||||
name = user.name
|
||||
store.state = state
|
||||
store.beatmap_status = beatmap.beatmap_status
|
||||
store.checksum = beatmap.checksum
|
||||
store.ruleset_id = state.ruleset_id
|
||||
store.score_token = score_token
|
||||
store.score = StoreScore(
|
||||
score_info=ScoreInfo(
|
||||
mods=state.mods,
|
||||
user=APIUser(id=user_id, name=name),
|
||||
ruleset=state.ruleset_id,
|
||||
maximum_statistics=state.maximum_statistics,
|
||||
)
|
||||
)
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserBeganPlaying",
|
||||
user_id,
|
||||
serialize_to_list(state),
|
||||
)
|
||||
|
||||
async def SendFrameData(self, client: Client, frame_data: FrameDataBundle) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
state = self.get_or_create_state(client)
|
||||
if not state.score:
|
||||
return
|
||||
state.score.score_info.acc = frame_data.header.acc
|
||||
state.score.score_info.combo = frame_data.header.combo
|
||||
state.score.score_info.max_combo = frame_data.header.max_combo
|
||||
state.score.score_info.statistics = frame_data.header.statistics
|
||||
state.score.score_info.total_score = frame_data.header.total_score
|
||||
state.score.score_info.mods = frame_data.header.mods
|
||||
state.score.replay_frames.extend(frame_data.frames)
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserSentFrames",
|
||||
user_id,
|
||||
frame_data.model_dump(),
|
||||
)
|
||||
|
||||
async def EndPlaySession(self, client: Client, state: SpectatorState) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
store = self.get_or_create_state(client)
|
||||
score = store.score
|
||||
if not score or not store.score_token:
|
||||
return
|
||||
|
||||
assert store.beatmap_status is not None
|
||||
|
||||
async def _save_replay():
|
||||
assert store.checksum is not None
|
||||
assert store.ruleset_id is not None
|
||||
assert store.state is not None
|
||||
assert store.score is not None
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session:
|
||||
start_time = time.time()
|
||||
score_record = None
|
||||
while time.time() - start_time < READ_SCORE_TIMEOUT:
|
||||
sub_query = select(ScoreToken.score_id).where(
|
||||
ScoreToken.id == store.score_token,
|
||||
)
|
||||
result = await session.exec(
|
||||
select(Score)
|
||||
.options(joinedload(Score.beatmap)) # pyright: ignore[reportArgumentType]
|
||||
.where(
|
||||
Score.id == sub_query,
|
||||
Score.user_id == user_id,
|
||||
)
|
||||
)
|
||||
score_record = result.first()
|
||||
if score_record:
|
||||
break
|
||||
if not score_record:
|
||||
return
|
||||
if not score_record.passed:
|
||||
return
|
||||
score_record.has_replay = True
|
||||
await session.commit()
|
||||
await session.refresh(score_record)
|
||||
save_replay(
|
||||
ruleset_id=store.ruleset_id,
|
||||
md5=store.checksum,
|
||||
username=store.score.score_info.user.name,
|
||||
score=score_record,
|
||||
statistics=score.score_info.statistics,
|
||||
maximum_statistics=score.score_info.maximum_statistics,
|
||||
frames=score.replay_frames,
|
||||
)
|
||||
|
||||
if (
|
||||
(
|
||||
BeatmapRankStatus.PENDING
|
||||
< store.beatmap_status
|
||||
<= BeatmapRankStatus.LOVED
|
||||
)
|
||||
and any(
|
||||
k.is_hit() and v > 0 for k, v in score.score_info.statistics.items()
|
||||
)
|
||||
and state.state != SpectatedUserState.Failed
|
||||
):
|
||||
# save replay
|
||||
await _save_replay()
|
||||
store.state = None
|
||||
store.beatmap_status = None
|
||||
store.checksum = None
|
||||
store.ruleset_id = None
|
||||
store.score_token = None
|
||||
store.score = None
|
||||
await self._end_session(user_id, state)
|
||||
|
||||
async def _end_session(self, user_id: int, state: SpectatorState) -> None:
|
||||
if state.state == SpectatedUserState.Playing:
|
||||
state.state = SpectatedUserState.Quit
|
||||
await self.broadcast_group_call(
|
||||
self.group_id(user_id),
|
||||
"UserFinishedPlaying",
|
||||
user_id,
|
||||
serialize_to_list(state) if state else None,
|
||||
)
|
||||
|
||||
async def StartWatchingUser(self, client: Client, target_id: int) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
target_store = self.get_or_create_state(client)
|
||||
if target_store.state:
|
||||
await self.call_noblock(
|
||||
client,
|
||||
"UserBeganPlaying",
|
||||
target_id,
|
||||
serialize_to_list(target_store.state),
|
||||
)
|
||||
store = self.get_or_create_state(client)
|
||||
store.watched_user.add(target_id)
|
||||
|
||||
self.add_to_group(client, self.group_id(target_id))
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
async with session.begin():
|
||||
username = (
|
||||
await session.exec(select(User.name).where(User.id == user_id))
|
||||
).first()
|
||||
if not username:
|
||||
return
|
||||
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||
await self.call_noblock(
|
||||
target_client, "UserStartedWatching", [[user_id, username]]
|
||||
)
|
||||
|
||||
async def EndWatchingUser(self, client: Client, target_id: int) -> None:
|
||||
user_id = int(client.connection_id)
|
||||
self.remove_from_group(client, self.group_id(target_id))
|
||||
store = self.state.get(user_id)
|
||||
if store:
|
||||
store.watched_user.discard(target_id)
|
||||
if (target_client := self.get_client_by_id(str(target_id))) is not None:
|
||||
await self.call_noblock(target_client, "UserEndedWatching", user_id)
|
||||
277
app/signalr/packet.py
Normal file
277
app/signalr/packet.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
import json
|
||||
from typing import (
|
||||
Any,
|
||||
Protocol as TypingProtocol,
|
||||
)
|
||||
|
||||
import msgpack
|
||||
|
||||
SEP = b"\x1e"
|
||||
|
||||
|
||||
class PacketType(IntEnum):
|
||||
INVOCATION = 1
|
||||
STREAM_ITEM = 2
|
||||
COMPLETION = 3
|
||||
STREAM_INVOCATION = 4
|
||||
CANCEL_INVOCATION = 5
|
||||
PING = 6
|
||||
CLOSE = 7
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Packet:
|
||||
type: PacketType
|
||||
header: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class InvocationPacket(Packet):
|
||||
type: PacketType = PacketType.INVOCATION
|
||||
invocation_id: str | None
|
||||
target: str
|
||||
arguments: list[Any] | None = None
|
||||
stream_ids: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class CompletionPacket(Packet):
|
||||
type: PacketType = PacketType.COMPLETION
|
||||
invocation_id: str
|
||||
result: Any
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class PingPacket(Packet):
|
||||
type: PacketType = PacketType.PING
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ClosePacket(Packet):
|
||||
type: PacketType = PacketType.CLOSE
|
||||
error: str | None = None
|
||||
allow_reconnect: bool = False
|
||||
|
||||
|
||||
PACKETS = {
|
||||
PacketType.INVOCATION: InvocationPacket,
|
||||
PacketType.COMPLETION: CompletionPacket,
|
||||
PacketType.PING: PingPacket,
|
||||
PacketType.CLOSE: ClosePacket,
|
||||
}
|
||||
|
||||
|
||||
class Protocol(TypingProtocol):
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> list[Packet]: ...
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes: ...
|
||||
|
||||
|
||||
class MsgpackProtocol:
|
||||
@staticmethod
|
||||
def _encode_varint(value: int) -> bytes:
|
||||
result = []
|
||||
while value >= 0x80:
|
||||
result.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
result.append(value & 0x7F)
|
||||
return bytes(result)
|
||||
|
||||
@staticmethod
|
||||
def _decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
|
||||
result = 0
|
||||
shift = 0
|
||||
pos = offset
|
||||
|
||||
while pos < len(data):
|
||||
byte = data[pos]
|
||||
result |= (byte & 0x7F) << shift
|
||||
pos += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
shift += 7
|
||||
|
||||
return result, pos
|
||||
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> list[Packet]:
|
||||
length, offset = MsgpackProtocol._decode_varint(input)
|
||||
message_data = input[offset : offset + length]
|
||||
# FIXME: custom deserializer for APIMod
|
||||
# https://github.com/ppy/osu/blob/master/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs
|
||||
unpacked = msgpack.unpackb(
|
||||
message_data, raw=False, strict_map_key=False, use_list=True
|
||||
)
|
||||
packet_type = PacketType(unpacked[0])
|
||||
if packet_type not in PACKETS:
|
||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||
match packet_type:
|
||||
case PacketType.INVOCATION:
|
||||
return [
|
||||
InvocationPacket(
|
||||
header=unpacked[1],
|
||||
invocation_id=unpacked[2],
|
||||
target=unpacked[3],
|
||||
arguments=unpacked[4] if len(unpacked) > 4 else None,
|
||||
stream_ids=unpacked[5] if len(unpacked) > 5 else None,
|
||||
)
|
||||
]
|
||||
case PacketType.COMPLETION:
|
||||
result_kind = unpacked[3]
|
||||
return [
|
||||
CompletionPacket(
|
||||
header=unpacked[1],
|
||||
invocation_id=unpacked[2],
|
||||
error=unpacked[4] if result_kind == 1 else None,
|
||||
result=unpacked[5] if result_kind == 3 else None,
|
||||
)
|
||||
]
|
||||
case PacketType.PING:
|
||||
return [PingPacket()]
|
||||
case PacketType.CLOSE:
|
||||
return [
|
||||
ClosePacket(
|
||||
error=unpacked[1],
|
||||
allow_reconnect=unpacked[2] if len(unpacked) > 2 else False,
|
||||
)
|
||||
]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes:
|
||||
payload = [packet.type.value, packet.header or {}]
|
||||
if isinstance(packet, InvocationPacket):
|
||||
payload.extend(
|
||||
[
|
||||
packet.invocation_id,
|
||||
packet.target,
|
||||
]
|
||||
)
|
||||
if packet.arguments is not None:
|
||||
payload.append(packet.arguments)
|
||||
if packet.stream_ids is not None:
|
||||
payload.append(packet.stream_ids)
|
||||
elif isinstance(packet, CompletionPacket):
|
||||
result_kind = 2
|
||||
if packet.error:
|
||||
result_kind = 1
|
||||
elif packet.result is None:
|
||||
result_kind = 3
|
||||
payload.extend(
|
||||
[
|
||||
packet.invocation_id,
|
||||
result_kind,
|
||||
packet.error or packet.result or None,
|
||||
]
|
||||
)
|
||||
elif isinstance(packet, ClosePacket):
|
||||
payload.extend(
|
||||
[
|
||||
packet.error or "",
|
||||
packet.allow_reconnect,
|
||||
]
|
||||
)
|
||||
elif isinstance(packet, PingPacket):
|
||||
payload.pop(-1)
|
||||
data = msgpack.packb(payload, use_bin_type=True, datetime=True)
|
||||
return MsgpackProtocol._encode_varint(len(data)) + data
|
||||
|
||||
|
||||
class JSONProtocol:
|
||||
@staticmethod
|
||||
def decode(input: bytes) -> list[Packet]:
|
||||
packets_raw = input.removesuffix(SEP).split(SEP)
|
||||
packets = []
|
||||
if len(packets_raw) > 1:
|
||||
for packet_raw in packets_raw:
|
||||
packets.extend(JSONProtocol.decode(packet_raw))
|
||||
return packets
|
||||
else:
|
||||
data = json.loads(packets_raw[0])
|
||||
packet_type = PacketType(data["type"])
|
||||
if packet_type not in PACKETS:
|
||||
raise ValueError(f"Unknown packet type: {packet_type}")
|
||||
match packet_type:
|
||||
case PacketType.INVOCATION:
|
||||
return [
|
||||
InvocationPacket(
|
||||
header=data.get("header"),
|
||||
invocation_id=data.get("invocationId"),
|
||||
target=data["target"],
|
||||
arguments=data.get("arguments"),
|
||||
stream_ids=data.get("streamIds"),
|
||||
)
|
||||
]
|
||||
case PacketType.COMPLETION:
|
||||
return [
|
||||
CompletionPacket(
|
||||
header=data.get("header"),
|
||||
invocation_id=data["invocationId"],
|
||||
error=data.get("error"),
|
||||
result=data.get("result"),
|
||||
)
|
||||
]
|
||||
case PacketType.PING:
|
||||
return [PingPacket()]
|
||||
case PacketType.CLOSE:
|
||||
return [
|
||||
ClosePacket(
|
||||
error=data.get("error"),
|
||||
allow_reconnect=data.get("allowReconnect", False),
|
||||
)
|
||||
]
|
||||
raise ValueError(f"Unsupported packet type: {packet_type}")
|
||||
|
||||
@staticmethod
|
||||
def encode(packet: Packet) -> bytes:
|
||||
payload: dict[str, Any] = {
|
||||
"type": packet.type.value,
|
||||
}
|
||||
if packet.header:
|
||||
payload["header"] = packet.header
|
||||
if isinstance(packet, InvocationPacket):
|
||||
payload.update(
|
||||
{
|
||||
"target": packet.target,
|
||||
}
|
||||
)
|
||||
if packet.invocation_id is not None:
|
||||
payload["invocationId"] = packet.invocation_id
|
||||
if packet.arguments is not None:
|
||||
payload["arguments"] = packet.arguments
|
||||
if packet.stream_ids is not None:
|
||||
payload["streamIds"] = packet.stream_ids
|
||||
elif isinstance(packet, CompletionPacket):
|
||||
payload.update(
|
||||
{
|
||||
"invocationId": packet.invocation_id,
|
||||
}
|
||||
)
|
||||
if packet.error is not None:
|
||||
payload["error"] = packet.error
|
||||
if packet.result is not None:
|
||||
payload["result"] = packet.result
|
||||
elif isinstance(packet, PingPacket):
|
||||
pass
|
||||
elif isinstance(packet, ClosePacket):
|
||||
payload.update(
|
||||
{
|
||||
"allowReconnect": packet.allow_reconnect,
|
||||
}
|
||||
)
|
||||
if packet.error is not None:
|
||||
payload["error"] = packet.error
|
||||
return json.dumps(payload).encode("utf-8") + SEP
|
||||
|
||||
|
||||
PROTOCOLS: dict[str, Protocol] = {
|
||||
"json": JSONProtocol,
|
||||
"messagepack": MsgpackProtocol,
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Literal
|
||||
@@ -10,9 +11,9 @@ from app.dependencies import get_current_user
|
||||
from app.dependencies.database import get_db
|
||||
from app.dependencies.user import get_current_user_by_token
|
||||
from app.models.signalr import NegotiateResponse, Transport
|
||||
from app.router.signalr.packet import SEP
|
||||
|
||||
from .hub import Hubs
|
||||
from .packet import PROTOCOLS, SEP
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -62,30 +63,41 @@ async def connect(
|
||||
await websocket.accept()
|
||||
|
||||
# handshake
|
||||
handshake = await websocket.receive_bytes()
|
||||
handshake_payload = json.loads(handshake[:-1])
|
||||
handshake = await websocket.receive()
|
||||
message = handshake.get("bytes") or handshake.get("text")
|
||||
if not message:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
handshake_payload = json.loads(message[:-1])
|
||||
error = ""
|
||||
if (protocol := handshake_payload.get("protocol")) != "messagepack" or (
|
||||
handshake_payload.get("version")
|
||||
) != 1:
|
||||
error = f"Requested protocol '{protocol}' is not available."
|
||||
protocol = handshake_payload.get("protocol", "json")
|
||||
|
||||
client = None
|
||||
try:
|
||||
client = hub_.add_client(
|
||||
client = await hub_.add_client(
|
||||
connection_id=user_id,
|
||||
connection_token=id,
|
||||
connection=websocket,
|
||||
protocol=PROTOCOLS[protocol],
|
||||
)
|
||||
except KeyError:
|
||||
error = f"Protocol '{protocol}' is not supported."
|
||||
except TimeoutError:
|
||||
error = f"Connection {id} has waited too long."
|
||||
except ValueError as e:
|
||||
error = str(e)
|
||||
payload = {"error": error} if error else {}
|
||||
|
||||
# finish handshake
|
||||
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
|
||||
if error or not client:
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
await hub_.clean_state(client, False)
|
||||
task = asyncio.create_task(hub_.on_connect(client))
|
||||
hub_.tasks.add(task)
|
||||
task.add_done_callback(hub_.tasks.discard)
|
||||
await hub_._listen_client(client)
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
...
|
||||
@@ -2,9 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.router.signalr.packet import ResultKind
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ResultStore:
|
||||
@@ -22,21 +20,17 @@ class ResultStore:
|
||||
return str(s)
|
||||
|
||||
def add_result(
|
||||
self, invocation_id: str, type: ResultKind, result: dict[str, Any] | None
|
||||
self, invocation_id: str, result: Any, error: str | None = None
|
||||
) -> None:
|
||||
if isinstance(invocation_id, str) and invocation_id.isdecimal():
|
||||
if future := self._futures.get(invocation_id):
|
||||
future.set_result((type, result))
|
||||
future.set_result((result, error))
|
||||
|
||||
async def fetch(
|
||||
self,
|
||||
invocation_id: str,
|
||||
timeout: float | None, # noqa: ASYNC109
|
||||
) -> (
|
||||
tuple[Literal[ResultKind.ERROR], str]
|
||||
| tuple[Literal[ResultKind.VOID], None]
|
||||
| tuple[Literal[ResultKind.HAS_VALUE], Any]
|
||||
):
|
||||
) -> tuple[Any, str | None]:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._futures[invocation_id] = future
|
||||
try:
|
||||
@@ -2,24 +2,20 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Any, ForwardRef, cast
|
||||
|
||||
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L61-L75
|
||||
if sys.version_info < (3, 12, 4):
|
||||
|
||||
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
|
||||
def evaluate_forwardref(
|
||||
type_: ForwardRef,
|
||||
globalns: Any,
|
||||
localns: Any,
|
||||
) -> Any:
|
||||
# Even though it is the right signature for python 3.9,
|
||||
# mypy complains with
|
||||
# `error: Too many arguments for "_evaluate" of
|
||||
# "ForwardRef"` hence the cast...
|
||||
return cast(Any, type_)._evaluate(
|
||||
globalns,
|
||||
localns,
|
||||
set(),
|
||||
)
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
|
||||
else:
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
return cast(Any, type_)._evaluate(
|
||||
globalns, localns, type_params=(), recursive_guard=set()
|
||||
)
|
||||
|
||||
|
||||
def get_annotation(param: inspect.Parameter, globalns: dict[str, Any]) -> Any:
|
||||
@@ -28,6 +28,11 @@ from app.models.user import (
|
||||
import rosu_pp_py as rosu
|
||||
|
||||
|
||||
def unix_timestamp_to_windows(timestamp: int) -> int:
|
||||
"""Convert a Unix timestamp to a Windows timestamp."""
|
||||
return (timestamp + 62135596800) * 10_000_000
|
||||
|
||||
|
||||
async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User:
|
||||
"""将数据库用户模型转换为API用户模型(使用 Lazer 表)"""
|
||||
|
||||
@@ -205,7 +210,7 @@ async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") ->
|
||||
# 转换团队信息
|
||||
team = None
|
||||
if db_user.team_membership:
|
||||
team_member = db_user.team_membership[0] # 假设用户只属于一个团队
|
||||
team_member = db_user.team_membership # 假设用户只属于一个团队
|
||||
team = team_member.team
|
||||
|
||||
# 创建用户对象
|
||||
|
||||
Reference in New Issue
Block a user