feat: 基础 SignalR 服务器支持

This commit is contained in:
MingxuanGame
2025-07-24 18:45:08 +08:00
committed by GitHub
parent 6ed5a2d347
commit 1655bb9f53
28 changed files with 1394 additions and 644 deletions

14
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,14 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Server",
"type": "debugpy",
"request": "launch",
"program": "main.py",
"console": "integratedTerminal",
"justMyCode": true
}
]
}

View File

@@ -1,6 +1,7 @@
FROM python:3.11-slim
FROM ghcr.io/astral-sh/uv:python3.11-bookworm-slim
WORKDIR /app
ENV UV_PROJECT_ENVIRONMENT venv
# 安装系统依赖
RUN apt-get update && apt-get install -y \
@@ -10,10 +11,11 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
COPY uv.lock .
COPY pyproject.toml .
# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt
RUN uv sync --locked
# 复制应用代码
COPY . .
@@ -22,4 +24,4 @@ COPY . .
EXPOSE 8000
# 启动命令
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
CMD ["uv", "run", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -32,5 +32,9 @@ class Settings:
PORT: int = int(os.getenv("PORT", "8000"))
DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true"
# SignalR 设置
SIGNALR_NEGOTIATE_TIMEOUT: int = int(os.getenv("SIGNALR_NEGOTIATE_TIMEOUT", "30"))
SIGNALR_PING_INTERVAL: int = int(os.getenv("SIGNALR_PING_INTERVAL", "120"))
settings = Settings()

View File

@@ -19,14 +19,15 @@ async def get_current_user(
"""获取当前认证用户"""
token = credentials.credentials
# 验证令牌
user = await get_current_user_by_token(token, db)
if not user:
raise HTTPException(status_code=401, detail="Invalid or expired token")
return user
async def get_current_user_by_token(token: str, db: Session) -> DBUser | None:
token_record = get_token_by_access_token(db, token)
if not token_record:
raise HTTPException(status_code=401, detail="Invalid or expired token")
# 获取用户
return None
user = db.query(DBUser).filter(DBUser.id == token_record.user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user

1
app/models/__init__.py Normal file
View File

@@ -0,0 +1 @@

29
app/models/oauth.py Normal file
View File

@@ -0,0 +1,29 @@
# OAuth 相关模型
from __future__ import annotations
from pydantic import BaseModel
class TokenRequest(BaseModel):
grant_type: str
username: str | None = None
password: str | None = None
refresh_token: str | None = None
client_id: str
client_secret: str
scope: str = "*"
class TokenResponse(BaseModel):
access_token: str
token_type: str = "Bearer"
expires_in: int
refresh_token: str
scope: str = "*"
class UserCreate(BaseModel):
username: str
password: str
email: str
country_code: str = "CN"

40
app/models/score.py Normal file
View File

@@ -0,0 +1,40 @@
from enum import Enum, IntEnum
from typing import Any
from pydantic import BaseModel
class GameMode(str, Enum):
OSU = "osu"
TAIKO = "taiko"
FRUITS = "fruits"
MANIA = "mania"
class APIMod(BaseModel):
acronym: str
settings: dict[str, Any] = {}
# https://github.com/ppy/osu/blob/master/osu.Game/Rulesets/Scoring/HitResult.cs
class HitResult(IntEnum):
PERFECT = 0 # [Order(0)]
GREAT = 1 # [Order(1)]
GOOD = 2 # [Order(2)]
OK = 3 # [Order(3)]
MEH = 4 # [Order(4)]
MISS = 5 # [Order(5)]
LARGE_TICK_HIT = 6 # [Order(6)]
SMALL_TICK_HIT = 7 # [Order(7)]
SLIDER_TAIL_HIT = 8 # [Order(8)]
LARGE_BONUS = 9 # [Order(9)]
SMALL_BONUS = 10 # [Order(10)]
LARGE_TICK_MISS = 11 # [Order(11)]
SMALL_TICK_MISS = 12 # [Order(12)]
IGNORE_HIT = 13 # [Order(13)]
IGNORE_MISS = 14 # [Order(14)]
NONE = 15 # [Order(15)]
COMBO_BREAK = 16 # [Order(16)]
LEGACY_COMBO_INCREASE = 99 # [Order(99)] @deprecated

31
app/models/signalr.py Normal file
View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, Field, model_validator
class MessagePackArrayModel(BaseModel):
@model_validator(mode="before")
@classmethod
def unpack(cls, v: Any) -> Any:
if isinstance(v, list):
fields = list(cls.model_fields.keys())
if len(v) != len(fields):
raise ValueError(f"Expected list of length {len(fields)}, got {len(v)}")
return dict(zip(fields, v))
return v
class Transport(BaseModel):
transport: str
transfer_formats: list[str] = Field(
default_factory=lambda: ["Binary"], alias="transferFormats"
)
class NegotiateResponse(BaseModel):
connectionId: str
connectionToken: str
negotiateVersion: int = 1
availableTransports: list[Transport]

View File

@@ -0,0 +1,99 @@
from __future__ import annotations
import datetime
from enum import IntEnum
from typing import Any
import msgpack
from pydantic import Field, field_validator
from .signalr import MessagePackArrayModel
from .score import (
APIMod as APIModBase,
HitResult,
)
class APIMod(APIModBase, MessagePackArrayModel): ...
class SpectatedUserState(IntEnum):
Idle = 0
Playing = 1
Paused = 2
Passed = 3
Failed = 4
Quit = 5
class SpectatorState(MessagePackArrayModel):
beatmap_id: int | None = None
ruleset_id: int | None = None # 0,1,2,3
mods: list[APIMod] = Field(default_factory=list)
state: SpectatedUserState
maximum_statistics: dict[HitResult, int] = Field(default_factory=dict)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SpectatorState):
return False
return (
self.beatmap_id == other.beatmap_id
and self.ruleset_id == other.ruleset_id
and self.mods == other.mods
and self.state == other.state
)
class ScoreProcessorStatistics(MessagePackArrayModel):
base_score: int
maximum_base_score: int
accuracy_judgement_count: int
combo_portion: float
bouns_portion: float
class FrameHeader(MessagePackArrayModel):
total_score: int
acc: float
combo: int
max_combo: int
statistics: dict[HitResult, int] = Field(default_factory=dict)
score_processor_statistics: ScoreProcessorStatistics
received_time: datetime.datetime
mods: list[APIMod] = Field(default_factory=list)
@field_validator("received_time", mode="before")
@classmethod
def validate_timestamp(cls, v: Any) -> datetime.datetime:
if isinstance(v, msgpack.ext.Timestamp):
return v.to_datetime()
if isinstance(v, list):
return v[0].to_datetime()
if isinstance(v, datetime.datetime):
return v
if isinstance(v, int | float):
return datetime.datetime.fromtimestamp(v, tz=datetime.UTC)
if isinstance(v, str):
return datetime.datetime.fromisoformat(v)
raise ValueError(f"Cannot convert {type(v)} to datetime")
class ReplayButtonState(IntEnum):
NONE = 0
LEFT1 = 1
RIGHT1 = 2
LEFT2 = 4
RIGHT2 = 8
SMOKE = 16
class LegacyReplayFrame(MessagePackArrayModel):
time: int # from ReplayFrame,the parent of LegacyReplayFrame
x: float | None = None
y: float | None = None
button_state: ReplayButtonState
class FrameDataBundle(MessagePackArrayModel):
header: FrameHeader
frames: list[LegacyReplayFrame]

View File

@@ -2,19 +2,13 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Optional
from .score import GameMode
from pydantic import BaseModel
from app.database import LazerUserAchievement # 添加数据库模型导入
class GameMode(str, Enum):
OSU = "osu"
TAIKO = "taiko"
FRUITS = "fruits"
MANIA = "mania"
class PlayStyle(str, Enum):
MOUSE = "mouse"
KEYBOARD = "keyboard"
@@ -28,9 +22,9 @@ class Country(BaseModel):
class Cover(BaseModel):
custom_url: Optional[str] = None
custom_url: str | None = None
url: str
id: Optional[int] = None
id: int | None = None
class Level(BaseModel):
@@ -52,8 +46,8 @@ class Statistics(BaseModel):
count_50: int = 0
count_miss: int = 0
level: Level
global_rank: Optional[int] = None
global_rank_exp: Optional[int] = None
global_rank: int | None = None
global_rank_exp: int | None = None
pp: float = 0.0
pp_exp: float = 0.0
ranked_score: int = 0
@@ -66,8 +60,8 @@ class Statistics(BaseModel):
replays_watched_by_others: int = 0
is_ranked: bool = False
grade_counts: GradeCounts
country_rank: Optional[int] = None
rank: Optional[dict] = None
country_rank: int | None = None
rank: dict | None = None
class Kudosu(BaseModel):
@@ -106,8 +100,8 @@ class RankHistory(BaseModel):
class DailyChallengeStats(BaseModel):
daily_streak_best: int = 0
daily_streak_current: int = 0
last_update: Optional[datetime] = None
last_weekly_streak: Optional[datetime] = None
last_update: datetime | None = None
last_weekly_streak: datetime | None = None
playcount: int = 0
top_10p_placements: int = 0
top_50p_placements: int = 0
@@ -141,24 +135,24 @@ class User(BaseModel):
is_online: bool = True
is_supporter: bool = False
is_restricted: bool = False
last_visit: Optional[datetime] = None
last_visit: datetime | None = None
pm_friends_only: bool = False
profile_colour: Optional[str] = None
profile_colour: str | None = None
# 个人资料
cover_url: Optional[str] = None
discord: Optional[str] = None
cover_url: str | None = None
discord: str | None = None
has_supported: bool = False
interests: Optional[str] = None
interests: str | None = None
join_date: datetime
location: Optional[str] = None
location: str | None = None
max_blocks: int = 100
max_friends: int = 500
occupation: Optional[str] = None
occupation: str | None = None
playmode: GameMode = GameMode.OSU
playstyle: list[PlayStyle] = []
post_count: int = 0
profile_hue: Optional[int] = None
profile_hue: int | None = None
profile_order: list[str] = [
"me",
"recent_activity",
@@ -168,10 +162,10 @@ class User(BaseModel):
"beatmaps",
"kudosu",
]
title: Optional[str] = None
title_url: Optional[str] = None
twitter: Optional[str] = None
website: Optional[str] = None
title: str | None = None
title_url: str | None = None
twitter: str | None = None
website: str | None = None
session_verified: bool = False
support_level: int = 0
@@ -203,44 +197,18 @@ class User(BaseModel):
# 历史数据
account_history: list[dict] = []
active_tournament_banner: Optional[dict] = None
active_tournament_banner: dict | None = None
active_tournament_banners: list[dict] = []
badges: list[dict] = []
current_season_stats: Optional[dict] = None
daily_challenge_user_stats: Optional[DailyChallengeStats] = None
current_season_stats: dict | None = None
daily_challenge_user_stats: DailyChallengeStats | None = None
groups: list[dict] = []
monthly_playcounts: list[MonthlyPlaycount] = []
page: Page = Page()
previous_usernames: list[str] = []
rank_highest: Optional[RankHighest] = None
rank_history: Optional[RankHistory] = None
rankHistory: Optional[RankHistory] = None # 兼容性别名
rank_highest: RankHighest | None = None
rank_history: RankHistory | None = None
rankHistory: RankHistory | None = None # 兼容性别名
replays_watched_counts: list[dict] = []
team: Optional[Team] = None
team: Team | None = None
user_achievements: list[UserAchievement] = []
# OAuth 相关模型
class TokenRequest(BaseModel):
grant_type: str
username: Optional[str] = None
password: Optional[str] = None
refresh_token: Optional[str] = None
client_id: str
client_secret: str
scope: str = "*"
class TokenResponse(BaseModel):
access_token: str
token_type: str = "Bearer"
expires_in: int
refresh_token: str
scope: str = "*"
class UserCreate(BaseModel):
username: str
password: str
email: str
country_code: str = "CN"

View File

@@ -3,3 +3,4 @@ from __future__ import annotations
from . import me # pyright: ignore[reportUnusedImport] # noqa: F401
from .api_router import router as api_router
from .auth import router as auth_router
from .signalr import signalr_router as signalr_router

View File

@@ -11,7 +11,7 @@ from app.auth import (
)
from app.config import settings
from app.dependencies import get_db
from app.models import TokenResponse
from app.models.oauth import TokenResponse
from fastapi import APIRouter, Depends, Form, HTTPException
from sqlalchemy.orm import Session

View File

@@ -6,7 +6,7 @@ from app.database import (
User as DBUser,
)
from app.dependencies import get_current_user, get_db
from app.models import (
from app.models.user import (
User as ApiUser,
)
from app.utils import convert_db_user_to_api_user
@@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
@router.get("/me/{ruleset}", response_model=ApiUser)
@router.get("/me/", response_model=ApiUser)
async def get_user_info_default(
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
current_user: DBUser = Depends(get_current_user),

View File

@@ -0,0 +1 @@
from .router import router as signalr_router

View File

@@ -0,0 +1,6 @@
class SignalRException(Exception):
pass
class InvokeException(SignalRException):
def __init__(self, message: str) -> None:
self.message = message

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from .hub import Hub
from .metadata import MetadataHub
from .multiplayer import MultiplayerHub
from .spectator import SpectatorHub
SpectatorHubs = SpectatorHub()
MultiplayerHubs = MultiplayerHub()
MetadataHubs = MetadataHub()
Hubs: dict[str, Hub] = {
"spectator": SpectatorHubs,
"multiplayer": MultiplayerHubs,
"metadata": MetadataHubs,
}

View File

@@ -0,0 +1,211 @@
from __future__ import annotations
import asyncio
import time
from typing import Any
from app.config import settings
from app.router.signalr.exception import InvokeException
from app.router.signalr.packet import (
PacketType,
ResultKind,
encode_varint,
parse_packet,
)
from app.router.signalr.store import ResultStore
from app.router.signalr.utils import get_signature
from fastapi import WebSocket
import msgpack
from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
class Client:
def __init__(
self, connection_id: str, connection_token: str, connection: WebSocket
) -> None:
self.connection_id = connection_id
self.connection_token = connection_token
self.connection = connection
self._listen_task: asyncio.Task | None = None
self._ping_task: asyncio.Task | None = None
self._store = ResultStore()
async def send_packet(self, type: PacketType, packet: list[Any]):
packet.insert(0, type.value)
payload = msgpack.packb(packet)
length = encode_varint(len(payload))
await self.connection.send_bytes(length + payload)
async def _ping(self):
while True:
try:
await self.send_packet(PacketType.PING, [])
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
except WebSocketDisconnect:
break
except Exception as e:
print(f"Error in ping task for {self.connection_id}: {e}")
break
class Hub:
def __init__(self) -> None:
self.clients: dict[str, Client] = {}
self.waited_clients: dict[str, int] = {}
self.tasks: set[asyncio.Task] = set()
def add_waited_client(self, connection_token: str, timestamp: int) -> None:
self.waited_clients[connection_token] = timestamp
def add_client(
self, connection_id: str, connection_token: str, connection: WebSocket
) -> Client:
if connection_token in self.clients:
raise ValueError(
f"Client with connection token {connection_token} already exists."
)
if connection_token in self.waited_clients:
if (
self.waited_clients[connection_token]
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT
):
raise TimeoutError(f"Connection {connection_id} has waited too long.")
del self.waited_clients[connection_token]
client = Client(connection_id, connection_token, connection)
self.clients[connection_token] = client
task = asyncio.create_task(client._ping())
self.tasks.add(task)
client._ping_task = task
return client
async def remove_client(self, connection_id: str) -> None:
if client := self.clients.get(connection_id):
del self.clients[connection_id]
if client._listen_task:
client._listen_task.cancel()
if client._ping_task:
client._ping_task.cancel()
await client.connection.close()
async def send_packet(self, client: Client, type: PacketType, packet: list[Any]):
await client.send_packet(type, packet)
async def _listen_client(self, client: Client) -> None:
jump = False
while not jump:
try:
message = await client.connection.receive_bytes()
packet_type, packet_data = parse_packet(message)
task = asyncio.create_task(
self._handle_packet(client, packet_type, packet_data)
)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
except WebSocketDisconnect as e:
if e.code == 1005:
continue
print(
f"Client {client.connection_id} disconnected: {e.code}, {e.reason}"
)
jump = True
except Exception as e:
print(f"Error in client {client.connection_id}: {e}")
jump = True
await self.remove_client(client.connection_id)
async def _handle_packet(
self, client: Client, type: PacketType, packet: list[Any]
) -> None:
match type:
case PacketType.PING:
...
case PacketType.INVOCATION:
invocation_id: str | None = packet[1] # pyright: ignore[reportRedeclaration]
target: str = packet[2]
args: list[Any] | None = packet[3]
if args is None:
args = []
streams: list[str] | None = packet[4] # TODO: stream support
code = ResultKind.VOID
result = None
try:
result = await self.invoke_method(client, target, args)
if result is not None:
code = ResultKind.HAS_VALUE
except InvokeException as e:
code = ResultKind.ERROR
result = e.message
except Exception as e:
code = ResultKind.ERROR
result = str(e)
packet = [
{}, # header
invocation_id,
code.value,
]
if result is not None:
packet.append(result)
if invocation_id is not None:
await client.send_packet(
PacketType.COMPLETION,
packet,
)
case PacketType.COMPLETION:
invocation_id: str = packet[1]
code: ResultKind = ResultKind(packet[2])
result: Any = packet[3] if len(packet) > 3 else None
client._store.add_result(invocation_id, code, result)
async def invoke_method(self, client: Client, method: str, args: list[Any]) -> Any:
method_ = getattr(self, method, None)
call_params = []
if not method_:
raise InvokeException(f"Method '{method}' not found in hub.")
signature = get_signature(method_)
for name, param in signature.parameters.items():
if name == "self" or param.annotation is Client:
continue
if issubclass(param.annotation, BaseModel):
call_params.append(param.annotation.model_validate(args.pop(0)))
else:
call_params.append(args.pop(0))
return await method_(client, *call_params)
async def call(self, client: Client, method: str, *args: Any) -> Any:
invocation_id = client._store.get_invocation_id()
await client.send_packet(
PacketType.INVOCATION,
[
{}, # header
invocation_id,
method,
list(args),
None, # streams
],
)
r = await client._store.fetch(invocation_id, None)
if r[0] == ResultKind.HAS_VALUE:
return r[1]
if r[0] == ResultKind.ERROR:
raise InvokeException(r[1])
return None
async def call_noblock(self, client: Client, method: str, *args: Any) -> None:
await client.send_packet(
PacketType.INVOCATION,
[
{}, # header
None, # invocation_id
method,
list(args),
None, # streams
],
)
return None
def __contains__(self, item: str) -> bool:
return item in self.clients or item in self.waited_clients

View File

@@ -0,0 +1,4 @@
from .hub import Hub
class MetadataHub(Hub): ...

View File

@@ -0,0 +1,4 @@
from .hub import Hub
class MultiplayerHub(Hub): ...

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from app.models.spectator_hub import FrameDataBundle, SpectatorState
from .hub import Client, Hub
class SpectatorHub(Hub):
async def BeginPlaySession(
self, client: Client, score_token: int, state: SpectatorState
) -> None:
...
async def SendFrameData(
self, client: Client, frame_data: FrameDataBundle
) -> None:
...

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
from enum import IntEnum
from typing import Any
import msgpack
from pydantic import BaseModel, model_validator
SEP = b"\x1e"
class PacketType(IntEnum):
INVOCATION = 1
STREAM_ITEM = 2
COMPLETION = 3
STREAM_INVOCATION = 4
CANCEL_INVOCATION = 5
PING = 6
CLOSE = 7
class ResultKind(IntEnum):
ERROR = 1
VOID = 2
HAS_VALUE = 3
def parse_packet(data: bytes) -> tuple[PacketType, list[Any]]:
length, offset = decode_varint(data)
message_data = data[offset : offset + length]
unpacked = msgpack.unpackb(message_data, raw=False)
return PacketType(unpacked[0]), unpacked[1:]
def encode_varint(value: int) -> bytes:
result = []
while value >= 0x80:
result.append((value & 0x7F) | 0x80)
value >>= 7
result.append(value & 0x7F)
return bytes(result)
def decode_varint(data: bytes, offset: int = 0) -> tuple[int, int]:
result = 0
shift = 0
pos = offset
while pos < len(data):
byte = data[pos]
result |= (byte & 0x7F) << shift
pos += 1
if (byte & 0x80) == 0:
break
shift += 7
return result, pos

View File

@@ -0,0 +1,93 @@
from __future__ import annotations
import json
from logging import info
import time
from typing import Literal
import uuid
from app.database import User as DBUser
from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user_by_token, security
from app.models.signalr import NegotiateResponse, Transport
from app.router.signalr.packet import SEP
from .hub import Hubs
from fastapi import APIRouter, Depends, Header, Query, WebSocket
from fastapi.security import HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
router = APIRouter()
@router.post("/{hub}/negotiate", response_model=NegotiateResponse)
async def negotiate(
hub: Literal["spectator", "multiplayer", "metadata"],
negotiate_version: int = Query(1, alias="negotiateVersion"),
user: DBUser = Depends(get_current_user),
):
connectionId = str(user.id)
connectionToken = f"{connectionId}:{uuid.uuid4()}"
Hubs[hub].add_waited_client(
connection_token=connectionToken,
timestamp=int(time.time()),
)
return NegotiateResponse(
connectionId=connectionId,
connectionToken=connectionToken,
negotiateVersion=negotiate_version,
availableTransports=[Transport(transport="WebSockets")],
)
@router.websocket("/{hub}")
async def connect(
hub: Literal["spectator", "multiplayer", "metadata"],
websocket: WebSocket,
id: str,
authorization: str = Header(...),
db: Session = Depends(get_db),
):
token = authorization[7:]
user_id = id.split(":")[0]
hub_ = Hubs[hub]
if id not in hub_:
await websocket.close(code=1008)
return
if (user := await get_current_user_by_token(token, db)) is None or str(
user.id
) != user_id:
await websocket.close(code=1008)
return
await websocket.accept()
# handshake
handshake = await websocket.receive_bytes()
handshake_payload = json.loads(handshake[:-1])
error = ""
if (protocol := handshake_payload.get("protocol")) != "messagepack" or (
handshake_payload.get("version")
) != 1:
error = f"Requested protocol '{protocol}' is not available."
client = None
try:
client = hub_.add_client(
connection_id=user_id,
connection_token=id,
connection=websocket,
)
except TimeoutError:
error = f"Connection {id} has waited too long."
except ValueError as e:
error = str(e)
payload = {"error": error} if error else {}
# finish handshake
await websocket.send_bytes(json.dumps(payload).encode() + SEP)
if error or not client:
await websocket.close(code=1008)
return
await hub_._listen_client(client)

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import asyncio
import sys
from typing import Any, Literal
from app.router.signalr.packet import ResultKind
class ResultStore:
def __init__(self) -> None:
self._seq: int = 1
self._futures: dict[str, asyncio.Future] = {}
@property
def current_invocation_id(self) -> int:
return self._seq
def get_invocation_id(self) -> str:
s = self._seq
self._seq = (self._seq + 1) % sys.maxsize
return str(s)
def add_result(
self, invocation_id: str, type: ResultKind, result: dict[str, Any] | None
) -> None:
if isinstance(invocation_id, str) and invocation_id.isdecimal():
if future := self._futures.get(invocation_id):
future.set_result((type, result))
async def fetch(
self,
invocation_id: str,
timeout: float | None, # noqa: ASYNC109
) -> (
tuple[Literal[ResultKind.ERROR], str]
| tuple[Literal[ResultKind.VOID], None]
| tuple[Literal[ResultKind.HAS_VALUE], Any]
):
future = asyncio.get_event_loop().create_future()
self._futures[invocation_id] = future
try:
return await asyncio.wait_for(future, timeout)
finally:
del self._futures[invocation_id]

View File

@@ -0,0 +1,45 @@
import inspect
from typing import Any, Callable, ForwardRef, cast
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
def evaluate_forwardref(
type_: ForwardRef,
globalns: Any,
localns: Any,
) -> Any:
# Even though it is the right signature for python 3.9,
# mypy complains with
# `error: Too many arguments for "_evaluate" of
# "ForwardRef"` hence the cast...
return cast(Any, type_)._evaluate(
globalns,
localns,
set(),
)
def get_annotation(param: inspect.Parameter, globalns: dict[str, Any]) -> Any:
annotation = param.annotation
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
try:
annotation = evaluate_forwardref(annotation, globalns, globalns)
except Exception:
return inspect.Parameter.empty
return annotation
def get_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_annotation(param, globalns),
)
for param in signature.parameters.values()
]
return inspect.Signature(typed_params)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from datetime import datetime
from app.database import User as DBUser
from app.models import (
from app.models.user import (
Country,
Cover,
DailyChallengeStats,
@@ -598,3 +598,5 @@ class MockLazerTournamentBanner:
MockLazerTournamentBanner(1, "https://example.com/banner1.jpg", True),
MockLazerTournamentBanner(2, "https://example.com/banner2.jpg", False),
]

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from datetime import datetime
from app.config import settings
from app.router import api_router, auth_router
from app.router import api_router, auth_router, signalr_router
from fastapi import FastAPI
@@ -12,6 +12,7 @@ from fastapi import FastAPI
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0")
app.include_router(api_router, prefix="/api/v2")
app.include_router(signalr_router, prefix="/signalr")
app.include_router(auth_router)

View File

@@ -9,6 +9,7 @@ dependencies = [
"bcrypt>=4.1.2",
"cryptography>=41.0.7",
"fastapi>=0.104.1",
"msgpack>=1.1.1",
"passlib[bcrypt]>=1.7.4",
"pydantic[email]>=2.5.0",
"pymysql>=1.1.0",
@@ -22,7 +23,7 @@ dependencies = [
[tool.ruff]
line-length = 88
target-version = "py39"
target-version = "py311"
[tool.ruff.format]
line-ending = "lf"
@@ -82,5 +83,6 @@ reportIncompatibleVariableOverride = false
[dependency-groups]
dev = [
"msgpack-types>=0.5.0",
"ruff>=0.12.4",
]

1191
uv.lock generated

File diff suppressed because it is too large Load Diff