Source code for modularml.core.topology.merge_nodes.concat_node

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np

from modularml.core.data.schema_constants import (
    DOMAIN_FEATURES,
    DOMAIN_SAMPLE_UUIDS,
    DOMAIN_TAGS,
    DOMAIN_TARGETS,
)
from modularml.core.topology.merge_nodes.merge_node import MergeNode
from modularml.core.topology.merge_nodes.merge_strategy import MergeStrategy
from modularml.utils.data.conversion import convert_to_format
from modularml.utils.data.data_format import (
    _TENSORLIKE_FORMATS,
    DataFormat,
    format_is_tensorlike,
    get_data_format_for_backend,
    normalize_format,
)
from modularml.utils.environment.optional_imports import ensure_tensorflow, ensure_torch
from modularml.utils.logging.warnings import warn
from modularml.utils.nn.padding import PadMode, map_pad_mode_to_backend

if TYPE_CHECKING:
    from modularml.core.data.sample_data import SampleData
    from modularml.core.experiment.experiment_node import ExperimentNode
    from modularml.core.references.experiment_reference import ExperimentNodeReference


# Type alias for strategy parameters
StrategyType = int | str | MergeStrategy


[docs] class ConcatNode(MergeNode): """ A merge stage that concatenates multiple inputs along a specified axis. Description: This stage merges tensors by concatenating them along a specified axis. It supports automatic padding of non-concat dimensions to align shapes, allowing for flexible merging even when inputs vary in size. Padding behavior can be controlled via mode (e.g., 'constant', 'reflect', 'replicate') and value. For the targets and tags domains, non-concatenation merge strategies can be used instead of axis-based concatenation. For example, `"first"` selects targets from the first input, `"mean"` computes an element-wise average, or an `ExperimentNodeReference` can be passed to select targets from a specific upstream input. Attributes: label (str): Unique identifier for this node. upstream_refs (list[ExperimentNode | ExperimentNodeReference]): Upstream node references from which inputs will be received. concat_axis (int): The axis along which to concatenate feature inputs. target_strategy (int | MergeStrategy | ExperimentNodeReference): Strategy for merging targets. An int means concatenation along that axis; a MergeStrategy applies an aggregation; an ExperimentNodeReference selects targets from a specific upstream input. tags_strategy (int | MergeStrategy | ExperimentNodeReference): Strategy for merging tags (same semantics as target_strategy). pad_inputs (bool, optional): Whether to pad inputs before merging. Defaults to False. pad_mode (PadMode, optional): Padding mode ('constant', 'reflect', 'replicate', etc.). Defaults to 'constant'. pad_value (float, optional): Value to use for constant padding. Defaults to 0.0. """
[docs] def __init__( self, label: str, upstream_refs: list[ExperimentNode | ExperimentNodeReference], concat_axis: int = 0, *, concat_axis_targets: StrategyType | ExperimentNodeReference = -1, concat_axis_tags: StrategyType | ExperimentNodeReference = -1, pad_inputs: bool = False, pad_mode: str | PadMode = "constant", pad_value: float = 0.0, node_id: str | None = None, register: bool = True, ): """ Initialize a ConcatNode. Args: label (str): Unique identifier for this node. upstream_refs (list[ExperimentNode | ExperimentNodeReference]): Upstream node references from which inputs will be received. concat_axis (int): The axis along which to concatenate feature inputs. Does not include the batch dimension. That is, for shape (with batch) of (32,1,16), axis=0 refers to "1". The examples below omit the batch dimension. - ``axis=0``: concat along features. Example: ``(1, 16) + (1, 16) -> (2, 16)``. - ``axis > 1``: concat within features. Example: ``(1, 16, 16) + (1, 16, 8) -> (1, 16, 24)``. All data must have at least ``concat_axis`` dimensions. - ``axis=-1``: concat along the last axis of all data. Example: ``(1, 16, 16) + (1, 16, 8) -> (1, 16, 24)``. concat_axis_targets (int | str | MergeStrategy | ExperimentNodeReference): Strategy for merging the targets domain. Accepts: - ``int``: concatenate along this axis (same semantics as ``concat_axis``). Defaults to ``-1`` (last axis). - ``str`` or ``MergeStrategy``: apply an aggregation strategy. Supported values are ``"first"``, ``"last"``, and ``"mean"``. - ``ExperimentNodeReference``: select targets from the upstream input matching this reference. concat_axis_tags (int | str | MergeStrategy | ExperimentNodeReference): Strategy for merging the tags domain. Same semantics as `concat_axis_targets`. Defaults to -1 (last axis). pad_inputs (bool, optional): Whether to pad inputs before merging. Defaults to False. pad_mode (PadMode, optional): Padding mode; one of {"constant", "reflect", "replicate", "circular"}. Defaults to "constant". pad_value (float, optional): Value to use for constant padding. Defaults to 0.0. node_id (str, optional): Used only for de-serialization. register (bool, optional): Used only for de-serialization. """ super().__init__( label=label, upstream_refs=upstream_refs, node_id=node_id, register=register, ) self.concat_axis = int(concat_axis) self.target_strategy = self._normalize_strategy(concat_axis_targets) self.tags_strategy = self._normalize_strategy(concat_axis_tags) self.pad_inputs = bool(pad_inputs) self.pad_mode = pad_mode if isinstance(pad_mode, PadMode) else PadMode(pad_mode) self.pad_value = pad_value if self.pad_mode not in [PadMode.CONSTANT]: msg = f"Pad mode is not supported yet: {self.pad_mode}" raise NotImplementedError(msg)
# ================================================ # Strategy normalization # ================================================ @staticmethod def _normalize_strategy( value: int | str | MergeStrategy | ExperimentNodeReference, ) -> int | MergeStrategy | ExperimentNodeReference: """ Normalize a strategy parameter to its canonical form. Args: value: Raw strategy value from the constructor. Returns: int | MergeStrategy | ExperimentNodeReference: Normalized strategy. Raises: ValueError: If the value is an unrecognized string. """ from modularml.core.experiment.experiment_node import ExperimentNode from modularml.core.references.experiment_reference import ( ExperimentNodeReference, ) if isinstance(value, ExperimentNode): return value.reference() if isinstance(value, ExperimentNodeReference): return value if isinstance(value, int): return value if isinstance(value, MergeStrategy): return value if isinstance(value, str): return MergeStrategy(value) msg = ( "Expected type to be one of int, str, MergeStrategy, or " f"ExperimentNodeReference. Received: {type(value)}." ) raise TypeError(msg) @property def target_axis(self) -> int: """ Target concatenation axis (only valid when target_strategy is int). Raises: TypeError: If the target strategy is not int-based. """ if isinstance(self.target_strategy, int): return self.target_strategy msg = ( f"target_axis is not available when target_strategy is " f"{self.target_strategy!r}. Use target_strategy instead." ) raise TypeError(msg) @property def tags_axis(self) -> int: """ Tags concatenation axis (only valid when tags_strategy is int). Raises: TypeError: If the tags strategy is not int-based. """ if isinstance(self.tags_strategy, int): return self.tags_strategy msg = ( f"tags_axis is not available when tags_strategy is " f"{self.tags_strategy!r}. Use tags_strategy instead." ) raise TypeError(msg) # ================================================ # Padding & Validation # ================================================ def _pad_inputs( self, values: list[Any], concat_axis: int, fmt: DataFormat | None = None, ) -> list[Any]: """ Pad all inputs along non-concat dimensions to match the largest shape. Description: This method applies backend-specific padding logic to ensure that all tensors have the same shape (except for the concat axis) before concatenation. Args: values (list[Any]): List of tensors to be padded. concat_axis (int): The axis to concatenate along, relative to the actual tensors in `values`. fmt (DataFormat | None): The data format expected for the returned tensor. If None, the data format will be inferred from the `backend` property. Defaults to None. Returns: list[Any]: Padded tensors. Raises: ValueError: If the backend is unsupported or if padding fails. """ # Determine max shape along each axis max_shape = np.max([np.array(v.shape) for v in values], axis=0) # Get padding requirements for each input tensor padded = [] for v in values: pad_width = [] for dim, current_shape in enumerate(v.shape): if dim == concat_axis: pad_width.append((0, 0)) # No padding on concat axis else: diff = max_shape[dim] - current_shape pad_width.append((0, diff)) # Apply backend-specific pad function if fmt == DataFormat.TORCH: # Verify that torch is installed torch = ensure_torch() torch_pad = [ p for dims in reversed(pad_width) for p in dims ] # reverse & flatten p = torch.nn.functional.pad( input=v, pad=torch_pad, mode=map_pad_mode_to_backend( mode=self.pad_mode, backend=self._backend, ), value=self.pad_value, ) padded.append(p) elif fmt == DataFormat.TENSORFLOW: # Verify that tf is installed tf = ensure_tensorflow() tf_pad = tf.constant(pad_width) p = tf.pad( tensor=v, paddings=tf_pad, mode=map_pad_mode_to_backend( mode=self.pad_mode, backend=self._backend, ), constant_values=self.pad_value, ) padded.append(p) else: # Default to numpy padding p = np.pad( array=v, pad_width=pad_width, mode=map_pad_mode_to_backend( mode=self.pad_mode, backend=self._backend, ), constant_values=self.pad_value, ) padded.append(p) return padded def _validate_dims( self, values: list[Any], concat_axis: int, ): """ Verfies that all dimensions can be concatenated along the specified axis. Args: values (list[Any]): List of tensors to concatenate. concat_axis (int): The axis to concatenate along, relative to the actual tensors in `values`. """ reference_shape = values[0].shape for i, v in enumerate(values[1:], start=1): for dim, (ref_dim, val_dim) in enumerate( zip(reference_shape, v.shape, strict=True), ): if dim == concat_axis: continue if ref_dim != val_dim: msg = ( f"Mismatch in non-concat dimension {dim} between input 0 and {i}: " f"{ref_dim} vs {val_dim}. Set `pad_inputs=True` to auto-align. " f"{reference_shape} vs {v.shape} on axis={concat_axis}" ) raise ValueError(msg) # ================================================ # Concatenation (apply_merge) # ================================================
[docs] def apply_merge( self, values: list[Any], *, includes_batch_dim: bool = True, fmt: DataFormat | None = None, domain: str = DOMAIN_FEATURES, ) -> Any: """ Concatenate input tensors along the configured axis. Description: Optionally pads the inputs to align non-concat dimensions before applying backend-specific concatenation. This method handles axis-based concatenation only. Non-concat strategies (e.g., `"first"`, `"mean"`) are handled by :meth:`_merge_sample_data` before reaching this method. Args: values (list[Any]): A list of backend-specific tensors to be merged. includes_batch_dim (bool): Whether the input values have a batch dimension. Defaults to True. fmt (DataFormat | None): The data format expected for the returned tensor. If None, the data format will be inferred from the `backend` property. Defaults to None. domain (str, optional): The domain in which the data belongs. This allows for domain-specific merge logic (e.g., different concat axes for each domain). Returns: Any: Concatenated tensor, in the specified format. """ # Get data format (`fmt`) if fmt is None: if self._backend is None: fmt = DataFormat.NUMPY else: fmt = get_data_format_for_backend(self._backend) else: fmt = normalize_format(fmt=fmt) if not format_is_tensorlike(fmt): msg = ( f"Format {fmt} does support tensors. " f"Must be one of: {_TENSORLIKE_FORMATS}." ) raise ValueError(msg) # Ensure all elements in list are converted to fmt values = [convert_to_format(x, fmt=fmt) for x in values] # Get domain-specific axis if domain == DOMAIN_FEATURES: effective_axis = self.concat_axis elif domain == DOMAIN_TARGETS: effective_axis = self.target_strategy elif domain == DOMAIN_TAGS: effective_axis = self.tags_strategy elif domain == DOMAIN_SAMPLE_UUIDS: effective_axis = -1 else: msg = f"Unknown domain: {domain}" raise ValueError(msg) # Get true concat axis (-1 = max dimension) if effective_axis == -1: effective_axis = max([len(x.shape) for x in values]) - 1 else: # Adjust for batch dimension effective_axis += 1 if includes_batch_dim else 0 # Apply padding if defined if self.pad_inputs: values = self._pad_inputs(values, fmt=fmt, concat_axis=effective_axis) # Ensure tensors can be concatenated along `effective_axis` self._validate_dims(values=values, concat_axis=effective_axis) # Apply backend-specific concatenation if fmt == DataFormat.TORCH: # Verify that torch is installed torch = ensure_torch() return torch.cat(tensors=values, dim=effective_axis) if fmt == DataFormat.TENSORFLOW: # Verify that tf is installed tf = ensure_tensorflow() return tf.concat(values=values, axis=effective_axis) # Default: numpy concatenation return np.concatenate(values, axis=effective_axis)
# ================================================ # Non-concat merge strategies # ================================================ def _apply_strategy( self, values: list[Any], strategy: MergeStrategy, fmt: DataFormat, ) -> Any: """ Apply a non-concat merge strategy to a list of tensors. Args: values (list[Any]): Non-None tensors to aggregate. strategy (MergeStrategy): The aggregation strategy to apply. fmt (DataFormat): Data format for the output tensor. Returns: Any: Aggregated tensor. """ values = [convert_to_format(x, fmt=fmt) for x in values] if strategy == MergeStrategy.FIRST: return values[0] if strategy == MergeStrategy.LAST: return values[-1] if strategy == MergeStrategy.MEAN: return self._compute_mean(values, fmt=fmt) msg = f"Unsupported merge strategy: {strategy}" raise ValueError(msg) def _compute_mean( self, values: list[Any], fmt: DataFormat, ) -> Any: """ Compute element-wise mean across inputs. Args: values (list[Any]): Tensors to average. All must have the same shape. fmt (DataFormat): Data format of the tensors. Returns: Any: Mean tensor. """ if fmt == DataFormat.TORCH: torch = ensure_torch() return torch.stack(values).mean(dim=0) if fmt == DataFormat.TENSORFLOW: tf = ensure_tensorflow() return tf.reduce_mean(tf.stack(values), axis=0) # Default: numpy return np.mean(np.stack(values), axis=0) def _select_by_reference( self, data: list[SampleData], attr: str, ref: ExperimentNodeReference, ) -> Any | None: """ Select a domain's data from the input matching the given reference. Args: data (list[SampleData]): Sorted list of SampleData inputs (same order as the references added to this node). attr (str): The domain attribute name (e.g., "targets"). ref (ExperimentNodeReference): The upstream reference to select from. Returns: Any | None: The selected tensor, or None if that input has no data for the given domain. Raises: RuntimeError: If called outside of a forward pass. ValueError: If the reference is not among upstream refs. """ # Find the index matching the reference for idx, sorted_ref in enumerate(self.get_upstream_refs()): if ref == sorted_ref: return getattr(data[idx], attr) available = [r.node_id for r in self.get_upstream_refs()] msg = f"Reference '{ref.node_id}' is not among upstream refs: {available}" raise ValueError(msg) # ================================================ # Domain-aware sample data merging # ================================================ def _merge_sample_data( self, data: list[SampleData], fmt: DataFormat, ) -> SampleData: """ Merge SampleData with flexible per-domain strategies. Description: Features are always concatenated along `self.concat_axis`. Targets and tags support non-concat strategies (e.g., "first", "mean", or select-by-reference). Sample UUIDs are always concatenated along the last axis. Args: data (list[SampleData]): Input SampleData objects to merge. fmt (DataFormat): Data format for features and targets. Returns: SampleData: Merged output. """ from modularml.core.data.sample_data import SampleData from modularml.core.references.experiment_reference import ( ExperimentNodeReference, ) merged_attrs: dict[str, Any] = {} domain_config: list[ tuple[str, DataFormat, int | MergeStrategy | ExperimentNodeReference] ] = [ (DOMAIN_FEATURES, fmt, self.concat_axis), (DOMAIN_TARGETS, fmt, self.target_strategy), (DOMAIN_TAGS, DataFormat.NUMPY, self.tags_strategy), (DOMAIN_SAMPLE_UUIDS, DataFormat.NUMPY, -1), ] for attr, attr_fmt, strategy in domain_config: values = [getattr(d, attr) for d in data] has_data = [v is not None for v in values] # If none have data, skip if not any(has_data): continue if isinstance(strategy, int): # Concatenation mode (original behavior) if not all(has_data): msg = ( f"Not all inputs have data for the `{attr}` attribute. " f"The merged results will not contain `{attr}` data." ) warn(msg, stacklevel=2) continue merged_attrs[attr] = self.apply_merge( values=values, includes_batch_dim=True, fmt=attr_fmt, domain=attr, ) elif isinstance(strategy, MergeStrategy): # Non-concat aggregation — silently filter out Nones non_none = [v for v in values if v is not None] if not non_none: continue merged_attrs[attr] = self._apply_strategy( values=non_none, strategy=strategy, fmt=attr_fmt, ) elif isinstance(strategy, ExperimentNodeReference): # Select from a specific upstream input selected = self._select_by_reference( data=data, attr=attr, ref=strategy, ) if selected is not None: merged_attrs[attr] = selected return SampleData(data=merged_attrs, kind="output") # ================================================ # Configurable # ================================================
[docs] def get_config(self) -> dict[str, Any]: """ Return configuration details required to reconstruct this node. Returns: dict[str, Any]: Configuration used to reconstruct the node. Keys must be strings. """ from modularml.core.references.experiment_reference import ( ExperimentNodeReference, ) def _get_strategy_cfg(arg) -> dict[str, Any]: strat_cfg = { "type": None, "value": None, } if isinstance(arg, MergeStrategy): strat_cfg["type"] = "MergeStrategy" strat_cfg["value"] = arg.value elif isinstance(arg, ExperimentNodeReference): strat_cfg["type"] = "ExperimentNodeReference" strat_cfg["value"] = arg.get_config() else: strat_cfg["type"] = "none" strat_cfg["value"] = arg return strat_cfg cfg = super().get_config() cfg.update( { "merge_node_type": self.__class__.__qualname__, "concat_axis": self.concat_axis, "target_strategy": _get_strategy_cfg(self.target_strategy), "tags_strategy": _get_strategy_cfg(self.tags_strategy), "pad_inputs": self.pad_inputs, "pad_mode": self.pad_mode.value, "pad_value": self.pad_value, }, ) return cfg
[docs] @classmethod def from_config( cls, config: dict, *, register: bool = True, ) -> MergeNode: """ Construct a ConcatNode from a configuration dictionary. Args: config (dict[str, Any]): Configuration details. Keys must be strings. register (bool): Whether to register the reconstructed node. Returns: Callback: Reconstructed ConcatNode. """ from modularml.core.references.experiment_reference import ( ExperimentNodeReference, ) def _decode_strategy_cfg( strat_cfg, ) -> int | str | ExperimentNodeReference | MergeStrategy: strat_type = strat_cfg.get("type") strat_value = strat_cfg.get("value") if strat_type == "none": return strat_value if strat_type == "MergeStrategy": return MergeStrategy(value=strat_value) if strat_type == "ExperimentNodeReference": return ExperimentNodeReference.from_config(strat_value) msg = "Invalid merge strategy config." raise ValueError(msg) cb_cls_name = config.get("merge_node_type") if cb_cls_name != cls.__qualname__: msg = f"Invalid config for {cls.__qualname__}." raise ValueError(msg) return cls( label=config["label"], upstream_refs=config["upstream_refs"], concat_axis=config["concat_axis"], concat_axis_targets=_decode_strategy_cfg(config["target_strategy"]), concat_axis_tags=_decode_strategy_cfg(config["tags_strategy"]), pad_inputs=config["pad_inputs"], pad_mode=config["pad_mode"], pad_value=config["pad_value"], node_id=config.get("node_id"), register=register, )