refactor(detector): more readable

This commit is contained in:
MingxuanGame
2025-08-15 08:52:01 +00:00
parent 7c7c68c163
commit 814d9c4618
7 changed files with 40 additions and 110 deletions

View File

@@ -1,9 +1,7 @@
from __future__ import annotations from __future__ import annotations
import math import math
import os
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import zipfile
from app.config import settings from app.config import settings
from app.log import logger from app.log import logger
@@ -11,10 +9,9 @@ from app.models.beatmap import BeatmapAttributes
from app.models.mods import APIMod from app.models.mods import APIMod
from app.models.score import GameMode from app.models.score import GameMode
import httpx
from osupyparser import OsuFile from osupyparser import OsuFile
from osupyparser.osu.objects import Slider from osupyparser.osu.objects import Slider
from sqlmodel import Session, create_engine, select from sqlmodel import col, exists, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
try: try:
@@ -63,10 +60,25 @@ def calculate_beatmap_attribute(
) )
def calculate_pp( async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> float:
score: "Score", from app.database.beatmap import BannedBeatmaps
beatmap: str,
) -> float: if settings.suspicious_score_check:
beatmap_banned = (
await session.exec(
select(exists()).where(
col(BannedBeatmaps.beatmap_id) == score.beatmap_id
)
)
).first()
if beatmap_banned:
return 0
is_suspicious = is_suspicious_beatmap(beatmap)
if is_suspicious:
session.add(BannedBeatmaps(beatmap_id=score.beatmap_id))
logger.warning(f"Beatmap {score.beatmap_id} is suspicious, banned")
return 0
map = rosu.Beatmap(content=beatmap) map = rosu.Beatmap(content=beatmap)
map.convert(score.gamemode.to_rosu(), score.mods) # pyright: ignore[reportArgumentType] map.convert(score.gamemode.to_rosu(), score.mods) # pyright: ignore[reportArgumentType]
perf = rosu.Performance( perf = rosu.Performance(
@@ -86,21 +98,10 @@ def calculate_pp(
) )
attrs = perf.calculate(map) attrs = perf.calculate(map)
pp = attrs.pp pp = attrs.pp
engine = create_engine(settings.database_url)
from app.database.beatmap import BannedBeatmaps
beatmap_banned = False
with Session(engine) as session:
beatmap_id = session.exec(
select(BannedBeatmaps).where(BannedBeatmaps.beatmap_id == score.beatmap_id)
).first()
if beatmap_id:
beatmap_banned = True
# mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp # mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp
if settings.suspicious_score_check and ( if settings.suspicious_score_check and (
beatmap_banned (attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 2300
or (attrs.difficulty.stars > 25 and score.accuracy < 0.8)
or pp > 2300
): ):
logger.warning( logger.warning(
f"User {score.user_id} played {score.beatmap_id} with {pp=} " f"User {score.user_id} played {score.beatmap_id} with {pp=} "
@@ -237,32 +238,6 @@ def calculate_score_to_level(total_score: int) -> float:
99999999999, 99999999999,
99999999999, 99999999999,
99999999999, 99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
99999999999,
] ]
remaining_score = total_score remaining_score = total_score
@@ -291,48 +266,13 @@ def calculate_weighted_acc(acc: float, index: int) -> float:
return calculate_pp_weight(index) * acc if acc > 0 else 0.0 return calculate_pp_weight(index) * acc if acc > 0 else 0.0
async def get_suspscious_beatmap(beatmapset_id: int, session: AsyncSession): def is_suspicious_beatmap(content: str) -> bool:
url = ( osufile = OsuFile(content=content.encode("utf-8-sig")).parse_file()
f"https://txy1.sayobot.cn/beatmaps/download/novideo/{beatmapset_id}?server=auto" for obj in osufile.hit_objects:
) if obj.pos.x < 0 or obj.pos.y < 0 or obj.pos.x > 512 or obj.pos.y > 384:
async with httpx.AsyncClient() as client: return True
resp = await client.get(url) if isinstance(obj, Slider):
if resp.status_code == 200: for point in obj.points:
import aiofiles if point.x < 0 or point.y < 0 or point.x > 512 or point.y > 384:
return True
async with aiofiles.open(f"temp_beatmaps/{beatmapset_id}.osz", "wb") as f: return False
await f.write(resp.content)
with zipfile.ZipFile(f"temp_beatmaps/{beatmapset_id}.osz", "r") as beatmap_ref:
beatmap_ref.extractall(f"temp_beatmaps/{beatmapset_id}")
os.remove(f"temp_beatmaps/{beatmapset_id}.osz")
all_osu_files = []
for root, dirs, files in os.walk(f"temp_beatmaps/{beatmapset_id}"):
for name in files:
if name.endswith(".osu"):
all_osu_files.append(os.path.join(root, name))
for file in all_osu_files:
osufile = OsuFile(file).parse_file()
for obj in osufile.hit_objects:
if obj.pos.x < 0 or obj.pos.y < 0 or obj.pos.x > 512 or obj.pos.y > 384:
# 延迟导入以解决循环导入问题
from app.database.beatmap import BannedBeatmaps
session.add(
BannedBeatmaps(id=osufile.beatmap_id, beatmap_id=osufile.beatmap_id)
)
break
if type(obj) is Slider:
for point in obj.points:
if point.x < 0 or point.y < 0 or point.x > 512 or point.y > 384:
# 延迟导入以解决循环导入问题
from app.database.beatmap import BannedBeatmaps
session.add(
BannedBeatmaps(
id=osufile.beatmap_id, beatmap_id=osufile.beatmap_id
)
)
break
os.remove(file)
os.remove(f"temp_beatmaps/{beatmapset_id}")
return None

View File

@@ -3,6 +3,7 @@ from datetime import datetime
import hashlib import hashlib
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from app.calculator import calculate_beatmap_attribute
from app.config import settings from app.config import settings
from app.models.beatmap import BeatmapAttributes, BeatmapRankStatus from app.models.beatmap import BeatmapAttributes, BeatmapRankStatus
from app.models.mods import APIMod from app.models.mods import APIMod
@@ -203,7 +204,7 @@ class BeatmapResp(BeatmapBase):
class BannedBeatmaps(SQLModel, table=True): class BannedBeatmaps(SQLModel, table=True):
__tablename__ = "banned_beatmaps" # pyright: ignore[reportAssignmentType] __tablename__ = "banned_beatmaps" # pyright: ignore[reportAssignmentType]
id: int = Field(primary_key=True, index=True) id: int | None = Field(primary_key=True, index=True, default=None)
beatmap_id: int = Field(index=True) beatmap_id: int = Field(index=True)
@@ -221,9 +222,6 @@ async def calculate_beatmap_attributes(
if await redis.exists(key): if await redis.exists(key):
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType] return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
# 延迟导入以解决循环导入问题
from app.calculator import calculate_beatmap_attribute
attr = await asyncio.get_event_loop().run_in_executor( attr = await asyncio.get_event_loop().run_in_executor(
None, calculate_beatmap_attribute, resp, ruleset, mods_ None, calculate_beatmap_attribute, resp, ruleset, mods_
) )

View File

@@ -183,10 +183,6 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
if not beatmapset: if not beatmapset:
resp = await fetcher.get_beatmapset(sid) resp = await fetcher.get_beatmapset(sid)
beatmapset = await cls.from_resp(session, resp) beatmapset = await cls.from_resp(session, resp)
# 检查可疑谱面
from app.calculator import get_suspscious_beatmap
await get_suspscious_beatmap(sid, session)
return beatmapset return beatmapset

View File

@@ -1,4 +1,3 @@
import asyncio
from collections.abc import Sequence from collections.abc import Sequence
from datetime import UTC, date, datetime from datetime import UTC, date, datetime
import json import json
@@ -726,9 +725,7 @@ async def process_score(
) )
if can_get_pp: if can_get_pp:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
pp = await asyncio.get_event_loop().run_in_executor( pp = await calculate_pp(score, beatmap_raw, session)
None, calculate_pp, score, beatmap_raw
)
score.pp = pp score.pp = pp
session.add(score) session.add(score)
user_id = user.id user_id = user.id

View File

@@ -85,9 +85,7 @@ async def _recalculate_pp(
break break
try: try:
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id) beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
pp = await asyncio.get_event_loop().run_in_executor( pp = await calculate_pp(score, beatmap_raw, session)
None, calculate_pp, score, beatmap_raw
)
score.pp = pp score.pp = pp
if score.beatmap_id not in prev or prev[score.beatmap_id].pp < pp: if score.beatmap_id not in prev or prev[score.beatmap_id].pp < pp:
best_score = PPBestScore( best_score = PPBestScore(

View File

@@ -96,10 +96,12 @@ reportIncompatibleVariableOverride = false
[tool.uv.workspace] [tool.uv.workspace]
members = [ members = [
"packages/msgpack_lazer_api", "packages/msgpack_lazer_api",
"packages/osupyparser",
] ]
[tool.uv.sources] [tool.uv.sources]
msgpack-lazer-api = { workspace = true } msgpack-lazer-api = { workspace = true }
osupyparser = { git = "https://github.com/MingxuanGame/osupyparser.git" }
[tool.uv] [tool.uv]
cache-keys = [{file = "pyproject.toml"}, {file = "packages/msgpack_lazer_api/Cargo.toml"}, {file = "**/*.rs"}] cache-keys = [{file = "pyproject.toml"}, {file = "packages/msgpack_lazer_api/Cargo.toml"}, {file = "**/*.rs"}]

7
uv.lock generated
View File

@@ -858,7 +858,7 @@ requires-dist = [
{ name = "httpx", specifier = ">=0.28.1" }, { name = "httpx", specifier = ">=0.28.1" },
{ name = "loguru", specifier = ">=0.7.3" }, { name = "loguru", specifier = ">=0.7.3" },
{ name = "msgpack-lazer-api", editable = "packages/msgpack_lazer_api" }, { name = "msgpack-lazer-api", editable = "packages/msgpack_lazer_api" },
{ name = "osupyparser", specifier = ">=1.0.7" }, { name = "osupyparser", git = "https://github.com/MingxuanGame/osupyparser.git" },
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
{ name = "pillow", specifier = ">=11.3.0" }, { name = "pillow", specifier = ">=11.3.0" },
{ name = "pydantic", extras = ["email"], specifier = ">=2.5.0" }, { name = "pydantic", extras = ["email"], specifier = ">=2.5.0" },
@@ -883,9 +883,8 @@ dev = [
[[package]] [[package]]
name = "osupyparser" name = "osupyparser"
version = "1.0.7" version = "1.0.8"
source = { registry = "https://pypi.org/simple" } source = { git = "https://github.com/MingxuanGame/osupyparser.git#e41ec1db87ab64531897127a44b86351c21322bd" }
sdist = { url = "https://files.pythonhosted.org/packages/7b/6b/7f567c2acd1f2028603353da40ad7411bb47754994552d3f0c4cfa6703f9/OsuPyParser-1.0.7.tar.gz", hash = "sha256:67f530c31dd5c288c8fff8f583c899c673536681f8cc3699d0afc5e4d8c2b1ff", size = 9095, upload-time = "2021-08-29T23:48:22.752Z" }
[[package]] [[package]]
name = "passlib" name = "passlib"