feat(client-verification): add client verification (#104)
New configurations: - `CHECK_CLIENT_VERSION` enables the check (default=True) - `CLIENT_VERSION_URLS` contains a chain of valid client hashes. [osu!](https://osu.ppy.sh/home/download) and [osu! GU](https://github.com/GooGuTeam/osu/releases) are valid by default. View [g0v0-client-versions](https://github.com/GooGuTeam/g0v0-client-versions) to learn how to support your own client. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -705,6 +705,21 @@ CALCULATOR_CONFIG='{}'
|
|||||||
Field(default=True, description="检查自定义 ruleset 版本"),
|
Field(default=True, description="检查自定义 ruleset 版本"),
|
||||||
"反作弊设置",
|
"反作弊设置",
|
||||||
]
|
]
|
||||||
|
check_client_version: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(default=True, description="检查客户端版本"),
|
||||||
|
"反作弊设置",
|
||||||
|
]
|
||||||
|
client_version_urls: Annotated[
|
||||||
|
list[str],
|
||||||
|
Field(
|
||||||
|
default=["https://raw.githubusercontent.com/GooGuTeam/g0v0-client-versions/main/version_list.json"],
|
||||||
|
description=(
|
||||||
|
"客户端版本列表 URL, 查看 https://github.com/GooGuTeam/g0v0-client-versions 来添加你自己的客户端"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"反作弊设置",
|
||||||
|
]
|
||||||
|
|
||||||
# 存储设置
|
# 存储设置
|
||||||
storage_service: Annotated[
|
storage_service: Annotated[
|
||||||
|
|||||||
10
app/dependencies/client_verification.py
Normal file
10
app/dependencies/client_verification.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from app.service.client_verification_service import (
|
||||||
|
ClientVerificationService as OriginalClientVerificationService,
|
||||||
|
get_client_verification_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
ClientVerificationService = Annotated[OriginalClientVerificationService, Depends(get_client_verification_service)]
|
||||||
27
app/models/version.py
Normal file
27
app/models/version.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from typing import NamedTuple, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class VersionInfo(TypedDict):
|
||||||
|
version: str
|
||||||
|
release_date: str
|
||||||
|
hashes: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class VersionList(TypedDict):
|
||||||
|
name: str
|
||||||
|
versions: list[VersionInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class VersionCheckResult(NamedTuple):
|
||||||
|
is_valid: bool
|
||||||
|
client_name: str = ""
|
||||||
|
version: str = ""
|
||||||
|
os: str = ""
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return self.is_valid
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if self.is_valid:
|
||||||
|
return f"{self.client_name} {self.version} ({self.os})"
|
||||||
|
return "Invalid Client Version"
|
||||||
@@ -35,6 +35,7 @@ from app.database.score import (
|
|||||||
)
|
)
|
||||||
from app.dependencies.api_version import APIVersion
|
from app.dependencies.api_version import APIVersion
|
||||||
from app.dependencies.cache import UserCacheService
|
from app.dependencies.cache import UserCacheService
|
||||||
|
from app.dependencies.client_verification import ClientVerificationService
|
||||||
from app.dependencies.database import Database, Redis, get_redis, with_db
|
from app.dependencies.database import Database, Redis, get_redis, with_db
|
||||||
from app.dependencies.fetcher import Fetcher, get_fetcher
|
from app.dependencies.fetcher import Fetcher, get_fetcher
|
||||||
from app.dependencies.storage import StorageService
|
from app.dependencies.storage import StorageService
|
||||||
@@ -415,6 +416,8 @@ async def get_user_all_beatmap_scores(
|
|||||||
async def create_solo_score(
|
async def create_solo_score(
|
||||||
background_task: BackgroundTasks,
|
background_task: BackgroundTasks,
|
||||||
db: Database,
|
db: Database,
|
||||||
|
fetcher: Fetcher,
|
||||||
|
verification_service: ClientVerificationService,
|
||||||
beatmap_id: Annotated[int, Path(description="谱面 ID")],
|
beatmap_id: Annotated[int, Path(description="谱面 ID")],
|
||||||
beatmap_hash: Annotated[str, Form(description="谱面文件哈希")],
|
beatmap_hash: Annotated[str, Form(description="谱面文件哈希")],
|
||||||
ruleset_id: Annotated[int, Form(..., description="ruleset 数字 ID (0-3)")],
|
ruleset_id: Annotated[int, Form(..., description="ruleset 数字 ID (0-3)")],
|
||||||
@@ -430,6 +433,21 @@ async def create_solo_score(
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
raise HTTPException(status_code=400, detail="Invalid ruleset ID")
|
raise HTTPException(status_code=400, detail="Invalid ruleset ID")
|
||||||
|
|
||||||
|
if not (
|
||||||
|
client_version := await verification_service.validate_client_version(
|
||||||
|
version_hash,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Client version check failed for user {current_user.id} on beatmap {beatmap_id} "
|
||||||
|
f"(version hash: {version_hash})"
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=422, detail="invalid client hash")
|
||||||
|
|
||||||
|
beatmap = await Beatmap.get_or_fetch(db, fetcher, md5=beatmap_hash)
|
||||||
|
if not beatmap or beatmap.id != beatmap_id:
|
||||||
|
raise HTTPException(status_code=422, detail="invalid or missing beatmap_hash")
|
||||||
|
|
||||||
if not (result := gamemode.check_ruleset_version(ruleset_hash)):
|
if not (result := gamemode.check_ruleset_version(ruleset_hash)):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Ruleset version check failed for user {current_user.id} on beatmap {beatmap_id} "
|
f"Ruleset version check failed for user {current_user.id} on beatmap {beatmap_id} "
|
||||||
@@ -450,6 +468,15 @@ async def create_solo_score(
|
|||||||
db.add(score_token)
|
db.add(score_token)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(score_token)
|
await db.refresh(score_token)
|
||||||
|
logger.debug(
|
||||||
|
"User {user_id} created solo score {score_token} for beatmap {beatmap_id} "
|
||||||
|
"(mode: {mode}), using client {client_version}",
|
||||||
|
user_id=user_id,
|
||||||
|
score_token=score_token.id,
|
||||||
|
beatmap_id=beatmap_id,
|
||||||
|
mode=ruleset_id,
|
||||||
|
client_version=str(client_version),
|
||||||
|
)
|
||||||
return ScoreTokenResp.from_db(score_token)
|
return ScoreTokenResp.from_db(score_token)
|
||||||
|
|
||||||
|
|
||||||
@@ -485,8 +512,9 @@ async def create_playlist_score(
|
|||||||
background_task: BackgroundTasks,
|
background_task: BackgroundTasks,
|
||||||
room_id: int,
|
room_id: int,
|
||||||
playlist_id: int,
|
playlist_id: int,
|
||||||
|
verification_service: ClientVerificationService,
|
||||||
beatmap_id: Annotated[int, Form(description="谱面 ID")],
|
beatmap_id: Annotated[int, Form(description="谱面 ID")],
|
||||||
beatmap_hash: Annotated[str, Form(description="游戏版本哈希")],
|
beatmap_hash: Annotated[str, Form(description="谱面文件哈希")],
|
||||||
ruleset_id: Annotated[int, Form(..., description="ruleset 数字 ID (0-3)")],
|
ruleset_id: Annotated[int, Form(..., description="ruleset 数字 ID (0-3)")],
|
||||||
current_user: ClientUser,
|
current_user: ClientUser,
|
||||||
version_hash: Annotated[str, Form(description="谱面版本哈希")] = "",
|
version_hash: Annotated[str, Form(description="谱面版本哈希")] = "",
|
||||||
@@ -497,6 +525,17 @@ async def create_playlist_score(
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
raise HTTPException(status_code=400, detail="Invalid ruleset ID")
|
raise HTTPException(status_code=400, detail="Invalid ruleset ID")
|
||||||
|
|
||||||
|
if not (
|
||||||
|
client_version := await verification_service.validate_client_version(
|
||||||
|
version_hash,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Client version check failed for user {current_user.id} on room {room_id}, playlist {playlist_id} "
|
||||||
|
f"(version hash: {version_hash})"
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=422, detail="invalid client hash")
|
||||||
|
|
||||||
if not (result := gamemode.check_ruleset_version(ruleset_hash)):
|
if not (result := gamemode.check_ruleset_version(ruleset_hash)):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Ruleset version check failed for user {current_user.id} on room {room_id}, playlist {playlist_id},"
|
f"Ruleset version check failed for user {current_user.id} on room {room_id}, playlist {playlist_id},"
|
||||||
@@ -556,6 +595,17 @@ async def create_playlist_score(
|
|||||||
session.add(score_token)
|
session.add(score_token)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(score_token)
|
await session.refresh(score_token)
|
||||||
|
logger.debug(
|
||||||
|
"User {user_id} created playlist score {score_token} for beatmap {beatmap_id} "
|
||||||
|
"(mode: {mode}, room {room_id}, item {playlist_id}), using client {client_version}",
|
||||||
|
user_id=user_id,
|
||||||
|
score_token=score_token.id,
|
||||||
|
beatmap_id=beatmap_id,
|
||||||
|
mode=ruleset_id,
|
||||||
|
room_id=room_id,
|
||||||
|
playlist_id=playlist_id,
|
||||||
|
client_version=str(client_version),
|
||||||
|
)
|
||||||
return ScoreTokenResp.from_db(score_token)
|
return ScoreTokenResp.from_db(score_token)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
131
app/service/client_verification_service.py
Normal file
131
app/service/client_verification_service.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Service for verifying client versions against known valid versions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.log import logger
|
||||||
|
from app.models.version import VersionCheckResult, VersionList
|
||||||
|
from app.path import CONFIG_DIR
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import httpx
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
HASHES_DIR = CONFIG_DIR / "client_versions.json"
|
||||||
|
|
||||||
|
|
||||||
|
class ClientVerificationService:
|
||||||
|
"""A service to verify client versions against known valid versions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
version_lists (list[VersionList]): A list of version lists fetched from remote sources.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
init(): Initialize the service by loading version data from disk and refreshing from remote.
|
||||||
|
refresh(): Fetch the latest version lists from configured URLs and store them locally.
|
||||||
|
load_from_disk(): Load version lists from the local JSON file.
|
||||||
|
validate_client_version(client_version: str) -> VersionCheckResult: Validate a given client version against the known versions.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.original_version_lists: dict[str, list[VersionList]] = {}
|
||||||
|
self.versions: dict[str, tuple[str, str, str]] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def init(self) -> None:
|
||||||
|
"""Initialize the service by loading version data from disk and refreshing from remote."""
|
||||||
|
await self.load_from_disk(first_load=True)
|
||||||
|
await self.refresh()
|
||||||
|
await self.load_from_disk()
|
||||||
|
|
||||||
|
async def refresh(self) -> None:
|
||||||
|
"""Fetch the latest version lists from configured URLs and store them locally."""
|
||||||
|
lists: dict[str, list[VersionList]] = self.original_version_lists.copy()
|
||||||
|
async with AsyncClient() as client:
|
||||||
|
for url in settings.client_version_urls:
|
||||||
|
try:
|
||||||
|
resp = await client.get(url, timeout=10)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
if len(data) == 0:
|
||||||
|
logger.warning(f"Client version list from {url} is empty")
|
||||||
|
continue
|
||||||
|
lists[url] = data
|
||||||
|
logger.info(f"Fetched client version list from {url}, total {len(data)} clients")
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.warning(f"Timeout when fetching client version list from {url}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch client version list from {url}: {e}")
|
||||||
|
async with aiofiles.open(HASHES_DIR, "wb") as f:
|
||||||
|
await f.write(json.dumps(lists).encode("utf-8"))
|
||||||
|
|
||||||
|
async def load_from_disk(self, first_load: bool = False) -> None:
|
||||||
|
"""Load version lists from the local JSON file."""
|
||||||
|
async with self._lock:
|
||||||
|
self.versions.clear()
|
||||||
|
try:
|
||||||
|
if not HASHES_DIR.is_file() and not first_load:
|
||||||
|
logger.warning("Client version list file does not exist on disk")
|
||||||
|
return
|
||||||
|
async with aiofiles.open(HASHES_DIR, "rb") as f:
|
||||||
|
content = await f.read()
|
||||||
|
self.original_version_lists = json.loads(content.decode("utf-8"))
|
||||||
|
for version_list_group in self.original_version_lists.values():
|
||||||
|
for version_list in version_list_group:
|
||||||
|
for version_info in version_list["versions"]:
|
||||||
|
for client_hash, os_name in version_info["hashes"].items():
|
||||||
|
self.versions[client_hash] = (
|
||||||
|
version_list["name"],
|
||||||
|
version_info["version"],
|
||||||
|
os_name,
|
||||||
|
)
|
||||||
|
if not first_load:
|
||||||
|
if len(self.versions) == 0:
|
||||||
|
logger.warning("Client version list is empty after loading from disk")
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Loaded client version list from disk, "
|
||||||
|
f"total {len(self.versions)} clients, {len(self.versions)} versions"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to load client version list from disk: {e}")
|
||||||
|
|
||||||
|
async def validate_client_version(self, client_version: str) -> VersionCheckResult:
|
||||||
|
"""Validate a given client version against the known versions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_version (str): The client version string to validate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VersionCheckResult: The result of the validation.
|
||||||
|
"""
|
||||||
|
if not settings.check_client_version:
|
||||||
|
return VersionCheckResult(is_valid=True)
|
||||||
|
async with self._lock:
|
||||||
|
if client_version in self.versions:
|
||||||
|
name, version, os_name = self.versions[client_version]
|
||||||
|
return VersionCheckResult(is_valid=True, client_name=name, version=version, os=os_name)
|
||||||
|
return VersionCheckResult(is_valid=False)
|
||||||
|
|
||||||
|
|
||||||
|
_client_verification_service: ClientVerificationService | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_verification_service() -> ClientVerificationService:
|
||||||
|
"""Get the singleton instance of ClientVerificationService.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ClientVerificationService: The singleton instance.
|
||||||
|
"""
|
||||||
|
global _client_verification_service
|
||||||
|
if _client_verification_service is None:
|
||||||
|
_client_verification_service = ClientVerificationService()
|
||||||
|
return _client_verification_service
|
||||||
|
|
||||||
|
|
||||||
|
async def init_client_verification_service() -> None:
|
||||||
|
"""Initialize the ClientVerificationService singleton."""
|
||||||
|
service = get_client_verification_service()
|
||||||
|
logger.info("Initializing ClientVerificationService...")
|
||||||
|
await service.init()
|
||||||
@@ -6,6 +6,7 @@ from . import (
|
|||||||
database_cleanup,
|
database_cleanup,
|
||||||
recalculate_banned_beatmap,
|
recalculate_banned_beatmap,
|
||||||
recalculate_failed_score,
|
recalculate_failed_score,
|
||||||
|
update_client_version,
|
||||||
)
|
)
|
||||||
from .cache import start_cache_tasks, stop_cache_tasks
|
from .cache import start_cache_tasks, stop_cache_tasks
|
||||||
from .calculate_all_user_rank import calculate_user_rank
|
from .calculate_all_user_rank import calculate_user_rank
|
||||||
|
|||||||
13
app/tasks/update_client_version.py
Normal file
13
app/tasks/update_client_version.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from app.config import settings
|
||||||
|
from app.dependencies.scheduler import get_scheduler
|
||||||
|
from app.log import logger
|
||||||
|
from app.service.client_verification_service import get_client_verification_service
|
||||||
|
|
||||||
|
if settings.check_client_version:
|
||||||
|
|
||||||
|
@get_scheduler().scheduled_job("interval", id="update_client_version", hours=2)
|
||||||
|
async def update_client_version():
|
||||||
|
logger.info("Updating client version lists...")
|
||||||
|
client_verification_service = get_client_verification_service()
|
||||||
|
await client_verification_service.refresh()
|
||||||
|
await client_verification_service.load_from_disk()
|
||||||
4
main.py
4
main.py
@@ -33,6 +33,7 @@ from app.router.redirect import redirect_router
|
|||||||
from app.router.v1 import api_v1_public_router
|
from app.router.v1 import api_v1_public_router
|
||||||
from app.service.beatmap_download_service import download_service
|
from app.service.beatmap_download_service import download_service
|
||||||
from app.service.beatmapset_update_service import init_beatmapset_update_service
|
from app.service.beatmapset_update_service import init_beatmapset_update_service
|
||||||
|
from app.service.client_verification_service import init_client_verification_service
|
||||||
from app.service.email_queue import start_email_processor, stop_email_processor
|
from app.service.email_queue import start_email_processor, stop_email_processor
|
||||||
from app.service.redis_message_system import redis_message_system
|
from app.service.redis_message_system import redis_message_system
|
||||||
from app.service.subscribers.user_cache import user_online_subscriber
|
from app.service.subscribers.user_cache import user_online_subscriber
|
||||||
@@ -68,6 +69,9 @@ async def lifespan(app: FastAPI): # noqa: ARG001
|
|||||||
load_achievements()
|
load_achievements()
|
||||||
await init_calculator()
|
await init_calculator()
|
||||||
|
|
||||||
|
if settings.check_client_version:
|
||||||
|
await init_client_verification_service()
|
||||||
|
|
||||||
# init rate limiter
|
# init rate limiter
|
||||||
await FastAPILimiter.init(redis_rate_limit_client)
|
await FastAPILimiter.init(redis_rate_limit_client)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user