{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# How to: Create and Use a MergeNode\n", "\n", "A `MergeNode` is a computational node that combines outputs from multiple upstream nodes into a single output. It is the counterpart to `ModelNode` in a `ModelGraph`: while a `ModelNode` accepts exactly one input, a `MergeNode` accepts two or more.\n", "\n", "Currently, ModularML provides one concrete implementation:\n", "\n", "- **`ConcatNode`** — Concatenates inputs along a specified axis, with optional padding for mismatched dimensions.\n", "\n", "```\n", "ComputeNode (abstract)\n", "├── ModelNode # Single-input, wraps a model\n", "└── MergeNode # Multi-input, merges upstream outputs (abstract)\n", " └── ConcatNode # Concatenates along an axis\n", "```\n", "\n", "This notebook covers:\n", "\n", "- {ref}`04-create-mergenode-when-to-use-a-mergenode`\n", "- {ref}`04-create-mergenode-creating-a-concatnode`\n", "- {ref}`04-create-mergenode-feature-axis-behavior`\n", "- {ref}`04-create-mergenode-per-domain-axes`\n", "- {ref}`04-create-mergenode-target-and-tag-aggregation-strategies`\n", "- {ref}`04-create-mergenode-padding-mismatched-dimensions`\n", "- {ref}`04-create-mergenode-building-a-graph-with-mergenodes`\n", "- {ref}`04-create-mergenode-forward-pass`\n", "- {ref}`04-create-mergenode-key-properties-and-methods`\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "1", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "\n", "from modularml import (\n", " ConcatNode,\n", " Experiment,\n", " FeatureSet,\n", " ModelGraph,\n", " ModelNode,\n", " Optimizer,\n", ")\n", "from modularml.core.topology.merge_nodes.merge_strategy import MergeStrategy\n", "from modularml.models.torch import SequentialMLP\n", "\n", "# Note that we don't need to explicitly create an Experiment right away\n", "# We do it here so we can disable the warning raise when creating multiple\n", "# nodes with the same name (`registration_policy` is what controls this).\n", "exp = Experiment(label=\"create_mergenode\", registration_policy=\"overwrite\")" ] }, { "cell_type": "markdown", "id": "2", "metadata": {}, "source": [ "We'll use a simple synthetic dataset: 200 samples of a 10-point feature with a scalar target." ] }, { "cell_type": "markdown", "id": "3", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": {}, "outputs": [], "source": [ "rng = np.random.default_rng(42)\n", "\n", "fs = FeatureSet.from_dict(\n", " label=\"Data A\",\n", " data={\n", " \"X\": list(rng.standard_normal((200, 10))),\n", " \"Y\": list(rng.standard_normal((200, 1))),\n", " },\n", " feature_keys=\"X\",\n", " target_keys=\"Y\",\n", ")\n", "\n", "fs_ref = fs.reference(features=\"X\", targets=\"Y\")\n", "print(fs)" ] }, { "cell_type": "markdown", "id": "5", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "6", "metadata": {}, "source": [ "(04-create-mergenode-when-to-use-a-mergenode)=\n", "## When to Use a MergeNode" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "\n", "A `MergeNode` is needed when your model graph has **multiple parallel branches** that must be combined before continuing to a downstream node. Common patterns include:\n", "\n", "- **Multi-encoder fusion:** Several encoders process the same (or different) inputs, and their representations are concatenated before a final regressor.\n", "- **Feature augmentation:** A raw feature path is concatenated with a learned embedding path.\n", "- **Ensemble merging:** Outputs from several models are merged (by concatenation, averaging, etc.) for downstream processing.\n", "\n", "```\n", "FeatureSet ─┬─> EncoderA ──┐\n", " │ ├─> ConcatNode ──> Regressor\n", " └─> EncoderB ──┘\n", "```" ] }, { "cell_type": "markdown", "id": "8", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "9", "metadata": {}, "source": [ "(04-create-mergenode-creating-a-concatnode)=\n", "## Creating a ConcatNode" ] }, { "cell_type": "markdown", "id": "10", "metadata": {}, "source": [ "\n", "`ConcatNode` concatenates multiple inputs along a specified axis.\n", "\n", "```python\n", " ConcatNode(\n", " label: str,\n", " upstream_refs: list[ExperimentNode | ExperimentNodeReference],\n", " concat_axis: int = 0,\n", " *,\n", " concat_axis_targets: int | str | MergeStrategy | ExperimentNodeReference = -1,\n", " concat_axis_tags: int | str | MergeStrategy | ExperimentNodeReference = -1,\n", " pad_inputs: bool = False,\n", " pad_mode: str = \"constant\",\n", " pad_value: float = 0.0,\n", " )\n", "```\n", "\n", "| Parameter | Type | Default | Description |\n", "|-----------|------|---------|-------------|\n", "| `label` | `str` | (required) | Unique name for this node. |\n", "| `upstream_refs` | `list` | (required) | List of upstream nodes or references to merge. |\n", "| `concat_axis` | `int` | `0` | Axis along which to concatenate **features** (see {ref}`04-create-mergenode-feature-axis-behavior`). |\n", "| `concat_axis_targets` | `int \\| str \\| MergeStrategy \\| ExperimentNodeReference` | `-1` | Strategy for merging **targets** (see {ref}`04-create-mergenode-per-domain-axes` and {ref}`04-create-mergenode-target-and-tag-aggregation-strategies`). |\n", "| `concat_axis_tags` | `int \\| str \\| MergeStrategy \\| ExperimentNodeReference` | `-1` | Strategy for merging **tags** (same semantics as `concat_axis_targets`). |\n", "| `pad_inputs` | `bool` | `False` | Whether to pad inputs to align non-concat dimensions. |\n", "| `pad_mode` | `str` | `\"constant\"` | Padding mode: `\"constant\"`, `\"reflect\"`, `\"replicate\"`, or `\"circular\"`. |\n", "| `pad_value` | `float` | `0.0` | Fill value when `pad_mode=\"constant\"`. |\n", "\n", "The `concat_axis` parameter controls how **features** are merged and is the primary axis used for shape inference during `ModelGraph.build()`. Targets, tags, and sample UUIDs each have their own merge behavior (see {ref}`04-create-mergenode-per-domain-axes` and {ref}`04-create-mergenode-target-and-tag-aggregation-strategies`))." ] }, { "cell_type": "markdown", "id": "11", "metadata": {}, "source": [ "We will utilize the `ModelGraph` class in this tutorial to showcase building of connected `ModelNode`s and `ConcatNode`s.\n", "\n", "Details on the `ModelGraph` class are provided in {doc}`03_create_modelgraph`" ] }, { "cell_type": "code", "execution_count": null, "id": "12", "metadata": {}, "outputs": [], "source": [ "def create_model_graph(\n", " output_shape_a: tuple[int, ...],\n", " output_shape_b: tuple[int, ...],\n", " concat_axis: int,\n", "):\n", " \"\"\"\n", " Build a two-encoder graph to demonstrate different feature concatenation axes.\n", "\n", " Args:\n", " output_shape_a (tuple[int, ...]):\n", " Output shape of encoder A (excluding batch dimension).\n", " output_shape_b (tuple[int, ...]):\n", " Output shape of encoder B (excluding batch dimension).\n", " concat_axis (int):\n", " The feature concatenation axis.\n", "\n", " \"\"\"\n", " enc_a = ModelNode(\n", " label=\"EncoderA\",\n", " model=SequentialMLP(output_shape=output_shape_a, n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", " )\n", " enc_b = ModelNode(\n", " label=\"EncoderB\",\n", " model=SequentialMLP(output_shape=output_shape_b, n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", " )\n", " merge = ConcatNode(\n", " label=\"Merge\",\n", " upstream_refs=[enc_a, enc_b],\n", " concat_axis=concat_axis,\n", " pad_inputs=True,\n", " )\n", "\n", " reg = ModelNode(\n", " label=\"Regressor\",\n", " model=SequentialMLP(n_layers=1, hidden_dim=8),\n", " upstream_ref=merge,\n", " )\n", "\n", " mg = ModelGraph(\n", " nodes=[enc_a, enc_b, merge, reg],\n", " optimizer=Optimizer(opt=\"adam\", backend=\"torch\"),\n", " )\n", " mg.build()\n", "\n", " print(merge)\n", " for k, inp_shape in merge.input_shapes.items():\n", " print(f\" - Data from {k.resolve()}: {inp_shape}\")\n", " print(f\" - Merged output shape: {merge.output_shape}\")\n", "\n", " return mg\n", "\n", "\n", "mg = create_model_graph(output_shape_a=(1, 10), output_shape_b=(1, 5), concat_axis=0)\n", "mg.visualize()" ] }, { "cell_type": "markdown", "id": "13", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "14", "metadata": {}, "source": [ "(04-create-mergenode-feature-axis-behavior)=\n", "## Feature Axis Behavior" ] }, { "cell_type": "markdown", "id": "15", "metadata": {}, "source": [ "\n", "The `concat_axis` parameter controls which dimension the **feature** inputs are concatenated along.\n", "All axis values are relative to the **data shape excluding the batch dimension**.\n", "\n", "For example, with upstream output shapes of `(1, 8)` (excluding batch), a training batch of size 32 produces tensors of shape `(32, 1, 8)`. Here, `concat_axis=0` refers to the `1` dimension and `concat_axis=1` refers to the `8` dimension.\n", "\n", "| `concat_axis` | Behavior | Example: `(1, 8)` + `(1, 8)` |\n", "|---------------|----------|-------------------------------|\n", "| `0` | Concat along first data dim | `(2, 8)` |\n", "| `1` | Concat along second data dim | `(1, 16)` |\n", "| `-1` | Concat along last data dim | `(1, 16)` — same as `axis=1` here |\n", "\n", "When non-concat dimensions don't match, the node will raise a `ValueError` unless `pad_inputs=True` (see {ref}`04-create-mergenode-padding-mismatched-dimensions`)." ] }, { "cell_type": "code", "execution_count": null, "id": "16", "metadata": {}, "outputs": [], "source": [ "# concat_axis=0: stack along first data dim\n", "# (1, 8) + (1, 8) -> (2, 8)\n", "mg = create_model_graph((1, 8), (1, 8), concat_axis=0)\n", "mg.visualize()" ] }, { "cell_type": "code", "execution_count": null, "id": "17", "metadata": {}, "outputs": [], "source": [ "# concat_axis=1: concat along second data dim\n", "# (1, 8) + (1, 8) -> (1, 16)\n", "mg = create_model_graph((1, 8), (1, 8), concat_axis=1)\n", "mg.visualize()" ] }, { "cell_type": "code", "execution_count": null, "id": "18", "metadata": {}, "outputs": [], "source": [ "# concat_axis=-1: concat along last dim (useful when ndim may vary)\n", "# (1, 8) + (1, 16) -> (1, 24)\n", "mg = create_model_graph((1, 8), (1, 16), concat_axis=-1)\n", "mg.visualize()" ] }, { "cell_type": "markdown", "id": "19", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "20", "metadata": {}, "source": [ "(04-create-mergenode-per-domain-axes)=\n", "## Per-Domain Axes" ] }, { "cell_type": "markdown", "id": "21", "metadata": {}, "source": [ "\n", "When a `ConcatNode` merges data from upstream nodes, it processes each domain of the `SampleData` independently:\n", "\n", "| Domain | Parameter | Default | Description |\n", "|--------|-----------|---------|-------------|\n", "| **Features** | `concat_axis` | `0` | Primary axis, also used for shape inference. Always int-based. |\n", "| **Targets** | `concat_axis_targets` | `-1` | Concatenation axis or aggregation strategy (see {ref}`04-create-mergenode-target-and-tag-aggregation-strategies`). |\n", "| **Tags** | `concat_axis_tags` | `-1` | Concatenation axis or aggregation strategy. |\n", "| **Sample UUIDs** | *(fixed)* | `-1` | Always concatenated along the last axis. Not configurable. |\n", "\n", "By default, all domains use int-based concatenation. When an `int` is provided, it specifies the axis along which to concatenate - identical semantics to the feature `concat_axis`. For 1-D arrays (the most common case for targets, tags, and sample UUIDs), `-1` is equivalent to `axis=0`.\n", "\n", "To use a non-concatenation strategy for targets or tags, see {ref}`04-create-mergenode-target-and-tag-aggregation-strategies`." ] }, { "cell_type": "code", "execution_count": null, "id": "22", "metadata": {}, "outputs": [], "source": [ "# Example: concat features along axis 0, targets along last axis (default)\n", "enc_a = ModelNode(\n", " label=\"EncoderA\",\n", " model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "enc_b = ModelNode(\n", " label=\"EncoderB\",\n", " model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "merge = ConcatNode(\n", " label=\"Merge\",\n", " upstream_refs=[enc_a, enc_b],\n", " concat_axis=0, # features: (1,8) + (1,8) -> (2,8)\n", " concat_axis_targets=-1, # targets: concat along last axis (default)\n", " concat_axis_tags=-1, # tags: concat along last axis (default)\n", ")\n", "print(f\"Feature axis: {merge.concat_axis}\")\n", "print(f\"Target strategy: {merge.target_strategy}\")\n", "print(f\"Tags strategy: {merge.tags_strategy}\")" ] }, { "cell_type": "markdown", "id": "23", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "24", "metadata": {}, "source": [ "(04-create-mergenode-target-and-tag-aggregation-strategies)=\n", "## Target and Tag Aggregation Strategies" ] }, { "cell_type": "markdown", "id": "25", "metadata": {}, "source": [ "\n", "When concatenating features from multiple upstream nodes, the default behavior is to also concatenate the associated targets and tags. This is often undesirable — for example, if both encoders receive the same FeatureSet targets, concatenation doubles the target values.\n", "\n", "The `concat_axis_targets` and `concat_axis_tags` parameters accept several types to control how these domains are merged:\n", "\n", "| Value | Type | Behavior |\n", "|-------|------|----------|\n", "| `-1` (default) | `int` | Concatenate along last axis (original behavior). |\n", "| Any `int` | `int` | Concatenate along the specified axis. |\n", "| `\"first\"` | `str` or `MergeStrategy.FIRST` | Use targets/tags from the **first** upstream input only. |\n", "| `\"last\"` | `str` or `MergeStrategy.LAST` | Use targets/tags from the **last** upstream input only. |\n", "| `\"mean\"` | `str` or `MergeStrategy.MEAN` | Element-wise mean across all inputs (shapes must match). |\n", "| `enc_a` | `ExperimentNode` or `ExperimentNodeReference` | Use targets/tags from a **specific** upstream input. |\n", "\n", "When a non-concat strategy is used, any upstream inputs with `None` data for that domain are silently filtered out.\n", "\n", "Strings are automatically converted to `MergeStrategy` enum values, so `\"first\"` and `MergeStrategy.FIRST` are equivalent." ] }, { "cell_type": "code", "execution_count": null, "id": "26", "metadata": {}, "outputs": [], "source": [ "# Strategy: \"first\" - use targets from the first upstream input only (enc_a)\n", "merge_first = ConcatNode(\n", " label=\"MergeFirst\",\n", " upstream_refs=[enc_a, enc_b],\n", " concat_axis=-1,\n", " concat_axis_targets=\"first\",\n", ")\n", "print(f\"target_strategy: {merge_first.target_strategy}\")\n", "\n", "# Strategy: MergeStrategy enum (equivalent to string)\n", "merge_mean = ConcatNode(\n", " label=\"MergeMean\",\n", " upstream_refs=[enc_a, enc_b],\n", " concat_axis=-1,\n", " concat_axis_targets=MergeStrategy.MEAN,\n", ")\n", "print(f\"target_strategy: {merge_mean.target_strategy}\")\n", "\n", "# Strategy: select by reference — use targets from a specific upstream node\n", "merge_ref = ConcatNode(\n", " label=\"MergeRef\",\n", " upstream_refs=[enc_a, enc_b],\n", " concat_axis=-1,\n", " concat_axis_targets=enc_a, # use EncoderA's targets\n", ")\n", "print(f\"target_strategy: {merge_ref.target_strategy.node_label}\")" ] }, { "cell_type": "markdown", "id": "27", "metadata": {}, "source": [ "### Comparing Strategies on a Forward Pass\n", "\n", "Let's run the same data through merge nodes with different target strategies to see how the output targets differ." ] }, { "cell_type": "code", "execution_count": null, "id": "28", "metadata": {}, "outputs": [], "source": [ "from modularml.core.data.sample_data import SampleData\n", "from modularml.utils.data.data_format import DataFormat\n", "\n", "# First, build enc_a and enc_b by constructing a graph with one merge node (above)\n", "reg_demo = ModelNode(\n", " label=\"Reg_demo\",\n", " model=SequentialMLP(n_layers=1, hidden_dim=8),\n", " upstream_ref=merge_first,\n", ")\n", "mg = ModelGraph(\n", " nodes=[enc_a, enc_b, merge_first, reg_demo],\n", " optimizer=Optimizer(opt=\"adam\", backend=\"torch\"),\n", ")\n", "mg.build()\n", "\n", "\n", "# Now build the remaining merge nodes manually (enc_a and enc_b are already built)\n", "input_shapes = {\n", " enc_a.reference(): enc_a.output_shape,\n", " enc_b.reference(): enc_b.output_shape,\n", "}\n", "for m in [merge_mean, merge_ref]:\n", " m.build(input_shapes=input_shapes, includes_batch_dim=False, backend=\"torch\")\n", "\n", "# Also build a default-concat merge for comparison\n", "merge_concat = ConcatNode(\n", " label=\"MergeConcat\",\n", " upstream_refs=[enc_a, enc_b],\n", " concat_axis=-1,\n", ")\n", "merge_concat.build(input_shapes=input_shapes, includes_batch_dim=False, backend=\"torch\")\n", "\n", "# Create sample data\n", "fsv = fs_ref.resolve()\n", "sample_data = SampleData(\n", " features=fsv.get_features(fmt=DataFormat.TORCH),\n", " targets=fsv.get_targets(fmt=DataFormat.TORCH),\n", ")\n", "\n", "with torch.no_grad():\n", " out_a = enc_a(sample_data)\n", " out_b = enc_b(sample_data)\n", " merge_inputs = {enc_a.reference(): out_a, enc_b.reference(): out_b}\n", "\n", " out_concat = merge_concat.forward(merge_inputs)\n", " out_first = merge_first.forward(merge_inputs)\n", " out_mean = merge_mean.forward(merge_inputs)\n", " out_ref = merge_ref.forward(merge_inputs)\n", "\n", "print(f\"Input targets shape: {sample_data.targets.shape}\")\n", "print(f\"concat (default) targets shape: {out_concat.targets.shape}\")\n", "print(f\"'first' strategy targets shape: {out_first.targets.shape}\")\n", "print(f\"'mean' strategy targets shape: {out_mean.targets.shape}\")\n", "print(f\"select-by-ref targets shape: {out_ref.targets.shape}\")\n", "print()\n", "print(\n", " f\"Targets match (first == ref): {torch.equal(out_first.targets, out_ref.targets)}\",\n", ")\n", "print(\n", " f\"Targets match (first == mean): {torch.equal(out_first.targets, out_mean.targets)}\",\n", ")" ] }, { "cell_type": "markdown", "id": "29", "metadata": {}, "source": [ "In this example both encoders receive the same FeatureSet targets, so:\n", "\n", "- **concat (default):** Targets are doubled — `(200, 1)` + `(200, 1)` → `(200, 2)`.\n", "- **\"first\":** Only the first input's targets are kept — shape stays `(200, 1)`.\n", "- **\"mean\":** Element-wise average of identical targets — shape stays `(200, 1)`, values unchanged.\n", "- **select-by-ref (`enc_a`):** Identical to \"first\" here since `enc_a` is the first input.\n", "\n", "The \"first\"/\"last\" and select-by-reference strategies are most useful when upstream nodes have different targets, or when you want to pass through a specific node's targets unchanged." ] }, { "cell_type": "markdown", "id": "30", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "31", "metadata": {}, "source": [ "(04-create-mergenode-padding-mismatched-dimensions)=\n", "## Padding Mismatched Dimensions" ] }, { "cell_type": "markdown", "id": "32", "metadata": {}, "source": [ "\n", "When inputs have different shapes in non-concat dimensions, `ConcatNode` can automatically pad the shorter tensors to match the longest one.\n", "\n", "Consider two encoders with outputs `(2, 8)` and `(3, 6)`, concatenated along axis 0 (first data dim):\n", "- **Concat dim:** `2 + 3 = 5`\n", "- **Non-concat dim:** `max(8, 6) = 8` (shorter tensor is padded)\n", "- **Result:** `(5, 8)`" ] }, { "cell_type": "code", "execution_count": null, "id": "33", "metadata": {}, "outputs": [], "source": [ "# Two encoders with different output shapes in BOTH dimensions\n", "enc_wide = ModelNode(\n", " label=\"WideEncoder\",\n", " model=SequentialMLP(output_shape=(2, 8), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "\n", "enc_tall = ModelNode(\n", " label=\"TallEncoder\",\n", " model=SequentialMLP(output_shape=(3, 6), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "\n", "# Concat on axis 0 with padding enabled\n", "# dim 0: concatenated (2+3=5), dim 1: padded to max(8,6)=8\n", "merge_padded = ConcatNode(\n", " label=\"PaddedMerge\",\n", " upstream_refs=[enc_wide, enc_tall],\n", " concat_axis=0,\n", " concat_axis_targets=\"first\", # avoid target concatenation doubling\n", " pad_inputs=True,\n", " pad_mode=\"constant\",\n", " pad_value=0.0,\n", ")\n", "\n", "reg = ModelNode(\n", " label=\"Regressor\",\n", " model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),\n", " upstream_ref=merge_padded,\n", ")\n", "\n", "mg = ModelGraph(\n", " nodes=[enc_wide, enc_tall, merge_padded, reg],\n", " optimizer=Optimizer(opt=\"adam\", backend=\"torch\"),\n", ")\n", "mg.build()\n", "mg.visualize()\n", "\n", "\n", "print(merge_padded)\n", "for k, inp_shape in merge_padded.input_shapes.items():\n", " print(f\" - Data from {k.resolve()}: {inp_shape}\")\n", "print(f\" - Merged output shape: {merge_padded.output_shape}\")" ] }, { "cell_type": "markdown", "id": "34", "metadata": {}, "source": [ "### Without Padding\n", "\n", "If `pad_inputs=False` (the default) and non-concat dimensions don't match, a `ValueError` is raised at build time with a helpful message." ] }, { "cell_type": "code", "execution_count": null, "id": "35", "metadata": {}, "outputs": [], "source": [ "merge_no_pad = ConcatNode(\n", " label=\"NoPadMerge\",\n", " upstream_refs=[enc_wide, enc_tall],\n", " concat_axis=0,\n", " pad_inputs=False,\n", ")\n", "\n", "try:\n", " merge_no_pad.build(\n", " input_shapes={\n", " enc_wide.reference(): enc_wide.output_shape,\n", " enc_tall.reference(): enc_tall.output_shape,\n", " },\n", " includes_batch_dim=False,\n", " )\n", "except ValueError as e:\n", " print(f\"ValueError: {e}\")" ] }, { "cell_type": "markdown", "id": "36", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "37", "metadata": {}, "source": [ "(04-create-mergenode-building-a-graph-with-mergenodes)=\n", "## Building a Graph with MergeNodes" ] }, { "cell_type": "markdown", "id": "38", "metadata": {}, "source": [ "\n", "In practice, you don't need to build `MergeNode`s manually. `ModelGraph.build()` handles\n", "shape inference and build order for all nodes, including merge nodes.\n", "\n", "We already saw this in the `create_model_graph` helper above. Here's the full pattern with a non-default target strategy:" ] }, { "cell_type": "code", "execution_count": null, "id": "39", "metadata": {}, "outputs": [], "source": [ "enc_a = ModelNode(\n", " label=\"EncoderA\",\n", " model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "enc_b = ModelNode(\n", " label=\"EncoderB\",\n", " model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "\n", "merge = ConcatNode(\n", " label=\"Merge\",\n", " upstream_refs=[enc_a, enc_b],\n", " concat_axis=-1, # concat features along last axis\n", " concat_axis_targets=enc_a, # use only enc_a's targets\n", ")\n", "\n", "regressor = ModelNode(\n", " label=\"Regressor\",\n", " model=SequentialMLP(n_layers=1, hidden_dim=8),\n", " upstream_ref=merge,\n", ")\n", "\n", "mg = ModelGraph(\n", " nodes=[enc_a, enc_b, merge, regressor],\n", " optimizer=Optimizer(opt=\"adam\", backend=\"torch\"),\n", ")\n", "mg.build()\n", "mg.visualize()\n", "\n", "print(\"Graph built successfully!\")\n", "for node in mg.nodes.values():\n", " in_shapes = None\n", " out_shape = None\n", " if hasattr(node, \"input_shape\"):\n", " in_shapes = node.input_shape\n", " elif hasattr(node, \"intput_shapes\"):\n", " in_shapes = node.input_shapes\n", " if hasattr(node, \"output_shape\"):\n", " out_shape = node.output_shape\n", "\n", " print(f\" {node.label}: {in_shapes} -> {out_shape}\")" ] }, { "cell_type": "markdown", "id": "40", "metadata": {}, "source": [ "Here `concat_axis_targets=enc_a` tells the merge node to use **EncoderA's targets** as the output targets instead of concatenating targets from both inputs. This is passed as an `ExperimentNode` instance (which is automatically converted to an `ExperimentNodeReference`)." ] }, { "cell_type": "markdown", "id": "41", "metadata": {}, "source": [ "The graph correctly infers:\n", "- **EncoderA:** input `(1, 10)` → output `(1, 8)`\n", "- **EncoderB:** input `(1, 10)` → output `(1, 4)`\n", "- **Merge:** `(1, 8)` + `(1, 4)` along last axis → `(1, 12)` features, targets selected from EncoderA\n", "- **Regressor:** input `(1, 12)` → output `(1, 1)`" ] }, { "cell_type": "markdown", "id": "42", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "43", "metadata": {}, "source": [ "(04-create-mergenode-forward-pass)=\n", "## Forward Pass" ] }, { "cell_type": "markdown", "id": "44", "metadata": {}, "source": [ "\n", "Forward passes through a `MergeNode` work the same as through a `ModelNode`. The merge\n", "accepts `SampleData`, `RoleData`, or `Batch` and returns the same type.\n", "\n", "When running through a `ModelGraph`, this is all handled automatically. Below we trace\n", "a manual forward pass to show how data flows through each node, using the `concat_axis_targets=enc_a` merge node from {ref}`04-create-mergenode-building-a-graph-with-mergenodes`." ] }, { "cell_type": "code", "execution_count": null, "id": "45", "metadata": {}, "outputs": [], "source": [ "# Create SampleData from the FeatureSet reference (already imported above)\n", "fsv = fs_ref.resolve()\n", "sample_data = SampleData(\n", " features=fsv.get_features(fmt=DataFormat.TORCH),\n", " targets=fsv.get_targets(fmt=DataFormat.TORCH),\n", ")\n", "print(f\"Input features shape: {sample_data.features.shape}\")\n", "print(f\"Input targets shape: {sample_data.targets.shape}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "46", "metadata": {}, "outputs": [], "source": [ "# Trace through each node manually\n", "with torch.no_grad():\n", " out_a = enc_a(sample_data)\n", " out_b = enc_b(sample_data)\n", " print(f\"EncoderA features: {out_a.features.shape}\")\n", " print(f\"EncoderA targets: {out_a.targets.shape}\")\n", " print(f\"EncoderB features: {out_b.features.shape}\")\n", " print(f\"EncoderB targets: {out_b.targets.shape}\")\n", "\n", " # Merge expects a dict of {reference: data}\n", " merge_inputs = {\n", " enc_a.reference(): out_a,\n", " enc_b.reference(): out_b,\n", " }\n", " out_merge = merge.forward(merge_inputs)\n", " print(f\"\\nMerge features: {out_merge.features.shape}\")\n", " print(f\"Merge targets: {out_merge.targets.shape} (selected from EncoderA)\")\n", "\n", " out_final = regressor(out_merge)\n", " print(f\"Regressor output: {out_final.features.shape}\")" ] }, { "cell_type": "markdown", "id": "47", "metadata": {}, "source": [ "Notice that features are concatenated along the last axis (`concat_axis=-1`): `(1,8) + (1,4) -> (1,12)`. Because `concat_axis_targets=enc_a`, the merged targets have the same shape as the original FeatureSet targets `(200, 1)` — they are not concatenated.\n", "\n", "Compare this with the default behavior (shown in {ref}`04-create-mergenode-target-and-tag-aggregation-strategies`), where targets would be `(200, 2)` due to concatenation." ] }, { "cell_type": "markdown", "id": "48", "metadata": {}, "source": [ "### Verifying Padded Forward Pass\n", "\n", "Let's verify that the padded merge node (from {ref}`04-create-mergenode-padding-mismatched-dimensions`) produces the expected shapes and that padded regions are filled with zeros." ] }, { "cell_type": "code", "execution_count": null, "id": "49", "metadata": {}, "outputs": [], "source": [ "print(f\"PaddedMerge output_shape: {merge_padded.output_shape}\")\n", "\n", "# Forward pass\n", "with torch.no_grad():\n", " out_wide = enc_wide(sample_data)\n", " out_tall = enc_tall(sample_data)\n", " print(f\"WideEncoder output: {out_wide.features.shape}\")\n", " print(f\"TallEncoder output: {out_tall.features.shape}\")\n", "\n", " padded_inputs = {\n", " enc_wide.reference(): out_wide,\n", " enc_tall.reference(): out_tall,\n", " }\n", " out_padded = merge_padded.forward(padded_inputs)\n", " print(f\"Padded merge output: {out_padded.features.shape}\")\n", "\n", " # Verify padding: TallEncoder (3, 6) is padded to (3, 8)\n", " # After concat on axis 0: rows 0:2 from WideEncoder, rows 2:5 from TallEncoder\n", " # Columns 6:8 of TallEncoder's contribution should be zero\n", " padded_region = out_padded.features[:, 2:5, 6:8].numpy()\n", " print(f\"Padded region values (should be all zeros): {np.unique(padded_region)}\")" ] }, { "cell_type": "markdown", "id": "50", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "51", "metadata": {}, "source": [ "(04-create-mergenode-key-properties-and-methods)=\n", "## Key Properties and Methods" ] }, { "cell_type": "markdown", "id": "52", "metadata": {}, "source": [ "\n", "### MergeNode (base class)\n", "\n", "| Property / Method | Description |\n", "|-------------------|-------------|\n", "| `.is_built` | Whether shape inference has been completed. |\n", "| `.output_shape` | Output shape (no batch dim) after merging. |\n", "| `.input_shapes` | Dict mapping each upstream reference to its input shape. |\n", "| `.backend` | Backend enum, or `None` if not set. |\n", "| `merge(x)` | Forward pass on a list of `SampleData`, `RoleData`, or `Batch`. |\n", "| `forward(inputs)` | Forward pass from a dict of `{reference: data}`. |\n", "| `apply_merge(values, domain=...)` | Abstract method that subclasses implement. Receives a `domain` string to allow per-domain merge logic. |\n", "\n", "### ConcatNode\n", "\n", "| Property / Method | Description |\n", "|-------------------|-------------|\n", "| `.concat_axis` | The axis along which **features** are concatenated (`int`). |\n", "| `.target_strategy` | Strategy for merging targets: `int` (concat axis), `MergeStrategy`, or `ExperimentNodeReference`. |\n", "| `.tags_strategy` | Strategy for merging tags (same types as `target_strategy`). |\n", "| `.target_axis` | Convenience property — returns the int axis when `target_strategy` is `int`. Raises `TypeError` otherwise. |\n", "| `.tags_axis` | Convenience property — returns the int axis when `tags_strategy` is `int`. Raises `TypeError` otherwise. |\n", "| `.pad_inputs` | Whether padding is enabled. |\n", "| `.pad_mode` | Padding mode (`\"constant\"`, `\"reflect\"`, etc.). |\n", "| `.pad_value` | Fill value for constant padding. |\n", "\n", "### MergeStrategy Enum\n", "\n", "| Value | Description |\n", "|-------|-------------|\n", "| `MergeStrategy.CONCAT` | Concatenate along an axis (requires an int axis). |\n", "| `MergeStrategy.FIRST` | Use data from the first upstream input. |\n", "| `MergeStrategy.LAST` | Use data from the last upstream input. |\n", "| `MergeStrategy.MEAN` | Element-wise mean across inputs (shapes must match). |\n", "\n", "### Next Steps\n", "\n", "- **ModelGraph:** See how `ModelNode`s and `MergeNode`s are composed into a full computational graph with automatic shape inference.\n", "- **Experiment:** Use `Experiment` to combine a `ModelGraph` with training phases, loss functions, and evaluation.\n", "- **Custom MergeNode:** Subclass `MergeNode` and implement `apply_merge()` for custom merging strategies (e.g., averaging, attention-based fusion)." ] } ], "metadata": { "kernelspec": { "display_name": ".venv (3.10.18)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }