Technical Details: https://blog.mxgame.top/2025/11/22/An-On-Demand-Design-Within-SQLModel/
500 lines
17 KiB
Python
500 lines
17 KiB
Python
from collections.abc import Awaitable, Callable, Sequence
|
|
from functools import lru_cache, wraps
|
|
import inspect
|
|
import sys
|
|
from types import NoneType, get_original_bases
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
ClassVar,
|
|
Concatenate,
|
|
ForwardRef,
|
|
ParamSpec,
|
|
TypedDict,
|
|
cast,
|
|
get_args,
|
|
get_origin,
|
|
overload,
|
|
)
|
|
|
|
from app.models.model import UTCBaseModel
|
|
from app.utils import type_is_optional
|
|
|
|
from sqlalchemy.ext.asyncio import async_object_session
|
|
from sqlmodel import SQLModel
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
from sqlmodel.main import SQLModelMetaclass
|
|
|
|
_dict_to_model: dict[type, type["DatabaseModel"]] = {}
|
|
|
|
|
|
def _safe_evaluate_forwardref(type_: str | ForwardRef, module_name: str) -> Any:
|
|
"""Safely evaluate a ForwardRef, with fallback to app.database module"""
|
|
if isinstance(type_, str):
|
|
type_ = ForwardRef(type_)
|
|
|
|
try:
|
|
return evaluate_forwardref(
|
|
type_,
|
|
globalns=vars(sys.modules[module_name]),
|
|
localns={},
|
|
)
|
|
except (NameError, AttributeError, KeyError):
|
|
# Fallback to app.database module
|
|
try:
|
|
import app.database
|
|
|
|
return evaluate_forwardref(
|
|
type_,
|
|
globalns=vars(app.database),
|
|
localns={},
|
|
)
|
|
except (NameError, AttributeError, KeyError):
|
|
return None
|
|
|
|
|
|
class OnDemand[T]:
|
|
if TYPE_CHECKING:
|
|
|
|
def __get__(self, instance: object | None, owner: Any) -> T: ...
|
|
|
|
def __set__(self, instance: Any, value: T) -> None: ...
|
|
|
|
def __delete__(self, instance: Any) -> None: ...
|
|
|
|
|
|
class Exclude[T]:
|
|
if TYPE_CHECKING:
|
|
|
|
def __get__(self, instance: object | None, owner: Any) -> T: ...
|
|
|
|
def __set__(self, instance: Any, value: T) -> None: ...
|
|
|
|
def __delete__(self, instance: Any) -> None: ...
|
|
|
|
|
|
# https://github.com/fastapi/sqlmodel/blob/main/sqlmodel/_compat.py#L126-L140
|
|
def _get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
|
|
raw_annotations: dict[str, Any] = class_dict.get("__annotations__", {})
|
|
if sys.version_info >= (3, 14) and "__annotations__" not in class_dict:
|
|
# See https://github.com/pydantic/pydantic/pull/11991
|
|
from annotationlib import (
|
|
Format,
|
|
call_annotate_function,
|
|
get_annotate_from_class_namespace,
|
|
)
|
|
|
|
if annotate := get_annotate_from_class_namespace(class_dict):
|
|
raw_annotations = call_annotate_function(annotate, format=Format.FORWARDREF)
|
|
return raw_annotations
|
|
|
|
|
|
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L58-L77
|
|
if sys.version_info < (3, 12, 4):
|
|
|
|
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
|
# Even though it is the right signature for python 3.9, mypy complains with
|
|
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
|
|
# Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid:
|
|
# TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard'
|
|
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
|
|
|
|
else:
|
|
|
|
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
|
# Pydantic 1.x will not support PEP 695 syntax, but provide `type_params` to avoid
|
|
# warnings:
|
|
return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set())
|
|
|
|
|
|
class DatabaseModelMetaclass(SQLModelMetaclass):
|
|
def __new__(
|
|
cls,
|
|
name: str,
|
|
bases: tuple[type, ...],
|
|
namespace: dict[str, Any],
|
|
**kwargs: Any,
|
|
) -> "DatabaseModelMetaclass":
|
|
original_annotations = _get_annotations(namespace)
|
|
new_annotations = {}
|
|
ondemands = []
|
|
excludes = []
|
|
|
|
for k, v in original_annotations.items():
|
|
if get_origin(v) is OnDemand:
|
|
inner_type = v.__args__[0]
|
|
new_annotations[k] = inner_type
|
|
ondemands.append(k)
|
|
elif get_origin(v) is Exclude:
|
|
inner_type = v.__args__[0]
|
|
new_annotations[k] = inner_type
|
|
excludes.append(k)
|
|
else:
|
|
new_annotations[k] = v
|
|
|
|
new_class = super().__new__(
|
|
cls,
|
|
name,
|
|
bases,
|
|
{
|
|
**namespace,
|
|
"__annotations__": new_annotations,
|
|
},
|
|
**kwargs,
|
|
)
|
|
|
|
new_class._CALCULATED_FIELDS = dict(getattr(new_class, "_CALCULATED_FIELDS", {}))
|
|
new_class._ONDEMAND_DATABASE_FIELDS = list(getattr(new_class, "_ONDEMAND_DATABASE_FIELDS", [])) + list(
|
|
ondemands
|
|
)
|
|
new_class._ONDEMAND_CALCULATED_FIELDS = dict(getattr(new_class, "_ONDEMAND_CALCULATED_FIELDS", {}))
|
|
new_class._EXCLUDED_DATABASE_FIELDS = list(getattr(new_class, "_EXCLUDED_DATABASE_FIELDS", [])) + list(excludes)
|
|
|
|
for attr_name, attr_value in namespace.items():
|
|
target = _get_callable_target(attr_value)
|
|
if target is None:
|
|
continue
|
|
|
|
if getattr(target, "__included__", False):
|
|
new_class._CALCULATED_FIELDS[attr_name] = _get_return_type(target)
|
|
_pre_calculate_context_params(target, attr_value)
|
|
|
|
if getattr(target, "__calculated_ondemand__", False):
|
|
new_class._ONDEMAND_CALCULATED_FIELDS[attr_name] = _get_return_type(target)
|
|
_pre_calculate_context_params(target, attr_value)
|
|
|
|
# Register TDict to DatabaseModel mapping
|
|
for base in get_original_bases(new_class):
|
|
cls_name = base.__name__
|
|
if "DatabaseModel" in cls_name and "[" in cls_name and "]" in cls_name:
|
|
generic_type_name = cls_name[cls_name.index("[") : cls_name.rindex("]") + 1]
|
|
generic_type = evaluate_forwardref(
|
|
ForwardRef(generic_type_name),
|
|
globalns=vars(sys.modules[new_class.__module__]),
|
|
localns={},
|
|
)
|
|
_dict_to_model[generic_type[0]] = new_class
|
|
|
|
return new_class
|
|
|
|
|
|
def _pre_calculate_context_params(target: Callable, attr_value: Any) -> None:
|
|
if hasattr(target, "__context_params__"):
|
|
return
|
|
|
|
sig = inspect.signature(target)
|
|
params = list(sig.parameters.keys())
|
|
|
|
start_index = 2
|
|
if isinstance(attr_value, classmethod):
|
|
start_index = 3
|
|
|
|
context_params = [] if len(params) < start_index else params[start_index:]
|
|
|
|
setattr(target, "__context_params__", context_params)
|
|
|
|
|
|
def _get_callable_target(value: Any) -> Callable | None:
|
|
if isinstance(value, (staticmethod, classmethod)):
|
|
return value.__func__
|
|
if inspect.isfunction(value):
|
|
return value
|
|
if inspect.ismethod(value):
|
|
return value.__func__
|
|
return None
|
|
|
|
|
|
def _mark_callable(value: Any, flag: str) -> Callable | None:
|
|
target = _get_callable_target(value)
|
|
if target is None:
|
|
return None
|
|
setattr(target, flag, True)
|
|
return target
|
|
|
|
|
|
def _get_return_type(func: Callable) -> type:
|
|
sig = inspect.get_annotations(func)
|
|
return sig.get("return", Any)
|
|
|
|
|
|
P = ParamSpec("P")
|
|
CalculatedField = Callable[Concatenate[AsyncSession, Any, P], Awaitable[Any]]
|
|
DecoratorTarget = CalculatedField | staticmethod | classmethod
|
|
|
|
|
|
def included(func: DecoratorTarget) -> DecoratorTarget:
|
|
marker = _mark_callable(func, "__included__")
|
|
if marker is None:
|
|
raise RuntimeError("@included is only usable on callables.")
|
|
|
|
@wraps(marker)
|
|
async def wrapper(*args, **kwargs):
|
|
return await marker(*args, **kwargs)
|
|
|
|
if isinstance(func, staticmethod):
|
|
return staticmethod(wrapper)
|
|
if isinstance(func, classmethod):
|
|
return classmethod(wrapper)
|
|
return wrapper
|
|
|
|
|
|
def ondemand(func: DecoratorTarget) -> DecoratorTarget:
|
|
marker = _mark_callable(func, "__calculated_ondemand__")
|
|
if marker is None:
|
|
raise RuntimeError("@ondemand is only usable on callables.")
|
|
|
|
@wraps(marker)
|
|
async def wrapper(*args, **kwargs):
|
|
return await marker(*args, **kwargs)
|
|
|
|
if isinstance(func, staticmethod):
|
|
return staticmethod(wrapper)
|
|
if isinstance(func, classmethod):
|
|
return classmethod(wrapper)
|
|
return wrapper
|
|
|
|
|
|
async def call_awaitable_with_context(
|
|
func: CalculatedField,
|
|
session: AsyncSession,
|
|
instance: Any,
|
|
context: dict[str, Any],
|
|
) -> Any:
|
|
context_params: list[str] | None = getattr(func, "__context_params__", None)
|
|
|
|
if context_params is None:
|
|
# Fallback if not pre-calculated
|
|
sig = inspect.signature(func)
|
|
if len(sig.parameters) == 2:
|
|
return await func(session, instance)
|
|
else:
|
|
call_params = {}
|
|
for param in sig.parameters.values():
|
|
if param.name in context:
|
|
call_params[param.name] = context[param.name]
|
|
return await func(session, instance, **call_params)
|
|
|
|
if not context_params:
|
|
return await func(session, instance)
|
|
|
|
call_params = {}
|
|
for name in context_params:
|
|
if name in context:
|
|
call_params[name] = context[name]
|
|
return await func(session, instance, **call_params)
|
|
|
|
|
|
class DatabaseModel[TDict](SQLModel, UTCBaseModel, metaclass=DatabaseModelMetaclass):
|
|
_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
|
|
|
|
_ONDEMAND_DATABASE_FIELDS: ClassVar[list[str]] = []
|
|
_ONDEMAND_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {}
|
|
|
|
_EXCLUDED_DATABASE_FIELDS: ClassVar[list[str]] = []
|
|
|
|
@overload
|
|
@classmethod
|
|
async def transform(
|
|
cls,
|
|
db_instance: "DatabaseModel",
|
|
*,
|
|
session: AsyncSession,
|
|
includes: list[str] | None = None,
|
|
**context: Any,
|
|
) -> TDict: ...
|
|
|
|
@overload
|
|
@classmethod
|
|
async def transform(
|
|
cls,
|
|
db_instance: "DatabaseModel",
|
|
*,
|
|
includes: list[str] | None = None,
|
|
**context: Any,
|
|
) -> TDict: ...
|
|
|
|
@classmethod
|
|
async def transform(
|
|
cls,
|
|
db_instance: "DatabaseModel",
|
|
*,
|
|
session: AsyncSession | None = None,
|
|
includes: list[str] | None = None,
|
|
**context: Any,
|
|
) -> TDict:
|
|
includes = includes.copy() if includes is not None else []
|
|
session = cast(AsyncSession | None, async_object_session(db_instance)) if session is None else session
|
|
if session is None:
|
|
raise RuntimeError("DatabaseModel.transform requires a session-bound instance.")
|
|
resp_obj = cls.model_validate(db_instance.model_dump())
|
|
data = resp_obj.model_dump()
|
|
|
|
for field in cls._CALCULATED_FIELDS:
|
|
func = getattr(cls, field)
|
|
value = await call_awaitable_with_context(func, session, db_instance, context)
|
|
data[field] = value
|
|
|
|
sub_include_map: dict[str, list[str]] = {}
|
|
for include in [i for i in includes if "." in i]:
|
|
parent, sub_include = include.split(".", 1)
|
|
if parent not in sub_include_map:
|
|
sub_include_map[parent] = []
|
|
sub_include_map[parent].append(sub_include)
|
|
includes.remove(include) # pyright: ignore[reportOptionalMemberAccess]
|
|
|
|
for field, sub_includes in sub_include_map.items():
|
|
if field in cls._ONDEMAND_CALCULATED_FIELDS:
|
|
func = getattr(cls, field)
|
|
value = await call_awaitable_with_context(
|
|
func, session, db_instance, {**context, "includes": sub_includes}
|
|
)
|
|
data[field] = value
|
|
|
|
for include in includes:
|
|
if include in data:
|
|
continue
|
|
|
|
if include in cls._ONDEMAND_CALCULATED_FIELDS:
|
|
func = getattr(cls, include)
|
|
value = await call_awaitable_with_context(func, session, db_instance, context)
|
|
data[include] = value
|
|
|
|
for field in cls._ONDEMAND_DATABASE_FIELDS:
|
|
if field not in includes:
|
|
del data[field]
|
|
|
|
for field in cls._EXCLUDED_DATABASE_FIELDS:
|
|
if field in data:
|
|
del data[field]
|
|
|
|
return cast(TDict, data)
|
|
|
|
@classmethod
|
|
async def transform_many(
|
|
cls,
|
|
db_instances: Sequence["DatabaseModel"],
|
|
*,
|
|
session: AsyncSession | None = None,
|
|
includes: list[str] | None = None,
|
|
**context: Any,
|
|
) -> list[TDict]:
|
|
if not db_instances:
|
|
return []
|
|
|
|
# SQLAlchemy AsyncSession is not concurrency-safe, so we cannot use asyncio.gather here
|
|
# if the transform method performs any database operations using the shared session.
|
|
# Since we don't know if the transform method (or its calculated fields) will use the DB,
|
|
# we must execute them serially to be safe.
|
|
results = []
|
|
for instance in db_instances:
|
|
results.append(await cls.transform(instance, session=session, includes=includes, **context))
|
|
return results
|
|
|
|
@classmethod
|
|
@lru_cache
|
|
def generate_typeddict(cls, includes: tuple[str, ...] | None = None) -> type[TypedDict]: # pyright: ignore[reportInvalidTypeForm]
|
|
def _evaluate_type(field_type: Any, *, resolve_database_model: bool = False, field_name: str = "") -> Any:
|
|
# Evaluate ForwardRef if present
|
|
if isinstance(field_type, (str, ForwardRef)):
|
|
resolved = _safe_evaluate_forwardref(field_type, cls.__module__)
|
|
if resolved is not None:
|
|
field_type = resolved
|
|
|
|
origin_type = get_origin(field_type)
|
|
inner_type = field_type
|
|
args = get_args(field_type)
|
|
|
|
is_optional = type_is_optional(field_type) # pyright: ignore[reportArgumentType]
|
|
if is_optional:
|
|
inner_type = next((arg for arg in args if arg is not NoneType), field_type)
|
|
|
|
is_list = False
|
|
if origin_type is list:
|
|
is_list = True
|
|
inner_type = args[0]
|
|
|
|
# Evaluate ForwardRef in inner_type if present
|
|
if isinstance(inner_type, (str, ForwardRef)):
|
|
resolved = _safe_evaluate_forwardref(inner_type, cls.__module__)
|
|
if resolved is not None:
|
|
inner_type = resolved
|
|
|
|
if not resolve_database_model:
|
|
if is_optional:
|
|
return inner_type | None # pyright: ignore[reportOperatorIssue]
|
|
elif is_list:
|
|
return list[inner_type]
|
|
return inner_type
|
|
|
|
model_class = None
|
|
|
|
# First check if inner_type is directly a DatabaseModel subclass
|
|
try:
|
|
if inspect.isclass(inner_type) and issubclass(inner_type, DatabaseModel): # type: ignore
|
|
model_class = inner_type
|
|
except TypeError:
|
|
pass
|
|
|
|
# If not found, look up in _dict_to_model
|
|
if model_class is None:
|
|
model_class = _dict_to_model.get(inner_type) # type: ignore
|
|
|
|
if model_class is not None:
|
|
nested_dict = model_class.generate_typeddict(tuple(sub_include_map.get(field_name, ())))
|
|
resolved_type = list[nested_dict] if is_list else nested_dict # type: ignore
|
|
|
|
if is_optional:
|
|
resolved_type = resolved_type | None # type: ignore
|
|
|
|
return resolved_type
|
|
|
|
# Fallback: use the resolved inner_type
|
|
resolved_type = list[inner_type] if is_list else inner_type # type: ignore
|
|
if is_optional:
|
|
resolved_type = resolved_type | None # type: ignore
|
|
return resolved_type
|
|
|
|
if includes is None:
|
|
includes = ()
|
|
|
|
# Parse nested includes
|
|
direct_includes = []
|
|
sub_include_map: dict[str, list[str]] = {}
|
|
for include in includes:
|
|
if "." in include:
|
|
parent, sub_include = include.split(".", 1)
|
|
if parent not in sub_include_map:
|
|
sub_include_map[parent] = []
|
|
sub_include_map[parent].append(sub_include)
|
|
if parent not in direct_includes:
|
|
direct_includes.append(parent)
|
|
else:
|
|
direct_includes.append(include)
|
|
|
|
fields = {}
|
|
|
|
# Process model fields
|
|
for field_name, field_info in cls.model_fields.items():
|
|
field_type = field_info.annotation or Any
|
|
field_type = _evaluate_type(field_type, field_name=field_name)
|
|
|
|
if field_name in cls._ONDEMAND_DATABASE_FIELDS and field_name not in direct_includes:
|
|
continue
|
|
else:
|
|
fields[field_name] = field_type
|
|
|
|
# Process calculated fields
|
|
for field_name, field_type in cls._CALCULATED_FIELDS.items():
|
|
field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
|
|
fields[field_name] = field_type
|
|
|
|
# Process ondemand calculated fields
|
|
for field_name, field_type in cls._ONDEMAND_CALCULATED_FIELDS.items():
|
|
if field_name not in direct_includes:
|
|
continue
|
|
|
|
field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name)
|
|
fields[field_name] = field_type
|
|
|
|
return TypedDict(f"{cls.__name__}Dict[{', '.join(includes)}]" if includes else f"{cls.__name__}Dict", fields) # pyright: ignore[reportArgumentType]
|