diff --git a/app/calculator.py b/app/calculator.py index 508bf58..8e09e67 100644 --- a/app/calculator.py +++ b/app/calculator.py @@ -1,7 +1,9 @@ from __future__ import annotations import math +import os from typing import TYPE_CHECKING +import zipfile from app.config import settings from app.log import logger @@ -9,6 +11,12 @@ from app.models.beatmap import BeatmapAttributes from app.models.mods import APIMod from app.models.score import GameMode +import httpx +from osupyparser import OsuFile +from osupyparser.osu.objects import Slider +from sqlmodel import Session, create_engine, select +from sqlmodel.ext.asyncio.session import AsyncSession + try: import rosu_pp_py as rosu except ImportError: @@ -78,9 +86,21 @@ def calculate_pp( ) attrs = perf.calculate(map) pp = attrs.pp + engine = create_engine(settings.database_url) + from app.database.beatmap import BannedBeatmaps + + beatmap_banned = False + with Session(engine) as session: + beatmap_id = session.exec( + select(BannedBeatmaps).where(BannedBeatmaps.beatmap_id == score.beatmap_id) + ).first() + if beatmap_id: + beatmap_banned = True # mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp if settings.suspicious_score_check and ( - (attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 2300 + beatmap_banned + or (attrs.difficulty.stars > 25 and score.accuracy < 0.8) + or pp > 2300 ): logger.warning( f"User {score.user_id} played {score.beatmap_id} with {pp=} " @@ -225,6 +245,24 @@ def calculate_score_to_level(total_score: int) -> float: 99999999999, 99999999999, 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, + 99999999999, ] remaining_score = total_score @@ -251,3 +289,50 @@ def calculate_weighted_pp(pp: float, index: int) -> float: def calculate_weighted_acc(acc: float, index: int) -> float: return calculate_pp_weight(index) * acc if acc > 0 else 0.0 + + +async def get_suspscious_beatmap(beatmapset_id: int, session: AsyncSession): + url = ( + f"https://txy1.sayobot.cn/beatmaps/download/novideo/{beatmapset_id}?server=auto" + ) + async with httpx.AsyncClient() as client: + resp = await client.get(url) + if resp.status_code == 200: + import aiofiles + + async with aiofiles.open(f"temp_beatmaps/{beatmapset_id}.osz", "wb") as f: + await f.write(resp.content) + with zipfile.ZipFile(f"temp_beatmaps/{beatmapset_id}.osz", "r") as beatmap_ref: + beatmap_ref.extractall(f"temp_beatmaps/{beatmapset_id}") + os.remove(f"temp_beatmaps/{beatmapset_id}.osz") + all_osu_files = [] + for root, dirs, files in os.walk(f"temp_beatmaps/{beatmapset_id}"): + for name in files: + if name.endswith(".osu"): + all_osu_files.append(os.path.join(root, name)) + for file in all_osu_files: + osufile = OsuFile(file).parse_file() + for obj in osufile.hit_objects: + if obj.pos.x < 0 or obj.pos.y < 0 or obj.pos.x > 512 or obj.pos.y > 384: + # 延迟导入以解决循环导入问题 + from app.database.beatmap import BannedBeatmaps + + session.add( + BannedBeatmaps(id=osufile.beatmap_id, beatmap_id=osufile.beatmap_id) + ) + break + if type(obj) is Slider: + for point in obj.points: + if point.x < 0 or point.y < 0 or point.x > 512 or point.y > 384: + # 延迟导入以解决循环导入问题 + from app.database.beatmap import BannedBeatmaps + + session.add( + BannedBeatmaps( + id=osufile.beatmap_id, beatmap_id=osufile.beatmap_id + ) + ) + break + os.remove(file) + os.remove(f"temp_beatmaps/{beatmapset_id}") + return None diff --git a/app/database/beatmap.py b/app/database/beatmap.py index 1cc8497..cc80b6d 100644 --- a/app/database/beatmap.py +++ b/app/database/beatmap.py @@ -3,7 +3,6 @@ from datetime import datetime import hashlib from typing import TYPE_CHECKING -from app.calculator import calculate_beatmap_attribute from app.config import settings from app.models.beatmap import BeatmapAttributes, BeatmapRankStatus from app.models.mods import APIMod @@ -202,6 +201,12 @@ class BeatmapResp(BeatmapBase): return cls.model_validate(beatmap_) +class BannedBeatmaps(SQLModel, table=True): + __tablename__ = "banned_beatmaps" # pyright: ignore[reportAssignmentType] + id: int = Field(primary_key=True, index=True) + beatmap_id: int = Field(index=True) + + async def calculate_beatmap_attributes( beatmap_id: int, ruleset: GameMode, @@ -216,6 +221,9 @@ async def calculate_beatmap_attributes( if await redis.exists(key): return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) + # 延迟导入以解决循环导入问题 + from app.calculator import calculate_beatmap_attribute + attr = await asyncio.get_event_loop().run_in_executor( None, calculate_beatmap_attribute, resp, ruleset, mods_ ) diff --git a/app/database/beatmapset.py b/app/database/beatmapset.py index 6a5ca2b..711ab33 100644 --- a/app/database/beatmapset.py +++ b/app/database/beatmapset.py @@ -183,6 +183,10 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True): if not beatmapset: resp = await fetcher.get_beatmapset(sid) beatmapset = await cls.from_resp(session, resp) + # 检查可疑谱面 + from app.calculator import get_suspscious_beatmap + + await get_suspscious_beatmap(sid, session) return beatmapset diff --git a/migrations/versions/9f6b27e8ea51_add_table_banned_beatmaps.py b/migrations/versions/9f6b27e8ea51_add_table_banned_beatmaps.py new file mode 100644 index 0000000..202d517 --- /dev/null +++ b/migrations/versions/9f6b27e8ea51_add_table_banned_beatmaps.py @@ -0,0 +1,50 @@ +"""add table banned_beatmaps + +Revision ID: 9f6b27e8ea51 +Revises: 951a2188e691 +Create Date: 2025-08-15 07:23:25.645360 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "9f6b27e8ea51" +down_revision: str | Sequence[str] | None = "951a2188e691" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "banned_beatmaps", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("beatmap_id", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_banned_beatmaps_beatmap_id"), + "banned_beatmaps", + ["beatmap_id"], + unique=False, + ) + op.create_index( + op.f("ix_banned_beatmaps_id"), "banned_beatmaps", ["id"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_banned_beatmaps_id"), table_name="banned_beatmaps") + op.drop_index(op.f("ix_banned_beatmaps_beatmap_id"), table_name="banned_beatmaps") + op.drop_table("banned_beatmaps") + # ### end Alembic commands ### diff --git a/pyproject.toml b/pyproject.toml index 361860a..d828667 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "httpx>=0.28.1", "loguru>=0.7.3", "msgpack-lazer-api", + "osupyparser>=1.0.7", "passlib[bcrypt]>=1.7.4", "pillow>=11.3.0", "pydantic-settings>=2.10.1", diff --git a/uv.lock b/uv.lock index 4a700c1..8525d55 100644 --- a/uv.lock +++ b/uv.lock @@ -822,6 +822,7 @@ dependencies = [ { name = "httpx" }, { name = "loguru" }, { name = "msgpack-lazer-api" }, + { name = "osupyparser" }, { name = "passlib", extra = ["bcrypt"] }, { name = "pillow" }, { name = "pydantic", extra = ["email"] }, @@ -857,6 +858,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.28.1" }, { name = "loguru", specifier = ">=0.7.3" }, { name = "msgpack-lazer-api", editable = "packages/msgpack_lazer_api" }, + { name = "osupyparser", specifier = ">=1.0.7" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "pillow", specifier = ">=11.3.0" }, { name = "pydantic", extras = ["email"], specifier = ">=2.5.0" }, @@ -879,6 +881,12 @@ dev = [ { name = "types-aioboto3", extras = ["aioboto3", "essential"], specifier = ">=15.0.0" }, ] +[[package]] +name = "osupyparser" +version = "1.0.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/6b/7f567c2acd1f2028603353da40ad7411bb47754994552d3f0c4cfa6703f9/OsuPyParser-1.0.7.tar.gz", hash = "sha256:67f530c31dd5c288c8fff8f583c899c673536681f8cc3699d0afc5e4d8c2b1ff", size = 9095, upload-time = "2021-08-29T23:48:22.752Z" } + [[package]] name = "passlib" version = "1.7.4"