refactor(database): use a new 'On-Demand' design (#86)
Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
This commit is contained in:
499
app/database/_base.py
Normal file
499
app/database/_base.py
Normal file
@@ -0,0 +1,499 @@
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from functools import lru_cache, wraps
|
||||
import inspect
|
||||
import sys
|
||||
from types import NoneType, get_original_bases
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Concatenate,
|
||||
ForwardRef,
|
||||
ParamSpec,
|
||||
TypedDict,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
overload,
|
||||
)
|
||||
|
||||
from app.models.model import UTCBaseModel
|
||||
from app.utils import type_is_optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_object_session
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import SQLModelMetaclass
|
||||
|
||||
_dict_to_model: dict[type, type["DatabaseModel"]] = {}
|
||||
|
||||
|
||||
def _safe_evaluate_forwardref(type_: str | ForwardRef, module_name: str) -> Any:
|
||||
"""Safely evaluate a ForwardRef, with fallback to app.database module"""
|
||||
if isinstance(type_, str):
|
||||
type_ = ForwardRef(type_)
|
||||
|
||||
try:
|
||||
return evaluate_forwardref(
|
||||
type_,
|
||||
globalns=vars(sys.modules[module_name]),
|
||||
localns={},
|
||||
)
|
||||
except (NameError, AttributeError, KeyError):
|
||||
# Fallback to app.database module
|
||||
try:
|
||||
import app.database
|
||||
|
||||
return evaluate_forwardref(
|
||||
type_,
|
||||
globalns=vars(app.database),
|
||||
localns={},
|
||||
)
|
||||
except (NameError, AttributeError, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
class OnDemand[T]:
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def __get__(self, instance: object | None, owner: Any) -> T: ...
|
||||
|
||||
def __set__(self, instance: Any, value: T) -> None: ...
|
||||
|
||||
def __delete__(self, instance: Any) -> None: ...
|
||||
|
||||
|
||||
class Exclude[T]:
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def __get__(self, instance: object | None, owner: Any) -> T: ...
|
||||
|
||||
def __set__(self, instance: Any, value: T) -> None: ...
|
||||
|
||||
def __delete__(self, instance: Any) -> None: ...
|
||||
|
||||
|
||||
# https://github.com/fastapi/sqlmodel/blob/main/sqlmodel/_compat.py#L126-L140
|
||||
def _get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
raw_annotations: dict[str, Any] = class_dict.get("__annotations__", {})
|
||||
if sys.version_info >= (3, 14) and "__annotations__" not in class_dict:
|
||||
# See https://github.com/pydantic/pydantic/pull/11991
|
||||
from annotationlib import (
|
||||
Format,
|
||||
call_annotate_function,
|
||||
get_annotate_from_class_namespace,
|
||||
)
|
||||
|
||||
if annotate := get_annotate_from_class_namespace(class_dict):
|
||||
raw_annotations = call_annotate_function(annotate, format=Format.FORWARDREF)
|
||||
return raw_annotations
|
||||
|
||||
|
||||
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L58-L77
|
||||
if sys.version_info < (3, 12, 4):
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
# Even though it is the right signature for python 3.9, mypy complains with
|
||||
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
|
||||
# Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid:
|
||||
# TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard'
|
||||
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
|
||||
|
||||
else:
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
# Pydantic 1.x will not support PEP 695 syntax, but provide `type_params` to avoid
|
||||
# warnings:
|
||||
return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set())
|
||||
|
||||
|
||||
class DatabaseModelMetaclass(SQLModelMetaclass):
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> "DatabaseModelMetaclass":
|
||||
original_annotations = _get_annotations(namespace)
|
||||
new_annotations = {}
|
||||
ondemands = []
|
||||
excludes = []
|
||||
|
||||
for k, v in original_annotations.items():
|
||||
if get_origin(v) is OnDemand:
|
||||
inner_type = v.__args__[0]
|
||||
new_annotations[k] = inner_type
|
||||
ondemands.append(k)
|
||||
elif get_origin(v) is Exclude:
|
||||
inner_type = v.__args__[0]
|
||||
new_annotations[k] = inner_type
|
||||
excludes.append(k)
|
||||
else:
|
||||
new_annotations[k] = v
|
||||
|
||||
new_class = super().__new__(
|
||||
cls,
|
||||
name,
|
||||
bases,
|
||||
{
|
||||
**namespace,
|
||||
"__annotations__": new_annotations,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
new_class._CALCULATED_FIELDS = dict(getattr(new_class, "_CALCULATED_FIELDS", {}))
|
||||
new_class._ONDEMAND_DATABASE_FIELDS = list(getattr(new_class, "_ONDEMAND_DATABASE_FIELDS", [])) + list(
|
||||
ondemands
|
||||
)
|
||||
new_class._ONDEMAND_CALCULATED_FIELDS = dict(getattr(new_class, "_ONDEMAND_CALCULATED_FIELDS", {}))
|
||||
new_class._EXCLUDED_DATABASE_FIELDS = list(getattr(new_class, "_EXCLUDED_DATABASE_FIELDS", [])) + list(excludes)
|
||||
|
||||
for attr_name, attr_value in namespace.items():
|
||||
target = _get_callable_target(attr_value)
|
||||
if target is None:
|
||||
continue
|
||||
|
||||
if getattr(target, "__included__", False):
|
||||
new_class._CALCULATED_FIELDS[attr_name] = _get_return_type(target)
|
||||
_pre_calculate_context_params(target, attr_value)
|
||||
|
||||
if getattr(target, "__calculated_ondemand__", False):
|
||||
new_class._ONDEMAND_CALCULATED_FIELDS[attr_name] = _get_return_type(target)
|
||||
_pre_calculate_context_params(target, attr_value)
|
||||
|
||||
# Register TDict to DatabaseModel mapping
|
||||
for base in get_original_bases(new_class):
|
||||
cls_name = base.__name__
|
||||
if "DatabaseModel" in cls_name and "[" in cls_name and "]" in cls_name:
|
||||
generic_type_name = cls_name[cls_name.index("[") : cls_name.rindex("]") + 1]
|
||||
generic_type = evaluate_forwardref(
|
||||
ForwardRef(generic_type_name),
|
||||
globalns=vars(sys.modules[new_class.__module__]),
|
||||
localns={},
|
||||
)
|
||||
_dict_to_model[generic_type[0]] = new_class
|
||||
|
||||
return new_class
|
||||
|
||||
|
||||
def _pre_calculate_context_params(target: Callable, attr_value: Any) -> None:
|
||||
if hasattr(target, "__context_params__"):
|
||||
return
|
||||
|
||||
sig = inspect.signature(target)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
start_index = 2
|
||||
if isinstance(attr_value, classmethod):
|
||||
start_index = 3
|
||||
|
||||
context_params = [] if len(params) < start_index else params[start_index:]
|
||||
|
||||
setattr(target, "__context_params__", context_params)
|
||||
|
||||
|
||||
def _get_callable_target(value: Any) -> Callable | None:
|
||||
if isinstance(value, (staticmethod, classmethod)):
|
||||
return value.__func__
|
||||
if inspect.isfunction(value):
|
||||
return value
|
||||
if inspect.ismethod(value):
|
||||
return value.__func__
|
||||
return None
|
||||
|
||||
|
||||
def _mark_callable(value: Any, flag: str) -> Callable | None:
|
||||
target = _get_callable_target(value)
|
||||
if target is None:
|
||||
return None
|
||||
setattr(target, flag, True)
|
||||
return target
|
||||
|
||||
|
||||
def _get_return_type(func: Callable) -> type:
|
||||
sig = inspect.get_annotations(func)
|
||||
return sig.get("return", Any)
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
CalculatedField = Callable[Concatenate[AsyncSession, Any, P], Awaitable[Any]]
|
||||
DecoratorTarget = CalculatedField | staticmethod | classmethod
|
||||
|
||||
|
||||
def included(func: DecoratorTarget) -> DecoratorTarget:
|
||||
marker = _mark_callable(func, "__included__")
|
||||
if marker is None:
|
||||
raise RuntimeError("@included is only usable on callables.")
|
||||
|
||||
@wraps(marker)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await marker(*args, **kwargs)
|
||||
|
||||
if isinstance(func, staticmethod):
|
||||
return staticmethod(wrapper)
|
||||
if isinstance(func, classmethod):
|
||||
return classmethod(wrapper)
|
||||
return wrapper
|
||||
|
||||
|
||||
def ondemand(func: DecoratorTarget) -> DecoratorTarget:
|
||||
marker = _mark_callable(func, "__calculated_ondemand__")
|
||||
if marker is None:
|
||||
raise RuntimeError("@ondemand is only usable on callables.")
|
||||
|
||||
@wraps(marker)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await marker(*args, **kwargs)
|
||||
|
||||
if isinstance(func, staticmethod):
|
||||
return staticmethod(wrapper)
|
||||
if isinstance(func, classmethod):
|
||||
return classmethod(wrapper)
|
||||
return wrapper
|
||||
|
||||
|
||||
async def call_awaitable_with_context(
|
||||
func: CalculatedField,
|
||||
session: AsyncSession,
|
||||
instance: Any,
|
||||
context: dict[str, Any],
|
||||
) -> Any:
|
||||
context_params: list[str] | None = getattr(func, "__context_params__", None)
|
||||
|
||||
if context_params is None:
|
||||
# Fallback if not pre-calculated
|
||||
sig = inspect.signature(func)
|
||||
if len(sig.parameters) == 2:
|
||||
return await func(session, instance)
|
||||
else:
|
||||
call_params = {}
|
||||
for param in sig.parameters.values():
|
||||
if param.name in context:
|
||||
call_params[param.name] = context[param.name]
|
||||
return await func(session, instance, **call_params)
|
||||
|
||||
if not context_params:
|
||||
return await func(session, instance)
|
||||
|
||||
call_params = {}
|
||||
for name in context_params:
|
||||
if name in context:
|
||||
call_params[name] = context[name]
|
||||
return await func(session, instance, **call_params)
|
||||
|
||||
|
||||
class DatabaseModel[TDict](SQLModel, UTCBaseModel, metaclass=DatabaseModelMetaclass):
|
||||
_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
|
||||
|
||||
_ONDEMAND_DATABASE_FIELDS: ClassVar[list[str]] = []
|
||||
_ONDEMAND_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
|
||||
|
||||
_EXCLUDED_DATABASE_FIELDS: ClassVar[list[str]] = []
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def transform(
|
||||
cls,
|
||||
db_instance: "DatabaseModel",
|
||||
*,
|
||||
session: AsyncSession,
|
||||
includes: list[str] | None = None,
|
||||
**context: Any,
|
||||
) -> TDict: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def transform(
|
||||
cls,
|
||||
db_instance: "DatabaseModel",
|
||||
*,
|
||||
includes: list[str] | None = None,
|
||||
**context: Any,
|
||||
) -> TDict: ...
|
||||
|
||||
@classmethod
|
||||
async def transform(
|
||||
cls,
|
||||
db_instance: "DatabaseModel",
|
||||
*,
|
||||
session: AsyncSession | None = None,
|
||||
includes: list[str] | None = None,
|
||||
**context: Any,
|
||||
) -> TDict:
|
||||
includes = includes.copy() if includes is not None else []
|
||||
session = cast(AsyncSession | None, async_object_session(db_instance)) if session is None else session
|
||||
if session is None:
|
||||
raise RuntimeError("DatabaseModel.transform requires a session-bound instance.")
|
||||
resp_obj = cls.model_validate(db_instance.model_dump())
|
||||
data = resp_obj.model_dump()
|
||||
|
||||
for field in cls._CALCULATED_FIELDS:
|
||||
func = getattr(cls, field)
|
||||
value = await call_awaitable_with_context(func, session, db_instance, context)
|
||||
data[field] = value
|
||||
|
||||
sub_include_map: dict[str, list[str]] = {}
|
||||
for include in [i for i in includes if "." in i]:
|
||||
parent, sub_include = include.split(".", 1)
|
||||
if parent not in sub_include_map:
|
||||
sub_include_map[parent] = []
|
||||
sub_include_map[parent].append(sub_include)
|
||||
includes.remove(include) # pyright: ignore[reportOptionalMemberAccess]
|
||||
|
||||
for field, sub_includes in sub_include_map.items():
|
||||
if field in cls._ONDEMAND_CALCULATED_FIELDS:
|
||||
func = getattr(cls, field)
|
||||
value = await call_awaitable_with_context(
|
||||
func, session, db_instance, {**context, "includes": sub_includes}
|
||||
)
|
||||
data[field] = value
|
||||
|
||||
for include in includes:
|
||||
if include in data:
|
||||
continue
|
||||
|
||||
if include in cls._ONDEMAND_CALCULATED_FIELDS:
|
||||
func = getattr(cls, include)
|
||||
value = await call_awaitable_with_context(func, session, db_instance, context)
|
||||
data[include] = value
|
||||
|
||||
for field in cls._ONDEMAND_DATABASE_FIELDS:
|
||||
if field not in includes:
|
||||
del data[field]
|
||||
|
||||
for field in cls._EXCLUDED_DATABASE_FIELDS:
|
||||
if field in data:
|
||||
del data[field]
|
||||
|
||||
return cast(TDict, data)
|
||||
|
||||
@classmethod
|
||||
async def transform_many(
|
||||
cls,
|
||||
db_instances: Sequence["DatabaseModel"],
|
||||
*,
|
||||
session: AsyncSession | None = None,
|
||||
includes: list[str] | None = None,
|
||||
**context: Any,
|
||||
) -> list[TDict]:
|
||||
if not db_instances:
|
||||
return []
|
||||
|
||||
# SQLAlchemy AsyncSession is not concurrency-safe, so we cannot use asyncio.gather here
|
||||
# if the transform method performs any database operations using the shared session.
|
||||
# Since we don't know if the transform method (or its calculated fields) will use the DB,
|
||||
# we must execute them serially to be safe.
|
||||
results = []
|
||||
for instance in db_instances:
|
||||
results.append(await cls.transform(instance, session=session, includes=includes, **context))
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
@lru_cache
|
||||
def generate_typeddict(cls, includes: tuple[str, ...] | None = None) -> type[TypedDict]: # pyright: ignore[reportInvalidTypeForm]
|
||||
def _evaluate_type(field_type: Any, *, resolve_database_model: bool = False, field_name: str = "") -> Any:
|
||||
# Evaluate ForwardRef if present
|
||||
if isinstance(field_type, (str, ForwardRef)):
|
||||
resolved = _safe_evaluate_forwardref(field_type, cls.__module__)
|
||||
if resolved is not None:
|
||||
field_type = resolved
|
||||
|
||||
origin_type = get_origin(field_type)
|
||||
inner_type = field_type
|
||||
args = get_args(field_type)
|
||||
|
||||
is_optional = type_is_optional(field_type) # pyright: ignore[reportArgumentType]
|
||||
if is_optional:
|
||||
inner_type = next((arg for arg in args if arg is not NoneType), field_type)
|
||||
|
||||
is_list = False
|
||||
if origin_type is list:
|
||||
is_list = True
|
||||
inner_type = args[0]
|
||||
|
||||
# Evaluate ForwardRef in inner_type if present
|
||||
if isinstance(inner_type, (str, ForwardRef)):
|
||||
resolved = _safe_evaluate_forwardref(inner_type, cls.__module__)
|
||||
if resolved is not None:
|
||||
inner_type = resolved
|
||||
|
||||
if not resolve_database_model:
|
||||
if is_optional:
|
||||
return inner_type | None # pyright: ignore[reportOperatorIssue]
|
||||
elif is_list:
|
||||
return list[inner_type]
|
||||
return inner_type
|
||||
|
||||
model_class = None
|
||||
|
||||
# First check if inner_type is directly a DatabaseModel subclass
|
||||
try:
|
||||
if inspect.isclass(inner_type) and issubclass(inner_type, DatabaseModel): # type: ignore
|
||||
model_class = inner_type
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
# If not found, look up in _dict_to_model
|
||||
if model_class is None:
|
||||
model_class = _dict_to_model.get(inner_type) # type: ignore
|
||||
|
||||
if model_class is not None:
|
||||
nested_dict = model_class.generate_typeddict(tuple(sub_include_map.get(field_name, ())))
|
||||
resolved_type = list[nested_dict] if is_list else nested_dict # type: ignore
|
||||
|
||||
if is_optional:
|
||||
resolved_type = resolved_type | None # type: ignore
|
||||
|
||||
return resolved_type
|
||||
|
||||
# Fallback: use the resolved inner_type
|
||||
resolved_type = list[inner_type] if is_list else inner_type # type: ignore
|
||||
if is_optional:
|
||||
resolved_type = resolved_type | None # type: ignore
|
||||
return resolved_type
|
||||
|
||||
if includes is None:
|
||||
includes = ()
|
||||
|
||||
# Parse nested includes
|
||||
direct_includes = []
|
||||
sub_include_map: dict[str, list[str]] = {}
|
||||
for include in includes:
|
||||
if "." in include:
|
||||
parent, sub_include = include.split(".", 1)
|
||||
if parent not in sub_include_map:
|
||||
sub_include_map[parent] = []
|
||||
sub_include_map[parent].append(sub_include)
|
||||
if parent not in direct_includes:
|
||||
direct_includes.append(parent)
|
||||
else:
|
||||
direct_includes.append(include)
|
||||
|
||||
fields = {}
|
||||
|
||||
# Process model fields
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
field_type = field_info.annotation or Any
|
||||
field_type = _evaluate_type(field_type, field_name=field_name)
|
||||
|
||||
if field_name in cls._ONDEMAND_DATABASE_FIELDS and field_name not in direct_includes:
|
||||
continue
|
||||
else:
|
||||
fields[field_name] = field_type
|
||||
|
||||
# Process calculated fields
|
||||
for field_name, field_type in cls._CALCULATED_FIELDS.items():
|
||||
field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
|
||||
fields[field_name] = field_type
|
||||
|
||||
# Process ondemand calculated fields
|
||||
for field_name, field_type in cls._ONDEMAND_CALCULATED_FIELDS.items():
|
||||
if field_name not in direct_includes:
|
||||
continue
|
||||
|
||||
field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
|
||||
fields[field_name] = field_type
|
||||
|
||||
return TypedDict(f"{cls.__name__}Dict[{', '.join(includes)}]" if includes else f"{cls.__name__}Dict", fields) # pyright: ignore[reportArgumentType]
|
||||
Reference in New Issue
Block a user