from collections.abc import Awaitable, Callable, Sequence from functools import lru_cache, wraps import inspect import sys from types import NoneType, get_original_bases from typing import ( TYPE_CHECKING, Any, ClassVar, Concatenate, ForwardRef, ParamSpec, TypedDict, cast, get_args, get_origin, overload, ) from app.models.model import UTCBaseModel from app.utils import type_is_optional from sqlalchemy.ext.asyncio import async_object_session from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.main import SQLModelMetaclass _dict_to_model: dict[type, type["DatabaseModel"]] = {} def _safe_evaluate_forwardref(type_: str | ForwardRef, module_name: str) -> Any: """Safely evaluate a ForwardRef, with fallback to app.database module""" if isinstance(type_, str): type_ = ForwardRef(type_) try: return evaluate_forwardref( type_, globalns=vars(sys.modules[module_name]), localns={}, ) except (NameError, AttributeError, KeyError): # Fallback to app.database module try: import app.database return evaluate_forwardref( type_, globalns=vars(app.database), localns={}, ) except (NameError, AttributeError, KeyError): return None class OnDemand[T]: if TYPE_CHECKING: def __get__(self, instance: object | None, owner: Any) -> T: ... def __set__(self, instance: Any, value: T) -> None: ... def __delete__(self, instance: Any) -> None: ... class Exclude[T]: if TYPE_CHECKING: def __get__(self, instance: object | None, owner: Any) -> T: ... def __set__(self, instance: Any, value: T) -> None: ... def __delete__(self, instance: Any) -> None: ... # https://github.com/fastapi/sqlmodel/blob/main/sqlmodel/_compat.py#L126-L140 def _get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: raw_annotations: dict[str, Any] = class_dict.get("__annotations__", {}) if sys.version_info >= (3, 14) and "__annotations__" not in class_dict: # See https://github.com/pydantic/pydantic/pull/11991 from annotationlib import ( Format, call_annotate_function, get_annotate_from_class_namespace, ) if annotate := get_annotate_from_class_namespace(class_dict): raw_annotations = call_annotate_function(annotate, format=Format.FORWARDREF) return raw_annotations # https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L58-L77 if sys.version_info < (3, 12, 4): def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: # Even though it is the right signature for python 3.9, mypy complains with # `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast... # Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid: # TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard' return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set()) else: def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: # Pydantic 1.x will not support PEP 695 syntax, but provide `type_params` to avoid # warnings: return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set()) class DatabaseModelMetaclass(SQLModelMetaclass): def __new__( cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs: Any, ) -> "DatabaseModelMetaclass": original_annotations = _get_annotations(namespace) new_annotations = {} ondemands = [] excludes = [] for k, v in original_annotations.items(): if get_origin(v) is OnDemand: inner_type = v.__args__[0] new_annotations[k] = inner_type ondemands.append(k) elif get_origin(v) is Exclude: inner_type = v.__args__[0] new_annotations[k] = inner_type excludes.append(k) else: new_annotations[k] = v new_class = super().__new__( cls, name, bases, { **namespace, "__annotations__": new_annotations, }, **kwargs, ) new_class._CALCULATED_FIELDS = dict(getattr(new_class, "_CALCULATED_FIELDS", {})) new_class._ONDEMAND_DATABASE_FIELDS = list(getattr(new_class, "_ONDEMAND_DATABASE_FIELDS", [])) + list( ondemands ) new_class._ONDEMAND_CALCULATED_FIELDS = dict(getattr(new_class, "_ONDEMAND_CALCULATED_FIELDS", {})) new_class._EXCLUDED_DATABASE_FIELDS = list(getattr(new_class, "_EXCLUDED_DATABASE_FIELDS", [])) + list(excludes) for attr_name, attr_value in namespace.items(): target = _get_callable_target(attr_value) if target is None: continue if getattr(target, "__included__", False): new_class._CALCULATED_FIELDS[attr_name] = _get_return_type(target) _pre_calculate_context_params(target, attr_value) if getattr(target, "__calculated_ondemand__", False): new_class._ONDEMAND_CALCULATED_FIELDS[attr_name] = _get_return_type(target) _pre_calculate_context_params(target, attr_value) # Register TDict to DatabaseModel mapping for base in get_original_bases(new_class): cls_name = base.__name__ if "DatabaseModel" in cls_name and "[" in cls_name and "]" in cls_name: generic_type_name = cls_name[cls_name.index("[") : cls_name.rindex("]") + 1] generic_type = evaluate_forwardref( ForwardRef(generic_type_name), globalns=vars(sys.modules[new_class.__module__]), localns={}, ) _dict_to_model[generic_type[0]] = new_class return new_class def _pre_calculate_context_params(target: Callable, attr_value: Any) -> None: if hasattr(target, "__context_params__"): return sig = inspect.signature(target) params = list(sig.parameters.keys()) start_index = 2 if isinstance(attr_value, classmethod): start_index = 3 context_params = [] if len(params) < start_index else params[start_index:] setattr(target, "__context_params__", context_params) def _get_callable_target(value: Any) -> Callable | None: if isinstance(value, (staticmethod, classmethod)): return value.__func__ if inspect.isfunction(value): return value if inspect.ismethod(value): return value.__func__ return None def _mark_callable(value: Any, flag: str) -> Callable | None: target = _get_callable_target(value) if target is None: return None setattr(target, flag, True) return target def _get_return_type(func: Callable) -> type: sig = inspect.get_annotations(func) return sig.get("return", Any) P = ParamSpec("P") CalculatedField = Callable[Concatenate[AsyncSession, Any, P], Awaitable[Any]] DecoratorTarget = CalculatedField | staticmethod | classmethod def included(func: DecoratorTarget) -> DecoratorTarget: marker = _mark_callable(func, "__included__") if marker is None: raise RuntimeError("@included is only usable on callables.") @wraps(marker) async def wrapper(*args, **kwargs): return await marker(*args, **kwargs) if isinstance(func, staticmethod): return staticmethod(wrapper) if isinstance(func, classmethod): return classmethod(wrapper) return wrapper def ondemand(func: DecoratorTarget) -> DecoratorTarget: marker = _mark_callable(func, "__calculated_ondemand__") if marker is None: raise RuntimeError("@ondemand is only usable on callables.") @wraps(marker) async def wrapper(*args, **kwargs): return await marker(*args, **kwargs) if isinstance(func, staticmethod): return staticmethod(wrapper) if isinstance(func, classmethod): return classmethod(wrapper) return wrapper async def call_awaitable_with_context( func: CalculatedField, session: AsyncSession, instance: Any, context: dict[str, Any], ) -> Any: context_params: list[str] | None = getattr(func, "__context_params__", None) if context_params is None: # Fallback if not pre-calculated sig = inspect.signature(func) if len(sig.parameters) == 2: return await func(session, instance) else: call_params = {} for param in sig.parameters.values(): if param.name in context: call_params[param.name] = context[param.name] return await func(session, instance, **call_params) if not context_params: return await func(session, instance) call_params = {} for name in context_params: if name in context: call_params[name] = context[name] return await func(session, instance, **call_params) class DatabaseModel[TDict](SQLModel, UTCBaseModel, metaclass=DatabaseModelMetaclass): _CALCULATED_FIELDS: ClassVar[dict[str, type]] = {} _ONDEMAND_DATABASE_FIELDS: ClassVar[list[str]] = [] _ONDEMAND_CALCULATED_FIELDS: ClassVar[dict[str, type]] = {} _EXCLUDED_DATABASE_FIELDS: ClassVar[list[str]] = [] @overload @classmethod async def transform( cls, db_instance: "DatabaseModel", *, session: AsyncSession, includes: list[str] | None = None, **context: Any, ) -> TDict: ... @overload @classmethod async def transform( cls, db_instance: "DatabaseModel", *, includes: list[str] | None = None, **context: Any, ) -> TDict: ... @classmethod async def transform( cls, db_instance: "DatabaseModel", *, session: AsyncSession | None = None, includes: list[str] | None = None, **context: Any, ) -> TDict: includes = includes.copy() if includes is not None else [] session = cast(AsyncSession | None, async_object_session(db_instance)) if session is None else session if session is None: raise RuntimeError("DatabaseModel.transform requires a session-bound instance.") resp_obj = cls.model_validate(db_instance.model_dump()) data = resp_obj.model_dump() for field in cls._CALCULATED_FIELDS: func = getattr(cls, field) value = await call_awaitable_with_context(func, session, db_instance, context) data[field] = value sub_include_map: dict[str, list[str]] = {} for include in [i for i in includes if "." in i]: parent, sub_include = include.split(".", 1) if parent not in sub_include_map: sub_include_map[parent] = [] sub_include_map[parent].append(sub_include) includes.remove(include) # pyright: ignore[reportOptionalMemberAccess] for field, sub_includes in sub_include_map.items(): if field in cls._ONDEMAND_CALCULATED_FIELDS: func = getattr(cls, field) value = await call_awaitable_with_context( func, session, db_instance, {**context, "includes": sub_includes} ) data[field] = value for include in includes: if include in data: continue if include in cls._ONDEMAND_CALCULATED_FIELDS: func = getattr(cls, include) value = await call_awaitable_with_context(func, session, db_instance, context) data[include] = value for field in cls._ONDEMAND_DATABASE_FIELDS: if field not in includes: del data[field] for field in cls._EXCLUDED_DATABASE_FIELDS: if field in data: del data[field] return cast(TDict, data) @classmethod async def transform_many( cls, db_instances: Sequence["DatabaseModel"], *, session: AsyncSession | None = None, includes: list[str] | None = None, **context: Any, ) -> list[TDict]: if not db_instances: return [] # SQLAlchemy AsyncSession is not concurrency-safe, so we cannot use asyncio.gather here # if the transform method performs any database operations using the shared session. # Since we don't know if the transform method (or its calculated fields) will use the DB, # we must execute them serially to be safe. results = [] for instance in db_instances: results.append(await cls.transform(instance, session=session, includes=includes, **context)) return results @classmethod @lru_cache def generate_typeddict(cls, includes: tuple[str, ...] | None = None) -> type[TypedDict]: # pyright: ignore[reportInvalidTypeForm] def _evaluate_type(field_type: Any, *, resolve_database_model: bool = False, field_name: str = "") -> Any: # Evaluate ForwardRef if present if isinstance(field_type, (str, ForwardRef)): resolved = _safe_evaluate_forwardref(field_type, cls.__module__) if resolved is not None: field_type = resolved origin_type = get_origin(field_type) inner_type = field_type args = get_args(field_type) is_optional = type_is_optional(field_type) # pyright: ignore[reportArgumentType] if is_optional: inner_type = next((arg for arg in args if arg is not NoneType), field_type) is_list = False if origin_type is list: is_list = True inner_type = args[0] # Evaluate ForwardRef in inner_type if present if isinstance(inner_type, (str, ForwardRef)): resolved = _safe_evaluate_forwardref(inner_type, cls.__module__) if resolved is not None: inner_type = resolved if not resolve_database_model: if is_optional: return inner_type | None # pyright: ignore[reportOperatorIssue] elif is_list: return list[inner_type] return inner_type model_class = None # First check if inner_type is directly a DatabaseModel subclass try: if inspect.isclass(inner_type) and issubclass(inner_type, DatabaseModel): # type: ignore model_class = inner_type except TypeError: pass # If not found, look up in _dict_to_model if model_class is None: model_class = _dict_to_model.get(inner_type) # type: ignore if model_class is not None: nested_dict = model_class.generate_typeddict(tuple(sub_include_map.get(field_name, ()))) resolved_type = list[nested_dict] if is_list else nested_dict # type: ignore if is_optional: resolved_type = resolved_type | None # type: ignore return resolved_type # Fallback: use the resolved inner_type resolved_type = list[inner_type] if is_list else inner_type # type: ignore if is_optional: resolved_type = resolved_type | None # type: ignore return resolved_type if includes is None: includes = () # Parse nested includes direct_includes = [] sub_include_map: dict[str, list[str]] = {} for include in includes: if "." in include: parent, sub_include = include.split(".", 1) if parent not in sub_include_map: sub_include_map[parent] = [] sub_include_map[parent].append(sub_include) if parent not in direct_includes: direct_includes.append(parent) else: direct_includes.append(include) fields = {} # Process model fields for field_name, field_info in cls.model_fields.items(): field_type = field_info.annotation or Any field_type = _evaluate_type(field_type, field_name=field_name) if field_name in cls._ONDEMAND_DATABASE_FIELDS and field_name not in direct_includes: continue else: fields[field_name] = field_type # Process calculated fields for field_name, field_type in cls._CALCULATED_FIELDS.items(): field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name) fields[field_name] = field_type # Process ondemand calculated fields for field_name, field_type in cls._ONDEMAND_CALCULATED_FIELDS.items(): if field_name not in direct_includes: continue field_type = _evaluate_type(field_type, resolve_database_model=True, field_name=field_name) fields[field_name] = field_type return TypedDict(f"{cls.__name__}Dict[{', '.join(includes)}]" if includes else f"{cls.__name__}Dict", fields) # pyright: ignore[reportArgumentType]