diff --git a/app/config.py b/app/config.py index 84d9290..544a690 100644 --- a/app/config.py +++ b/app/config.py @@ -705,6 +705,21 @@ CALCULATOR_CONFIG='{}' Field(default=True, description="检查自定义 ruleset 版本"), "反作弊设置", ] + check_client_version: Annotated[ + bool, + Field(default=True, description="检查客户端版本"), + "反作弊设置", + ] + client_version_urls: Annotated[ + list[str], + Field( + default=["https://raw.githubusercontent.com/GooGuTeam/g0v0-client-versions/main/version_list.json"], + description=( + "客户端版本列表 URL, 查看 https://github.com/GooGuTeam/g0v0-client-versions 来添加你自己的客户端" + ), + ), + "反作弊设置", + ] # 存储设置 storage_service: Annotated[ diff --git a/app/dependencies/client_verification.py b/app/dependencies/client_verification.py new file mode 100644 index 0000000..c8eda3e --- /dev/null +++ b/app/dependencies/client_verification.py @@ -0,0 +1,10 @@ +from typing import Annotated + +from app.service.client_verification_service import ( + ClientVerificationService as OriginalClientVerificationService, + get_client_verification_service, +) + +from fastapi import Depends + +ClientVerificationService = Annotated[OriginalClientVerificationService, Depends(get_client_verification_service)] diff --git a/app/models/version.py b/app/models/version.py new file mode 100644 index 0000000..606eabb --- /dev/null +++ b/app/models/version.py @@ -0,0 +1,27 @@ +from typing import NamedTuple, TypedDict + + +class VersionInfo(TypedDict): + version: str + release_date: str + hashes: dict[str, str] + + +class VersionList(TypedDict): + name: str + versions: list[VersionInfo] + + +class VersionCheckResult(NamedTuple): + is_valid: bool + client_name: str = "" + version: str = "" + os: str = "" + + def __bool__(self) -> bool: + return self.is_valid + + def __str__(self) -> str: + if self.is_valid: + return f"{self.client_name} {self.version} ({self.os})" + return "Invalid Client Version" diff --git a/app/router/v2/score.py b/app/router/v2/score.py index 2e250e4..09956ee 100644 --- a/app/router/v2/score.py +++ b/app/router/v2/score.py @@ -35,6 +35,7 @@ from app.database.score import ( ) from app.dependencies.api_version import APIVersion from app.dependencies.cache import UserCacheService +from app.dependencies.client_verification import ClientVerificationService from app.dependencies.database import Database, Redis, get_redis, with_db from app.dependencies.fetcher import Fetcher, get_fetcher from app.dependencies.storage import StorageService @@ -415,6 +416,8 @@ async def get_user_all_beatmap_scores( async def create_solo_score( background_task: BackgroundTasks, db: Database, + fetcher: Fetcher, + verification_service: ClientVerificationService, beatmap_id: Annotated[int, Path(description="谱面 ID")], beatmap_hash: Annotated[str, Form(description="谱面文件哈希")], ruleset_id: Annotated[int, Form(..., description="ruleset 数字 ID (0-3)")], @@ -430,6 +433,21 @@ async def create_solo_score( except ValueError: raise HTTPException(status_code=400, detail="Invalid ruleset ID") + if not ( + client_version := await verification_service.validate_client_version( + version_hash, + ) + ): + logger.info( + f"Client version check failed for user {current_user.id} on beatmap {beatmap_id} " + f"(version hash: {version_hash})" + ) + raise HTTPException(status_code=422, detail="invalid client hash") + + beatmap = await Beatmap.get_or_fetch(db, fetcher, md5=beatmap_hash) + if not beatmap or beatmap.id != beatmap_id: + raise HTTPException(status_code=422, detail="invalid or missing beatmap_hash") + if not (result := gamemode.check_ruleset_version(ruleset_hash)): logger.info( f"Ruleset version check failed for user {current_user.id} on beatmap {beatmap_id} " @@ -450,6 +468,15 @@ async def create_solo_score( db.add(score_token) await db.commit() await db.refresh(score_token) + logger.debug( + "User {user_id} created solo score {score_token} for beatmap {beatmap_id} " + "(mode: {mode}), using client {client_version}", + user_id=user_id, + score_token=score_token.id, + beatmap_id=beatmap_id, + mode=ruleset_id, + client_version=str(client_version), + ) return ScoreTokenResp.from_db(score_token) @@ -485,8 +512,9 @@ async def create_playlist_score( background_task: BackgroundTasks, room_id: int, playlist_id: int, + verification_service: ClientVerificationService, beatmap_id: Annotated[int, Form(description="谱面 ID")], - beatmap_hash: Annotated[str, Form(description="游戏版本哈希")], + beatmap_hash: Annotated[str, Form(description="谱面文件哈希")], ruleset_id: Annotated[int, Form(..., description="ruleset 数字 ID (0-3)")], current_user: ClientUser, version_hash: Annotated[str, Form(description="谱面版本哈希")] = "", @@ -497,6 +525,17 @@ async def create_playlist_score( except ValueError: raise HTTPException(status_code=400, detail="Invalid ruleset ID") + if not ( + client_version := await verification_service.validate_client_version( + version_hash, + ) + ): + logger.info( + f"Client version check failed for user {current_user.id} on room {room_id}, playlist {playlist_id} " + f"(version hash: {version_hash})" + ) + raise HTTPException(status_code=422, detail="invalid client hash") + if not (result := gamemode.check_ruleset_version(ruleset_hash)): logger.info( f"Ruleset version check failed for user {current_user.id} on room {room_id}, playlist {playlist_id}," @@ -556,6 +595,17 @@ async def create_playlist_score( session.add(score_token) await session.commit() await session.refresh(score_token) + logger.debug( + "User {user_id} created playlist score {score_token} for beatmap {beatmap_id} " + "(mode: {mode}, room {room_id}, item {playlist_id}), using client {client_version}", + user_id=user_id, + score_token=score_token.id, + beatmap_id=beatmap_id, + mode=ruleset_id, + room_id=room_id, + playlist_id=playlist_id, + client_version=str(client_version), + ) return ScoreTokenResp.from_db(score_token) diff --git a/app/service/client_verification_service.py b/app/service/client_verification_service.py new file mode 100644 index 0000000..6ae1119 --- /dev/null +++ b/app/service/client_verification_service.py @@ -0,0 +1,131 @@ +"""Service for verifying client versions against known valid versions.""" + +import asyncio +import json + +from app.config import settings +from app.log import logger +from app.models.version import VersionCheckResult, VersionList +from app.path import CONFIG_DIR + +import aiofiles +import httpx +from httpx import AsyncClient + +HASHES_DIR = CONFIG_DIR / "client_versions.json" + + +class ClientVerificationService: + """A service to verify client versions against known valid versions. + + Attributes: + version_lists (list[VersionList]): A list of version lists fetched from remote sources. + + Methods: + init(): Initialize the service by loading version data from disk and refreshing from remote. + refresh(): Fetch the latest version lists from configured URLs and store them locally. + load_from_disk(): Load version lists from the local JSON file. + validate_client_version(client_version: str) -> VersionCheckResult: Validate a given client version against the known versions. + """ # noqa: E501 + + def __init__(self) -> None: + self.original_version_lists: dict[str, list[VersionList]] = {} + self.versions: dict[str, tuple[str, str, str]] = {} + self._lock = asyncio.Lock() + + async def init(self) -> None: + """Initialize the service by loading version data from disk and refreshing from remote.""" + await self.load_from_disk(first_load=True) + await self.refresh() + await self.load_from_disk() + + async def refresh(self) -> None: + """Fetch the latest version lists from configured URLs and store them locally.""" + lists: dict[str, list[VersionList]] = self.original_version_lists.copy() + async with AsyncClient() as client: + for url in settings.client_version_urls: + try: + resp = await client.get(url, timeout=10) + resp.raise_for_status() + data = resp.json() + if len(data) == 0: + logger.warning(f"Client version list from {url} is empty") + continue + lists[url] = data + logger.info(f"Fetched client version list from {url}, total {len(data)} clients") + except httpx.TimeoutException: + logger.warning(f"Timeout when fetching client version list from {url}") + except Exception as e: + logger.warning(f"Failed to fetch client version list from {url}: {e}") + async with aiofiles.open(HASHES_DIR, "wb") as f: + await f.write(json.dumps(lists).encode("utf-8")) + + async def load_from_disk(self, first_load: bool = False) -> None: + """Load version lists from the local JSON file.""" + async with self._lock: + self.versions.clear() + try: + if not HASHES_DIR.is_file() and not first_load: + logger.warning("Client version list file does not exist on disk") + return + async with aiofiles.open(HASHES_DIR, "rb") as f: + content = await f.read() + self.original_version_lists = json.loads(content.decode("utf-8")) + for version_list_group in self.original_version_lists.values(): + for version_list in version_list_group: + for version_info in version_list["versions"]: + for client_hash, os_name in version_info["hashes"].items(): + self.versions[client_hash] = ( + version_list["name"], + version_info["version"], + os_name, + ) + if not first_load: + if len(self.versions) == 0: + logger.warning("Client version list is empty after loading from disk") + else: + logger.info( + "Loaded client version list from disk, " + f"total {len(self.versions)} clients, {len(self.versions)} versions" + ) + except Exception as e: + logger.exception(f"Failed to load client version list from disk: {e}") + + async def validate_client_version(self, client_version: str) -> VersionCheckResult: + """Validate a given client version against the known versions. + + Args: + client_version (str): The client version string to validate. + + Returns: + VersionCheckResult: The result of the validation. + """ + if not settings.check_client_version: + return VersionCheckResult(is_valid=True) + async with self._lock: + if client_version in self.versions: + name, version, os_name = self.versions[client_version] + return VersionCheckResult(is_valid=True, client_name=name, version=version, os=os_name) + return VersionCheckResult(is_valid=False) + + +_client_verification_service: ClientVerificationService | None = None + + +def get_client_verification_service() -> ClientVerificationService: + """Get the singleton instance of ClientVerificationService. + + Returns: + ClientVerificationService: The singleton instance. + """ + global _client_verification_service + if _client_verification_service is None: + _client_verification_service = ClientVerificationService() + return _client_verification_service + + +async def init_client_verification_service() -> None: + """Initialize the ClientVerificationService singleton.""" + service = get_client_verification_service() + logger.info("Initializing ClientVerificationService...") + await service.init() diff --git a/app/tasks/__init__.py b/app/tasks/__init__.py index 3c2dc94..2a77be0 100644 --- a/app/tasks/__init__.py +++ b/app/tasks/__init__.py @@ -6,6 +6,7 @@ from . import ( database_cleanup, recalculate_banned_beatmap, recalculate_failed_score, + update_client_version, ) from .cache import start_cache_tasks, stop_cache_tasks from .calculate_all_user_rank import calculate_user_rank diff --git a/app/tasks/update_client_version.py b/app/tasks/update_client_version.py new file mode 100644 index 0000000..84462c9 --- /dev/null +++ b/app/tasks/update_client_version.py @@ -0,0 +1,13 @@ +from app.config import settings +from app.dependencies.scheduler import get_scheduler +from app.log import logger +from app.service.client_verification_service import get_client_verification_service + +if settings.check_client_version: + + @get_scheduler().scheduled_job("interval", id="update_client_version", hours=2) + async def update_client_version(): + logger.info("Updating client version lists...") + client_verification_service = get_client_verification_service() + await client_verification_service.refresh() + await client_verification_service.load_from_disk() diff --git a/main.py b/main.py index 344de1c..3e8d22a 100644 --- a/main.py +++ b/main.py @@ -33,6 +33,7 @@ from app.router.redirect import redirect_router from app.router.v1 import api_v1_public_router from app.service.beatmap_download_service import download_service from app.service.beatmapset_update_service import init_beatmapset_update_service +from app.service.client_verification_service import init_client_verification_service from app.service.email_queue import start_email_processor, stop_email_processor from app.service.redis_message_system import redis_message_system from app.service.subscribers.user_cache import user_online_subscriber @@ -68,6 +69,9 @@ async def lifespan(app: FastAPI): # noqa: ARG001 load_achievements() await init_calculator() + if settings.check_client_version: + await init_client_verification_service() + # init rate limiter await FastAPILimiter.init(redis_rate_limit_client)