Files
g0v0-server/app/helpers/geoip_helper.py

227 lines
8.1 KiB
Python

"""
GeoLite2 Helper Class (asynchronous)
"""
import asyncio
from contextlib import suppress
import os
from pathlib import Path
import shutil
import tarfile
import tempfile
import time
from typing import Any, Required, TypedDict
from app.log import logger
import aiofiles
import httpx
import maxminddb
class GeoIPLookupResult(TypedDict, total=False):
ip: Required[str]
country_iso: str
country_name: str
city_name: str
latitude: str
longitude: str
time_zone: str
postal_code: str
asn: int | None
organization: str
BASE_URL = "https://download.maxmind.com/app/geoip_download"
EDITIONS = {
"City": "GeoLite2-City",
"Country": "GeoLite2-Country",
"ASN": "GeoLite2-ASN",
}
class GeoIPHelper:
def __init__(
self,
dest_dir: str | Path = Path("./geoip"),
license_key: str | None = None,
editions: list[str] | None = None,
max_age_days: int = 8,
timeout: float = 60.0,
):
self.dest_dir = Path(dest_dir).expanduser()
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
self.editions = list(editions or ["City", "ASN"])
self.max_age_days = max_age_days
self.timeout = timeout
self._readers: dict[str, maxminddb.Reader] = {}
self._update_lock = asyncio.Lock()
@staticmethod
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
base = path.resolve()
for member in tar.getmembers():
target = (base / member.name).resolve()
if not target.is_relative_to(base): # py312
raise RuntimeError("Unsafe path in tar file")
tar.extractall(path=base, filter="data")
@staticmethod
def _as_mapping(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}
@staticmethod
def _as_str(value: Any, default: str = "") -> str:
if isinstance(value, str):
return value
if value is None:
return default
return str(value)
@staticmethod
def _as_int(value: Any) -> int | None:
return value if isinstance(value, int) else None
@staticmethod
def _extract_tarball(src: Path, dest: Path) -> None:
with tarfile.open(src, "r:gz") as tar:
GeoIPHelper._safe_extract(tar, dest)
@staticmethod
def _find_mmdb(root: Path) -> Path | None:
for candidate in root.rglob("*.mmdb"):
return candidate
return None
def _latest_file_sync(self, edition_id: str) -> Path | None:
directory = self.dest_dir
if not directory.is_dir():
return None
candidates = list(directory.glob(f"{edition_id}*.mmdb"))
if not candidates:
return None
return max(candidates, key=lambda p: p.stat().st_mtime)
async def _latest_file(self, edition_id: str) -> Path | None:
return await asyncio.to_thread(self._latest_file_sync, edition_id)
async def _download_and_extract(self, edition_id: str) -> Path:
if not self.license_key:
raise ValueError("MaxMind License Key is missing. Please configure it via env MAXMIND_LICENSE_KEY.")
url = f"{BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
tmp_dir = Path(await asyncio.to_thread(tempfile.mkdtemp))
try:
tgz_path = tmp_dir / "db.tgz"
async with (
httpx.AsyncClient(follow_redirects=True, timeout=self.timeout) as client,
client.stream("GET", url) as resp,
):
resp.raise_for_status()
async with aiofiles.open(tgz_path, "wb") as download_file:
async for chunk in resp.aiter_bytes():
if chunk:
await download_file.write(chunk)
await asyncio.to_thread(self._extract_tarball, tgz_path, tmp_dir)
mmdb_path = await asyncio.to_thread(self._find_mmdb, tmp_dir)
if mmdb_path is None:
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
await asyncio.to_thread(self.dest_dir.mkdir, parents=True, exist_ok=True)
dst = self.dest_dir / mmdb_path.name
await asyncio.to_thread(shutil.move, mmdb_path, dst)
return dst
finally:
await asyncio.to_thread(shutil.rmtree, tmp_dir, ignore_errors=True)
async def update(self, force: bool = False) -> None:
async with self._update_lock:
for edition in self.editions:
edition_id = EDITIONS[edition]
path = await self._latest_file(edition_id)
need_download = force or path is None
if path:
mtime = await asyncio.to_thread(path.stat)
age_days = (time.time() - mtime.st_mtime) / 86400
if age_days >= self.max_age_days:
need_download = True
logger.info(
f"{edition_id} database is {age_days:.1f} days old "
f"(max: {self.max_age_days}), will download new version"
)
else:
logger.info(
f"{edition_id} database is {age_days:.1f} days old, still fresh (max: {self.max_age_days})"
)
else:
logger.info(f"{edition_id} database not found, will download")
if need_download:
logger.info(f"Downloading {edition_id} database...")
path = await self._download_and_extract(edition_id)
logger.info(f"{edition_id} database downloaded successfully")
else:
logger.info(f"Using existing {edition_id} database")
old_reader = self._readers.get(edition)
if old_reader:
with suppress(Exception):
old_reader.close()
if path is not None:
self._readers[edition] = maxminddb.open_database(str(path))
def lookup(self, ip: str) -> GeoIPLookupResult:
res: GeoIPLookupResult = {"ip": ip}
city_reader = self._readers.get("City")
if city_reader:
data = city_reader.get(ip)
if isinstance(data, dict):
country = self._as_mapping(data.get("country"))
res["country_iso"] = self._as_str(country.get("iso_code"))
country_names = self._as_mapping(country.get("names"))
res["country_name"] = self._as_str(country_names.get("en"))
city = self._as_mapping(data.get("city"))
city_names = self._as_mapping(city.get("names"))
res["city_name"] = self._as_str(city_names.get("en"))
location = self._as_mapping(data.get("location"))
latitude = location.get("latitude")
longitude = location.get("longitude")
res["latitude"] = str(latitude) if latitude is not None else ""
res["longitude"] = str(longitude) if longitude is not None else ""
res["time_zone"] = self._as_str(location.get("time_zone"))
postal = self._as_mapping(data.get("postal"))
postal_code = postal.get("code")
if postal_code is not None:
res["postal_code"] = self._as_str(postal_code)
asn_reader = self._readers.get("ASN")
if asn_reader:
data = asn_reader.get(ip)
if isinstance(data, dict):
res["asn"] = self._as_int(data.get("autonomous_system_number"))
res["organization"] = self._as_str(data.get("autonomous_system_organization"), default="")
return res
def close(self) -> None:
for reader in self._readers.values():
with suppress(Exception):
reader.close()
self._readers = {}
if __name__ == "__main__":
async def _demo() -> None:
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
await geo.update()
print(geo.lookup("8.8.8.8"))
geo.close()
asyncio.run(_demo())