chore(linter): update ruff rules
This commit is contained in:
0
app/helpers/__init__.py
Normal file
0
app/helpers/__init__.py
Normal file
@@ -1,19 +1,39 @@
|
||||
"""
|
||||
GeoLite2 Helper Class
|
||||
GeoLite2 Helper Class (asynchronous)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Required, TypedDict
|
||||
|
||||
from app.log import logger
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
import maxminddb
|
||||
|
||||
|
||||
class GeoIPLookupResult(TypedDict, total=False):
|
||||
ip: Required[str]
|
||||
country_iso: str
|
||||
country_name: str
|
||||
city_name: str
|
||||
latitude: str
|
||||
longitude: str
|
||||
time_zone: str
|
||||
postal_code: str
|
||||
asn: int | None
|
||||
organization: str
|
||||
|
||||
|
||||
BASE_URL = "https://download.maxmind.com/app/geoip_download"
|
||||
EDITIONS = {
|
||||
"City": "GeoLite2-City",
|
||||
@@ -25,161 +45,184 @@ EDITIONS = {
|
||||
class GeoIPHelper:
|
||||
def __init__(
|
||||
self,
|
||||
dest_dir="./geoip",
|
||||
license_key=None,
|
||||
editions=None,
|
||||
max_age_days=8,
|
||||
timeout=60.0,
|
||||
dest_dir: str | Path = Path("./geoip"),
|
||||
license_key: str | None = None,
|
||||
editions: list[str] | None = None,
|
||||
max_age_days: int = 8,
|
||||
timeout: float = 60.0,
|
||||
):
|
||||
self.dest_dir = dest_dir
|
||||
self.dest_dir = Path(dest_dir).expanduser()
|
||||
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
|
||||
self.editions = editions or ["City", "ASN"]
|
||||
self.editions = list(editions or ["City", "ASN"])
|
||||
self.max_age_days = max_age_days
|
||||
self.timeout = timeout
|
||||
self._readers = {}
|
||||
self._readers: dict[str, maxminddb.Reader] = {}
|
||||
self._update_lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _safe_extract(tar: tarfile.TarFile, path: str):
|
||||
base = Path(path).resolve()
|
||||
for m in tar.getmembers():
|
||||
target = (base / m.name).resolve()
|
||||
if not str(target).startswith(str(base)):
|
||||
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
|
||||
base = path.resolve()
|
||||
for member in tar.getmembers():
|
||||
target = (base / member.name).resolve()
|
||||
if not target.is_relative_to(base): # py312
|
||||
raise RuntimeError("Unsafe path in tar file")
|
||||
tar.extractall(path=path, filter="data")
|
||||
tar.extractall(path=base, filter="data")
|
||||
|
||||
def _download_and_extract(self, edition_id: str) -> str:
|
||||
"""
|
||||
下载并解压 mmdb 文件到 dest_dir,仅保留 .mmdb
|
||||
- 跟随 302 重定向
|
||||
- 流式下载到临时文件
|
||||
- 临时目录退出后自动清理
|
||||
"""
|
||||
@staticmethod
|
||||
def _as_mapping(value: Any) -> dict[str, Any]:
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
def _as_str(value: Any, default: str = "") -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if value is None:
|
||||
return default
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _as_int(value: Any) -> int | None:
|
||||
return value if isinstance(value, int) else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_tarball(src: Path, dest: Path) -> None:
|
||||
with tarfile.open(src, "r:gz") as tar:
|
||||
GeoIPHelper._safe_extract(tar, dest)
|
||||
|
||||
@staticmethod
|
||||
def _find_mmdb(root: Path) -> Path | None:
|
||||
for candidate in root.rglob("*.mmdb"):
|
||||
return candidate
|
||||
return None
|
||||
|
||||
def _latest_file_sync(self, edition_id: str) -> Path | None:
|
||||
directory = self.dest_dir
|
||||
if not directory.is_dir():
|
||||
return None
|
||||
candidates = list(directory.glob(f"{edition_id}*.mmdb"))
|
||||
if not candidates:
|
||||
return None
|
||||
return max(candidates, key=lambda p: p.stat().st_mtime)
|
||||
|
||||
async def _latest_file(self, edition_id: str) -> Path | None:
|
||||
return await asyncio.to_thread(self._latest_file_sync, edition_id)
|
||||
|
||||
async def _download_and_extract(self, edition_id: str) -> Path:
|
||||
if not self.license_key:
|
||||
raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.")
|
||||
|
||||
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
|
||||
tmp_dir = Path(await asyncio.to_thread(tempfile.mkdtemp))
|
||||
|
||||
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
|
||||
with client.stream("GET", url) as resp:
|
||||
try:
|
||||
tgz_path = tmp_dir / "db.tgz"
|
||||
async with (
|
||||
httpx.AsyncClient(follow_redirects=True, timeout=self.timeout) as client,
|
||||
client.stream("GET", url) as resp,
|
||||
):
|
||||
resp.raise_for_status()
|
||||
with tempfile.TemporaryDirectory() as tmpd:
|
||||
tgz_path = os.path.join(tmpd, "db.tgz")
|
||||
# 流式写入
|
||||
with open(tgz_path, "wb") as f:
|
||||
for chunk in resp.iter_bytes():
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
async with aiofiles.open(tgz_path, "wb") as download_file:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if chunk:
|
||||
await download_file.write(chunk)
|
||||
|
||||
# 解压并只移动 .mmdb
|
||||
with tarfile.open(tgz_path, "r:gz") as tar:
|
||||
# 先安全检查与解压
|
||||
self._safe_extract(tar, tmpd)
|
||||
await asyncio.to_thread(self._extract_tarball, tgz_path, tmp_dir)
|
||||
mmdb_path = await asyncio.to_thread(self._find_mmdb, tmp_dir)
|
||||
if mmdb_path is None:
|
||||
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
|
||||
|
||||
# 递归找 .mmdb
|
||||
mmdb_path = None
|
||||
for root, _, files in os.walk(tmpd):
|
||||
for fn in files:
|
||||
if fn.endswith(".mmdb"):
|
||||
mmdb_path = os.path.join(root, fn)
|
||||
break
|
||||
if mmdb_path:
|
||||
break
|
||||
await asyncio.to_thread(self.dest_dir.mkdir, parents=True, exist_ok=True)
|
||||
dst = self.dest_dir / mmdb_path.name
|
||||
await asyncio.to_thread(shutil.move, mmdb_path, dst)
|
||||
return dst
|
||||
finally:
|
||||
await asyncio.to_thread(shutil.rmtree, tmp_dir, ignore_errors=True)
|
||||
|
||||
if not mmdb_path:
|
||||
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
|
||||
async def update(self, force: bool = False) -> None:
|
||||
async with self._update_lock:
|
||||
for edition in self.editions:
|
||||
edition_id = EDITIONS[edition]
|
||||
path = await self._latest_file(edition_id)
|
||||
need_download = force or path is None
|
||||
|
||||
os.makedirs(self.dest_dir, exist_ok=True)
|
||||
dst = os.path.join(self.dest_dir, os.path.basename(mmdb_path))
|
||||
shutil.move(mmdb_path, dst)
|
||||
return dst
|
||||
|
||||
def _latest_file(self, edition_id: str):
|
||||
if not os.path.isdir(self.dest_dir):
|
||||
return None
|
||||
files = [
|
||||
os.path.join(self.dest_dir, f)
|
||||
for f in os.listdir(self.dest_dir)
|
||||
if f.startswith(edition_id) and f.endswith(".mmdb")
|
||||
]
|
||||
return max(files, key=os.path.getmtime) if files else None
|
||||
|
||||
def update(self, force=False):
|
||||
from app.log import logger
|
||||
|
||||
for ed in self.editions:
|
||||
eid = EDITIONS[ed]
|
||||
path = self._latest_file(eid)
|
||||
need = force or not path
|
||||
|
||||
if path:
|
||||
age_days = (time.time() - os.path.getmtime(path)) / 86400
|
||||
if age_days >= self.max_age_days:
|
||||
need = True
|
||||
logger.info(
|
||||
f"{eid} database is {age_days:.1f} days old "
|
||||
f"(max: {self.max_age_days}), will download new version"
|
||||
)
|
||||
if path:
|
||||
mtime = await asyncio.to_thread(path.stat)
|
||||
age_days = (time.time() - mtime.st_mtime) / 86400
|
||||
if age_days >= self.max_age_days:
|
||||
need_download = True
|
||||
logger.info(
|
||||
f"{edition_id} database is {age_days:.1f} days old "
|
||||
f"(max: {self.max_age_days}), will download new version"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{edition_id} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})"
|
||||
)
|
||||
else:
|
||||
logger.info(f"{eid} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})")
|
||||
else:
|
||||
logger.info(f"{eid} database not found, will download")
|
||||
logger.info(f"{edition_id} database not found, will download")
|
||||
|
||||
if need:
|
||||
logger.info(f"Downloading {eid} database...")
|
||||
path = self._download_and_extract(eid)
|
||||
logger.info(f"{eid} database downloaded successfully")
|
||||
else:
|
||||
logger.info(f"Using existing {eid} database")
|
||||
if need_download:
|
||||
logger.info(f"Downloading {edition_id} database...")
|
||||
path = await self._download_and_extract(edition_id)
|
||||
logger.info(f"{edition_id} database downloaded successfully")
|
||||
else:
|
||||
logger.info(f"Using existing {edition_id} database")
|
||||
|
||||
old = self._readers.get(ed)
|
||||
if old:
|
||||
try:
|
||||
old.close()
|
||||
except Exception:
|
||||
pass
|
||||
if path is not None:
|
||||
self._readers[ed] = maxminddb.open_database(path)
|
||||
old_reader = self._readers.get(edition)
|
||||
if old_reader:
|
||||
with suppress(Exception):
|
||||
old_reader.close()
|
||||
if path is not None:
|
||||
self._readers[edition] = maxminddb.open_database(str(path))
|
||||
|
||||
def lookup(self, ip: str):
|
||||
res = {"ip": ip}
|
||||
# City
|
||||
city_r = self._readers.get("City")
|
||||
if city_r:
|
||||
data = city_r.get(ip)
|
||||
if data:
|
||||
country = data.get("country") or {}
|
||||
res["country_iso"] = country.get("iso_code") or ""
|
||||
res["country_name"] = (country.get("names") or {}).get("en", "")
|
||||
city = data.get("city") or {}
|
||||
res["city_name"] = (city.get("names") or {}).get("en", "")
|
||||
loc = data.get("location") or {}
|
||||
res["latitude"] = str(loc.get("latitude") or "")
|
||||
res["longitude"] = str(loc.get("longitude") or "")
|
||||
res["time_zone"] = str(loc.get("time_zone") or "")
|
||||
postal = data.get("postal") or {}
|
||||
if "code" in postal:
|
||||
res["postal_code"] = postal["code"]
|
||||
# ASN
|
||||
asn_r = self._readers.get("ASN")
|
||||
if asn_r:
|
||||
data = asn_r.get(ip)
|
||||
if data:
|
||||
res["asn"] = data.get("autonomous_system_number")
|
||||
res["organization"] = data.get("autonomous_system_organization")
|
||||
def lookup(self, ip: str) -> GeoIPLookupResult:
|
||||
res: GeoIPLookupResult = {"ip": ip}
|
||||
city_reader = self._readers.get("City")
|
||||
if city_reader:
|
||||
data = city_reader.get(ip)
|
||||
if isinstance(data, dict):
|
||||
country = self._as_mapping(data.get("country"))
|
||||
res["country_iso"] = self._as_str(country.get("iso_code"))
|
||||
country_names = self._as_mapping(country.get("names"))
|
||||
res["country_name"] = self._as_str(country_names.get("en"))
|
||||
|
||||
city = self._as_mapping(data.get("city"))
|
||||
city_names = self._as_mapping(city.get("names"))
|
||||
res["city_name"] = self._as_str(city_names.get("en"))
|
||||
|
||||
location = self._as_mapping(data.get("location"))
|
||||
latitude = location.get("latitude")
|
||||
longitude = location.get("longitude")
|
||||
res["latitude"] = str(latitude) if latitude is not None else ""
|
||||
res["longitude"] = str(longitude) if longitude is not None else ""
|
||||
res["time_zone"] = self._as_str(location.get("time_zone"))
|
||||
|
||||
postal = self._as_mapping(data.get("postal"))
|
||||
postal_code = postal.get("code")
|
||||
if postal_code is not None:
|
||||
res["postal_code"] = self._as_str(postal_code)
|
||||
|
||||
asn_reader = self._readers.get("ASN")
|
||||
if asn_reader:
|
||||
data = asn_reader.get(ip)
|
||||
if isinstance(data, dict):
|
||||
res["asn"] = self._as_int(data.get("autonomous_system_number"))
|
||||
res["organization"] = self._as_str(data.get("autonomous_system_organization"), default="")
|
||||
return res
|
||||
|
||||
def close(self):
|
||||
for r in self._readers.values():
|
||||
try:
|
||||
r.close()
|
||||
except Exception:
|
||||
pass
|
||||
def close(self) -> None:
|
||||
for reader in self._readers.values():
|
||||
with suppress(Exception):
|
||||
reader.close()
|
||||
self._readers = {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
|
||||
geo.update()
|
||||
print(geo.lookup("8.8.8.8"))
|
||||
geo.close()
|
||||
|
||||
async def _demo() -> None:
|
||||
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
|
||||
await geo.update()
|
||||
print(geo.lookup("8.8.8.8"))
|
||||
geo.close()
|
||||
|
||||
asyncio.run(_demo())
|
||||
|
||||
Reference in New Issue
Block a user