diff --git a/app/auth.py b/app/auth.py index 6293c07..3ffac5a 100644 --- a/app/auth.py +++ b/app/auth.py @@ -114,11 +114,14 @@ async def authenticate_user_legacy( pw_md5 = hashlib.md5(password.encode()).hexdigest() # 2. 根据用户名查找用户 - statement = select(User).where(User.username == name) + statement = select(User).where(User.username == name).options() user = (await db.exec(statement)).first() if not user: return None + + await db.refresh(user) + # 3. 验证密码 if user.pw_bcrypt is None or user.pw_bcrypt == "": return None @@ -261,4 +264,8 @@ async def get_user_by_authorization_code( statement = select(User).where(User.id == int(user_id)) user = (await db.exec(statement)).first() - return (user, scopes.split(",")) if user else None + if user: + + await db.refresh(user) + return (user, scopes.split(",")) + return None diff --git a/app/database/__init__.py b/app/database/__init__.py index d8794c4..5c58b47 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -60,6 +60,7 @@ from .user_account_history import ( UserAccountHistoryResp, UserAccountHistoryType, ) +from .user_login_log import UserLoginLog __all__ = [ "APIUploadedRoom", @@ -118,6 +119,7 @@ __all__ = [ "UserAchievement", "UserAchievement", "UserAchievementResp", + "UserLoginLog", "UserResp", "UserStatistics", "UserStatisticsResp", diff --git a/app/database/user_login_log.py b/app/database/user_login_log.py new file mode 100644 index 0000000..ace1132 --- /dev/null +++ b/app/database/user_login_log.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +""" +User Login Log Database Model +""" +from datetime import datetime +from typing import Optional +from sqlmodel import Field, SQLModel + + +class UserLoginLog(SQLModel, table=True): + """User login log table""" + __tablename__ = "user_login_log" # pyright: ignore[reportAssignmentType] + + id: Optional[int] = Field(default=None, primary_key=True, description="Record ID") + user_id: int = Field(index=True, description="User ID") + ip_address: str = Field(max_length=45, index=True, description="IP address (supports IPv4 and IPv6)") + user_agent: Optional[str] = Field(default=None, max_length=500, description="User agent information") + login_time: datetime = Field(default_factory=datetime.utcnow, description="Login time") + + # GeoIP information + country_code: Optional[str] = Field(default=None, max_length=2, description="Country code") + country_name: Optional[str] = Field(default=None, max_length=100, description="Country name") + city_name: Optional[str] = Field(default=None, max_length=100, description="City name") + latitude: Optional[str] = Field(default=None, max_length=20, description="Latitude") + longitude: Optional[str] = Field(default=None, max_length=20, description="Longitude") + time_zone: Optional[str] = Field(default=None, max_length=50, description="Time zone") + + # ASN information + asn: Optional[int] = Field(default=None, description="Autonomous System Number") + organization: Optional[str] = Field(default=None, max_length=200, description="Organization name") + + # Login status + login_success: bool = Field(default=True, description="Whether the login was successful") + login_method: str = Field(max_length=50, description="Login method (password/oauth/etc.)") + + # Additional information + notes: Optional[str] = Field(default=None, max_length=500, description="Additional notes") + + class Config: + from_attributes = True diff --git a/app/dependencies/geoip.py b/app/dependencies/geoip.py index 34e8d3e..8b29679 100644 --- a/app/dependencies/geoip.py +++ b/app/dependencies/geoip.py @@ -2,6 +2,7 @@ """ GeoIP dependency for FastAPI """ +import ipaddress from functools import lru_cache from app.helpers.geoip_helper import GeoIPHelper from app.config import settings @@ -23,29 +24,76 @@ def get_geoip_helper() -> GeoIPHelper: def get_client_ip(request) -> str: """ - Get the real client IP address - Supports proxies, load balancers, and Cloudflare headers + 获取客户端真实 IP 地址 + 支持 IPv4 和 IPv6,考虑代理、负载均衡器等情况 """ headers = request.headers - # 1. Cloudflare specific headers + # 1. Cloudflare 专用头部 cf_ip = headers.get("CF-Connecting-IP") if cf_ip: - return cf_ip.strip() + ip = cf_ip.strip() + if is_valid_ip(ip): + return ip true_client_ip = headers.get("True-Client-IP") if true_client_ip: - return true_client_ip.strip() + ip = true_client_ip.strip() + if is_valid_ip(ip): + return ip - # 2. Standard proxy headers + # 2. 标准代理头部 forwarded_for = headers.get("X-Forwarded-For") if forwarded_for: - # X-Forwarded-For may contain multiple IPs, take the first - return forwarded_for.split(",")[0].strip() + # X-Forwarded-For 可能包含多个 IP,取第一个有效的 + for ip_str in forwarded_for.split(","): + ip = ip_str.strip() + if is_valid_ip(ip) and not is_private_ip(ip): + return ip real_ip = headers.get("X-Real-IP") if real_ip: - return real_ip.strip() + ip = real_ip.strip() + if is_valid_ip(ip): + return ip - # 3. Fallback to client host - return request.client.host if request.client else "127.0.0.1" + # 3. 回退到客户端 IP + client_ip = request.client.host if request.client else "127.0.0.1" + return client_ip if is_valid_ip(client_ip) else "127.0.0.1" + + +def is_valid_ip(ip_str: str) -> bool: + """ + 验证 IP 地址是否有效(支持 IPv4 和 IPv6) + """ + try: + ipaddress.ip_address(ip_str) + return True + except ValueError: + return False + + +def is_private_ip(ip_str: str) -> bool: + """ + 判断是否为私有 IP 地址 + """ + try: + ip = ipaddress.ip_address(ip_str) + return ip.is_private + except ValueError: + return False + + +def normalize_ip(ip_str: str) -> str: + """ + 标准化 IP 地址格式 + 对于 IPv6,转换为压缩格式 + """ + try: + ip = ipaddress.ip_address(ip_str) + if isinstance(ip, ipaddress.IPv6Address): + return ip.compressed + else: + return str(ip) + except ValueError: + return ip_str diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 084323f..73e679f 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -89,6 +89,9 @@ async def get_client_user( user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() if not user: raise HTTPException(status_code=401, detail="Invalid or expired token") + + + await db.refresh(user) return user @@ -125,4 +128,7 @@ async def get_current_user( user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() if not user: raise HTTPException(status_code=401, detail="Invalid or expired token") + + + await db.refresh(user) return user diff --git a/app/router/auth.py b/app/router/auth.py index 4c3505d..3eac5f7 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -23,6 +23,7 @@ from app.dependencies.database import get_redis from app.dependencies.geoip import get_geoip_helper, get_client_ip from app.helpers.geoip_helper import GeoIPHelper from app.log import logger +from app.service.login_log_service import LoginLogService from app.models.oauth import ( OAuthErrorResponse, RegistrationRequestErrors, @@ -201,6 +202,7 @@ async def register_user( description="OAuth 令牌端点,支持密码、刷新令牌和授权码三种授权方式。", ) async def oauth_token( + request: Request, grant_type: Literal[ "authorization_code", "refresh_token", "password", "client_credentials" ] = Form(..., description="授权类型:密码/刷新令牌/授权码/客户端凭证"), @@ -268,6 +270,15 @@ async def oauth_token( # 验证用户 user = await authenticate_user(db, username, password) if not user: + # 记录失败的登录尝试 + await LoginLogService.record_failed_login( + db=db, + request=request, + attempted_username=username, + login_method="password", + notes="Invalid credentials" + ) + return create_oauth_error_response( error="invalid_grant", description=( @@ -280,18 +291,34 @@ async def oauth_token( hint="Incorrect sign in", ) + # 确保用户对象与当前会话关联 + await db.refresh(user) + + # 记录成功的登录 + user_id = getattr(user, 'id') + assert user_id is not None, "User ID should not be None after authentication" + await LoginLogService.record_login( + db=db, + user_id=user_id, + request=request, + login_success=True, + login_method="password", + notes=f"OAuth password grant for client {client_id}" + ) + # 生成令牌 access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) + # 获取用户ID,避免触发延迟加载 access_token = create_access_token( - data={"sub": str(user.id)}, expires_delta=access_token_expires + data={"sub": str(user_id)}, expires_delta=access_token_expires ) refresh_token_str = generate_refresh_token() # 存储令牌 - assert user.id + assert user_id await store_token( db, - user.id, + user_id, client_id, scopes, access_token, @@ -397,18 +424,26 @@ async def oauth_token( hint="Invalid authorization code", ) user, scopes = code_result + + # 确保用户对象与当前会话关联 + await db.refresh(user) + # 生成令牌 access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) + # 重新查询只获取ID,避免触发延迟加载 + id_result = await db.exec(select(User.id).where(User.username == username)) + user_id = id_result.first() + access_token = create_access_token( - data={"sub": str(user.id)}, expires_delta=access_token_expires + data={"sub": str(user_id)}, expires_delta=access_token_expires ) refresh_token_str = generate_refresh_token() # 存储令牌 - assert user.id + assert user_id await store_token( db, - user.id, + user_id, client_id, scopes, access_token, diff --git a/app/service/login_log_service.py b/app/service/login_log_service.py new file mode 100644 index 0000000..a9cc484 --- /dev/null +++ b/app/service/login_log_service.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +""" +用户登录记录服务 +""" +import asyncio +from datetime import datetime +from typing import Optional +from fastapi import Request +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.database.user_login_log import UserLoginLog +from app.dependencies.geoip import get_geoip_helper, get_client_ip, normalize_ip +from app.log import logger + + +class LoginLogService: + """用户登录记录服务""" + + @staticmethod + async def record_login( + db: AsyncSession, + user_id: int, + request: Request, + login_success: bool = True, + login_method: str = "password", + notes: Optional[str] = None + ) -> UserLoginLog: + """ + 记录用户登录信息 + + Args: + db: 数据库会话 + user_id: 用户ID + request: HTTP请求对象 + login_success: 登录是否成功 + login_method: 登录方式 + notes: 备注信息 + + Returns: + UserLoginLog: 登录记录对象 + """ + # 获取客户端IP并标准化格式 + raw_ip = get_client_ip(request) + ip_address = normalize_ip(raw_ip) + + # 获取User-Agent + user_agent = request.headers.get("User-Agent", "") + + # 创建基本的登录记录 + login_log = UserLoginLog( + user_id=user_id, + ip_address=ip_address, + user_agent=user_agent, + login_time=datetime.utcnow(), + login_success=login_success, + login_method=login_method, + notes=notes + ) + + # 异步获取GeoIP信息 + try: + geoip = get_geoip_helper() + + # 在后台线程中运行GeoIP查询(避免阻塞) + loop = asyncio.get_event_loop() + geo_info = await loop.run_in_executor( + None, + lambda: geoip.lookup(ip_address) + ) + + if geo_info: + login_log.country_code = geo_info.get("country_iso", "") + login_log.country_name = geo_info.get("country_name", "") + login_log.city_name = geo_info.get("city_name", "") + login_log.latitude = geo_info.get("latitude", "") + login_log.longitude = geo_info.get("longitude", "") + login_log.time_zone = geo_info.get("time_zone", "") + + # 处理 ASN(可能是字符串,需要转换为整数) + asn_value = geo_info.get("asn") + if asn_value is not None: + try: + login_log.asn = int(asn_value) + except (ValueError, TypeError): + login_log.asn = None + + login_log.organization = geo_info.get("organization", "") + + logger.debug(f"GeoIP lookup for {ip_address}: {geo_info.get('country_name', 'Unknown')}") + else: + logger.warning(f"GeoIP lookup failed for {ip_address}") + + except Exception as e: + logger.warning(f"GeoIP lookup error for {ip_address}: {e}") + + # 保存到数据库 + db.add(login_log) + await db.commit() + await db.refresh(login_log) + + logger.info(f"Login recorded for user {user_id} from {ip_address} ({login_method})") + return login_log + + @staticmethod + async def record_failed_login( + db: AsyncSession, + request: Request, + attempted_username: Optional[str] = None, + login_method: str = "password", + notes: Optional[str] = None + ) -> UserLoginLog: + """ + 记录失败的登录尝试 + + Args: + db: 数据库会话 + request: HTTP请求对象 + attempted_username: 尝试登录的用户名 + login_method: 登录方式 + notes: 备注信息 + + Returns: + UserLoginLog: 登录记录对象 + """ + # 对于失败的登录,使用user_id=0表示未知用户 + return await LoginLogService.record_login( + db=db, + user_id=0, # 0表示未知/失败的登录 + request=request, + login_success=False, + login_method=login_method, + notes=f"Failed login attempt: {attempted_username}" if attempted_username else "Failed login attempt" + ) + + +def get_request_info(request: Request) -> dict: + """ + 提取请求的详细信息 + + Args: + request: HTTP请求对象 + + Returns: + dict: 包含请求信息的字典 + """ + return { + "ip": get_client_ip(request), + "user_agent": request.headers.get("User-Agent", ""), + "referer": request.headers.get("Referer", ""), + "accept_language": request.headers.get("Accept-Language", ""), + "x_forwarded_for": request.headers.get("X-Forwarded-For", ""), + "x_real_ip": request.headers.get("X-Real-IP", ""), + } diff --git a/migrations/versions/2dcd04d3f4dc_fix_user_login_log_table_name.py b/migrations/versions/2dcd04d3f4dc_fix_user_login_log_table_name.py new file mode 100644 index 0000000..e361e27 --- /dev/null +++ b/migrations/versions/2dcd04d3f4dc_fix_user_login_log_table_name.py @@ -0,0 +1,84 @@ +"""Fix user login log table name + +Revision ID: 2dcd04d3f4dc +Revises: 3eef4794ded1 +Create Date: 2025-08-18 00:07:06.886879 + +""" +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql +import sqlmodel + +# revision identifiers, used by Alembic. +revision: str = "2dcd04d3f4dc" +down_revision: str | Sequence[str] | None = "3eef4794ded1" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table("user_login_log", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False), + sa.Column("user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), + sa.Column("login_time", sa.DateTime(), nullable=False), + sa.Column("country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True), + sa.Column("country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), + sa.Column("city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), + sa.Column("latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True), + sa.Column("longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True), + sa.Column("time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True), + sa.Column("asn", sa.Integer(), nullable=True), + sa.Column("organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True), + sa.Column("login_success", sa.Boolean(), nullable=False), + sa.Column("login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), + sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), + sa.PrimaryKeyConstraint("id") + ) + op.create_index(op.f("ix_user_login_log_ip_address"), "user_login_log", ["ip_address"], unique=False) + op.create_index(op.f("ix_user_login_log_user_id"), "user_login_log", ["user_id"], unique=False) + op.drop_index(op.f("ix_userloginlog_ip_address"), table_name="userloginlog") + op.drop_index(op.f("ix_userloginlog_user_id"), table_name="userloginlog") + op.drop_table("userloginlog") + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table("userloginlog", + sa.Column("id", mysql.INTEGER(), autoincrement=True, nullable=False), + sa.Column("user_id", mysql.INTEGER(), autoincrement=False, nullable=False), + sa.Column("ip_address", mysql.VARCHAR(length=45), nullable=False), + sa.Column("user_agent", mysql.VARCHAR(length=500), nullable=True), + sa.Column("login_time", mysql.DATETIME(), nullable=False), + sa.Column("country_code", mysql.VARCHAR(length=2), nullable=True), + sa.Column("country_name", mysql.VARCHAR(length=100), nullable=True), + sa.Column("city_name", mysql.VARCHAR(length=100), nullable=True), + sa.Column("latitude", mysql.VARCHAR(length=20), nullable=True), + sa.Column("longitude", mysql.VARCHAR(length=20), nullable=True), + sa.Column("time_zone", mysql.VARCHAR(length=50), nullable=True), + sa.Column("asn", mysql.INTEGER(), autoincrement=False, nullable=True), + sa.Column("organization", mysql.VARCHAR(length=200), nullable=True), + sa.Column("login_success", mysql.TINYINT(display_width=1), autoincrement=False, nullable=False), + sa.Column("login_method", mysql.VARCHAR(length=50), nullable=False), + sa.Column("notes", mysql.VARCHAR(length=500), nullable=True), + sa.PrimaryKeyConstraint("id"), + mysql_collate="utf8mb4_0900_ai_ci", + mysql_default_charset="utf8mb4", + mysql_engine="InnoDB" + ) + op.create_index(op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False) + op.create_index(op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False) + op.drop_index(op.f("ix_user_login_log_user_id"), table_name="user_login_log") + op.drop_index(op.f("ix_user_login_log_ip_address"), table_name="user_login_log") + op.drop_table("user_login_log") + # ### end Alembic commands ### diff --git a/migrations/versions/3eef4794ded1_add_user_login_log_table.py b/migrations/versions/3eef4794ded1_add_user_login_log_table.py new file mode 100644 index 0000000..b4df540 --- /dev/null +++ b/migrations/versions/3eef4794ded1_add_user_login_log_table.py @@ -0,0 +1,56 @@ +"""Add user login log table + +Revision ID: 3eef4794ded1 +Revises: df9f725a077c +Create Date: 2025-08-18 00:00:11.369944 + +""" +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +import sqlmodel + +# revision identifiers, used by Alembic. +revision: str = "3eef4794ded1" +down_revision: str | Sequence[str] | None = "df9f725a077c" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table("userloginlog", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("ip_address", sqlmodel.sql.sqltypes.AutoString(length=45), nullable=False), + sa.Column("user_agent", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), + sa.Column("login_time", sa.DateTime(), nullable=False), + sa.Column("country_code", sqlmodel.sql.sqltypes.AutoString(length=2), nullable=True), + sa.Column("country_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), + sa.Column("city_name", sqlmodel.sql.sqltypes.AutoString(length=100), nullable=True), + sa.Column("latitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True), + sa.Column("longitude", sqlmodel.sql.sqltypes.AutoString(length=20), nullable=True), + sa.Column("time_zone", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=True), + sa.Column("asn", sa.Integer(), nullable=True), + sa.Column("organization", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=True), + sa.Column("login_success", sa.Boolean(), nullable=False), + sa.Column("login_method", sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), + sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=500), nullable=True), + sa.PrimaryKeyConstraint("id") + ) + op.create_index(op.f("ix_userloginlog_ip_address"), "userloginlog", ["ip_address"], unique=False) + op.create_index(op.f("ix_userloginlog_user_id"), "userloginlog", ["user_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_userloginlog_user_id"), table_name="userloginlog") + op.drop_index(op.f("ix_userloginlog_ip_address"), table_name="userloginlog") + op.drop_table("userloginlog") + # ### end Alembic commands ###