Files
g0v0-server/app/dependencies/geoip.py

109 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
GeoIP dependency for FastAPI
"""
from functools import lru_cache
import ipaddress
from typing import Annotated
from app.config import settings
from app.helpers.geoip_helper import GeoIPHelper
from fastapi import Depends, Request
@lru_cache
def get_geoip_helper() -> GeoIPHelper:
"""
获取 GeoIP 帮助类实例
使用 lru_cache 确保单例模式
"""
return GeoIPHelper(
dest_dir=settings.geoip_dest_dir,
license_key=settings.maxmind_license_key,
editions=["City", "ASN"],
max_age_days=8,
timeout=60.0,
)
def get_client_ip(request: Request) -> str:
"""
获取客户端真实 IP 地址
支持 IPv4 和 IPv6考虑代理、负载均衡器等情况
"""
headers = request.headers
# 1. Cloudflare 专用头部
cf_ip = headers.get("CF-Connecting-IP")
if cf_ip:
ip = cf_ip.strip()
if is_valid_ip(ip):
return ip
true_client_ip = headers.get("True-Client-IP")
if true_client_ip:
ip = true_client_ip.strip()
if is_valid_ip(ip):
return ip
# 2. 标准代理头部
forwarded_for = headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For 可能包含多个 IP取第一个有效的
for ip_str in forwarded_for.split(","):
ip = ip_str.strip()
if is_valid_ip(ip) and not is_private_ip(ip):
return ip
real_ip = headers.get("X-Real-IP")
if real_ip:
ip = real_ip.strip()
if is_valid_ip(ip):
return ip
# 3. 回退到客户端 IP
client_ip = request.client.host if request.client else "127.0.0.1"
return client_ip if is_valid_ip(client_ip) else "127.0.0.1"
IPAddress = Annotated[str, Depends(get_client_ip)]
GeoIPService = Annotated[GeoIPHelper, Depends(get_geoip_helper)]
def is_valid_ip(ip_str: str) -> bool:
"""
验证 IP 地址是否有效(支持 IPv4 和 IPv6
"""
try:
ipaddress.ip_address(ip_str)
return True
except ValueError:
return False
def is_private_ip(ip_str: str) -> bool:
"""
判断是否为私有 IP 地址
"""
try:
ip = ipaddress.ip_address(ip_str)
return ip.is_private
except ValueError:
return False
def normalize_ip(ip_str: str) -> str:
"""
标准化 IP 地址格式
对于 IPv6转换为压缩格式
"""
try:
ip = ipaddress.ip_address(ip_str)
if isinstance(ip, ipaddress.IPv6Address):
return ip.compressed
else:
return str(ip)
except ValueError:
return ip_str