add ip log

This commit is contained in:
咕谷酱
2025-08-18 00:23:57 +08:00
parent de0c86f4a2
commit 6e496a1123
9 changed files with 450 additions and 19 deletions

View File

@@ -114,11 +114,14 @@ async def authenticate_user_legacy(
pw_md5 = hashlib.md5(password.encode()).hexdigest()
# 2. 根据用户名查找用户
statement = select(User).where(User.username == name)
statement = select(User).where(User.username == name).options()
user = (await db.exec(statement)).first()
if not user:
return None
await db.refresh(user)
# 3. 验证密码
if user.pw_bcrypt is None or user.pw_bcrypt == "":
return None
@@ -261,4 +264,8 @@ async def get_user_by_authorization_code(
statement = select(User).where(User.id == int(user_id))
user = (await db.exec(statement)).first()
return (user, scopes.split(",")) if user else None
if user:
await db.refresh(user)
return (user, scopes.split(","))
return None

View File

@@ -60,6 +60,7 @@ from .user_account_history import (
UserAccountHistoryResp,
UserAccountHistoryType,
)
from .user_login_log import UserLoginLog
__all__ = [
"APIUploadedRoom",
@@ -118,6 +119,7 @@ __all__ = [
"UserAchievement",
"UserAchievement",
"UserAchievementResp",
"UserLoginLog",
"UserResp",
"UserStatistics",
"UserStatisticsResp",

View File

@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
"""
User Login Log Database Model
"""
from datetime import datetime
from typing import Optional
from sqlmodel import Field, SQLModel
class UserLoginLog(SQLModel, table=True):
"""User login log table"""
__tablename__ = "user_login_log" # pyright: ignore[reportAssignmentType]
id: Optional[int] = Field(default=None, primary_key=True, description="Record ID")
user_id: int = Field(index=True, description="User ID")
ip_address: str = Field(max_length=45, index=True, description="IP address (supports IPv4 and IPv6)")
user_agent: Optional[str] = Field(default=None, max_length=500, description="User agent information")
login_time: datetime = Field(default_factory=datetime.utcnow, description="Login time")
# GeoIP information
country_code: Optional[str] = Field(default=None, max_length=2, description="Country code")
country_name: Optional[str] = Field(default=None, max_length=100, description="Country name")
city_name: Optional[str] = Field(default=None, max_length=100, description="City name")
latitude: Optional[str] = Field(default=None, max_length=20, description="Latitude")
longitude: Optional[str] = Field(default=None, max_length=20, description="Longitude")
time_zone: Optional[str] = Field(default=None, max_length=50, description="Time zone")
# ASN information
asn: Optional[int] = Field(default=None, description="Autonomous System Number")
organization: Optional[str] = Field(default=None, max_length=200, description="Organization name")
# Login status
login_success: bool = Field(default=True, description="Whether the login was successful")
login_method: str = Field(max_length=50, description="Login method (password/oauth/etc.)")
# Additional information
notes: Optional[str] = Field(default=None, max_length=500, description="Additional notes")
class Config:
from_attributes = True

View File

@@ -2,6 +2,7 @@
"""
GeoIP dependency for FastAPI
"""
import ipaddress
from functools import lru_cache
from app.helpers.geoip_helper import GeoIPHelper
from app.config import settings
@@ -23,29 +24,76 @@ def get_geoip_helper() -> GeoIPHelper:
def get_client_ip(request) -> str:
"""
Get the real client IP address
Supports proxies, load balancers, and Cloudflare headers
获取客户端真实 IP 地址
支持 IPv4 和 IPv6考虑代理、负载均衡器等情况
"""
headers = request.headers
# 1. Cloudflare specific headers
# 1. Cloudflare 专用头部
cf_ip = headers.get("CF-Connecting-IP")
if cf_ip:
return cf_ip.strip()
ip = cf_ip.strip()
if is_valid_ip(ip):
return ip
true_client_ip = headers.get("True-Client-IP")
if true_client_ip:
return true_client_ip.strip()
ip = true_client_ip.strip()
if is_valid_ip(ip):
return ip
# 2. Standard proxy headers
# 2. 标准代理头部
forwarded_for = headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For may contain multiple IPs, take the first
return forwarded_for.split(",")[0].strip()
# X-Forwarded-For 可能包含多个 IP取第一个有效的
for ip_str in forwarded_for.split(","):
ip = ip_str.strip()
if is_valid_ip(ip) and not is_private_ip(ip):
return ip
real_ip = headers.get("X-Real-IP")
if real_ip:
return real_ip.strip()
ip = real_ip.strip()
if is_valid_ip(ip):
return ip
# 3. Fallback to client host
return request.client.host if request.client else "127.0.0.1"
# 3. 回退到客户端 IP
client_ip = request.client.host if request.client else "127.0.0.1"
return client_ip if is_valid_ip(client_ip) else "127.0.0.1"
def is_valid_ip(ip_str: str) -> bool:
"""
验证 IP 地址是否有效(支持 IPv4 和 IPv6
"""
try:
ipaddress.ip_address(ip_str)
return True
except ValueError:
return False
def is_private_ip(ip_str: str) -> bool:
"""
判断是否为私有 IP 地址
"""
try:
ip = ipaddress.ip_address(ip_str)
return ip.is_private
except ValueError:
return False
def normalize_ip(ip_str: str) -> str:
"""
标准化 IP 地址格式
对于 IPv6转换为压缩格式
"""
try:
ip = ipaddress.ip_address(ip_str)
if isinstance(ip, ipaddress.IPv6Address):
return ip.compressed
else:
return str(ip)
except ValueError:
return ip_str

View File

@@ -89,6 +89,9 @@ async def get_client_user(
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
if not user:
raise HTTPException(status_code=401, detail="Invalid or expired token")
await db.refresh(user)
return user
@@ -125,4 +128,7 @@ async def get_current_user(
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
if not user:
raise HTTPException(status_code=401, detail="Invalid or expired token")
await db.refresh(user)
return user

View File

@@ -23,6 +23,7 @@ from app.dependencies.database import get_redis
from app.dependencies.geoip import get_geoip_helper, get_client_ip
from app.helpers.geoip_helper import GeoIPHelper
from app.log import logger
from app.service.login_log_service import LoginLogService
from app.models.oauth import (
OAuthErrorResponse,
RegistrationRequestErrors,
@@ -201,6 +202,7 @@ async def register_user(
description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。",
)
async def oauth_token(
request: Request,
grant_type: Literal[
"authorization_code", "refresh_token", "password", "client_credentials"
] = Form(..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"),
@@ -268,6 +270,15 @@ async def oauth_token(
# 验证用户
user = await authenticate_user(db, username, password)
if not user:
# 记录失败的登录尝试
await LoginLogService.record_failed_login(
db=db,
request=request,
attempted_username=username,
login_method="password",
notes="Invalid credentials"
)
return create_oauth_error_response(
error="invalid_grant",
description=(
@@ -280,18 +291,34 @@ async def oauth_token(
hint="Incorrect sign in",
)
# 确保用户对象与当前会话关联
await db.refresh(user)
# 记录成功的登录
user_id = getattr(user, 'id')
assert user_id is not None, "User ID should not be None after authentication"
await LoginLogService.record_login(
db=db,
user_id=user_id,
request=request,
login_success=True,
login_method="password",
notes=f"OAuth password grant for client {client_id}"
)
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
# 获取用户ID避免触发延迟加载
access_token = create_access_token(
data={"sub": str(user.id)}, expires_delta=access_token_expires
data={"sub": str(user_id)}, expires_delta=access_token_expires
)
refresh_token_str = generate_refresh_token()
# 存储令牌
assert user.id
assert user_id
await store_token(
db,
user.id,
user_id,
client_id,
scopes,
access_token,
@@ -397,18 +424,26 @@ async def oauth_token(
hint="Invalid authorization code",
)
user, scopes = code_result
# 确保用户对象与当前会话关联
await db.refresh(user)
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
# 重新查询只获取ID避免触发延迟加载
id_result = await db.exec(select(User.id).where(User.username == username))
user_id = id_result.first()
access_token = create_access_token(
data={"sub": str(user.id)}, expires_delta=access_token_expires
data={"sub": str(user_id)}, expires_delta=access_token_expires
)
refresh_token_str = generate_refresh_token()
# 存储令牌
assert user.id
assert user_id
await store_token(
db,
user.id,
user_id,
client_id,
scopes,
access_token,

View File

@@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
"""
用户登录记录服务
"""
import asyncio
from datetime import datetime
from typing import Optional
from fastapi import Request
from sqlmodel.ext.asyncio.session import AsyncSession
from app.database.user_login_log import UserLoginLog
from app.dependencies.geoip import get_geoip_helper, get_client_ip, normalize_ip
from app.log import logger
class LoginLogService:
"""用户登录记录服务"""
@staticmethod
async def record_login(
db: AsyncSession,
user_id: int,
request: Request,
login_success: bool = True,
login_method: str = "password",
notes: Optional[str] = None
) -> UserLoginLog:
"""
记录用户登录信息
Args:
db: 数据库会话
user_id: 用户ID
request: HTTP请求对象
login_success: 登录是否成功
login_method: 登录方式
notes: 备注信息
Returns:
UserLoginLog: 登录记录对象
"""
# 获取客户端IP并标准化格式
raw_ip = get_client_ip(request)
ip_address = normalize_ip(raw_ip)
# 获取User-Agent
user_agent = request.headers.get("User-Agent", "")
# 创建基本的登录记录
login_log = UserLoginLog(
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent,
login_time=datetime.utcnow(),
login_success=login_success,
login_method=login_method,
notes=notes
)
# 异步获取GeoIP信息
try:
geoip = get_geoip_helper()
# 在后台线程中运行GeoIP查询避免阻塞
loop = asyncio.get_event_loop()
geo_info = await loop.run_in_executor(
None,
lambda: geoip.lookup(ip_address)
)
if geo_info:
login_log.country_code = geo_info.get("country_iso", "")
login_log.country_name = geo_info.get("country_name", "")
login_log.city_name = geo_info.get("city_name", "")
login_log.latitude = geo_info.get("latitude", "")
login_log.longitude = geo_info.get("longitude", "")
login_log.time_zone = geo_info.get("time_zone", "")
# 处理 ASN可能是字符串需要转换为整数
asn_value = geo_info.get("asn")
if asn_value is not None:
try:
login_log.asn = int(asn_value)
except (ValueError, TypeError):
login_log.asn = None
login_log.organization = geo_info.get("organization", "")
logger.debug(f"GeoIP lookup for {ip_address}: {geo_info.get('country_name', 'Unknown')}")
else:
logger.warning(f"GeoIP lookup failed for {ip_address}")
except Exception as e:
logger.warning(f"GeoIP lookup error for {ip_address}: {e}")
# 保存到数据库
db.add(login_log)
await db.commit()
await db.refresh(login_log)
logger.info(f"Login recorded for user {user_id} from {ip_address} ({login_method})")
return login_log
@staticmethod
async def record_failed_login(
db: AsyncSession,
request: Request,
attempted_username: Optional[str] = None,
login_method: str = "password",
notes: Optional[str] = None
) -> UserLoginLog:
"""
记录失败的登录尝试
Args:
db: 数据库会话
request: HTTP请求对象
attempted_username: 尝试登录的用户名
login_method: 登录方式
notes: 备注信息
Returns:
UserLoginLog: 登录记录对象
"""
# 对于失败的登录使用user_id=0表示未知用户
return await LoginLogService.record_login(
db=db,
user_id=0, # 0表示未知/失败的登录
request=request,
login_success=False,
login_method=login_method,
notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt"
)
def get_request_info(request: Request) -> dict:
"""
提取请求的详细信息
Args:
request: HTTP请求对象
Returns:
dict: 包含请求信息的字典
"""
return {
"ip": get_client_ip(request),
"user_agent": request.headers.get("User-Agent", ""),
"referer": request.headers.get("Referer", ""),
"accept_language": request.headers.get("Accept-Language", ""),
"x_forwarded_for": request.headers.get("X-Forwarded-For", ""),
"x_real_ip": request.headers.get("X-Real-IP", ""),
}

View File

@@ -0,0 +1,84 @@
"""Fix user login log table name
Revision ID: 2dcd04d3f4dc
Revises: 3eef4794ded1
Create Date: 2025-08-18 00:07:06.886879
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = "2dcd04d3f4dc"
down_revision: str | Sequence[str] | None = "3eef4794ded1"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table("user_login_log",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False),
sa.Column("user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
sa.Column("login_time", sa.DateTime(), nullable=False),
sa.Column("country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True),
sa.Column("country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True),
sa.Column("city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True),
sa.Column("latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True),
sa.Column("longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True),
sa.Column("time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True),
sa.Column("asn", sa.Integer(), nullable=True),
sa.Column("organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True),
sa.Column("login_success", sa.Boolean(), nullable=False),
sa.Column("login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False),
sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
sa.PrimaryKeyConstraint("id")
)
op.create_index(op.f("ix_user_login_log_ip_address"), "user_login_log", ["ip_address"], unique=False)
op.create_index(op.f("ix_user_login_log_user_id"), "user_login_log", ["user_id"], unique=False)
op.drop_index(op.f("ix_userloginlog_ip_address"), table_name="userloginlog")
op.drop_index(op.f("ix_userloginlog_user_id"), table_name="userloginlog")
op.drop_table("userloginlog")
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table("userloginlog",
sa.Column("id", mysql.INTEGER(), autoincrement=True, nullable=False),
sa.Column("user_id", mysql.INTEGER(), autoincrement=False, nullable=False),
sa.Column("ip_address", mysql.VARCHAR(length=45), nullable=False),
sa.Column("user_agent", mysql.VARCHAR(length=500), nullable=True),
sa.Column("login_time", mysql.DATETIME(), nullable=False),
sa.Column("country_code", mysql.VARCHAR(length=2), nullable=True),
sa.Column("country_name", mysql.VARCHAR(length=100), nullable=True),
sa.Column("city_name", mysql.VARCHAR(length=100), nullable=True),
sa.Column("latitude", mysql.VARCHAR(length=20), nullable=True),
sa.Column("longitude", mysql.VARCHAR(length=20), nullable=True),
sa.Column("time_zone", mysql.VARCHAR(length=50), nullable=True),
sa.Column("asn", mysql.INTEGER(), autoincrement=False, nullable=True),
sa.Column("organization", mysql.VARCHAR(length=200), nullable=True),
sa.Column("login_success", mysql.TINYINT(display_width=1), autoincrement=False, nullable=False),
sa.Column("login_method", mysql.VARCHAR(length=50), nullable=False),
sa.Column("notes", mysql.VARCHAR(length=500), nullable=True),
sa.PrimaryKeyConstraint("id"),
mysql_collate="utf8mb4_0900_ai_ci",
mysql_default_charset="utf8mb4",
mysql_engine="InnoDB"
)
op.create_index(op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False)
op.create_index(op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False)
op.drop_index(op.f("ix_user_login_log_user_id"), table_name="user_login_log")
op.drop_index(op.f("ix_user_login_log_ip_address"), table_name="user_login_log")
op.drop_table("user_login_log")
# ### end Alembic commands ###

View File

@@ -0,0 +1,56 @@
"""Add user login log table
Revision ID: 3eef4794ded1
Revises: df9f725a077c
Create Date: 2025-08-18 00:00:11.369944
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = "3eef4794ded1"
down_revision: str | Sequence[str] | None = "df9f725a077c"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table("userloginlog",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False),
sa.Column("user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
sa.Column("login_time", sa.DateTime(), nullable=False),
sa.Column("country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True),
sa.Column("country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True),
sa.Column("city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True),
sa.Column("latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True),
sa.Column("longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True),
sa.Column("time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True),
sa.Column("asn", sa.Integer(), nullable=True),
sa.Column("organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True),
sa.Column("login_success", sa.Boolean(), nullable=False),
sa.Column("login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False),
sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True),
sa.PrimaryKeyConstraint("id")
)
op.create_index(op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False)
op.create_index(op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_userloginlog_user_id"), table_name="userloginlog")
op.drop_index(op.f("ix_userloginlog_ip_address"), table_name="userloginlog")
op.drop_table("userloginlog")
# ### end Alembic commands ###