chore(lint): make ruff happy

This commit is contained in:
MingxuanGame
2025-08-17 16:57:27 +00:00
parent 3c460f1d82
commit 86bea5d4b5
13 changed files with 316 additions and 181 deletions

View File

@@ -119,7 +119,6 @@ async def authenticate_user_legacy(
if not user: if not user:
return None return None
await db.refresh(user) await db.refresh(user)
# 3. 验证密码 # 3. 验证密码
@@ -265,7 +264,6 @@ async def get_user_by_authorization_code(
statement = select(User).where(User.id == int(user_id)) statement = select(User).where(User.id == int(user_id))
user = (await db.exec(statement)).first() user = (await db.exec(statement)).first()
if user: if user:
await db.refresh(user) await db.refresh(user)
return (user, scopes.split(",")) return (user, scopes.split(","))
return None return None

View File

@@ -3,9 +3,17 @@ from __future__ import annotations
from enum import Enum from enum import Enum
from typing import Annotated, Any from typing import Annotated, Any
from pydantic import AliasChoices, Field, HttpUrl, ValidationInfo, field_validator, BeforeValidator from pydantic import (
AliasChoices,
BeforeValidator,
Field,
HttpUrl,
ValidationInfo,
field_validator,
)
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
def _parse_list(v): def _parse_list(v):
if v is None or v == "" or str(v).strip() in ("[]", "{}"): if v is None or v == "" or str(v).strip() in ("[]", "{}"):
return [] return []
@@ -14,6 +22,7 @@ def _parse_list(v):
s = str(v).strip() s = str(v).strip()
try: try:
import json import json
parsed = json.loads(s) parsed = json.loads(s)
if isinstance(parsed, list): if isinstance(parsed, list):
return parsed return parsed
@@ -21,6 +30,7 @@ def _parse_list(v):
pass pass
return [x.strip() for x in s.split(",") if x.strip()] return [x.strip() for x in s.split(",") if x.strip()]
class AWSS3StorageSettings(BaseSettings): class AWSS3StorageSettings(BaseSettings):
s3_access_key_id: str s3_access_key_id: str
s3_secret_access_key: str s3_secret_access_key: str

View File

@@ -1,40 +1,59 @@
# -*- coding: utf-8 -*-
""" """
User Login Log Database Model User Login Log Database Model
""" """
from datetime import datetime from datetime import datetime
from typing import Optional
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
class UserLoginLog(SQLModel, table=True): class UserLoginLog(SQLModel, table=True):
"""User login log table""" """User login log table"""
__tablename__ = "user_login_log" # pyright: ignore[reportAssignmentType] __tablename__ = "user_login_log" # pyright: ignore[reportAssignmentType]
id: Optional[int] = Field(default=None, primary_key=True, description="Record ID") id: int | None = Field(default=None, primary_key=True, description="Record ID")
user_id: int = Field(index=True, description="User ID") user_id: int = Field(index=True, description="User ID")
ip_address: str = Field(max_length=45, index=True, description="IP address (supports IPv4 and IPv6)") ip_address: str = Field(
user_agent: Optional[str] = Field(default=None, max_length=500, description="User agent information") max_length=45, index=True, description="IP address (supports IPv4 and IPv6)"
login_time: datetime = Field(default_factory=datetime.utcnow, description="Login time") )
user_agent: str | None = Field(
default=None, max_length=500, description="User agent information"
)
login_time: datetime = Field(
default_factory=datetime.utcnow, description="Login time"
)
# GeoIP information # GeoIP information
country_code: Optional[str] = Field(default=None, max_length=2, description="Country code") country_code: str | None = Field(
country_name: Optional[str] = Field(default=None, max_length=100, description="Country name") default=None, max_length=2, description="Country code"
city_name: Optional[str] = Field(default=None, max_length=100, description="City name") )
latitude: Optional[str] = Field(default=None, max_length=20, description="Latitude") country_name: str | None = Field(
longitude: Optional[str] = Field(default=None, max_length=20, description="Longitude") default=None, max_length=100, description="Country name"
time_zone: Optional[str] = Field(default=None, max_length=50, description="Time zone") )
city_name: str | None = Field(default=None, max_length=100, description="City name")
latitude: str | None = Field(default=None, max_length=20, description="Latitude")
longitude: str | None = Field(default=None, max_length=20, description="Longitude")
time_zone: str | None = Field(default=None, max_length=50, description="Time zone")
# ASN information # ASN information
asn: Optional[int] = Field(default=None, description="Autonomous System Number") asn: int | None = Field(default=None, description="Autonomous System Number")
organization: Optional[str] = Field(default=None, max_length=200, description="Organization name") organization: str | None = Field(
default=None, max_length=200, description="Organization name"
)
# Login status # Login status
login_success: bool = Field(default=True, description="Whether the login was successful") login_success: bool = Field(
login_method: str = Field(max_length=50, description="Login method (password/oauth/etc.)") default=True, description="Whether the login was successful"
)
login_method: str = Field(
max_length=50, description="Login method (password/oauth/etc.)"
)
# Additional information # Additional information
notes: Optional[str] = Field(default=None, max_length=500, description="Additional notes") notes: str | None = Field(
default=None, max_length=500, description="Additional notes"
)
class Config: class Config:
from_attributes = True from_attributes = True

View File

@@ -1,13 +1,17 @@
# -*- coding: utf-8 -*-
""" """
GeoIP dependency for FastAPI GeoIP dependency for FastAPI
""" """
import ipaddress
from functools import lru_cache
from app.helpers.geoip_helper import GeoIPHelper
from app.config import settings
@lru_cache() from __future__ import annotations
from functools import lru_cache
import ipaddress
from app.config import settings
from app.helpers.geoip_helper import GeoIPHelper
@lru_cache
def get_geoip_helper() -> GeoIPHelper: def get_geoip_helper() -> GeoIPHelper:
""" """
获取 GeoIP 帮助类实例 获取 GeoIP 帮助类实例
@@ -18,7 +22,7 @@ def get_geoip_helper() -> GeoIPHelper:
license_key=settings.maxmind_license_key, license_key=settings.maxmind_license_key,
editions=["City", "ASN"], editions=["City", "ASN"],
max_age_days=8, max_age_days=8,
timeout=60.0 timeout=60.0,
) )

View File

@@ -89,8 +89,7 @@ async def get_client_user(
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
if not user: if not user:
raise HTTPException(status_code=401, detail="Invalid or expired token") raise HTTPException(status_code=401, detail="Invalid or expired token")
await db.refresh(user) await db.refresh(user)
return user return user
@@ -128,7 +127,6 @@ async def get_current_user(
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
if not user: if not user:
raise HTTPException(status_code=401, detail="Invalid or expired token") raise HTTPException(status_code=401, detail="Invalid or expired token")
await db.refresh(user) await db.refresh(user)
return user return user

View File

@@ -1,21 +1,36 @@
# -*- coding: utf-8 -*-
""" """
GeoLite2 Helper Class GeoLite2 Helper Class
""" """
from __future__ import annotations
import os import os
import tarfile from pathlib import Path
import shutil import shutil
import tarfile
import tempfile import tempfile
import time import time
import httpx import httpx
import maxminddb import maxminddb
from pathlib import Path
BASE_URL = "https://download.maxmind.com/app/geoip_download"
EDITIONS = {
"City": "GeoLite2-City",
"Country": "GeoLite2-Country",
"ASN": "GeoLite2-ASN",
}
class GeoIPHelper: class GeoIPHelper:
BASE_URL = "https://download.maxmind.com/app/geoip_download" def __init__(
EDITIONS = {"City": "GeoLite2-City", "Country": "GeoLite2-Country", "ASN": "GeoLite2-ASN"} self,
dest_dir="./geoip",
def __init__(self, dest_dir="./geoip", license_key=None, editions=None, max_age_days=8, timeout=60.0): license_key=None,
editions=None,
max_age_days=8,
timeout=60.0,
):
self.dest_dir = dest_dir self.dest_dir = dest_dir
self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY") self.license_key = license_key or os.getenv("MAXMIND_LICENSE_KEY")
self.editions = editions or ["City", "ASN"] self.editions = editions or ["City", "ASN"]
@@ -30,7 +45,7 @@ class GeoIPHelper:
target = (base / m.name).resolve() target = (base / m.name).resolve()
if not str(target).startswith(str(base)): if not str(target).startswith(str(base)):
raise RuntimeError("Unsafe path in tar file") raise RuntimeError("Unsafe path in tar file")
tar.extractall(path=path, filter='data') tar.extractall(path=path, filter="data")
def _download_and_extract(self, edition_id: str) -> str: def _download_and_extract(self, edition_id: str) -> str:
""" """
@@ -40,9 +55,14 @@ class GeoIPHelper:
- 临时目录退出后自动清理 - 临时目录退出后自动清理
""" """
if not self.license_key: if not self.license_key:
raise ValueError("缺少 MaxMind License Key请传入或设置环境变量 MAXMIND_LICENSE_KEY") raise ValueError(
"缺少 MaxMind License Key请传入或设置环境变量 MAXMIND_LICENSE_KEY"
)
url = f"{self.BASE_URL}?edition_id={edition_id}&license_key={self.license_key}&suffix=tar.gz" url = (
f"{BASE_URL}?edition_id={edition_id}&"
f"license_key={self.license_key}&suffix=tar.gz"
)
with httpx.Client(follow_redirects=True, timeout=self.timeout) as client: with httpx.Client(follow_redirects=True, timeout=self.timeout) as client:
with client.stream("GET", url) as resp: with client.stream("GET", url) as resp:
@@ -81,13 +101,16 @@ class GeoIPHelper:
def _latest_file(self, edition_id: str): def _latest_file(self, edition_id: str):
if not os.path.isdir(self.dest_dir): if not os.path.isdir(self.dest_dir):
return None return None
files = [os.path.join(self.dest_dir, f) for f in os.listdir(self.dest_dir) files = [
if f.startswith(edition_id) and f.endswith(".mmdb")] os.path.join(self.dest_dir, f)
for f in os.listdir(self.dest_dir)
if f.startswith(edition_id) and f.endswith(".mmdb")
]
return max(files, key=os.path.getmtime) if files else None return max(files, key=os.path.getmtime) if files else None
def update(self, force=False): def update(self, force=False):
for ed in self.editions: for ed in self.editions:
eid = self.EDITIONS[ed] eid = EDITIONS[ed]
path = self._latest_file(eid) path = self._latest_file(eid)
need = force or not path need = force or not path
if path: if path:
@@ -97,12 +120,11 @@ class GeoIPHelper:
if need: if need:
path = self._download_and_extract(eid) path = self._download_and_extract(eid)
old = self._readers.get(ed) old = self._readers.get(ed)
if old: if old:
try: try:
old.close() old.close()
except: except Exception:
pass pass
if path is not None: if path is not None:
self._readers[ed] = maxminddb.open_database(path) self._readers[ed] = maxminddb.open_database(path)
@@ -139,15 +161,14 @@ class GeoIPHelper:
for r in self._readers.values(): for r in self._readers.values():
try: try:
r.close() r.close()
except: except Exception:
pass pass
self._readers = {} self._readers = {}
if __name__ == "__main__": if __name__ == "__main__":
# 示例用法 # 示例用法
geo = GeoIPHelper(dest_dir="./geoip", license_key="") geo = GeoIPHelper(dest_dir="./geoip", license_key="")
geo.update() geo.update()
print(geo.lookup("8.8.8.8")) print(geo.lookup("8.8.8.8"))
geo.close() geo.close()

View File

@@ -20,10 +20,9 @@ from app.database import DailyChallengeStats, OAuthClient, User
from app.database.statistics import UserStatistics from app.database.statistics import UserStatistics
from app.dependencies import get_db from app.dependencies import get_db
from app.dependencies.database import get_redis from app.dependencies.database import get_redis
from app.dependencies.geoip import get_geoip_helper, get_client_ip from app.dependencies.geoip import get_client_ip, get_geoip_helper
from app.helpers.geoip_helper import GeoIPHelper from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger from app.log import logger
from app.service.login_log_service import LoginLogService
from app.models.oauth import ( from app.models.oauth import (
OAuthErrorResponse, OAuthErrorResponse,
RegistrationRequestErrors, RegistrationRequestErrors,
@@ -31,6 +30,7 @@ from app.models.oauth import (
UserRegistrationErrors, UserRegistrationErrors,
) )
from app.models.score import GameMode from app.models.score import GameMode
from app.service.login_log_service import LoginLogService
from fastapi import APIRouter, Depends, Form, Request from fastapi import APIRouter, Depends, Form, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@@ -82,6 +82,7 @@ def validate_password(password: str) -> list[str]:
router = APIRouter(tags=["osu! OAuth 认证"]) router = APIRouter(tags=["osu! OAuth 认证"])
@router.post( @router.post(
"/users", "/users",
name="注册用户", name="注册用户",
@@ -93,9 +94,8 @@ async def register_user(
user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"), user_email: str = Form(..., alias="user[user_email]", description="电子邮箱"),
user_password: str = Form(..., alias="user[password]", description="密码"), user_password: str = Form(..., alias="user[password]", description="密码"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
geoip: GeoIPHelper = Depends(get_geoip_helper) geoip: GeoIPHelper = Depends(get_geoip_helper),
): ):
username_errors = validate_username(user_username) username_errors = validate_username(user_username)
email_errors = validate_email(user_email) email_errors = validate_email(user_email)
password_errors = validate_password(user_password) password_errors = validate_password(user_password)
@@ -127,18 +127,21 @@ async def register_user(
# 获取客户端 IP 并查询地理位置 # 获取客户端 IP 并查询地理位置
client_ip = get_client_ip(request) client_ip = get_client_ip(request)
country_code = "CN" # 默认国家代码 country_code = "CN" # 默认国家代码
try: try:
# 查询 IP 地理位置 # 查询 IP 地理位置
geo_info = geoip.lookup(client_ip) geo_info = geoip.lookup(client_ip)
if geo_info and geo_info.get("country_iso"): if geo_info and geo_info.get("country_iso"):
country_code = geo_info["country_iso"] country_code = geo_info["country_iso"]
logger.info(f"User {user_username} registering from {client_ip}, country: {country_code}") logger.info(
f"User {user_username} registering from "
f"{client_ip}, country: {country_code}"
)
else: else:
logger.warning(f"Could not determine country for IP {client_ip}") logger.warning(f"Could not determine country for IP {client_ip}")
except Exception as e: except Exception as e:
logger.warning(f"GeoIP lookup failed for {client_ip}: {e}") logger.warning(f"GeoIP lookup failed for {client_ip}: {e}")
# 创建新用户 # 创建新用户
# 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy # 确保 AUTO_INCREMENT 值从3开始ID=1是BanchoBotID=2预留给ppy
result = await db.execute( # pyright: ignore[reportDeprecated] result = await db.execute( # pyright: ignore[reportDeprecated]
@@ -276,9 +279,9 @@ async def oauth_token(
request=request, request=request,
attempted_username=username, attempted_username=username,
login_method="password", login_method="password",
notes="Invalid credentials" notes="Invalid credentials",
) )
return create_oauth_error_response( return create_oauth_error_response(
error="invalid_grant", error="invalid_grant",
description=( description=(
@@ -293,9 +296,9 @@ async def oauth_token(
# 确保用户对象与当前会话关联 # 确保用户对象与当前会话关联
await db.refresh(user) await db.refresh(user)
# 记录成功的登录 # 记录成功的登录
user_id = getattr(user, 'id') user_id = getattr(user, "id")
assert user_id is not None, "User ID should not be None after authentication" assert user_id is not None, "User ID should not be None after authentication"
await LoginLogService.record_login( await LoginLogService.record_login(
db=db, db=db,
@@ -303,7 +306,7 @@ async def oauth_token(
request=request, request=request,
login_success=True, login_success=True,
login_method="password", login_method="password",
notes=f"OAuth password grant for client {client_id}" notes=f"OAuth password grant for client {client_id}",
) )
# 生成令牌 # 生成令牌
@@ -424,16 +427,16 @@ async def oauth_token(
hint="Invalid authorization code", hint="Invalid authorization code",
) )
user, scopes = code_result user, scopes = code_result
# 确保用户对象与当前会话关联 # 确保用户对象与当前会话关联
await db.refresh(user) await db.refresh(user)
# 生成令牌 # 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
# 重新查询只获取ID避免触发延迟加载 # 重新查询只获取ID避免触发延迟加载
id_result = await db.exec(select(User.id).where(User.username == username)) id_result = await db.exec(select(User.id).where(User.username == username))
user_id = id_result.first() user_id = id_result.first()
access_token = create_access_token( access_token = create_access_token(
data={"sub": str(user_id)}, expires_delta=access_token_expires data={"sub": str(user_id)}, expires_delta=access_token_expires
) )

View File

@@ -1,10 +1,12 @@
# -*- coding: utf-8 -*-
""" """
[GeoIP] Scheduled Update Service [GeoIP] Scheduled Update Service
Periodically update the MaxMind GeoIP database Periodically update the MaxMind GeoIP database
""" """
from __future__ import annotations
import asyncio import asyncio
from datetime import datetime
from app.config import settings from app.config import settings
from app.dependencies.geoip import get_geoip_helper from app.dependencies.geoip import get_geoip_helper
from app.dependencies.scheduler import get_scheduler from app.dependencies.scheduler import get_scheduler
@@ -18,11 +20,11 @@ async def update_geoip_database():
try: try:
logger.info("[GeoIP] Starting scheduled GeoIP database update...") logger.info("[GeoIP] Starting scheduled GeoIP database update...")
geoip = get_geoip_helper() geoip = get_geoip_helper()
# Run the synchronous update method in a background thread # Run the synchronous update method in a background thread
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: geoip.update(force=False)) await loop.run_in_executor(None, lambda: geoip.update(force=False))
logger.info("[GeoIP] Scheduled GeoIP database update completed successfully") logger.info("[GeoIP] Scheduled GeoIP database update completed successfully")
except Exception as e: except Exception as e:
logger.error(f"[GeoIP] Scheduled GeoIP database update failed: {e}") logger.error(f"[GeoIP] Scheduled GeoIP database update failed: {e}")
@@ -33,20 +35,21 @@ def schedule_geoip_updates():
Schedule the GeoIP database update task Schedule the GeoIP database update task
""" """
scheduler = get_scheduler() scheduler = get_scheduler()
# Use settings to configure the update time: update once a week # Use settings to configure the update time: update once a week
scheduler.add_job( scheduler.add_job(
update_geoip_database, update_geoip_database,
'cron', "cron",
day_of_week=settings.geoip_update_day, day_of_week=settings.geoip_update_day,
hour=settings.geoip_update_hour, hour=settings.geoip_update_hour,
minute=0, minute=0,
id='geoip_weekly_update', id="geoip_weekly_update",
name='Weekly GeoIP database update', name="Weekly GeoIP database update",
replace_existing=True replace_existing=True,
) )
logger.info( logger.info(
f"[GeoIP] Scheduled update task registered: " f"[GeoIP] Scheduled update task registered: "
f"every week on day {settings.geoip_update_day} at {settings.geoip_update_hour}:00" f"every week on day {settings.geoip_update_day} "
f"at {settings.geoip_update_hour}:00"
) )

View File

@@ -1,12 +1,16 @@
# -*- coding: utf-8 -*-
""" """
[GeoIP] Initialization Service [GeoIP] Initialization Service
Initialize the GeoIP database when the application starts Initialize the GeoIP database when the application starts
""" """
from __future__ import annotations
import asyncio import asyncio
from app.dependencies.geoip import get_geoip_helper from app.dependencies.geoip import get_geoip_helper
from app.log import logger from app.log import logger
async def init_geoip(): async def init_geoip():
""" """
Asynchronously initialize the GeoIP database Asynchronously initialize the GeoIP database
@@ -14,11 +18,11 @@ async def init_geoip():
try: try:
geoip = get_geoip_helper() geoip = get_geoip_helper()
logger.info("[GeoIP] Initializing GeoIP database...") logger.info("[GeoIP] Initializing GeoIP database...")
# Run the synchronous update method in a background thread # Run the synchronous update method in a background thread
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor(None, geoip.update) await loop.run_in_executor(None, geoip.update)
logger.info("[GeoIP] GeoIP database initialization completed") logger.info("[GeoIP] GeoIP database initialization completed")
except Exception as e: except Exception as e:
logger.error(f"[GeoIP] GeoIP database initialization failed: {e}") logger.error(f"[GeoIP] GeoIP database initialization failed: {e}")

View File

@@ -1,21 +1,23 @@
# -*- coding: utf-8 -*-
""" """
用户登录记录服务 用户登录记录服务
""" """
from __future__ import annotations
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from typing import Optional
from fastapi import Request
from sqlmodel.ext.asyncio.session import AsyncSession
from app.database.user_login_log import UserLoginLog from app.database.user_login_log import UserLoginLog
from app.dependencies.geoip import get_geoip_helper, get_client_ip, normalize_ip from app.dependencies.geoip import get_client_ip, get_geoip_helper, normalize_ip
from app.log import logger from app.log import logger
from fastapi import Request
from sqlmodel.ext.asyncio.session import AsyncSession
class LoginLogService: class LoginLogService:
"""用户登录记录服务""" """用户登录记录服务"""
@staticmethod @staticmethod
async def record_login( async def record_login(
db: AsyncSession, db: AsyncSession,
@@ -23,11 +25,11 @@ class LoginLogService:
request: Request, request: Request,
login_success: bool = True, login_success: bool = True,
login_method: str = "password", login_method: str = "password",
notes: Optional[str] = None notes: str | None = None,
) -> UserLoginLog: ) -> UserLoginLog:
""" """
记录用户登录信息 记录用户登录信息
Args: Args:
db: 数据库会话 db: 数据库会话
user_id: 用户ID user_id: 用户ID
@@ -35,17 +37,17 @@ class LoginLogService:
login_success: 登录是否成功 login_success: 登录是否成功
login_method: 登录方式 login_method: 登录方式
notes: 备注信息 notes: 备注信息
Returns: Returns:
UserLoginLog: 登录记录对象 UserLoginLog: 登录记录对象
""" """
# 获取客户端IP并标准化格式 # 获取客户端IP并标准化格式
raw_ip = get_client_ip(request) raw_ip = get_client_ip(request)
ip_address = normalize_ip(raw_ip) ip_address = normalize_ip(raw_ip)
# 获取User-Agent # 获取User-Agent
user_agent = request.headers.get("User-Agent", "") user_agent = request.headers.get("User-Agent", "")
# 创建基本的登录记录 # 创建基本的登录记录
login_log = UserLoginLog( login_log = UserLoginLog(
user_id=user_id, user_id=user_id,
@@ -54,20 +56,19 @@ class LoginLogService:
login_time=datetime.utcnow(), login_time=datetime.utcnow(),
login_success=login_success, login_success=login_success,
login_method=login_method, login_method=login_method,
notes=notes notes=notes,
) )
# 异步获取GeoIP信息 # 异步获取GeoIP信息
try: try:
geoip = get_geoip_helper() geoip = get_geoip_helper()
# 在后台线程中运行GeoIP查询避免阻塞 # 在后台线程中运行GeoIP查询避免阻塞
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
geo_info = await loop.run_in_executor( geo_info = await loop.run_in_executor(
None, None, lambda: geoip.lookup(ip_address)
lambda: geoip.lookup(ip_address)
) )
if geo_info: if geo_info:
login_log.country_code = geo_info.get("country_iso", "") login_log.country_code = geo_info.get("country_iso", "")
login_log.country_name = geo_info.get("country_name", "") login_log.country_name = geo_info.get("country_name", "")
@@ -75,7 +76,7 @@ class LoginLogService:
login_log.latitude = geo_info.get("latitude", "") login_log.latitude = geo_info.get("latitude", "")
login_log.longitude = geo_info.get("longitude", "") login_log.longitude = geo_info.get("longitude", "")
login_log.time_zone = geo_info.get("time_zone", "") login_log.time_zone = geo_info.get("time_zone", "")
# 处理 ASN可能是字符串需要转换为整数 # 处理 ASN可能是字符串需要转换为整数
asn_value = geo_info.get("asn") asn_value = geo_info.get("asn")
if asn_value is not None: if asn_value is not None:
@@ -83,42 +84,47 @@ class LoginLogService:
login_log.asn = int(asn_value) login_log.asn = int(asn_value)
except (ValueError, TypeError): except (ValueError, TypeError):
login_log.asn = None login_log.asn = None
login_log.organization = geo_info.get("organization", "") login_log.organization = geo_info.get("organization", "")
logger.debug(f"GeoIP lookup for {ip_address}: {geo_info.get('country_name', 'Unknown')}") logger.debug(
f"GeoIP lookup for {ip_address}: "
f"{geo_info.get('country_name', 'Unknown')}"
)
else: else:
logger.warning(f"GeoIP lookup failed for {ip_address}") logger.warning(f"GeoIP lookup failed for {ip_address}")
except Exception as e: except Exception as e:
logger.warning(f"GeoIP lookup error for {ip_address}: {e}") logger.warning(f"GeoIP lookup error for {ip_address}: {e}")
# 保存到数据库 # 保存到数据库
db.add(login_log) db.add(login_log)
await db.commit() await db.commit()
await db.refresh(login_log) await db.refresh(login_log)
logger.info(f"Login recorded for user {user_id} from {ip_address} ({login_method})") logger.info(
f"Login recorded for user {user_id} from {ip_address} ({login_method})"
)
return login_log return login_log
@staticmethod @staticmethod
async def record_failed_login( async def record_failed_login(
db: AsyncSession, db: AsyncSession,
request: Request, request: Request,
attempted_username: Optional[str] = None, attempted_username: str | None = None,
login_method: str = "password", login_method: str = "password",
notes: Optional[str] = None notes: str | None = None,
) -> UserLoginLog: ) -> UserLoginLog:
""" """
记录失败的登录尝试 记录失败的登录尝试
Args: Args:
db: 数据库会话 db: 数据库会话
request: HTTP请求对象 request: HTTP请求对象
attempted_username: 尝试登录的用户名 attempted_username: 尝试登录的用户名
login_method: 登录方式 login_method: 登录方式
notes: 备注信息 notes: 备注信息
Returns: Returns:
UserLoginLog: 登录记录对象 UserLoginLog: 登录记录对象
""" """
@@ -129,17 +135,19 @@ class LoginLogService:
request=request, request=request,
login_success=False, login_success=False,
login_method=login_method, login_method=login_method,
notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt" notes=f"Failed login attempt: {attempted_username}"
if attempted_username
else "Failed login attempt",
) )
def get_request_info(request: Request) -> dict: def get_request_info(request: Request) -> dict:
""" """
提取请求的详细信息 提取请求的详细信息
Args: Args:
request: HTTP请求对象 request: HTTP请求对象
Returns: Returns:
dict: 包含请求信息的字典 dict: 包含请求信息的字典
""" """

View File

@@ -37,13 +37,15 @@ if os.path.exists(newrelic_config_path):
environment = os.environ.get( environment = os.environ.get(
"NEW_RELIC_ENVIRONMENT", "NEW_RELIC_ENVIRONMENT",
"production" if not settings.debug else "development" "production" if not settings.debug else "development",
) )
newrelic.agent.initialize(newrelic_config_path, environment) newrelic.agent.initialize(newrelic_config_path, environment)
logger.info(f"[NewRelic] Enabled, environment: {environment}") logger.info(f"[NewRelic] Enabled, environment: {environment}")
except ImportError: except ImportError:
logger.warning("[NewRelic] Config file found but 'newrelic' package is not installed") logger.warning(
"[NewRelic] Config file found but 'newrelic' package is not installed"
)
except Exception as e: except Exception as e:
logger.error(f"[NewRelic] Initialization failed: {e}") logger.error(f"[NewRelic] Initialization failed: {e}")
else: else:

View File

@@ -5,6 +5,7 @@ Revises: 3eef4794ded1
Create Date: 2025-08-18 00:07:06.886879 Create Date: 2025-08-18 00:07:06.886879
""" """
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
@@ -24,27 +25,55 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
"""Upgrade schema.""" """Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table("user_login_log", op.create_table(
sa.Column("id", sa.Integer(), nullable=False), "user_login_log",
sa.Column("user_id", sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column("ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), sa.Column(
sa.Column("login_time", sa.DateTime(), nullable=False), "ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False
sa.Column("country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True), ),
sa.Column("country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), sa.Column(
sa.Column("city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), "user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True
sa.Column("latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True), ),
sa.Column("longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True), sa.Column("login_time", sa.DateTime(), nullable=False),
sa.Column("time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True), sa.Column(
sa.Column("asn", sa.Integer(), nullable=True), "country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True
sa.Column("organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True), ),
sa.Column("login_success", sa.Boolean(), nullable=False), sa.Column(
sa.Column("login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), "country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True
sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), ),
sa.PrimaryKeyConstraint("id") sa.Column(
"city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True
),
sa.Column(
"latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True
),
sa.Column(
"longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True
),
sa.Column(
"time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True
),
sa.Column("asn", sa.Integer(), nullable=True),
sa.Column(
"organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True
),
sa.Column("login_success", sa.Boolean(), nullable=False),
sa.Column(
"login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False
),
sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_user_login_log_ip_address"),
"user_login_log",
["ip_address"],
unique=False,
)
op.create_index(
op.f("ix_user_login_log_user_id"), "user_login_log", ["user_id"], unique=False
) )
op.create_index(op.f("ix_user_login_log_ip_address"), "user_login_log", ["ip_address"], unique=False)
op.create_index(op.f("ix_user_login_log_user_id"), "user_login_log", ["user_id"], unique=False)
op.drop_index(op.f("ix_userloginlog_ip_address"), table_name="userloginlog") op.drop_index(op.f("ix_userloginlog_ip_address"), table_name="userloginlog")
op.drop_index(op.f("ix_userloginlog_user_id"), table_name="userloginlog") op.drop_index(op.f("ix_userloginlog_user_id"), table_name="userloginlog")
op.drop_table("userloginlog") op.drop_table("userloginlog")
@@ -54,30 +83,40 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
"""Downgrade schema.""" """Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table("userloginlog", op.create_table(
sa.Column("id", mysql.INTEGER(), autoincrement=True, nullable=False), "userloginlog",
sa.Column("user_id", mysql.INTEGER(), autoincrement=False, nullable=False), sa.Column("id", mysql.INTEGER(), autoincrement=True, nullable=False),
sa.Column("ip_address", mysql.VARCHAR(length=45), nullable=False), sa.Column("user_id", mysql.INTEGER(), autoincrement=False, nullable=False),
sa.Column("user_agent", mysql.VARCHAR(length=500), nullable=True), sa.Column("ip_address", mysql.VARCHAR(length=45), nullable=False),
sa.Column("login_time", mysql.DATETIME(), nullable=False), sa.Column("user_agent", mysql.VARCHAR(length=500), nullable=True),
sa.Column("country_code", mysql.VARCHAR(length=2), nullable=True), sa.Column("login_time", mysql.DATETIME(), nullable=False),
sa.Column("country_name", mysql.VARCHAR(length=100), nullable=True), sa.Column("country_code", mysql.VARCHAR(length=2), nullable=True),
sa.Column("city_name", mysql.VARCHAR(length=100), nullable=True), sa.Column("country_name", mysql.VARCHAR(length=100), nullable=True),
sa.Column("latitude", mysql.VARCHAR(length=20), nullable=True), sa.Column("city_name", mysql.VARCHAR(length=100), nullable=True),
sa.Column("longitude", mysql.VARCHAR(length=20), nullable=True), sa.Column("latitude", mysql.VARCHAR(length=20), nullable=True),
sa.Column("time_zone", mysql.VARCHAR(length=50), nullable=True), sa.Column("longitude", mysql.VARCHAR(length=20), nullable=True),
sa.Column("asn", mysql.INTEGER(), autoincrement=False, nullable=True), sa.Column("time_zone", mysql.VARCHAR(length=50), nullable=True),
sa.Column("organization", mysql.VARCHAR(length=200), nullable=True), sa.Column("asn", mysql.INTEGER(), autoincrement=False, nullable=True),
sa.Column("login_success", mysql.TINYINT(display_width=1), autoincrement=False, nullable=False), sa.Column("organization", mysql.VARCHAR(length=200), nullable=True),
sa.Column("login_method", mysql.VARCHAR(length=50), nullable=False), sa.Column(
sa.Column("notes", mysql.VARCHAR(length=500), nullable=True), "login_success",
sa.PrimaryKeyConstraint("id"), mysql.TINYINT(display_width=1),
mysql_collate="utf8mb4_0900_ai_ci", autoincrement=False,
mysql_default_charset="utf8mb4", nullable=False,
mysql_engine="InnoDB" ),
sa.Column("login_method", mysql.VARCHAR(length=50), nullable=False),
sa.Column("notes", mysql.VARCHAR(length=500), nullable=True),
sa.PrimaryKeyConstraint("id"),
mysql_collate="utf8mb4_0900_ai_ci",
mysql_default_charset="utf8mb4",
mysql_engine="InnoDB",
)
op.create_index(
op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False
)
op.create_index(
op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False
) )
op.create_index(op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False)
op.create_index(op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False)
op.drop_index(op.f("ix_user_login_log_user_id"), table_name="user_login_log") op.drop_index(op.f("ix_user_login_log_user_id"), table_name="user_login_log")
op.drop_index(op.f("ix_user_login_log_ip_address"), table_name="user_login_log") op.drop_index(op.f("ix_user_login_log_ip_address"), table_name="user_login_log")
op.drop_table("user_login_log") op.drop_table("user_login_log")

View File

@@ -5,6 +5,7 @@ Revises: df9f725a077c
Create Date: 2025-08-18 00:00:11.369944 Create Date: 2025-08-18 00:00:11.369944
""" """
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
@@ -23,27 +24,52 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
"""Upgrade schema.""" """Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table("userloginlog", op.create_table(
sa.Column("id", sa.Integer(), nullable=False), "userloginlog",
sa.Column("user_id", sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column("ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), sa.Column(
sa.Column("login_time", sa.DateTime(), nullable=False), "ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False
sa.Column("country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True), ),
sa.Column("country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), sa.Column(
sa.Column("city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), "user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True
sa.Column("latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True), ),
sa.Column("longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True), sa.Column("login_time", sa.DateTime(), nullable=False),
sa.Column("time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True), sa.Column(
sa.Column("asn", sa.Integer(), nullable=True), "country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True
sa.Column("organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True), ),
sa.Column("login_success", sa.Boolean(), nullable=False), sa.Column(
sa.Column("login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), "country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True
sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), ),
sa.PrimaryKeyConstraint("id") sa.Column(
"city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True
),
sa.Column(
"latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True
),
sa.Column(
"longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True
),
sa.Column(
"time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True
),
sa.Column("asn", sa.Integer(), nullable=True),
sa.Column(
"organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True
),
sa.Column("login_success", sa.Boolean(), nullable=False),
sa.Column(
"login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False
),
sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False
)
op.create_index(
op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False
) )
op.create_index(op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False)
op.create_index(op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False)
# ### end Alembic commands ### # ### end Alembic commands ###