How to: Create and Use a ModelGraph#

A ModelGraph is the computational backbone of a ModularML Experiment. It organizes one or more ModelNodes (and optionally MergeNodes) into a directed acyclic graph (DAG) that handles:

  • Shape inference: Automatically determines input/output shapes for every node during build().

  • Topological execution: Ensures nodes execute in dependency order during forward, training, and evaluation passes.

  • Global optimizer management: Optionally shares a single optimizer across all trainable nodes for end-to-end gradient flow.

  • Freeze / unfreeze control: Selectively disable training for subsets of the graph.

  • Graph mutation: Add, remove, replace, or insert nodes dynamically.

  • Serialization & checkpointing: Save and restore the full graph structure and learned weights.

FeatureSet ──> ModelNode("Encoder") ──> ModelNode("Regressor")

FeatureSet ─┬─> ModelNode("A") ──┐
            │                    ├─> ConcatNode ──> ModelNode("Head")
            └─> ModelNode("B") ──┘

This notebook covers:

import numpy as np
import torch

from modularml import (
    ConcatNode,
    Experiment,
    FeatureSet,
    ModelGraph,
    ModelNode,
    Optimizer,
)
from modularml.models.torch import SequentialMLP

# Create an Experiment with overwrite policy so we can freely recreate nodes
# with the same names (prevent getting a warning each time we overwrite a node)
exp = Experiment(label="create_modelgraph", registration_policy="overwrite")
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 2
      1 import numpy as np
----> 2 import torch
      4 from modularml import (
      5     ConcatNode,
      6     Experiment,
   (...)
     10     Optimizer,
     11 )
     12 from modularml.models.torch import SequentialMLP

ModuleNotFoundError: No module named 'torch'

We’ll use a simple synthetic dataset throughout this notebook: 500 samples of a 10-point feature with a scalar target.

rng = np.random.default_rng(42)

fs = FeatureSet.from_dict(
    label="SensorData",
    data={
        "voltage": list(rng.standard_normal((500, 10))),
        "soh": list(rng.standard_normal((500, 1))),
    },
    feature_keys="voltage",
    target_keys="soh",
)
fs_ref = fs.reference(features="voltage", targets="soh")
print(fs)

Creating a ModelGraph#

A ModelGraph is constructed from a list of GraphNode instances and an optional shared Optimizer.

    ModelGraph(
        nodes: list[str | GraphNode] | None,
        optimizer: Optimizer | None = None,
        label: str = "model-graph",
    )

Parameter

Type

Default

Description

nodes

list[str | GraphNode] | None

(required)

Nodes comprising the graph. Pass node instances or their string labels. If None, all registered GraphNodes in the active ExperimentContext are used.

optimizer

Optimizer | None

None

A shared optimizer for end-to-end training. If provided, all trainable nodes must share the same backend.

label

str

"model-graph"

A human-readable label for this graph.

Simple Linear Graph#

The simplest graph is a linear chain: FeatureSet -> ModelNode.

node = ModelNode(
    label="SimpleMLP",
    model=SequentialMLP(output_shape=(1, 1), n_layers=2, hidden_dim=32),
    upstream_ref=fs_ref,
)

mg = ModelGraph(
    nodes=[node],
    optimizer=Optimizer(opt="adam", opt_kwargs={"lr": 1e-3}, backend="torch"),
    label="simple-graph",
)
print(f"Label:  {mg.label}")
print(f"Nodes:  {mg.node_labels}")
print(f"Built:  {mg.is_built}")

Multi-Node Chain#

Chain multiple ModelNodes by passing one as the upstream_ref of the next.

ModelGraph supports the .visualize() method, which we’ll use to show our topology updates.

encoder = ModelNode(
    label="Encoder",
    model=SequentialMLP(output_shape=(1, 8), n_layers=2, hidden_dim=32),
    upstream_ref=fs_ref,
)

regressor = ModelNode(
    label="Regressor",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=16),
    upstream_ref=encoder,
)

mg_chain = ModelGraph(
    nodes=[encoder, regressor],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Node labels: {mg_chain.node_labels}")

mg_chain.visualize()

Branching Graph with MergeNode#

Use ConcatNode (a MergeNode) to combine outputs from parallel branches.

FeatureSet ─┬─> EncoderA ──┐
            │              ├─> ConcatNode ──> Head
            └─> EncoderB ──┘
enc_a = ModelNode(
    label="EncoderA",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
enc_b = ModelNode(
    label="EncoderB",
    model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)

merge = ConcatNode(
    label="Merge",
    upstream_refs=[enc_a, enc_b],
    concat_axis=-1,
    concat_axis_targets="first",
)

head = ModelNode(
    label="Head",
    model=SequentialMLP(n_layers=1, hidden_dim=8),
    upstream_ref=merge,
)

mg_branch = ModelGraph(
    nodes=[enc_a, enc_b, merge, head],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Node labels: {mg_branch.node_labels}")

mg_branch.visualize()

Referencing Nodes by Label#

Instead of passing node instances, you can pass their string labels. The graph will look them up in the active ExperimentContext.

mg_by_label = ModelGraph(
    nodes=["EncoderA", "EncoderB", "Merge", "Head"],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Node labels: {mg_by_label.node_labels}")

mg_by_label.visualize()

Without a Global Optimizer#

If no global optimizer is provided, each ModelNode must define its own local optimizer. This is useful when different nodes need different optimizers or learning rates (stage-wise training).

node_with_opt = ModelNode(
    label="StageWiseMLP",
    model=SequentialMLP(output_shape=(1, 1), n_layers=2, hidden_dim=32),
    upstream_ref=fs_ref,
    optimizer=Optimizer("adam", opt_kwargs={"lr": 1e-3}, backend="torch"),
)

mg_no_global = ModelGraph(
    nodes=[node_with_opt],
    optimizer=None,
)
print(f"Global optimizer: {mg_no_global.backend}")

Building the Graph#

ModelGraph.build() performs the following steps in topological order:

  1. Validates the DAG structure (no cycles, all upstream references resolved).

  2. Infers input and output shapes for each node from upstream outputs and FeatureSet shapes.

  3. Builds each node’s underlying model (lazy initialization).

  4. Builds the global optimizer (if provided) with parameters from all trainable nodes.

    ModelGraph.build(*, force: bool = False)

Parameter

Type

Default

Description

force

bool

False

If True, rebuilds even if the graph is already built.

mg_branch.build()
print(f"Built: {mg_branch.is_built}")

for node in mg_branch.nodes.values():
    in_shape = (
        node.input_shape
        if hasattr(node, "input_shape")
        else list(node.input_shapes.values())
    )
    out_shape = getattr(node, "output_shape", None)
    print(f"  {node.label}: {in_shape} -> {out_shape}")

mg_branch.visualize()  # Note how all edges now show the input/output shapes

Shape Inference Details#

During build(), shapes propagate through the graph as follows:

  • Head nodes (inputs from a FeatureSet): Input shape is pulled directly from the referenced FeatureSet data.

  • Intermediate nodes: Input shape equals the output shape of their upstream node.

  • Tail nodes (no downstream consumers): If no output_shape is specified on the model, it defaults to the target shape propagated from the upstream FeatureSet.

  • MergeNodes: Both feature and target output shapes are determined by a dummy forward pass through the merge logic.

You generally do not need to specify input_shape on your models — build() infers it. Specifying output_shape is recommended for all non-tail nodes.

Rebuilding#

Calling build() on an already-built graph is a no-op unless force=True.

# No-op (already built)
mg_branch.build()

# Force rebuild (e.g., after modifying graph structure)
mg_branch.build(force=True)
print(f"Rebuilt: {mg_branch.is_built}")

Graph Properties#

After building, the graph exposes several useful properties for inspecting its structure.

print(f"Label:       {mg_branch.label}")
print(f"Built:       {mg_branch.is_built}")
print(f"Backend:     {mg_branch.backend}")
print(f"Node labels: {mg_branch.node_labels}")

Head and Tail Nodes#

  • Head nodes: Nodes whose inputs come directly from a FeatureSet (no upstream GraphNode dependencies).

  • Tail nodes: Nodes whose outputs are not consumed by any other node in the graph.

print("Head nodes (receive FeatureSet data):")
for n in mg_branch.head_nodes.values():
    print(f"  - {n.label}")

print("\nTail nodes (produce final outputs):")
for n in mg_branch.tail_nodes.values():
    print(f"  - {n.label}")

Accessing Individual Nodes#

Nodes are stored in a dict keyed by node_id. These IDs are globally unique and are the reason nodes can be reference by their label, ID, or instance at any point in an Experiment.

You can iterate over nodes or access by label.

# All nodes (keyed by node_id)
for n_id, node in mg_branch.nodes.items():
    print(f"  {node.label}  (id={n_id[:8]}...)")

Forward Pass#

Once built, you can execute a forward pass through the graph. The graph handles data routing between nodes in topological order.

    ModelGraph.forward(
        inputs: dict[tuple[str, FeatureSetReference], TForward],
        *,
        active_nodes: list[str | GraphNode] | None = None,
    ) -> dict[str, TForward]

Parameter

Type

Description

inputs

dict

Mapping of (head_node_id, FeatureSetReference) to input data. Each head node needs its upstream FeatureSet data.

active_nodes

list | None

Optional subset of nodes to execute. Upstream dependencies are included automatically. If None, all nodes run.

Returns: A dict mapping node_id to that node’s output data for every executed node.

from modularml.core.data.sample_data import SampleData
from modularml.utils.data.data_format import DataFormat

# Prepare input data
fsv = fs_ref.resolve()
sample_data = SampleData(
    features=fsv.get_features(fmt=DataFormat.TORCH),
    targets=fsv.get_targets(fmt=DataFormat.TORCH),
)

# Build the inputs dict: (head_node_id, featureset_ref) -> data
inputs = {}
for n_id, node in mg_branch.head_nodes.items():
    for ref in node.get_upstream_refs():
        inputs[(n_id, ref)] = sample_data

print(f"Number of input entries: {len(inputs)}")
# Execute forward pass
with torch.no_grad():
    outputs = mg_branch.forward(inputs)

print("Outputs per node:")
for n_id, out in outputs.items():
    node_label = mg_branch.nodes[n_id].label
    print(f"  {node_label}: features={out.features.shape}")

Active Nodes#

You can restrict the forward pass to a subset of the graph using active_nodes. All required upstream dependencies are automatically included.

We can set just “merge” to be active, but all upstream nodes (Encoders A and B) will need to be executed as well. The head node, however, does not need to be executed.

# Only execute EncoderA and the Merge (plus its dependencies)
with torch.no_grad():
    partial_outputs = mg_branch.forward(inputs, active_nodes=[merge])

print("Executed nodes:")
for n_id in partial_outputs:
    print(f"  - {mg_branch.nodes[n_id].label}")

Graph Mutation#

ModelGraph provides several methods to modify the graph structure after creation. All mutation methods return self for method chaining.

After any structural change, the graph automatically revalidates connections and recomputes the topological order. You will need to call build() again to reinitialize shapes and optimizers.

add_node()#

Add a new node to the graph. The node must already be connected to existing nodes via its upstream_ref.

# Start with a simple single-node graph
base_node = ModelNode(
    label="Base",
    model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
mg_mut = ModelGraph(
    nodes=[base_node],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Before: {mg_mut.node_labels}")

# Add a downstream node
added_node = ModelNode(
    label="Added",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=base_node,
)
mg_mut.add_node(added_node)
print(f"After:  {mg_mut.node_labels}")

mg_mut.visualize()

remove_node()#

Remove a node from the graph. Downstream nodes are reconnected to the removed node’s upstream sources.

Given: A -> B -> C
Remove B:
Result: A -> C
# Create a 3-node chain
n1 = ModelNode(
    label="N1",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
n2 = ModelNode(
    label="N2",
    model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=8),
    upstream_ref=n1,
)
n3 = ModelNode(
    label="N3",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=n2,
)
mg_rem = ModelGraph(
    nodes=[n1, n2, n3],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Before: {mg_rem.node_labels}")

# Remove the middle node
mg_rem.remove_node("N2")
print(f"After:  {mg_rem.node_labels}")

mg_rem.visualize()

replace_node()#

Replace an existing node with a new one, preserving all upstream and downstream connections.

# Create a simple chain
old_enc = ModelNode(
    label="OldEncoder",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
reg = ModelNode(
    label="Reg",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=old_enc,
)
mg_rep = ModelGraph(
    nodes=[old_enc, reg],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Before: {mg_rep.node_labels}")

# Replace with a deeper encoder
new_enc = ModelNode(
    label="NewEncoder",
    model=SequentialMLP(output_shape=(1, 8), n_layers=3, hidden_dim=64),
    upstream_ref=fs_ref,
)
mg_rep.replace_node(old_node="OldEncoder", new_node=new_enc)
print(f"After:  {mg_rep.node_labels}")

mg_rep.visualize()

insert_node_between()#

Insert a new node between two already-connected nodes.

Given: A -> B
Insert C between A and B:
Result: A -> C -> B
a = ModelNode(
    label="A",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
b = ModelNode(
    label="B",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=a,
)
mg_ins = ModelGraph(
    nodes=[a, b],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Before: {mg_ins.node_labels}")

c = ModelNode(
    label="C",
    model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,  # will be overwritten by insert
)
mg_ins.insert_node_between(new_node=c, upstream=a, downstream=b)
print(f"After:  {mg_ins.node_labels}")

# Verify connectivity
for node in mg_ins.nodes.values():
    ups = [r.node_label for r in node.get_upstream_refs()]
    print(f"  {node.label} <- {ups}")

mg_ins.visualize()

insert_node_before() and insert_node_after()#

  • insert_node_before(new_node, downstream=...): Insert before an existing node, taking over all its upstream connections.

  • insert_node_after(new_node, upstream=...): Insert after an existing node as an additional downstream consumer.

p = ModelNode(
    label="P",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
q = ModelNode(
    label="Q",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=p,
)
mg_ib = ModelGraph(
    nodes=[p, q],
    optimizer=Optimizer(opt="adam", backend="torch"),
)

# Insert a node before Q (takes over Q's upstream connections)
pre_q = ModelNode(
    label="PreQ",
    model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
mg_ib.insert_node_before(new_node=pre_q, downstream=q)
print("After insert_node_before:")
for node in mg_ib.nodes.values():
    ups = [r.node_label for r in node.get_upstream_refs()]
    print(f"  {node.label} <- {ups}")

mg_ib.visualize()
# Insert a node after P (adds a new branch)
post_p = ModelNode(
    label="PostP",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=fs_ref,
)
mg_ib.insert_node_after(new_node=post_p, upstream=p)
print("After insert_node_after:")
for node in mg_ib.nodes.values():
    ups = [r.node_label for r in node.get_upstream_refs()]
    print(f"  {node.label} <- {ups}")

print(f"\nTail nodes: {[n.label for n in mg_ib.tail_nodes.values()]}")

mg_ib.visualize()

Freezing and Unfreezing#

Freezing prevents a node’s parameters from being updated during training. This is useful for transfer learning, multi-stage training, or keeping pretrained components fixed.

    ModelGraph.freeze(nodes: list[str | GraphNode] | None = None)
    ModelGraph.unfreeze(nodes: list[str | GraphNode] | None = None)

Parameter

Type

Default

Description

nodes

list | None

None

Nodes to freeze/unfreeze (by label, ID, or instance). If None, applies to all trainable nodes.

# Using the branching graph from Section 1.3
mg_branch.build(force=True)

# Freeze specific nodes
mg_branch.freeze(nodes=[enc_a])
print(f"Frozen nodes: {[n.label for n in mg_branch.frozen_nodes.values()]}")
mg_branch.visualize(show_frozen=True)

# Unfreeze
mg_branch.unfreeze(nodes=[enc_a])
print(f"Frozen nodes: {[n.label for n in mg_branch.frozen_nodes.values()]}")
# Freeze all nodes at once
mg_branch.freeze()
print(f"All frozen: {[n.label for n in mg_branch.frozen_nodes.values()]}")

# Unfreeze all
mg_branch.unfreeze()
print(f"All unfrozen: {[n.label for n in mg_branch.frozen_nodes.values()]}")

Frozen Nodes and the Optimizer#

When using a global optimizer, the optimizer is automatically rebuilt to exclude frozen nodes’ parameters before each training step. This means frozen nodes will not accumulate gradients and their weights remain unchanged.


Optimizer Management#

The ModelGraph supports two training modes based on whether a global optimizer is provided:

Global Optimizer (Graph-Wise Training)#

When a global Optimizer is set on the ModelGraph:

  • A single forward pass runs through the entire graph.

  • All losses are accumulated.

  • A single backward pass computes gradients across all unfrozen nodes.

  • The global optimizer steps once.

This enables end-to-end gradient flow through the full graph, which is the most common training paradigm.

No Global Optimizer (Stage-Wise Training)#

When optimizer=None on the ModelGraph:

  • Each ModelNode must have its own local Optimizer.

  • Nodes are trained independently in topological order.

  • Each node performs its own forward pass, loss computation, backward pass, and optimizer step.

This is useful when you need different optimizers per node, or when certain nodes should not share gradient flow.

Inspecting Optimizer Parameters#

After at least one training step (or after calling build()), you can inspect which nodes contribute parameters to the global optimizer.

mg_branch.build(force=True)

opt_info = mg_branch.get_optimizer_parameters()
print(f"Backend: {opt_info['backend']}")
print(f"Contributing nodes: {len(opt_info['contributing_nodes'])}")
print(f"Total parameters: {len(opt_info['parameters'])}")

Backend Constraints#

When using a global optimizer, all trainable nodes must share the same backend (e.g., all PyTorch). A RuntimeError is raised if backends conflict.

Mixed-backend graphs (e.g., PyTorch encoder + scikit-learn head) must use stage-wise training (no global optimizer).


Serialization#

ModelGraph supports full serialization: saving and loading both the graph structure (config) and learned weights (state).

Config Serialization#

get_config() captures the graph structure (node configs, optimizer config) without learned weights. from_config() reconstructs the graph from a config dict.

config = mg_branch.get_config()
print(f"Config keys: {list(config.keys())}")
print(f"Number of node configs: {len(config['nodes'])}")
print(f"Optimizer config: {config['optimizer'] is not None}")

State Serialization#

get_state() captures the learned weights and optimizer state. set_state() restores them.

state = mg_branch.get_state()
print(f"State keys: {list(state.keys())}")
print(f"Number of node states: {len(state['nodes'])}")
print(f"Is built: {state['is_built']}")

Save and Load to Disk#

Use save() and load() for persistent serialization. The file includes both config and state.

from pathlib import Path
from tempfile import TemporaryDirectory

SAVE_DIR = TemporaryDirectory()

# Save
save_path = mg_branch.save(Path(SAVE_DIR.name) / "my_graph", overwrite=True)
print(f"Saved to: {save_path}")

# Load
# Note that we need allow overwriting because all reloaded node labels/IDs
# with those defined in this notebook
mg_loaded = ModelGraph.load(save_path, overwrite=True)
print(f"Loaded graph labels: {mg_loaded.node_labels}")

mg_loaded.visualize()

Checkpointing#

Checkpointing allows you to save and restore the full state of a ModelGraph at a specific point during training. Unlike save() / load() (which creates a new ModelGraph instance), checkpointing restores state into an existing graph.

    ModelGraph.save_checkpoint(
        filepath: Path,
        *,
        overwrite: bool = False,
        meta: dict[str, Any] | None = None,
    ) -> Path

    ModelGraph.restore_checkpoint(filepath: Path) -> ModelGraph

Parameter

Type

Description

filepath

Path

Location to save/load the checkpoint.

overwrite

bool

Whether to overwrite an existing file.

meta

dict

Optional metadata to attach to the checkpoint (must be pickle-able).

# Save a checkpoint (includes model weights and optimizer state)
ckpt_path = mg_branch.save_checkpoint(
    Path(SAVE_DIR.name) / "checkpoint_epoch5",
    overwrite=True,
    meta={"epoch": 5, "val_loss": 0.032},
)
print(f"Checkpoint saved to: {ckpt_path}")
# Restore the checkpoint into the existing graph
mg_branch.restore_checkpoint(ckpt_path)
print(f"Restored. Built: {mg_branch.is_built}")

Summary#

Constructor#

Parameter

Type

Default

Description

nodes

list[str | GraphNode] | None

(required)

Nodes comprising the graph.

optimizer

Optimizer | None

None

Shared optimizer for graph-wise training.

label

str

"model-graph"

Human-readable label.

Properties#

Property

Type

Description

.nodes

dict[str, GraphNode]

All nodes keyed by node_id.

.node_labels

set[str]

Unique node labels.

.head_nodes

dict[str, GraphNode]

Nodes receiving FeatureSet input.

.tail_nodes

dict[str, GraphNode]

Nodes with no downstream consumers.

.is_built

bool

Whether build() has been called.

.backend

Backend | None

Backend of the global optimizer, or None.

.frozen_nodes

dict[str, GraphNode]

Currently frozen trainable nodes.

Methods#

Method

Description

build(force=False)

Build all nodes and the global optimizer.

forward(inputs, active_nodes=None)

Execute a forward pass through the graph.

train_step(ctx, losses, active_nodes=None)

Execute a single training step (graph-wise or stage-wise).

eval_step(ctx, losses, active_nodes=None)

Execute a forward-only evaluation step (no gradients).

fit_step(ctx, losses=None, active_nodes=None)

Fit batch-fit nodes (e.g., scikit-learn) in topological order.

freeze(nodes=None)

Freeze nodes to prevent training.

unfreeze(nodes=None)

Unfreeze nodes to allow training.

add_node(node)

Add a node to the graph.

remove_node(node)

Remove a node, reconnecting neighbors.

replace_node(old_node, new_node)

Replace a node, preserving connections.

insert_node_between(new_node, upstream, downstream)

Insert between two connected nodes.

insert_node_before(new_node, downstream)

Insert before an existing node.

insert_node_after(new_node, upstream)

Insert after an existing node.

get_config() / from_config()

Config serialization (structure only).

get_state() / set_state()

State serialization (includes weights).

save(filepath) / load(filepath)

Full serialization to/from disk.

save_checkpoint(filepath, meta=None)

Save a training checkpoint.

restore_checkpoint(filepath)

Restore state from a checkpoint.

Training Modes#

Mode

When

Behavior

Graph-wise

Global Optimizer provided

Single forward + backward pass across all nodes. End-to-end gradient flow.

Stage-wise

No global optimizer (None)

Each node trains independently with its own optimizer.

Next Steps#

  • Experiment: Use Experiment to combine a ModelGraph with training phases, loss functions, and evaluation — the primary user-facing entry point.

  • ModelNode: See how individual nodes wrap models and handle forward passes.

  • MergeNode: Learn how to combine parallel branches with ConcatNode.