Model Graph Design#

Most ML frameworks treat a model as a single callable: data goes in, predictions come out. ModularML takes a different approach by representing model architectures as directed acyclic graphs (DAGs) of discrete nodes. This design makes it possible to compose multi-stage pipelines, mix backends, selectively freeze stages, and reason about data flow before any training begins.

This document explains the principles behind graph construction, how nodes connect and communicate, and why certain design choices were made. For a high-level overview of where graphs fit within the broader framework, see Architecture Overview. For how graphs interact with the training lifecycle, see Training Phases.

Why graphs instead of sequential models?#

A sequential model is a special case of a graph—a chain of nodes with no branching. Graphs generalize this to support architectures that sequential APIs cannot express cleanly:

  • Multi-input models that consume data from separate FeatureSets (e.g., tabular features and time-series signals fused at a later stage).

  • Multi-stage pipelines where an encoder feeds into both a classifier and a reconstruction decoder, each with its own loss.

  • Ensemble and fusion patterns where multiple upstream models are merged before a final prediction head.

  • Selective freezing where some stages are held fixed while others are trained, as in transfer learning or feature extraction.

A sequential API can approximate some of these patterns through workarounds, but the graph representation makes them first-class. Each node is an independent unit with its own identity, backend, optimizer, and freeze state. The graph manages the wiring and execution order, freeing the user from manually orchestrating data flow between stages.

The node hierarchy#

All graph nodes inherit from a common base and specialize through a focused hierarchy:

        classDiagram
    ExperimentNode <|-- GraphNode
    GraphNode <|-- ComputeNode
    ComputeNode <|-- ModelNode
    ComputeNode <|-- MergeNode
    MergeNode <|-- ConcatNode

    class ExperimentNode {
        +node_id: str
        +label: str
        +reference()
    }
    class GraphNode {
        +upstream_refs
        +downstream_refs
        +max_upstream_refs
        +max_downstream_refs
    }
    class ComputeNode {
        +forward(inputs)
        +build(input_shapes)
        +get_input_data(inputs, outputs)
    }
    class ModelNode {
        +model: BaseModel
        +optimizer: Optimizer
        +is_frozen: bool
        +train_step(ctx, losses)
        +eval_step(ctx, losses)
    }
    class MergeNode {
        +apply_merge(values)
        +merge(x)
    }
    class ConcatNode {
        +concat_axis: int
        +pad_inputs: bool
    }
    

Each layer adds a specific responsibility:

ExperimentNode provides identity. Every node gets a UUID and a human-readable label, and registers itself with the active ExperimentContext so it can be found by reference later. This is the common base shared by all experiment-aware objects, including FeatureSets.

GraphNode adds topology. It maintains lists of upstream and downstream references and enforces connection limits. A ModelNode allows exactly one upstream connection (it consumes a single input), while a MergeNode allows unlimited upstream connections (it combines multiple inputs). These limits are structural guarantees, not runtime checks—they are enforced at construction time.

ComputeNode adds computation. It defines the forward() interface and the input resolution logic that determines where a node gets its data from. It also defines build(), which initializes the node’s internal state based on inferred shapes.

ModelNode and MergeNode are the two concrete compute node types. A ModelNode wraps a single BaseModel instance and optionally an Optimizer. A MergeNode combines outputs from multiple upstream nodes using a merge strategy (currently concatenation via ConcatNode, with the architecture open to other strategies like averaging or attention-based fusion).

This hierarchy is deliberately shallow. Rather than deep inheritance chains, the design favors composition: a ModelNode contains a BaseModel rather than being one. This separation means the graph topology is independent of the ML backend—the same graph structure works whether the underlying models are PyTorch, TensorFlow, or scikit-learn.

References: wiring without coupling#

Nodes are connected through references rather than direct object pointers. A reference is an immutable, frozen dataclass that identifies a target node by label or ID and is resolved against an ExperimentContext at execution time.

This indirection exists for several practical reasons:

Serialization. A graph’s structure can be saved as configuration and reconstructed later. Direct object references cannot survive serialization; symbolic references can.

Cross-validation. When running k-fold cross-validation, the same graph structure is applied to different data folds. References allow the graph to be defined once and bound to different FeatureSetViews at execution time, rather than rebuilding the graph for each fold.

Construction order independence. A ModelNode can reference a FeatureSet that hasn’t been created yet, as long as it exists in the ExperimentContext by the time the reference is resolved. This flexibility is important when building complex experiments programmatically.

There are two primary reference types:

FeatureSetReference points to specific columns of a FeatureSet. It carries optional features, targets, and tags tuples that act as column selectors. When resolved, it produces a FeatureSetView filtered to the selected columns. If no selectors are specified, the full FeatureSet is used. This is how a ModelNode declares which data it consumes without holding a direct pointer to the data container.

GraphNodeReference (and its parent ExperimentNodeReference) points to another node in the graph. This is used for downstream wiring—when one ModelNode’s output feeds into another ModelNode or a MergeNode. Resolution returns the target node itself, and the graph’s forward pass uses this to look up cached outputs.

References prefer resolution by node_id (the UUID) when available, falling back to node_label. This means that label changes do not break existing wiring as long as IDs are preserved, which matters for serialization round-trips.

How the graph executes#

Building: shape inference before training#

Before any data flows through the graph, calling ModelGraph.build() performs shape inference. This step propagates tensor dimensions from the FeatureSet inputs through every node, ensuring that each node’s model is initialized with the correct input and output shapes.

The algorithm traverses nodes in topological order:

  1. For each head node (those whose upstream reference is a FeatureSet), the input shape is determined from the FeatureSet’s column shapes.

  2. For each subsequent node, the input shape comes from the cached output shape of its upstream node.

  3. Each node’s build() method initializes its internal model with the resolved shapes.

  4. For MergeNodes, shape inference uses a dummy forward pass—synthetic tensors of the correct shape are passed through the merge operation to determine the output dimensions. This is necessary because merge operations (concatenation, padding) can produce output shapes that are not trivially derivable from input shapes alone.

  5. Optimizers are constructed after all models are built, since they need access to model parameters.

Shape inference catches dimension mismatches early—before any training data is loaded—which is particularly valuable for multi-stage graphs where a mismatch between stages might otherwise surface as a cryptic runtime error deep in a training loop.

Forward pass: topological traversal#

ModelGraph.forward() executes the graph by iterating through nodes in topological order. For each node:

  1. Input resolution. The node’s get_input_data() method determines where its input comes from (within a given Batch). If the upstream reference is a FeatureSetReference, the input is looked up in the inputs dictionary (keyed by (node_id, reference) pairs). If the upstream reference is a GraphNodeReference, the input is retrieved from the outputs dictionary, which caches each node’s output as it is computed.

  2. Computation. The node’s forward() method processes the input and produces an output Batch. For a ModelNode, this means passing the feature tensors through the wrapped model while preserving targets, tags, and UUIDs. For a MergeNode, this means applying the merge strategy across all upstream outputs.

  3. Output caching. The result is stored in the outputs dictionary, keyed by node ID, making it available to downstream nodes.

The forward pass can optionally operate on a subset of nodes via the active_nodes parameter. When specified, the graph automatically includes all upstream dependencies of the active nodes, ensuring that every active node has its inputs available.

Partial execution and active nodes#

Not every forward pass needs to traverse the entire graph. The active_nodes parameter allows execution of a subgraph while still respecting dependencies. This is useful in scenarios like:

  • Evaluating only a specific output branch of a multi-output graph.

  • Running inference through an encoder without executing the decoder.

  • Debugging a specific stage by isolating its execution.

The graph resolves the full dependency chain automatically—specifying a downstream node as active will include all upstream nodes it depends on.

Training strategies#

The graph supports two distinct training strategies, chosen based on whether a global optimizer is provided.

Stage-wise training (no global optimizer)#

When each ModelNode has its own optimizer (or no optimizer), training proceeds node by node in topological order. Each trainable, unfrozen node performs its own forward pass, loss computation, backward pass, and optimizer step independently. Frozen nodes and non-trainable nodes perform forward-only passes to propagate their outputs downstream.

This is the appropriate strategy when:

  • Different stages require different optimizers or learning rates.

  • Some stages are pretrained and frozen.

  • Gradient flow between stages is not needed (e.g., the upstream stage’s output is treated as a fixed input).

Graph-wise training (global optimizer)#

When a global optimizer is provided to the ModelGraph, training performs a single forward pass through the entire graph, computes all losses, and executes a single backward pass that propagates gradients across all unfrozen nodes. This enables end-to-end gradient flow, which is necessary when upstream and downstream stages need to co-adapt.

This strategy requires all trainable nodes to share the same backend (you cannot backpropagate through a PyTorch node into a TensorFlow node). The graph validates this constraint during build() and raises an error if backend-incompatible nodes would need to participate in the same gradient computation.

The choice between strategies reflects a real trade-off: stage-wise training is more flexible (mixed backends, independent optimization) but prevents gradient flow between stages. Graph-wise training enables end-to-end learning but constrains the graph to a single backend. ModularML makes this trade-off explicit rather than hiding it.

Backend neutrality in practice#

A ModelNode wraps a BaseModel, which is an abstract class with backend-specific subclasses: TorchBaseModel, TensorflowBaseModel, and ScikitWrapper. The graph topology itself is backend-agnostic—it operates on Batch objects and SampleData containers that can convert their contents between backends as needed.

When a node receives input from an upstream node using a different backend, the framework handles the conversion at the boundary. This means a graph can mix backends, though with the understanding that gradient flow stops at backend boundaries. In practice, most graphs use a single backend throughout, but the design does not enforce this—it simply makes the consequences explicit.

Each backend’s model wrapper is responsible for its own forward/backward logic:

  • PyTorch nodes use torch.no_grad() for evaluation and standard autograd for training.

  • TensorFlow nodes use tf.GradientTape for training and direct calls for evaluation.

  • Scikit-learn nodes use fit() / partial_fit() for training and predict() / transform() for evaluation.

This delegation means the graph does not need to understand backend internals—it simply calls the protocol methods (train_step, eval_step, fit_step) and lets each node handle its own execution semantics.

Merge nodes and multi-input architectures#

When a graph has multiple branches that need to converge, MergeNode provides the junction point. A MergeNode accepts unlimited upstream connections and combines their outputs into a single Batch for downstream consumption.

The current concrete implementation, ConcatNode, concatenates upstream outputs along a configurable axis. It handles several practical concerns:

Dimension alignment. When upstream outputs have different shapes on non-concatenation axes, ConcatNode can optionally pad inputs to align them. The padding mode and value are configurable.

Domain-aware merging. Features, targets, and tags can each use a different merge strategy. Features are always concatenated along the specified axis, but targets and tags support three strategies:

  • Axis-based concatenation (same as features).

  • Aggregation (first, last, or mean across upstream outputs).

  • Selection (use the targets/tags from a specific upstream node, identified by reference).

This flexibility reflects the reality that in multi-input architectures, feature fusion and target handling often have different requirements. Two input branches might contribute features that should be concatenated, while the target labels should come from only one of them.

Shape inference. MergeNode determines its output shape by running a dummy forward pass with synthetic data matching the upstream shapes. This approach is more robust than algebraic shape computation, especially when padding or non-trivial merge strategies are involved.

Freezing and selective training#

Any ModelNode can be frozen (node.freeze()) or unfrozen (node.unfreeze()), which controls whether its parameters participate in gradient computation and optimizer steps. Frozen nodes still execute their forward pass—they produce outputs that downstream nodes can consume—but they do not update their weights.

This mechanism supports common workflows:

  • Transfer learning: Freeze a pretrained encoder, train only the downstream head.

  • Feature extraction: Freeze all stages, use the graph purely for inference.

  • Progressive unfreezing: Gradually unfreeze stages during training (implementable via callbacks).

The graph’s training methods respect freeze state automatically. During stage-wise training, frozen nodes fall through to evaluation-only execution. During graph-wise training, frozen nodes are excluded from the optimizer’s parameter set, so the global backward pass does not compute their gradients.

Graph manipulation#

The ModelGraph provides methods for modifying graph structure after construction:

  • add_node() adds a new node to the graph.

  • remove_node() removes a node and re-routes its upstream connections to its downstream neighbors.

  • replace_node() swaps a node while preserving all existing connections.

  • insert_node_between() places a new node on an existing edge between two nodes.

  • insert_node_before() and insert_node_after() insert a node relative to an existing one.

These operations re-validate the graph’s connection integrity after each modification, ensuring the DAG invariant is maintained. They enable programmatic graph construction patterns where a base graph is modified dynamically—for example, inserting a normalization stage between an encoder and decoder, or replacing a model node with a different architecture while preserving the surrounding topology.

Graph validation#

The graph performs several validation passes during construction and building:

DAG validation. A depth-first search detects cycles. The graph must be acyclic; cyclic dependencies would create unresolvable execution orders.

Reachability. All nodes must be reachable from at least one head node. Unreachable nodes indicate wiring errors.

Connection limits. Each node type declares its maximum upstream and downstream connections. ModelNode allows exactly one upstream reference; MergeNode allows unlimited. Violations are caught at construction time.

Backend compatibility. When a global optimizer is specified, the graph verifies that all trainable nodes share a compatible backend. Mixed backends are permitted only under stage-wise training.

Reference resolution. All references must resolve to existing nodes in the ExperimentContext. Unresolvable references indicate missing dependencies.

These checks are designed to surface configuration errors early—at graph construction or build time—rather than during training, where debugging becomes significantly harder.

Visualization#

ModularML includes a visualization system that generates Mermaid diagrams from graph structures. Each node type has a distinct visual style:

  • FeatureSet nodes appear in purple, showing the data sources feeding the graph.

  • ModelNode nodes appear in blue (or light blue when frozen), representing computational stages.

  • MergeNode nodes appear in gray, marking convergence points.

  • AppliedLoss nodes appear in red, indicating where losses are computed.

  • Sampler nodes appear in orange, showing data loading points.

Edge animations indicate data flow activity (static, slow, medium, or fast dash animations). FeatureSet nodes can optionally display their column names, shapes, split overlaps, and sample counts—useful for verifying that the graph is wired to the correct data.

These diagrams serve as both documentation and debugging tools. A visual representation of the graph makes it easier to verify that the topology matches the intended architecture, especially for complex multi-branch designs. It also makes collaboration significantly faster; one can just load a shared Experiment file, call .visualize(), and quickly ensure all data flow matches expectations and incorporates best data-bias prevention measures.

Design principles summarized#

Several principles guide the graph design:

Topology is separate from computation. The graph structure (which nodes exist, how they connect) is defined independently of what each node computes. This separation enables structural operations (add, remove, replace) without touching model internals.

References over pointers. Symbolic references decouple construction from resolution, enabling serialization, cross-validation, and construction-order independence.

Explicit over implicit. Backend differences, freeze states, and training strategies are all surfaced explicitly rather than hidden behind abstractions. The user knows whether gradients flow between stages and whether backends are compatible.

Validate early. Shape mismatches, backend conflicts, and wiring errors are caught at build time, not mid-training. This front-loads debugging effort to where it is cheapest.

Composition over inheritance. A ModelNode contains a BaseModel rather than extending it. A MergeNode defines a merge strategy rather than implementing concatenation directly. This keeps each component focused and replaceable.