"""Core Experiment orchestration and execution utilities."""
from __future__ import annotations
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, overload
from modularml.core.experiment.callbacks.experiment_callback import (
ExperimentCallback,
)
from modularml.core.experiment.checkpointing import (
EXPERIMENT_HOOKS,
EXPERIMENT_NAME_TEMPLATE,
EXPERIMENT_PLACEHOLDERS,
Checkpointing,
)
from modularml.core.experiment.experiment_context import (
ExperimentContext,
RegistrationPolicy,
)
from modularml.core.experiment.phases.eval_phase import EvalPhase
from modularml.core.experiment.phases.fit_phase import FitPhase
from modularml.core.experiment.phases.phase import ExperimentPhase
from modularml.core.experiment.phases.phase_group import PhaseGroup
from modularml.core.experiment.phases.train_phase import ResultRecording, TrainPhase
from modularml.core.experiment.results.eval_results import EvalResults
from modularml.core.experiment.results.execution_meta import (
PhaseExecutionMeta,
PhaseGroupExecutionMeta,
)
from modularml.core.experiment.results.experiment_run import ExperimentRun
from modularml.core.experiment.results.fit_results import FitResults
from modularml.core.experiment.results.group_results import PhaseGroupResults
from modularml.core.experiment.results.train_results import TrainResults
from modularml.core.io.checkpoint import Checkpoint
from modularml.utils.environment.environment import IN_NOTEBOOK
from modularml.utils.logging.logger import get_logger
from modularml.utils.logging.warnings import warn
if TYPE_CHECKING:
from modularml.core.experiment.results.phase_results import PhaseResults
from modularml.core.topology.model_graph import ModelGraph
logger = get_logger(name="Experiment")
def _phase_mutates_state(phase_or_group: ExperimentPhase | PhaseGroup) -> bool:
"""Return True if executing the phase/group could modify model state."""
if isinstance(phase_or_group, EvalPhase):
return False
if isinstance(phase_or_group, PhaseGroup):
return any(_phase_mutates_state(el) for el in phase_or_group.all)
return True # TrainPhase, FitPhase, or unknown
[docs]
class Experiment:
"""High-level container coordinating phases, callbacks, and checkpoints."""
[docs]
def __init__(
self,
label: str,
registration_policy: RegistrationPolicy | str | None = None,
ctx: ExperimentContext | None = None,
checkpointing: Checkpointing | None = None,
callbacks: list[ExperimentCallback] | None = None,
):
"""
Constructs a new Experiment.
Args:
label (str):
A name to assign to this experiment.
registration_policy (RegistrationPolicy | str, optional):
Default registration policy for nodes created after this Experiment
is constructed.
ctx (ExperimentContext, optional):
Context to associate with this Experiment. If None, a new context
is created and activated.
checkpointing (Checkpointing | None, optional):
An optional Checkpointing configuration for automatically saving
the full experiment state to disk at execution lifecycle hooks
(e.g. `phase_end`, `group_end`). Must use `mode="disk"`
and `save_on` hooks from: `phase_start`, `phase_end`,
`group_start`, `group_end`. Defaults to None.
callbacks (list[ExperimentCallback] | None, optional):
An optional list of experiment-level callbacks to run during
`Experiment.run()` execution at phase/group boundaries.
Defaults to None.
"""
self.label = label
# Initialize / attach context
if ctx is None:
ctx = ExperimentContext(
experiment=self,
registration_policy=registration_policy,
)
ExperimentContext._set_active(ctx)
else:
ctx.set_experiment(self)
if registration_policy is not None:
ctx.set_registration_policy(registration_policy)
self._ctx = ctx
# Initialize phase registry
self._exec_plan = PhaseGroup(label=self.label)
# For recording execution history
self._history: list[ExperimentRun] = []
# For checkpointing model graph state
self._checkpoints: dict[str, Path] = {}
self._checkpoint_dir: Path | None = None
# Experiment-level checkpointing
self._exp_checkpointing: Checkpointing | None = None
self.set_checkpointing(checkpointing)
# Experiment-level callbacks
self._exp_callbacks: list[ExperimentCallback] = list(callbacks or [])
self._exp_callbacks.sort(key=lambda cb: cb._exec_order)
# Bool flags for guarding
# True while executing inside a callback
self._in_callback: bool = False
# True disables all checkpointin (experiment-level and TrainPhase-level)
self._checkpointing_disabled: bool = False
# ================================================
# Constructors
# ================================================
[docs]
@classmethod
def from_active_context(
cls,
label: str,
registration_policy: RegistrationPolicy | str | None = None,
checkpointing: Checkpointing | None = None,
callbacks: list[ExperimentCallback] | None = None,
) -> Experiment:
"""
Construct an Experiment using the active ExperimentContext.
Description:
Creates a new Experiment instance, but retains all nodes that have been
registered in the current ExperimentContext.
Args:
label (str):
A name to assign to this experiment.
registration_policy (RegistrationPolicy | str | None, optional):
Default registration policy for nodes created after this Experiment
is constructed.
checkpointing (Checkpointing | None, optional):
An optional Checkpointing configuration for automatically saving
the full experiment state to disk at execution lifecycle hooks
(e.g. `phase_end`, `group_end`). Must use `mode="disk"`
and `save_on` hooks from: `phase_start`, `phase_end`,
`group_start`, `group_end`. Defaults to None.
callbacks (list[ExperimentCallback] | None, optional):
An optional list of experiment-level callbacks to run during
`Experiment.run()` execution at phase/group boundaries.
Defaults to None.
Returns:
Experiment: A new Experiment utilizing the active context.
"""
active_ctx = ExperimentContext.get_active()
if active_ctx._experiment_ref is not None:
msg = "An Experiment has already been associated with the active context."
raise ValueError(msg)
return cls(
label=label,
registration_policy=registration_policy,
ctx=active_ctx,
checkpointing=checkpointing,
callbacks=callbacks,
)
# ================================================
# Properties
# ================================================
@property
def ctx(self) -> ExperimentContext:
"""Gets the context associated with this Experiment."""
return self._ctx
@property
def model_graph(self) -> ModelGraph | None:
"""Gets the ModelGraph associated with this Experiment."""
return self._ctx.model_graph
@property
def execution_plan(self) -> PhaseGroup:
"""Group of phases (and sub-groups) to be executed."""
return self._exec_plan
@property
def history(self) -> list[ExperimentRun]:
"""All completed experiment runs in chronological order."""
return list(self._history)
@property
def last_run(self) -> ExperimentRun | None:
"""Most recent ExperimentRun."""
return self._history[-1] if self._history else None
@property
def checkpointing(self) -> Checkpointing | None:
"""The experiment-level Checkpointing configuration, or None."""
return self._exp_checkpointing
@property
def available_checkpoints(self) -> dict[str, Path]:
"""All available disk checkpoints (from both TrainPhase and Experiment)."""
return dict(self._checkpoints)
@property
def exp_callbacks(self) -> list[ExperimentCallback]:
"""Experiment-level callbacks in execution order."""
return list(self._exp_callbacks)
# ================================================
# Experiment Callback Management
# ================================================
[docs]
def add_callback(self, callback: ExperimentCallback) -> None:
"""
Register an experiment-level callback.
Args:
callback (ExperimentCallback):
The callback to add.
"""
if not isinstance(callback, ExperimentCallback):
msg = f"Expected ExperimentCallback, got {type(callback)}."
raise TypeError(msg)
self._exp_callbacks.append(callback)
self._exp_callbacks.sort(key=lambda cb: cb._exec_order)
[docs]
@contextmanager
def disable_checkpointing(self):
"""
Context manager that disables all checkpointing while active.
Description:
Suppresses both experiment-level checkpointing and any
TrainPhase-level checkpointing that occurs within the block.
The previous checkpointing configuration is restored on exit,
even if an exception is raised.
Example:
Scoped checkpoint disabling:
>>> with experiment.disable_checkpointing(): # doctest: +SKIP
... experiment.run_phase(training_phase)
"""
prev = self._checkpointing_disabled
self._checkpointing_disabled = True
try:
yield
finally:
self._checkpointing_disabled = prev
# ================================================
# Checkpointing
# ================================================
[docs]
def set_checkpointing(self, checkpointing: Checkpointing | None) -> None:
"""
Attach or replace the Checkpointing configuration for this experiment.
Validates that the mode is `"disk"`, that all `save_on` hooks are
valid for an Experiment, and that the `name_template` only uses
allowed placeholders. If no `name_template` is set, the experiment
default is applied.
Args:
checkpointing (Checkpointing | None):
The Checkpointing configuration, or None to disable.
"""
if checkpointing is None:
self._exp_checkpointing = None
return
# Experiment only supports disk mode
if checkpointing.mode != "disk":
msg = (
"Experiment-level checkpointing only supports mode='disk'. "
"In-memory checkpointing of the full experiment state is "
"not supported due to memory overhead."
)
raise ValueError(msg)
# Validate hooks
invalid = set(checkpointing.save_on) - EXPERIMENT_HOOKS
if invalid:
msg = (
f"Invalid save_on hooks for Experiment: {sorted(invalid)}. "
f"Valid hooks: {sorted(EXPERIMENT_HOOKS)}."
)
raise ValueError(msg)
# Apply default template if not set
if checkpointing.name_template is None:
checkpointing.name_template = EXPERIMENT_NAME_TEMPLATE
# Validate placeholders
Checkpointing.validate_placeholders(
checkpointing.name_template,
EXPERIMENT_PLACEHOLDERS,
context_name="Experiment",
)
self._exp_checkpointing = checkpointing
# Eagerly set experiment checkpoint directory from the config
if checkpointing.directory is not None and self._checkpoint_dir is None:
self.set_checkpoint_dir(checkpointing.directory, create=True)
[docs]
def set_checkpoint_dir(self, path: Path, *, create: bool = True):
"""
Set directory used for storing experiment checkpoints.
Args:
path (Path):
Directory path.
create (bool, optional):
Whether to create directory if it does not exist.
"""
path = Path(path)
if create:
path.mkdir(parents=True, exist_ok=True)
if not path.is_dir():
msg = f"No directory exists at '{path!r}'."
raise FileExistsError(msg)
# Warn if directory already contains checkpoint files
existing = list(path.glob("*.ckpt.mml"))
if existing:
warn(
f"Checkpoint directory '{path}' already contains "
f"{len(existing)} checkpoint file(s). Existing files will "
f"only be overwritten if an exact name match occurs and "
f"overwrite=True.",
stacklevel=2,
)
self._checkpoint_dir = path
[docs]
def save_checkpoint(
self,
name: str,
*,
overwrite: bool = False,
meta: dict[str, Any] | None = None,
) -> Path:
"""
Save full experiment state to a Checkpoint file.
Creates a :class:`Checkpoint` container with the full experiment
state and serializes it to disk.
Args:
name (str):
Unique name to assign to this checkpoint.
overwrite (bool, optional):
Whether to overwrite existing checkpoints with this name.
Defaults to False.
meta (dict[str, Any], optional):
Additional meta data to attach to the checkpoint.
Returns:
Path: The saved checkpoint file path.
"""
from modularml.core.io.serialization_policy import SerializationPolicy
from modularml.core.io.serializer import serializer
if self._checkpoint_dir is None:
msg = (
"Checkpoint directory not set. Call `set_checkpoint_dir()` "
"or set `directory` on the Checkpointing config."
)
raise RuntimeError(msg)
if name in self._checkpoints and not overwrite:
msg = f"Checkpoint '{name}' already exists."
raise ValueError(msg)
filepath = self._checkpoint_dir / name
ckpt = Checkpoint()
ckpt.add_entry(key="experiment", obj=self)
if meta is not None:
for k, v in meta.items():
ckpt.add_meta(k, v)
save_path = serializer.save(
ckpt,
filepath,
policy=SerializationPolicy.BUILTIN,
overwrite=overwrite,
)
self._checkpoints[name] = Path(save_path)
return Path(save_path)
[docs]
def restore_checkpoint(self, name_or_path: str | Path) -> None:
"""
Restore state from a previously saved checkpoint.
Description:
Accepts either a checkpoint name (a key from
:attr:`available_checkpoints`) or an explicit file path. The
checkpoint type is detected automatically:
- **Experiment checkpoint** (contains an ``"experiment"`` entry):
restores the full experiment state via :meth:`set_state`.
- **ModelGraph checkpoint** (contains a ``"modelgraph"`` entry):
restores the model graph state via
:meth:`ModelGraph.restore_checkpoint`.
Args:
name_or_path (str | Path):
Either a checkpoint name registered in
:attr:`available_checkpoints`, or the file path to a
``.ckpt.mml`` checkpoint file.
Raises:
ValueError: If ``name_or_path`` is not a registered name and
does not point to an existing file.
TypeError: If the loaded checkpoint contains neither an
``"experiment"`` nor a ``"modelgraph"`` entry.
"""
from modularml.core.io.serializer import _enforce_file_suffix, serializer
# Resolve to filepath
name_or_path_str = str(name_or_path)
if name_or_path_str in self._checkpoints:
filepath = self._checkpoints[name_or_path_str]
else:
filepath = Path(name_or_path)
filepath = Path(filepath)
if filepath.suffix == "":
filepath = _enforce_file_suffix(path=filepath, cls=Checkpoint)
if not filepath.exists():
msg = (
f"No checkpoint named '{name_or_path}' exists and no file "
f"found at '{filepath}'. "
f"Available checkpoints: {list(self._checkpoints.keys())}."
)
raise ValueError(msg)
ckpt: Checkpoint = serializer.load(filepath)
# Auto-detect checkpoint type and restore
if "experiment" in ckpt.entries:
exp_state = ckpt.entries["experiment"].entry_state
self._history = exp_state["history"]
self.model_graph.set_state(exp_state["mg_state"])
elif "modelgraph" in ckpt.entries:
self.model_graph.restore_checkpoint(filepath)
else:
msg = (
f"Checkpoint at '{filepath}' does not contain a recognized "
f"entry. Expected 'experiment' or 'modelgraph' key, "
f"found: {list(ckpt.entries.keys())}."
)
raise TypeError(msg)
def _save_experiment_checkpoint(self, label: str) -> None:
"""
Save the full experiment to disk using the checkpointing config.
Args:
label (str):
The phase or group label to use as the checkpoint key.
"""
ckpt = self._exp_checkpointing
name = ckpt.format_name(label=label)
# Ensure checkpoint directory
if self._checkpoint_dir is None:
if ckpt.directory is not None:
self.set_checkpoint_dir(ckpt.directory, create=True)
else:
msg = "Cannot save experiment checkpoint: no checkpoint directory set."
raise RuntimeError(msg)
save_path = self.save_checkpoint(
name=name,
overwrite=ckpt.overwrite,
)
ckpt.record_disk(key=label, path=save_path)
# ================================================
# Execution
# ================================================
# Private helpers
def _execute_training(
self,
phase: TrainPhase,
*,
show_sampler_progress: bool = True,
show_training_progress: bool = True,
persist_progress: bool = IN_NOTEBOOK,
persist_epoch_progress: bool = IN_NOTEBOOK,
val_loss_metric: str = "val_loss",
) -> TrainResults:
"""
Executes a training phase on this experiment.
Description:
The provided TrainPhase will be executed regardless of whether it
is registered to this Experiment (`execution_plan`).
**This will mutate the experiment state, but history will not be
recorded.**
Args:
phase (TrainPhase):
Training phase to be executed.
show_sampler_progress (bool, optional):
Whether to show a progress bar for sampler batching.
Defaults to True.
show_training_progress (bool, optional):
Whether to show a progress bar for training execution.
Defaults to True.
persist_progress (bool, optional):
Whether to leave all epoch progress bars shown after they complete.
Defaults to `IN_NOTEBOOK` (True if working in a notebook, False if in
a Python script).
persist_epoch_progress (bool, optional):
Whether to leave all per-epoch training bars shown after they complete.
Defaults to `IN_NOTEBOOK` (True if working in a notebook, False if in
a Python script).
val_loss_metric (str, optional):
The name of a recorded ValidationLossMetrics to show in the progress
bar. Results must be tracked, and `val_loss_metric` must be an existing
loss metric. Otherwise, no val_loss field will be shown in the progress
bar. Defaults to `"val_loss"`.
Returns:
TrainResults: Tracked results from training.
"""
# Ensure active nodes are not frozen
self.model_graph.unfreeze(phase.active_nodes)
# Run training and track results
res = TrainResults(label=phase.label)
recording = phase.result_recording
# For LAST mode, find any EarlyStopping callback with restore_best
early_stop = None
if recording == ResultRecording.LAST:
from modularml.callbacks.early_stopping import EarlyStopping
for cb in phase.callbacks:
if isinstance(cb, EarlyStopping) and cb.restore_best:
early_stop = cb
break
best_ctxs: list = []
prev_epoch = -1
for ctx in phase.iter_execution(
results=res,
show_sampler_progress=show_sampler_progress,
show_training_progress=show_training_progress,
persist_progress=persist_progress,
persist_epoch_progress=persist_epoch_progress,
val_loss_metric=val_loss_metric,
):
self.model_graph.train_step(
ctx=ctx,
losses=phase.losses,
active_nodes=phase.active_nodes,
)
if recording == ResultRecording.ALL:
res.add_execution_context(ctx=ctx)
elif recording == ResultRecording.LAST:
# On epoch boundary, snapshot best and clear
if ctx.epoch_idx != prev_epoch and prev_epoch >= 0:
if early_stop and early_stop.best_epoch == prev_epoch:
best_ctxs = list(res._execution)
res._execution.clear()
res._series_cache.clear()
res.add_execution_context(ctx=ctx)
prev_epoch = ctx.epoch_idx
# NONE: skip recording execution contexts entirely
# LAST mode: resolve final vs best epoch
if recording == ResultRecording.LAST and early_stop is not None:
# Check if the final completed epoch was also the best
if early_stop.best_epoch == prev_epoch:
best_ctxs = list(res._execution)
# If best epoch differs from the final epoch, restore snapshot
if best_ctxs and early_stop.best_epoch != prev_epoch:
res._execution = best_ctxs
res._series_cache.clear()
return res
def _execute_evaluation(
self,
phase: EvalPhase,
*,
show_eval_progress: bool = False,
persist_progress: bool = IN_NOTEBOOK,
) -> EvalResults:
"""
Executes an evaluation phase on this experiment.
Description:
The provided EvalPhase will be executed regardless of whether it
is registered to this Experiment (`execution_plan`).
**This will mutate the experiment state, but history will not be
recorded.**
Args:
phase (EvalPhase):
Evaluation phase to be executed.
show_eval_progress (bool, optional):
Whether to show a progress bar for eval batches. Defaults to False.
persist_progress (bool, optional):
Whether to leave all eval progress bars shown after they complete.
Defaults to `IN_NOTEBOOK` (True if working in a notebook, False if in
a Python script).
Returns:
EvalResults: Tracked results from evaluation.
"""
# Ensure all nodes are frozen
self.model_graph.freeze()
# Run evaluation and track results
res = EvalResults(label=phase.label)
for ctx in phase.iter_execution(
results=res,
show_eval_progress=show_eval_progress,
persist_progress=persist_progress,
):
self.model_graph.eval_step(
ctx=ctx,
losses=phase.losses,
active_nodes=phase.active_nodes,
)
res.add_execution_context(ctx=ctx)
return res
def _execute_fit(
self,
phase: FitPhase,
) -> FitResults:
"""
Executes a fit phase on this experiment.
Description:
The provided FitPhase will be executed regardless of whether it
is registered to this Experiment (`execution_plan`).
**This will mutate the experiment state, but history will not be
recorded.**
Args:
phase (FitPhase):
Fit phase to be executed.
Returns:
FitResults: Tracked results from fitting.
"""
res = FitResults(label=phase.label)
for ctx in phase.iter_execution(results=res):
self.model_graph.fit_step(
ctx=ctx,
losses=phase.losses,
active_nodes=phase.active_nodes,
freeze_after_fit=phase.freeze_after_fit,
)
res.add_execution_context(ctx=ctx)
return res
def _execute_phase_with_meta(
self,
phase: TrainPhase | EvalPhase | FitPhase,
**kwargs,
) -> tuple[PhaseResults, PhaseExecutionMeta]:
"""
Wraps phase execution with meta data.
The phase is executed, with results and meta data returned.
**This will mutate the experiment state, but history will not be
recorded.**
"""
# ------------------------------------------------
# Propagate checkpoint directory to TrainPhase if needed
# ------------------------------------------------
if (
isinstance(phase, TrainPhase)
and phase.checkpointing is not None
and phase.checkpointing.mode == "disk"
and phase.checkpointing.directory is None
):
exp_dir = self._checkpoint_dir
if exp_dir is None and self._exp_checkpointing is not None:
exp_dir = self._exp_checkpointing.directory
if exp_dir is not None:
phase_dir = exp_dir / phase.label
phase_dir.mkdir(parents=True, exist_ok=True)
phase.checkpointing._directory = phase_dir
# Skip callbacks and checkpointing when inside a callback
run_hooks = not self._in_callback
run_ckpt = run_hooks and not self._checkpointing_disabled
# ------------------------------------------------
# on_phase_start
# - Run experiment callback
# - Run experiment checkpointing
# ------------------------------------------------
if run_hooks:
self._in_callback = True
try:
for cb in self._exp_callbacks:
cb._on_phase_start(experiment=self, phase=phase)
finally:
self._in_callback = False
if (
run_ckpt
and (self._exp_checkpointing is not None)
and (self._exp_checkpointing.should_save("phase_start"))
):
self._save_experiment_checkpoint(label=phase.label)
# ------------------------------------------------
# run phase
# - modifies experiment state but does not update history
# ------------------------------------------------
phase_start = datetime.now()
if isinstance(phase, TrainPhase):
train_keys = {
"show_sampler_progress",
"show_training_progress",
"persist_progress",
"persist_epoch_progress",
"val_loss_metric",
}
phase_res: TrainResults = self._execute_training(
phase,
**{k: v for k, v in kwargs.items() if k in train_keys},
)
elif isinstance(phase, EvalPhase):
eval_keys = {"show_eval_progress", "persist_progress"}
phase_res: EvalResults = self._execute_evaluation(
phase,
**{k: v for k, v in kwargs.items() if k in eval_keys},
)
elif isinstance(phase, FitPhase):
phase_res: FitResults = self._execute_fit(phase)
else:
msg = f"Expected type of TrainPhase, EvalPhase, or FitPhase. Received: {type(phase)}."
raise TypeError(msg)
# Create meta for run
phase_end = datetime.now()
phase_meta = PhaseExecutionMeta(
label=phase.label,
started_at=phase_start,
ended_at=phase_end,
status="completed",
)
# ------------------------------------------------
# on_phase_end
# - Run experiment callback
# - Run experiment checkpointing
# ------------------------------------------------
if run_hooks:
self._in_callback = True
try:
for cb in self._exp_callbacks:
cb.on_phase_end(
experiment=self,
phase=phase,
results=phase_res,
)
finally:
self._in_callback = False
if (
run_ckpt
and (self._exp_checkpointing is not None)
and self._exp_checkpointing.should_save("phase_end")
):
self._save_experiment_checkpoint(label=phase.label)
return phase_res, phase_meta
def _execute_group_with_meta(
self,
group: PhaseGroup,
**kwargs,
) -> tuple[PhaseGroupResults, PhaseGroupExecutionMeta]:
"""
Wraps group execution with meta data.
The group is executed, with results and meta data returned.
**This will mutate the experiment state, but history will not be
recorded.**
"""
if not isinstance(group, PhaseGroup):
msg = f"Expected type of PhaseGroup. Received: {type(group)}."
raise TypeError(msg)
# Skip callbacks and checkpointing when inside a callback
run_hooks = not self._in_callback
run_ckpt = run_hooks and not self._checkpointing_disabled
# ------------------------------------------------
# on_group_start
# - Run experiment callback
# - Run experiment checkpointing
# ------------------------------------------------
if run_hooks:
self._in_callback = True
try:
for cb in self._exp_callbacks:
cb.on_group_start(experiment=self, group=group)
finally:
self._in_callback = False
if (
run_ckpt
and (self._exp_checkpointing is not None)
and self._exp_checkpointing.should_save("group_start")
):
self._save_experiment_checkpoint(label=group.label)
# ------------------------------------------------
# run phase group
# - construct result container
# - run each phase in order
# ------------------------------------------------
group_results = PhaseGroupResults(label=group.label)
group_meta = PhaseGroupExecutionMeta(
label=group.label,
started_at=datetime.now(),
ended_at=None,
)
for element in group.all:
if isinstance(element, ExperimentPhase):
# Run phase with meta tracking
phase_res, phase_meta = self._execute_phase_with_meta(
phase=element,
**kwargs,
)
# Record phase results
group_results.add_result(phase_res)
# Record phase meta
group_meta.add_child(phase_meta)
elif isinstance(element, PhaseGroup):
# Run group with meta tracking
sub_res, sub_meta = self._execute_group_with_meta(
group=element,
**kwargs,
)
# Record group results
group_results.add_result(sub_res)
# Record group meta
group_meta.add_child(sub_meta)
else:
msg = (
"Unsupported group element. Expected ExperimentPhase "
f"or PhaseGroup. Received: {type(element)}."
)
raise TypeError(msg)
# Update group meta
group_meta.ended_at = datetime.now()
# ------------------------------------------------
# on_group_end
# - Run experiment callback
# - Run experiment checkpointing
# ------------------------------------------------
if run_hooks:
self._in_callback = True
try:
for cb in self._exp_callbacks:
cb.on_group_end(
experiment=self,
group=group,
results=group_results,
)
finally:
self._in_callback = False
if (
run_ckpt
and (self._exp_checkpointing is not None)
and self._exp_checkpointing.should_save("group_end")
):
self._save_experiment_checkpoint(label=group.label)
return group_results, group_meta
# Run API
@overload
def run_phase(
self,
phase: TrainPhase,
*,
show_sampler_progress: bool = True,
show_training_progress: bool = True,
persist_progress: bool = IN_NOTEBOOK,
persist_epoch_progress: bool = IN_NOTEBOOK,
val_loss_metric: str = "val_loss",
) -> TrainResults: ...
@overload
def run_phase(
self,
phase: EvalPhase,
*,
show_eval_progress: bool = False,
persist_progress: bool = IN_NOTEBOOK,
) -> EvalResults: ...
@overload
def run_phase(
self,
phase: FitPhase,
) -> FitResults: ...
[docs]
def run_phase(
self,
phase: ExperimentPhase,
**kwargs,
) -> PhaseResults:
"""
Execute a single phase and record the results.
Description:
The provided :class:`ExperimentPhase` runs regardless of whether it
is registered on :attr:`execution_plan`, and its outputs are stored
under :attr:`history`. This mutates experiment state. To run a phase
without mutating state, use :meth:`preview_phase`.
Args:
phase (ExperimentPhase):
Phase to run.
**kwargs (Any):
Display flags forwarded to the phase-specific run method.
Returns:
PhaseResults: Results produced by the executed phase.
"""
# Initiallize run attributes
started_at = datetime.now()
status = "completed"
# Run phase and record phase-level meta data
try:
res, meta = self._execute_phase_with_meta(
phase=phase,
**kwargs,
)
except Exception:
status = "failed"
raise
finally:
ended_at = datetime.now()
# Construct experiment
run = ExperimentRun(
label=phase.label,
started_at=started_at,
ended_at=ended_at,
status=status,
results=res,
execution_meta=meta,
)
# Update internal history
self._history.append(run)
# Directly return phase results
return res
[docs]
def run_group(
self,
group: PhaseGroup,
**kwargs,
) -> PhaseGroupResults:
"""
Execute all phases in a PhaseGroup.
Description:
The provided PhaseGroup will be executed regardless
of whether it is registered to this Experiment (`execution_plan`),
and its outputs will be recorded under `history`.
**This will mutate the experiment state**. To run a group without
mutating the experiment state, use `preview_group(...)`.
Args:
group (PhaseGroup):
The PhaseGroup to execute.
**kwargs:
Display flags forwarded to each phase's run method.
Returns:
PhaseGroupResults:
Results of the executed phase group.
"""
# Initiallize run attributes
started_at = datetime.now()
status = "completed"
# Run group and record phase-level meta data
try:
res, meta = self._execute_group_with_meta(
group=group,
**kwargs,
)
except Exception:
status = "failed"
raise
finally:
ended_at = datetime.now()
# Construct experiment
run = ExperimentRun(
label=group.label,
started_at=started_at,
ended_at=ended_at,
status=status,
results=res,
execution_meta=meta,
)
# Update internal history
self._history.append(run)
# Directly return group results
return res
[docs]
def run(self, **kwargs) -> PhaseGroupResults:
"""
Run the registered execution plan.
Description:
All phases and phase groups added to this experiment
will be executed in the order they were added.
Execution history can be viewed via the `history` attribute.
Args:
**kwargs:
Additional arguments to be passed to each executed phase.
Returns:
PhaseGroupResults:
Results of all executed phases.
"""
# ------------------------------------------------
# on_experiment_start
# - Run experiment callback
# - Run experiment checkpointing
# ------------------------------------------------
for cb in self._exp_callbacks:
cb.on_experiment_start(experiment=self)
if (
self._exp_checkpointing is not None
) and self._exp_checkpointing.should_save("experiment_start"):
self._save_experiment_checkpoint(label="START")
# ------------------------------------------------
# run all phases
# - callback/checkpointing logic handled internally
# ------------------------------------------------
try:
res = self.run_group(group=self._exec_plan, **kwargs)
except BaseException as exc:
self._in_callback = True
try:
for cb in self._exp_callbacks:
cb._on_exception(
experiment=self,
phase=None,
exception=exc,
)
finally:
self._in_callback = False
raise
# ------------------------------------------------
# on_experiment_end
# - Run experiment callback
# - Run experiment checkpointing
# ------------------------------------------------
for cb in self._exp_callbacks:
cb.on_experiment_end(experiment=self)
if (
self._exp_checkpointing is not None
) and self._exp_checkpointing.should_save("experiment_end"):
self._save_experiment_checkpoint(label="END")
# Return results
return res
# Preview API
@overload
def preview_phase(
self,
phase: TrainPhase,
*,
show_sampler_progress: bool = True,
show_training_progress: bool = True,
persist_progress: bool = IN_NOTEBOOK,
persist_epoch_progress: bool = IN_NOTEBOOK,
val_loss_metric: str = "val_loss",
) -> TrainResults: ...
@overload
def preview_phase(
self,
phase: EvalPhase,
*,
show_eval_progress: bool = False,
persist_progress: bool = IN_NOTEBOOK,
) -> EvalResults: ...
[docs]
def preview_phase(
self,
phase: ExperimentPhase,
**kwargs,
) -> PhaseResults:
"""
Execute a phase without mutating the Experiment state.
Description:
The provided :class:`ExperimentPhase` runs against the current
experiment state and any changes are reverted afterward. Execution
is not recorded in :attr:`history`. Use :meth:`run_phase` to persist
results.
Args:
phase (ExperimentPhase):
Phase to run.
**kwargs (Any):
Display flags forwarded to the phase-specific run method.
Returns:
PhaseResults: Results produced by the previewed phase.
"""
# Snapshot state only for phases that mutate model weights
needs_restore = _phase_mutates_state(phase)
state = self.get_state() if needs_restore else None
# Execute phase with checkpointing disabled
with self.disable_checkpointing():
res, _ = self._execute_phase_with_meta(
phase=phase,
**kwargs,
)
# Restore experiment state
if needs_restore:
self.set_state(state=state)
return res
[docs]
def preview_group(
self,
group: PhaseGroup,
**kwargs,
) -> PhaseGroupResults:
"""
Executes a given phase group without mutating the Experiment state.
Description:
The provided PhaseGroup will be executed on the current
experiment state. Any state changes are reverted after the group
is executed. Execution is not recorded in `history`.
To run a group with history tracking, use `run_group(...)`.
Args:
group (PhaseGroup):
The phase group to run.
**kwargs:
Display flags forwarded to the phase-specific run method.
Returns:
PhaseGroupResults: Phase group results.
"""
# Snapshot state only for groups containing mutating phases
needs_restore = _phase_mutates_state(group)
state = self.get_state() if needs_restore else None
# Execute group with checkpointing disabled
with self.disable_checkpointing():
res, _ = self._execute_group_with_meta(
group=group,
**kwargs,
)
# Restore experiment state
if needs_restore:
self.set_state(state=state)
return res
# ================================================
# Configurable
# ================================================
[docs]
def get_config(self) -> dict[str, Any]:
"""
Retrieve the configuration details for this experiment.
This does not contain state information of the underlying model graph.
"""
return {
"label": self.label,
"registration_policy": self._ctx._policy.value,
"execution_plan": self._exec_plan.get_config(),
}
[docs]
@classmethod
def from_config(cls, config: dict[str, Any]) -> Experiment:
"""
Reconstructs an Experiment from configuration details.
This does not restore state information.
Args:
config (dict[str, Any]):
Configuration payload returned by :meth:`get_config`.
Returns:
Experiment: Newly constructed experiment bound to the active context.
"""
active_ctx = ExperimentContext.get_active()
exp = cls(
label=config["label"],
registration_policy=config.get("registration_policy"),
ctx=active_ctx,
)
# Rebuild execution plan
exec_plan_cfg = config.get("execution_plan")
if exec_plan_cfg is not None:
exp._exec_plan = PhaseGroup.from_config(exec_plan_cfg)
return exp
# ================================================
# Stateful
# ================================================
[docs]
def get_state(self) -> dict[str, Any]:
"""Return a deep copy of mutable experiment state."""
return {
"ctx": self.ctx.get_state(),
"history": deepcopy(self._history),
"checkpoints": self._checkpoints.copy(),
}
[docs]
def set_state(self, state: dict[str, Any]) -> None:
"""
Restore experiment state from :meth:`get_state` output.
Args:
state (dict[str, Any]): Serialized snapshot captured by :meth:`get_state`.
"""
# Restore context state
self._ctx.set_state(state["ctx"])
# Restore history
self._history = state.get("history", [])
# Restore recorded checkpoints
self._checkpoints = state.get("checkpoints", {})
# ================================================
# Serialization
# ================================================
[docs]
def save(self, filepath: Path, *, overwrite: bool = False) -> Path:
"""
Serializes this experiment to the specified filepath.
Args:
filepath (Path):
File location to save to. Note that the suffix may be overwritten
to enforce the ModularML file extension schema.
overwrite (bool, optional):
Whether to overwrite any existing file at the save location.
Defaults to False.
Returns:
Path: The actual filepath at which the experiment was saved.
"""
from modularml.core.io.serialization_policy import SerializationPolicy
from modularml.core.io.serializer import serializer
return serializer.save(
self,
filepath,
policy=SerializationPolicy.BUILTIN,
overwrite=overwrite,
)
[docs]
@classmethod
def load(
cls,
filepath: Path,
*,
checkpoint_dir: Path | None = None,
allow_packaged_code: bool = False,
overwrite: bool = False,
) -> Experiment:
"""
Load an Experiment from file.
Args:
filepath (Path):
File location of a previously saved Experiment.
checkpoint_dir (Path | None, optional):
Directory to extract saved checkpoints into. If the
serialized experiment contains checkpoint artifacts and
this is None, the checkpoints will not be restored and
a warning will be emitted. Defaults to None.
allow_packaged_code : bool
Whether bundled code execution is allowed.
overwrite (bool):
Whether to replace any colliding node registrations in ExperimentContext
If False, new IDs are assigned to the reloaded nodes comprising the
graph. Otherwise, any collision are overwritten with the saved nodes.
Defaults to False.
It is recommended to only reload an Experiment into a new/empty
`ExperimentContext`.
Returns:
Experiment: The reloaded Experiment.
"""
from modularml.core.io.serializer import _enforce_file_suffix, serializer
# Append proper suffix only if no suffix is given
if Path(filepath).suffix == "":
filepath = _enforce_file_suffix(path=filepath, cls=cls)
return serializer.load(
filepath,
allow_packaged_code=allow_packaged_code,
overwrite=overwrite,
extras={"checkpoint_dir": checkpoint_dir}
if checkpoint_dir is not None
else None,
)