chore(linter): update ruff rules

This commit is contained in:
MingxuanGame
2025-10-03 15:46:53 +00:00
parent b10425ad91
commit d490239f46
59 changed files with 393 additions and 425 deletions

0
app/helpers/__init__.py Normal file
View File

View File

@@ -1,19 +1,39 @@
"""
GeoLite2 Helper Class
GeoLite2 Helper Class (asynchronous)
"""
from __future__ import annotations
import asyncio
from contextlib import suppress
import os
from pathlib import Path
import shutil
import tarfile
import tempfile
import time
from typing import Any, Required, TypedDict
from app.log import logger
import aiofiles
import httpx
import maxminddb
class GeoIPLookupResult(TypedDict, total=False):
ip: Required[str]
country_iso: str
country_name: str
city_name: str
latitude: str
longitude: str
time_zone: str
postal_code: str
asn: int | None
organization: str
BASE_URL = "https://download.maxmind.com/app/geoip_download"
EDITIONS = {
"City": "GeoLite2-City",
@@ -25,161 +45,184 @@ EDITIONS = {
class GeoIPHelper:
def __init__(
self,
dest_dir="./geoip",
license_key=None,
editions=None,
max_age_days=8,
timeout=60.0,
dest_dir: str | Path = Path("./geoip"),
license_key: str | None = None,
editions: list[str] | None = None,
max_age_days: int = 8,
timeout: float = 60.0,
):
self.dest_dir = dest_dir
self.dest_dir = Path(dest_dir).expanduser()
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
self.editions = editions or ["City", "ASN"]
self.editions = list(editions or ["City", "ASN"])
self.max_age_days = max_age_days
self.timeout = timeout
self._readers = {}
self._readers: dict[str, maxminddb.Reader] = {}
self._update_lock = asyncio.Lock()
@staticmethod
def _safe_extract(tar: tarfile.TarFile, path: str):
base = Path(path).resolve()
for m in tar.getmembers():
target = (base / m.name).resolve()
if not str(target).startswith(str(base)):
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
base = path.resolve()
for member in tar.getmembers():
target = (base / member.name).resolve()
if not target.is_relative_to(base): # py312
raise RuntimeError("Unsafe path in tar file")
tar.extractall(path=path, filter="data")
tar.extractall(path=base, filter="data")
def _download_and_extract(self, edition_id: str) -> str:
"""
下载并解压 mmdb 文件到 dest_dir仅保留 .mmdb
- 跟随 302 重定向
- 流式下载到临时文件
- 临时目录退出后自动清理
"""
@staticmethod
def _as_mapping(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}
@staticmethod
def _as_str(value: Any, default: str = "") -> str:
if isinstance(value, str):
return value
if value is None:
return default
return str(value)
@staticmethod
def _as_int(value: Any) -> int | None:
return value if isinstance(value, int) else None
@staticmethod
def _extract_tarball(src: Path, dest: Path) -> None:
with tarfile.open(src, "r:gz") as tar:
GeoIPHelper._safe_extract(tar, dest)
@staticmethod
def _find_mmdb(root: Path) -> Path | None:
for candidate in root.rglob("*.mmdb"):
return candidate
return None
def _latest_file_sync(self, edition_id: str) -> Path | None:
directory = self.dest_dir
if not directory.is_dir():
return None
candidates = list(directory.glob(f"{edition_id}*.mmdb"))
if not candidates:
return None
return max(candidates, key=lambda p: p.stat().st_mtime)
async def _latest_file(self, edition_id: str) -> Path | None:
return await asyncio.to_thread(self._latest_file_sync, edition_id)
async def _download_and_extract(self, edition_id: str) -> Path:
if not self.license_key:
raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.")
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
tmp_dir = Path(await asyncio.to_thread(tempfile.mkdtemp))
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
with client.stream("GET", url) as resp:
try:
tgz_path = tmp_dir / "db.tgz"
async with (
httpx.AsyncClient(follow_redirects=True, timeout=self.timeout) as client,
client.stream("GET", url) as resp,
):
resp.raise_for_status()
with tempfile.TemporaryDirectory() as tmpd:
tgz_path = os.path.join(tmpd, "db.tgz")
# 流式写入
with open(tgz_path, "wb") as f:
for chunk in resp.iter_bytes():
if chunk:
f.write(chunk)
async with aiofiles.open(tgz_path, "wb") as download_file:
async for chunk in resp.aiter_bytes():
if chunk:
await download_file.write(chunk)
# 解压并只移动 .mmdb
with tarfile.open(tgz_path, "r:gz") as tar:
# 先安全检查与解压
self._safe_extract(tar, tmpd)
await asyncio.to_thread(self._extract_tarball, tgz_path, tmp_dir)
mmdb_path = await asyncio.to_thread(self._find_mmdb, tmp_dir)
if mmdb_path is None:
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
# 递归找 .mmdb
mmdb_path = None
for root, _, files in os.walk(tmpd):
for fn in files:
if fn.endswith(".mmdb"):
mmdb_path = os.path.join(root, fn)
break
if mmdb_path:
break
await asyncio.to_thread(self.dest_dir.mkdir, parents=True, exist_ok=True)
dst = self.dest_dir / mmdb_path.name
await asyncio.to_thread(shutil.move, mmdb_path, dst)
return dst
finally:
await asyncio.to_thread(shutil.rmtree, tmp_dir, ignore_errors=True)
if not mmdb_path:
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
async def update(self, force: bool = False) -> None:
async with self._update_lock:
for edition in self.editions:
edition_id = EDITIONS[edition]
path = await self._latest_file(edition_id)
need_download = force or path is None
os.makedirs(self.dest_dir, exist_ok=True)
dst = os.path.join(self.dest_dir, os.path.basename(mmdb_path))
shutil.move(mmdb_path, dst)
return dst
def _latest_file(self, edition_id: str):
if not os.path.isdir(self.dest_dir):
return None
files = [
os.path.join(self.dest_dir, f)
for f in os.listdir(self.dest_dir)
if f.startswith(edition_id) and f.endswith(".mmdb")
]
return max(files, key=os.path.getmtime) if files else None
def update(self, force=False):
from app.log import logger
for ed in self.editions:
eid = EDITIONS[ed]
path = self._latest_file(eid)
need = force or not path
if path:
age_days = (time.time() - os.path.getmtime(path)) / 86400
if age_days >= self.max_age_days:
need = True
logger.info(
f"{eid} database is {age_days:.1f} days old "
f"(max: {self.max_age_days}), will download new version"
)
if path:
mtime = await asyncio.to_thread(path.stat)
age_days = (time.time() - mtime.st_mtime) / 86400
if age_days >= self.max_age_days:
need_download = True
logger.info(
f"{edition_id} database is {age_days:.1f} days old "
f"(max: {self.max_age_days}), will download new version"
)
else:
logger.info(
f"{edition_id} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})"
)
else:
logger.info(f"{eid} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})")
else:
logger.info(f"{eid} database not found, will download")
logger.info(f"{edition_id} database not found, will download")
if need:
logger.info(f"Downloading {eid} database...")
path = self._download_and_extract(eid)
logger.info(f"{eid} database downloaded successfully")
else:
logger.info(f"Using existing {eid} database")
if need_download:
logger.info(f"Downloading {edition_id} database...")
path = await self._download_and_extract(edition_id)
logger.info(f"{edition_id} database downloaded successfully")
else:
logger.info(f"Using existing {edition_id} database")
old = self._readers.get(ed)
if old:
try:
old.close()
except Exception:
pass
if path is not None:
self._readers[ed] = maxminddb.open_database(path)
old_reader = self._readers.get(edition)
if old_reader:
with suppress(Exception):
old_reader.close()
if path is not None:
self._readers[edition] = maxminddb.open_database(str(path))
def lookup(self, ip: str):
res = {"ip": ip}
# City
city_r = self._readers.get("City")
if city_r:
data = city_r.get(ip)
if data:
country = data.get("country") or {}
res["country_iso"] = country.get("iso_code") or ""
res["country_name"] = (country.get("names") or {}).get("en", "")
city = data.get("city") or {}
res["city_name"] = (city.get("names") or {}).get("en", "")
loc = data.get("location") or {}
res["latitude"] = str(loc.get("latitude") or "")
res["longitude"] = str(loc.get("longitude") or "")
res["time_zone"] = str(loc.get("time_zone") or "")
postal = data.get("postal") or {}
if "code" in postal:
res["postal_code"] = postal["code"]
# ASN
asn_r = self._readers.get("ASN")
if asn_r:
data = asn_r.get(ip)
if data:
res["asn"] = data.get("autonomous_system_number")
res["organization"] = data.get("autonomous_system_organization")
def lookup(self, ip: str) -> GeoIPLookupResult:
res: GeoIPLookupResult = {"ip": ip}
city_reader = self._readers.get("City")
if city_reader:
data = city_reader.get(ip)
if isinstance(data, dict):
country = self._as_mapping(data.get("country"))
res["country_iso"] = self._as_str(country.get("iso_code"))
country_names = self._as_mapping(country.get("names"))
res["country_name"] = self._as_str(country_names.get("en"))
city = self._as_mapping(data.get("city"))
city_names = self._as_mapping(city.get("names"))
res["city_name"] = self._as_str(city_names.get("en"))
location = self._as_mapping(data.get("location"))
latitude = location.get("latitude")
longitude = location.get("longitude")
res["latitude"] = str(latitude) if latitude is not None else ""
res["longitude"] = str(longitude) if longitude is not None else ""
res["time_zone"] = self._as_str(location.get("time_zone"))
postal = self._as_mapping(data.get("postal"))
postal_code = postal.get("code")
if postal_code is not None:
res["postal_code"] = self._as_str(postal_code)
asn_reader = self._readers.get("ASN")
if asn_reader:
data = asn_reader.get(ip)
if isinstance(data, dict):
res["asn"] = self._as_int(data.get("autonomous_system_number"))
res["organization"] = self._as_str(data.get("autonomous_system_organization"), default="")
return res
def close(self):
for r in self._readers.values():
try:
r.close()
except Exception:
pass
def close(self) -> None:
for reader in self._readers.values():
with suppress(Exception):
reader.close()
self._readers = {}
if __name__ == "__main__":
# 示例用法
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
geo.update()
print(geo.lookup("8.8.8.8"))
geo.close()
async def _demo() -> None:
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
await geo.update()
print(geo.lookup("8.8.8.8"))
geo.close()
asyncio.run(_demo())