How to: Create and Use a ModelGraph#
A ModelGraph is the computational backbone of a ModularML Experiment. It organizes
one or more ModelNodes (and optionally MergeNodes) into a directed acyclic graph (DAG)
that handles:
Shape inference: Automatically determines input/output shapes for every node during
build().Topological execution: Ensures nodes execute in dependency order during forward, training, and evaluation passes.
Global optimizer management: Optionally shares a single optimizer across all trainable nodes for end-to-end gradient flow.
Freeze / unfreeze control: Selectively disable training for subsets of the graph.
Graph mutation: Add, remove, replace, or insert nodes dynamically.
Serialization & checkpointing: Save and restore the full graph structure and learned weights.
FeatureSet ──> ModelNode("Encoder") ──> ModelNode("Regressor")
FeatureSet ─┬─> ModelNode("A") ──┐
│ ├─> ConcatNode ──> ModelNode("Head")
└─> ModelNode("B") ──┘
This notebook covers:
import numpy as np
import torch
from modularml import (
ConcatNode,
Experiment,
FeatureSet,
ModelGraph,
ModelNode,
Optimizer,
)
from modularml.models.torch import SequentialMLP
# Create an Experiment with overwrite policy so we can freely recreate nodes
# with the same names (prevent getting a warning each time we overwrite a node)
exp = Experiment(label="create_modelgraph", registration_policy="overwrite")
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[1], line 2
1 import numpy as np
----> 2 import torch
4 from modularml import (
5 ConcatNode,
6 Experiment,
(...)
10 Optimizer,
11 )
12 from modularml.models.torch import SequentialMLP
ModuleNotFoundError: No module named 'torch'
We’ll use a simple synthetic dataset throughout this notebook: 500 samples of a 10-point feature with a scalar target.
rng = np.random.default_rng(42)
fs = FeatureSet.from_dict(
label="SensorData",
data={
"voltage": list(rng.standard_normal((500, 10))),
"soh": list(rng.standard_normal((500, 1))),
},
feature_keys="voltage",
target_keys="soh",
)
fs_ref = fs.reference(features="voltage", targets="soh")
print(fs)
Creating a ModelGraph#
A ModelGraph is constructed from a list of GraphNode instances and an optional shared Optimizer.
ModelGraph(
nodes: list[str | GraphNode] | None,
optimizer: Optimizer | None = None,
label: str = "model-graph",
)
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
(required) |
Nodes comprising the graph. Pass node instances or their string labels. If |
|
|
|
A shared optimizer for end-to-end training. If provided, all trainable nodes must share the same backend. |
|
|
|
A human-readable label for this graph. |
Simple Linear Graph#
The simplest graph is a linear chain: FeatureSet -> ModelNode.
node = ModelNode(
label="SimpleMLP",
model=SequentialMLP(output_shape=(1, 1), n_layers=2, hidden_dim=32),
upstream_ref=fs_ref,
)
mg = ModelGraph(
nodes=[node],
optimizer=Optimizer(opt="adam", opt_kwargs={"lr": 1e-3}, backend="torch"),
label="simple-graph",
)
print(f"Label: {mg.label}")
print(f"Nodes: {mg.node_labels}")
print(f"Built: {mg.is_built}")
Multi-Node Chain#
Chain multiple ModelNodes by passing one as the upstream_ref of the next.
ModelGraph supports the .visualize() method, which we’ll use to show our topology updates.
encoder = ModelNode(
label="Encoder",
model=SequentialMLP(output_shape=(1, 8), n_layers=2, hidden_dim=32),
upstream_ref=fs_ref,
)
regressor = ModelNode(
label="Regressor",
model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=16),
upstream_ref=encoder,
)
mg_chain = ModelGraph(
nodes=[encoder, regressor],
optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Node labels: {mg_chain.node_labels}")
mg_chain.visualize()
Branching Graph with MergeNode#
Use ConcatNode (a MergeNode) to combine outputs from parallel branches.
FeatureSet ─┬─> EncoderA ──┐
│ ├─> ConcatNode ──> Head
└─> EncoderB ──┘
enc_a = ModelNode(
label="EncoderA",
model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
enc_b = ModelNode(
label="EncoderB",
model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
merge = ConcatNode(
label="Merge",
upstream_refs=[enc_a, enc_b],
concat_axis=-1,
concat_axis_targets="first",
)
head = ModelNode(
label="Head",
model=SequentialMLP(n_layers=1, hidden_dim=8),
upstream_ref=merge,
)
mg_branch = ModelGraph(
nodes=[enc_a, enc_b, merge, head],
optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Node labels: {mg_branch.node_labels}")
mg_branch.visualize()
Referencing Nodes by Label#
Instead of passing node instances, you can pass their string labels. The graph will look them up in the active ExperimentContext.
mg_by_label = ModelGraph(
nodes=["EncoderA", "EncoderB", "Merge", "Head"],
optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Node labels: {mg_by_label.node_labels}")
mg_by_label.visualize()
Without a Global Optimizer#
If no global optimizer is provided, each ModelNode must define its own local optimizer. This is useful when different nodes need different optimizers or learning rates (stage-wise training).
node_with_opt = ModelNode(
label="StageWiseMLP",
model=SequentialMLP(output_shape=(1, 1), n_layers=2, hidden_dim=32),
upstream_ref=fs_ref,
optimizer=Optimizer("adam", opt_kwargs={"lr": 1e-3}, backend="torch"),
)
mg_no_global = ModelGraph(
nodes=[node_with_opt],
optimizer=None,
)
print(f"Global optimizer: {mg_no_global.backend}")
Building the Graph#
ModelGraph.build() performs the following steps in topological order:
Validates the DAG structure (no cycles, all upstream references resolved).
Infers input and output shapes for each node from upstream outputs and FeatureSet shapes.
Builds each node’s underlying model (lazy initialization).
Builds the global optimizer (if provided) with parameters from all trainable nodes.
ModelGraph.build(*, force: bool = False)
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
If |
mg_branch.build()
print(f"Built: {mg_branch.is_built}")
for node in mg_branch.nodes.values():
in_shape = (
node.input_shape
if hasattr(node, "input_shape")
else list(node.input_shapes.values())
)
out_shape = getattr(node, "output_shape", None)
print(f" {node.label}: {in_shape} -> {out_shape}")
mg_branch.visualize() # Note how all edges now show the input/output shapes
Shape Inference Details#
During build(), shapes propagate through the graph as follows:
Head nodes (inputs from a
FeatureSet): Input shape is pulled directly from the referencedFeatureSetdata.Intermediate nodes: Input shape equals the output shape of their upstream node.
Tail nodes (no downstream consumers): If no
output_shapeis specified on the model, it defaults to the target shape propagated from the upstreamFeatureSet.MergeNodes: Both feature and target output shapes are determined by a dummy forward pass through the merge logic.
You generally do not need to specify input_shape on your models — build() infers it. Specifying output_shape is recommended for all non-tail nodes.
Rebuilding#
Calling build() on an already-built graph is a no-op unless force=True.
# No-op (already built)
mg_branch.build()
# Force rebuild (e.g., after modifying graph structure)
mg_branch.build(force=True)
print(f"Rebuilt: {mg_branch.is_built}")
Graph Properties#
After building, the graph exposes several useful properties for inspecting its structure.
print(f"Label: {mg_branch.label}")
print(f"Built: {mg_branch.is_built}")
print(f"Backend: {mg_branch.backend}")
print(f"Node labels: {mg_branch.node_labels}")
Head and Tail Nodes#
Head nodes: Nodes whose inputs come directly from a
FeatureSet(no upstreamGraphNodedependencies).Tail nodes: Nodes whose outputs are not consumed by any other node in the graph.
print("Head nodes (receive FeatureSet data):")
for n in mg_branch.head_nodes.values():
print(f" - {n.label}")
print("\nTail nodes (produce final outputs):")
for n in mg_branch.tail_nodes.values():
print(f" - {n.label}")
Accessing Individual Nodes#
Nodes are stored in a dict keyed by node_id. These IDs are globally unique and are the reason nodes can be reference by their label, ID, or instance at any point in an Experiment.
You can iterate over nodes or access by label.
# All nodes (keyed by node_id)
for n_id, node in mg_branch.nodes.items():
print(f" {node.label} (id={n_id[:8]}...)")
Forward Pass#
Once built, you can execute a forward pass through the graph. The graph handles data routing between nodes in topological order.
ModelGraph.forward(
inputs: dict[tuple[str, FeatureSetReference], TForward],
*,
active_nodes: list[str | GraphNode] | None = None,
) -> dict[str, TForward]
Parameter |
Type |
Description |
|---|---|---|
|
|
Mapping of |
|
|
Optional subset of nodes to execute. Upstream dependencies are included automatically. If |
Returns: A dict mapping node_id to that node’s output data for every executed node.
from modularml.core.data.sample_data import SampleData
from modularml.utils.data.data_format import DataFormat
# Prepare input data
fsv = fs_ref.resolve()
sample_data = SampleData(
features=fsv.get_features(fmt=DataFormat.TORCH),
targets=fsv.get_targets(fmt=DataFormat.TORCH),
)
# Build the inputs dict: (head_node_id, featureset_ref) -> data
inputs = {}
for n_id, node in mg_branch.head_nodes.items():
for ref in node.get_upstream_refs():
inputs[(n_id, ref)] = sample_data
print(f"Number of input entries: {len(inputs)}")
# Execute forward pass
with torch.no_grad():
outputs = mg_branch.forward(inputs)
print("Outputs per node:")
for n_id, out in outputs.items():
node_label = mg_branch.nodes[n_id].label
print(f" {node_label}: features={out.features.shape}")
Active Nodes#
You can restrict the forward pass to a subset of the graph using active_nodes. All required upstream dependencies are automatically included.
We can set just “merge” to be active, but all upstream nodes (Encoders A and B) will need to be executed as well. The head node, however, does not need to be executed.
# Only execute EncoderA and the Merge (plus its dependencies)
with torch.no_grad():
partial_outputs = mg_branch.forward(inputs, active_nodes=[merge])
print("Executed nodes:")
for n_id in partial_outputs:
print(f" - {mg_branch.nodes[n_id].label}")
Graph Mutation#
ModelGraph provides several methods to modify the graph structure after creation. All mutation methods return self for method chaining.
After any structural change, the graph automatically revalidates connections and recomputes the topological order. You will need to call build() again to reinitialize shapes and optimizers.
add_node()#
Add a new node to the graph. The node must already be connected to existing nodes via its upstream_ref.
# Start with a simple single-node graph
base_node = ModelNode(
label="Base",
model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
mg_mut = ModelGraph(
nodes=[base_node],
optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Before: {mg_mut.node_labels}")
# Add a downstream node
added_node = ModelNode(
label="Added",
model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
upstream_ref=base_node,
)
mg_mut.add_node(added_node)
print(f"After: {mg_mut.node_labels}")
mg_mut.visualize()
remove_node()#
Remove a node from the graph. Downstream nodes are reconnected to the removed node’s upstream sources.
Given: A -> B -> C
Remove B:
Result: A -> C
# Create a 3-node chain
n1 = ModelNode(
label="N1",
model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
n2 = ModelNode(
label="N2",
model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=8),
upstream_ref=n1,
)
n3 = ModelNode(
label="N3",
model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
upstream_ref=n2,
)
mg_rem = ModelGraph(
nodes=[n1, n2, n3],
optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Before: {mg_rem.node_labels}")
# Remove the middle node
mg_rem.remove_node("N2")
print(f"After: {mg_rem.node_labels}")
mg_rem.visualize()
replace_node()#
Replace an existing node with a new one, preserving all upstream and downstream connections.
# Create a simple chain
old_enc = ModelNode(
label="OldEncoder",
model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
reg = ModelNode(
label="Reg",
model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
upstream_ref=old_enc,
)
mg_rep = ModelGraph(
nodes=[old_enc, reg],
optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Before: {mg_rep.node_labels}")
# Replace with a deeper encoder
new_enc = ModelNode(
label="NewEncoder",
model=SequentialMLP(output_shape=(1, 8), n_layers=3, hidden_dim=64),
upstream_ref=fs_ref,
)
mg_rep.replace_node(old_node="OldEncoder", new_node=new_enc)
print(f"After: {mg_rep.node_labels}")
mg_rep.visualize()
insert_node_between()#
Insert a new node between two already-connected nodes.
Given: A -> B
Insert C between A and B:
Result: A -> C -> B
a = ModelNode(
label="A",
model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
b = ModelNode(
label="B",
model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
upstream_ref=a,
)
mg_ins = ModelGraph(
nodes=[a, b],
optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Before: {mg_ins.node_labels}")
c = ModelNode(
label="C",
model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref, # will be overwritten by insert
)
mg_ins.insert_node_between(new_node=c, upstream=a, downstream=b)
print(f"After: {mg_ins.node_labels}")
# Verify connectivity
for node in mg_ins.nodes.values():
ups = [r.node_label for r in node.get_upstream_refs()]
print(f" {node.label} <- {ups}")
mg_ins.visualize()
insert_node_before() and insert_node_after()#
insert_node_before(new_node, downstream=...): Insert before an existing node, taking over all its upstream connections.insert_node_after(new_node, upstream=...): Insert after an existing node as an additional downstream consumer.
p = ModelNode(
label="P",
model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
q = ModelNode(
label="Q",
model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
upstream_ref=p,
)
mg_ib = ModelGraph(
nodes=[p, q],
optimizer=Optimizer(opt="adam", backend="torch"),
)
# Insert a node before Q (takes over Q's upstream connections)
pre_q = ModelNode(
label="PreQ",
model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
mg_ib.insert_node_before(new_node=pre_q, downstream=q)
print("After insert_node_before:")
for node in mg_ib.nodes.values():
ups = [r.node_label for r in node.get_upstream_refs()]
print(f" {node.label} <- {ups}")
mg_ib.visualize()
# Insert a node after P (adds a new branch)
post_p = ModelNode(
label="PostP",
model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
upstream_ref=fs_ref,
)
mg_ib.insert_node_after(new_node=post_p, upstream=p)
print("After insert_node_after:")
for node in mg_ib.nodes.values():
ups = [r.node_label for r in node.get_upstream_refs()]
print(f" {node.label} <- {ups}")
print(f"\nTail nodes: {[n.label for n in mg_ib.tail_nodes.values()]}")
mg_ib.visualize()
Freezing and Unfreezing#
Freezing prevents a node’s parameters from being updated during training. This is useful for transfer learning, multi-stage training, or keeping pretrained components fixed.
ModelGraph.freeze(nodes: list[str | GraphNode] | None = None)
ModelGraph.unfreeze(nodes: list[str | GraphNode] | None = None)
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
Nodes to freeze/unfreeze (by label, ID, or instance). If |
# Using the branching graph from Section 1.3
mg_branch.build(force=True)
# Freeze specific nodes
mg_branch.freeze(nodes=[enc_a])
print(f"Frozen nodes: {[n.label for n in mg_branch.frozen_nodes.values()]}")
mg_branch.visualize(show_frozen=True)
# Unfreeze
mg_branch.unfreeze(nodes=[enc_a])
print(f"Frozen nodes: {[n.label for n in mg_branch.frozen_nodes.values()]}")
# Freeze all nodes at once
mg_branch.freeze()
print(f"All frozen: {[n.label for n in mg_branch.frozen_nodes.values()]}")
# Unfreeze all
mg_branch.unfreeze()
print(f"All unfrozen: {[n.label for n in mg_branch.frozen_nodes.values()]}")
Frozen Nodes and the Optimizer#
When using a global optimizer, the optimizer is automatically rebuilt to exclude frozen nodes’ parameters before each training step. This means frozen nodes will not accumulate gradients and their weights remain unchanged.
Optimizer Management#
The ModelGraph supports two training modes based on whether a global optimizer is provided:
Global Optimizer (Graph-Wise Training)#
When a global Optimizer is set on the ModelGraph:
A single forward pass runs through the entire graph.
All losses are accumulated.
A single backward pass computes gradients across all unfrozen nodes.
The global optimizer steps once.
This enables end-to-end gradient flow through the full graph, which is the most common training paradigm.
No Global Optimizer (Stage-Wise Training)#
When optimizer=None on the ModelGraph:
Each
ModelNodemust have its own localOptimizer.Nodes are trained independently in topological order.
Each node performs its own forward pass, loss computation, backward pass, and optimizer step.
This is useful when you need different optimizers per node, or when certain nodes should not share gradient flow.
Inspecting Optimizer Parameters#
After at least one training step (or after calling build()), you can inspect which nodes contribute parameters to the global optimizer.
mg_branch.build(force=True)
opt_info = mg_branch.get_optimizer_parameters()
print(f"Backend: {opt_info['backend']}")
print(f"Contributing nodes: {len(opt_info['contributing_nodes'])}")
print(f"Total parameters: {len(opt_info['parameters'])}")
Backend Constraints#
When using a global optimizer, all trainable nodes must share the same backend (e.g., all PyTorch). A RuntimeError is raised if backends conflict.
Mixed-backend graphs (e.g., PyTorch encoder + scikit-learn head) must use stage-wise training (no global optimizer).
Serialization#
ModelGraph supports full serialization: saving and loading both the graph structure (config) and learned weights (state).
Config Serialization#
get_config() captures the graph structure (node configs, optimizer config) without learned weights. from_config() reconstructs the graph from a config dict.
config = mg_branch.get_config()
print(f"Config keys: {list(config.keys())}")
print(f"Number of node configs: {len(config['nodes'])}")
print(f"Optimizer config: {config['optimizer'] is not None}")
State Serialization#
get_state() captures the learned weights and optimizer state. set_state() restores them.
state = mg_branch.get_state()
print(f"State keys: {list(state.keys())}")
print(f"Number of node states: {len(state['nodes'])}")
print(f"Is built: {state['is_built']}")
Save and Load to Disk#
Use save() and load() for persistent serialization. The file includes both config and state.
from pathlib import Path
from tempfile import TemporaryDirectory
SAVE_DIR = TemporaryDirectory()
# Save
save_path = mg_branch.save(Path(SAVE_DIR.name) / "my_graph", overwrite=True)
print(f"Saved to: {save_path}")
# Load
# Note that we need allow overwriting because all reloaded node labels/IDs
# with those defined in this notebook
mg_loaded = ModelGraph.load(save_path, overwrite=True)
print(f"Loaded graph labels: {mg_loaded.node_labels}")
mg_loaded.visualize()
Checkpointing#
Checkpointing allows you to save and restore the full state of a ModelGraph at a specific point during training. Unlike save() / load() (which creates a new ModelGraph instance), checkpointing restores state into an existing graph.
ModelGraph.save_checkpoint(
filepath: Path,
*,
overwrite: bool = False,
meta: dict[str, Any] | None = None,
) -> Path
ModelGraph.restore_checkpoint(filepath: Path) -> ModelGraph
Parameter |
Type |
Description |
|---|---|---|
|
|
Location to save/load the checkpoint. |
|
|
Whether to overwrite an existing file. |
|
|
Optional metadata to attach to the checkpoint (must be pickle-able). |
# Save a checkpoint (includes model weights and optimizer state)
ckpt_path = mg_branch.save_checkpoint(
Path(SAVE_DIR.name) / "checkpoint_epoch5",
overwrite=True,
meta={"epoch": 5, "val_loss": 0.032},
)
print(f"Checkpoint saved to: {ckpt_path}")
# Restore the checkpoint into the existing graph
mg_branch.restore_checkpoint(ckpt_path)
print(f"Restored. Built: {mg_branch.is_built}")
Summary#
Constructor#
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
(required) |
Nodes comprising the graph. |
|
|
|
Shared optimizer for graph-wise training. |
|
|
|
Human-readable label. |
Properties#
Property |
Type |
Description |
|---|---|---|
|
|
All nodes keyed by |
|
|
Unique node labels. |
|
|
Nodes receiving FeatureSet input. |
|
|
Nodes with no downstream consumers. |
|
|
Whether |
|
|
Backend of the global optimizer, or |
|
|
Currently frozen trainable nodes. |
Methods#
Method |
Description |
|---|---|
|
Build all nodes and the global optimizer. |
|
Execute a forward pass through the graph. |
|
Execute a single training step (graph-wise or stage-wise). |
|
Execute a forward-only evaluation step (no gradients). |
|
Fit batch-fit nodes (e.g., scikit-learn) in topological order. |
|
Freeze nodes to prevent training. |
|
Unfreeze nodes to allow training. |
|
Add a node to the graph. |
|
Remove a node, reconnecting neighbors. |
|
Replace a node, preserving connections. |
|
Insert between two connected nodes. |
|
Insert before an existing node. |
|
Insert after an existing node. |
|
Config serialization (structure only). |
|
State serialization (includes weights). |
|
Full serialization to/from disk. |
|
Save a training checkpoint. |
|
Restore state from a checkpoint. |
Training Modes#
Mode |
When |
Behavior |
|---|---|---|
Graph-wise |
Global |
Single forward + backward pass across all nodes. End-to-end gradient flow. |
Stage-wise |
No global optimizer ( |
Each node trains independently with its own optimizer. |
Next Steps#
Experiment: Use
Experimentto combine aModelGraphwith training phases, loss functions, and evaluation — the primary user-facing entry point.ModelNode: See how individual nodes wrap models and handle forward passes.
MergeNode: Learn how to combine parallel branches with
ConcatNode.