Source code for modularml.core.models.tensorflow_base_model
"""Base classes for native TensorFlow models in ModularML."""
from __future__ import annotations
from abc import ABC
from typing import TYPE_CHECKING, Any
from modularml.core.models.base_model import BaseModel
from modularml.utils.environment.optional_imports import ensure_tensorflow
from modularml.utils.nn.backend import Backend
if TYPE_CHECKING:
import numpy as np
[docs]
class TensorflowBaseModel(BaseModel, ABC):
"""
Base class for ModularML-native TensorFlow/Keras models.
Description:
Designed for framework-owned Keras implementations. User-defined
:class:`tf.keras.Model` objects can subclass this base, though
:class:`TensorflowModelWrapper` is typically simpler.
"""
[docs]
def __init__(self, **init_args: Any):
"""Initialize TensorFlow dependencies and call :class:`BaseModel`."""
_ = ensure_tensorflow()
# BaseModel handles backend + built flag
_ = init_args.pop("backend", None)
super().__init__(backend=Backend.TENSORFLOW, **init_args)
# ================================================
# Model Weights (Stateful)
# ================================================
[docs]
def get_weights(self) -> dict[str, np.ndarray]:
"""Return model weights as numpy arrays keyed by variable name."""
if not self.is_built:
return {}
# Subclasses must expose a `model` attribute holding the Keras model
model = self._get_keras_model()
return {var.name: var.numpy() for var in model.variables}
[docs]
def set_weights(self, weights: dict[str, np.ndarray]) -> None:
"""Restore numpy-based weights produced by :meth:`get_weights`."""
if not weights:
return
model = self._get_keras_model()
var_map = {var.name: var for var in model.variables}
for name, value in weights.items():
if name not in var_map:
msg = (
f"Variable `{name}` not found in model. "
f"Available: {list(var_map.keys())}"
)
raise ValueError(msg)
var_map[name].assign(value)
def _get_keras_model(self):
"""
Return the underlying Keras model for weight access.
Returns:
tf.keras.Model: The underlying model referenced for weights.
Raises:
AttributeError: If the subclass did not expose a model and did
not override this helper.
"""
if hasattr(self, "model") and self.model is not None:
return self.model
msg = (
"No `model` attribute found. Subclasses of TensorflowBaseModel "
"must either store a Keras model as `self.model` or override "
"`_get_keras_model()`."
)
raise AttributeError(msg)