Source code for modularml.core.topology.model_node

"""Model node implementations within ModularML model graphs."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, overload

from modularml.core.data.batch import Batch
from modularml.core.data.featureset import FeatureSet
from modularml.core.data.sample_data import RoleData, SampleData
from modularml.core.experiment.experiment_node import ExperimentNode
from modularml.core.models import wrap_model
from modularml.core.models.base_model import BaseModel
from modularml.core.references.experiment_reference import ExperimentNodeReference
from modularml.core.references.featureset_reference import FeatureSetReference
from modularml.core.topology.compute_node import ComputeNode, TForward
from modularml.core.training.loss_record import LossCollection, LossRecord
from modularml.core.training.optimizer import Optimizer
from modularml.utils.data.data_format import DataFormat, get_data_format_for_backend
from modularml.utils.environment.optional_imports import check_tensorflow, check_torch
from modularml.utils.errors.exceptions import (
    BackendMismatchError,
    BackendNotSupportedError,
    OptimizerNotSetError,
)
from modularml.utils.logging.warnings import catch_warnings, warn
from modularml.utils.nn.backend import Backend
from modularml.utils.representation.summary import safe_cast_to_summary_rows
from modularml.utils.topology.graph_search_utils import find_upstream_featuresets

if TYPE_CHECKING:
    from modularml.core.data.execution_context import ExecutionContext
    from modularml.core.training.applied_loss import AppliedLoss

tf = check_tensorflow()
torch = check_torch()


[docs] class ModelNode(ComputeNode): """ Single learnable or static stage inside a :class:`ModelGraph`. Attributes: _model (BaseModel): Wrapped backend model implementation. _optimizer (Optimizer | None): Optional optimizer coordinating gradient steps. _freeze (bool): Flag indicating whether training is disabled. """
[docs] def __init__( self, label: str, model: BaseModel | Any, upstream_ref: ExperimentNode | ExperimentNodeReference, optimizer: Optimizer | None = None, *, node_id: str | None = None, register: bool = True, ): """ Initialize a ModelNode. Args: label (str): Unique name identifying this stage within the model graph. model (Union[BaseModel, Any]): A backend-specific model instance or config. upstream_ref (ExperimentReference): Reference to the upstream node. optimizer (Optional[Optimizer]): Optimizer to use during training (optional). node_id (str, optional): Used only for de-serialization. register (bool, optional): Used only for de-serialization. """ ref = None if isinstance(upstream_ref, FeatureSet): dup_rep_warnings = False with catch_warnings() as cw: ref = upstream_ref.reference() if cw.match("Multiple representations selected"): dup_rep_warnings = True if dup_rep_warnings: msg = ( "Setting a ModelNode `upstream_ref` with a FeatureSet will result in multiple " "representations of the same column being combined into input/target tensors. " ) hint = ( "Use `FeatureSet(...).reference(...)` is this is not intentional." ) warn(msg, category=UserWarning, stacklevel=2, hints=hint) elif isinstance(upstream_ref, ExperimentNodeReference): ref = upstream_ref elif isinstance(upstream_ref, ExperimentNode): ref = upstream_ref.reference() else: msg = f"`upstream_ref` must be of type ExperimentReference or ExperimentNode. Received: {type(upstream_ref)}." raise TypeError(msg) super().__init__( label=label, upstream_refs=ref, node_id=node_id, register=register, ) # Set model (cast to BaseModel if explicit subclass not provided) self._model: BaseModel = wrap_model(model) self._freeze = False # make stage trainable as default # Error checking on optimizer (can be None) self._optimizer = optimizer self._check_valid_optimizer(required=False)
@property def model(self) -> BaseModel: """ Return the wrapped backend model instance. Returns: BaseModel: Backend-specific implementation. """ return self._model @property def input_shape(self) -> tuple[int, ...]: """ Return the model's input tensor shape. Returns: tuple[int, ...]: Expected feature tensor shape. """ return self.model.input_shape # ================================================ # ComputeNode Interface # ================================================ @property def output_shape(self) -> tuple[int, ...]: """ Return the model's output tensor shape. Returns: tuple[int, ...]: Output tensor shape. """ return self.model.output_shape @property def max_upstream_refs(self) -> int: """ Return the maximum number of allowed upstream references. Returns: int: Always 1 because :class:`ModelNode` has a single input. """ return 1 @property def is_built(self) -> bool: """ Checks if the model has been built (i.e., instantiated with input/output shape). Returns: bool: True if built, False otherwise. """ return self._model.is_built def _build_impl( self, *, input_shapes: dict[ExperimentNodeReference, tuple[int, ...]] | None = None, output_shape: tuple[int, ...] | None = None, force: bool = False, **kwargs, # noqa: ARG002 ): """ Construct the wrapped model using upstream/downstream shapes. Args: input_shapes (dict[ExperimentNodeReference, tuple[int, ...]] | None): Shapes of upstream tensors; must contain a single entry. output_shape (tuple[int, ...] | None): Expected output shape used to validate decoder layers. force (bool): Whether to rebuild even if already built. **kwargs: Additional subclass parameters (unused). Raises: ValueError: If multiple inputs are provided. """ if input_shapes is None: input_shape = None else: if len(input_shapes) != 1: msg = ( f"{self.__class__.__name__} expects exactly one input. " f"Received {len(input_shapes)}." ) raise ValueError(msg) input_shape = next(iter(input_shapes.values())) self.build_model( input_shape=input_shape, output_shape=output_shape, force=force, ) def _build_optimizer(self, *, force: bool = False): """ Construct the optimizer once the model weights exist. Args: force (bool): Whether to rebuild even if already built. Raises: ValueError: If optimizer or model state is unavailable. BackendNotSupportedError: If the backend is unknown. """ if self._optimizer is None: raise ValueError("Optimizer is None. Cannot build.") if not self.is_built: raise ValueError("Optimzier cannot be built until model is built.") if self.backend == Backend.TORCH: self._optimizer.build( parameters=self._model.parameters(), backend=self.backend, force_rebuild=force, ) elif self.backend == Backend.TENSORFLOW: self._optimizer.build( backend=self.backend, force_rebuild=force, ) elif self.backend == Backend.SCIKIT: # Scikit-learn optimizers are typically fit internally pass else: raise BackendNotSupportedError( backend=self.backend, message="Unknown backend for optimizer building", )
[docs] def build_model( self, input_shape: tuple[int, ...] | None = None, output_shape: tuple[int, ...] | None = None, *, force: bool = False, ): """ Build the ModelNode by initializing the internal BaseModel and optimizer. Args: input_shape (tuple[int, ...] | None, optional): Input shape to construct this model with. Defaults to None. output_shape (tuple[int, ...] | None, optional): Output shape to construct this model with. If not provided, the BaseModel must be capable of inferring it internally or during construction. Defaults to None. force (bool, optional): If model is already instantiated it will not be re-instantiated unless `force=True`. Defaults to False. Notes: - For PyTorch and TensorFlow, optimizers are built after the model is initialized. - Scikit-learn models typically do not require external optimizers. - This method assumes that shape inference and merge logic (if needed) has already been resolved upstream by the ModelGraph. """ # Build underlying BaseModel if not already built if (not self._model.is_built) or force: self._model.build( input_shape=input_shape, output_shape=output_shape, force=force, ) # Build optimizer if defined if self._optimizer is not None: self._build_optimizer(force=force)
@overload def forward_single(self, batch: Batch, **kwargs) -> Batch: ... @overload def forward_single(self, roles: RoleData, **kwargs) -> RoleData: ... @overload def forward_single(self, data: SampleData, **kwargs) -> SampleData: ...
[docs] def forward_single( self, x: SampleData | RoleData | Batch, **kwargs, ) -> SampleData | RoleData | Batch: """ Performs a forward pass through the model using SampleData. This method preserves raw tensor outputs to maintain backend autograd support. It returns a `SampleData` object keyed by output roles containing model predictions. Args: x (SampleData | RoleData | Batch): Input data to the model. **kwargs: Any additional keyword arguments to provide to BaseModel.forward Returns: SampleData | RoleData | Batch: Outputs from the model. Output type matches input. """ # Ensure built if not self.is_built: # We can try to auto-build base on runtime upstream/downstream connections # If upstream_ref is a FeatureSet, we can take feature shapes in_shape = None if isinstance(self.upstream_ref, FeatureSetReference): # Get feature and target shapes (drops leading dim of n_samples) fsv = self.upstream_ref.resolve() in_shape = fsv.get_features(fmt=DataFormat.NUMPY).shape[1:] # If this is a tail node, and is downstream of only one FeatureSet, we # can infer the output shape to be the FeatureSet.targets shape out_shape = None ups_fs_refs = find_upstream_featuresets(node=self) ups_fs_ids = {ref.node_id for ref in ups_fs_refs} if len(ups_fs_ids) == 1: fsv = ups_fs_refs[0].resolve() t_shape = fsv.get_targets(fmt=DataFormat.NUMPY).shape[1:] out_shape = tuple(t_shape) try: self.build_model( input_shape=in_shape, output_shape=out_shape, ) except Exception as e: msg = ( f"ModelNode '{self.label}' has not been built yet. " "Call `build_model()` first." ) raise RuntimeError(msg) from e def _forward_sample_data(d: SampleData) -> SampleData: """ Run backend-forward pass for a single :class:`SampleData`. Args: d (SampleData): Input sample bundle. Returns: SampleData: Output bundle preserving metadata. """ # Ensure SampleData is in expected backend (modified inplace) d.as_backend(self.backend) # Pass features through internal model out_features = self._model(d.features, **kwargs) # Targets, tags, and uuids pass through without modification return SampleData( features=out_features, targets=d.targets, tags=d.tags, sample_uuids=d.sample_uuids, kind="output", ) if isinstance(x, SampleData): return _forward_sample_data(x) if isinstance(x, RoleData): out = {k: _forward_sample_data(v) for k, v in x.items()} return RoleData(data=out) if isinstance(x, Batch): out = RoleData( data={k: _forward_sample_data(v) for k, v in x.role_data.items()}, ) return Batch( batch_size=x.batch_size, role_data=out, shapes=out.shapes, role_weights=x.role_weights, role_masks=x.role_masks, ) msg = f"Input must be of type SampleData or RoleData or Batch. Received: {type(x)}" raise TypeError(msg)
def _forward_impl( self, *, inputs: dict[ExperimentNodeReference, TForward], **kwargs, ) -> TForward: """ Delegate to :meth:`forward_single` after validating inputs. Args: inputs (dict[ExperimentNodeReference, TForward]): = Single upstream tensor keyed by its reference. **kwargs: Extra arguments forwarded to :meth:`forward_single`. Returns: TForward: Batch or sample data emitted by the model. Raises: ValueError: If more than one input is supplied. """ if len(inputs) != 1: msg = ( f"{self.__class__.__name__} expects exactly one input. " f"Received {len(inputs)}." ) raise ValueError(msg) x = next(iter(inputs.values())) return self.forward_single(x, **kwargs) __call__ = forward_single # ================================================ # Representation # ================================================ def _summary_rows(self) -> list[tuple]: """ Return tabular summary rows for logging output. Returns: list[tuple]: Key/value metadata about the node. """ return [ ("label", self.label), ("upstream_ref", safe_cast_to_summary_rows(self.upstream_ref)), ( "downstream_refs", [safe_cast_to_summary_rows(r) for r in self._downstream_refs], ), ( "input_shape", str(self.input_shape) if self.is_built else "NOT BUILT YET", ), ( "output_shape", str(self.output_shape) if self.is_built else "NOT BUILT YET", ), ("model", safe_cast_to_summary_rows(self._model)), ("optimizer", safe_cast_to_summary_rows(self._optimizer)), ("backend", safe_cast_to_summary_rows(self.backend)), ("frozen", f"{'True' if self.is_frozen else 'False'}"), ] def __repr__(self): """ Return developer-friendly representation for debugging. Returns: str: String showing labels, model, optimizer, and backend. """ return ( f"ModelNode(label='{self.label}', " f"upstream_refs={self._upstream_refs}, " f"downstream_refs={self._downstream_refs}, " f"model={self._model!r}, " f"optimizer={self._optimizer}, " f"backend={self.backend})" ) def __str__(self): """ Return human-readable identifier for logging. Returns: str: Node label formatted for readability. """ return f"ModelNode('{self.label}')" # ================================================ # Error Checking Methods # ================================================ def _check_valid_optimizer(self, *, required: bool = True): """ Verifies that the optimizer is compatible with the model's backend. Args: required (bool): Whether an optimizer is required. Default is True. Raises: OptimizerNotSetError: If required and optimizer is None. BackendMismatchError: If optimizer and model backends differ. """ if self._optimizer is None and required: msg = f"Missing optimizer for ModelNode '{self.label}'." raise OptimizerNotSetError(message=msg) if self._optimizer is not None: if self._optimizer.backend is None: self._optimizer.backend = self.backend elif self._optimizer.backend != self.backend: raise BackendMismatchError( expected=self.backend, received=self._optimizer.backend, message=f"Optimizer backend does not match model backend: {self._optimizer.backend} != {self.backend}", ) def _validate_ctx(self, ctx: ExecutionContext): """ Validates that the context contains needed input data for this node. Args: ctx (ExecutionContext): Execution context to validate losses on. Raises: ValueError: If any expected input or loss role is missing. """ # If this node takes input from FeatureSet, ensure in ctx.inputs if isinstance(self.upstream_ref, FeatureSetReference): req_input_key = (self.node_id, self.upstream_ref) if req_input_key not in ctx.inputs: msg = ( f"ExecutionContext missing input data for ModelNode '{self.label}'." ) raise ValueError(msg) # Otherwise, prior model outputs must be in ctx.outputs elif self.upstream_ref.node_id not in ctx.outputs: msg = f"ExecutionContext missing output data from upstream node '{self.upstream_ref.node_label}'." raise ValueError(msg) # ================================================ # Trainable Protocol # ================================================ @property def backend(self) -> Backend: """ Returns the backend associated with the wrapped model. Returns: Backend: TORCH, TENSORFLOW, SCIKIT, ... """ return self._model.backend @property def is_frozen(self) -> bool: """ Indicates whether this stage is frozen (not trainable). Returns: bool: True if frozen, False if trainable. """ return self._freeze
[docs] def freeze(self): """Freezes this node (prevents training updates).""" self._freeze = True # Ensure trainable state if self.backend == Backend.TORCH: for p in self.model.parameters(): p.requires_grad = False self.model.eval() elif self.backend == Backend.TENSORFLOW: self.model.trainable = False
[docs] def unfreeze(self): """Unfreezes this node (allows training updates).""" self._freeze = False # Ensure trainable state if self.backend == Backend.TORCH: for p in self.model.parameters(): p.requires_grad = True self.model.train() elif self.backend == Backend.TENSORFLOW: self.model.trainable = True
def _get_input_batch( self, ctx: ExecutionContext, ) -> Batch: """Retrieves Batch data for this ModelNode at the current execution step.""" all_inp_data = self.get_input_data( inputs=ctx.inputs, outputs=ctx.outputs, fmt=get_data_format_for_backend(backend=self.backend), ) return all_inp_data[self.upstream_ref] def _train_step_torch( self, ctx: ExecutionContext, losses: list[AppliedLoss], ): """ Runs a training step using PyTorch: forward, loss, backward, optimizer. Args: ctx (ExecutionContext): Context (input/output data) for the given execution step. losses (list[AppliedLoss]): List of losses to be applied in this execution step. """ # Set optimizer and train mode self._model.train() self._optimizer.zero_grad() loss_records: list[LossRecord] = [] # Forward pass (ctx.execution modified inplace) out_batch: Batch = self.forward_single(self._get_input_batch(ctx=ctx)) ctx.set_output(node_id=self.node_id, batch=out_batch) # Compute losses for loss in losses: weighted_raw_loss = loss.compute(ctx=ctx) lr = LossRecord( label=loss.label, node_id=loss.node_id, trainable=weighted_raw_loss, ) loss_records.append(lr) # Backward + opt step lc = LossCollection(records=loss_records) lc.trainable.backward() self._optimizer.step() # Record loss collection ctx.add_losses(lc) def _train_step_tensorflow( self, ctx: ExecutionContext, losses: list[AppliedLoss], ): """ Runs a training step using Tensorflow: forward, loss, backward, optimizer. Args: ctx (ExecutionContext): Context (input/output data) for the given execution step. losses (list[AppliedLoss]): List of losses to be applied in this execution step. """ # Zero optimizer self._optimizer.zero_grad() loss_records: list[LossRecord] = [] # Track gradients over forward passes & loss computation with tf.GradientTape() as tape: # Forward pass (ctx.execution modified inplace) out_batch: Batch = self.forward_single( self._get_input_batch(ctx=ctx), training=True, ) ctx.set_output(node_id=self.node_id, batch=out_batch) # Compute losses for loss in losses: weighted_raw_loss = loss.compute(ctx=ctx) lr = LossRecord( label=loss.label, node_id=loss.node_id, trainable=weighted_raw_loss, ) loss_records.append(lr) # Backward + opt step lc = LossCollection(records=loss_records) grads = tape.gradient(lc.trainable, self._model.trainable_variables) self._optimizer.step(grads=grads, variables=self._model.trainable_variables) # Record loss collection ctx.add_losses(lc) def _train_step_scikit( self, ctx: ExecutionContext, losses: list[AppliedLoss], ): """ Runs a training step using scikit-learn's `partial_fit`. Only applicable to models that support incremental learning (e.g., SGDRegressor, MLPRegressor). Batch-fit models should use `FitPhase` instead. Args: ctx (ExecutionContext): Context (input/output data) for the given execution step. losses (list[AppliedLoss]): List of losses to be applied in this execution step. """ from modularml.core.models.scikit_wrapper import ( ScikitModelWrapper, ScikitTrainingMode, ) if ( isinstance(self._model, ScikitModelWrapper) and self._model.resolved_training_mode != ScikitTrainingMode.PARTIAL_FIT ): msg = ( f"ModelNode '{self.label}' wraps a batch-fit scikit model " f"({type(self._model.model).__name__}) that does not support " "incremental training. Use `fit_step` instead of `train_step`." ) raise RuntimeError(msg) # Get input batch input_batch: Batch = self._get_input_batch(ctx=ctx) # Merge data from all roles, then partial fit on joint set joint_sd = SampleData.concat( *list(input_batch.role_data.values()), fmt=get_data_format_for_backend(self.backend), ) # Perform incremental fit on this merged data self._model.partial_fit( joint_sd.features, joint_sd.targets, ) # Forward pass to record outputs (equivalent to .predict()) out_batch: Batch = self.forward_single(input_batch) ctx.set_output(node_id=self.node_id, batch=out_batch) # Compute losses (recorded as auxiliary since no gradient backprop) loss_records: list[LossRecord] = [] if losses: for loss in losses: weighted_raw_loss = loss.compute(ctx=ctx) lr = LossRecord( label=loss.label, node_id=loss.node_id, auxiliary=weighted_raw_loss, ) loss_records.append(lr) lc = LossCollection(records=loss_records) ctx.add_losses(lc)
[docs] def train_step( self, ctx: ExecutionContext, losses: list[AppliedLoss], ): """ Performs a training step (forward, loss, backward, optimizer step) for this stage. Only callable if this stage has an optimizer and is not frozen. Otherwise, training must be delegated to `ModelGraph`. Args: ctx (ExecutionContext): Context (input/output data) for the given execution step. losses (list[AppliedLoss]): List of losses to be applied in this execution step. Raises: RuntimeError: If stage is frozen or optimizer is missing. """ # If stage is frozen, raise error if self.is_frozen: msg = "Cannot train a frozen node. Either unfreeze or use `eval_step`." raise RuntimeError(msg) # Ensure input data exists for this node self._validate_ctx(ctx=ctx) # Ensure losses only include those applied to this node valid_losses = losses if losses is not None: valid_losses = [loss for loss in losses if loss.node_id == self.node_id] # Ensure optimizer is set and matches model backend self._check_valid_optimizer(required=True) if self.backend == Backend.TORCH: return self._train_step_torch(ctx=ctx, losses=valid_losses) if self.backend == Backend.TENSORFLOW: return self._train_step_tensorflow(ctx=ctx, losses=valid_losses) if self.backend == Backend.SCIKIT: return self._train_step_scikit(ctx=ctx, losses=valid_losses) msg = f"Unknown backend: {self.backend}" raise ValueError(msg)
# ================================================ # Evaluable Protocol # ================================================ def _eval_step_torch( self, ctx: ExecutionContext, losses: list[AppliedLoss] | None = None, ): """ Runs an evaluation step using PyTorch: forward + loss (no gradients). Args: ctx (ExecutionContext): Context (input/output data) for the given execution step. losses (list[AppliedLoss]): Optional list of losses to be applied in this execution step. """ # Set eval mode self._model.eval() loss_records: list[LossRecord] = [] # Forward pass (ctx.execution modified inplace) with torch.no_grad(): out_batch: Batch = self.forward_single(self._get_input_batch(ctx=ctx)) ctx.set_output(node_id=self.node_id, batch=out_batch) # Compute losses if losses is not None: for loss in losses: weighted_raw_loss = loss.compute(ctx=ctx) lr = LossRecord( label=loss.label, node_id=loss.node_id, auxiliary=weighted_raw_loss, ) loss_records.append(lr) # Record loss collection lc = LossCollection(records=loss_records) ctx.add_losses(lc) def _eval_step_tensorflow( self, ctx: ExecutionContext, losses: list[AppliedLoss] | None = None, ): """ Runs an evaluation step using Tensorflow: forward + loss (no gradients). Args: ctx (ExecutionContext): Context (input/output data) for the given execution step. losses (list[AppliedLoss]): Optional list of losses to be applied in this execution step. """ loss_records: list[LossRecord] = [] # Forward pass (ctx.execution modified inplace) out_batch: Batch = self.forward_single( self._get_input_batch(ctx=ctx), training=False, ) ctx.set_output(node_id=self.node_id, batch=out_batch) # Compute losses if losses is not None: for loss in losses: weighted_raw_loss = loss.compute(ctx=ctx) lr = LossRecord( label=loss.label, node_id=loss.node_id, auxiliary=weighted_raw_loss, ) loss_records.append(lr) # Record loss collection lc = LossCollection(records=loss_records) ctx.add_losses(lc) def _eval_step_scikit( self, ctx: ExecutionContext, losses: list[AppliedLoss] | None = None, ): """ Runs an evaluation step for a scikit-learn model: forward pass + optional loss. Args: ctx (ExecutionContext): Context (input/output data) for the given execution step. losses (list[AppliedLoss]): Optional list of losses to be applied in this execution step. """ # Forward pass out_batch: Batch = self.forward_single(self._get_input_batch(ctx=ctx)) ctx.set_output(node_id=self.node_id, batch=out_batch) # Compute losses (auxiliary only) loss_records: list[LossRecord] = [] if losses is not None: for loss in losses: weighted_raw_loss = loss.compute(ctx=ctx) lr = LossRecord( label=loss.label, node_id=loss.node_id, auxiliary=weighted_raw_loss, ) loss_records.append(lr) lc = LossCollection(records=loss_records) ctx.add_losses(lc)
[docs] def eval_step( self, ctx: ExecutionContext, losses: list[AppliedLoss] | None = None, ): """ Performs an evaluation step (forward pass and loss computation) for this stage. Only callable if this stage is frozen. No gradient tracking is performed. Args: ctx (ExecutionContext): Context (input/output data) for the given execution step. losses (list[AppliedLoss]): Optional list of losses to be applied in this execution step. Raises: RuntimeError: If stage is not frozen. """ # If stage is not frozen, raise error if self.is_frozen: msg = "Cannot evaluate an unfrozen node. Either freeze or use `train_step`." raise RuntimeError(msg) # Ensure input data exists for this node self._validate_ctx(ctx=ctx) # Ensure losses only include those applied to this node valid_losses = losses if losses is not None: valid_losses = [loss for loss in losses if loss.node_id == self.node_id] if self.backend == Backend.TORCH: return self._eval_step_torch(ctx=ctx, losses=valid_losses) if self.backend == Backend.TENSORFLOW: return self._eval_step_tensorflow(ctx=ctx, losses=valid_losses) if self.backend == Backend.SCIKIT: return self._eval_step_scikit(ctx=ctx, losses=valid_losses) msg = f"Unknown backend: {self.backend}" raise ValueError(msg)
# ================================================ # Fittable Protocol # ================================================
[docs] def fit_step( self, ctx: ExecutionContext, losses: list[AppliedLoss] | None = None, ): """ Fits this node on complete data (for batch-fit scikit-learn models). Calls the underlying model's `.fit(X, y)` method using the full dataset provided in the execution context. After fitting, a forward pass is performed to record outputs for downstream nodes. Args: ctx (ExecutionContext): Context containing full-dataset inputs. losses (list[AppliedLoss] | None): Optional losses to compute after fitting (for metrics only). Raises: RuntimeError: If this node is frozen. """ if self.is_frozen: msg = f"Cannot fit a frozen node '{self.label}'." raise RuntimeError(msg) if not hasattr(self, "fit"): msg = f"Node `{self.label}` does not implement a `.fit()` method." raise AttributeError(msg) self._validate_ctx(ctx=ctx) # Get input batch input_batch: Batch = self._get_input_batch(ctx=ctx) # Merge data from all roles, then fit on joint set joint_sd = SampleData.concat( *list(input_batch.role_data.values()), fmt=get_data_format_for_backend(self.backend), ) # Perform incremental fit on this merged data self._model.fit( joint_sd.features, joint_sd.targets, ) # Forward pass to record outputs for downstream nodes out_batch: Batch = self.forward_single(input_batch) ctx.set_output(node_id=self.node_id, batch=out_batch) # Optional loss computation (auxiliary only) if losses is not None: valid_losses = [loss for loss in losses if loss.node_id == self.node_id] loss_records: list[LossRecord] = [] for loss in valid_losses: weighted_raw_loss = loss.compute(ctx=ctx) lr = LossRecord( label=loss.label, node_id=loss.node_id, auxiliary=weighted_raw_loss, ) loss_records.append(lr) ctx.add_losses(LossCollection(records=loss_records))
# ================================================ # Configurable # ================================================
[docs] def get_config(self) -> dict[str, Any]: """ Retrieve the configuration details of this ModelNode instance. This does not contain state information of the underlying model or optimizer. """ cfg = super().get_config() cfg.update( { "model": self._model.get_config(), "optimizer": None if self._optimizer is None else self._optimizer.get_config(), "frozen": self._freeze, "graph_node_type": "ModelNode", }, ) return cfg
[docs] @classmethod def from_config( cls, config: dict[str, Any], *, register: bool = True, ) -> ModelNode: """ Reconstructs a ModelNode from configuration details. This does not restore state information of the underlying model or optimizer. """ if "graph_node_type" not in config or config["graph_node_type"] != "ModelNode": raise ValueError("Invalid config data for ModelNode.") # Rebuild model (no weights) model = BaseModel.from_config(config["model"]) # Rebuild optimizer optimizer = None optimizer_cfg = config.get("optimizer") if optimizer_cfg is not None: optimizer = Optimizer.from_config(optimizer_cfg) # Create ModelNode node = cls( label=config["label"], model=model, upstream_ref=config["upstream_refs"][0] if config["upstream_refs"] else None, optimizer=optimizer, node_id=config.get("node_id"), register=register, ) # Restore downstream refs explicitly node.set_downstream_refs(config.get("downstream_refs", [])) # Restore frozen flag node._freeze = config.get("frozen", False) return node
# ================================================ # Stateful # ================================================
[docs] def get_state(self) -> dict[str, Any]: """ Return serialized state for the node, model, and optimizer. Returns: dict[str, Any]: Snapshot captured for :meth:`set_state`. """ state = { "super": super().get_state(), "model": self._model.get_state(), "optimizer": None if self._optimizer is None else self._optimizer.get_state(), "frozen": self._freeze, } return state
[docs] def set_state(self, state: dict[str, Any]) -> None: """ Restore runtime state from :meth:`get_state` output. Args: state (dict[str, Any]): Serialized node data. """ # Set parent state first super().set_state(state["super"]) # Model weights can always be restored self._model.set_state(state["model"]) # Optimizer state may need to wait until build() if self._optimizer is not None and state.get("optimizer") is not None: self._optimizer.set_state(state["optimizer"]) # Restore freeze state (must re-apply to sync backend parameters) if state.get("frozen", False): self.freeze() else: self.unfreeze()