refactor(config): use pydantic-settings
This commit is contained in:
@@ -125,12 +125,12 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None) -> s
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
minutes=settings.access_token_expire_minutes
|
||||
)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
to_encode, settings.secret_key, algorithm=settings.algorithm
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
@@ -146,7 +146,7 @@ def verify_token(token: str) -> dict | None:
|
||||
"""验证访问令牌"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
token, settings.secret_key, algorithms=[settings.algorithm]
|
||||
)
|
||||
return payload
|
||||
except JWTError:
|
||||
|
||||
@@ -1,51 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Annotated, Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings:
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
# 数据库设置
|
||||
DATABASE_URL: str = os.getenv(
|
||||
"DATABASE_URL", "mysql+aiomysql://root:password@127.0.0.1:3306/osu_api"
|
||||
)
|
||||
REDIS_URL: str = os.getenv("REDIS_URL", "redis://127.0.0.1:6379/0")
|
||||
database_url: str = "mysql+aiomysql://root:password@127.0.0.1:3306/osu_api"
|
||||
redis_url: str = "redis://127.0.0.1:6379/0"
|
||||
|
||||
# JWT 设置
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-here")
|
||||
ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(
|
||||
os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")
|
||||
)
|
||||
secret_key: str = Field(default="your-secret-key-here", alias="jwt_secret_key")
|
||||
algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 1440
|
||||
|
||||
# OAuth 设置
|
||||
OSU_CLIENT_ID: str = os.getenv("OSU_CLIENT_ID", "5")
|
||||
OSU_CLIENT_SECRET: str = os.getenv(
|
||||
"OSU_CLIENT_SECRET", "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
|
||||
)
|
||||
osu_client_id: str = "5"
|
||||
osu_client_secret: str = "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
|
||||
|
||||
# 服务器设置
|
||||
HOST: str = os.getenv("HOST", "0.0.0.0")
|
||||
PORT: int = int(os.getenv("PORT", "8000"))
|
||||
DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true"
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
debug: bool = False
|
||||
|
||||
# SignalR 设置
|
||||
SIGNALR_NEGOTIATE_TIMEOUT: int = int(os.getenv("SIGNALR_NEGOTIATE_TIMEOUT", "30"))
|
||||
SIGNALR_PING_INTERVAL: int = int(os.getenv("SIGNALR_PING_INTERVAL", "15"))
|
||||
signalr_negotiate_timeout: int = 30
|
||||
signalr_ping_interval: int = 15
|
||||
|
||||
# Fetcher 设置
|
||||
FETCHER_CLIENT_ID: str = os.getenv("FETCHER_CLIENT_ID", "")
|
||||
FETCHER_CLIENT_SECRET: str = os.getenv("FETCHER_CLIENT_SECRET", "")
|
||||
FETCHER_SCOPES: list[str] = os.getenv("FETCHER_SCOPES", "public").split(",")
|
||||
FETCHER_CALLBACK_URL: str = os.getenv(
|
||||
"FETCHER_CALLBACK_URL", "http://localhost:8000/fetcher/callback"
|
||||
)
|
||||
fetcher_client_id: str = ""
|
||||
fetcher_client_secret: str = ""
|
||||
fetcher_scopes: Annotated[list[str], NoDecode] = ["public"]
|
||||
fetcher_callback_url: str = "http://localhost:8000/fetcher/callback"
|
||||
|
||||
# 日志设置
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
log_level: str = "INFO"
|
||||
|
||||
@field_validator("fetcher_scopes", mode="before")
|
||||
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
|
||||
if isinstance(v, str):
|
||||
return v.split(",")
|
||||
return v
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -18,10 +18,10 @@ def json_serializer(value):
|
||||
|
||||
|
||||
# 数据库引擎
|
||||
engine = create_async_engine(settings.DATABASE_URL, json_serializer=json_serializer)
|
||||
engine = create_async_engine(settings.database_url, json_serializer=json_serializer)
|
||||
|
||||
# Redis 连接
|
||||
redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
redis_client = redis.from_url(settings.redis_url, decode_responses=True)
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
|
||||
@@ -12,10 +12,10 @@ async def get_fetcher() -> Fetcher:
|
||||
global fetcher
|
||||
if fetcher is None:
|
||||
fetcher = Fetcher(
|
||||
settings.FETCHER_CLIENT_ID,
|
||||
settings.FETCHER_CLIENT_SECRET,
|
||||
settings.FETCHER_SCOPES,
|
||||
settings.FETCHER_CALLBACK_URL,
|
||||
settings.fetcher_client_id,
|
||||
settings.fetcher_client_secret,
|
||||
settings.fetcher_scopes,
|
||||
settings.fetcher_callback_url,
|
||||
)
|
||||
redis = get_redis()
|
||||
access_token = await redis.get(f"fetcher:access_token:{fetcher.client_id}")
|
||||
|
||||
@@ -120,10 +120,10 @@ logger.add(
|
||||
format=(
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> [<level>{level}</level>] | {message}"
|
||||
),
|
||||
level=settings.LOG_LEVEL,
|
||||
diagnose=settings.DEBUG,
|
||||
level=settings.log_level,
|
||||
diagnose=settings.debug,
|
||||
)
|
||||
logging.basicConfig(handlers=[InterceptHandler()], level=settings.LOG_LEVEL, force=True)
|
||||
logging.basicConfig(handlers=[InterceptHandler()], level=settings.log_level, force=True)
|
||||
|
||||
uvicorn_loggers = [
|
||||
"uvicorn",
|
||||
|
||||
@@ -199,8 +199,8 @@ async def oauth_token(
|
||||
"""OAuth 令牌端点"""
|
||||
# 验证客户端凭据
|
||||
if (
|
||||
client_id != settings.OSU_CLIENT_ID
|
||||
or client_secret != settings.OSU_CLIENT_SECRET
|
||||
client_id != settings.osu_client_id
|
||||
or client_secret != settings.osu_client_secret
|
||||
):
|
||||
return create_oauth_error_response(
|
||||
error="invalid_client",
|
||||
@@ -242,7 +242,7 @@ async def oauth_token(
|
||||
)
|
||||
|
||||
# 生成令牌
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(user.id)}, expires_delta=access_token_expires
|
||||
)
|
||||
@@ -255,13 +255,13 @@ async def oauth_token(
|
||||
user.id,
|
||||
access_token,
|
||||
refresh_token_str,
|
||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=refresh_token_str,
|
||||
scope=scope,
|
||||
)
|
||||
@@ -295,7 +295,7 @@ async def oauth_token(
|
||||
)
|
||||
|
||||
# 生成新的访问令牌
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
access_token = create_access_token(
|
||||
data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires
|
||||
)
|
||||
@@ -307,13 +307,13 @@ async def oauth_token(
|
||||
token_record.user_id,
|
||||
access_token,
|
||||
new_refresh_token,
|
||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
refresh_token=new_refresh_token,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
@@ -74,7 +74,7 @@ class Client:
|
||||
while True:
|
||||
try:
|
||||
await self.send_packet(PingPacket())
|
||||
await asyncio.sleep(settings.SIGNALR_PING_INTERVAL)
|
||||
await asyncio.sleep(settings.signalr_ping_interval)
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except Exception as e:
|
||||
@@ -131,7 +131,7 @@ class Hub[TState: UserState]:
|
||||
if connection_token in self.waited_clients:
|
||||
if (
|
||||
self.waited_clients[connection_token]
|
||||
< time.time() - settings.SIGNALR_NEGOTIATE_TIMEOUT
|
||||
< time.time() - settings.signalr_negotiate_timeout
|
||||
):
|
||||
raise TimeoutError(f"Connection {connection_id} has waited too long.")
|
||||
del self.waited_clients[connection_token]
|
||||
|
||||
Reference in New Issue
Block a user