feat(storage): support cloud storage
This commit is contained in:
@@ -1,11 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic import Field, ValidationInfo, field_validator
|
||||
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
|
||||
|
||||
|
||||
class AWSS3StorageSettings(BaseSettings):
|
||||
s3_access_key_id: str
|
||||
s3_secret_access_key: str
|
||||
s3_bucket_name: str
|
||||
s3_region_name: str
|
||||
s3_public_url_base: str | None = None
|
||||
|
||||
|
||||
class CloudflareR2Settings(BaseSettings):
|
||||
r2_account_id: str
|
||||
r2_access_key_id: str
|
||||
r2_secret_access_key: str
|
||||
r2_bucket_name: str
|
||||
r2_public_url_base: str | None = None
|
||||
|
||||
|
||||
class LocalStorageSettings(BaseSettings):
|
||||
local_storage_path: str = "./storage"
|
||||
|
||||
|
||||
class StorageServiceType(str, Enum):
|
||||
LOCAL = "local"
|
||||
CLOUDFLARE_R2 = "r2"
|
||||
AWS_S3 = "s3"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
@@ -60,11 +87,43 @@ class Settings(BaseSettings):
|
||||
enable_all_beatmap_leaderboard: bool = False
|
||||
seasonal_backgrounds: list[str] = []
|
||||
|
||||
# 存储设置
|
||||
storage_service: StorageServiceType = StorageServiceType.LOCAL
|
||||
storage_settings: (
|
||||
LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings
|
||||
) = LocalStorageSettings()
|
||||
|
||||
@field_validator("fetcher_scopes", mode="before")
|
||||
def validate_fetcher_scopes(cls, v: Any) -> list[str]:
|
||||
if isinstance(v, str):
|
||||
return v.split(",")
|
||||
return v
|
||||
|
||||
@field_validator("storage_settings", mode="after")
|
||||
def validate_storage_settings(
|
||||
cls,
|
||||
v: LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings,
|
||||
info: ValidationInfo,
|
||||
) -> LocalStorageSettings | CloudflareR2Settings | AWSS3StorageSettings:
|
||||
if info.data.get("storage_service") == StorageServiceType.CLOUDFLARE_R2:
|
||||
if not isinstance(v, CloudflareR2Settings):
|
||||
raise ValueError(
|
||||
"When storage_service is 'r2', "
|
||||
"storage_settings must be CloudflareR2Settings"
|
||||
)
|
||||
elif info.data.get("storage_service") == StorageServiceType.LOCAL:
|
||||
if not isinstance(v, LocalStorageSettings):
|
||||
raise ValueError(
|
||||
"When storage_service is 'local', "
|
||||
"storage_settings must be LocalStorageSettings"
|
||||
)
|
||||
elif info.data.get("storage_service") == StorageServiceType.AWS_S3:
|
||||
if not isinstance(v, AWSS3StorageSettings):
|
||||
raise ValueError(
|
||||
"When storage_service is 's3', "
|
||||
"storage_settings must be AWSS3StorageSettings"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
52
app/dependencies/storage.py
Normal file
52
app/dependencies/storage.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
from app.config import (
|
||||
AWSS3StorageSettings,
|
||||
CloudflareR2Settings,
|
||||
LocalStorageSettings,
|
||||
StorageServiceType,
|
||||
settings,
|
||||
)
|
||||
from app.storage import StorageService
|
||||
from app.storage.cloudflare_r2 import AWSS3StorageService, CloudflareR2StorageService
|
||||
from app.storage.local import LocalStorageService
|
||||
|
||||
storage: StorageService | None = None
|
||||
|
||||
|
||||
def init_storage_service():
|
||||
global storage
|
||||
if settings.storage_service == StorageServiceType.LOCAL:
|
||||
storage_settings = cast(LocalStorageSettings, settings.storage_settings)
|
||||
storage = LocalStorageService(
|
||||
storage_path=storage_settings.local_storage_path,
|
||||
)
|
||||
elif settings.storage_service == StorageServiceType.CLOUDFLARE_R2:
|
||||
storage_settings = cast(CloudflareR2Settings, settings.storage_settings)
|
||||
storage = CloudflareR2StorageService(
|
||||
account_id=storage_settings.r2_account_id,
|
||||
access_key_id=storage_settings.r2_access_key_id,
|
||||
secret_access_key=storage_settings.r2_secret_access_key,
|
||||
bucket_name=storage_settings.r2_bucket_name,
|
||||
public_url_base=storage_settings.r2_public_url_base,
|
||||
)
|
||||
elif settings.storage_service == StorageServiceType.AWS_S3:
|
||||
storage_settings = cast(AWSS3StorageSettings, settings.storage_settings)
|
||||
storage = AWSS3StorageService(
|
||||
access_key_id=storage_settings.s3_access_key_id,
|
||||
secret_access_key=storage_settings.s3_secret_access_key,
|
||||
bucket_name=storage_settings.s3_bucket_name,
|
||||
public_url_base=storage_settings.s3_public_url_base,
|
||||
region_name=storage_settings.s3_region_name,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported storage service: {settings.storage_service}")
|
||||
return storage
|
||||
|
||||
|
||||
def get_storage_service():
|
||||
if storage is None:
|
||||
return init_storage_service()
|
||||
return storage
|
||||
13
app/storage/__init__.py
Normal file
13
app/storage/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .aws_s3 import AWSS3StorageService
|
||||
from .base import StorageService
|
||||
from .cloudflare_r2 import CloudflareR2StorageService
|
||||
from .local import LocalStorageService
|
||||
|
||||
__all__ = [
|
||||
"AWSS3StorageService",
|
||||
"CloudflareR2StorageService",
|
||||
"LocalStorageService",
|
||||
"StorageService",
|
||||
]
|
||||
103
app/storage/aws_s3.py
Normal file
103
app/storage/aws_s3.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import StorageService
|
||||
|
||||
import aioboto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
|
||||
class AWSS3StorageService(StorageService):
|
||||
def __init__(
|
||||
self,
|
||||
access_key_id: str,
|
||||
secret_access_key: str,
|
||||
bucket_name: str,
|
||||
region_name: str,
|
||||
public_url_base: str | None = None,
|
||||
):
|
||||
self.bucket_name = bucket_name
|
||||
self.public_url_base = public_url_base
|
||||
self.session = aioboto3.Session()
|
||||
self.access_key_id = access_key_id
|
||||
self.secret_access_key = secret_access_key
|
||||
self.region_name = region_name
|
||||
|
||||
@property
|
||||
def endpoint_url(self) -> str | None:
|
||||
return None
|
||||
|
||||
def _get_client(self):
|
||||
return self.session.client(
|
||||
"s3",
|
||||
endpoint_url=self.endpoint_url,
|
||||
aws_access_key_id=self.access_key_id,
|
||||
aws_secret_access_key=self.secret_access_key,
|
||||
region_name=self.region_name,
|
||||
)
|
||||
|
||||
async def write_file(
|
||||
self,
|
||||
file_path: str,
|
||||
content: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
cache_control: str = "public, max-age=31536000",
|
||||
) -> None:
|
||||
async with self._get_client() as client:
|
||||
await client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=file_path,
|
||||
Body=content,
|
||||
ContentType=content_type,
|
||||
CacheControl=cache_control,
|
||||
)
|
||||
|
||||
async def read_file(self, file_path: str) -> bytes:
|
||||
async with self._get_client() as client:
|
||||
try:
|
||||
response = await client.get_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=file_path,
|
||||
)
|
||||
async with response["Body"] as stream:
|
||||
return await stream.read()
|
||||
except ClientError as e:
|
||||
if e.response.get("Error", {}).get("Code") == "404":
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
raise RuntimeError(f"Failed to read file from R2: {e}")
|
||||
|
||||
async def delete_file(self, file_path: str) -> None:
|
||||
async with self._get_client() as client:
|
||||
try:
|
||||
await client.delete_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=file_path,
|
||||
)
|
||||
except ClientError as e:
|
||||
raise RuntimeError(f"Failed to delete file from R2: {e}")
|
||||
|
||||
async def is_exists(self, file_path: str) -> bool:
|
||||
async with self._get_client() as client:
|
||||
try:
|
||||
await client.head_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=file_path,
|
||||
)
|
||||
return True
|
||||
except ClientError as e:
|
||||
if e.response.get("Error", {}).get("Code") == "404":
|
||||
return False
|
||||
raise RuntimeError(f"Failed to check file existence in R2: {e}")
|
||||
|
||||
async def get_file_url(self, file_path: str) -> str:
|
||||
if self.public_url_base:
|
||||
return f"{self.public_url_base.rstrip('/')}/{file_path.lstrip('/')}"
|
||||
|
||||
async with self._get_client() as client:
|
||||
try:
|
||||
url = await client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": self.bucket_name, "Key": file_path},
|
||||
)
|
||||
return url
|
||||
except ClientError as e:
|
||||
raise RuntimeError(f"Failed to generate file URL: {e}")
|
||||
34
app/storage/base.py
Normal file
34
app/storage/base.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
|
||||
class StorageService(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def write_file(
|
||||
self,
|
||||
file_path: str,
|
||||
content: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
cache_control: str = "public, max-age=31536000",
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def read_file(self, file_path: str) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_file(self, file_path: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def is_exists(self, file_path: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_file_url(self, file_path: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
26
app/storage/cloudflare_r2.py
Normal file
26
app/storage/cloudflare_r2.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .aws_s3 import AWSS3StorageService
|
||||
|
||||
|
||||
class CloudflareR2StorageService(AWSS3StorageService):
|
||||
def __init__(
|
||||
self,
|
||||
account_id: str,
|
||||
access_key_id: str,
|
||||
secret_access_key: str,
|
||||
bucket_name: str,
|
||||
public_url_base: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
access_key_id=access_key_id,
|
||||
secret_access_key=secret_access_key,
|
||||
bucket_name=bucket_name,
|
||||
public_url_base=public_url_base,
|
||||
region_name="auto",
|
||||
)
|
||||
self.account_id = account_id
|
||||
|
||||
@property
|
||||
def endpoint_url(self) -> str:
|
||||
return f"https://{self.account_id}.r2.cloudflarestorage.com"
|
||||
78
app/storage/local.py
Normal file
78
app/storage/local.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .base import StorageService
|
||||
|
||||
import aiofiles
|
||||
|
||||
|
||||
class LocalStorageService(StorageService):
|
||||
def __init__(
|
||||
self,
|
||||
storage_path: str,
|
||||
):
|
||||
self.storage_path = Path(storage_path).resolve()
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_file_path(self, file_path: str) -> Path:
|
||||
clean_path = file_path.lstrip("/")
|
||||
full_path = self.storage_path / clean_path
|
||||
|
||||
try:
|
||||
full_path.resolve().relative_to(self.storage_path)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid file path: {file_path}")
|
||||
|
||||
return full_path
|
||||
|
||||
async def write_file(
|
||||
self,
|
||||
file_path: str,
|
||||
content: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
cache_control: str = "public, max-age=31536000",
|
||||
) -> None:
|
||||
full_path = self._get_file_path(file_path)
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
async with aiofiles.open(full_path, "wb") as f:
|
||||
await f.write(content)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Failed to write file: {e}")
|
||||
|
||||
async def read_file(self, file_path: str) -> bytes:
|
||||
full_path = self._get_file_path(file_path)
|
||||
|
||||
if not full_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
try:
|
||||
async with aiofiles.open(full_path, "rb") as f:
|
||||
return await f.read()
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Failed to read file: {e}")
|
||||
|
||||
async def delete_file(self, file_path: str) -> None:
|
||||
full_path = self._get_file_path(file_path)
|
||||
|
||||
if not full_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
full_path.unlink()
|
||||
|
||||
parent = full_path.parent
|
||||
while parent != self.storage_path and not any(parent.iterdir()):
|
||||
parent.rmdir()
|
||||
parent = parent.parent
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Failed to delete file: {e}")
|
||||
|
||||
async def is_exists(self, file_path: str) -> bool:
|
||||
full_path = self._get_file_path(file_path)
|
||||
return full_path.exists() and full_path.is_file()
|
||||
|
||||
async def get_file_url(self, file_path: str) -> str:
|
||||
return str(self.storage_path / file_path)
|
||||
Reference in New Issue
Block a user