"""Model graph orchestration logic for ModularML."""
from __future__ import annotations
from collections import defaultdict, deque
from collections.abc import Iterable
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any
from modularml.core.experiment.experiment_context import ExperimentContext
from modularml.core.io.checkpoint import Checkpoint
from modularml.core.io.protocols import Configurable, Stateful
from modularml.core.references.featureset_reference import FeatureSetReference
from modularml.core.topology.compute_node import ComputeNode, TForward
from modularml.core.topology.graph_node import GraphNode
from modularml.core.topology.model_node import ModelNode
from modularml.core.topology.protocols import (
Evaluable,
Fittable,
Forwardable,
Trainable,
)
from modularml.core.training.loss_record import LossCollection, LossRecord
from modularml.core.training.optimizer import Optimizer
from modularml.utils.data.comparators import deep_equal
from modularml.utils.data.data_format import DataFormat, get_data_format_for_backend
from modularml.utils.data.dummy_data import make_dummy_sample_data
from modularml.utils.environment.optional_imports import check_tensorflow, check_torch
from modularml.utils.errors.error_handling import ErrorMode
from modularml.utils.errors.exceptions import BackendNotSupportedError
from modularml.utils.logging.warnings import catch_warnings, get_logger, warn
from modularml.utils.nn.backend import (
Backend,
backend_requires_optimizer,
is_valid_backend,
)
from modularml.utils.topology.graph_search_utils import (
get_subgraph_nodes,
is_head_node,
is_tail_node,
)
from modularml.visualization.visualizer.styling import ModelGraphDisplayOptions
from modularml.visualization.visualizer.visualizer import Visualizer
if TYPE_CHECKING:
from modularml.core.data.batch import Batch
from modularml.core.data.execution_context import ExecutionContext
from modularml.core.data.sample_data import SampleData
from modularml.core.references.experiment_reference import (
ExperimentNodeReference,
GraphNodeReference,
)
from modularml.core.training.applied_loss import AppliedLoss
tf = check_tensorflow()
torch = check_torch()
logger = get_logger("ModelGraph")
[docs]
class ModelGraph(Configurable, Stateful):
"""
Directed acyclic graph orchestrating :class:`GraphNode` nodes.
Attributes:
label (str):
User-defined identifier for the graph.
nodes (dict[str, GraphNode]):
Registered nodes keyed by :attr:`GraphNode.node_id`.
_optimizer (Optimizer | None):
Optional graph-level optimizer.
"""
[docs]
def __init__(
self,
nodes: list[str | GraphNode] | None,
optimizer: Optimizer | None = None,
label: str = "model-graph",
*,
ctx: ExperimentContext | None = None,
register: bool = True,
):
"""
Initialize a ModelGraph from a list of modular nodes and an optional global optimizer.
Args:
nodes (list[str | GraphNode], optional):
A list of GraphNodes (e.g., ModeNode) or their labels to construct a ModelGraph
around. If None, all registered GraphNodes in the active ExperimentContext are used.
optimizer (Optional[Optimizer], optional):
A shared optimizer to use for all `nodes` that require one. If provided,
the graph will ensure that all such stages have a matching backend and override
any stage-level optimizers. If not provided, each stage that requires an optimizer
must define one locally.
label (str, optional):
Optional label to assign to this ModelGraph instance. Defaults to "model-graph-0".
ctx (ExperimentContext, optional):
The ExperimentContext this ModelGraph should exist within. If None, uses the
active ExperimentContext. Defaults to None.
register (bool, optional):
Used only for de-serialization.
Raises:
ValueError:
If duplicate node labels are provided or if graph connectivity is invalid.
RuntimeError:
If required optimizers are missing or if backends are incompatible.
"""
# Register to context
self.exp_ctx = ctx or ExperimentContext.get_active()
if register:
self.exp_ctx.register_model_graph(self)
# Update default label to next available name
self.label = label
# Nodes comprising this graph, keyed by node_id
self._nodes: dict[str, GraphNode] = {}
if nodes is not None:
for n in nodes:
cn = n
if isinstance(n, str):
if self.exp_ctx.has_node(label=n):
cn = self.exp_ctx.get_node(label=n)
else:
msg = (
f"String value given in `nodes` ('{n}') does not exist "
"in the active Experiment Context."
)
raise ValueError(msg)
if not isinstance(cn, GraphNode):
msg = (
"All objects in `nodes` must be of type GraphNode. "
f"Received: {type(cn)}."
)
raise TypeError(msg)
self._nodes[cn.node_id] = cn
else:
self._nodes = self.exp_ctx.available_computenodes
self._optimizer = optimizer
# Nodes that require an optimizer
self._nodes_req_opt: dict[str, ModelNode] | None = None
# Nodes used in building a global optimizer
self._opt_built_from_node_ids: set[str] | None = None
self._optimizer_state: dict[str, Any] | None = None
# Connection helpers
self._head_nodes: dict[str, GraphNode] = {}
self._tail_nodes: dict[str, GraphNode] = {}
self._rebuild_connections()
# ================================================
# Properties & Dunders
# ================================================
@property
def nodes(self) -> dict[str, GraphNode]:
"""All GraphNode in the ModelGraph, keyed by `node_id`."""
return self._nodes
@property
def node_labels(self) -> set[str]:
"""Returns the set of unique node labels in this ModelGraph."""
lbls = []
for n in self.nodes.values():
lbls.append(n.label)
unique_labels = set(lbls)
if len(unique_labels) != len(lbls):
msg = "This ModelGraph contains nodes with identical labels. Only unique labels are returned."
hint = "Use the `nodes` property to retrieve all unique nodes and ids."
warn(msg, category=UserWarning, hints=hint, stacklevel=2)
return unique_labels
@property
def head_nodes(self) -> dict[str, GraphNode]:
"""
Head nodes of the ModelGraph.
Description:
Head nodes are GraphNodes whose inputs originate directly from
FeatureSets (i.e., they have no upstream GraphNode dependencies).
Returns:
dict[str, GraphNode]:
Mapping of node_id to GraphNode for all head nodes.
"""
return self._head_nodes
@property
def tail_nodes(self) -> dict[str, GraphNode]:
"""
Tail nodes of the graph.
Description:
Tail nodes are GraphNodes whose outputs are not consumed
by any other GraphNode in the ModelGraph.
Returns:
dict[str, GraphNode]:
Mapping of node_id to GraphNode for all tail nodes.
"""
return self._tail_nodes
@property
def is_built(self) -> bool:
"""
Return whether the graph has been initialized.
Returns:
bool: True once :meth:`build` has run successfully.
"""
return self._built
def __eq__(self, other: ModelGraph):
"""
Return True when graph configs and states match.
Args:
other (ModelGraph): Graph compared to this instance.
Returns:
bool: True when configs and runtime state are identical.
Raises:
TypeError: If `other` is not a :class:`ModelGraph`.
"""
if not isinstance(other, ModelGraph):
msg = f"Cannot compare equality between ModelGraph and {type(other)}"
raise TypeError(msg)
if not self.label == other.label:
return False
if not deep_equal(self.get_config(), other.get_config()):
return False
return deep_equal(self.get_state(), other.get_state())
__hash__ = None
@property
def frozen_nodes(self) -> dict[str, GraphNode]:
"""All trainable but frozen nodes in this graph, keyed by `node_id`."""
return {
n_id: node
for n_id, node in self.nodes.items()
if isinstance(node, Trainable) and node.is_frozen
}
# ================================================
# Error Checking Methods
# ================================================
def _validate_graph_connections(self):
"""
Validates the internal graph structure.
Perform the following checks:
- Ensures nodes are valid GraphNode instances.
- Propagates inputs to upstream node outputs.
- Validates input/output limits.
- Ensures the graph is a DAG (no cycles).
- Ensures all nodes are reachable from at least one base node.
Raises:
TypeError: If any node is not a GraphNode.
KeyError: If a node references a non-existent input.
ValueError: If input/output constraints are violated or if a cycle is detected.
UserWarning: If unreachable nodes are found or mixed backends are used.
"""
used_backends = [] # record all node backends (for checking)
frontier = [] # get base nodes (for traversal / connection checks)
# Ensure node inherits from GraphNode, and input/ouput properties are fully set
for n_id, node in self._nodes.items():
if not isinstance(node, GraphNode):
msg = f"ModelGraph nodes must be of type GraphNode. Received: {node}"
raise TypeError(msg)
# Record backend for later checking
if hasattr(node, "backend") and is_valid_backend(node.backend):
used_backends.append(node.backend)
# Record base nodes (any GraphNode with upstream_refs = FeatureSetReference)
all_up_refs = node.get_upstream_refs()
if all(isinstance(x, FeatureSetReference) for x in all_up_refs):
frontier.append(n_id)
# Validate all upstream connections
for ups_ref in all_up_refs:
# Ensure node exists in current ctx
ups_node = ups_ref.resolve(ctx=self.exp_ctx)
# If a FeatureSet (or view), continue
if isinstance(ups_ref, FeatureSetReference):
continue
# Ensure this upstream node (a GraphNode) is also in the graph
if ups_node.node_id not in self._nodes:
msg = (
f"Upstream node '{ups_node.label}' for node '{node.label}'"
"not found in ModelGraph."
)
raise KeyError(msg)
if not isinstance(ups_node, GraphNode):
msg = (
"Non-FeatureSet references must resolve to GraphNodes. "
f"Received: {type(ups_node)}."
)
raise TypeError(msg)
# Ensure the upstream node also references this node in its output (if has outputs)
# Using "coerce" ignores the warning if the reference already exists
# but still raises error if reached max number of downstream nodes
ups_node.add_downstream_ref(
node.reference(),
error_mode=ErrorMode.COERCE,
)
# Warn if using mixed backend: not thoroughly tested
if len(set(used_backends)) > 1:
msg = (
"Mixed backends detected in ModelGraph. Though allowed, minimal testing has been "
"conducted. Gradient flow may break during training."
)
warn(msg, category=UserWarning, stacklevel=2)
# Ensure is DAG (check for cycles)
visited = set()
visiting = set()
def dfs(node: GraphNode):
"""Depth first search."""
if node.node_id in visiting:
msg = f"Cycle detected in graph at node '{node.label}'. Graph must be acyclic."
raise ValueError(msg)
if node.node_id in visited:
return
visiting.add(node.node_id)
for dwn_ref in node.get_downstream_refs(error_mode=ErrorMode.IGNORE):
dfs(dwn_ref.resolve(self.exp_ctx))
visiting.remove(node.node_id)
visited.add(node.node_id)
# Perform depth-first-search starting at head nodes
for root_node_id in frontier:
node = self.exp_ctx.get_node(node_id=root_node_id)
dfs(node)
# Ensure reachability of all nodes
reachable = set()
queue = list(frontier)
while queue:
cur_node_id = queue.pop(0)
if cur_node_id in reachable:
continue
reachable.add(cur_node_id)
cur_node: GraphNode = self._nodes[cur_node_id]
dwn_node_ids: list[str] = [
ref.node_id
for ref in cur_node.get_downstream_refs(
error_mode=ErrorMode.IGNORE,
)
]
queue.extend(dwn_node_ids)
unreachable_node_ids = set(self._nodes.keys()) - reachable
if unreachable_node_ids:
node_labels = [self.nodes[un].label for un in unreachable_node_ids]
msg = f"Unreachable nodes detected in ModelGraph: {sorted(node_labels)}."
hint = "Verify upstream reference atributes of all GraphNodes."
warn(msg, category=UserWarning, stacklevel=2, hints=hint)
def _topological_sort(self) -> list[str]:
"""
Perform a topological sort of the ModelGraph using Kahn's algorithm.
Returns:
List[str]: A list of node labels in topological (execution) order.
Raises:
ValueError: If a cycle is detected in the graph.
"""
in_degree = defaultdict(int) # Number of incoming edges (keyed by node ID)
children = defaultdict(list) # Outgoing edges (keyed by node ID)
all_node_ids = set(self._nodes.keys())
# Initialize in-degrees
for node_id in all_node_ids:
in_degree[node_id] = 0
# Record in-degree (number of inputs) and out-degree for each node
for node_id, node in self._nodes.items():
# Get parents of this node (ie, upstream)
parent_node_ids: list[str] = [
ref.node_id
for ref in node.get_upstream_refs(
error_mode=ErrorMode.IGNORE,
)
if not isinstance(ref, FeatureSetReference)
]
for parent_id in parent_node_ids:
if parent_id not in self._nodes:
p_node = self.exp_ctx.get_node(node_id=parent_id)
msg = f"Invalid upstream_node '{p_node.label}' for node `{node.label}`."
raise KeyError(msg)
in_degree[node_id] += 1
children[parent_id].append(node_id)
# Init a queue with base nodes (no inputs)
sorted_node_ids: list[str] = []
queue = deque([node_id for node_id in all_node_ids if in_degree[node_id] == 0])
while queue:
current = queue.popleft()
sorted_node_ids.append(current)
for child in children[current]:
in_degree[child] -= 1
if in_degree[child] == 0:
queue.append(child)
if len(sorted_node_ids) != len(all_node_ids):
unresolved = all_node_ids - set(sorted_node_ids)
unres_node_lbls = [
self.exp_ctx.get_node(node_id=un).label for un in unresolved
]
msg = f"Cyclic dependency detected in ModelGraph: {unres_node_lbls}"
raise ValueError(msg)
return sorted_node_ids
def _validate_optimizer(self):
"""
Validate and assign optimizers to all trainable stages in the graph.
Description:
This method ensures that all GraphNodes which require an optimizer
(based on their backend) are properly configured with one.
- If a global optimizer is provided to the ModelGraph, it will be assigned
to all relevant stages. If those stages already define a local optimizer,
it will be overwritten with a warning.
- If no global optimizer is provided, then every stage that requires an optimizer
must have its own stage-level optimizer defined.
- It also verifies that all optimizers share a consistent backend (e.g., PyTorch).
- If a node without a `backend` attribute exists on a path between two
optimizer-requiring nodes, a global optimizer cannot be used (e.g., a
static merge node that doesn't support gradient propagation).
Raises:
RuntimeError: If any stage that requires an optimizer is missing one and no
global optimizer is provided.
RuntimeError: If a global optimizer is provided but its backend doesn't match
a stage's backend.
RuntimeError: If a node without a backend sits between optimizer-requiring
nodes when a global optimizer is used.
UserWarning: If a stage has its own optimizer but is being overwritten by the
graph-level optimizer.
"""
# Get nodes that require optimizer (only ModelNodes)
self._nodes_req_opt: dict[str, ModelNode] = {
node_id: node
for node_id, node in self._nodes.items()
if isinstance(node, ModelNode) and backend_requires_optimizer(node.backend)
}
# Ensure all stages have their own optimizer if global one not provided
if self._optimizer is None:
for node in self._nodes_req_opt.values():
if not hasattr(node, "_optimizer") or node._optimizer is None:
msg = (
f"ModelNode '{node.label}' is missing an optimizer. "
f"Provide one at the stage level or to ModelGraph."
)
raise RuntimeError(msg)
# Ensure all stages have the same backend
else:
used_backends = []
for node in self._nodes_req_opt.values():
used_backends.append(node.backend)
# Overwrite existing optimizers at stage-level (and warn)
if node._optimizer is not None:
msg = (
f"An optimizer was provided to both the ModelGraph and the '{node.label}' "
f"ModelNode. The optimizer for '{node.label}' will be overwritten."
)
warn(msg, category=UserWarning, stacklevel=2)
node._optimizer = None
# Warn if using mixed backend: not thoroughly tested
if len(set(used_backends)) > 1:
msg = (
"A global optimizer was provided to ModelGraph, but the underlying stages have "
"differing backends. All backends must match to use a single optimizer."
)
raise RuntimeError(msg)
# Check for intermediate nodes without a backend between optimizer-requiring nodes
opt_node_ids = set(self._nodes_req_opt.keys())
if len(opt_node_ids) > 1:
upstream_of_opt: set[str] = set()
downstream_of_opt: set[str] = set()
for nid in opt_node_ids:
upstream_of_opt |= get_subgraph_nodes(
self,
nid,
direction="upstream",
include_roots=False,
)
downstream_of_opt |= get_subgraph_nodes(
self,
nid,
direction="downstream",
include_roots=False,
)
between_ids = (upstream_of_opt & downstream_of_opt) - opt_node_ids
for nid in between_ids:
node = self._nodes[nid]
if not hasattr(node, "backend"):
msg = (
f"A global optimizer was provided to ModelGraph, but node "
f"'{node.label}' sits between optimizer-requiring stages and "
f"does not have a backend. Cannot use a global optimizer."
)
raise RuntimeError(msg)
self._optimizer.backend = used_backends[0]
def _rebuild_connections(self):
"""Recompute cached graph metadata after topology changes."""
# Clear downstream references (auto-generated in validation)
for n in self._nodes.values():
if hasattr(n, "clear_downstream_refs"):
n.clear_downstream_refs(ErrorMode.IGNORE)
# Validate graph and connections
self._validate_graph_connections()
# Cache head/tail nodes (must be after validation ^)
head_nodes: dict[str, GraphNode] = {}
tail_nodes: dict[str, GraphNode] = {}
for n_id, node in self._nodes.items():
# Head nodes: inputs from a FeatureSet
if is_head_node(node):
head_nodes[n_id] = node
# Tail nodes: no downstream consumers
if is_tail_node(node):
tail_nodes[n_id] = node
self._head_nodes = head_nodes
self._tail_nodes = tail_nodes
# Topological sort
self._sorted_node_ids = self._topological_sort()
# If an optimizer is provided, check that:
# 1. all optimizer-requiring stages have same backend
# 2. warn if stages have their own optimizer (will be overwritten)
self._validate_optimizer()
self._built = False
# ================================================
# Connection Modifiers
# ================================================
def _resolve_existing(self, val: str | GraphNode) -> GraphNode:
"""
Verifies that the given values corresponds to a node in this graph.
Args:
val (str | GraphNode):
Node ID, label, or instance of a node in this graph.
Returns:
GraphNode:
The node instance of the existing node. Throws an error
if the value cannot be resolved to an existing node.
"""
# Normalize value to GraphNode instance
existing_node: GraphNode | None = None
if isinstance(val, str):
if val in self._nodes:
existing_node = self._nodes[val]
else:
# Get existing node labels, if not unique, throw error
existing_node_lbls = None
with catch_warnings() as cw:
existing_node_lbls = self.node_labels
if cw.match("contains nodes with identical labels"):
existing_node_lbls = None
if existing_node_lbls is None:
msg = (
"ModelGraph contains nodes with identical labels. Existing nodes "
"must be referenced with either their node ID or the actual instance."
)
raise ValueError(msg)
# Get node instance with that label
matches = [n for n in self._nodes.values() if n.label == val]
if not matches:
msg = f"No node exists in this graph with label '{val}'."
raise ValueError(msg)
existing_node = matches[0]
elif isinstance(val, GraphNode):
if val.node_id not in self._nodes:
msg = f"No node exists in this graph with id '{val.node_id}'."
raise ValueError(msg)
existing_node = val
else:
msg = f"Existing node value must be of type `str` or `GraphNode`. Received: {type(val)}."
raise TypeError(msg)
return existing_node
[docs]
def add_node(self, node: GraphNode) -> ModelGraph:
"""
Add a new node to the graph.
Description:
This modifies graph structure only; no existing node states
are reset or copied. The added node must already be registered
in the ExperimentContext.
Args:
node (GraphNode): Node to add.
Returns:
ModelGraph: self
Raises:
ValueError: If a node with the same `node_id` already exists.
"""
if not isinstance(node, GraphNode):
msg = f"Expected GraphNode, got {type(node)}"
raise TypeError(msg)
if node.node_id in self._nodes:
msg = f"Node '{node.label}' already exists in ModelGraph."
raise ValueError(msg)
self._nodes[node.node_id] = node
self._rebuild_connections()
return self
[docs]
def replace_node(
self,
old_node: str | GraphNode,
new_node: GraphNode,
) -> ModelGraph:
"""
Replace an existing node while preserving all upstream and downstream connections.
Description:
Connectivity of the graph is preserved, and learned state of the other nodes
is unaffected. The replaced node's state is replaced with the new node.
Args:
old_node (str | GraphNode):
Existing node in the graph to be replaced. Provided argument value can be
the existing node's label, ID, or the actual node instance.
The state of the existing node is not changed; the graph connections
are simply redirected to the `new_node`.
new_node (GraphNode):
New node instance to take the spot of `old_node`.
Returns:
ModelGraph: self
"""
# Normalize old_node valus
old_node = self._resolve_existing(val=old_node)
# Validate new_node type
if not isinstance(new_node, GraphNode):
msg = f"New node must be of type GraphNode, got {type(new_node)}"
raise TypeError(msg)
# Grab connection to/from old_node
ups_refs: list[GraphNodeReference] = old_node.get_upstream_refs(
error_mode=ErrorMode.IGNORE,
)
dwn_refs: list[GraphNodeReference] = old_node.get_downstream_refs(
error_mode=ErrorMode.IGNORE,
)
# Update new_node to match old_node connections
new_node.set_upstream_refs(ups_refs, error_mode=ErrorMode.COERCE)
new_node.set_downstream_refs(dwn_refs, error_mode=ErrorMode.COERCE)
# Update all nodes upstream of old_node
for ref in ups_refs:
# FeatureSet refs are not graph nodes; skip them
if isinstance(ref, FeatureSetReference):
continue
# Get all downstream refs of the upstream node, removing the old_node ref
u_dwn_refs = [
r
for r in self._nodes[ref.node_id].get_downstream_refs(
error_mode=ErrorMode.IGNORE,
)
if r.node_id != old_node.node_id
]
# Replace with cleaned refs
self._nodes[ref.node_id].set_downstream_refs(
u_dwn_refs,
error_mode=ErrorMode.IGNORE,
)
# Add new_node reference
self._nodes[ref.node_id].add_downstream_ref(
new_node.reference(),
error_mode=ErrorMode.IGNORE,
)
# Update all nodes downstream of old_node
for ref in dwn_refs:
# Get all upstream refs of the downstream node, removing the old_node ref
u_ups_refs = [
r
for r in self._nodes[ref.node_id].get_upstream_refs(
error_mode=ErrorMode.IGNORE,
)
if r.node_id != old_node.node_id
]
# Replace with cleaned refs
self._nodes[ref.node_id].set_upstream_refs(
u_ups_refs,
error_mode=ErrorMode.IGNORE,
)
# Add new_node reference
self._nodes[ref.node_id].add_upstream_ref(
new_node.reference(),
error_mode=ErrorMode.IGNORE,
)
# Replace self._nodes
_ = self._nodes.pop(old_node.node_id)
self._nodes[new_node.node_id] = new_node
self._rebuild_connections()
return self
[docs]
def insert_node_between(
self,
new_node: GraphNode,
*,
upstream: str | GraphNode,
downstream: str | GraphNode,
) -> ModelGraph:
"""
Insert a node between two, already connected, nodes.
Description:
Insert a new node between a connection of two existing node.
The old connection (upstream -> downstream) is replaced with
(upstream -> new_node -> downstream).
An error will be thrown if the existing nodes are not already
connected.
Args:
new_node (GraphNode):
New node instance to be inserted.
upstream (str | GraphNode):
Node ID, label, or instance of an existing ModelGraph node.
The `new_node` will be inserted downstream of this node.
downstream (str | GraphNode):
Node ID, label, or instance of an existing ModelGraph node.
The `new_node` will be inserted upstream of this node.
Returns:
ModelGraph: self
"""
# Normalize existing node valus
ups_node: GraphNode = self._resolve_existing(val=upstream)
dwn_node: GraphNode = self._resolve_existing(val=downstream)
# Validate new_node type
if not isinstance(new_node, GraphNode):
msg = f"New node must be of type GraphNode, got {type(new_node)}"
raise TypeError(msg)
# Clear any references on new_node
new_node.clear_upstream_refs(error_mode=ErrorMode.COERCE)
new_node.clear_downstream_refs(error_mode=ErrorMode.COERCE)
# Validate that `downstream` connects to `upstream`
existing_dwn_to_ups: GraphNodeReference | FeatureSetReference = None
for ref in dwn_node.get_upstream_refs(error_mode=ErrorMode.IGNORE):
if ref.node_id == ups_node.node_id:
existing_dwn_to_ups = ref
if existing_dwn_to_ups is None:
dwn_ups_n_lbls = [
ref.node_label
for ref in dwn_node.get_upstream_refs(
error_mode=ErrorMode.IGNORE,
)
]
ups_dwn_n_lbls = [
ref.node_label
for ref in ups_node.get_downstream_refs(
error_mode=ErrorMode.IGNORE,
)
]
msg = f"`downstream` does not take input `upstream`. Detected inputs from: {dwn_ups_n_lbls}."
raise ValueError(msg)
# Validate that `upstream` connect to `downstream`
existing_ups_to_dwn: GraphNodeReference = None
for ref in ups_node.get_downstream_refs(error_mode=ErrorMode.IGNORE):
if ref.node_id == dwn_node.node_id:
existing_ups_to_dwn = ref
break
if existing_ups_to_dwn is None:
ups_dwn_n_lbls = [
ref.node_label
for ref in ups_node.get_downstream_refs(
error_mode=ErrorMode.IGNORE,
)
]
msg = f"`upstream` does not output to `downstream`. Detected outputs to: {ups_dwn_n_lbls}."
raise ValueError(msg)
# Replace downstream connection of `upstream`
# 'ups -> dwn' with 'ups -> new'
ups_node.remove_downstream_ref(
ref=existing_ups_to_dwn,
error_mode=ErrorMode.RAISE,
)
ups_node.add_downstream_ref(
ref=new_node.reference(),
error_mode=ErrorMode.RAISE,
)
# Replace upstream connection of `downstream`
# 'ups -> dwn' with 'new -> dwn'
dwn_node.remove_upstream_ref(
ref=existing_dwn_to_ups,
error_mode=ErrorMode.RAISE,
)
dwn_node.add_upstream_ref(
ref=new_node.reference(),
error_mode=ErrorMode.RAISE,
)
# Update `new_node` connection
new_node.add_upstream_ref(ref=ups_node.reference())
new_node.add_downstream_ref(ref=dwn_node.reference())
# Add new_node to graph
self._nodes[new_node.node_id] = new_node
self._rebuild_connections()
return self
[docs]
def insert_node_before(
self,
new_node: GraphNode,
*,
downstream: str | GraphNode,
) -> ModelGraph:
"""
Insert a node before an existing node.
Description:
Inserts a new node before an existing GraphNode. All inputs
to the existing node are attached to the new node. The
existing node will now only receive input from the new node.
Args:
new_node (GraphNode):
New node instance to be inserted.
downstream (str | GraphNode):
Node ID, label, or instance of an existing ModelGraph node.
The `new_node` will be inserted upstream of this node.
Returns:
ModelGraph: self
"""
# Normalize existing node valus
dwn_node: GraphNode = self._resolve_existing(val=downstream)
# Validate new_node type
if not isinstance(new_node, GraphNode):
msg = f"New node must be of type GraphNode, got {type(new_node)}"
raise TypeError(msg)
# Clear any references on new_node
new_node.clear_upstream_refs(error_mode=ErrorMode.COERCE)
new_node.clear_downstream_refs(error_mode=ErrorMode.COERCE)
# Move upstream refs of `downstream` to new node
for ref in dwn_node.get_upstream_refs(error_mode=ErrorMode.IGNORE):
new_node.add_upstream_ref(ref=ref)
# `downstream` should now only get input data from `new_node`
dwn_node.clear_upstream_refs()
dwn_node.add_upstream_ref(new_node.reference())
new_node.add_downstream_ref(dwn_node.reference())
# Add new_node to graph
self._nodes[new_node.node_id] = new_node
self._rebuild_connections()
return self
[docs]
def insert_node_after(
self,
new_node: GraphNode,
*,
upstream: str | GraphNode,
) -> ModelGraph:
"""
Insert a node after an existing node.
Description:
Inserts a new node after an existing GraphNode.
A new output connection is added between the existing
node and the new node. All other output connection are
left undisturbed. The new node will only receive input
from the existing node.
Args:
new_node (GraphNode):
New node instance to be inserted.
upstream (str | GraphNode):
Node ID, label, or instance of an existing ModelGraph node.
The `new_node` will be inserted downstream of this node.
Returns:
ModelGraph: self
"""
# Normalize existing node value
ups_node: GraphNode = self._resolve_existing(val=upstream)
# Validate new_node type
if not isinstance(new_node, GraphNode):
msg = f"New node must be of type GraphNode, got {type(new_node)}"
raise TypeError(msg)
# Clear any references on new_node
new_node.clear_upstream_refs(error_mode=ErrorMode.COERCE)
new_node.clear_downstream_refs(error_mode=ErrorMode.COERCE)
# Attach new_node as an output of `upstream`
ups_node.add_downstream_ref(new_node.reference())
new_node.add_upstream_ref(ups_node.reference())
# Add new_node to graph
self._nodes[new_node.node_id] = new_node
self._rebuild_connections()
return self
[docs]
def remove_node(self, node: str | GraphNode) -> ModelGraph:
"""
Remove an existing node from the graph.
Description:
The existing node is removed an all connections are updated.
Any nodes downstream of `node` will re-route inputs to *all*
nodes that previously provided input to `node`.
Args:
node (GraphNode):
Node ID, label, or instance of an existing ModelGraph node
to be removed.
Example:
1. Removing an existing single-input, single-output node.
Given: `A -> B -> C`
Remove: `B`
Result: `A -> C`
2. Removing an existing multi-input, single-output node.
Given: `[A, B] -> C -> D`
Remove: `C`
Result: `[A, B] -> D`
Note that `D` must be able to accept multiple inputs or an error
will be thrown.
3. Removing an existing single-input, multi-output node.
Given: `A -> B -> [C, D]`
Remove: `B`
Result: `A -> [C, D]`
Note that `A` must be able to accept multiple outputs or an error
will be thrown.
Returns:
ModelGraph: self
"""
# Normalize existing node value
node: GraphNode = self._resolve_existing(val=node)
# Get all upstream and downstream refs for later use
all_ups_refs = node.get_upstream_refs(error_mode=ErrorMode.IGNORE)
all_dwn_refs = node.get_downstream_refs(error_mode=ErrorMode.IGNORE)
# Update all nodes downstream of `node`
# They now should take input from all nodes upstream of `node`
for dwn_ref in all_dwn_refs:
dwn_node = dwn_ref.resolve(ctx=self.exp_ctx)
# Remove reference to `node`
dwn_node.remove_upstream_ref(node.reference())
# Add connection to all of `node`'s upstream refs
for r in all_ups_refs:
dwn_node.add_upstream_ref(r)
# Update all nodes upstream of `node`
# They now should output to all nodes downstream of `node`
for ups_ref in all_ups_refs:
ups_node = ups_ref.resolve(ctx=self.exp_ctx)
# Remove reference to `node`
ups_node.remove_downstream_ref(node.reference())
# Add connection to all of `node`'s downstream refs
for r in all_dwn_refs:
ups_node.add_downstream_ref(r)
# Remove node from the graph
_ = self._nodes.pop(node.node_id)
self._rebuild_connections()
return self
# ================================================
# Graph Construction
# ================================================
def _select_optimizer_parameters(
self,
nodes_to_include: list[str] | None = None,
*,
include_only_unfrozen: bool = True,
) -> tuple[dict[str, list[Any]], set[str]]:
"""
Collect trainable parameters / variables from ModelNodes for optimizer usage.
Args:
nodes_to_include (list[str] | None):
A list of node IDs to consider for parameter extraction.
If None, all nodes in this graph are considered.
include_only_unfrozen (bool, optional):
If True, only nodes in `nodes_to_include` that are not frozen
will be used for parameter extraction. If False, all nodes in
`nodes_to_include` are used (i.e., ignores any frozen state).
Returns:
dict[str, list[Any]]:
A dict with the set of node_ids actually contributing parameters,
and backend specific fields:
- `"backend": Backend,`
- `"contributing_nodes": set[str],`
- `"parameters": list[torch.nn.Parameter], # PyTorch only`
- `"variables": list[tf.Variable], # TensorFlow only`
"""
if self._optimizer is None:
raise ValueError("No global optimizer exists for the graph.")
# Select candidate nodes
node_ids = (
set(nodes_to_include)
if nodes_to_include is not None
else set(self._nodes_req_opt.keys())
)
selected_nodes: list[ModelNode] = []
for nid in node_ids:
node = self._nodes.get(nid)
if node is None:
continue
if not isinstance(node, ModelNode):
continue
if include_only_unfrozen and node.is_frozen:
continue
selected_nodes.append(node)
used_node_ids = {n.node_id for n in selected_nodes}
# Collect backend-specific trainables
backend = self._optimizer.backend
if backend == Backend.TORCH:
parameters = []
for node in selected_nodes:
if not hasattr(node.model, "parameters"):
msg = f"ModelNode '{node.label}' does not expose .parameters()"
raise AttributeError(msg)
parameters.extend(list(node.model.parameters()))
return {
"backend": Backend.TORCH,
"parameters": parameters,
"contributing_nodes": used_node_ids,
}
if backend == Backend.TENSORFLOW:
variables = []
for node in selected_nodes:
if not hasattr(node.model, "trainable_variables"):
msg = (
f"ModelNode '{node.label}' does not expose trainable_variables"
)
raise AttributeError(msg)
variables.extend(list(node.model.trainable_variables))
return {
"backend": Backend.TENSORFLOW,
"variables": variables,
"contributing_nodes": used_node_ids,
}
if backend == Backend.SCIKIT:
raise NotImplementedError("Scikit optimizers are not supported.")
raise BackendNotSupportedError(
backend=backend,
message="Unknown backend for optimizer parameter collection.",
)
def _build_optimizer(
self,
nodes_to_include: list[str] | None = None,
*,
include_only_unfrozen: bool = True,
force: bool = False,
):
"""
Builds the global optimizer with parameters from the specified nodes.
Args:
nodes_to_include (list[str] | None):
A list of node IDs to consider for parameter extraction.
If None, all nodes in this graph are considered.
include_only_unfrozen (bool, optional):
If True, only nodes in `nodes_to_include` that are not frozen
will be used for parameter extraction. If False, all nodes in
`nodes_to_include` are used (i.e., ignores any frozen state).
force (bool, optional):
If False, the optimizer will only be rebuilt if the node
parameters the optimizer relies on, has changed. Otherwise,
the optimizer will be forcefully rebuilt.
"""
if self._optimizer is None:
msg = "No global optimizer exists for the graph."
raise ValueError(msg)
# Get info needed to build optimizer
info = self._select_optimizer_parameters(
nodes_to_include=nodes_to_include,
include_only_unfrozen=include_only_unfrozen,
)
new_node_ids = info["contributing_nodes"]
# Rebuild only if contributing nodes changed
if (self._opt_built_from_node_ids == new_node_ids) and not force:
return
# Build optimizer
if info["backend"] == Backend.TORCH:
self._optimizer.build(
parameters=info["parameters"],
backend=Backend.TORCH,
force_rebuild=True,
)
elif info["backend"] == Backend.TENSORFLOW:
self._optimizer.build(
backend=Backend.TENSORFLOW,
force_rebuild=True,
)
elif self._optimizer.backend == Backend.SCIKIT:
msg = "Scikit optimizers are not supported yet."
raise NotImplementedError(msg)
else:
raise BackendNotSupportedError(
backend=self._optimizer.backend,
message="Unknown backend for optimizer building.",
)
# Update tracking on which nodes were used to build optimizer
self._opt_built_from_node_ids = set(new_node_ids)
self._optimizer_state = info
[docs]
def get_optimizer_parameters(self) -> dict[str, Any]:
"""
State of current global optimizer (if defined).
Returns a dict with the set of node_ids actually contributing parameters,
and backend specific fields:
- `"backend": Backend,`
- `"contributing_nodes": set[str],`
- `"parameters": list[torch.nn.Parameter], # PyTorch only`
- `"variables": list[tf.Variable], # TensorFlow only`
"""
if self._optimizer is None:
raise ValueError("No global optimizer exists.")
if self._optimizer_state is None:
msg = (
"Optimizer has not been built yet. "
"Call train_step() or build_optimizer() first."
)
raise RuntimeError(msg)
return self._optimizer_state
[docs]
def build(self, *, force: bool = False):
"""
Build the ModelGraph by initializing all underlying models and optimizers.
Args:
force (bool, optional):
If the graph is already built it will not be rebuilt unless
`force=True`. Defaults to False.
"""
# Skip if already built
if self.is_built and not force:
return
# Revalidate all connections
self._rebuild_connections()
# Ensure all nodes are built
# Track feature and target output shapes per node (keyed by node_id)
node_feature_shapes: dict[str, tuple[int, ...]] = {}
node_target_shapes: dict[str, tuple[int, ...]] = {}
for node_id in self._sorted_node_ids:
node = self._nodes[node_id]
# Check if node can and needs to be built
if not isinstance(node, ComputeNode) or (node.is_built and not force):
continue
# ------------------------------------------------
# Collect input feature shapes and target shapes per upstream ref
# ------------------------------------------------
ups_refs = node.get_upstream_refs()
is_single_input = len(ups_refs) == 1
in_shapes: dict[ExperimentNodeReference, tuple[int, ...]] = {}
in_target_shapes: dict[ExperimentNodeReference, tuple[int, ...]] = {}
for ups_ref in ups_refs:
if isinstance(ups_ref, FeatureSetReference):
# Upstream is a FeatureSet
# - pull feature and target shapes directly
# - note that batch dimension is dropped (via [1:])
fsv = ups_ref.resolve()
in_shapes[ups_ref] = tuple(
fsv.get_features(fmt=DataFormat.NUMPY).shape[1:],
)
in_target_shapes[ups_ref] = tuple(
fsv.get_targets(fmt=DataFormat.NUMPY).shape[1:],
)
else:
# Upstream is another graph node
# - use its tracked output shapes
if ups_ref.node_id not in node_feature_shapes:
msg = f"Input shape could not be inferred for node '{node.label}'."
raise RuntimeError(msg)
in_shapes[ups_ref] = node_feature_shapes[ups_ref.node_id]
in_target_shapes[ups_ref] = node_target_shapes[ups_ref.node_id]
# ------------------------------------------------
# Determine output shape for tail nodes
# ------------------------------------------------
# For single-input tail nodes, the output shape equals the
# propagated target shape (which may have been modified by
# upstream MergeNodes).
out_shape: tuple[int, ...] | None = None
if node_id in self.tail_nodes and is_single_input:
out_shape = next(iter(in_target_shapes.values()))
# ------------------------------------------------
# Build the node
# ------------------------------------------------
backend = self._optimizer.backend if self._optimizer is not None else None
node.build(
input_shapes=in_shapes,
output_shape=out_shape,
force=force, # For ModelNode
includes_batch_dim=False, # For MergeNode
backend=backend, # For MergeNode
)
# ------------------------------------------------
# Track output shapes for downstream nodes
# ------------------------------------------------
if is_single_input:
# Single-input nodes (ModelNode)
# - feature shape comes from build
# - target shape passes through unchanged
if out_shape is not None:
node_feature_shapes[node_id] = out_shape
else:
# Run a dummy forward pass to infer output feature shape
dummy_inputs = {
ref: make_dummy_sample_data(
batch_size=4,
feature_shape=in_shapes[ref],
)
for ref in ups_refs
}
sd_out: SampleData = node.forward(dummy_inputs)
node_feature_shapes[node_id] = tuple(sd_out.shapes.features_shape)
# Target shape is unchanged for single-input nodes
node_target_shapes[node_id] = next(iter(in_target_shapes.values()))
else:
# Multi-input nodes (MergeNode)
# - both feature and target shapes may change
# - run a dummy forward pass to determine both
dummy_inputs = {
ref: make_dummy_sample_data(
batch_size=4,
feature_shape=in_shapes[ref],
target_shape=in_target_shapes[ref],
)
for ref in ups_refs
}
sd_out: SampleData = node.forward(dummy_inputs)
node_feature_shapes[node_id] = tuple(sd_out.shapes.features_shape)
if sd_out.shapes.targets_shape is not None:
node_target_shapes[node_id] = tuple(sd_out.shapes.targets_shape)
else:
node_target_shapes[node_id] = next(iter(in_target_shapes.values()))
# Build/rebuild optimizer
if self._optimizer is not None:
self._build_optimizer(force=force)
# Update flag
self._built = True
# ================================================
# Forward / Calling
# ================================================
[docs]
def forward(
self,
inputs: dict[tuple[str, FeatureSetReference], TForward],
*,
active_nodes: list[str | GraphNode] | None = None,
) -> dict[str, Batch]:
"""
Execute a forward pass through the ModelGraph.
Args:
inputs (dict[tuple[str, FeatureSetReference], TForward]):
Mapping of (head_node_id, upstream_featureset_ref) -> TForward.
Keys must correspond to head nodes in this graph (nodes whose upstream
refs are FeatureSetReferences). A head node may have multiple inputs
if it consumes multiple FeatureSets.
active_nodes (list[str | GraphNode] | None, optional):
Optional subset of nodes to execute. If provided, only these nodes (and
any required upstream dependencies within this graph) are executed. If
None, all nodes in the graph are executed.
Returns:
dict[str, TForward]:
Mapping of node_id -> output for every executed node (typically all
nodes, but may be restricted by `active_nodes`). Tail-node outputs
can be obtained by filtering this dict with `self.tail_nodes`.
"""
if not self.is_built:
raise RuntimeError("ModelGraph must be built before calling forward().")
# Resolve active nodes (and all upstream dependencies)
if active_nodes is None:
active_node_ids: set[str] = set(self._nodes.keys())
else:
active_node_ids = get_subgraph_nodes(
graph=self,
roots=active_nodes,
direction="upstream",
include_roots=True,
)
# Maintain execution order
exec_order: list[str] = [
nid for nid in self._sorted_node_ids if nid in active_node_ids
]
# Compute outputs for all active node
outputs: dict[str, TForward] = {}
for n_id in exec_order:
node = self.nodes[n_id]
if not isinstance(node, Forwardable):
continue
# Gather inputs for this node
inp_data = node.get_input_data(
inputs=inputs,
outputs=outputs,
fmt=get_data_format_for_backend(node.backend),
)
# Forward pass & record outputs
out_batch = node.forward(inp_data)
outputs[n_id] = out_batch
return outputs
__call__ = forward
# ================================================
# Trainable Protocol
# ================================================
@property
def backend(self) -> Backend | None:
"""
The shared backend of this ModelGraph.
Description:
A ModelGraph's backend is only defined if a global optimizer
if used. If the graph consists of mixed-backend nodes, None
is returned.
Returns:
Backend | None:
The backend of the global optimizer, if defined. Otherwise,
returns None.
"""
if self._optimizer is not None:
return self._optimizer.backend
return None
[docs]
def freeze(self, nodes: list[str, GraphNode] | None = None):
"""
Sets the trainable state of `nodes` to frozen.
Args:
nodes (list[str, GraphNode] | None):
A list of node IDs, labels, or instances. All specified nodes will
have their internal state set to frozen, preventing training mutation.
If None, all nodes in this graph will be frozen.
"""
# Normalize node values
if nodes is None:
nodes = [n for n in self.nodes.values() if isinstance(n, Trainable)]
else:
if isinstance(nodes, (str, GraphNode)) or not isinstance(nodes, Iterable):
nodes = [nodes]
nodes: list[GraphNode] = [self._resolve_existing(n) for n in nodes]
# Freeze all nodes
for n in nodes:
if not isinstance(n, Trainable):
msg = f"GraphNode '{n.label}' is not Trainable. It cannot be frozen."
logger.debug(msg)
continue
n.freeze()
[docs]
def unfreeze(self, nodes: list[str, GraphNode] | None = None):
"""
Sets the trainable state of `nodes` to unfrozen.
Args:
nodes (list[str, GraphNode] | None):
A list of node IDs, labels, or instances. All specified nodes will
have their internal state set to unfrozen, allowing training.
If None, all nodes in this graph will be unfrozen.
"""
# Normalize node values
if nodes is None:
nodes = [n for n in self.nodes.values() if isinstance(n, Trainable)]
else:
if isinstance(nodes, (str, GraphNode)) or not isinstance(nodes, Iterable):
nodes = [nodes]
nodes: list[GraphNode] = [self._resolve_existing(n) for n in nodes]
# Unfreeze all nodes
for n in nodes:
if not isinstance(n, Trainable):
msg = f"GraphNode '{n.label}' is not Trainable. It cannot be frozen."
logger.debug(msg)
continue
n.unfreeze()
def _train_step_torch(
self,
ctx: ExecutionContext,
losses: list[AppliedLoss],
*,
active_nodes: list[str | GraphNode] | None = None,
):
"""
Graph-wise training with a PyTorch global optimizer.
Args:
ctx (ExecutionContext):
Execution context containing inputs, outputs, and loss storage.
losses (list[AppliedLoss]):
All losses defined for this phase. Losses are filtered internally
to the nodes they apply to.
active_nodes (list[str | GraphNode] | None, optional):
Optional subset of nodes to train. If provided, all upstream
dependencies are included automatically.
"""
# Reset optimizer gradients & get optimizer state info
self._optimizer.zero_grad()
# Forward pass & update ctx records
outputs = self.forward(inputs=ctx.inputs, active_nodes=active_nodes)
for n_id, batch in outputs.items():
ctx.set_output(node_id=n_id, batch=batch)
# Compute losses
loss_records: list[LossRecord] = []
for loss in losses:
weighted_raw_loss = loss.compute(ctx=ctx)
lr = LossRecord(
label=loss.label,
node_id=loss.node_id,
trainable=weighted_raw_loss,
)
loss_records.append(lr)
# Optimizer stepping using all trainable losses
lc = LossCollection(records=loss_records)
lc.trainable.backward()
self._optimizer.step()
# Record loss collection
ctx.add_losses(lc)
def _train_step_tensorflow(
self,
ctx: ExecutionContext,
losses: list[AppliedLoss],
*,
active_nodes: list[str | GraphNode] | None = None,
):
"""
Graph-wise training with a TensorFlow global optimizer.
Args:
ctx (ExecutionContext):
Execution context containing inputs, outputs, and loss storage.
losses (list[AppliedLoss]):
All losses defined for this phase. Losses are filtered internally
to the nodes they apply to.
active_nodes (list[str | GraphNode] | None, optional):
Optional subset of nodes to train. If provided, all upstream
dependencies are included automatically.
"""
# Reset optimizer gradients & get optimizer state info
self._optimizer.zero_grad()
opt_info = self.get_optimizer_parameters()
# Forward pass & update ctx records
with tf.GradientTape() as tape:
outputs = self.forward(inputs=ctx.inputs, active_nodes=active_nodes)
for n_id, batch in outputs.items():
ctx.set_output(node_id=n_id, batch=batch)
# Compute losses
loss_records: list[LossRecord] = []
for loss in losses:
weighted_raw_loss = loss.compute(ctx=ctx)
lr = LossRecord(
label=loss.label,
node_id=loss.node_id,
trainable=weighted_raw_loss,
)
loss_records.append(lr)
# Optimizer stepping using all trainable losses
lc = LossCollection(records=loss_records)
grads = tape.gradient(lc.trainable, opt_info["variables"])
self._optimizer.step(grads=grads, variables=opt_info["variables"])
# Record loss collection
ctx.add_losses(lc)
def _train_step_scikit(
self,
ctx: ExecutionContext,
losses: list[AppliedLoss],
*,
active_nodes: list[str | GraphNode] | None = None,
):
"""
Graph-wise training with a SciKit global optimizer.
Args:
ctx (ExecutionContext):
Execution context containing inputs, outputs, and loss storage.
losses (list[AppliedLoss]):
All losses defined for this phase. Losses are filtered internally
to the nodes they apply to.
active_nodes (list[str | GraphNode] | None, optional):
Optional subset of nodes to train. If provided, all upstream
dependencies are included automatically.
"""
# TODO: not implemented yet
msg = "Training with a scikit global optimizer not implemented yet."
raise NotImplementedError(msg)
[docs]
def train_step(
self,
ctx: ExecutionContext,
losses: list[AppliedLoss],
*,
active_nodes: list[str | GraphNode] | None = None,
):
"""
Execute a single training step for the ModelGraph.
Behavior depends on whether a global optimizer is attached:
- If `self._optimizer is None`:
Stage-wise training is performed. Each ModelNode executes its own
`train_step()` in topological order using its local optimizer.
- If `self._optimizer is not None`:
Graph-wise training is performed. A single forward pass is executed
across the graph, all losses are computed, a single backward pass
is performed, and the global optimizer is stepped once.
Args:
ctx (ExecutionContext):
Execution context containing inputs, outputs, and loss storage.
losses (list[AppliedLoss]):
All losses defined for this phase. Losses are filtered internally
to the nodes they apply to.
active_nodes (list[str | GraphNode] | None, optional):
Optional subset of nodes to be excuted in the forward pass. If provided,
all upstream dependencies are included automatically. Otherwise, all
nodes in the graph are executed.
Note that this does not set the trainable state of the nodes, only on
which nodes a forward pass is called. Use `freeze()` and `unfreeze`
to set the trainable state of graph nodes.
"""
# Ensure graph is built
if not self.is_built:
self.build()
# Resolve active nodes (and all upstream dependencies)
if active_nodes is None:
active_node_ids: set[str] = set(self._nodes.keys())
else:
active_node_ids = get_subgraph_nodes(
graph=self,
roots=active_nodes,
direction="upstream",
include_roots=True,
)
# Maintain execution order
exec_order: list[str] = [
nid for nid in self._sorted_node_ids if nid in active_node_ids
]
# Validate that at least one loss is applied to these active nodes
valid = False
for loss in losses:
if loss.node_id in active_node_ids:
valid = True
break
if not valid:
msg = "Training must have at least one loss applied to an active node."
raise ValueError(msg)
# Validate not all frozen
valid = False
for n_id in active_node_ids:
node = self.nodes[n_id]
if isinstance(node, Trainable) and not node.is_frozen:
valid = True
break
if not valid:
msg = "Training must have at least unfrozen node."
raise ValueError(msg)
# ------------------------------------------------
# Training Case 1: Stage-wise training (no global optimizer)
# ------------------------------------------------
if self._optimizer is None:
for node_id in exec_order:
node = self._nodes[node_id]
# If trainable, use train_step (check if frozen)
if isinstance(node, Trainable) and not node.is_frozen:
node.train_step(ctx=ctx, losses=losses)
# If evaluable (or trainable + frozen), use eval_step
elif isinstance(node, Evaluable):
node.eval_step(ctx=ctx, losses=losses)
# If forwardable, record outputs of manual forward pass
elif isinstance(node, Forwardable):
# Gather inputs for this node
inp_data = node.get_input_data(
inputs=ctx.inputs,
outputs=ctx.outputs,
fmt=get_data_format_for_backend(node.backend),
)
# Forward pass & record outputs
ctx.outputs[node_id] = node.forward(inp_data)
# Otherwise, skip node
return None
# ------------------------------------------------
# Training Case 2: Graph-wise training (global optimizer)
# ------------------------------------------------
# Rebuild optimizer with only unfrozen nodes (only rebuilds if necessary)
self._build_optimizer(
nodes_to_include=active_node_ids,
include_only_unfrozen=True,
)
# Use backend-specific training logic
backend = self._optimizer.backend
if backend == Backend.TORCH:
return self._train_step_torch(
ctx=ctx,
losses=losses,
active_nodes=active_node_ids,
)
if backend == Backend.TENSORFLOW:
return self._train_step_tensorflow(
ctx=ctx,
losses=losses,
active_nodes=active_node_ids,
)
if backend == Backend.SCIKIT:
return self._train_step_scikit(
ctx=ctx,
losses=losses,
active_nodes=active_node_ids,
)
msg = f"Unknown backend: {backend}"
raise BackendNotSupportedError(msg)
# ================================================
# Evaluable Protocol
# ================================================
[docs]
def eval_step(
self,
ctx: ExecutionContext,
losses: list[AppliedLoss],
*,
active_nodes: list[str | GraphNode] | None = None,
):
"""
Execute a single evaluation step for the ModelGraph.
Description:
Performs a forward-only pass through the graph, computes all applicable
losses, and records outputs and losses into the ExecutionContext.
No gradients are tracked and no optimizers are stepped.
Args:
ctx (ExecutionContext):
Execution context containing inputs, outputs, and loss storage.
losses (list[AppliedLoss]):
Losses to compute during evaluation.
active_nodes (list[str | GraphNode] | None, optional):
Optional subset of nodes to execute. All required upstream
dependencies are included automatically.
"""
# Ensure graph is built
if not self.is_built:
self.build()
# Ensure all nodes frozen
self.freeze(nodes=None)
# Forward Pass + No Gradients
backend = self.backend
if backend == Backend.TORCH:
with torch.no_grad():
outputs = self.forward(
inputs=ctx.inputs,
active_nodes=active_nodes,
)
else:
outputs = self.forward(
inputs=ctx.inputs,
active_nodes=active_nodes,
)
# Record outputs
for n_id, batch in outputs.items():
ctx.set_output(node_id=n_id, batch=batch)
# Compute losses
loss_records: list[LossRecord] = []
for loss in losses:
weighted_raw_loss = loss.compute(ctx=ctx)
lr = LossRecord(
label=loss.label,
node_id=loss.node_id,
auxiliary=weighted_raw_loss,
)
loss_records.append(lr)
# Record loss collection
lc = LossCollection(records=loss_records)
ctx.add_losses(lc)
# ================================================
# Fittable Protocol
# ================================================
[docs]
def fit_step(
self,
ctx: ExecutionContext,
losses: list[AppliedLoss] | None = None,
*,
active_nodes: list[str | GraphNode] | None = None,
freeze_after_fit: bool = True,
):
"""
Fit batch-fit nodes in topological order.
Description:
Iterates through active nodes in topological order. Nodes that
implement the `Fittable` protocol and are not frozen will have
`fit_step()` called. All other forwardable nodes perform a
forward-only pass to propagate outputs downstream.
After fitting, nodes are optionally frozen to prevent interference
during subsequent gradient-based training phases.
Args:
ctx (ExecutionContext):
Execution context containing full-dataset inputs.
losses (list[AppliedLoss] | None, optional):
Optional losses to compute after fitting (for metrics only).
active_nodes (list[str | GraphNode] | None, optional):
Optional subset of nodes to fit. If None, all nodes in the
graph are considered.
freeze_after_fit (bool, optional):
Whether to freeze fitted nodes after completion.
Defaults to True.
"""
# Ensure graph is built
if not self.is_built:
self.build()
# Resolve active nodes (and all upstream dependencies)
if active_nodes is None:
active_node_ids: set[str] = set(self._nodes.keys())
else:
active_node_ids = get_subgraph_nodes(
graph=self,
roots=active_nodes,
direction="upstream",
include_roots=True,
)
# Maintain execution order
exec_order: list[str] = [
nid for nid in self._sorted_node_ids if nid in active_node_ids
]
for node_id in exec_order:
node = self._nodes[node_id]
# If node is Fittable and not frozen -> fit it
if isinstance(node, Fittable) and not node.is_frozen:
node.fit_step(ctx=ctx, losses=losses)
if freeze_after_fit:
node.freeze()
# Otherwise, if forwardable -> forward pass only (record outputs)
elif isinstance(node, Forwardable):
inp_data = node.get_input_data(
inputs=ctx.inputs,
outputs=ctx.outputs,
fmt=get_data_format_for_backend(node.backend),
)
out = node.forward(inp_data)
ctx.outputs[node_id] = out
# ================================================
# Configurable
# ================================================
[docs]
def get_config(self) -> dict[str, Any]:
"""
Retrieve the configuration details of this ModelGraph instance.
This does not contain state information of any underlying models or optimizers.
"""
return {
"label": self.label,
"nodes": [self.nodes[n_id].get_config() for n_id in self._sorted_node_ids],
"optimizer": None
if self._optimizer is None
else self._optimizer.get_config(),
}
[docs]
@classmethod
def from_config(
cls,
config: dict[str, Any],
*,
register: bool = True,
) -> ModelGraph:
"""
Reconstructs a ModelGraph from configuration details.
This does not restore state information of any underlying models or optimizers.
"""
ctx = ExperimentContext.get_active()
# Rebuild nodes first (must register them to use a ModelGraph)
nodes: list[GraphNode] = []
for node_cfg in config["nodes"]:
node = GraphNode.from_config(config=node_cfg, register=register)
nodes.append(node)
# Rebuild optimizer
optimizer = None
optimizer_cfg = config.get("optimizer")
if optimizer_cfg is not None:
optimizer = Optimizer.from_config(optimizer_cfg)
# Create ModelGraph
mg = cls(
nodes=nodes,
optimizer=optimizer,
label=config.get("label", "model-graph"),
ctx=ctx,
register=register,
)
return mg
# ================================================
# Stateful
# ================================================
[docs]
def get_state(self) -> dict[str, Any]:
"""
Return serialized state for all nodes and global optimizer.
Returns:
dict[str, Any]: Snapshot capturing node state, optimizer
state, and build metadata.
"""
state = {
"nodes": {
n_id: deepcopy(self.nodes[n_id].get_state())
for n_id in self._sorted_node_ids
if isinstance(self.nodes[n_id], Stateful)
},
"optimizer": None
if self._optimizer is None
else deepcopy(self._optimizer.get_state()),
"opt_built_from_node_ids": self._opt_built_from_node_ids,
"is_built": self.is_built,
}
return state
[docs]
def set_state(self, state: dict[str, Any]) -> None:
"""
Restore node and optimizer state from :meth:`get_state`.
Args:
state (dict[str, Any]):
Snapshot previously generated by :meth:`get_state`.
"""
# Restore node state
for n_id, n_state in state.get("nodes", {}).items():
node = self._nodes[n_id]
if isinstance(node, Stateful):
node.set_state(n_state)
# Restore optimizer
if self._optimizer is not None and state.get("optimizer") is not None:
self._optimizer.set_state(state["optimizer"])
opt_nodes = state.get("opt_built_from_node_ids")
self._opt_built_from_node_ids = None if opt_nodes is None else set(opt_nodes)
if state.get("is_built", False):
self._build_optimizer(
self._opt_built_from_node_ids,
include_only_unfrozen=False,
force=True,
)
self._built = True
# ================================================
# Serialization
# ================================================
[docs]
def save(self, filepath: Path, *, overwrite: bool = False) -> Path:
"""
Serializes this ModelGraph to the specified filepath.
Args:
filepath (Path):
File location to save to. Note that the suffix may be overwritten
to enforce the ModularML file extension schema.
overwrite (bool, optional):
Whether to overwrite any existing file at the save location.
Defaults to False.
Returns:
Path: The actual filepath to write the ModelGraph is saved.
"""
from modularml.core.io.serialization_policy import SerializationPolicy
from modularml.core.io.serializer import serializer
return serializer.save(
self,
filepath,
policy=SerializationPolicy.BUILTIN,
overwrite=overwrite,
)
[docs]
@classmethod
def load(
cls,
filepath: Path,
*,
allow_packaged_code: bool = False,
overwrite: bool = False,
) -> ModelGraph:
"""
Load a FeaturModelGrapheSet from file.
Args:
filepath (Path):
File location of a previously saved ModelGraph.
allow_packaged_code : bool
Whether bundled code execution is allowed.
overwrite (bool):
Whether to replace any colliding node registrations in ExperimentContext
If False, new IDs are assigned to the reloaded nodes comprising the
graph. Otherwise, any collision are overwritten with the saved nodes.
Defaults to False.
It is recommended to only reload a ModelGraph into a new/empty
`ExperimentContext`.
Returns:
ModelGraph: The reloaded ModelGraph.
"""
from modularml.core.io.serializer import _enforce_file_suffix, serializer
# Append proper sufficx only if no suffix is given
if Path(filepath).suffix == "":
filepath = _enforce_file_suffix(path=filepath, cls=cls)
return serializer.load(
filepath,
allow_packaged_code=allow_packaged_code,
overwrite=overwrite,
)
# ================================================
# Checkpointing
# ================================================
[docs]
def save_checkpoint(
self,
filepath: Path,
*,
overwrite: bool = False,
meta: dict[str, Any] | None = None,
) -> Path:
"""
Save full ModelGraph state.
Args:
filepath (Path):
File location to save to. Note that the suffix may be overwritten
to enforce the ModularML file extension schema.
overwrite (bool, optional):
Whether to overwrite any existing file at the save location.
Defaults to False.
meta (dict[str, Any], optional):
Additional meta data to attach to the checkpoint.
Must be pickle-able.
Returns:
Path: Final path of saved ModelGraph checkpoint.
"""
from modularml.core.io.serialization_policy import SerializationPolicy
from modularml.core.io.serializer import serializer
ckpt = Checkpoint()
# Attach node and optimizer states
ckpt.add_entry(key="modelgraph", obj=self)
for n_id, node in self.nodes.items():
ckpt.add_entry(key=f"nodes:{n_id}", obj=node)
if self._optimizer is not None:
ckpt.add_entry(key="optimizer", obj=self._optimizer)
# Attach meta data
if meta is not None:
for k, v in meta.items():
ckpt.add_meta(k, v)
return serializer.save(
ckpt,
filepath,
policy=SerializationPolicy.BUILTIN,
overwrite=overwrite,
)
[docs]
def restore_checkpoint(self, filepath: Path) -> ModelGraph:
"""
Restore ModelGraph state from checkpoint.
Args:
filepath (Path):
File location of a previously saved ModelGraph checkpoint.
Returns:
self: The ModelGraph restored to the checkpoint state.
"""
from modularml.core.io.serializer import _enforce_file_suffix, serializer
# Append proper suffix only if no suffix is given
if Path(filepath).suffix == "":
filepath = _enforce_file_suffix(path=filepath, cls=Checkpoint)
# Load checkpoint
ckpt: Checkpoint = serializer.load(filepath, allow_packaged_code=True)
# Set node states
n_states = {
k.split(":")[-1]: v.entry_state
for k, v in ckpt.entries.items()
if k.startswith("nodes")
}
for n_id, n_state in n_states.items():
self.nodes[n_id].set_state(n_state)
# Set optimizer state
if "optimizer" in ckpt.entries:
self._optimizer.set_state(ckpt.entries["optimizer"].entry_state)
# Update model graph state
mg_state = ckpt.entries["modelgraph"].entry_state
self._opt_built_from_node_ids = mg_state["opt_built_from_node_ids"]
self._built = mg_state["is_built"]
return self
# ================================================
# Visualizer
# ================================================
[docs]
def visualize(
self,
*,
show_features: bool = False,
show_targets: bool = False,
show_tags: bool = False,
show_frozen: bool = True,
show_splits: bool = False,
):
"""
Displays a mermaid diagram for this FeatureSet.
Args:
show_features (bool, optional):
Show feature columns on head nodes. Defaults to False.
show_targets (bool, optional):
Show target columns on head nodes. Defaults to False.
show_tags (bool | str, optional):
Show tags columns on head nodes. Defaults to False.
show_frozen (bool, optional):
Show frozen state (label text and dimmed styling) on ModelNodes
Defaults to True.
show_splits (bool, optional):
Show available splits on FeatureSet nodes. Defaults to False.
"""
display_opts = ModelGraphDisplayOptions(
show_features=show_features,
show_targets=show_targets,
show_tags=show_tags,
show_frozen=show_frozen,
show_splits=show_splits,
)
return Visualizer(self, display_options=display_opts).display_mermaid()