{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# How to: Create and Use a ModelGraph\n", "\n", "A `ModelGraph` is the computational backbone of a ModularML `Experiment`. It organizes\n", "one or more `ModelNode`s (and optionally `MergeNode`s) into a directed acyclic graph (DAG)\n", "that handles:\n", "\n", "- **Shape inference:** Automatically determines input/output shapes for every node during `build()`.\n", "- **Topological execution:** Ensures nodes execute in dependency order during forward, training, and evaluation passes.\n", "- **Global optimizer management:** Optionally shares a single optimizer across all trainable nodes for end-to-end gradient flow.\n", "- **Freeze / unfreeze control:** Selectively disable training for subsets of the graph.\n", "- **Graph mutation:** Add, remove, replace, or insert nodes dynamically.\n", "- **Serialization & checkpointing:** Save and restore the full graph structure and learned weights.\n", "\n", "```\n", "FeatureSet ──> ModelNode(\"Encoder\") ──> ModelNode(\"Regressor\")\n", "\n", "FeatureSet ─┬─> ModelNode(\"A\") ──┐\n", " │ ├─> ConcatNode ──> ModelNode(\"Head\")\n", " └─> ModelNode(\"B\") ──┘\n", "```\n", "\n", "This notebook covers:\n", "\n", "- {ref}`03-create-modelgraph-creating-a-modelgraph`\n", "- {ref}`03-create-modelgraph-building-the-graph`\n", "- {ref}`03-create-modelgraph-graph-properties`\n", "- {ref}`03-create-modelgraph-forward-pass`\n", "- {ref}`03-create-modelgraph-graph-mutation`\n", "- {ref}`03-create-modelgraph-freezing-and-unfreezing`\n", "- {ref}`03-create-modelgraph-optimizer-management`\n", "- {ref}`03-create-modelgraph-serialization`\n", "- {ref}`03-create-modelgraph-checkpointing`\n", "- {ref}`03-create-modelgraph-summary`" ] }, { "cell_type": "code", "execution_count": null, "id": "1", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "\n", "from modularml import (\n", " ConcatNode,\n", " Experiment,\n", " FeatureSet,\n", " ModelGraph,\n", " ModelNode,\n", " Optimizer,\n", ")\n", "from modularml.models.torch import SequentialMLP\n", "\n", "# Create an Experiment with overwrite policy so we can freely recreate nodes\n", "# with the same names (prevent getting a warning each time we overwrite a node)\n", "exp = Experiment(label=\"create_modelgraph\", registration_policy=\"overwrite\")" ] }, { "cell_type": "markdown", "id": "2", "metadata": {}, "source": [ "We'll use a simple synthetic dataset throughout this notebook: 500 samples of a 10-point feature with a scalar target." ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": {}, "outputs": [], "source": [ "rng = np.random.default_rng(42)\n", "\n", "fs = FeatureSet.from_dict(\n", " label=\"SensorData\",\n", " data={\n", " \"voltage\": list(rng.standard_normal((500, 10))),\n", " \"soh\": list(rng.standard_normal((500, 1))),\n", " },\n", " feature_keys=\"voltage\",\n", " target_keys=\"soh\",\n", ")\n", "fs_ref = fs.reference(features=\"voltage\", targets=\"soh\")\n", "print(fs)" ] }, { "cell_type": "markdown", "id": "4", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "5", "metadata": {}, "source": [ "(03-create-modelgraph-creating-a-modelgraph)=\n", "## Creating a ModelGraph" ] }, { "cell_type": "markdown", "id": "6", "metadata": {}, "source": [ "\n", "A `ModelGraph` is constructed from a list of `GraphNode` instances and an optional shared `Optimizer`.\n", "\n", "```python\n", " ModelGraph(\n", " nodes: list[str | GraphNode] | None,\n", " optimizer: Optimizer | None = None,\n", " label: str = \"model-graph\",\n", " )\n", "```\n", "\n", "| Parameter | Type | Default | Description |\n", "|-----------|------|---------|-------------|\n", "| `nodes` | `list[str \\| GraphNode] \\| None` | (required) | Nodes comprising the graph. Pass node instances or their string labels. If `None`, all registered `GraphNode`s in the active `ExperimentContext` are used. |\n", "| `optimizer` | `Optimizer \\| None` | `None` | A shared optimizer for end-to-end training. If provided, all trainable nodes must share the same backend. |\n", "| `label` | `str` | `\"model-graph\"` | A human-readable label for this graph. |" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "### Simple Linear Graph\n", "\n", "The simplest graph is a linear chain: `FeatureSet -> ModelNode`." ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "node = ModelNode(\n", " label=\"SimpleMLP\",\n", " model=SequentialMLP(output_shape=(1, 1), n_layers=2, hidden_dim=32),\n", " upstream_ref=fs_ref,\n", ")\n", "\n", "mg = ModelGraph(\n", " nodes=[node],\n", " optimizer=Optimizer(opt=\"adam\", opt_kwargs={\"lr\": 1e-3}, backend=\"torch\"),\n", " label=\"simple-graph\",\n", ")\n", "print(f\"Label: {mg.label}\")\n", "print(f\"Nodes: {mg.node_labels}\")\n", "print(f\"Built: {mg.is_built}\")" ] }, { "cell_type": "markdown", "id": "9", "metadata": {}, "source": [ "### Multi-Node Chain\n", "\n", "Chain multiple `ModelNode`s by passing one as the `upstream_ref` of the next.\n", "\n", "ModelGraph supports the `.visualize()` method, which we'll use to show our topology updates." ] }, { "cell_type": "code", "execution_count": null, "id": "10", "metadata": {}, "outputs": [], "source": [ "encoder = ModelNode(\n", " label=\"Encoder\",\n", " model=SequentialMLP(output_shape=(1, 8), n_layers=2, hidden_dim=32),\n", " upstream_ref=fs_ref,\n", ")\n", "\n", "regressor = ModelNode(\n", " label=\"Regressor\",\n", " model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=16),\n", " upstream_ref=encoder,\n", ")\n", "\n", "mg_chain = ModelGraph(\n", " nodes=[encoder, regressor],\n", " optimizer=Optimizer(opt=\"adam\", backend=\"torch\"),\n", ")\n", "print(f\"Node labels: {mg_chain.node_labels}\")\n", "\n", "mg_chain.visualize()" ] }, { "cell_type": "markdown", "id": "11", "metadata": {}, "source": [ "### Branching Graph with MergeNode\n", "\n", "Use `ConcatNode` (a `MergeNode`) to combine outputs from parallel branches.\n", "\n", "```\n", "FeatureSet ─┬─> EncoderA ──┐\n", " │ ├─> ConcatNode ──> Head\n", " └─> EncoderB ──┘\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "12", "metadata": {}, "outputs": [], "source": [ "enc_a = ModelNode(\n", " label=\"EncoderA\",\n", " model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "enc_b = ModelNode(\n", " label=\"EncoderB\",\n", " model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "\n", "merge = ConcatNode(\n", " label=\"Merge\",\n", " upstream_refs=[enc_a, enc_b],\n", " concat_axis=-1,\n", " concat_axis_targets=\"first\",\n", ")\n", "\n", "head = ModelNode(\n", " label=\"Head\",\n", " model=SequentialMLP(n_layers=1, hidden_dim=8),\n", " upstream_ref=merge,\n", ")\n", "\n", "mg_branch = ModelGraph(\n", " nodes=[enc_a, enc_b, merge, head],\n", " optimizer=Optimizer(opt=\"adam\", backend=\"torch\"),\n", ")\n", "print(f\"Node labels: {mg_branch.node_labels}\")\n", "\n", "mg_branch.visualize()" ] }, { "cell_type": "markdown", "id": "13", "metadata": {}, "source": [ "### Referencing Nodes by Label\n", "\n", "Instead of passing node instances, you can pass their string labels. The graph will look them up in the active `ExperimentContext`." ] }, { "cell_type": "code", "execution_count": null, "id": "14", "metadata": {}, "outputs": [], "source": [ "mg_by_label = ModelGraph(\n", " nodes=[\"EncoderA\", \"EncoderB\", \"Merge\", \"Head\"],\n", " optimizer=Optimizer(opt=\"adam\", backend=\"torch\"),\n", ")\n", "print(f\"Node labels: {mg_by_label.node_labels}\")\n", "\n", "mg_by_label.visualize()" ] }, { "cell_type": "markdown", "id": "15", "metadata": {}, "source": [ "### Without a Global Optimizer\n", "\n", "If no global optimizer is provided, each `ModelNode` must define its own local optimizer. This is useful when different nodes need different optimizers or learning rates (stage-wise training)." ] }, { "cell_type": "code", "execution_count": null, "id": "16", "metadata": {}, "outputs": [], "source": [ "node_with_opt = ModelNode(\n", " label=\"StageWiseMLP\",\n", " model=SequentialMLP(output_shape=(1, 1), n_layers=2, hidden_dim=32),\n", " upstream_ref=fs_ref,\n", " optimizer=Optimizer(\"adam\", opt_kwargs={\"lr\": 1e-3}, backend=\"torch\"),\n", ")\n", "\n", "mg_no_global = ModelGraph(\n", " nodes=[node_with_opt],\n", " optimizer=None,\n", ")\n", "print(f\"Global optimizer: {mg_no_global.backend}\")" ] }, { "cell_type": "markdown", "id": "17", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "18", "metadata": {}, "source": [ "(03-create-modelgraph-building-the-graph)=\n", "## Building the Graph" ] }, { "cell_type": "markdown", "id": "19", "metadata": {}, "source": [ "\n", "`ModelGraph.build()` performs the following steps in topological order:\n", "\n", "1. **Validates** the DAG structure (no cycles, all upstream references resolved).\n", "2. **Infers** input and output shapes for each node from upstream outputs and FeatureSet shapes.\n", "3. **Builds** each node's underlying model (lazy initialization).\n", "4. **Builds** the global optimizer (if provided) with parameters from all trainable nodes.\n", "\n", "```python\n", " ModelGraph.build(*, force: bool = False)\n", "```\n", "\n", "| Parameter | Type | Default | Description |\n", "|-----------|------|---------|-------------|\n", "| `force` | `bool` | `False` | If `True`, rebuilds even if the graph is already built. |" ] }, { "cell_type": "code", "execution_count": null, "id": "20", "metadata": {}, "outputs": [], "source": [ "mg_branch.build()\n", "print(f\"Built: {mg_branch.is_built}\")\n", "\n", "for node in mg_branch.nodes.values():\n", " in_shape = (\n", " node.input_shape\n", " if hasattr(node, \"input_shape\")\n", " else list(node.input_shapes.values())\n", " )\n", " out_shape = getattr(node, \"output_shape\", None)\n", " print(f\" {node.label}: {in_shape} -> {out_shape}\")\n", "\n", "mg_branch.visualize() # Note how all edges now show the input/output shapes" ] }, { "cell_type": "markdown", "id": "21", "metadata": {}, "source": [ "### Shape Inference Details\n", "\n", "During `build()`, shapes propagate through the graph as follows:\n", "\n", "- **Head nodes** (inputs from a `FeatureSet`): Input shape is pulled directly from the referenced `FeatureSet` data.\n", "- **Intermediate nodes**: Input shape equals the output shape of their upstream node.\n", "- **Tail nodes** (no downstream consumers): If no `output_shape` is specified on the model, it defaults to the target shape propagated from the upstream `FeatureSet`.\n", "- **MergeNodes**: Both feature and target output shapes are determined by a dummy forward pass through the merge logic.\n", "\n", "You generally do not need to specify `input_shape` on your models — `build()` infers it. Specifying `output_shape` is recommended for all non-tail nodes." ] }, { "cell_type": "markdown", "id": "22", "metadata": {}, "source": [ "### Rebuilding\n", "\n", "Calling `build()` on an already-built graph is a no-op unless `force=True`." ] }, { "cell_type": "code", "execution_count": null, "id": "23", "metadata": {}, "outputs": [], "source": [ "# No-op (already built)\n", "mg_branch.build()\n", "\n", "# Force rebuild (e.g., after modifying graph structure)\n", "mg_branch.build(force=True)\n", "print(f\"Rebuilt: {mg_branch.is_built}\")" ] }, { "cell_type": "markdown", "id": "24", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "25", "metadata": {}, "source": [ "(03-create-modelgraph-graph-properties)=\n", "## Graph Properties" ] }, { "cell_type": "markdown", "id": "26", "metadata": {}, "source": [ "\n", "After building, the graph exposes several useful properties for inspecting its structure." ] }, { "cell_type": "code", "execution_count": null, "id": "27", "metadata": {}, "outputs": [], "source": [ "print(f\"Label: {mg_branch.label}\")\n", "print(f\"Built: {mg_branch.is_built}\")\n", "print(f\"Backend: {mg_branch.backend}\")\n", "print(f\"Node labels: {mg_branch.node_labels}\")" ] }, { "cell_type": "markdown", "id": "28", "metadata": {}, "source": [ "### Head and Tail Nodes\n", "\n", "- **Head nodes**: Nodes whose inputs come directly from a `FeatureSet` (no upstream `GraphNode` dependencies).\n", "- **Tail nodes**: Nodes whose outputs are not consumed by any other node in the graph." ] }, { "cell_type": "code", "execution_count": null, "id": "29", "metadata": {}, "outputs": [], "source": [ "print(\"Head nodes (receive FeatureSet data):\")\n", "for n in mg_branch.head_nodes.values():\n", " print(f\" - {n.label}\")\n", "\n", "print(\"\\nTail nodes (produce final outputs):\")\n", "for n in mg_branch.tail_nodes.values():\n", " print(f\" - {n.label}\")" ] }, { "cell_type": "markdown", "id": "30", "metadata": {}, "source": [ "### Accessing Individual Nodes\n", "\n", "Nodes are stored in a dict keyed by `node_id`. These IDs are globally unique and are the reason nodes can be reference by their label, ID, or instance at any point in an Experiment.\n", "\n", "You can iterate over nodes or access by label." ] }, { "cell_type": "code", "execution_count": null, "id": "31", "metadata": {}, "outputs": [], "source": [ "# All nodes (keyed by node_id)\n", "for n_id, node in mg_branch.nodes.items():\n", " print(f\" {node.label} (id={n_id[:8]}...)\")" ] }, { "cell_type": "markdown", "id": "32", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "33", "metadata": {}, "source": [ "(03-create-modelgraph-forward-pass)=\n", "## Forward Pass" ] }, { "cell_type": "markdown", "id": "34", "metadata": {}, "source": [ "\n", "Once built, you can execute a forward pass through the graph. The graph handles data routing between nodes in topological order.\n", "\n", "```python\n", " ModelGraph.forward(\n", " inputs: dict[tuple[str, FeatureSetReference], TForward],\n", " *,\n", " active_nodes: list[str | GraphNode] | None = None,\n", " ) -> dict[str, TForward]\n", "```\n", "\n", "\n", "| Parameter | Type | Description |\n", "|-----------|------|-------------|\n", "| `inputs` | `dict` | Mapping of `(head_node_id, FeatureSetReference)` to input data. Each head node needs its upstream `FeatureSet` data. |\n", "| `active_nodes` | `list \\| None` | Optional subset of nodes to execute. Upstream dependencies are included automatically. If `None`, all nodes run. |\n", "\n", "**Returns:** A dict mapping `node_id` to that node's output data for every executed node." ] }, { "cell_type": "code", "execution_count": null, "id": "35", "metadata": {}, "outputs": [], "source": [ "from modularml.core.data.sample_data import SampleData\n", "from modularml.utils.data.data_format import DataFormat\n", "\n", "# Prepare input data\n", "fsv = fs_ref.resolve()\n", "sample_data = SampleData(\n", " features=fsv.get_features(fmt=DataFormat.TORCH),\n", " targets=fsv.get_targets(fmt=DataFormat.TORCH),\n", ")\n", "\n", "# Build the inputs dict: (head_node_id, featureset_ref) -> data\n", "inputs = {}\n", "for n_id, node in mg_branch.head_nodes.items():\n", " for ref in node.get_upstream_refs():\n", " inputs[(n_id, ref)] = sample_data\n", "\n", "print(f\"Number of input entries: {len(inputs)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "36", "metadata": {}, "outputs": [], "source": [ "# Execute forward pass\n", "with torch.no_grad():\n", " outputs = mg_branch.forward(inputs)\n", "\n", "print(\"Outputs per node:\")\n", "for n_id, out in outputs.items():\n", " node_label = mg_branch.nodes[n_id].label\n", " print(f\" {node_label}: features={out.features.shape}\")" ] }, { "cell_type": "markdown", "id": "37", "metadata": {}, "source": [ "### Active Nodes\n", "\n", "You can restrict the forward pass to a subset of the graph using `active_nodes`. All required upstream dependencies are automatically included.\n", "\n", "We can set just \"merge\" to be active, but all upstream nodes (Encoders A and B) will need to be executed as well.\n", "The head node, however, does not need to be executed." ] }, { "cell_type": "code", "execution_count": null, "id": "38", "metadata": {}, "outputs": [], "source": [ "# Only execute EncoderA and the Merge (plus its dependencies)\n", "with torch.no_grad():\n", " partial_outputs = mg_branch.forward(inputs, active_nodes=[merge])\n", "\n", "print(\"Executed nodes:\")\n", "for n_id in partial_outputs:\n", " print(f\" - {mg_branch.nodes[n_id].label}\")" ] }, { "cell_type": "markdown", "id": "39", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "40", "metadata": {}, "source": [ "(03-create-modelgraph-graph-mutation)=\n", "## Graph Mutation" ] }, { "cell_type": "markdown", "id": "41", "metadata": {}, "source": [ "\n", "`ModelGraph` provides several methods to modify the graph structure after creation. All mutation methods return `self` for method chaining.\n", "\n", "After any structural change, the graph automatically revalidates connections and recomputes the topological order. You will need to call `build()` again to reinitialize shapes and optimizers." ] }, { "cell_type": "markdown", "id": "42", "metadata": {}, "source": [ "### `add_node()`\n", "\n", "Add a new node to the graph. The node must already be connected to existing nodes via its `upstream_ref`." ] }, { "cell_type": "code", "execution_count": null, "id": "43", "metadata": {}, "outputs": [], "source": [ "# Start with a simple single-node graph\n", "base_node = ModelNode(\n", " label=\"Base\",\n", " model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "mg_mut = ModelGraph(\n", " nodes=[base_node],\n", " optimizer=Optimizer(opt=\"adam\", backend=\"torch\"),\n", ")\n", "print(f\"Before: {mg_mut.node_labels}\")\n", "\n", "# Add a downstream node\n", "added_node = ModelNode(\n", " label=\"Added\",\n", " model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),\n", " upstream_ref=base_node,\n", ")\n", "mg_mut.add_node(added_node)\n", "print(f\"After: {mg_mut.node_labels}\")\n", "\n", "mg_mut.visualize()" ] }, { "cell_type": "markdown", "id": "44", "metadata": {}, "source": [ "### `remove_node()`\n", "\n", "Remove a node from the graph. Downstream nodes are reconnected to the removed node's upstream sources.\n", "\n", "```\n", "Given: A -> B -> C\n", "Remove B:\n", "Result: A -> C\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "45", "metadata": {}, "outputs": [], "source": [ "# Create a 3-node chain\n", "n1 = ModelNode(\n", " label=\"N1\",\n", " model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "n2 = ModelNode(\n", " label=\"N2\",\n", " model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=8),\n", " upstream_ref=n1,\n", ")\n", "n3 = ModelNode(\n", " label=\"N3\",\n", " model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),\n", " upstream_ref=n2,\n", ")\n", "mg_rem = ModelGraph(\n", " nodes=[n1, n2, n3],\n", " optimizer=Optimizer(opt=\"adam\", backend=\"torch\"),\n", ")\n", "print(f\"Before: {mg_rem.node_labels}\")\n", "\n", "# Remove the middle node\n", "mg_rem.remove_node(\"N2\")\n", "print(f\"After: {mg_rem.node_labels}\")\n", "\n", "mg_rem.visualize()" ] }, { "cell_type": "markdown", "id": "46", "metadata": {}, "source": [ "### `replace_node()`\n", "\n", "Replace an existing node with a new one, preserving all upstream and downstream connections." ] }, { "cell_type": "code", "execution_count": null, "id": "47", "metadata": {}, "outputs": [], "source": [ "# Create a simple chain\n", "old_enc = ModelNode(\n", " label=\"OldEncoder\",\n", " model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "reg = ModelNode(\n", " label=\"Reg\",\n", " model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),\n", " upstream_ref=old_enc,\n", ")\n", "mg_rep = ModelGraph(\n", " nodes=[old_enc, reg],\n", " optimizer=Optimizer(opt=\"adam\", backend=\"torch\"),\n", ")\n", "print(f\"Before: {mg_rep.node_labels}\")\n", "\n", "# Replace with a deeper encoder\n", "new_enc = ModelNode(\n", " label=\"NewEncoder\",\n", " model=SequentialMLP(output_shape=(1, 8), n_layers=3, hidden_dim=64),\n", " upstream_ref=fs_ref,\n", ")\n", "mg_rep.replace_node(old_node=\"OldEncoder\", new_node=new_enc)\n", "print(f\"After: {mg_rep.node_labels}\")\n", "\n", "mg_rep.visualize()" ] }, { "cell_type": "markdown", "id": "48", "metadata": {}, "source": [ "### `insert_node_between()`\n", "\n", "Insert a new node between two already-connected nodes.\n", "\n", "```\n", "Given: A -> B\n", "Insert C between A and B:\n", "Result: A -> C -> B\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "49", "metadata": {}, "outputs": [], "source": [ "a = ModelNode(\n", " label=\"A\",\n", " model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "b = ModelNode(\n", " label=\"B\",\n", " model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),\n", " upstream_ref=a,\n", ")\n", "mg_ins = ModelGraph(\n", " nodes=[a, b],\n", " optimizer=Optimizer(opt=\"adam\", backend=\"torch\"),\n", ")\n", "print(f\"Before: {mg_ins.node_labels}\")\n", "\n", "c = ModelNode(\n", " label=\"C\",\n", " model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref, # will be overwritten by insert\n", ")\n", "mg_ins.insert_node_between(new_node=c, upstream=a, downstream=b)\n", "print(f\"After: {mg_ins.node_labels}\")\n", "\n", "# Verify connectivity\n", "for node in mg_ins.nodes.values():\n", " ups = [r.node_label for r in node.get_upstream_refs()]\n", " print(f\" {node.label} <- {ups}\")\n", "\n", "mg_ins.visualize()" ] }, { "cell_type": "markdown", "id": "50", "metadata": {}, "source": [ "### `insert_node_before()` and `insert_node_after()`\n", "\n", "- `insert_node_before(new_node, downstream=...)`: Insert before an existing node, taking over all its upstream connections.\n", "- `insert_node_after(new_node, upstream=...)`: Insert after an existing node as an additional downstream consumer." ] }, { "cell_type": "code", "execution_count": null, "id": "51", "metadata": {}, "outputs": [], "source": [ "p = ModelNode(\n", " label=\"P\",\n", " model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "q = ModelNode(\n", " label=\"Q\",\n", " model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),\n", " upstream_ref=p,\n", ")\n", "mg_ib = ModelGraph(\n", " nodes=[p, q],\n", " optimizer=Optimizer(opt=\"adam\", backend=\"torch\"),\n", ")\n", "\n", "# Insert a node before Q (takes over Q's upstream connections)\n", "pre_q = ModelNode(\n", " label=\"PreQ\",\n", " model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),\n", " upstream_ref=fs_ref,\n", ")\n", "mg_ib.insert_node_before(new_node=pre_q, downstream=q)\n", "print(\"After insert_node_before:\")\n", "for node in mg_ib.nodes.values():\n", " ups = [r.node_label for r in node.get_upstream_refs()]\n", " print(f\" {node.label} <- {ups}\")\n", "\n", "mg_ib.visualize()" ] }, { "cell_type": "code", "execution_count": null, "id": "52", "metadata": {}, "outputs": [], "source": [ "# Insert a node after P (adds a new branch)\n", "post_p = ModelNode(\n", " label=\"PostP\",\n", " model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),\n", " upstream_ref=fs_ref,\n", ")\n", "mg_ib.insert_node_after(new_node=post_p, upstream=p)\n", "print(\"After insert_node_after:\")\n", "for node in mg_ib.nodes.values():\n", " ups = [r.node_label for r in node.get_upstream_refs()]\n", " print(f\" {node.label} <- {ups}\")\n", "\n", "print(f\"\\nTail nodes: {[n.label for n in mg_ib.tail_nodes.values()]}\")\n", "\n", "mg_ib.visualize()" ] }, { "cell_type": "markdown", "id": "53", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "54", "metadata": {}, "source": [ "(03-create-modelgraph-freezing-and-unfreezing)=\n", "## Freezing and Unfreezing" ] }, { "cell_type": "markdown", "id": "55", "metadata": {}, "source": [ "\n", "Freezing prevents a node's parameters from being updated during training. This is useful for transfer learning, multi-stage training, or keeping pretrained components fixed.\n", "\n", "```python\n", " ModelGraph.freeze(nodes: list[str | GraphNode] | None = None)\n", " ModelGraph.unfreeze(nodes: list[str | GraphNode] | None = None)\n", "```\n", "\n", "| Parameter | Type | Default | Description |\n", "|-----------|------|---------|-------------|\n", "| `nodes` | `list \\| None` | `None` | Nodes to freeze/unfreeze (by label, ID, or instance). If `None`, applies to all trainable nodes. |" ] }, { "cell_type": "code", "execution_count": null, "id": "56", "metadata": {}, "outputs": [], "source": [ "# Using the branching graph from Section 1.3\n", "mg_branch.build(force=True)\n", "\n", "# Freeze specific nodes\n", "mg_branch.freeze(nodes=[enc_a])\n", "print(f\"Frozen nodes: {[n.label for n in mg_branch.frozen_nodes.values()]}\")\n", "mg_branch.visualize(show_frozen=True)\n", "\n", "# Unfreeze\n", "mg_branch.unfreeze(nodes=[enc_a])\n", "print(f\"Frozen nodes: {[n.label for n in mg_branch.frozen_nodes.values()]}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "57", "metadata": {}, "outputs": [], "source": [ "# Freeze all nodes at once\n", "mg_branch.freeze()\n", "print(f\"All frozen: {[n.label for n in mg_branch.frozen_nodes.values()]}\")\n", "\n", "# Unfreeze all\n", "mg_branch.unfreeze()\n", "print(f\"All unfrozen: {[n.label for n in mg_branch.frozen_nodes.values()]}\")" ] }, { "cell_type": "markdown", "id": "58", "metadata": {}, "source": [ "### Frozen Nodes and the Optimizer\n", "\n", "When using a global optimizer, the optimizer is automatically rebuilt to exclude frozen nodes' parameters before each training step. This means frozen nodes will not accumulate gradients and their weights remain unchanged." ] }, { "cell_type": "markdown", "id": "59", "metadata": {}, "source": [ "---\n", "\n" ] }, { "cell_type": "markdown", "id": "60", "metadata": {}, "source": [ "(03-create-modelgraph-optimizer-management)=\n", "## Optimizer Management" ] }, { "cell_type": "markdown", "id": "61", "metadata": {}, "source": [ "\n", "The `ModelGraph` supports two training modes based on whether a global optimizer is provided:\n", "\n", "### Global Optimizer (Graph-Wise Training)\n", "\n", "When a global `Optimizer` is set on the `ModelGraph`:\n", "- A single forward pass runs through the entire graph.\n", "- All losses are accumulated.\n", "- A single backward pass computes gradients across all unfrozen nodes.\n", "- The global optimizer steps once.\n", "\n", "This enables **end-to-end gradient flow** through the full graph, which is the most common training paradigm.\n", "\n", "### No Global Optimizer (Stage-Wise Training)\n", "\n", "When `optimizer=None` on the `ModelGraph`:\n", "- Each `ModelNode` must have its own local `Optimizer`.\n", "- Nodes are trained independently in topological order.\n", "- Each node performs its own forward pass, loss computation, backward pass, and optimizer step.\n", "\n", "This is useful when you need different optimizers per node, or when certain nodes should not share gradient flow." ] }, { "cell_type": "markdown", "id": "62", "metadata": {}, "source": [ "### Inspecting Optimizer Parameters\n", "\n", "After at least one training step (or after calling `build()`), you can inspect which nodes contribute parameters to the global optimizer." ] }, { "cell_type": "code", "execution_count": null, "id": "63", "metadata": {}, "outputs": [], "source": [ "mg_branch.build(force=True)\n", "\n", "opt_info = mg_branch.get_optimizer_parameters()\n", "print(f\"Backend: {opt_info['backend']}\")\n", "print(f\"Contributing nodes: {len(opt_info['contributing_nodes'])}\")\n", "print(f\"Total parameters: {len(opt_info['parameters'])}\")" ] }, { "cell_type": "markdown", "id": "64", "metadata": {}, "source": [ "### Backend Constraints\n", "\n", "When using a global optimizer, all trainable nodes must share the same backend (e.g., all PyTorch). A `RuntimeError` is raised if backends conflict.\n", "\n", "Mixed-backend graphs (e.g., PyTorch encoder + scikit-learn head) must use stage-wise training (no global optimizer)." ] }, { "cell_type": "markdown", "id": "65", "metadata": {}, "source": [ "---\n", "\n" ] }, { "cell_type": "markdown", "id": "66", "metadata": {}, "source": [ "(03-create-modelgraph-serialization)=\n", "## Serialization" ] }, { "cell_type": "markdown", "id": "67", "metadata": {}, "source": [ "`ModelGraph` supports full serialization: saving and loading both the graph structure (config) and learned weights (state)." ] }, { "cell_type": "markdown", "id": "68", "metadata": {}, "source": [ "### Config Serialization\n", "\n", "`get_config()` captures the graph structure (node configs, optimizer config) without learned weights. `from_config()` reconstructs the graph from a config dict." ] }, { "cell_type": "code", "execution_count": null, "id": "69", "metadata": {}, "outputs": [], "source": [ "config = mg_branch.get_config()\n", "print(f\"Config keys: {list(config.keys())}\")\n", "print(f\"Number of node configs: {len(config['nodes'])}\")\n", "print(f\"Optimizer config: {config['optimizer'] is not None}\")" ] }, { "cell_type": "markdown", "id": "70", "metadata": {}, "source": [ "### State Serialization\n", "\n", "`get_state()` captures the learned weights and optimizer state. `set_state()` restores them." ] }, { "cell_type": "code", "execution_count": null, "id": "71", "metadata": {}, "outputs": [], "source": [ "state = mg_branch.get_state()\n", "print(f\"State keys: {list(state.keys())}\")\n", "print(f\"Number of node states: {len(state['nodes'])}\")\n", "print(f\"Is built: {state['is_built']}\")" ] }, { "cell_type": "markdown", "id": "72", "metadata": {}, "source": [ "### Save and Load to Disk\n", "\n", "Use `save()` and `load()` for persistent serialization. The file includes both config and state." ] }, { "cell_type": "code", "execution_count": null, "id": "73", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "from tempfile import TemporaryDirectory\n", "\n", "SAVE_DIR = TemporaryDirectory()\n", "\n", "# Save\n", "save_path = mg_branch.save(Path(SAVE_DIR.name) / \"my_graph\", overwrite=True)\n", "print(f\"Saved to: {save_path}\")\n", "\n", "# Load\n", "# Note that we need allow overwriting because all reloaded node labels/IDs\n", "# with those defined in this notebook\n", "mg_loaded = ModelGraph.load(save_path, overwrite=True)\n", "print(f\"Loaded graph labels: {mg_loaded.node_labels}\")\n", "\n", "mg_loaded.visualize()" ] }, { "cell_type": "markdown", "id": "74", "metadata": {}, "source": [ "---\n" ] }, { "cell_type": "markdown", "id": "75", "metadata": {}, "source": [ "(03-create-modelgraph-checkpointing)=\n", "## Checkpointing" ] }, { "cell_type": "markdown", "id": "76", "metadata": {}, "source": [ "Checkpointing allows you to save and restore the full state of a `ModelGraph` at a specific point during training. Unlike `save()` / `load()` (which creates a new `ModelGraph` instance), checkpointing restores state into an existing graph.\n", "\n", "```python\n", " ModelGraph.save_checkpoint(\n", " filepath: Path,\n", " *,\n", " overwrite: bool = False,\n", " meta: dict[str, Any] | None = None,\n", " ) -> Path\n", "\n", " ModelGraph.restore_checkpoint(filepath: Path) -> ModelGraph\n", "```\n", "\n", "| Parameter | Type | Description |\n", "|-----------|------|-------------|\n", "| `filepath` | `Path` | Location to save/load the checkpoint. |\n", "| `overwrite` | `bool` | Whether to overwrite an existing file. |\n", "| `meta` | `dict` | Optional metadata to attach to the checkpoint (must be pickle-able). |" ] }, { "cell_type": "code", "execution_count": null, "id": "77", "metadata": {}, "outputs": [], "source": [ "# Save a checkpoint (includes model weights and optimizer state)\n", "ckpt_path = mg_branch.save_checkpoint(\n", " Path(SAVE_DIR.name) / \"checkpoint_epoch5\",\n", " overwrite=True,\n", " meta={\"epoch\": 5, \"val_loss\": 0.032},\n", ")\n", "print(f\"Checkpoint saved to: {ckpt_path}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "78", "metadata": {}, "outputs": [], "source": [ "# Restore the checkpoint into the existing graph\n", "mg_branch.restore_checkpoint(ckpt_path)\n", "print(f\"Restored. Built: {mg_branch.is_built}\")" ] }, { "cell_type": "markdown", "id": "79", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "80", "metadata": {}, "source": [ "(03-create-modelgraph-summary)=\n", "## Summary" ] }, { "cell_type": "markdown", "id": "81", "metadata": {}, "source": [ "\n", "### Constructor\n", "\n", "| Parameter | Type | Default | Description |\n", "|-----------|------|---------|-------------|\n", "| `nodes` | `list[str \\| GraphNode] \\| None` | (required) | Nodes comprising the graph. |\n", "| `optimizer` | `Optimizer \\| None` | `None` | Shared optimizer for graph-wise training. |\n", "| `label` | `str` | `\"model-graph\"` | Human-readable label. |\n", "\n", "### Properties\n", "\n", "| Property | Type | Description |\n", "|----------|------|-------------|\n", "| `.nodes` | `dict[str, GraphNode]` | All nodes keyed by `node_id`. |\n", "| `.node_labels` | `set[str]` | Unique node labels. |\n", "| `.head_nodes` | `dict[str, GraphNode]` | Nodes receiving FeatureSet input. |\n", "| `.tail_nodes` | `dict[str, GraphNode]` | Nodes with no downstream consumers. |\n", "| `.is_built` | `bool` | Whether `build()` has been called. |\n", "| `.backend` | `Backend \\| None` | Backend of the global optimizer, or `None`. |\n", "| `.frozen_nodes` | `dict[str, GraphNode]` | Currently frozen trainable nodes. |\n", "\n", "### Methods\n", "\n", "| Method | Description |\n", "|--------|-------------|\n", "| `build(force=False)` | Build all nodes and the global optimizer. |\n", "| `forward(inputs, active_nodes=None)` | Execute a forward pass through the graph. |\n", "| `train_step(ctx, losses, active_nodes=None)` | Execute a single training step (graph-wise or stage-wise). |\n", "| `eval_step(ctx, losses, active_nodes=None)` | Execute a forward-only evaluation step (no gradients). |\n", "| `fit_step(ctx, losses=None, active_nodes=None)` | Fit batch-fit nodes (e.g., scikit-learn) in topological order. |\n", "| `freeze(nodes=None)` | Freeze nodes to prevent training. |\n", "| `unfreeze(nodes=None)` | Unfreeze nodes to allow training. |\n", "| `add_node(node)` | Add a node to the graph. |\n", "| `remove_node(node)` | Remove a node, reconnecting neighbors. |\n", "| `replace_node(old_node, new_node)` | Replace a node, preserving connections. |\n", "| `insert_node_between(new_node, upstream, downstream)` | Insert between two connected nodes. |\n", "| `insert_node_before(new_node, downstream)` | Insert before an existing node. |\n", "| `insert_node_after(new_node, upstream)` | Insert after an existing node. |\n", "| `get_config()` / `from_config()` | Config serialization (structure only). |\n", "| `get_state()` / `set_state()` | State serialization (includes weights). |\n", "| `save(filepath)` / `load(filepath)` | Full serialization to/from disk. |\n", "| `save_checkpoint(filepath, meta=None)` | Save a training checkpoint. |\n", "| `restore_checkpoint(filepath)` | Restore state from a checkpoint. |\n", "\n", "### Training Modes\n", "\n", "| Mode | When | Behavior |\n", "|------|------|----------|\n", "| **Graph-wise** | Global `Optimizer` provided | Single forward + backward pass across all nodes. End-to-end gradient flow. |\n", "| **Stage-wise** | No global optimizer (`None`) | Each node trains independently with its own optimizer. |" ] }, { "cell_type": "markdown", "id": "82", "metadata": {}, "source": [ "### Next Steps\n", "\n", "- **Experiment:** Use `Experiment` to combine a `ModelGraph` with training phases,\n", " loss functions, and evaluation — the primary user-facing entry point.\n", "\n", "- **ModelNode:** See how individual nodes wrap models and handle forward passes.\n", "\n", "- **MergeNode:** Learn how to combine parallel branches with `ConcatNode`.\n" ] } ], "metadata": { "kernelspec": { "display_name": ".venv (3.10.18)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }