add geoip

This commit is contained in:
咕谷酱
2025-08-17 23:56:46 +08:00
parent eaa6d4d92b
commit de0c86f4a2
11 changed files with 503 additions and 82 deletions

View File

@@ -3,9 +3,23 @@ from __future__ import annotations
from enum import Enum
from typing import Annotated, Any
from pydantic import AliasChoices, Field, HttpUrl, ValidationInfo, field_validator
from pydantic import AliasChoices, Field, HttpUrl, ValidationInfo, field_validator, BeforeValidator
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
def _parse_list(v):
if v is None or v == "" or str(v).strip() in ("[]", "{}"):
return []
if isinstance(v, list):
return v
s = str(v).strip()
try:
import json
parsed = json.loads(s)
if isinstance(parsed, list):
return parsed
except Exception:
pass
return [x.strip() for x in s.split(",") if x.strip()]
class AWSS3StorageSettings(BaseSettings):
s3_access_key_id: str
@@ -96,6 +110,12 @@ class Settings(BaseSettings):
# Sentry 配置
sentry_dsn: HttpUrl | None = None
# GeoIP 配置
maxmind_license_key: str = ""
geoip_dest_dir: str = "./geoip"
geoip_update_day: int = 1 # 每周更新的星期几0=周一6=周日)
geoip_update_hour: int = 2 # 每周更新的小时数0-23
# 游戏设置
enable_rx: bool = Field(
default=False, validation_alias=AliasChoices("enable_rx", "enable_osu_rx")
@@ -108,7 +128,7 @@ class Settings(BaseSettings):
enable_all_beatmap_leaderboard: bool = False
enable_all_beatmap_pp: bool = False
suspicious_score_check: bool = True
seasonal_backgrounds: list[str] = []
seasonal_backgrounds: Annotated[list[str], BeforeValidator(_parse_list)] = []
banned_name: list[str] = [
"mrekk",
"vaxei",

51
app/dependencies/geoip.py Normal file
View File

@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
"""
GeoIP dependency for FastAPI
"""
from functools import lru_cache
from app.helpers.geoip_helper import GeoIPHelper
from app.config import settings
@lru_cache()
def get_geoip_helper() -> GeoIPHelper:
"""
获取 GeoIP 帮助类实例
使用 lru_cache 确保单例模式
"""
return GeoIPHelper(
dest_dir=settings.geoip_dest_dir,
license_key=settings.maxmind_license_key,
editions=["City", "ASN"],
max_age_days=8,
timeout=60.0
)
def get_client_ip(request) -> str:
"""
Get the real client IP address
Supports proxies, load balancers, and Cloudflare headers
"""
headers = request.headers
# 1. Cloudflare specific headers
cf_ip = headers.get("CF-Connecting-IP")
if cf_ip:
return cf_ip.strip()
true_client_ip = headers.get("True-Client-IP")
if true_client_ip:
return true_client_ip.strip()
# 2. Standard proxy headers
forwarded_for = headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For may contain multiple IPs, take the first
return forwarded_for.split(",")[0].strip()
real_ip = headers.get("X-Real-IP")
if real_ip:
return real_ip.strip()
# 3. Fallback to client host
return request.client.host if request.client else "127.0.0.1"

153
app/helpers/geoip_helper.py Normal file
View File

@@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
"""
GeoLite2 Helper Class
"""
import os
import tarfile
import shutil
import tempfile
import time
import httpx
import maxminddb
from pathlib import Path
class GeoIPHelper:
BASE_URL = "https://download.maxmind.com/app/geoip_download"
EDITIONS = {"City": "GeoLite2-City", "Country": "GeoLite2-Country", "ASN": "GeoLite2-ASN"}
def __init__(self, dest_dir="./geoip", license_key=None, editions=None, max_age_days=8, timeout=60.0):
self.dest_dir = dest_dir
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
self.editions = editions or ["City", "ASN"]
self.max_age_days = max_age_days
self.timeout = timeout
self._readers = {}
@staticmethod
def _safe_extract(tar: tarfile.TarFile, path: str):
base = Path(path).resolve()
for m in tar.getmembers():
target = (base / m.name).resolve()
if not str(target).startswith(str(base)):
raise RuntimeError("Unsafe path in tar file")
tar.extractall(path=path, filter='data')
def _download_and_extract(self, edition_id: str) -> str:
"""
下载并解压 mmdb 文件到 dest_dir仅保留 .mmdb
- 跟随 302 重定向
- 流式下载到临时文件
- 临时目录退出后自动清理
"""
if not self.license_key:
raise ValueError("缺少 MaxMind License Key请传入或设置环境变量 MAXMIND_LICENSE_KEY")
url = f"{self.BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
with client.stream("GET", url) as resp:
resp.raise_for_status()
with tempfile.TemporaryDirectory() as tmpd:
tgz_path = os.path.join(tmpd, "db.tgz")
# 流式写入
with open(tgz_path, "wb") as f:
for chunk in resp.iter_bytes():
if chunk:
f.write(chunk)
# 解压并只移动 .mmdb
with tarfile.open(tgz_path, "r:gz") as tar:
# 先安全检查与解压
self._safe_extract(tar, tmpd)
# 递归找 .mmdb
mmdb_path = None
for root, _, files in os.walk(tmpd):
for fn in files:
if fn.endswith(".mmdb"):
mmdb_path = os.path.join(root, fn)
break
if mmdb_path:
break
if not mmdb_path:
raise RuntimeError("未在压缩包中找到 .mmdb 文件")
os.makedirs(self.dest_dir, exist_ok=True)
dst = os.path.join(self.dest_dir, os.path.basename(mmdb_path))
shutil.move(mmdb_path, dst)
return dst
def _latest_file(self, edition_id: str):
if not os.path.isdir(self.dest_dir):
return None
files = [os.path.join(self.dest_dir, f) for f in os.listdir(self.dest_dir)
if f.startswith(edition_id) and f.endswith(".mmdb")]
return max(files, key=os.path.getmtime) if files else None
def update(self, force=False):
for ed in self.editions:
eid = self.EDITIONS[ed]
path = self._latest_file(eid)
need = force or not path
if path:
age_days = (time.time() - os.path.getmtime(path)) / 86400
if age_days >= self.max_age_days:
need = True
if need:
path = self._download_and_extract(eid)
old = self._readers.get(ed)
if old:
try:
old.close()
except:
pass
if path is not None:
self._readers[ed] = maxminddb.open_database(path)
def lookup(self, ip: str):
res = {"ip": ip}
# City
city_r = self._readers.get("City")
if city_r:
data = city_r.get(ip)
if data:
country = data.get("country") or {}
res["country_iso"] = country.get("iso_code") or ""
res["country_name"] = (country.get("names") or {}).get("en", "")
city = data.get("city") or {}
res["city_name"] = (city.get("names") or {}).get("en", "")
loc = data.get("location") or {}
res["latitude"] = str(loc.get("latitude") or "")
res["longitude"] = str(loc.get("longitude") or "")
res["time_zone"] = str(loc.get("time_zone") or "")
postal = data.get("postal") or {}
if "code" in postal:
res["postal_code"] = postal["code"]
# ASN
asn_r = self._readers.get("ASN")
if asn_r:
data = asn_r.get(ip)
if data:
res["asn"] = data.get("autonomous_system_number")
res["organization"] = data.get("autonomous_system_organization")
return res
def close(self):
for r in self._readers.values():
try:
r.close()
except:
pass
self._readers = {}
if __name__ == "__main__":
# 示例用法
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
geo.update()
print(geo.lookup("8.8.8.8"))
geo.close()

View File

@@ -20,6 +20,8 @@ from app.database import DailyChallengeStats, OAuthClient, User
from app.database.statistics import UserStatistics
from app.dependencies import get_db
from app.dependencies.database import get_redis
from app.dependencies.geoip import get_geoip_helper, get_client_ip
from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger
from app.models.oauth import (
OAuthErrorResponse,
@@ -29,7 +31,7 @@ from app.models.oauth import (
)
from app.models.score import GameMode
from fastapi import APIRouter, Depends, Form
from fastapi import APIRouter, Depends, Form, Request
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlalchemy import text
@@ -79,18 +81,20 @@ def validate_password(password: str) -> list[str]:
router = APIRouter(tags=["osu! OAuth 认证"])
@router.post(
"/users",
name="注册用户",
description="用户注册接口",
)
async def register_user(
request: Request,
user_username: str = Form(..., alias="user[username]", description="用户名"),
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
user_password: str = Form(..., alias="user[password]", description="密码"),
db: AsyncSession = Depends(get_db),
geoip: GeoIPHelper = Depends(get_geoip_helper)
):
username_errors = validate_username(user_username)
email_errors = validate_email(user_email)
password_errors = validate_password(user_password)
@@ -119,6 +123,21 @@ async def register_user(
)
try:
# 获取客户端 IP 并查询地理位置
client_ip = get_client_ip(request)
country_code = "CN" # 默认国家代码
try:
# 查询 IP 地理位置
geo_info = geoip.lookup(client_ip)
if geo_info and geo_info.get("country_iso"):
country_code = geo_info["country_iso"]
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
else:
logger.warning(f"Could not determine country for IP {client_ip}")
except Exception as e:
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
# 创建新用户
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
result = await db.execute( # pyright: ignore[reportDeprecated]
@@ -137,7 +156,7 @@ async def register_user(
email=user_email,
pw_bcrypt=get_password_hash(user_password),
priv=1, # 普通用户权限
country_code="CN", # 默认国家
country_code=country_code, # 根据 IP 地理位置设置国家
join_date=datetime.now(UTC),
last_visit=datetime.now(UTC),
is_supporter=settings.enable_supporter_for_all_users,

View File

@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
"""
[GeoIP] Scheduled Update Service
Periodically update the MaxMind GeoIP database
"""
import asyncio
from datetime import datetime
from app.config import settings
from app.dependencies.geoip import get_geoip_helper
from app.dependencies.scheduler import get_scheduler
from app.log import logger
async def update_geoip_database():
"""
Asynchronous task to update the GeoIP database
"""
try:
logger.info("[GeoIP] Starting scheduled GeoIP database update...")
geoip = get_geoip_helper()
# Run the synchronous update method in a background thread
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: geoip.update(force=False))
logger.info("[GeoIP] Scheduled GeoIP database update completed successfully")
except Exception as e:
logger.error(f"[GeoIP] Scheduled GeoIP database update failed: {e}")
def schedule_geoip_updates():
"""
Schedule the GeoIP database update task
"""
scheduler = get_scheduler()
# Use settings to configure the update time: update once a week
scheduler.add_job(
update_geoip_database,
'cron',
day_of_week=settings.geoip_update_day,
hour=settings.geoip_update_hour,
minute=0,
id='geoip_weekly_update',
name='Weekly GeoIP database update',
replace_existing=True
)
logger.info(
f"[GeoIP] Scheduled update task registered: "
f"every week on day {settings.geoip_update_day} at {settings.geoip_update_hour}:00"
)

25
app/service/init_geoip.py Normal file
View File

@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
"""
[GeoIP] Initialization Service
Initialize the GeoIP database when the application starts
"""
import asyncio
from app.dependencies.geoip import get_geoip_helper
from app.log import logger
async def init_geoip():
"""
Asynchronously initialize the GeoIP database
"""
try:
geoip = get_geoip_helper()
logger.info("[GeoIP] Initializing GeoIP database...")
# Run the synchronous update method in a background thread
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, geoip.update)
logger.info("[GeoIP] GeoIP database initialization completed")
except Exception as e:
logger.error(f"[GeoIP] GeoIP database initialization failed: {e}")
# Do not raise an exception to avoid blocking application startup