refactor(detector): more readable
This commit is contained in:
@@ -1,9 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
import zipfile
|
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
@@ -11,10 +9,9 @@ from app.models.beatmap import BeatmapAttributes
|
|||||||
from app.models.mods import APIMod
|
from app.models.mods import APIMod
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
|
||||||
import httpx
|
|
||||||
from osupyparser import OsuFile
|
from osupyparser import OsuFile
|
||||||
from osupyparser.osu.objects import Slider
|
from osupyparser.osu.objects import Slider
|
||||||
from sqlmodel import Session, create_engine, select
|
from sqlmodel import col, exists, select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -63,10 +60,25 @@ def calculate_beatmap_attribute(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def calculate_pp(
|
async def calculate_pp(score: "Score", beatmap: str, session: AsyncSession) -> float:
|
||||||
score: "Score",
|
from app.database.beatmap import BannedBeatmaps
|
||||||
beatmap: str,
|
|
||||||
) -> float:
|
if settings.suspicious_score_check:
|
||||||
|
beatmap_banned = (
|
||||||
|
await session.exec(
|
||||||
|
select(exists()).where(
|
||||||
|
col(BannedBeatmaps.beatmap_id) == score.beatmap_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if beatmap_banned:
|
||||||
|
return 0
|
||||||
|
is_suspicious = is_suspicious_beatmap(beatmap)
|
||||||
|
if is_suspicious:
|
||||||
|
session.add(BannedBeatmaps(beatmap_id=score.beatmap_id))
|
||||||
|
logger.warning(f"Beatmap {score.beatmap_id} is suspicious, banned")
|
||||||
|
return 0
|
||||||
|
|
||||||
map = rosu.Beatmap(content=beatmap)
|
map = rosu.Beatmap(content=beatmap)
|
||||||
map.convert(score.gamemode.to_rosu(), score.mods) # pyright: ignore[reportArgumentType]
|
map.convert(score.gamemode.to_rosu(), score.mods) # pyright: ignore[reportArgumentType]
|
||||||
perf = rosu.Performance(
|
perf = rosu.Performance(
|
||||||
@@ -86,21 +98,10 @@ def calculate_pp(
|
|||||||
)
|
)
|
||||||
attrs = perf.calculate(map)
|
attrs = perf.calculate(map)
|
||||||
pp = attrs.pp
|
pp = attrs.pp
|
||||||
engine = create_engine(settings.database_url)
|
|
||||||
from app.database.beatmap import BannedBeatmaps
|
|
||||||
|
|
||||||
beatmap_banned = False
|
|
||||||
with Session(engine) as session:
|
|
||||||
beatmap_id = session.exec(
|
|
||||||
select(BannedBeatmaps).where(BannedBeatmaps.beatmap_id == score.beatmap_id)
|
|
||||||
).first()
|
|
||||||
if beatmap_id:
|
|
||||||
beatmap_banned = True
|
|
||||||
# mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp
|
# mrekk bp1: 2048pp; ppy-sb top1 rxbp1: 2198pp
|
||||||
if settings.suspicious_score_check and (
|
if settings.suspicious_score_check and (
|
||||||
beatmap_banned
|
(attrs.difficulty.stars > 25 and score.accuracy < 0.8) or pp > 2300
|
||||||
or (attrs.difficulty.stars > 25 and score.accuracy < 0.8)
|
|
||||||
or pp > 2300
|
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"User {score.user_id} played {score.beatmap_id} with {pp=} "
|
f"User {score.user_id} played {score.beatmap_id} with {pp=} "
|
||||||
@@ -237,32 +238,6 @@ def calculate_score_to_level(total_score: int) -> float:
|
|||||||
99999999999,
|
99999999999,
|
||||||
99999999999,
|
99999999999,
|
||||||
99999999999,
|
99999999999,
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
99999999999,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
remaining_score = total_score
|
remaining_score = total_score
|
||||||
@@ -291,48 +266,13 @@ def calculate_weighted_acc(acc: float, index: int) -> float:
|
|||||||
return calculate_pp_weight(index) * acc if acc > 0 else 0.0
|
return calculate_pp_weight(index) * acc if acc > 0 else 0.0
|
||||||
|
|
||||||
|
|
||||||
async def get_suspscious_beatmap(beatmapset_id: int, session: AsyncSession):
|
def is_suspicious_beatmap(content: str) -> bool:
|
||||||
url = (
|
osufile = OsuFile(content=content.encode("utf-8-sig")).parse_file()
|
||||||
f"https://txy1.sayobot.cn/beatmaps/download/novideo/{beatmapset_id}?server=auto"
|
for obj in osufile.hit_objects:
|
||||||
)
|
if obj.pos.x < 0 or obj.pos.y < 0 or obj.pos.x > 512 or obj.pos.y > 384:
|
||||||
async with httpx.AsyncClient() as client:
|
return True
|
||||||
resp = await client.get(url)
|
if isinstance(obj, Slider):
|
||||||
if resp.status_code == 200:
|
for point in obj.points:
|
||||||
import aiofiles
|
if point.x < 0 or point.y < 0 or point.x > 512 or point.y > 384:
|
||||||
|
return True
|
||||||
async with aiofiles.open(f"temp_beatmaps/{beatmapset_id}.osz", "wb") as f:
|
return False
|
||||||
await f.write(resp.content)
|
|
||||||
with zipfile.ZipFile(f"temp_beatmaps/{beatmapset_id}.osz", "r") as beatmap_ref:
|
|
||||||
beatmap_ref.extractall(f"temp_beatmaps/{beatmapset_id}")
|
|
||||||
os.remove(f"temp_beatmaps/{beatmapset_id}.osz")
|
|
||||||
all_osu_files = []
|
|
||||||
for root, dirs, files in os.walk(f"temp_beatmaps/{beatmapset_id}"):
|
|
||||||
for name in files:
|
|
||||||
if name.endswith(".osu"):
|
|
||||||
all_osu_files.append(os.path.join(root, name))
|
|
||||||
for file in all_osu_files:
|
|
||||||
osufile = OsuFile(file).parse_file()
|
|
||||||
for obj in osufile.hit_objects:
|
|
||||||
if obj.pos.x < 0 or obj.pos.y < 0 or obj.pos.x > 512 or obj.pos.y > 384:
|
|
||||||
# 延迟导入以解决循环导入问题
|
|
||||||
from app.database.beatmap import BannedBeatmaps
|
|
||||||
|
|
||||||
session.add(
|
|
||||||
BannedBeatmaps(id=osufile.beatmap_id, beatmap_id=osufile.beatmap_id)
|
|
||||||
)
|
|
||||||
break
|
|
||||||
if type(obj) is Slider:
|
|
||||||
for point in obj.points:
|
|
||||||
if point.x < 0 or point.y < 0 or point.x > 512 or point.y > 384:
|
|
||||||
# 延迟导入以解决循环导入问题
|
|
||||||
from app.database.beatmap import BannedBeatmaps
|
|
||||||
|
|
||||||
session.add(
|
|
||||||
BannedBeatmaps(
|
|
||||||
id=osufile.beatmap_id, beatmap_id=osufile.beatmap_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
break
|
|
||||||
os.remove(file)
|
|
||||||
os.remove(f"temp_beatmaps/{beatmapset_id}")
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
|||||||
import hashlib
|
import hashlib
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from app.calculator import calculate_beatmap_attribute
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.models.beatmap import BeatmapAttributes, BeatmapRankStatus
|
from app.models.beatmap import BeatmapAttributes, BeatmapRankStatus
|
||||||
from app.models.mods import APIMod
|
from app.models.mods import APIMod
|
||||||
@@ -203,7 +204,7 @@ class BeatmapResp(BeatmapBase):
|
|||||||
|
|
||||||
class BannedBeatmaps(SQLModel, table=True):
|
class BannedBeatmaps(SQLModel, table=True):
|
||||||
__tablename__ = "banned_beatmaps" # pyright: ignore[reportAssignmentType]
|
__tablename__ = "banned_beatmaps" # pyright: ignore[reportAssignmentType]
|
||||||
id: int = Field(primary_key=True, index=True)
|
id: int | None = Field(primary_key=True, index=True, default=None)
|
||||||
beatmap_id: int = Field(index=True)
|
beatmap_id: int = Field(index=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -221,9 +222,6 @@ async def calculate_beatmap_attributes(
|
|||||||
if await redis.exists(key):
|
if await redis.exists(key):
|
||||||
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
return BeatmapAttributes.model_validate_json(await redis.get(key)) # pyright: ignore[reportArgumentType]
|
||||||
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
resp = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||||
# 延迟导入以解决循环导入问题
|
|
||||||
from app.calculator import calculate_beatmap_attribute
|
|
||||||
|
|
||||||
attr = await asyncio.get_event_loop().run_in_executor(
|
attr = await asyncio.get_event_loop().run_in_executor(
|
||||||
None, calculate_beatmap_attribute, resp, ruleset, mods_
|
None, calculate_beatmap_attribute, resp, ruleset, mods_
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -183,10 +183,6 @@ class Beatmapset(AsyncAttrs, BeatmapsetBase, table=True):
|
|||||||
if not beatmapset:
|
if not beatmapset:
|
||||||
resp = await fetcher.get_beatmapset(sid)
|
resp = await fetcher.get_beatmapset(sid)
|
||||||
beatmapset = await cls.from_resp(session, resp)
|
beatmapset = await cls.from_resp(session, resp)
|
||||||
# 检查可疑谱面
|
|
||||||
from app.calculator import get_suspscious_beatmap
|
|
||||||
|
|
||||||
await get_suspscious_beatmap(sid, session)
|
|
||||||
return beatmapset
|
return beatmapset
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import UTC, date, datetime
|
from datetime import UTC, date, datetime
|
||||||
import json
|
import json
|
||||||
@@ -726,9 +725,7 @@ async def process_score(
|
|||||||
)
|
)
|
||||||
if can_get_pp:
|
if can_get_pp:
|
||||||
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||||
pp = await asyncio.get_event_loop().run_in_executor(
|
pp = await calculate_pp(score, beatmap_raw, session)
|
||||||
None, calculate_pp, score, beatmap_raw
|
|
||||||
)
|
|
||||||
score.pp = pp
|
score.pp = pp
|
||||||
session.add(score)
|
session.add(score)
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
|
|||||||
@@ -85,9 +85,7 @@ async def _recalculate_pp(
|
|||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
beatmap_raw = await fetcher.get_or_fetch_beatmap_raw(redis, beatmap_id)
|
||||||
pp = await asyncio.get_event_loop().run_in_executor(
|
pp = await calculate_pp(score, beatmap_raw, session)
|
||||||
None, calculate_pp, score, beatmap_raw
|
|
||||||
)
|
|
||||||
score.pp = pp
|
score.pp = pp
|
||||||
if score.beatmap_id not in prev or prev[score.beatmap_id].pp < pp:
|
if score.beatmap_id not in prev or prev[score.beatmap_id].pp < pp:
|
||||||
best_score = PPBestScore(
|
best_score = PPBestScore(
|
||||||
|
|||||||
@@ -96,10 +96,12 @@ reportIncompatibleVariableOverride = false
|
|||||||
[tool.uv.workspace]
|
[tool.uv.workspace]
|
||||||
members = [
|
members = [
|
||||||
"packages/msgpack_lazer_api",
|
"packages/msgpack_lazer_api",
|
||||||
|
"packages/osupyparser",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
msgpack-lazer-api = { workspace = true }
|
msgpack-lazer-api = { workspace = true }
|
||||||
|
osupyparser = { git = "https://github.com/MingxuanGame/osupyparser.git" }
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
cache-keys = [{file = "pyproject.toml"}, {file = "packages/msgpack_lazer_api/Cargo.toml"}, {file = "**/*.rs"}]
|
cache-keys = [{file = "pyproject.toml"}, {file = "packages/msgpack_lazer_api/Cargo.toml"}, {file = "**/*.rs"}]
|
||||||
|
|||||||
7
uv.lock
generated
7
uv.lock
generated
@@ -858,7 +858,7 @@ requires-dist = [
|
|||||||
{ name = "httpx", specifier = ">=0.28.1" },
|
{ name = "httpx", specifier = ">=0.28.1" },
|
||||||
{ name = "loguru", specifier = ">=0.7.3" },
|
{ name = "loguru", specifier = ">=0.7.3" },
|
||||||
{ name = "msgpack-lazer-api", editable = "packages/msgpack_lazer_api" },
|
{ name = "msgpack-lazer-api", editable = "packages/msgpack_lazer_api" },
|
||||||
{ name = "osupyparser", specifier = ">=1.0.7" },
|
{ name = "osupyparser", git = "https://github.com/MingxuanGame/osupyparser.git" },
|
||||||
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
|
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
|
||||||
{ name = "pillow", specifier = ">=11.3.0" },
|
{ name = "pillow", specifier = ">=11.3.0" },
|
||||||
{ name = "pydantic", extras = ["email"], specifier = ">=2.5.0" },
|
{ name = "pydantic", extras = ["email"], specifier = ">=2.5.0" },
|
||||||
@@ -883,9 +883,8 @@ dev = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "osupyparser"
|
name = "osupyparser"
|
||||||
version = "1.0.7"
|
version = "1.0.8"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { git = "https://github.com/MingxuanGame/osupyparser.git#e41ec1db87ab64531897127a44b86351c21322bd" }
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/7b/6b/7f567c2acd1f2028603353da40ad7411bb47754994552d3f0c4cfa6703f9/OsuPyParser-1.0.7.tar.gz", hash = "sha256:67f530c31dd5c288c8fff8f583c899c673536681f8cc3699d0afc5e4d8c2b1ff", size = 9095, upload-time = "2021-08-29T23:48:22.752Z" }
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "passlib"
|
name = "passlib"
|
||||||
|
|||||||
Reference in New Issue
Block a user