How to: Create and Use a MergeNode#
A MergeNode is a computational node that combines outputs from multiple upstream nodes into a single output. It is the counterpart to ModelNode in a ModelGraph: while a ModelNode accepts exactly one input, a MergeNode accepts two or more.
Currently, ModularML provides one concrete implementation:
ConcatNode— Concatenates inputs along a specified axis, with optional padding for mismatched dimensions.
ComputeNode (abstract)
├── ModelNode # Single-input, wraps a model
└── MergeNode # Multi-input, merges upstream outputs (abstract)
└── ConcatNode # Concatenates along an axis
This notebook covers:
import numpy as np
import torch
from modularml import (
ConcatNode,
Experiment,
FeatureSet,
ModelGraph,
ModelNode,
Optimizer,
)
from modularml.core.topology.merge_nodes.merge_strategy import MergeStrategy
from modularml.models.torch import SequentialMLP
# Note that we don't need to explicitly create an Experiment right away
# We do it here so we can disable the warning raise when creating multiple
# nodes with the same name (`registration_policy` is what controls this).
exp = Experiment(label="create_mergenode", registration_policy="overwrite")
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[1], line 2
1 import numpy as np
----> 2 import torch
4 from modularml import (
5 ConcatNode,
6 Experiment,
(...)
10 Optimizer,
11 )
12 from modularml.core.topology.merge_nodes.merge_strategy import MergeStrategy
ModuleNotFoundError: No module named 'torch'
We’ll use a simple synthetic dataset: 200 samples of a 10-point feature with a scalar target.
rng = np.random.default_rng(42)
fs = FeatureSet.from_dict(
label="Data A",
data={
"X": list(rng.standard_normal((200, 10))),
"Y": list(rng.standard_normal((200, 1))),
},
feature_keys="X",
target_keys="Y",
)
fs_ref = fs.reference(features="X", targets="Y")
print(fs)
When to Use a MergeNode#
A MergeNode is needed when your model graph has multiple parallel branches that must be combined before continuing to a downstream node. Common patterns include:
Multi-encoder fusion: Several encoders process the same (or different) inputs, and their representations are concatenated before a final regressor.
Feature augmentation: A raw feature path is concatenated with a learned embedding path.
Ensemble merging: Outputs from several models are merged (by concatenation, averaging, etc.) for downstream processing.
FeatureSet ─┬─> EncoderA ──┐
│ ├─> ConcatNode ──> Regressor
└─> EncoderB ──┘
Creating a ConcatNode#
ConcatNode concatenates multiple inputs along a specified axis.
ConcatNode(
label: str,
upstream_refs: list[ExperimentNode | ExperimentNodeReference],
concat_axis: int = 0,
*,
concat_axis_targets: int | str | MergeStrategy | ExperimentNodeReference = -1,
concat_axis_tags: int | str | MergeStrategy | ExperimentNodeReference = -1,
pad_inputs: bool = False,
pad_mode: str = "constant",
pad_value: float = 0.0,
)
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
(required) |
Unique name for this node. |
|
|
(required) |
List of upstream nodes or references to merge. |
|
|
|
Axis along which to concatenate features (see Feature Axis Behavior). |
|
|
|
Strategy for merging targets (see Per-Domain Axes and Target and Tag Aggregation Strategies). |
|
|
|
Strategy for merging tags (same semantics as |
|
|
|
Whether to pad inputs to align non-concat dimensions. |
|
|
|
Padding mode: |
|
|
|
Fill value when |
The concat_axis parameter controls how features are merged and is the primary axis used for shape inference during ModelGraph.build(). Targets, tags, and sample UUIDs each have their own merge behavior (see Per-Domain Axes and Target and Tag Aggregation Strategies)).
We will utilize the ModelGraph class in this tutorial to showcase building of connected ModelNodes and ConcatNodes.
Details on the ModelGraph class are provided in How to: Create and Use a ModelGraph
def create_model_graph(
output_shape_a: tuple[int, ...],
output_shape_b: tuple[int, ...],
concat_axis: int,
):
"""
Build a two-encoder graph to demonstrate different feature concatenation axes.
Args:
output_shape_a (tuple[int, ...]):
Output shape of encoder A (excluding batch dimension).
output_shape_b (tuple[int, ...]):
Output shape of encoder B (excluding batch dimension).
concat_axis (int):
The feature concatenation axis.
"""
enc_a = ModelNode(
label="EncoderA",
model=SequentialMLP(output_shape=output_shape_a, n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
enc_b = ModelNode(
label="EncoderB",
model=SequentialMLP(output_shape=output_shape_b, n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
merge = ConcatNode(
label="Merge",
upstream_refs=[enc_a, enc_b],
concat_axis=concat_axis,
pad_inputs=True,
)
reg = ModelNode(
label="Regressor",
model=SequentialMLP(n_layers=1, hidden_dim=8),
upstream_ref=merge,
)
mg = ModelGraph(
nodes=[enc_a, enc_b, merge, reg],
optimizer=Optimizer(opt="adam", backend="torch"),
)
mg.build()
print(merge)
for k, inp_shape in merge.input_shapes.items():
print(f" - Data from {k.resolve()}: {inp_shape}")
print(f" - Merged output shape: {merge.output_shape}")
return mg
mg = create_model_graph(output_shape_a=(1, 10), output_shape_b=(1, 5), concat_axis=0)
mg.visualize()
Feature Axis Behavior#
The concat_axis parameter controls which dimension the feature inputs are concatenated along.
All axis values are relative to the data shape excluding the batch dimension.
For example, with upstream output shapes of (1, 8) (excluding batch), a training batch of size 32 produces tensors of shape (32, 1, 8). Here, concat_axis=0 refers to the 1 dimension and concat_axis=1 refers to the 8 dimension.
|
Behavior |
Example: |
|---|---|---|
|
Concat along first data dim |
|
|
Concat along second data dim |
|
|
Concat along last data dim |
|
When non-concat dimensions don’t match, the node will raise a ValueError unless pad_inputs=True (see Padding Mismatched Dimensions).
# concat_axis=0: stack along first data dim
# (1, 8) + (1, 8) -> (2, 8)
mg = create_model_graph((1, 8), (1, 8), concat_axis=0)
mg.visualize()
# concat_axis=1: concat along second data dim
# (1, 8) + (1, 8) -> (1, 16)
mg = create_model_graph((1, 8), (1, 8), concat_axis=1)
mg.visualize()
# concat_axis=-1: concat along last dim (useful when ndim may vary)
# (1, 8) + (1, 16) -> (1, 24)
mg = create_model_graph((1, 8), (1, 16), concat_axis=-1)
mg.visualize()
Per-Domain Axes#
When a ConcatNode merges data from upstream nodes, it processes each domain of the SampleData independently:
Domain |
Parameter |
Default |
Description |
|---|---|---|---|
Features |
|
|
Primary axis, also used for shape inference. Always int-based. |
Targets |
|
|
Concatenation axis or aggregation strategy (see Target and Tag Aggregation Strategies). |
Tags |
|
|
Concatenation axis or aggregation strategy. |
Sample UUIDs |
(fixed) |
|
Always concatenated along the last axis. Not configurable. |
By default, all domains use int-based concatenation. When an int is provided, it specifies the axis along which to concatenate - identical semantics to the feature concat_axis. For 1-D arrays (the most common case for targets, tags, and sample UUIDs), -1 is equivalent to axis=0.
To use a non-concatenation strategy for targets or tags, see Target and Tag Aggregation Strategies.
# Example: concat features along axis 0, targets along last axis (default)
enc_a = ModelNode(
label="EncoderA",
model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
enc_b = ModelNode(
label="EncoderB",
model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
merge = ConcatNode(
label="Merge",
upstream_refs=[enc_a, enc_b],
concat_axis=0, # features: (1,8) + (1,8) -> (2,8)
concat_axis_targets=-1, # targets: concat along last axis (default)
concat_axis_tags=-1, # tags: concat along last axis (default)
)
print(f"Feature axis: {merge.concat_axis}")
print(f"Target strategy: {merge.target_strategy}")
print(f"Tags strategy: {merge.tags_strategy}")
Target and Tag Aggregation Strategies#
When concatenating features from multiple upstream nodes, the default behavior is to also concatenate the associated targets and tags. This is often undesirable — for example, if both encoders receive the same FeatureSet targets, concatenation doubles the target values.
The concat_axis_targets and concat_axis_tags parameters accept several types to control how these domains are merged:
Value |
Type |
Behavior |
|---|---|---|
|
|
Concatenate along last axis (original behavior). |
Any |
|
Concatenate along the specified axis. |
|
|
Use targets/tags from the first upstream input only. |
|
|
Use targets/tags from the last upstream input only. |
|
|
Element-wise mean across all inputs (shapes must match). |
|
|
Use targets/tags from a specific upstream input. |
When a non-concat strategy is used, any upstream inputs with None data for that domain are silently filtered out.
Strings are automatically converted to MergeStrategy enum values, so "first" and MergeStrategy.FIRST are equivalent.
# Strategy: "first" - use targets from the first upstream input only (enc_a)
merge_first = ConcatNode(
label="MergeFirst",
upstream_refs=[enc_a, enc_b],
concat_axis=-1,
concat_axis_targets="first",
)
print(f"target_strategy: {merge_first.target_strategy}")
# Strategy: MergeStrategy enum (equivalent to string)
merge_mean = ConcatNode(
label="MergeMean",
upstream_refs=[enc_a, enc_b],
concat_axis=-1,
concat_axis_targets=MergeStrategy.MEAN,
)
print(f"target_strategy: {merge_mean.target_strategy}")
# Strategy: select by reference — use targets from a specific upstream node
merge_ref = ConcatNode(
label="MergeRef",
upstream_refs=[enc_a, enc_b],
concat_axis=-1,
concat_axis_targets=enc_a, # use EncoderA's targets
)
print(f"target_strategy: {merge_ref.target_strategy.node_label}")
Comparing Strategies on a Forward Pass#
Let’s run the same data through merge nodes with different target strategies to see how the output targets differ.
from modularml.core.data.sample_data import SampleData
from modularml.utils.data.data_format import DataFormat
# First, build enc_a and enc_b by constructing a graph with one merge node (above)
reg_demo = ModelNode(
label="Reg_demo",
model=SequentialMLP(n_layers=1, hidden_dim=8),
upstream_ref=merge_first,
)
mg = ModelGraph(
nodes=[enc_a, enc_b, merge_first, reg_demo],
optimizer=Optimizer(opt="adam", backend="torch"),
)
mg.build()
# Now build the remaining merge nodes manually (enc_a and enc_b are already built)
input_shapes = {
enc_a.reference(): enc_a.output_shape,
enc_b.reference(): enc_b.output_shape,
}
for m in [merge_mean, merge_ref]:
m.build(input_shapes=input_shapes, includes_batch_dim=False, backend="torch")
# Also build a default-concat merge for comparison
merge_concat = ConcatNode(
label="MergeConcat",
upstream_refs=[enc_a, enc_b],
concat_axis=-1,
)
merge_concat.build(input_shapes=input_shapes, includes_batch_dim=False, backend="torch")
# Create sample data
fsv = fs_ref.resolve()
sample_data = SampleData(
features=fsv.get_features(fmt=DataFormat.TORCH),
targets=fsv.get_targets(fmt=DataFormat.TORCH),
)
with torch.no_grad():
out_a = enc_a(sample_data)
out_b = enc_b(sample_data)
merge_inputs = {enc_a.reference(): out_a, enc_b.reference(): out_b}
out_concat = merge_concat.forward(merge_inputs)
out_first = merge_first.forward(merge_inputs)
out_mean = merge_mean.forward(merge_inputs)
out_ref = merge_ref.forward(merge_inputs)
print(f"Input targets shape: {sample_data.targets.shape}")
print(f"concat (default) targets shape: {out_concat.targets.shape}")
print(f"'first' strategy targets shape: {out_first.targets.shape}")
print(f"'mean' strategy targets shape: {out_mean.targets.shape}")
print(f"select-by-ref targets shape: {out_ref.targets.shape}")
print()
print(
f"Targets match (first == ref): {torch.equal(out_first.targets, out_ref.targets)}",
)
print(
f"Targets match (first == mean): {torch.equal(out_first.targets, out_mean.targets)}",
)
In this example both encoders receive the same FeatureSet targets, so:
concat (default): Targets are doubled —
(200, 1)+(200, 1)→(200, 2).“first”: Only the first input’s targets are kept — shape stays
(200, 1).“mean”: Element-wise average of identical targets — shape stays
(200, 1), values unchanged.select-by-ref (
enc_a): Identical to “first” here sinceenc_ais the first input.
The “first”/”last” and select-by-reference strategies are most useful when upstream nodes have different targets, or when you want to pass through a specific node’s targets unchanged.
Padding Mismatched Dimensions#
When inputs have different shapes in non-concat dimensions, ConcatNode can automatically pad the shorter tensors to match the longest one.
Consider two encoders with outputs (2, 8) and (3, 6), concatenated along axis 0 (first data dim):
Concat dim:
2 + 3 = 5Non-concat dim:
max(8, 6) = 8(shorter tensor is padded)Result:
(5, 8)
# Two encoders with different output shapes in BOTH dimensions
enc_wide = ModelNode(
label="WideEncoder",
model=SequentialMLP(output_shape=(2, 8), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
enc_tall = ModelNode(
label="TallEncoder",
model=SequentialMLP(output_shape=(3, 6), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
# Concat on axis 0 with padding enabled
# dim 0: concatenated (2+3=5), dim 1: padded to max(8,6)=8
merge_padded = ConcatNode(
label="PaddedMerge",
upstream_refs=[enc_wide, enc_tall],
concat_axis=0,
concat_axis_targets="first", # avoid target concatenation doubling
pad_inputs=True,
pad_mode="constant",
pad_value=0.0,
)
reg = ModelNode(
label="Regressor",
model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
upstream_ref=merge_padded,
)
mg = ModelGraph(
nodes=[enc_wide, enc_tall, merge_padded, reg],
optimizer=Optimizer(opt="adam", backend="torch"),
)
mg.build()
mg.visualize()
print(merge_padded)
for k, inp_shape in merge_padded.input_shapes.items():
print(f" - Data from {k.resolve()}: {inp_shape}")
print(f" - Merged output shape: {merge_padded.output_shape}")
Without Padding#
If pad_inputs=False (the default) and non-concat dimensions don’t match, a ValueError is raised at build time with a helpful message.
merge_no_pad = ConcatNode(
label="NoPadMerge",
upstream_refs=[enc_wide, enc_tall],
concat_axis=0,
pad_inputs=False,
)
try:
merge_no_pad.build(
input_shapes={
enc_wide.reference(): enc_wide.output_shape,
enc_tall.reference(): enc_tall.output_shape,
},
includes_batch_dim=False,
)
except ValueError as e:
print(f"ValueError: {e}")
Building a Graph with MergeNodes#
In practice, you don’t need to build MergeNodes manually. ModelGraph.build() handles
shape inference and build order for all nodes, including merge nodes.
We already saw this in the create_model_graph helper above. Here’s the full pattern with a non-default target strategy:
enc_a = ModelNode(
label="EncoderA",
model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
enc_b = ModelNode(
label="EncoderB",
model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
upstream_ref=fs_ref,
)
merge = ConcatNode(
label="Merge",
upstream_refs=[enc_a, enc_b],
concat_axis=-1, # concat features along last axis
concat_axis_targets=enc_a, # use only enc_a's targets
)
regressor = ModelNode(
label="Regressor",
model=SequentialMLP(n_layers=1, hidden_dim=8),
upstream_ref=merge,
)
mg = ModelGraph(
nodes=[enc_a, enc_b, merge, regressor],
optimizer=Optimizer(opt="adam", backend="torch"),
)
mg.build()
mg.visualize()
print("Graph built successfully!")
for node in mg.nodes.values():
in_shapes = None
out_shape = None
if hasattr(node, "input_shape"):
in_shapes = node.input_shape
elif hasattr(node, "intput_shapes"):
in_shapes = node.input_shapes
if hasattr(node, "output_shape"):
out_shape = node.output_shape
print(f" {node.label}: {in_shapes} -> {out_shape}")
Here concat_axis_targets=enc_a tells the merge node to use EncoderA’s targets as the output targets instead of concatenating targets from both inputs. This is passed as an ExperimentNode instance (which is automatically converted to an ExperimentNodeReference).
The graph correctly infers:
EncoderA: input
(1, 10)→ output(1, 8)EncoderB: input
(1, 10)→ output(1, 4)Merge:
(1, 8)+(1, 4)along last axis →(1, 12)features, targets selected from EncoderARegressor: input
(1, 12)→ output(1, 1)
Forward Pass#
Forward passes through a MergeNode work the same as through a ModelNode. The merge
accepts SampleData, RoleData, or Batch and returns the same type.
When running through a ModelGraph, this is all handled automatically. Below we trace
a manual forward pass to show how data flows through each node, using the concat_axis_targets=enc_a merge node from Building a Graph with MergeNodes.
# Create SampleData from the FeatureSet reference (already imported above)
fsv = fs_ref.resolve()
sample_data = SampleData(
features=fsv.get_features(fmt=DataFormat.TORCH),
targets=fsv.get_targets(fmt=DataFormat.TORCH),
)
print(f"Input features shape: {sample_data.features.shape}")
print(f"Input targets shape: {sample_data.targets.shape}")
# Trace through each node manually
with torch.no_grad():
out_a = enc_a(sample_data)
out_b = enc_b(sample_data)
print(f"EncoderA features: {out_a.features.shape}")
print(f"EncoderA targets: {out_a.targets.shape}")
print(f"EncoderB features: {out_b.features.shape}")
print(f"EncoderB targets: {out_b.targets.shape}")
# Merge expects a dict of {reference: data}
merge_inputs = {
enc_a.reference(): out_a,
enc_b.reference(): out_b,
}
out_merge = merge.forward(merge_inputs)
print(f"\nMerge features: {out_merge.features.shape}")
print(f"Merge targets: {out_merge.targets.shape} (selected from EncoderA)")
out_final = regressor(out_merge)
print(f"Regressor output: {out_final.features.shape}")
Notice that features are concatenated along the last axis (concat_axis=-1): (1,8) + (1,4) -> (1,12). Because concat_axis_targets=enc_a, the merged targets have the same shape as the original FeatureSet targets (200, 1) — they are not concatenated.
Compare this with the default behavior (shown in Target and Tag Aggregation Strategies), where targets would be (200, 2) due to concatenation.
Verifying Padded Forward Pass#
Let’s verify that the padded merge node (from Padding Mismatched Dimensions) produces the expected shapes and that padded regions are filled with zeros.
print(f"PaddedMerge output_shape: {merge_padded.output_shape}")
# Forward pass
with torch.no_grad():
out_wide = enc_wide(sample_data)
out_tall = enc_tall(sample_data)
print(f"WideEncoder output: {out_wide.features.shape}")
print(f"TallEncoder output: {out_tall.features.shape}")
padded_inputs = {
enc_wide.reference(): out_wide,
enc_tall.reference(): out_tall,
}
out_padded = merge_padded.forward(padded_inputs)
print(f"Padded merge output: {out_padded.features.shape}")
# Verify padding: TallEncoder (3, 6) is padded to (3, 8)
# After concat on axis 0: rows 0:2 from WideEncoder, rows 2:5 from TallEncoder
# Columns 6:8 of TallEncoder's contribution should be zero
padded_region = out_padded.features[:, 2:5, 6:8].numpy()
print(f"Padded region values (should be all zeros): {np.unique(padded_region)}")
Key Properties and Methods#
MergeNode (base class)#
Property / Method |
Description |
|---|---|
|
Whether shape inference has been completed. |
|
Output shape (no batch dim) after merging. |
|
Dict mapping each upstream reference to its input shape. |
|
Backend enum, or |
|
Forward pass on a list of |
|
Forward pass from a dict of |
|
Abstract method that subclasses implement. Receives a |
ConcatNode#
Property / Method |
Description |
|---|---|
|
The axis along which features are concatenated ( |
|
Strategy for merging targets: |
|
Strategy for merging tags (same types as |
|
Convenience property — returns the int axis when |
|
Convenience property — returns the int axis when |
|
Whether padding is enabled. |
|
Padding mode ( |
|
Fill value for constant padding. |
MergeStrategy Enum#
Value |
Description |
|---|---|
|
Concatenate along an axis (requires an int axis). |
|
Use data from the first upstream input. |
|
Use data from the last upstream input. |
|
Element-wise mean across inputs (shapes must match). |
Next Steps#
ModelGraph: See how
ModelNodes andMergeNodes are composed into a full computational graph with automatic shape inference.Experiment: Use
Experimentto combine aModelGraphwith training phases, loss functions, and evaluation.Custom MergeNode: Subclass
MergeNodeand implementapply_merge()for custom merging strategies (e.g., averaging, attention-based fusion).