refactor(config): use pydantic-settings

This commit is contained in:
MingxuanGame
2025-08-10 05:38:28 +00:00
parent 703a7901b3
commit f165ae5dc3
12 changed files with 105 additions and 59 deletions

View File

@@ -1,4 +0,0 @@
# osu! API 客户端配置
OSU_CLIENT_ID=5
OSU_CLIENT_SECRET=FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk
OSU_API_URL=http://localhost:8000

34
.env.example Normal file
View File

@@ -0,0 +1,34 @@
# 数据库 URL
DATABASE_URL="mysql+aiomysql://root:password@127.0.0.1:3306/osu_api"
# Redis URL
REDIS_URL="redis://127.0.0.1:6379/0"
# JWT 密钥,使用 openssl rand -hex 32 生成
JWT_SECRET_KEY="your_jwt_secret_here"
# JWT 算法
ALGORITHM="HS256"
# JWT 过期时间
ACCESS_TOKEN_EXPIRE_MINUTES=1440
# 服务器地址
HOST="0.0.0.0"
PORT=8000
# 调试模式,生产环境请设置为 false
DEBUG=false
# osu!lazer 登录设置
OSU_CLIENT_ID="5"
OSU_CLIENT_SECRET="FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
# SignalR 服务器设置
SIGNALR_NEGOTIATE_TIMEOUT=30
SIGNALR_PING_INTERVAL=15
# Fetcher 设置
FETCHER_CLIENT_ID=""
FETCHER_CLIENT_SECRET=""
FETCHER_SCOPES=["public"]
FETCHER_CALLBACK_URL="http://localhost:8000/fetcher/callback"
# 日志设置
LOG_LEVEL="INFO"

View File

@@ -125,12 +125,12 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s
expire = datetime.utcnow() + expires_delta expire = datetime.utcnow() + expires_delta
else: else:
expire = datetime.utcnow() + timedelta( expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES minutes=settings.access_token_expire_minutes
) )
to_encode.update({"exp": expire}) to_encode.update({"exp": expire})
encoded_jwt = jwt.encode( encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM to_encode, settings.secret_key, algorithm=settings.algorithm
) )
return encoded_jwt return encoded_jwt
@@ -146,7 +146,7 @@ def verify_token(token: str) -> dict | None:
"""验证访问令牌""" """验证访问令牌"""
try: try:
payload = jwt.decode( payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] token, settings.secret_key, algorithms=[settings.algorithm]
) )
return payload return payload
except JWTError: except JWTError:

View File

@@ -1,51 +1,50 @@
from __future__ import annotations from __future__ import annotations
import os from typing import Annotated, Any
from dotenv import load_dotenv from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
load_dotenv()
class Settings: class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
# 数据库设置 # 数据库设置
DATABASE_URL: str = os.getenv( database_url: str = "mysql+aiomysql://root:password@127.0.0.1:3306/osu_api"
"DATABASE_URL", "mysql+aiomysql://root:password@127.0.0.1:3306/osu_api" redis_url: str = "redis://127.0.0.1:6379/0"
)
REDIS_URL: str = os.getenv("REDIS_URL", "redis://127.0.0.1:6379/0")
# JWT 设置 # JWT 设置
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-here") secret_key: str = Field(default="your-secret-key-here", alias="jwt_secret_key")
ALGORITHM: str = os.getenv("ALGORITHM", "HS256") algorithm: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = int( access_token_expire_minutes: int = 1440
os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")
)
# OAuth 设置 # OAuth 设置
OSU_CLIENT_ID: str = os.getenv("OSU_CLIENT_ID", "5") osu_client_id: str = "5"
OSU_CLIENT_SECRET: str = os.getenv( osu_client_secret: str = "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
"OSU_CLIENT_SECRET", "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
)
# 服务器设置 # 服务器设置
HOST: str = os.getenv("HOST", "0.0.0.0") host: str = "0.0.0.0"
PORT: int = int(os.getenv("PORT", "8000")) port: int = 8000
DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true" debug: bool = False
# SignalR 设置 # SignalR 设置
SIGNALR_NEGOTIATE_TIMEOUT: int = int(os.getenv("SIGNALR_NEGOTIATE_TIMEOUT", "30")) signalr_negotiate_timeout: int = 30
SIGNALR_PING_INTERVAL: int = int(os.getenv("SIGNALR_PING_INTERVAL", "15")) signalr_ping_interval: int = 15
# Fetcher 设置 # Fetcher 设置
FETCHER_CLIENT_ID: str = os.getenv("FETCHER_CLIENT_ID", "") fetcher_client_id: str = ""
FETCHER_CLIENT_SECRET: str = os.getenv("FETCHER_CLIENT_SECRET", "") fetcher_client_secret: str = ""
FETCHER_SCOPES: list[str] = os.getenv("FETCHER_SCOPES", "public").split(",") fetcher_scopes: Annotated[list[str], NoDecode] = ["public"]
FETCHER_CALLBACK_URL: str = os.getenv( fetcher_callback_url: str = "http://localhost:8000/fetcher/callback"
"FETCHER_CALLBACK_URL", "http://localhost:8000/fetcher/callback"
)
# 日志设置 # 日志设置
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO").upper() log_level: str = "INFO"
@field_validator("fetcher_scopes", mode="before")
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
if isinstance(v, str):
return v.split(",")
return v
settings = Settings() settings = Settings()

View File

@@ -18,10 +18,10 @@ def json_serializer(value):
# 数据库引擎 # 数据库引擎
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer) engine = create_async_engine(settings.database_url, json_serializer=json_serializer)
# Redis 连接 # Redis 连接
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) redis_client = redis.from_url(settings.redis_url, decode_responses=True)
# 数据库依赖 # 数据库依赖

View File

@@ -12,10 +12,10 @@ async def get_fetcher() -> Fetcher:
global fetcher global fetcher
if fetcher is None: if fetcher is None:
fetcher = Fetcher( fetcher = Fetcher(
settings.FETCHER_CLIENT_ID, settings.fetcher_client_id,
settings.FETCHER_CLIENT_SECRET, settings.fetcher_client_secret,
settings.FETCHER_SCOPES, settings.fetcher_scopes,
settings.FETCHER_CALLBACK_URL, settings.fetcher_callback_url,
) )
redis = get_redis() redis = get_redis()
access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}") access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")

View File

@@ -120,10 +120,10 @@ logger.add(
format=( format=(
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}" "<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"
), ),
level=settings.LOG_LEVEL, level=settings.log_level,
diagnose=settings.DEBUG, diagnose=settings.debug,
) )
logging.basicConfig(handlers=[InterceptHandler()], level=settings.LOG_LEVEL, force=True) logging.basicConfig(handlers=[InterceptHandler()], level=settings.log_level, force=True)
uvicorn_loggers = [ uvicorn_loggers = [
"uvicorn", "uvicorn",

View File

@@ -199,8 +199,8 @@ async def oauth_token(
"""OAuth 令牌端点""" """OAuth 令牌端点"""
# 验证客户端凭据 # 验证客户端凭据
if ( if (
client_id != settings.OSU_CLIENT_ID client_id != settings.osu_client_id
or client_secret != settings.OSU_CLIENT_SECRET or client_secret != settings.osu_client_secret
): ):
return create_oauth_error_response( return create_oauth_error_response(
error="invalid_client", error="invalid_client",
@@ -242,7 +242,7 @@ async def oauth_token(
) )
# 生成令牌 # 生成令牌
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token( access_token = create_access_token(
data={"sub": str(user.id)}, expires_delta=access_token_expires data={"sub": str(user.id)}, expires_delta=access_token_expires
) )
@@ -255,13 +255,13 @@ async def oauth_token(
user.id, user.id,
access_token, access_token,
refresh_token_str, refresh_token_str,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, settings.access_token_expire_minutes * 60,
) )
return TokenResponse( return TokenResponse(
access_token=access_token, access_token=access_token,
token_type="Bearer", token_type="Bearer",
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str, refresh_token=refresh_token_str,
scope=scope, scope=scope,
) )
@@ -295,7 +295,7 @@ async def oauth_token(
) )
# 生成新的访问令牌 # 生成新的访问令牌
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token( access_token = create_access_token(
data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires
) )
@@ -307,13 +307,13 @@ async def oauth_token(
token_record.user_id, token_record.user_id,
access_token, access_token,
new_refresh_token, new_refresh_token,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, settings.access_token_expire_minutes * 60,
) )
return TokenResponse( return TokenResponse(
access_token=access_token, access_token=access_token,
token_type="Bearer", token_type="Bearer",
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, expires_in=settings.access_token_expire_minutes * 60,
refresh_token=new_refresh_token, refresh_token=new_refresh_token,
scope=scope, scope=scope,
) )

View File

@@ -74,7 +74,7 @@ class Client:
while True: while True:
try: try:
await self.send_packet(PingPacket()) await self.send_packet(PingPacket())
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL) await asyncio.sleep(settings.signalr_ping_interval)
except WebSocketDisconnect: except WebSocketDisconnect:
break break
except Exception as e: except Exception as e:
@@ -131,7 +131,7 @@ class Hub[TState: UserState]:
if connection_token in self.waited_clients: if connection_token in self.waited_clients:
if ( if (
self.waited_clients[connection_token] self.waited_clients[connection_token]
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT < time.time() - settings.signalr_negotiate_timeout
): ):
raise TimeoutError(f"Connection {connection_id} has waited too long.") raise TimeoutError(f"Connection {connection_id} has waited too long.")
del self.waited_clients[connection_token] del self.waited_clients[connection_token]

View File

@@ -57,9 +57,9 @@ if __name__ == "__main__":
uvicorn.run( uvicorn.run(
"main:app", "main:app",
host=settings.HOST, host=settings.host,
port=settings.PORT, port=settings.port,
reload=settings.DEBUG, reload=settings.debug,
log_config=None, # 禁用uvicorn默认日志配置 log_config=None, # 禁用uvicorn默认日志配置
access_log=True, # 启用访问日志 access_log=True, # 启用访问日志
) )

View File

@@ -15,6 +15,7 @@ dependencies = [
"loguru>=0.7.3", "loguru>=0.7.3",
"msgpack-lazer-api", "msgpack-lazer-api",
"passlib[bcrypt]>=1.7.4", "passlib[bcrypt]>=1.7.4",
"pydantic-settings>=2.10.1",
"pydantic[email]>=2.5.0", "pydantic[email]>=2.5.0",
"python-dotenv>=1.0.0", "python-dotenv>=1.0.0",
"python-jose[cryptography]>=3.3.0", "python-jose[cryptography]>=3.3.0",

16
uv.lock generated
View File

@@ -514,6 +514,7 @@ dependencies = [
{ name = "msgpack-lazer-api" }, { name = "msgpack-lazer-api" },
{ name = "passlib", extra = ["bcrypt"] }, { name = "passlib", extra = ["bcrypt"] },
{ name = "pydantic", extra = ["email"] }, { name = "pydantic", extra = ["email"] },
{ name = "pydantic-settings" },
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "python-jose", extra = ["cryptography"] }, { name = "python-jose", extra = ["cryptography"] },
{ name = "python-multipart" }, { name = "python-multipart" },
@@ -543,6 +544,7 @@ requires-dist = [
{ name = "msgpack-lazer-api", editable = "packages/msgpack_lazer_api" }, { name = "msgpack-lazer-api", editable = "packages/msgpack_lazer_api" },
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
{ name = "pydantic", extras = ["email"], specifier = ">=2.5.0" }, { name = "pydantic", extras = ["email"], specifier = ">=2.5.0" },
{ name = "pydantic-settings", specifier = ">=2.10.1" },
{ name = "python-dotenv", specifier = ">=1.0.0" }, { name = "python-dotenv", specifier = ">=1.0.0" },
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
{ name = "python-multipart", specifier = ">=0.0.6" }, { name = "python-multipart", specifier = ">=0.0.6" },
@@ -678,6 +680,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777, upload-time = "2025-04-23T18:32:25.088Z" }, { url = "https://files.pythonhosted.org/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777, upload-time = "2025-04-23T18:32:25.088Z" },
] ]
[[package]]
name = "pydantic-settings"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pydantic" },
{ name = "python-dotenv" },
{ name = "typing-inspection" },
]
sdist = { url = "https://files.pythonhosted.org/packages/68/85/1ea668bbab3c50071ca613c6ab30047fb36ab0da1b92fa8f17bbc38fd36c/pydantic_settings-2.10.1.tar.gz", hash = "sha256:06f0062169818d0f5524420a360d632d5857b83cffd4d42fe29597807a1614ee", size = 172583, upload-time = "2025-06-24T13:26:46.841Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/58/f0/427018098906416f580e3cf1366d3b1abfb408a0652e9f31600c24a1903c/pydantic_settings-2.10.1-py3-none-any.whl", hash = "sha256:a60952460b99cf661dc25c29c0ef171721f98bfcb52ef8d9ea4c943d7c8cc796", size = 45235, upload-time = "2025-06-24T13:26:45.485Z" },
]
[[package]] [[package]]
name = "pymysql" name = "pymysql"
version = "1.1.1" version = "1.1.1"