From f165ae5dc38a2ee22533afa5c54f8d4e3dcd2d6a Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Sun, 10 Aug 2025 05:38:28 +0000 Subject: [PATCH] refactor(config): use pydantic-settings --- .env.client | 4 --- .env.example | 34 +++++++++++++++++++++ app/auth.py | 6 ++-- app/config.py | 59 ++++++++++++++++++------------------ app/dependencies/database.py | 4 +-- app/dependencies/fetcher.py | 8 ++--- app/log.py | 6 ++-- app/router/auth.py | 16 +++++----- app/signalr/hub/hub.py | 4 +-- main.py | 6 ++-- pyproject.toml | 1 + uv.lock | 16 ++++++++++ 12 files changed, 105 insertions(+), 59 deletions(-) delete mode 100644 .env.client create mode 100644 .env.example diff --git a/.env.client b/.env.client deleted file mode 100644 index eac9b22..0000000 --- a/.env.client +++ /dev/null @@ -1,4 +0,0 @@ -# osu! API 客户端配置 -OSU_CLIENT_ID=5 -OSU_CLIENT_SECRET=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk -OSU_API_URL=http://localhost:8000 diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..88437ff --- /dev/null +++ b/.env.example @@ -0,0 +1,34 @@ +# 数据库 URL +DATABASE_URL="mysql+aiomysql://root:password@127.0.0.1:3306/osu_api" +# Redis URL +REDIS_URL="redis://127.0.0.1:6379/0" + +# JWT 密钥,使用 openssl rand -hex 32 生成 +JWT_SECRET_KEY="your_jwt_secret_here" +# JWT 算法 +ALGORITHM="HS256" +# JWT 过期时间 +ACCESS_TOKEN_EXPIRE_MINUTES=1440 + +# 服务器地址 +HOST="0.0.0.0" +PORT=8000 +# 调试模式,生产环境请设置为 false +DEBUG=false + +# osu!lazer 登录设置 +OSU_CLIENT_ID="5" +OSU_CLIENT_SECRET="FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" + +# SignalR 服务器设置 +SIGNALR_NEGOTIATE_TIMEOUT=30 +SIGNALR_PING_INTERVAL=15 + +# Fetcher 设置 +FETCHER_CLIENT_ID="" +FETCHER_CLIENT_SECRET="" +FETCHER_SCOPES=["public"] +FETCHER_CALLBACK_URL="http://localhost:8000/fetcher/callback" + +# 日志设置 +LOG_LEVEL="INFO" diff --git a/app/auth.py b/app/auth.py index 4762662..ddf5f56 100644 --- a/app/auth.py +++ b/app/auth.py @@ -125,12 +125,12 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta( - minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES + minutes=settings.access_token_expire_minutes ) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode( - to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM + to_encode, settings.secret_key, algorithm=settings.algorithm ) return encoded_jwt @@ -146,7 +146,7 @@ def verify_token(token: str) -> dict | None: """验证访问令牌""" try: payload = jwt.decode( - token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + token, settings.secret_key, algorithms=[settings.algorithm] ) return payload except JWTError: diff --git a/app/config.py b/app/config.py index 778155f..d008ccb 100644 --- a/app/config.py +++ b/app/config.py @@ -1,51 +1,50 @@ from __future__ import annotations -import os +from typing import Annotated, Any -from dotenv import load_dotenv - -load_dotenv() +from pydantic import Field, field_validator +from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict -class Settings: +class Settings(BaseSettings): + model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") + # 数据库设置 - DATABASE_URL: str = os.getenv( - "DATABASE_URL", "mysql+aiomysql://root:password@127.0.0.1:3306/osu_api" - ) - REDIS_URL: str = os.getenv("REDIS_URL", "redis://127.0.0.1:6379/0") + database_url: str = "mysql+aiomysql://root:password@127.0.0.1:3306/osu_api" + redis_url: str = "redis://127.0.0.1:6379/0" # JWT 设置 - SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-here") - ALGORITHM: str = os.getenv("ALGORITHM", "HS256") - ACCESS_TOKEN_EXPIRE_MINUTES: int = int( - os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440") - ) + secret_key: str = Field(default="your-secret-key-here", alias="jwt_secret_key") + algorithm: str = "HS256" + access_token_expire_minutes: int = 1440 # OAuth 设置 - OSU_CLIENT_ID: str = os.getenv("OSU_CLIENT_ID", "5") - OSU_CLIENT_SECRET: str = os.getenv( - "OSU_CLIENT_SECRET", "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" - ) + osu_client_id: str = "5" + osu_client_secret: str = "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" # 服务器设置 - HOST: str = os.getenv("HOST", "0.0.0.0") - PORT: int = int(os.getenv("PORT", "8000")) - DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true" + host: str = "0.0.0.0" + port: int = 8000 + debug: bool = False # SignalR 设置 - SIGNALR_NEGOTIATE_TIMEOUT: int = int(os.getenv("SIGNALR_NEGOTIATE_TIMEOUT", "30")) - SIGNALR_PING_INTERVAL: int = int(os.getenv("SIGNALR_PING_INTERVAL", "15")) + signalr_negotiate_timeout: int = 30 + signalr_ping_interval: int = 15 # Fetcher 设置 - FETCHER_CLIENT_ID: str = os.getenv("FETCHER_CLIENT_ID", "") - FETCHER_CLIENT_SECRET: str = os.getenv("FETCHER_CLIENT_SECRET", "") - FETCHER_SCOPES: list[str] = os.getenv("FETCHER_SCOPES", "public").split(",") - FETCHER_CALLBACK_URL: str = os.getenv( - "FETCHER_CALLBACK_URL", "http://localhost:8000/fetcher/callback" - ) + fetcher_client_id: str = "" + fetcher_client_secret: str = "" + fetcher_scopes: Annotated[list[str], NoDecode] = ["public"] + fetcher_callback_url: str = "http://localhost:8000/fetcher/callback" # 日志设置 - LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO").upper() + log_level: str = "INFO" + + @field_validator("fetcher_scopes", mode="before") + def validate_fetcher_scopes(cls, v: Any) -> list[str]: + if isinstance(v, str): + return v.split(",") + return v settings = Settings() diff --git a/app/dependencies/database.py b/app/dependencies/database.py index 1525bfb..c3ec9a4 100644 --- a/app/dependencies/database.py +++ b/app/dependencies/database.py @@ -18,10 +18,10 @@ def json_serializer(value): # 数据库引擎 -engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer) +engine = create_async_engine(settings.database_url, json_serializer=json_serializer) # Redis 连接 -redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) +redis_client = redis.from_url(settings.redis_url, decode_responses=True) # 数据库依赖 diff --git a/app/dependencies/fetcher.py b/app/dependencies/fetcher.py index 806eb87..51964f0 100644 --- a/app/dependencies/fetcher.py +++ b/app/dependencies/fetcher.py @@ -12,10 +12,10 @@ async def get_fetcher() -> Fetcher: global fetcher if fetcher is None: fetcher = Fetcher( - settings.FETCHER_CLIENT_ID, - settings.FETCHER_CLIENT_SECRET, - settings.FETCHER_SCOPES, - settings.FETCHER_CALLBACK_URL, + settings.fetcher_client_id, + settings.fetcher_client_secret, + settings.fetcher_scopes, + settings.fetcher_callback_url, ) redis = get_redis() access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}") diff --git a/app/log.py b/app/log.py index 600ec4d..8383494 100644 --- a/app/log.py +++ b/app/log.py @@ -120,10 +120,10 @@ logger.add( format=( "{time:YYYY-MM-DD HH:mm:ss} [{level}] | {message}" ), - level=settings.LOG_LEVEL, - diagnose=settings.DEBUG, + level=settings.log_level, + diagnose=settings.debug, ) -logging.basicConfig(handlers=[InterceptHandler()], level=settings.LOG_LEVEL, force=True) +logging.basicConfig(handlers=[InterceptHandler()], level=settings.log_level, force=True) uvicorn_loggers = [ "uvicorn", diff --git a/app/router/auth.py b/app/router/auth.py index 7a2a14d..2fa69ad 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -199,8 +199,8 @@ async def oauth_token( """OAuth 令牌端点""" # 验证客户端凭据 if ( - client_id != settings.OSU_CLIENT_ID - or client_secret != settings.OSU_CLIENT_SECRET + client_id != settings.osu_client_id + or client_secret != settings.osu_client_secret ): return create_oauth_error_response( error="invalid_client", @@ -242,7 +242,7 @@ async def oauth_token( ) # 生成令牌 - access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) access_token = create_access_token( data={"sub": str(user.id)}, expires_delta=access_token_expires ) @@ -255,13 +255,13 @@ async def oauth_token( user.id, access_token, refresh_token_str, - settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, + settings.access_token_expire_minutes * 60, ) return TokenResponse( access_token=access_token, token_type="Bearer", - expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, + expires_in=settings.access_token_expire_minutes * 60, refresh_token=refresh_token_str, scope=scope, ) @@ -295,7 +295,7 @@ async def oauth_token( ) # 生成新的访问令牌 - access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) access_token = create_access_token( data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires ) @@ -307,13 +307,13 @@ async def oauth_token( token_record.user_id, access_token, new_refresh_token, - settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, + settings.access_token_expire_minutes * 60, ) return TokenResponse( access_token=access_token, token_type="Bearer", - expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, + expires_in=settings.access_token_expire_minutes * 60, refresh_token=new_refresh_token, scope=scope, ) diff --git a/app/signalr/hub/hub.py b/app/signalr/hub/hub.py index 4bab451..92fc90d 100644 --- a/app/signalr/hub/hub.py +++ b/app/signalr/hub/hub.py @@ -74,7 +74,7 @@ class Client: while True: try: await self.send_packet(PingPacket()) - await asyncio.sleep(settings.SIGNALR_PING_INTERVAL) + await asyncio.sleep(settings.signalr_ping_interval) except WebSocketDisconnect: break except Exception as e: @@ -131,7 +131,7 @@ class Hub[TState: UserState]: if connection_token in self.waited_clients: if ( self.waited_clients[connection_token] - < time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT + < time.time() - settings.signalr_negotiate_timeout ): raise TimeoutError(f"Connection {connection_id} has waited too long.") del self.waited_clients[connection_token] diff --git a/main.py b/main.py index fe27c14..ade04f2 100644 --- a/main.py +++ b/main.py @@ -57,9 +57,9 @@ if __name__ == "__main__": uvicorn.run( "main:app", - host=settings.HOST, - port=settings.PORT, - reload=settings.DEBUG, + host=settings.host, + port=settings.port, + reload=settings.debug, log_config=None, # 禁用uvicorn默认日志配置 access_log=True, # 启用访问日志 ) diff --git a/pyproject.toml b/pyproject.toml index efadb39..a687f17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "loguru>=0.7.3", "msgpack-lazer-api", "passlib[bcrypt]>=1.7.4", + "pydantic-settings>=2.10.1", "pydantic[email]>=2.5.0", "python-dotenv>=1.0.0", "python-jose[cryptography]>=3.3.0", diff --git a/uv.lock b/uv.lock index a343b97..ffc6105 100644 --- a/uv.lock +++ b/uv.lock @@ -514,6 +514,7 @@ dependencies = [ { name = "msgpack-lazer-api" }, { name = "passlib", extra = ["bcrypt"] }, { name = "pydantic", extra = ["email"] }, + { name = "pydantic-settings" }, { name = "python-dotenv" }, { name = "python-jose", extra = ["cryptography"] }, { name = "python-multipart" }, @@ -543,6 +544,7 @@ requires-dist = [ { name = "msgpack-lazer-api", editable = "packages/msgpack_lazer_api" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "pydantic", extras = ["email"], specifier = ">=2.5.0" }, + { name = "pydantic-settings", specifier = ">=2.10.1" }, { name = "python-dotenv", specifier = ">=1.0.0" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" }, { name = "python-multipart", specifier = ">=0.0.6" }, @@ -678,6 +680,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777, upload-time = "2025-04-23T18:32:25.088Z" }, ] +[[package]] +name = "pydantic-settings" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/85/1ea668bbab3c50071ca613c6ab30047fb36ab0da1b92fa8f17bbc38fd36c/pydantic_settings-2.10.1.tar.gz", hash = "sha256:06f0062169818d0f5524420a360d632d5857b83cffd4d42fe29597807a1614ee", size = 172583, upload-time = "2025-06-24T13:26:46.841Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/58/f0/427018098906416f580e3cf1366d3b1abfb408a0652e9f31600c24a1903c/pydantic_settings-2.10.1-py3-none-any.whl", hash = "sha256:a60952460b99cf661dc25c29c0ef171721f98bfcb52ef8d9ea4c943d7c8cc796", size = 45235, upload-time = "2025-06-24T13:26:45.485Z" }, +] + [[package]] name = "pymysql" version = "1.1.1"