Source code for modularml.core.models.base_model

"""Abstract base class definitions for ModularML model backends."""

from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any

from modularml.core.io.protocols import Configurable, Stateful
from modularml.core.io.symbol_registry import symbol_registry
from modularml.utils.data.comparators import deep_equal
from modularml.utils.nn.backend import Backend, normalize_backend


[docs] class BaseModel(Configurable, Stateful, ABC): """ Abstract base class for backend-agnostic ModularML models. Attributes: _backend (Backend): Normalized backend enum powering the model. _built (bool): Flag indicating whether :meth:`build` completed. _init_args (dict[str, Any]): Keyword arguments stored for config round-trips via :meth:`get_config`. """
[docs] def __init__(self, backend: str | Backend, **init_args: Any): """ Initialize the model with a backend and keyword arguments. Args: backend (str | Backend): Backend identifier or enum value used to normalize :attr:`_backend`. **init_args (Any): Keyword arguments cached for config serialization so subclasses can reconstruct themselves. """ super().__init__() self._backend = normalize_backend(backend) self._built = False self._init_args = dict(init_args)
def __eq__(self, other): """Return True when configs and states match another model.""" if not isinstance(other, BaseModel): msg = f"Cannot compare equality between BaseModel and {type(other)}" raise TypeError(msg) # Compare config (replace class object with name) def grab_clean_cfg(x: Configurable): cfg = x.get_config() # Replace class with qualname (str) if "model_class" in cfg: cls_or_obj = cfg["model_class"] cls = ( cls_or_obj if isinstance(cls_or_obj, type) else cls_or_obj.__class__ ) cfg["model_class"] = cls.__qualname__ if not deep_equal(grab_clean_cfg(self), grab_clean_cfg(other)): return False # Compare state return deep_equal(self.get_state(), other.get_state()) def __hash__(self) -> int: """Return identity-based hash so models remain hashable.""" return id(self) def __repr__(self) -> str: """Return string representation including backend value.""" return f"<{self.backend.value} model>" # ================================================ # Properties # ================================================ @property @abstractmethod def input_shape(self) -> tuple[int, ...] | None: """ Return the expected per-sample input shape for the model. Returns: tuple[int, ...] | None: Shape tuple without the batch dimension, or None when the shape is not yet known. """ @property @abstractmethod def output_shape(self) -> tuple[int, ...] | None: """ Return the per-sample output shape produced by the model. Returns: tuple[int, ...] | None: Shape tuple without the batch dimension, or None when not yet known. """ @property def backend(self) -> Backend: """Return the :class:`Backend` powering this model.""" return self._backend @property def is_built(self) -> bool: """Return True when the model has been built and shapes resolved.""" return ( self._built and (self.input_shape is not None) and (self.output_shape is not None) ) # ================================================ # Methods # ================================================
[docs] @abstractmethod def build( self, input_shape: tuple[int, ...] | None = None, output_shape: tuple[int, ...] | None = None, *, force: bool = False, ): """ Build backend layers for the model implementation. Args: input_shape (tuple[int, ...] | None): Per-sample input shape, excluding batch dimension, if known. output_shape (tuple[int, ...] | None): Per-sample output shape, excluding batch dimension, if known. force (bool): Whether to rebuild even if the model is already constructed. """
[docs] @abstractmethod def forward(self, *args, **kwargs): """Run a forward pass for the underlying backend."""
def __call__(self, *args, **kwargs): """Execute :meth:`forward` to make the instance callable.""" return self.forward(*args, **kwargs) # ================================================ # Model Weights & Initiallize Params # ================================================
[docs] @abstractmethod def get_weights(self) -> dict[str, Any]: """Return backend weights as pure-Python or numpy data."""
[docs] @abstractmethod def set_weights(self, weights: dict[str, Any]) -> None: """ Restore the weights previously produced by :meth:`get_weights`. Args: weights (dict[str, Any]): Dictionary returned by :meth:`get_weights`. """
[docs] def get_init_args(self) -> dict[str, Any]: """ Return keyword arguments needed to reconstruct this model. Description: Subclasses may override this to control config serialization, but caching `init_args` automatically works for most models. Returns: dict[str, Any]: Copy of initialization keyword arguments. """ return dict(self._init_args)
# ================================================ # Configurable # ================================================
[docs] def get_config(self) -> dict[str, Any]: """ Return configuration required to re-instantiate this model. Returns: dict[str, Any]: Serializable configuration capturing the model class, backend, and build metadata. """ return { "model_class": type(self), "init_args": self.get_init_args(), "backend": self.backend.value, "is_built": self._built, "input_shape": self.input_shape, "output_shape": self.output_shape, }
[docs] @classmethod def from_config(cls, config: dict[str, Any]) -> BaseModel: """ Construct a :class:`BaseModel` from serialized configuration. Description: Instantiates a new model with the provided config. Learned weights are not restored; call :meth:`set_state` afterwards. Args: config (dict[str, Any]): Configuration emitted by :meth:`get_config`. Returns: BaseModel: Newly constructed model instance. """ model_cls = config["model_class"] model: BaseModel = model_cls(backend=config["backend"], **config["init_args"]) if config["is_built"]: model.build( input_shape=config.get("input_shape"), output_shape=config.get("output_shape"), force=True, ) return model
# ================================================ # Stateful # ================================================
[docs] def get_state(self) -> dict[str, Any]: """Return the learned state for serialization.""" state: dict[str, Any] = { "weights": self.get_weights(), } return state
[docs] def set_state(self, state: dict[str, Any]) -> None: """ Restore learned state produced by :meth:`get_state`. Args: state (dict[str, Any]): State dictionary containing weights. """ self.set_weights(state["weights"])
# ================================================ # Serialization # ================================================
[docs] def save(self, filepath: Path, *, overwrite: bool = False) -> Path: """ Serialize this model to disk using the serializer. Args: filepath (Path): Destination path. The suffix may be adjusted to follow ModularML naming conventions. overwrite (bool): Whether to overwrite an existing artifact. Returns: Path: Actual file path written by the serializer. """ from modularml.core.io.serialization_policy import SerializationPolicy from modularml.core.io.serializer import serializer policy = SerializationPolicy.PACKAGED if symbol_registry.obj_is_a_builtin_class(self): policy = SerializationPolicy.BUILTIN elif symbol_registry.obj_in_a_builtin_registry( obj_or_cls=self, registry_name="model_registry", ): policy = SerializationPolicy.REGISTERED return serializer.save( self, filepath, policy=policy, overwrite=overwrite, )
[docs] @classmethod def load(cls, filepath: Path, *, allow_packaged_code: bool = False) -> BaseModel: """ Load a serialized :class:`BaseModel` from disk. Args: filepath (Path): Path pointing to a saved model. allow_packaged_code (bool): Whether packaged code execution is permitted for user-defined artifacts. Returns: BaseModel: Reloaded model instance. """ from modularml.core.io.serializer import _enforce_file_suffix, serializer # Append proper sufficx only if no suffix is given if Path(filepath).suffix == "": filepath = _enforce_file_suffix(path=filepath, cls=cls) return serializer.load(filepath, allow_packaged_code=allow_packaged_code)