How to: Create and Use a MergeNode#

A MergeNode is a computational node that combines outputs from multiple upstream nodes into a single output. It is the counterpart to ModelNode in a ModelGraph: while a ModelNode accepts exactly one input, a MergeNode accepts two or more.

Currently, ModularML provides one concrete implementation:

  • ConcatNode — Concatenates inputs along a specified axis, with optional padding for mismatched dimensions.

ComputeNode (abstract)
├── ModelNode       # Single-input, wraps a model
└── MergeNode       # Multi-input, merges upstream outputs (abstract)
    └── ConcatNode  # Concatenates along an axis

This notebook covers:

import numpy as np
import torch

from modularml import (
    ConcatNode,
    Experiment,
    FeatureSet,
    ModelGraph,
    ModelNode,
    Optimizer,
)
from modularml.core.topology.merge_nodes.merge_strategy import MergeStrategy
from modularml.models.torch import SequentialMLP

# Note that we don't need to explicitly create an Experiment right away
# We do it here so we can disable the warning raise when creating multiple
# nodes with the same name (`registration_policy` is what controls this).
exp = Experiment(label="create_mergenode", registration_policy="overwrite")
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 2
      1 import numpy as np
----> 2 import torch
      4 from modularml import (
      5     ConcatNode,
      6     Experiment,
   (...)
     10     Optimizer,
     11 )
     12 from modularml.core.topology.merge_nodes.merge_strategy import MergeStrategy

ModuleNotFoundError: No module named 'torch'

We’ll use a simple synthetic dataset: 200 samples of a 10-point feature with a scalar target.

rng = np.random.default_rng(42)

fs = FeatureSet.from_dict(
    label="Data A",
    data={
        "X": list(rng.standard_normal((200, 10))),
        "Y": list(rng.standard_normal((200, 1))),
    },
    feature_keys="X",
    target_keys="Y",
)

fs_ref = fs.reference(features="X", targets="Y")
print(fs)

When to Use a MergeNode#

A MergeNode is needed when your model graph has multiple parallel branches that must be combined before continuing to a downstream node. Common patterns include:

  • Multi-encoder fusion: Several encoders process the same (or different) inputs, and their representations are concatenated before a final regressor.

  • Feature augmentation: A raw feature path is concatenated with a learned embedding path.

  • Ensemble merging: Outputs from several models are merged (by concatenation, averaging, etc.) for downstream processing.

FeatureSet ─┬─> EncoderA ──┐
            │              ├─> ConcatNode ──> Regressor
            └─> EncoderB ──┘

Creating a ConcatNode#

ConcatNode concatenates multiple inputs along a specified axis.

    ConcatNode(
        label: str,
        upstream_refs: list[ExperimentNode | ExperimentNodeReference],
        concat_axis: int = 0,
        *,
        concat_axis_targets: int | str | MergeStrategy | ExperimentNodeReference = -1,
        concat_axis_tags: int | str | MergeStrategy | ExperimentNodeReference = -1,
        pad_inputs: bool = False,
        pad_mode: str = "constant",
        pad_value: float = 0.0,
    )

Parameter

Type

Default

Description

label

str

(required)

Unique name for this node.

upstream_refs

list

(required)

List of upstream nodes or references to merge.

concat_axis

int

0

Axis along which to concatenate features (see Feature Axis Behavior).

concat_axis_targets

int | str | MergeStrategy | ExperimentNodeReference

-1

Strategy for merging targets (see Per-Domain Axes and Target and Tag Aggregation Strategies).

concat_axis_tags

int | str | MergeStrategy | ExperimentNodeReference

-1

Strategy for merging tags (same semantics as concat_axis_targets).

pad_inputs

bool

False

Whether to pad inputs to align non-concat dimensions.

pad_mode

str

"constant"

Padding mode: "constant", "reflect", "replicate", or "circular".

pad_value

float

0.0

Fill value when pad_mode="constant".

The concat_axis parameter controls how features are merged and is the primary axis used for shape inference during ModelGraph.build(). Targets, tags, and sample UUIDs each have their own merge behavior (see Per-Domain Axes and Target and Tag Aggregation Strategies)).

We will utilize the ModelGraph class in this tutorial to showcase building of connected ModelNodes and ConcatNodes.

Details on the ModelGraph class are provided in How to: Create and Use a ModelGraph

def create_model_graph(
    output_shape_a: tuple[int, ...],
    output_shape_b: tuple[int, ...],
    concat_axis: int,
):
    """
    Build a two-encoder graph to demonstrate different feature concatenation axes.

    Args:
        output_shape_a (tuple[int, ...]):
            Output shape of encoder A (excluding batch dimension).
        output_shape_b (tuple[int, ...]):
            Output shape of encoder B (excluding batch dimension).
        concat_axis (int):
            The feature concatenation axis.

    """
    enc_a = ModelNode(
        label="EncoderA",
        model=SequentialMLP(output_shape=output_shape_a, n_layers=1, hidden_dim=16),
        upstream_ref=fs_ref,
    )
    enc_b = ModelNode(
        label="EncoderB",
        model=SequentialMLP(output_shape=output_shape_b, n_layers=1, hidden_dim=16),
        upstream_ref=fs_ref,
    )
    merge = ConcatNode(
        label="Merge",
        upstream_refs=[enc_a, enc_b],
        concat_axis=concat_axis,
        pad_inputs=True,
    )

    reg = ModelNode(
        label="Regressor",
        model=SequentialMLP(n_layers=1, hidden_dim=8),
        upstream_ref=merge,
    )

    mg = ModelGraph(
        nodes=[enc_a, enc_b, merge, reg],
        optimizer=Optimizer(opt="adam", backend="torch"),
    )
    mg.build()

    print(merge)
    for k, inp_shape in merge.input_shapes.items():
        print(f" - Data from {k.resolve()}: {inp_shape}")
    print(f" - Merged output shape: {merge.output_shape}")

    return mg


mg = create_model_graph(output_shape_a=(1, 10), output_shape_b=(1, 5), concat_axis=0)
mg.visualize()

Feature Axis Behavior#

The concat_axis parameter controls which dimension the feature inputs are concatenated along. All axis values are relative to the data shape excluding the batch dimension.

For example, with upstream output shapes of (1, 8) (excluding batch), a training batch of size 32 produces tensors of shape (32, 1, 8). Here, concat_axis=0 refers to the 1 dimension and concat_axis=1 refers to the 8 dimension.

concat_axis

Behavior

Example: (1, 8) + (1, 8)

0

Concat along first data dim

(2, 8)

1

Concat along second data dim

(1, 16)

-1

Concat along last data dim

(1, 16) — same as axis=1 here

When non-concat dimensions don’t match, the node will raise a ValueError unless pad_inputs=True (see Padding Mismatched Dimensions).

# concat_axis=0: stack along first data dim
# (1, 8) + (1, 8) -> (2, 8)
mg = create_model_graph((1, 8), (1, 8), concat_axis=0)
mg.visualize()
# concat_axis=1: concat along second data dim
# (1, 8) + (1, 8) -> (1, 16)
mg = create_model_graph((1, 8), (1, 8), concat_axis=1)
mg.visualize()
# concat_axis=-1: concat along last dim (useful when ndim may vary)
# (1, 8) + (1, 16) -> (1, 24)
mg = create_model_graph((1, 8), (1, 16), concat_axis=-1)
mg.visualize()

Per-Domain Axes#

When a ConcatNode merges data from upstream nodes, it processes each domain of the SampleData independently:

Domain

Parameter

Default

Description

Features

concat_axis

0

Primary axis, also used for shape inference. Always int-based.

Targets

concat_axis_targets

-1

Concatenation axis or aggregation strategy (see Target and Tag Aggregation Strategies).

Tags

concat_axis_tags

-1

Concatenation axis or aggregation strategy.

Sample UUIDs

(fixed)

-1

Always concatenated along the last axis. Not configurable.

By default, all domains use int-based concatenation. When an int is provided, it specifies the axis along which to concatenate - identical semantics to the feature concat_axis. For 1-D arrays (the most common case for targets, tags, and sample UUIDs), -1 is equivalent to axis=0.

To use a non-concatenation strategy for targets or tags, see Target and Tag Aggregation Strategies.

# Example: concat features along axis 0, targets along last axis (default)
enc_a = ModelNode(
    label="EncoderA",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
enc_b = ModelNode(
    label="EncoderB",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
merge = ConcatNode(
    label="Merge",
    upstream_refs=[enc_a, enc_b],
    concat_axis=0,  # features: (1,8) + (1,8) -> (2,8)
    concat_axis_targets=-1,  # targets: concat along last axis (default)
    concat_axis_tags=-1,  # tags: concat along last axis (default)
)
print(f"Feature axis:       {merge.concat_axis}")
print(f"Target strategy:    {merge.target_strategy}")
print(f"Tags strategy:      {merge.tags_strategy}")

Target and Tag Aggregation Strategies#

When concatenating features from multiple upstream nodes, the default behavior is to also concatenate the associated targets and tags. This is often undesirable — for example, if both encoders receive the same FeatureSet targets, concatenation doubles the target values.

The concat_axis_targets and concat_axis_tags parameters accept several types to control how these domains are merged:

Value

Type

Behavior

-1 (default)

int

Concatenate along last axis (original behavior).

Any int

int

Concatenate along the specified axis.

"first"

str or MergeStrategy.FIRST

Use targets/tags from the first upstream input only.

"last"

str or MergeStrategy.LAST

Use targets/tags from the last upstream input only.

"mean"

str or MergeStrategy.MEAN

Element-wise mean across all inputs (shapes must match).

enc_a

ExperimentNode or ExperimentNodeReference

Use targets/tags from a specific upstream input.

When a non-concat strategy is used, any upstream inputs with None data for that domain are silently filtered out.

Strings are automatically converted to MergeStrategy enum values, so "first" and MergeStrategy.FIRST are equivalent.

# Strategy: "first" - use targets from the first upstream input only (enc_a)
merge_first = ConcatNode(
    label="MergeFirst",
    upstream_refs=[enc_a, enc_b],
    concat_axis=-1,
    concat_axis_targets="first",
)
print(f"target_strategy: {merge_first.target_strategy}")

# Strategy: MergeStrategy enum (equivalent to string)
merge_mean = ConcatNode(
    label="MergeMean",
    upstream_refs=[enc_a, enc_b],
    concat_axis=-1,
    concat_axis_targets=MergeStrategy.MEAN,
)
print(f"target_strategy: {merge_mean.target_strategy}")

# Strategy: select by reference — use targets from a specific upstream node
merge_ref = ConcatNode(
    label="MergeRef",
    upstream_refs=[enc_a, enc_b],
    concat_axis=-1,
    concat_axis_targets=enc_a,  # use EncoderA's targets
)
print(f"target_strategy: {merge_ref.target_strategy.node_label}")

Comparing Strategies on a Forward Pass#

Let’s run the same data through merge nodes with different target strategies to see how the output targets differ.

from modularml.core.data.sample_data import SampleData
from modularml.utils.data.data_format import DataFormat

# First, build enc_a and enc_b by constructing a graph with one merge node (above)
reg_demo = ModelNode(
    label="Reg_demo",
    model=SequentialMLP(n_layers=1, hidden_dim=8),
    upstream_ref=merge_first,
)
mg = ModelGraph(
    nodes=[enc_a, enc_b, merge_first, reg_demo],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
mg.build()


# Now build the remaining merge nodes manually (enc_a and enc_b are already built)
input_shapes = {
    enc_a.reference(): enc_a.output_shape,
    enc_b.reference(): enc_b.output_shape,
}
for m in [merge_mean, merge_ref]:
    m.build(input_shapes=input_shapes, includes_batch_dim=False, backend="torch")

# Also build a default-concat merge for comparison
merge_concat = ConcatNode(
    label="MergeConcat",
    upstream_refs=[enc_a, enc_b],
    concat_axis=-1,
)
merge_concat.build(input_shapes=input_shapes, includes_batch_dim=False, backend="torch")

# Create sample data
fsv = fs_ref.resolve()
sample_data = SampleData(
    features=fsv.get_features(fmt=DataFormat.TORCH),
    targets=fsv.get_targets(fmt=DataFormat.TORCH),
)

with torch.no_grad():
    out_a = enc_a(sample_data)
    out_b = enc_b(sample_data)
    merge_inputs = {enc_a.reference(): out_a, enc_b.reference(): out_b}

    out_concat = merge_concat.forward(merge_inputs)
    out_first = merge_first.forward(merge_inputs)
    out_mean = merge_mean.forward(merge_inputs)
    out_ref = merge_ref.forward(merge_inputs)

print(f"Input targets shape:              {sample_data.targets.shape}")
print(f"concat (default) targets shape:   {out_concat.targets.shape}")
print(f"'first' strategy targets shape:   {out_first.targets.shape}")
print(f"'mean' strategy targets shape:    {out_mean.targets.shape}")
print(f"select-by-ref targets shape:      {out_ref.targets.shape}")
print()
print(
    f"Targets match (first == ref):     {torch.equal(out_first.targets, out_ref.targets)}",
)
print(
    f"Targets match (first == mean):    {torch.equal(out_first.targets, out_mean.targets)}",
)

In this example both encoders receive the same FeatureSet targets, so:

  • concat (default): Targets are doubled — (200, 1) + (200, 1)(200, 2).

  • “first”: Only the first input’s targets are kept — shape stays (200, 1).

  • “mean”: Element-wise average of identical targets — shape stays (200, 1), values unchanged.

  • select-by-ref (enc_a): Identical to “first” here since enc_a is the first input.

The “first”/”last” and select-by-reference strategies are most useful when upstream nodes have different targets, or when you want to pass through a specific node’s targets unchanged.


Padding Mismatched Dimensions#

When inputs have different shapes in non-concat dimensions, ConcatNode can automatically pad the shorter tensors to match the longest one.

Consider two encoders with outputs (2, 8) and (3, 6), concatenated along axis 0 (first data dim):

  • Concat dim: 2 + 3 = 5

  • Non-concat dim: max(8, 6) = 8 (shorter tensor is padded)

  • Result: (5, 8)

# Two encoders with different output shapes in BOTH dimensions
enc_wide = ModelNode(
    label="WideEncoder",
    model=SequentialMLP(output_shape=(2, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)

enc_tall = ModelNode(
    label="TallEncoder",
    model=SequentialMLP(output_shape=(3, 6), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)

# Concat on axis 0 with padding enabled
# dim 0: concatenated (2+3=5), dim 1: padded to max(8,6)=8
merge_padded = ConcatNode(
    label="PaddedMerge",
    upstream_refs=[enc_wide, enc_tall],
    concat_axis=0,
    concat_axis_targets="first",  # avoid target concatenation doubling
    pad_inputs=True,
    pad_mode="constant",
    pad_value=0.0,
)

reg = ModelNode(
    label="Regressor",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=merge_padded,
)

mg = ModelGraph(
    nodes=[enc_wide, enc_tall, merge_padded, reg],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
mg.build()
mg.visualize()


print(merge_padded)
for k, inp_shape in merge_padded.input_shapes.items():
    print(f" - Data from {k.resolve()}: {inp_shape}")
print(f" - Merged output shape: {merge_padded.output_shape}")

Without Padding#

If pad_inputs=False (the default) and non-concat dimensions don’t match, a ValueError is raised at build time with a helpful message.

merge_no_pad = ConcatNode(
    label="NoPadMerge",
    upstream_refs=[enc_wide, enc_tall],
    concat_axis=0,
    pad_inputs=False,
)

try:
    merge_no_pad.build(
        input_shapes={
            enc_wide.reference(): enc_wide.output_shape,
            enc_tall.reference(): enc_tall.output_shape,
        },
        includes_batch_dim=False,
    )
except ValueError as e:
    print(f"ValueError: {e}")

Building a Graph with MergeNodes#

In practice, you don’t need to build MergeNodes manually. ModelGraph.build() handles shape inference and build order for all nodes, including merge nodes.

We already saw this in the create_model_graph helper above. Here’s the full pattern with a non-default target strategy:

enc_a = ModelNode(
    label="EncoderA",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
enc_b = ModelNode(
    label="EncoderB",
    model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)

merge = ConcatNode(
    label="Merge",
    upstream_refs=[enc_a, enc_b],
    concat_axis=-1,  # concat features along last axis
    concat_axis_targets=enc_a,  # use only enc_a's targets
)

regressor = ModelNode(
    label="Regressor",
    model=SequentialMLP(n_layers=1, hidden_dim=8),
    upstream_ref=merge,
)

mg = ModelGraph(
    nodes=[enc_a, enc_b, merge, regressor],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
mg.build()
mg.visualize()

print("Graph built successfully!")
for node in mg.nodes.values():
    in_shapes = None
    out_shape = None
    if hasattr(node, "input_shape"):
        in_shapes = node.input_shape
    elif hasattr(node, "intput_shapes"):
        in_shapes = node.input_shapes
    if hasattr(node, "output_shape"):
        out_shape = node.output_shape

    print(f"  {node.label}: {in_shapes} -> {out_shape}")

Here concat_axis_targets=enc_a tells the merge node to use EncoderA’s targets as the output targets instead of concatenating targets from both inputs. This is passed as an ExperimentNode instance (which is automatically converted to an ExperimentNodeReference).

The graph correctly infers:

  • EncoderA: input (1, 10) → output (1, 8)

  • EncoderB: input (1, 10) → output (1, 4)

  • Merge: (1, 8) + (1, 4) along last axis → (1, 12) features, targets selected from EncoderA

  • Regressor: input (1, 12) → output (1, 1)


Forward Pass#

Forward passes through a MergeNode work the same as through a ModelNode. The merge accepts SampleData, RoleData, or Batch and returns the same type.

When running through a ModelGraph, this is all handled automatically. Below we trace a manual forward pass to show how data flows through each node, using the concat_axis_targets=enc_a merge node from Building a Graph with MergeNodes.

# Create SampleData from the FeatureSet reference (already imported above)
fsv = fs_ref.resolve()
sample_data = SampleData(
    features=fsv.get_features(fmt=DataFormat.TORCH),
    targets=fsv.get_targets(fmt=DataFormat.TORCH),
)
print(f"Input features shape: {sample_data.features.shape}")
print(f"Input targets shape:  {sample_data.targets.shape}")
# Trace through each node manually
with torch.no_grad():
    out_a = enc_a(sample_data)
    out_b = enc_b(sample_data)
    print(f"EncoderA features: {out_a.features.shape}")
    print(f"EncoderA targets:  {out_a.targets.shape}")
    print(f"EncoderB features: {out_b.features.shape}")
    print(f"EncoderB targets:  {out_b.targets.shape}")

    # Merge expects a dict of {reference: data}
    merge_inputs = {
        enc_a.reference(): out_a,
        enc_b.reference(): out_b,
    }
    out_merge = merge.forward(merge_inputs)
    print(f"\nMerge features:    {out_merge.features.shape}")
    print(f"Merge targets:     {out_merge.targets.shape}  (selected from EncoderA)")

    out_final = regressor(out_merge)
    print(f"Regressor output:  {out_final.features.shape}")

Notice that features are concatenated along the last axis (concat_axis=-1): (1,8) + (1,4) -> (1,12). Because concat_axis_targets=enc_a, the merged targets have the same shape as the original FeatureSet targets (200, 1) — they are not concatenated.

Compare this with the default behavior (shown in Target and Tag Aggregation Strategies), where targets would be (200, 2) due to concatenation.

Verifying Padded Forward Pass#

Let’s verify that the padded merge node (from Padding Mismatched Dimensions) produces the expected shapes and that padded regions are filled with zeros.

print(f"PaddedMerge output_shape: {merge_padded.output_shape}")

# Forward pass
with torch.no_grad():
    out_wide = enc_wide(sample_data)
    out_tall = enc_tall(sample_data)
    print(f"WideEncoder output: {out_wide.features.shape}")
    print(f"TallEncoder output: {out_tall.features.shape}")

    padded_inputs = {
        enc_wide.reference(): out_wide,
        enc_tall.reference(): out_tall,
    }
    out_padded = merge_padded.forward(padded_inputs)
    print(f"Padded merge output: {out_padded.features.shape}")

    # Verify padding: TallEncoder (3, 6) is padded to (3, 8)
    # After concat on axis 0: rows 0:2 from WideEncoder, rows 2:5 from TallEncoder
    # Columns 6:8 of TallEncoder's contribution should be zero
    padded_region = out_padded.features[:, 2:5, 6:8].numpy()
    print(f"Padded region values (should be all zeros): {np.unique(padded_region)}")

Key Properties and Methods#

MergeNode (base class)#

Property / Method

Description

.is_built

Whether shape inference has been completed.

.output_shape

Output shape (no batch dim) after merging.

.input_shapes

Dict mapping each upstream reference to its input shape.

.backend

Backend enum, or None if not set.

merge(x)

Forward pass on a list of SampleData, RoleData, or Batch.

forward(inputs)

Forward pass from a dict of {reference: data}.

apply_merge(values, domain=...)

Abstract method that subclasses implement. Receives a domain string to allow per-domain merge logic.

ConcatNode#

Property / Method

Description

.concat_axis

The axis along which features are concatenated (int).

.target_strategy

Strategy for merging targets: int (concat axis), MergeStrategy, or ExperimentNodeReference.

.tags_strategy

Strategy for merging tags (same types as target_strategy).

.target_axis

Convenience property — returns the int axis when target_strategy is int. Raises TypeError otherwise.

.tags_axis

Convenience property — returns the int axis when tags_strategy is int. Raises TypeError otherwise.

.pad_inputs

Whether padding is enabled.

.pad_mode

Padding mode ("constant", "reflect", etc.).

.pad_value

Fill value for constant padding.

MergeStrategy Enum#

Value

Description

MergeStrategy.CONCAT

Concatenate along an axis (requires an int axis).

MergeStrategy.FIRST

Use data from the first upstream input.

MergeStrategy.LAST

Use data from the last upstream input.

MergeStrategy.MEAN

Element-wise mean across inputs (shapes must match).

Next Steps#

  • ModelGraph: See how ModelNodes and MergeNodes are composed into a full computational graph with automatic shape inference.

  • Experiment: Use Experiment to combine a ModelGraph with training phases, loss functions, and evaluation.

  • Custom MergeNode: Subclass MergeNode and implement apply_merge() for custom merging strategies (e.g., averaging, attention-based fusion).