chore(lint): make ruff happy
This commit is contained in:
@@ -119,7 +119,6 @@ async def authenticate_user_legacy(
|
|||||||
if not user:
|
if not user:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
await db.refresh(user)
|
await db.refresh(user)
|
||||||
|
|
||||||
# 3. 验证密码
|
# 3. 验证密码
|
||||||
@@ -265,7 +264,6 @@ async def get_user_by_authorization_code(
|
|||||||
statement = select(User).where(User.id == int(user_id))
|
statement = select(User).where(User.id == int(user_id))
|
||||||
user = (await db.exec(statement)).first()
|
user = (await db.exec(statement)).first()
|
||||||
if user:
|
if user:
|
||||||
|
|
||||||
await db.refresh(user)
|
await db.refresh(user)
|
||||||
return (user, scopes.split(","))
|
return (user, scopes.split(","))
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -3,9 +3,17 @@ from __future__ import annotations
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from pydantic import AliasChoices, Field, HttpUrl, ValidationInfo, field_validator, BeforeValidator
|
from pydantic import (
|
||||||
|
AliasChoices,
|
||||||
|
BeforeValidator,
|
||||||
|
Field,
|
||||||
|
HttpUrl,
|
||||||
|
ValidationInfo,
|
||||||
|
field_validator,
|
||||||
|
)
|
||||||
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
|
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
def _parse_list(v):
|
def _parse_list(v):
|
||||||
if v is None or v == "" or str(v).strip() in ("[]", "{}"):
|
if v is None or v == "" or str(v).strip() in ("[]", "{}"):
|
||||||
return []
|
return []
|
||||||
@@ -14,6 +22,7 @@ def _parse_list(v):
|
|||||||
s = str(v).strip()
|
s = str(v).strip()
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
parsed = json.loads(s)
|
parsed = json.loads(s)
|
||||||
if isinstance(parsed, list):
|
if isinstance(parsed, list):
|
||||||
return parsed
|
return parsed
|
||||||
@@ -21,6 +30,7 @@ def _parse_list(v):
|
|||||||
pass
|
pass
|
||||||
return [x.strip() for x in s.split(",") if x.strip()]
|
return [x.strip() for x in s.split(",") if x.strip()]
|
||||||
|
|
||||||
|
|
||||||
class AWSS3StorageSettings(BaseSettings):
|
class AWSS3StorageSettings(BaseSettings):
|
||||||
s3_access_key_id: str
|
s3_access_key_id: str
|
||||||
s3_secret_access_key: str
|
s3_secret_access_key: str
|
||||||
|
|||||||
@@ -1,40 +1,59 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
"""
|
||||||
User Login Log Database Model
|
User Login Log Database Model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
|
||||||
from sqlmodel import Field, SQLModel
|
from sqlmodel import Field, SQLModel
|
||||||
|
|
||||||
|
|
||||||
class UserLoginLog(SQLModel, table=True):
|
class UserLoginLog(SQLModel, table=True):
|
||||||
"""User login log table"""
|
"""User login log table"""
|
||||||
|
|
||||||
__tablename__ = "user_login_log" # pyright: ignore[reportAssignmentType]
|
__tablename__ = "user_login_log" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
id: Optional[int] = Field(default=None, primary_key=True, description="Record ID")
|
id: int | None = Field(default=None, primary_key=True, description="Record ID")
|
||||||
user_id: int = Field(index=True, description="User ID")
|
user_id: int = Field(index=True, description="User ID")
|
||||||
ip_address: str = Field(max_length=45, index=True, description="IP address (supports IPv4 and IPv6)")
|
ip_address: str = Field(
|
||||||
user_agent: Optional[str] = Field(default=None, max_length=500, description="User agent information")
|
max_length=45, index=True, description="IP address (supports IPv4 and IPv6)"
|
||||||
login_time: datetime = Field(default_factory=datetime.utcnow, description="Login time")
|
)
|
||||||
|
user_agent: str | None = Field(
|
||||||
|
default=None, max_length=500, description="User agent information"
|
||||||
|
)
|
||||||
|
login_time: datetime = Field(
|
||||||
|
default_factory=datetime.utcnow, description="Login time"
|
||||||
|
)
|
||||||
|
|
||||||
# GeoIP information
|
# GeoIP information
|
||||||
country_code: Optional[str] = Field(default=None, max_length=2, description="Country code")
|
country_code: str | None = Field(
|
||||||
country_name: Optional[str] = Field(default=None, max_length=100, description="Country name")
|
default=None, max_length=2, description="Country code"
|
||||||
city_name: Optional[str] = Field(default=None, max_length=100, description="City name")
|
)
|
||||||
latitude: Optional[str] = Field(default=None, max_length=20, description="Latitude")
|
country_name: str | None = Field(
|
||||||
longitude: Optional[str] = Field(default=None, max_length=20, description="Longitude")
|
default=None, max_length=100, description="Country name"
|
||||||
time_zone: Optional[str] = Field(default=None, max_length=50, description="Time zone")
|
)
|
||||||
|
city_name: str | None = Field(default=None, max_length=100, description="City name")
|
||||||
|
latitude: str | None = Field(default=None, max_length=20, description="Latitude")
|
||||||
|
longitude: str | None = Field(default=None, max_length=20, description="Longitude")
|
||||||
|
time_zone: str | None = Field(default=None, max_length=50, description="Time zone")
|
||||||
|
|
||||||
# ASN information
|
# ASN information
|
||||||
asn: Optional[int] = Field(default=None, description="Autonomous System Number")
|
asn: int | None = Field(default=None, description="Autonomous System Number")
|
||||||
organization: Optional[str] = Field(default=None, max_length=200, description="Organization name")
|
organization: str | None = Field(
|
||||||
|
default=None, max_length=200, description="Organization name"
|
||||||
|
)
|
||||||
|
|
||||||
# Login status
|
# Login status
|
||||||
login_success: bool = Field(default=True, description="Whether the login was successful")
|
login_success: bool = Field(
|
||||||
login_method: str = Field(max_length=50, description="Login method (password/oauth/etc.)")
|
default=True, description="Whether the login was successful"
|
||||||
|
)
|
||||||
|
login_method: str = Field(
|
||||||
|
max_length=50, description="Login method (password/oauth/etc.)"
|
||||||
|
)
|
||||||
|
|
||||||
# Additional information
|
# Additional information
|
||||||
notes: Optional[str] = Field(default=None, max_length=500, description="Additional notes")
|
notes: str | None = Field(
|
||||||
|
default=None, max_length=500, description="Additional notes"
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
"""
|
||||||
GeoIP dependency for FastAPI
|
GeoIP dependency for FastAPI
|
||||||
"""
|
"""
|
||||||
import ipaddress
|
|
||||||
from functools import lru_cache
|
|
||||||
from app.helpers.geoip_helper import GeoIPHelper
|
|
||||||
from app.config import settings
|
|
||||||
|
|
||||||
@lru_cache()
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
import ipaddress
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.helpers.geoip_helper import GeoIPHelper
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
def get_geoip_helper() -> GeoIPHelper:
|
def get_geoip_helper() -> GeoIPHelper:
|
||||||
"""
|
"""
|
||||||
获取 GeoIP 帮助类实例
|
获取 GeoIP 帮助类实例
|
||||||
@@ -18,7 +22,7 @@ def get_geoip_helper() -> GeoIPHelper:
|
|||||||
license_key=settings.maxmind_license_key,
|
license_key=settings.maxmind_license_key,
|
||||||
editions=["City", "ASN"],
|
editions=["City", "ASN"],
|
||||||
max_age_days=8,
|
max_age_days=8,
|
||||||
timeout=60.0
|
timeout=60.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -90,7 +90,6 @@ async def get_client_user(
|
|||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||||
|
|
||||||
|
|
||||||
await db.refresh(user)
|
await db.refresh(user)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@@ -129,6 +128,5 @@ async def get_current_user(
|
|||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||||
|
|
||||||
|
|
||||||
await db.refresh(user)
|
await db.refresh(user)
|
||||||
return user
|
return user
|
||||||
|
|||||||
@@ -1,21 +1,36 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
"""
|
||||||
GeoLite2 Helper Class
|
GeoLite2 Helper Class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import tarfile
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
|
import tarfile
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import maxminddb
|
import maxminddb
|
||||||
from pathlib import Path
|
|
||||||
|
BASE_URL = "https://download.maxmind.com/app/geoip_download"
|
||||||
|
EDITIONS = {
|
||||||
|
"City": "GeoLite2-City",
|
||||||
|
"Country": "GeoLite2-Country",
|
||||||
|
"ASN": "GeoLite2-ASN",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class GeoIPHelper:
|
class GeoIPHelper:
|
||||||
BASE_URL = "https://download.maxmind.com/app/geoip_download"
|
def __init__(
|
||||||
EDITIONS = {"City": "GeoLite2-City", "Country": "GeoLite2-Country", "ASN": "GeoLite2-ASN"}
|
self,
|
||||||
|
dest_dir="./geoip",
|
||||||
def __init__(self, dest_dir="./geoip", license_key=None, editions=None, max_age_days=8, timeout=60.0):
|
license_key=None,
|
||||||
|
editions=None,
|
||||||
|
max_age_days=8,
|
||||||
|
timeout=60.0,
|
||||||
|
):
|
||||||
self.dest_dir = dest_dir
|
self.dest_dir = dest_dir
|
||||||
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
|
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
|
||||||
self.editions = editions or ["City", "ASN"]
|
self.editions = editions or ["City", "ASN"]
|
||||||
@@ -30,7 +45,7 @@ class GeoIPHelper:
|
|||||||
target = (base / m.name).resolve()
|
target = (base / m.name).resolve()
|
||||||
if not str(target).startswith(str(base)):
|
if not str(target).startswith(str(base)):
|
||||||
raise RuntimeError("Unsafe path in tar file")
|
raise RuntimeError("Unsafe path in tar file")
|
||||||
tar.extractall(path=path, filter='data')
|
tar.extractall(path=path, filter="data")
|
||||||
|
|
||||||
def _download_and_extract(self, edition_id: str) -> str:
|
def _download_and_extract(self, edition_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -40,9 +55,14 @@ class GeoIPHelper:
|
|||||||
- 临时目录退出后自动清理
|
- 临时目录退出后自动清理
|
||||||
"""
|
"""
|
||||||
if not self.license_key:
|
if not self.license_key:
|
||||||
raise ValueError("缺少 MaxMind License Key,请传入或设置环境变量 MAXMIND_LICENSE_KEY")
|
raise ValueError(
|
||||||
|
"缺少 MaxMind License Key,请传入或设置环境变量 MAXMIND_LICENSE_KEY"
|
||||||
|
)
|
||||||
|
|
||||||
url = f"{self.BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz"
|
url = (
|
||||||
|
f"{BASE_URL}?edition_id={edition_id}&"
|
||||||
|
f"license_key={self.license_key}&suffix=tar.gz"
|
||||||
|
)
|
||||||
|
|
||||||
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
|
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
|
||||||
with client.stream("GET", url) as resp:
|
with client.stream("GET", url) as resp:
|
||||||
@@ -81,13 +101,16 @@ class GeoIPHelper:
|
|||||||
def _latest_file(self, edition_id: str):
|
def _latest_file(self, edition_id: str):
|
||||||
if not os.path.isdir(self.dest_dir):
|
if not os.path.isdir(self.dest_dir):
|
||||||
return None
|
return None
|
||||||
files = [os.path.join(self.dest_dir, f) for f in os.listdir(self.dest_dir)
|
files = [
|
||||||
if f.startswith(edition_id) and f.endswith(".mmdb")]
|
os.path.join(self.dest_dir, f)
|
||||||
|
for f in os.listdir(self.dest_dir)
|
||||||
|
if f.startswith(edition_id) and f.endswith(".mmdb")
|
||||||
|
]
|
||||||
return max(files, key=os.path.getmtime) if files else None
|
return max(files, key=os.path.getmtime) if files else None
|
||||||
|
|
||||||
def update(self, force=False):
|
def update(self, force=False):
|
||||||
for ed in self.editions:
|
for ed in self.editions:
|
||||||
eid = self.EDITIONS[ed]
|
eid = EDITIONS[ed]
|
||||||
path = self._latest_file(eid)
|
path = self._latest_file(eid)
|
||||||
need = force or not path
|
need = force or not path
|
||||||
if path:
|
if path:
|
||||||
@@ -97,12 +120,11 @@ class GeoIPHelper:
|
|||||||
if need:
|
if need:
|
||||||
path = self._download_and_extract(eid)
|
path = self._download_and_extract(eid)
|
||||||
|
|
||||||
|
|
||||||
old = self._readers.get(ed)
|
old = self._readers.get(ed)
|
||||||
if old:
|
if old:
|
||||||
try:
|
try:
|
||||||
old.close()
|
old.close()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if path is not None:
|
if path is not None:
|
||||||
self._readers[ed] = maxminddb.open_database(path)
|
self._readers[ed] = maxminddb.open_database(path)
|
||||||
@@ -139,12 +161,11 @@ class GeoIPHelper:
|
|||||||
for r in self._readers.values():
|
for r in self._readers.values():
|
||||||
try:
|
try:
|
||||||
r.close()
|
r.close()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self._readers = {}
|
self._readers = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 示例用法
|
# 示例用法
|
||||||
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
|
geo = GeoIPHelper(dest_dir="./geoip", license_key="")
|
||||||
|
|||||||
@@ -20,10 +20,9 @@ from app.database import DailyChallengeStats, OAuthClient, User
|
|||||||
from app.database.statistics import UserStatistics
|
from app.database.statistics import UserStatistics
|
||||||
from app.dependencies import get_db
|
from app.dependencies import get_db
|
||||||
from app.dependencies.database import get_redis
|
from app.dependencies.database import get_redis
|
||||||
from app.dependencies.geoip import get_geoip_helper, get_client_ip
|
from app.dependencies.geoip import get_client_ip, get_geoip_helper
|
||||||
from app.helpers.geoip_helper import GeoIPHelper
|
from app.helpers.geoip_helper import GeoIPHelper
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.service.login_log_service import LoginLogService
|
|
||||||
from app.models.oauth import (
|
from app.models.oauth import (
|
||||||
OAuthErrorResponse,
|
OAuthErrorResponse,
|
||||||
RegistrationRequestErrors,
|
RegistrationRequestErrors,
|
||||||
@@ -31,6 +30,7 @@ from app.models.oauth import (
|
|||||||
UserRegistrationErrors,
|
UserRegistrationErrors,
|
||||||
)
|
)
|
||||||
from app.models.score import GameMode
|
from app.models.score import GameMode
|
||||||
|
from app.service.login_log_service import LoginLogService
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Form, Request
|
from fastapi import APIRouter, Depends, Form, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
@@ -82,6 +82,7 @@ def validate_password(password: str) -> list[str]:
|
|||||||
|
|
||||||
router = APIRouter(tags=["osu! OAuth 认证"])
|
router = APIRouter(tags=["osu! OAuth 认证"])
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/users",
|
"/users",
|
||||||
name="注册用户",
|
name="注册用户",
|
||||||
@@ -93,9 +94,8 @@ async def register_user(
|
|||||||
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
|
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
|
||||||
user_password: str = Form(..., alias="user[password]", description="密码"),
|
user_password: str = Form(..., alias="user[password]", description="密码"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
geoip: GeoIPHelper = Depends(get_geoip_helper)
|
geoip: GeoIPHelper = Depends(get_geoip_helper),
|
||||||
):
|
):
|
||||||
|
|
||||||
username_errors = validate_username(user_username)
|
username_errors = validate_username(user_username)
|
||||||
email_errors = validate_email(user_email)
|
email_errors = validate_email(user_email)
|
||||||
password_errors = validate_password(user_password)
|
password_errors = validate_password(user_password)
|
||||||
@@ -133,7 +133,10 @@ async def register_user(
|
|||||||
geo_info = geoip.lookup(client_ip)
|
geo_info = geoip.lookup(client_ip)
|
||||||
if geo_info and geo_info.get("country_iso"):
|
if geo_info and geo_info.get("country_iso"):
|
||||||
country_code = geo_info["country_iso"]
|
country_code = geo_info["country_iso"]
|
||||||
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}")
|
logger.info(
|
||||||
|
f"User {user_username} registering from "
|
||||||
|
f"{client_ip}, country: {country_code}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Could not determine country for IP {client_ip}")
|
logger.warning(f"Could not determine country for IP {client_ip}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -276,7 +279,7 @@ async def oauth_token(
|
|||||||
request=request,
|
request=request,
|
||||||
attempted_username=username,
|
attempted_username=username,
|
||||||
login_method="password",
|
login_method="password",
|
||||||
notes="Invalid credentials"
|
notes="Invalid credentials",
|
||||||
)
|
)
|
||||||
|
|
||||||
return create_oauth_error_response(
|
return create_oauth_error_response(
|
||||||
@@ -295,7 +298,7 @@ async def oauth_token(
|
|||||||
await db.refresh(user)
|
await db.refresh(user)
|
||||||
|
|
||||||
# 记录成功的登录
|
# 记录成功的登录
|
||||||
user_id = getattr(user, 'id')
|
user_id = getattr(user, "id")
|
||||||
assert user_id is not None, "User ID should not be None after authentication"
|
assert user_id is not None, "User ID should not be None after authentication"
|
||||||
await LoginLogService.record_login(
|
await LoginLogService.record_login(
|
||||||
db=db,
|
db=db,
|
||||||
@@ -303,7 +306,7 @@ async def oauth_token(
|
|||||||
request=request,
|
request=request,
|
||||||
login_success=True,
|
login_success=True,
|
||||||
login_method="password",
|
login_method="password",
|
||||||
notes=f"OAuth password grant for client {client_id}"
|
notes=f"OAuth password grant for client {client_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 生成令牌
|
# 生成令牌
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
"""
|
||||||
[GeoIP] Scheduled Update Service
|
[GeoIP] Scheduled Update Service
|
||||||
Periodically update the MaxMind GeoIP database
|
Periodically update the MaxMind GeoIP database
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.dependencies.geoip import get_geoip_helper
|
from app.dependencies.geoip import get_geoip_helper
|
||||||
from app.dependencies.scheduler import get_scheduler
|
from app.dependencies.scheduler import get_scheduler
|
||||||
@@ -37,16 +39,17 @@ def schedule_geoip_updates():
|
|||||||
# Use settings to configure the update time: update once a week
|
# Use settings to configure the update time: update once a week
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
update_geoip_database,
|
update_geoip_database,
|
||||||
'cron',
|
"cron",
|
||||||
day_of_week=settings.geoip_update_day,
|
day_of_week=settings.geoip_update_day,
|
||||||
hour=settings.geoip_update_hour,
|
hour=settings.geoip_update_hour,
|
||||||
minute=0,
|
minute=0,
|
||||||
id='geoip_weekly_update',
|
id="geoip_weekly_update",
|
||||||
name='Weekly GeoIP database update',
|
name="Weekly GeoIP database update",
|
||||||
replace_existing=True
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[GeoIP] Scheduled update task registered: "
|
f"[GeoIP] Scheduled update task registered: "
|
||||||
f"every week on day {settings.geoip_update_day} at {settings.geoip_update_hour}:00"
|
f"every week on day {settings.geoip_update_day} "
|
||||||
|
f"at {settings.geoip_update_hour}:00"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
"""
|
||||||
[GeoIP] Initialization Service
|
[GeoIP] Initialization Service
|
||||||
Initialize the GeoIP database when the application starts
|
Initialize the GeoIP database when the application starts
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from app.dependencies.geoip import get_geoip_helper
|
from app.dependencies.geoip import get_geoip_helper
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
|
|
||||||
|
|
||||||
async def init_geoip():
|
async def init_geoip():
|
||||||
"""
|
"""
|
||||||
Asynchronously initialize the GeoIP database
|
Asynchronously initialize the GeoIP database
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
"""
|
||||||
用户登录记录服务
|
用户登录记录服务
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
|
||||||
from fastapi import Request
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
from app.database.user_login_log import UserLoginLog
|
from app.database.user_login_log import UserLoginLog
|
||||||
from app.dependencies.geoip import get_geoip_helper, get_client_ip, normalize_ip
|
from app.dependencies.geoip import get_client_ip, get_geoip_helper, normalize_ip
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
class LoginLogService:
|
class LoginLogService:
|
||||||
"""用户登录记录服务"""
|
"""用户登录记录服务"""
|
||||||
@@ -23,7 +25,7 @@ class LoginLogService:
|
|||||||
request: Request,
|
request: Request,
|
||||||
login_success: bool = True,
|
login_success: bool = True,
|
||||||
login_method: str = "password",
|
login_method: str = "password",
|
||||||
notes: Optional[str] = None
|
notes: str | None = None,
|
||||||
) -> UserLoginLog:
|
) -> UserLoginLog:
|
||||||
"""
|
"""
|
||||||
记录用户登录信息
|
记录用户登录信息
|
||||||
@@ -54,7 +56,7 @@ class LoginLogService:
|
|||||||
login_time=datetime.utcnow(),
|
login_time=datetime.utcnow(),
|
||||||
login_success=login_success,
|
login_success=login_success,
|
||||||
login_method=login_method,
|
login_method=login_method,
|
||||||
notes=notes
|
notes=notes,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 异步获取GeoIP信息
|
# 异步获取GeoIP信息
|
||||||
@@ -64,8 +66,7 @@ class LoginLogService:
|
|||||||
# 在后台线程中运行GeoIP查询(避免阻塞)
|
# 在后台线程中运行GeoIP查询(避免阻塞)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
geo_info = await loop.run_in_executor(
|
geo_info = await loop.run_in_executor(
|
||||||
None,
|
None, lambda: geoip.lookup(ip_address)
|
||||||
lambda: geoip.lookup(ip_address)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if geo_info:
|
if geo_info:
|
||||||
@@ -86,7 +87,10 @@ class LoginLogService:
|
|||||||
|
|
||||||
login_log.organization = geo_info.get("organization", "")
|
login_log.organization = geo_info.get("organization", "")
|
||||||
|
|
||||||
logger.debug(f"GeoIP lookup for {ip_address}: {geo_info.get('country_name', 'Unknown')}")
|
logger.debug(
|
||||||
|
f"GeoIP lookup for {ip_address}: "
|
||||||
|
f"{geo_info.get('country_name', 'Unknown')}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"GeoIP lookup failed for {ip_address}")
|
logger.warning(f"GeoIP lookup failed for {ip_address}")
|
||||||
|
|
||||||
@@ -98,16 +102,18 @@ class LoginLogService:
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(login_log)
|
await db.refresh(login_log)
|
||||||
|
|
||||||
logger.info(f"Login recorded for user {user_id} from {ip_address} ({login_method})")
|
logger.info(
|
||||||
|
f"Login recorded for user {user_id} from {ip_address} ({login_method})"
|
||||||
|
)
|
||||||
return login_log
|
return login_log
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def record_failed_login(
|
async def record_failed_login(
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
request: Request,
|
request: Request,
|
||||||
attempted_username: Optional[str] = None,
|
attempted_username: str | None = None,
|
||||||
login_method: str = "password",
|
login_method: str = "password",
|
||||||
notes: Optional[str] = None
|
notes: str | None = None,
|
||||||
) -> UserLoginLog:
|
) -> UserLoginLog:
|
||||||
"""
|
"""
|
||||||
记录失败的登录尝试
|
记录失败的登录尝试
|
||||||
@@ -129,7 +135,9 @@ class LoginLogService:
|
|||||||
request=request,
|
request=request,
|
||||||
login_success=False,
|
login_success=False,
|
||||||
login_method=login_method,
|
login_method=login_method,
|
||||||
notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt"
|
notes=f"Failed login attempt: {attempted_username}"
|
||||||
|
if attempted_username
|
||||||
|
else "Failed login attempt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
6
main.py
6
main.py
@@ -37,13 +37,15 @@ if os.path.exists(newrelic_config_path):
|
|||||||
|
|
||||||
environment = os.environ.get(
|
environment = os.environ.get(
|
||||||
"NEW_RELIC_ENVIRONMENT",
|
"NEW_RELIC_ENVIRONMENT",
|
||||||
"production" if not settings.debug else "development"
|
"production" if not settings.debug else "development",
|
||||||
)
|
)
|
||||||
|
|
||||||
newrelic.agent.initialize(newrelic_config_path, environment)
|
newrelic.agent.initialize(newrelic_config_path, environment)
|
||||||
logger.info(f"[NewRelic] Enabled, environment: {environment}")
|
logger.info(f"[NewRelic] Enabled, environment: {environment}")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("[NewRelic] Config file found but 'newrelic' package is not installed")
|
logger.warning(
|
||||||
|
"[NewRelic] Config file found but 'newrelic' package is not installed"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[NewRelic] Initialization failed: {e}")
|
logger.error(f"[NewRelic] Initialization failed: {e}")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Revises: 3eef4794ded1
|
|||||||
Create Date: 2025-08-18 00:07:06.886879
|
Create Date: 2025-08-18 00:07:06.886879
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
@@ -24,27 +25,55 @@ depends_on: str | Sequence[str] | None = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
"""Upgrade schema."""
|
"""Upgrade schema."""
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table("user_login_log",
|
op.create_table(
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
"user_login_log",
|
||||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
sa.Column("ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False),
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
sa.Column("user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
|
sa.Column(
|
||||||
sa.Column("login_time", sa.DateTime(), nullable=False),
|
"ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False
|
||||||
sa.Column("country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True),
|
),
|
||||||
sa.Column("country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True),
|
sa.Column(
|
||||||
sa.Column("city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True),
|
"user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True
|
||||||
sa.Column("latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True),
|
),
|
||||||
sa.Column("longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True),
|
sa.Column("login_time", sa.DateTime(), nullable=False),
|
||||||
sa.Column("time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True),
|
sa.Column(
|
||||||
sa.Column("asn", sa.Integer(), nullable=True),
|
"country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True
|
||||||
sa.Column("organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True),
|
),
|
||||||
sa.Column("login_success", sa.Boolean(), nullable=False),
|
sa.Column(
|
||||||
sa.Column("login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False),
|
"country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True
|
||||||
sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
|
),
|
||||||
sa.PrimaryKeyConstraint("id")
|
sa.Column(
|
||||||
|
"city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("asn", sa.Integer(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("login_success", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_login_log_ip_address"),
|
||||||
|
"user_login_log",
|
||||||
|
["ip_address"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_login_log_user_id"), "user_login_log", ["user_id"], unique=False
|
||||||
)
|
)
|
||||||
op.create_index(op.f("ix_user_login_log_ip_address"), "user_login_log", ["ip_address"], unique=False)
|
|
||||||
op.create_index(op.f("ix_user_login_log_user_id"), "user_login_log", ["user_id"], unique=False)
|
|
||||||
op.drop_index(op.f("ix_userloginlog_ip_address"), table_name="userloginlog")
|
op.drop_index(op.f("ix_userloginlog_ip_address"), table_name="userloginlog")
|
||||||
op.drop_index(op.f("ix_userloginlog_user_id"), table_name="userloginlog")
|
op.drop_index(op.f("ix_userloginlog_user_id"), table_name="userloginlog")
|
||||||
op.drop_table("userloginlog")
|
op.drop_table("userloginlog")
|
||||||
@@ -54,30 +83,40 @@ def upgrade() -> None:
|
|||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
"""Downgrade schema."""
|
"""Downgrade schema."""
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table("userloginlog",
|
op.create_table(
|
||||||
sa.Column("id", mysql.INTEGER(), autoincrement=True, nullable=False),
|
"userloginlog",
|
||||||
sa.Column("user_id", mysql.INTEGER(), autoincrement=False, nullable=False),
|
sa.Column("id", mysql.INTEGER(), autoincrement=True, nullable=False),
|
||||||
sa.Column("ip_address", mysql.VARCHAR(length=45), nullable=False),
|
sa.Column("user_id", mysql.INTEGER(), autoincrement=False, nullable=False),
|
||||||
sa.Column("user_agent", mysql.VARCHAR(length=500), nullable=True),
|
sa.Column("ip_address", mysql.VARCHAR(length=45), nullable=False),
|
||||||
sa.Column("login_time", mysql.DATETIME(), nullable=False),
|
sa.Column("user_agent", mysql.VARCHAR(length=500), nullable=True),
|
||||||
sa.Column("country_code", mysql.VARCHAR(length=2), nullable=True),
|
sa.Column("login_time", mysql.DATETIME(), nullable=False),
|
||||||
sa.Column("country_name", mysql.VARCHAR(length=100), nullable=True),
|
sa.Column("country_code", mysql.VARCHAR(length=2), nullable=True),
|
||||||
sa.Column("city_name", mysql.VARCHAR(length=100), nullable=True),
|
sa.Column("country_name", mysql.VARCHAR(length=100), nullable=True),
|
||||||
sa.Column("latitude", mysql.VARCHAR(length=20), nullable=True),
|
sa.Column("city_name", mysql.VARCHAR(length=100), nullable=True),
|
||||||
sa.Column("longitude", mysql.VARCHAR(length=20), nullable=True),
|
sa.Column("latitude", mysql.VARCHAR(length=20), nullable=True),
|
||||||
sa.Column("time_zone", mysql.VARCHAR(length=50), nullable=True),
|
sa.Column("longitude", mysql.VARCHAR(length=20), nullable=True),
|
||||||
sa.Column("asn", mysql.INTEGER(), autoincrement=False, nullable=True),
|
sa.Column("time_zone", mysql.VARCHAR(length=50), nullable=True),
|
||||||
sa.Column("organization", mysql.VARCHAR(length=200), nullable=True),
|
sa.Column("asn", mysql.INTEGER(), autoincrement=False, nullable=True),
|
||||||
sa.Column("login_success", mysql.TINYINT(display_width=1), autoincrement=False, nullable=False),
|
sa.Column("organization", mysql.VARCHAR(length=200), nullable=True),
|
||||||
sa.Column("login_method", mysql.VARCHAR(length=50), nullable=False),
|
sa.Column(
|
||||||
sa.Column("notes", mysql.VARCHAR(length=500), nullable=True),
|
"login_success",
|
||||||
sa.PrimaryKeyConstraint("id"),
|
mysql.TINYINT(display_width=1),
|
||||||
mysql_collate="utf8mb4_0900_ai_ci",
|
autoincrement=False,
|
||||||
mysql_default_charset="utf8mb4",
|
nullable=False,
|
||||||
mysql_engine="InnoDB"
|
),
|
||||||
|
sa.Column("login_method", mysql.VARCHAR(length=50), nullable=False),
|
||||||
|
sa.Column("notes", mysql.VARCHAR(length=500), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
mysql_collate="utf8mb4_0900_ai_ci",
|
||||||
|
mysql_default_charset="utf8mb4",
|
||||||
|
mysql_engine="InnoDB",
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False
|
||||||
)
|
)
|
||||||
op.create_index(op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False)
|
|
||||||
op.create_index(op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False)
|
|
||||||
op.drop_index(op.f("ix_user_login_log_user_id"), table_name="user_login_log")
|
op.drop_index(op.f("ix_user_login_log_user_id"), table_name="user_login_log")
|
||||||
op.drop_index(op.f("ix_user_login_log_ip_address"), table_name="user_login_log")
|
op.drop_index(op.f("ix_user_login_log_ip_address"), table_name="user_login_log")
|
||||||
op.drop_table("user_login_log")
|
op.drop_table("user_login_log")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Revises: df9f725a077c
|
|||||||
Create Date: 2025-08-18 00:00:11.369944
|
Create Date: 2025-08-18 00:00:11.369944
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
@@ -23,27 +24,52 @@ depends_on: str | Sequence[str] | None = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
"""Upgrade schema."""
|
"""Upgrade schema."""
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table("userloginlog",
|
op.create_table(
|
||||||
sa.Column("id", sa.Integer(), nullable=False),
|
"userloginlog",
|
||||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
sa.Column("ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False),
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
sa.Column("user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
|
sa.Column(
|
||||||
sa.Column("login_time", sa.DateTime(), nullable=False),
|
"ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False
|
||||||
sa.Column("country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True),
|
),
|
||||||
sa.Column("country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True),
|
sa.Column(
|
||||||
sa.Column("city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True),
|
"user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True
|
||||||
sa.Column("latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True),
|
),
|
||||||
sa.Column("longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True),
|
sa.Column("login_time", sa.DateTime(), nullable=False),
|
||||||
sa.Column("time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True),
|
sa.Column(
|
||||||
sa.Column("asn", sa.Integer(), nullable=True),
|
"country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True
|
||||||
sa.Column("organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True),
|
),
|
||||||
sa.Column("login_success", sa.Boolean(), nullable=False),
|
sa.Column(
|
||||||
sa.Column("login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False),
|
"country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True
|
||||||
sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
|
),
|
||||||
sa.PrimaryKeyConstraint("id")
|
sa.Column(
|
||||||
|
"city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("asn", sa.Integer(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("login_success", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False
|
||||||
)
|
)
|
||||||
op.create_index(op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False)
|
|
||||||
op.create_index(op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False)
|
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user