{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# How to: Create and Use an Experiment\n", "\n", "An `Experiment` is the top-level orchestrator in ModularML. It coordinates:\n", "\n", "- **Phases** — units of work such as training (`TrainPhase`), evaluation (`EvalPhase`), or batch fitting (`FitPhase`)\n", "- **Phase Groups** — named collections of phases that execute in order\n", "- **Callbacks** — hooks at phase, group, and experiment boundaries\n", "- **Checkpointing** — automatic saving and restoring of experiment state\n", "- **Execution History** — records of every run for reproducibility\n", "\n", "> **Note:** This notebook covers the `Experiment` API and how phases are registered,\n", "> organized, and executed. Phase-specific details (configuration, advanced usage) are\n", "> covered in dedicated notebooks:\n", "> $\\textcolor{red}{\\text{...to be added soon}}$\n", "\n", "This notebook covers:\n", "\n", "- {ref}`05-create-experiment-creating-an-experiment`\n", "- {ref}`05-create-experiment-setting-up-a-model-graph`\n", "- {ref}`05-create-experiment-defining-phases`\n", "- {ref}`05-create-experiment-the-execution-plan`\n", "- {ref}`05-create-experiment-running-phases`\n", "- {ref}`05-create-experiment-running-the-full-execution-plan`\n", "- {ref}`05-create-experiment-preview-mode`\n", "- {ref}`05-create-experiment-execution-history`\n", "- {ref}`05-create-experiment-phase-groups`\n", "- {ref}`05-create-experiment-experiment-callbacks`\n", "- {ref}`05-create-experiment-checkpointing`\n", "- {ref}`05-create-experiment-serialization`\n", "- {ref}`05-create-experiment-summary`" ] }, { "cell_type": "code", "execution_count": null, "id": "1", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from modularml import (\n", " AppliedLoss,\n", " EvalPhase,\n", " Experiment,\n", " FeatureSet,\n", " InputBinding,\n", " Loss,\n", " ModelGraph,\n", " ModelNode,\n", " Optimizer,\n", " TrainPhase,\n", ")\n", "from modularml.core.experiment.phases.phase_group import PhaseGroup\n", "from modularml.samplers import SimpleSampler" ] }, { "cell_type": "markdown", "id": "2", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "3", "metadata": {}, "source": [ "(05-create-experiment-creating-an-experiment)=\n", "## Creating an Experiment" ] }, { "cell_type": "markdown", "id": "4", "metadata": {}, "source": [ "\n", "An `Experiment` is created with a label and an optional `registration_policy` that\n", "controls how duplicate node names are handled.\n", "\n", "```python\n", " Experiment(\n", " label: str,\n", " registration_policy: str | None = None,\n", " ctx: ExperimentContext | None = None,\n", " checkpointing: Checkpointing | None = None,\n", " callbacks: list[ExperimentCallback] | None = None,\n", " )\n", "```\n", "\n", "| Parameter | Type | Default | Description |\n", "|-----------|------|---------|-------------|\n", "| `label` | `str` | (required) | Name for this experiment. |\n", "| `registration_policy` | `str \\| None` | `None` | How to handle duplicate node labels: `\"raise\"`, `\"overwrite\"`, or `\"rename\"`. |\n", "| `ctx` | `ExperimentContext \\| None` | `None` | Context to associate with. If `None`, a new context is created. |\n", "| `checkpointing` | `Checkpointing \\| None` | `None` | Experiment-level checkpointing configuration. |\n", "| `callbacks` | `list[ExperimentCallback] \\| None` | `None` | Experiment-level callbacks for phase/group boundaries. |" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [], "source": [ "exp = Experiment(label=\"my_experiment\", registration_policy=\"overwrite\")\n", "print(f\"Experiment: {exp.label}\")\n", "print(f\"Context: {exp.ctx}\")" ] }, { "cell_type": "markdown", "id": "6", "metadata": {}, "source": [ "### Registration Policy\n", "\n", "The `registration_policy` determines what happens when two nodes share the same label.\n", "This is primarily useful in notebook environments where cells may be re-executed.\n", "\n", "| Policy | Behavior |\n", "|--------|----------|\n", "| `\"raise\"` | Raises an error on duplicate labels (default). |\n", "| `\"overwrite\"` | Silently replaces the existing node. |\n", "| `\"rename\"` | Assigns a unique suffix to the new node's label. |" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "### Creating from an Active Context\n", "\n", "If nodes have already been registered in the current `ExperimentContext`,\n", "you can bind a new `Experiment` to that existing context with `from_active_context()`.\n", "This retains all previously registered nodes.\n", "\n", "```python\n", " exp = Experiment.from_active_context(\n", " label=\"my_experiment\",\n", " registration_policy=\"overwrite\",\n", " )\n", "```" ] }, { "cell_type": "markdown", "id": "8", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "9", "metadata": {}, "source": [ "(05-create-experiment-setting-up-a-model-graph)=\n", "## Setting Up a Model Graph" ] }, { "cell_type": "markdown", "id": "10", "metadata": {}, "source": [ "\n", "Before defining phases, we need a `ModelGraph` with at least one `ModelNode` and a\n", "`FeatureSet` to supply data. The `Experiment` automatically tracks the `ModelGraph`\n", "registered in its context.\n", "\n", "For details on creating model graphs, see {doc}`03_create_modelgraph`." ] }, { "cell_type": "code", "execution_count": null, "id": "11", "metadata": {}, "outputs": [], "source": [ "# Create synthetic data\n", "rng = np.random.default_rng(42)\n", "\n", "fs = FeatureSet.from_dict(\n", " label=\"SensorData\",\n", " data={\n", " \"voltage\": list(rng.standard_normal((500, 10))),\n", " \"soh\": list(rng.standard_normal((500, 1))),\n", " },\n", " feature_keys=\"voltage\",\n", " target_keys=\"soh\",\n", ")\n", "\n", "# Create a train/test split\n", "fs.split_random(\n", " ratios={\n", " \"train\": 0.8,\n", " \"test\": 0.2,\n", " },\n", " seed=13,\n", ")\n", "print(fs)\n", "print(f\"Splits: {fs.available_splits}\")\n", "fs.visualize()" ] }, { "cell_type": "code", "execution_count": null, "id": "12", "metadata": {}, "outputs": [], "source": [ "from modularml.models.torch import SequentialMLP\n", "\n", "# Reference defining which columns feed into the model\n", "fs_ref = fs.reference(features=\"voltage\", targets=\"soh\")\n", "\n", "# Create model node\n", "node = ModelNode(\n", " label=\"MLP\",\n", " model=SequentialMLP(output_shape=(1, 1), n_layers=2, hidden_dim=32),\n", " upstream_ref=fs_ref,\n", ")\n", "\n", "# Create model graph with a global optimizer\n", "graph = ModelGraph(\n", " label=\"SimpleGraph\",\n", " nodes=[node],\n", " optimizer=Optimizer(\"adam\", opt_kwargs={\"lr\": 1e-3}, backend=\"torch\"),\n", ")\n", "\n", "# Build the graph (infers shapes)\n", "graph.build()\n", "graph.visualize()\n", "\n", "print(f\"Experiment model_graph: {exp.model_graph}\")" ] }, { "cell_type": "markdown", "id": "13", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "14", "metadata": {}, "source": [ "(05-create-experiment-defining-phases)=\n", "## Defining Phases" ] }, { "cell_type": "markdown", "id": "15", "metadata": {}, "source": [ "\n", "Phases are the executable units of an `Experiment`. Each phase type handles a\n", "different style of model execution:\n", "\n", "| Phase | Purpose | Key Concept |\n", "|-------|---------|-------------|\n", "| `TrainPhase` | Mini-batch gradient training | Requires a `Sampler` and `Loss` |\n", "| `EvalPhase` | Forward-only evaluation | No sampler; runs on full split |\n", "| `FitPhase` | Batch fitting (e.g., scikit-learn) | Entire dataset passed at once |\n", "\n", "All phases require **input bindings** that connect `FeatureSet` data to head\n", "`GraphNode`s in the model graph." ] }, { "cell_type": "markdown", "id": "16", "metadata": {}, "source": [ "### Input Bindings\n", "\n", "An `InputBinding` defines how data flows from a `FeatureSet` into a head `GraphNode`\n", "during a specific phase. There are two constructors:\n", "\n", "- **`InputBinding.for_training(...)`** — requires a `Sampler` to generate batches\n", "- **`InputBinding.for_evaluation(...)`** — passes data directly (no sampler)\n", "\n", "| Parameter | `for_training` | `for_evaluation` |\n", "|-----------|:-:|:-:|\n", "| `node` | required | required |\n", "| `sampler` | required | — |\n", "| `upstream` | required\\* | required\\* |\n", "| `split` | optional | optional |\n", "\n", "\\* Can be `None` if the node has exactly one upstream `FeatureSet`." ] }, { "cell_type": "code", "execution_count": null, "id": "17", "metadata": {}, "outputs": [], "source": [ "# Training binding: requires a sampler\n", "train_binding = InputBinding.for_training(\n", " node=node,\n", " sampler=SimpleSampler(batch_size=32, shuffle=True, seed=42),\n", " upstream=None, # auto-resolved (node has one upstream FeatureSet)\n", " split=\"train\",\n", ")\n", "print(f\"Train binding node: {train_binding.node_id[:8]}...\")\n", "print(f\"Train binding split: {train_binding.split}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "18", "metadata": {}, "outputs": [], "source": [ "# Evaluation binding: no sampler needed\n", "eval_binding = InputBinding.for_evaluation(\n", " node=node,\n", " upstream=None,\n", " split=\"test\",\n", ")\n", "print(f\"Eval binding split: {eval_binding.split}\")" ] }, { "cell_type": "markdown", "id": "19", "metadata": {}, "source": [ "### Defining a Loss\n", "\n", "Training phases require at least one `AppliedLoss`, which binds a `Loss` function to\n", "a specific `ModelNode` and specifies what inputs the loss receives.\n", "\n", "```python\n", " AppliedLoss(\n", " loss: Loss,\n", " on: str | ModelNode,\n", " inputs: list[str] | dict[str, str],\n", " weight: float = 1.0,\n", " label: str | None = None,\n", " )\n", "```\n", "\n", "The `inputs` argument uses string references to resolve data at runtime:\n", "- `\"outputs\"` — the model node's predictions\n", "- `\"targets\"` — the target data passed through the model node" ] }, { "cell_type": "code", "execution_count": null, "id": "20", "metadata": {}, "outputs": [], "source": [ "mse_loss = AppliedLoss(\n", " loss=Loss(\"mse\", backend=\"torch\"),\n", " on=node,\n", " inputs=[\"outputs\", \"targets\"],\n", ")\n", "print(f\"Loss: {mse_loss.label}\")\n", "print(f\"Applied on: {mse_loss.node_id[:8]}...\")" ] }, { "cell_type": "markdown", "id": "21", "metadata": {}, "source": [ "### Creating a TrainPhase\n", "\n", "A `TrainPhase` performs mini-batch gradient training over one or more epochs.\n", "\n", "There are two ways to create a `TrainPhase`:\n", "\n", "1. **Default constructor** — provide `InputBinding`s explicitly\n", "2. **`from_split()` convenience** — auto-generates bindings from a split name" ] }, { "cell_type": "code", "execution_count": null, "id": "22", "metadata": {}, "outputs": [], "source": [ "# Option A: Using explicit InputBindings\n", "train_phase = TrainPhase(\n", " label=\"train\",\n", " input_sources=[train_binding],\n", " losses=[mse_loss],\n", " n_epochs=3,\n", ")\n", "print(f\"TrainPhase: {train_phase.label}\")\n", "print(f\" n_epochs: {train_phase.n_epochs}\")\n", "print(f\" losses: {[ls.label for ls in train_phase.losses]}\")\n", "\n", "train_phase.visualize()" ] }, { "cell_type": "code", "execution_count": null, "id": "23", "metadata": {}, "outputs": [], "source": [ "# Option B: Using the from_split() convenience constructor\n", "# This auto-generates InputBindings for all active head nodes\n", "train_phase_b = TrainPhase.from_split(\n", " label=\"train_from_split\",\n", " split=\"train\",\n", " sampler=SimpleSampler(batch_size=32, shuffle=True, seed=42),\n", " losses=[mse_loss],\n", " n_epochs=3,\n", ")\n", "print(f\"TrainPhase (from_split): {train_phase_b.label}\")\n", "\n", "train_phase.visualize()" ] }, { "cell_type": "markdown", "id": "24", "metadata": {}, "source": [ "### Creating an EvalPhase\n", "\n", "An `EvalPhase` runs a forward pass over a FeatureSet split without any gradient\n", "computation. All graph nodes are automatically frozen during evaluation." ] }, { "cell_type": "code", "execution_count": null, "id": "25", "metadata": {}, "outputs": [], "source": [ "# Using the from_split() convenience constructor\n", "eval_phase = EvalPhase.from_split(\n", " label=\"eval\",\n", " split=\"test\",\n", " losses=[mse_loss],\n", ")\n", "print(f\"EvalPhase: {eval_phase.label}\")\n", "\n", "eval_phase.visualize()" ] }, { "cell_type": "markdown", "id": "26", "metadata": {}, "source": [ "### Creating a FitPhase\n", "\n", "A `FitPhase` fits batch-fit models (like scikit-learn estimators) on the entire\n", "dataset at once. It has no epochs or sampling. By default, fitted nodes are frozen\n", "after fitting.\n", "\n", "```python\n", " fit_phase = FitPhase.from_split(\n", " label=\"fit_rf\",\n", " split=\"train\",\n", " freeze_after_fit=True, # default\n", " )\n", "```\n", "\n", "> **Note:** FitPhase is only relevant when your `ModelGraph` contains scikit-learn\n", "> (batch-fit) model nodes. We will not use it in the running examples below since\n", "> our graph uses PyTorch models." ] }, { "cell_type": "markdown", "id": "27", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "28", "metadata": {}, "source": [ "(05-create-experiment-the-execution-plan)=\n", "## The Execution Plan\n", "\n", "Every `Experiment` has an `execution_plan` property — a `PhaseGroup` that defines the\n", "order in which phases execute when you call `experiment.run()`.\n", "\n", "Phases are added with `add_phase()` and execute in the order they are registered." ] }, { "cell_type": "code", "execution_count": null, "id": "29", "metadata": {}, "outputs": [], "source": [ "# Access the execution plan\n", "plan = exp.execution_plan\n", "print(f\"Execution plan: {plan}\")\n", "print(f\"Currently empty: {len(plan.all) == 0}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "30", "metadata": {}, "outputs": [], "source": [ "# Register phases in execution order\n", "plan.add_phase(train_phase)\n", "plan.add_phase(eval_phase)\n", "\n", "print(f\"Plan entries: {len(plan.all)}\")\n", "for i, entry in enumerate(plan.all):\n", " print(f\" [{i}] {entry.label} ({type(entry).__name__})\")" ] }, { "cell_type": "markdown", "id": "31", "metadata": {}, "source": [ "### Accessing Phases\n", "\n", "Phases can be accessed by position (index) or by label." ] }, { "cell_type": "code", "execution_count": null, "id": "32", "metadata": {}, "outputs": [], "source": [ "# By index\n", "first_phase = plan[0]\n", "print(f\"By index: {first_phase.label}\")\n", "\n", "# By label\n", "train_ref = plan[\"train\"]\n", "print(f\"By label: {train_ref.label}\")\n", "\n", "# Type-safe accessors\n", "tp = plan.get_train_phase(\"train\")\n", "ep = plan.get_eval_phase(\"eval\")\n", "print(f\"TrainPhase: {tp.label}, EvalPhase: {ep.label}\")" ] }, { "cell_type": "markdown", "id": "33", "metadata": {}, "source": [ "### Removing Phases\n", "\n", "Phases can be removed by index, label, or instance." ] }, { "cell_type": "code", "execution_count": null, "id": "34", "metadata": {}, "outputs": [], "source": [ "# Remove by label\n", "plan.remove_phase(\"eval\")\n", "print(f\"After remove: {[e.label for e in plan.all]}\")\n", "\n", "# Re-add for later examples\n", "plan.add_phase(eval_phase)\n", "print(f\"After re-add: {[e.label for e in plan.all]}\")" ] }, { "cell_type": "markdown", "id": "35", "metadata": {}, "source": [ "### Convenience Methods\n", "\n", "The execution plan also provides convenience methods to construct and register\n", "phases in a single call:\n", "\n", "```python\n", " plan.add_train_phase(\n", " label=\"train\",\n", " input_sources=[...],\n", " losses=[...],\n", " n_epochs=5,\n", " )\n", "\n", " plan.add_eval_phase(\n", " label=\"eval\",\n", " input_sources=[...],\n", " losses=[...],\n", " )\n", "```\n", "\n", "Aliases `add_train()`, `add_training()`, `add_eval()`, and `add_evaluation()` are also available." ] }, { "cell_type": "markdown", "id": "36", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "37", "metadata": {}, "source": [ "(05-create-experiment-running-phases)=\n", "## Running Phases" ] }, { "cell_type": "markdown", "id": "38", "metadata": {}, "source": [ "Phases can be run individually with `run_phase()`, regardless of whether they\n", "are registered on the execution plan. Each run mutates experiment state and\n", "records an entry in `history`." ] }, { "cell_type": "code", "execution_count": null, "id": "39", "metadata": {}, "outputs": [], "source": [ "# Run the training phase\n", "train_results = exp.run_phase(train_phase)\n", "print(\"Training completed.\")\n", "print(f\" History entries: {len(exp.history)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "40", "metadata": {}, "outputs": [], "source": [ "# Run the evaluation phase\n", "eval_results = exp.run_phase(eval_phase)\n", "print(\"Evaluation completed.\")\n", "print(f\" History entries: {len(exp.history)}\")" ] }, { "cell_type": "markdown", "id": "41", "metadata": {}, "source": [ "### Display Options\n", "\n", "Each phase type accepts display-related keyword arguments to control progress bars:\n", "\n", "**TrainPhase:**\n", "\n", "| Parameter | Default | Description |\n", "|-----------|---------|-------------|\n", "| `show_sampler_progress` | `True` | Show progress for batch creation |\n", "| `show_training_progress` | `True` | Show epoch-level progress bar |\n", "| `persist_progress` | `IN_NOTEBOOK` | Keep progress bars visible after completion |\n", "| `persist_epoch_progress` | `IN_NOTEBOOK` | Keep per-epoch bars visible |\n", "\n", "**EvalPhase:**\n", "\n", "| Parameter | Default | Description |\n", "|-----------|---------|-------------|\n", "| `show_eval_progress` | `False` | Show evaluation progress bar |\n", "| `persist_progress` | `IN_NOTEBOOK` | Keep progress bars visible after completion |" ] }, { "cell_type": "markdown", "id": "42", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "43", "metadata": {}, "source": [ "(05-create-experiment-running-the-full-execution-plan)=\n", "## Running the Full Execution Plan" ] }, { "cell_type": "markdown", "id": "44", "metadata": {}, "source": [ "Calling `experiment.run()` executes all phases registered on the execution plan,\n", "in the order they were added. This is the primary entry point for running a\n", "complete experiment." ] }, { "cell_type": "code", "execution_count": null, "id": "45", "metadata": {}, "outputs": [], "source": [ "# Run the full execution plan (train -> eval)\n", "results = exp.run()\n", "print(\"Full run completed.\")\n", "print(f\" History entries: {len(exp.history)}\")" ] }, { "cell_type": "markdown", "id": "46", "metadata": {}, "source": [ "`run()` returns a `PhaseGroupResults` object that contains results from all\n", "executed phases. Individual phase results can be accessed by label." ] }, { "cell_type": "code", "execution_count": null, "id": "47", "metadata": {}, "outputs": [], "source": [ "# Inspect results\n", "print(f\"Result type: {type(results).__name__}\")\n", "print(f\"Contained results: {results.flatten()}\")" ] }, { "cell_type": "markdown", "id": "48", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "49", "metadata": {}, "source": [ "(05-create-experiment-preview-mode)=\n", "## Preview Mode" ] }, { "cell_type": "markdown", "id": "50", "metadata": {}, "source": [ "Sometimes you want to evaluate a phase without permanently changing experiment\n", "state. The `preview_phase()` and `preview_group()` methods do exactly this:\n", "\n", "1. Capture the current experiment state\n", "2. Execute the phase/group\n", "3. Restore the original state\n", "\n", "Preview runs are **not** recorded in `history`, and checkpointing is disabled." ] }, { "cell_type": "code", "execution_count": null, "id": "51", "metadata": {}, "outputs": [], "source": [ "history_before = len(exp.history)\n", "\n", "# Preview does not mutate state\n", "preview_res = exp.preview_phase(eval_phase)\n", "\n", "history_after = len(exp.history)\n", "print(f\"History before: {history_before}\")\n", "print(f\"History after: {history_after}\")\n", "print(f\"State was restored: {history_before == history_after}\")" ] }, { "cell_type": "markdown", "id": "52", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "53", "metadata": {}, "source": [ "(05-create-experiment-execution-history)=\n", "## Execution History" ] }, { "cell_type": "markdown", "id": "54", "metadata": {}, "source": [ "Every call to `run_phase()`, `run_group()`, or `run()` records an `ExperimentRun`\n", "in `experiment.history`. Each run captures:\n", "\n", "- Label, start/end timestamps, and status\n", "- Phase results (losses, outputs, etc.)\n", "- Execution metadata (timing per phase)" ] }, { "cell_type": "code", "execution_count": null, "id": "55", "metadata": {}, "outputs": [], "source": [ "for i, run in enumerate(exp.history):\n", " print(\n", " f\" Run {i}: label={run.label!r}, \"\n", " f\"status={run.status}, \"\n", " f\"duration={run.ended_at - run.started_at}\",\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "56", "metadata": {}, "outputs": [], "source": [ "# Access the most recent run\n", "last = exp.last_run\n", "print(f\"Last run: {last.label}\")\n", "print(f\" Status: {last.status}\")\n", "print(f\" Results: {type(last.results).__name__}\")" ] }, { "cell_type": "markdown", "id": "57", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "58", "metadata": {}, "source": [ "(05-create-experiment-phase-groups)=\n", "## Phase Groups" ] }, { "cell_type": "markdown", "id": "59", "metadata": {}, "source": [ "A `PhaseGroup` is a named collection that organizes phases into logical blocks.\n", "Phase groups can be nested (a group can contain other groups), enabling\n", "hierarchical experiment structures.\n", "\n", "The experiment's `execution_plan` is itself a `PhaseGroup`." ] }, { "cell_type": "code", "execution_count": null, "id": "60", "metadata": {}, "outputs": [], "source": [ "# Create a sub-group for a train-eval cycle\n", "cycle = PhaseGroup(label=\"train_eval_cycle\")\n", "\n", "cycle.add_phase(\n", " TrainPhase.from_split(\n", " label=\"cycle_train\",\n", " split=\"train\",\n", " sampler=SimpleSampler(batch_size=32, shuffle=True, seed=42),\n", " losses=[mse_loss],\n", " n_epochs=2,\n", " ),\n", ")\n", "cycle.add_phase(\n", " EvalPhase.from_split(\n", " label=\"cycle_eval\",\n", " split=\"test\",\n", " losses=[mse_loss],\n", " ),\n", ")\n", "\n", "print(f\"Group: {cycle}\")\n", "print(f\"Entries: {[e.label for e in cycle.all]}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "61", "metadata": {}, "outputs": [], "source": [ "# Run the group directly\n", "group_results = exp.run_group(cycle)\n", "print(f\"Group results: {group_results.flatten()}\")" ] }, { "cell_type": "markdown", "id": "62", "metadata": {}, "source": [ "### Nesting Groups\n", "\n", "Groups can be nested within the execution plan or within other groups.\n", "Use `add_group()` to nest a `PhaseGroup` inside another." ] }, { "cell_type": "code", "execution_count": null, "id": "63", "metadata": {}, "outputs": [], "source": [ "# Build a nested plan\n", "outer = PhaseGroup(label=\"outer\")\n", "\n", "inner = PhaseGroup(label=\"inner\")\n", "inner.add_phase(\n", " TrainPhase.from_split(\n", " label=\"inner_train\",\n", " split=\"train\",\n", " sampler=SimpleSampler(batch_size=64, shuffle=True, seed=0),\n", " losses=[mse_loss],\n", " n_epochs=1,\n", " ),\n", ")\n", "\n", "outer.add_group(inner)\n", "outer.add_phase(\n", " EvalPhase.from_split(\n", " label=\"outer_eval\",\n", " split=\"test\",\n", " losses=[mse_loss],\n", " ),\n", ")\n", "\n", "# flatten() unrolls all nested groups into execution order\n", "print(f\"Flattened: {[p.label for p in outer.flatten()]}\")" ] }, { "cell_type": "markdown", "id": "64", "metadata": {}, "source": [ "### PhaseGroup API\n", "\n", "| Method | Description |\n", "|--------|-------------|\n", "| `add_phase(phase)` | Register a phase. |\n", "| `add_group(group)` | Register a nested group. |\n", "| `add_train_phase(...)` | Construct and register a `TrainPhase`. |\n", "| `add_eval_phase(...)` | Construct and register an `EvalPhase`. |\n", "| `remove_phase(key)` | Remove a phase by index, label, or instance. |\n", "| `remove_group(key)` | Remove a group by index, label, or instance. |\n", "| `clear()` | Remove all entries. |\n", "| `flatten()` | Unroll all nested groups into a flat list of phases. |\n", "| `get_phase(key)` | Get a phase by index or label. |\n", "| `get_train_phase(key)` | Get a `TrainPhase` by index or label. |\n", "| `get_eval_phase(key)` | Get an `EvalPhase` by index or label. |\n", "| `get_group(key)` | Get a nested `PhaseGroup` by index or label. |\n", "| `items()` | Iterate over `(label, entry)` pairs. |" ] }, { "cell_type": "markdown", "id": "65", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "66", "metadata": {}, "source": [ "(05-create-experiment-experiment-callbacks)=\n", "## Experiment Callbacks" ] }, { "cell_type": "markdown", "id": "67", "metadata": {}, "source": [ "Experiment-level callbacks (`ExperimentCallback`) fire at phase and group\n", "boundaries during `run()`. They are distinct from phase-level `Callback`s that\n", "fire at batch/epoch boundaries within a single phase.\n", "\n", "| Hook | Trigger |\n", "|------|---------|\n", "| `on_experiment_start(experiment)` | Before the execution plan begins |\n", "| `on_experiment_end(experiment)` | After the execution plan completes |\n", "| `on_phase_start(experiment, phase)` | Before each phase executes |\n", "| `on_phase_end(experiment, phase)` | After each phase completes |\n", "| `on_group_start(experiment, group)` | Before each group executes |\n", "| `on_group_end(experiment, group)` | After each group completes |\n", "| `on_exception(experiment, phase, exception)` | On unhandled exception |\n", "\n", "Callbacks are registered via the constructor or `add_callback()`:\n", "\n", "```python\n", " exp = Experiment(\n", " label=\"my_exp\",\n", " callbacks=[my_callback],\n", " )\n", "\n", " # Or add later\n", " exp.add_callback(another_callback)\n", "```" ] }, { "cell_type": "markdown", "id": "68", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "69", "metadata": {}, "source": [ "(05-create-experiment-checkpointing)=\n", "## Checkpointing" ] }, { "cell_type": "markdown", "id": "70", "metadata": {}, "source": [ "Experiment-level checkpointing automatically saves the full experiment state to\n", "disk at configurable lifecycle hooks. This is useful for fault tolerance and\n", "resumption.\n", "\n", "Experiment checkpointing only supports `mode=\"disk\"` (in-memory snapshots of the\n", "full experiment state would be too large)." ] }, { "cell_type": "markdown", "id": "71", "metadata": {}, "source": [ "### Configuring Checkpointing\n", "\n", "Checkpointing is configured via the `Checkpointing` class and passed at\n", "construction time or via `set_checkpointing()`.\n", "\n", "Valid `save_on` hooks for experiment-level checkpointing:\n", "\n", "| Hook | When |\n", "|------|------|\n", "| `\"phase_start\"` | Before each phase |\n", "| `\"phase_end\"` | After each phase |\n", "| `\"group_start\"` | Before each group |\n", "| `\"group_end\"` | After each group |\n", "| `\"experiment_start\"` | Before `run()` begins |\n", "| `\"experiment_end\"` | After `run()` completes |\n", "\n", "```python\n", " from modularml import Checkpointing\n", "\n", " exp = Experiment(\n", " label=\"checkpointed_exp\",\n", " checkpointing=Checkpointing(\n", " mode=\"disk\",\n", " save_on=[\"phase_end\"],\n", " directory=\"./checkpoints\",\n", " ),\n", " )\n", "```" ] }, { "cell_type": "markdown", "id": "72", "metadata": {}, "source": [ "### Manual Checkpointing\n", "\n", "You can also save and restore checkpoints manually." ] }, { "cell_type": "code", "execution_count": null, "id": "73", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "from tempfile import TemporaryDirectory\n", "\n", "CKPT_DIR = TemporaryDirectory()\n", "\n", "# Set the checkpoint directory\n", "exp.set_checkpoint_dir(Path(CKPT_DIR.name))\n", "\n", "# Save a checkpoint\n", "ckpt_path = exp.save_checkpoint(\"after_training\", overwrite=True)\n", "print(f\"Checkpoint saved to: {ckpt_path}\")\n", "print(f\"Available checkpoints: {list(exp.available_checkpoints.keys())}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "74", "metadata": {}, "outputs": [], "source": [ "# Restore from a checkpoint (by name or path)\n", "exp.restore_checkpoint(\"after_training\")\n", "print(\"Checkpoint restored.\")" ] }, { "cell_type": "markdown", "id": "75", "metadata": {}, "source": [ "### Disabling Checkpointing\n", "\n", "Use the `disable_checkpointing()` context manager to temporarily suppress all\n", "checkpointing (both experiment-level and TrainPhase-level).\n", "\n", "```python\n", " with exp.disable_checkpointing():\n", " exp.run_phase(train_phase) # No checkpoints saved\n", "```" ] }, { "cell_type": "markdown", "id": "76", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "77", "metadata": {}, "source": [ "(05-create-experiment-serialization)=\n", "## Serialization" ] }, { "cell_type": "markdown", "id": "78", "metadata": {}, "source": [ "An `Experiment` can be fully serialized to disk via `save()` and reloaded with `load()`.\n", "This includes the model graph state, execution plan, and execution history." ] }, { "cell_type": "code", "execution_count": null, "id": "79", "metadata": {}, "outputs": [], "source": [ "SAVE_DIR = TemporaryDirectory()\n", "\n", "# Save the experiment\n", "save_path = exp.save(Path(SAVE_DIR.name) / \"my_experiment\", overwrite=True)\n", "print(f\"Experiment saved to: {save_path}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "80", "metadata": {}, "outputs": [], "source": [ "# Load the experiment\n", "loaded_exp = Experiment.load(save_path, overwrite=True)\n", "print(f\"Loaded experiment: {loaded_exp.label}\")\n", "print(f\" Model graph: {loaded_exp.model_graph}\")" ] }, { "cell_type": "markdown", "id": "81", "metadata": {}, "source": [ "The `get_config()` and `get_state()` methods provide lower-level access to the\n", "experiment's structure and mutable state for custom serialization workflows.\n", "\n", "```python\n", " config = exp.get_config() # Structure (label, plan, policy)\n", " state = exp.get_state() # Mutable state (context, history, checkpoints)\n", "\n", " # Restore\n", " exp.set_state(state)\n", "```" ] }, { "cell_type": "markdown", "id": "82", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "id": "83", "metadata": {}, "source": [ "(05-create-experiment-summary)=\n", "## Summary" ] }, { "cell_type": "markdown", "id": "84", "metadata": {}, "source": [ "### Experiment Constructor\n", "\n", "| Parameter | Type | Default | Description |\n", "|-----------|------|---------|-------------|\n", "| `label` | `str` | (required) | Name for this experiment. |\n", "| `registration_policy` | `str \\| None` | `None` | `\"raise\"`, `\"overwrite\"`, or `\"rename\"`. |\n", "| `ctx` | `ExperimentContext \\| None` | `None` | Context to bind to. |\n", "| `checkpointing` | `Checkpointing \\| None` | `None` | Auto-checkpoint configuration. |\n", "| `callbacks` | `list[ExperimentCallback] \\| None` | `None` | Experiment-level callbacks. |\n", "\n", "### Experiment Properties\n", "\n", "| Property | Type | Description |\n", "|----------|------|-------------|\n", "| `ctx` | `ExperimentContext` | The associated context. |\n", "| `model_graph` | `ModelGraph \\| None` | The registered model graph. |\n", "| `execution_plan` | `PhaseGroup` | Phases to run on `run()`. |\n", "| `history` | `list[ExperimentRun]` | All completed runs. |\n", "| `last_run` | `ExperimentRun \\| None` | Most recent run. |\n", "| `checkpointing` | `Checkpointing \\| None` | Checkpoint configuration. |\n", "| `available_checkpoints` | `dict[str, Path]` | Saved checkpoint registry. |\n", "| `exp_callbacks` | `list[ExperimentCallback]` | Registered callbacks. |\n", "\n", "### Experiment Methods\n", "\n", "| Method | Description |\n", "|--------|-------------|\n", "| `run()` | Execute the full execution plan. |\n", "| `run_phase(phase)` | Execute a single phase (records history). |\n", "| `run_group(group)` | Execute a phase group (records history). |\n", "| `preview_phase(phase)` | Execute a phase without mutating state. |\n", "| `preview_group(group)` | Execute a group without mutating state. |\n", "| `add_callback(cb)` | Register an experiment-level callback. |\n", "| `set_checkpointing(ckpt)` | Attach/replace checkpointing configuration. |\n", "| `set_checkpoint_dir(path)` | Set the checkpoint save directory. |\n", "| `save_checkpoint(name)` | Manually save a checkpoint. |\n", "| `restore_checkpoint(name)` | Restore from a saved checkpoint. |\n", "| `disable_checkpointing()` | Context manager to suppress checkpointing. |\n", "| `save(filepath)` | Serialize experiment to disk. |\n", "| `load(filepath)` | Load experiment from disk. |\n", "| `get_config()` / `from_config()` | Config serialization. |\n", "| `get_state()` / `set_state()` | State serialization. |\n", "\n", "### Phase Types\n", "\n", "| Phase | Module | Use Case |\n", "|-------|--------|----------|\n", "| `TrainPhase` | `modularml` | Mini-batch gradient training with epochs and sampling. |\n", "| `EvalPhase` | `modularml` | Forward-only evaluation on a data split. |\n", "| `FitPhase` | `modularml` | Batch fitting for scikit-learn models. |" ] }, { "cell_type": "markdown", "id": "85", "metadata": {}, "source": [ "### Next Steps\n", "\n", "- **TrainPhase:** Detailed training configuration, batch scheduling, and\n", " TrainPhase-level checkpointing — see $\\textcolor{red}{\\text{...to be added soon}}$\n", "\n", "- **EvalPhase:** Evaluation strategies, batched evaluation, and metrics —\n", " see $\\textcolor{red}{\\text{...to be added soon}}$\n", "\n", "- **FitPhase:** Batch-fit workflows for scikit-learn models —\n", " see $\\textcolor{red}{\\text{...to be added soon}}$\n" ] } ], "metadata": { "kernelspec": { "display_name": ".venv (3.10.18)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }