How to: Use Callbacks#
Callbacks let you inject custom logic at well-defined points in a phase lifecycle
without subclassing the phase itself. They are attached to a TrainPhase,
EvalPhase, or FitPhase and are called automatically at each lifecycle boundary.
Note: This notebook covers phase-level callbacks (
Callbacksubclasses) that fire at batch/epoch boundaries inside a single phase. Experiment-level callbacks (ExperimentCallback) that fire at phase/group boundaries are covered in How to: Create and Use an Experiment.
This notebook covers:
%matplotlib inline
import numpy as np
from modularml import (
AppliedLoss,
EvalPhase,
Experiment,
FeatureSet,
Loss,
ModelGraph,
ModelNode,
Optimizer,
TrainPhase,
)
from modularml.models.torch import SequentialMLP
from modularml.samplers import SimpleSampler
Experiment Setup#
We set up a small synthetic experiment used throughout this notebook. The setup follows the pattern from How to: Create and Use an Experiment.
rng = np.random.default_rng(42)
# Synthetic data: 500 samples, 20-d feature, 1-d target
fs = FeatureSet.from_dict(
label="SensorData",
data={
"voltage": list(rng.standard_normal((500, 20))),
"soh": list(rng.standard_normal((500, 1))),
},
feature_keys="voltage",
target_keys="soh",
)
fs.split_random(ratios={"train": 0.7, "val": 0.15, "test": 0.15}, seed=13)
fs_ref = fs.reference(features="voltage", targets="soh")
mn_mlp = ModelNode(
label="MLP",
model=SequentialMLP(output_shape=(1, 1), n_layers=2, hidden_dim=16),
upstream_ref=fs_ref,
)
graph = ModelGraph(
label="SimpleGraph",
nodes=[mn_mlp],
optimizer=Optimizer("adam", opt_kwargs={"lr": 1e-3}, backend="torch"),
)
graph.build()
exp = Experiment.from_active_context(label="my_experiment")
mse_loss = AppliedLoss(
loss=Loss("mse", backend="torch"),
on="MLP",
inputs=["outputs", "targets"],
)
print(f"Splits: {fs.available_splits}")
The Callback Lifecycle#
Every Callback subclass can override up to seven hook methods. They are called
automatically by the phase in the order shown below:
on_phase_start
for each epoch:
on_epoch_start
for each batch:
on_batch_start
[forward / backward pass]
on_batch_end
on_epoch_end
on_phase_end
on_exception # called if any hook or step raises
Hook |
Signature |
When it fires |
|---|---|---|
|
|
Once before the first epoch |
|
|
Once after the last epoch |
|
|
Before each epoch’s first batch |
|
|
After each epoch’s last batch |
|
|
Before each batch |
|
|
After each batch |
|
|
On any unhandled exception |
Each hook may return any value (or None). Non-None return values are automatically
wrapped in a CallbackResult and stored in the phase’s PhaseResults container,
keyed by the callback’s label.
Writing a Custom Callback#
Subclass Callback and override only the hooks you need. The only abstract
requirement is implementing get_config() and from_config() for serialization.
The example below prints the mean training loss at the end of every epoch.
from typing import Any
from modularml.core.experiment.callbacks.callback import Callback
class LossLogger(Callback):
"""Prints the mean training loss at the end of each epoch."""
def __init__(self, node: str = "MLP", label: str | None = None):
super().__init__(label=label or "LossLogger")
self.node = node
self._epoch_losses: list[float] = []
def on_epoch_end(self, *, experiment, phase, exec_ctx, results=None):
epoch_idx = exec_ctx.epoch_idx
# Use the results provided to this hook
if results is not None:
# Grab the recorded losses for this epoch
epoch_losses = results.losses(node=self.node).where(epoch=epoch_idx)
# Since we likely have a loss recorded for each batch in the epoch,
# we can collapse losses via averaging
if len(epoch_losses.values()) > 0:
collapsed = epoch_losses.collapse(axis="batch", reducer="mean").one()
# Grab the trainable component of this loss
loss_val = collapsed.trainable
self._epoch_losses.append(loss_val)
print(f" epoch {epoch_idx:>3}: train_loss = {loss_val:.4f}")
def get_config(self) -> dict[str, Any]:
return {"callback_type": self.__class__.__name__, "node": self.node, "label": self.label}
@classmethod
def from_config(cls, config: dict) -> "LossLogger":
return cls(node=config["node"], label=config["label"])
Callbacks are registered on a phase via add_callback() or at construction time
via the callbacks= parameter.
loss_logger = LossLogger(node="MLP")
train_phase = TrainPhase.from_split(
label="train",
split="train",
sampler=SimpleSampler(batch_size=32, shuffle=True, seed=42),
losses=[mse_loss],
n_epochs=3,
callbacks=[loss_logger],
)
print(f"Attached callbacks: {[cb.label for cb in train_phase.callbacks]}")
train_results = exp.run_phase(train_phase)
print(f"\nRecorded epoch losses: {loss_logger._epoch_losses}")
ExecutionContext#
In addition to the current phase results, epoch- and batch-level hooks
receive an ExecutionContext (exec_ctx) with the following fields:
Field |
Type |
Description |
|---|---|---|
|
|
Label of the currently executing phase |
|
|
Zero-based epoch index ( |
|
|
Zero-based batch index within the epoch |
|
|
Pre-materialized input batches for all head nodes |
The Evaluation Built-in Callback#
Evaluation runs a full EvalPhase at configurable epoch intervals and stores
the resulting EvalResults as an EvaluationCallbackResult. It is the primary
tool for tracking held-out performance during training.
Evaluation(
eval_phase: EvalPhase,
every_n_epochs: int = 1,
run_on_start: bool = False,
label: str | None = None,
metrics: list[EvaluationMetric] | None = None,
)
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
(required) |
The evaluation phase to run. |
|
|
|
Run evaluation every N completed epochs. |
|
|
|
If |
|
|
|
Result key (defaults to |
|
|
|
Metric extractors to run after each evaluation. |
A convenience constructor Evaluation.from_split() creates the inner EvalPhase
automatically:
from modularml.callbacks import Evaluation
eval_cb = Evaluation.from_split(
label="eval_val",
split="val",
every_n_epochs=1,
)
train_phase_with_eval = TrainPhase.from_split(
label="train_with_eval",
split="train",
sampler=SimpleSampler(batch_size=32, shuffle=True, seed=42),
losses=[mse_loss],
n_epochs=3,
callbacks=[eval_cb],
)
train_results = exp.run_phase(train_phase_with_eval)
print("Training with Evaluation callback complete.")
How Evaluation Works#
Evaluation fires in on_epoch_end. At each selected epoch it calls
experiment.preview_phase(phase=self.eval_phase), a stateless forward pass
that does not mutate experiment history or optimizer state.
The evaluation result is wrapped in an EvaluationCallbackResult and stored
in the TrainResults of the parent TrainPhase, keyed by the callback label.
print(train_results.callbacks())
eval_cbs = train_results.callbacks(kind="evaluation").values()
print(eval_cbs)
# And the eval results in those callbacks can be accessed with
[eval_cb.eval_results for eval_cb in eval_cbs]
While running evaluation on some set of data during callbacks can be useful, a more common case is to record metrics on an evaluation forward pass.
This is where EvaluationMetric and EvalLossMetric come in.
It is equivalent to defining an Evaluation callback with an attached loss; however,
MetricCallbacks provide more-convenient access to scalar values produced by callback hooks.
The EvalLossMetric Built-in Metric#
EvalLossMetric is an EvaluationMetric that extracts a scalar loss value
from an Evaluation result and logs it to the MetricStore under a chosen name
(default: "val_loss"). Pass it to Evaluation via the metrics= argument.
EvalLossMetric(
loss: AppliedLoss,
reducer: Literal["sum", "mean"] = "mean",
name: str = "val_loss",
)
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
(required) |
Applied loss to track. Appended to the inner |
|
|
|
How to aggregate per-batch losses into a scalar. |
|
|
|
Metric name logged to the |
from modularml.callbacks import EvalLossMetric, Evaluation
# We define a loss to be applied during our Evaluation callback
val_loss_metric = EvalLossMetric(
loss=AppliedLoss(
loss=Loss("mse", backend="torch"),
on="MLP",
inputs=["targets", "outputs"],
),
reducer="mean",
name="val_loss",
)
# ^ note that metrics with name "val_loss" will be shown in the progress bars during training
eval_cb_with_metric = Evaluation.from_split(
label="eval_val",
split="val",
every_n_epochs=1,
metrics=[val_loss_metric],
)
train_phase_tracked = TrainPhase.from_split(
label="train_tracked",
split="train",
sampler=SimpleSampler(batch_size=32, shuffle=True, seed=42),
losses=[mse_loss],
n_epochs=5,
callbacks=[eval_cb_with_metric],
)
train_results = exp.run_phase(train_phase_tracked)
print("Training complete.")
Accessing Callback Results#
After training, callback results are stored in the parent phase under a callbacks attribute.
Metrics are available under the metrics attribute.
# Here we grab the training results directly from the experiment's history
last_train_results = exp.last_run.results
# All result attributes return AxisSeries; it's essentially a queriable, multi-keyed dict
# More on those in a later notebook
val_losses = last_train_results.metrics().where(name="val_loss").values()
print("val_loss per epoch:")
for entry in val_losses:
print(f" epoch {entry.epoch_idx}: {entry.value:.4f}")
Full Evaluation results are a little more annoying to access as all values returned by a callback hook are wrapped in a CallbackResult container.
For callbacks that run an EvalPhase, our data structure is along the lines of:
CallbackResults –> EvalResults –> [tensors, losses, metrics]
print(last_train_results.callbacks().axes)
print("Axis values:")
for k, vs in last_train_results.callbacks().axes_values().items():
print(f" {k}: {vs}")
# Grab the underlying EvalResults from epoch 3
cb_res = last_train_results.callbacks(kind="evaluation").where(epoch=3).one()
print(cb_res.eval_results)
# Output tensors accessed directly
pred = cb_res.stacked_tensors(node="MLP", domain="outputs", fmt="np").reshape(-1)
true = cb_res.stacked_tensors(node="MLP", domain="targets", fmt="np").reshape(-1)
import matplotlib.pyplot as plt
plt.figure(figsize=(3,3))
plt.scatter(pred, true)
plt.show()
Pretty terrible results :(
But at least we can run callbacks.
The ArtifactResult - logging rich non-scalar objects#
MetricResult is restricted to scalar return types.
For richer objects (e.g., matplotlib figures, pandas DataFrames, numpy arrays, text), wrap your custom callback returns in an ArtifactResult.
ArtifactResult(
artifact_name: str, # stable name, e.g. "val_scatter"
artifact: Any, # any Python object
)
When an ArtifactResult is returned, the framework:
Stores the result in
PhaseResults._callbacks(accessible viacallbacks(kind="artifact"))Also stores it in
PhaseResults._artifacts(accessible viaresults.artifacts())
This provides more convenient access to artifacts, similar to the access path of metrics.
If the Experiment was created with a ResultsConfig(results_dir=...), the artifact is
serialized to disk automatically. entry.artifact transparently deserializes it on access.
Let’s create a custom callback that produces matplotlib figures at the end of each epoch.
import matplotlib.pyplot as plt
from modularml.callbacks import ArtifactResult
class ScatterPlotCallback(Callback):
"""Produces a pred-vs-true scatter plot at the end of every epoch."""
def __init__(self, node: str = "MLP", label: str | None = None):
super().__init__(label=label or "ScatterPlot")
self.node = node
def on_epoch_end(self, *, experiment, phase, exec_ctx, results=None):
# Only run every other epoch to keep the demo quick
if exec_ctx.epoch_idx % 2 != 0:
return None
# Pull predictions from the most recent eval callback result
if results is None:
return None
eval_cbs = results.callbacks(kind="evaluation").values()
if not eval_cbs:
return None
last_eval = eval_cbs[-1]
pred = last_eval.stacked_tensors(node=self.node, domain="outputs", fmt="np").reshape(-1)
true = last_eval.stacked_tensors(node=self.node, domain="targets", fmt="np").reshape(-1)
fig, ax = plt.subplots(figsize=(3, 3))
ax.scatter(pred, true, alpha=0.4, s=10)
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
ax.set_title(f"Epoch {exec_ctx.epoch_idx}")
plt.tight_layout()
plt.close()
return ArtifactResult(artifact_name="val_scatter", artifact=fig)
def get_config(self):
return {"callback_type": self.__class__.__name__, "node": self.node, "label": self.label}
@classmethod
def from_config(cls, config):
return cls(node=config["node"], label=config["label"])
scatter_cb = ScatterPlotCallback(node="MLP")
eval_cb_for_scatter = Evaluation.from_split(
label="eval_val",
split="val",
every_n_epochs=1,
)
train_phase_with_artifacts = TrainPhase.from_split(
label="train_with_artifacts",
split="train",
sampler=SimpleSampler(batch_size=32, shuffle=True, seed=42),
losses=[mse_loss],
n_epochs=4,
callbacks=[eval_cb_for_scatter, scatter_cb],
)
train_results = exp.run_phase(train_phase_with_artifacts)
print("Artifact names:", train_results.artifact_names())
# List all artifact names produced during the phase
print("Artifact names:", train_results.artifact_names())
# Query via the artifact store; keyed by (name, epoch, batch)
scatter_series = train_results.artifacts().where(name="val_scatter")
print(f"Scatter plots recorded: {len(scatter_series)}")
for entry in scatter_series.values():
print(f" epoch {entry.epoch_idx}: {type(entry.artifact).__name__}")
# entry.artifact transparently loads from disk when ResultsConfig(results_dir=...) is set
entry = scatter_series.where(epoch=0).one()
entry.artifact
Summary#
Callback Base Class#
Method |
Override to… |
|---|---|
|
Run once before the first epoch. |
|
Run once after the last epoch. |
|
Run before each epoch. |
|
Run after each epoch. |
|
Run before each batch. |
|
Run after each batch. |
|
Run on unhandled exception. |
|
Serialize callback config (required). |
|
Reconstruct from config (required). |
Evaluation Callback#
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
(required) |
Phase to run on selected epochs. |
|
|
|
Evaluation frequency. |
|
|
|
Evaluate before training begins. |
|
|
|
Result key (defaults to phase label). |
|
|
|
Metric extractors run after each evaluation. |
Convenience constructor: Evaluation.from_split(label, split, losses, every_n_epochs, metrics, ...)
EvalLossMetric#
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
(required) |
Applied loss to track. |
|
|
|
Batch aggregation method. |
|
|
|
Metric name in the |
ArtifactResult#
Return an ArtifactResult from any callback hook to log a rich non-scalar object
(matplotlib figure, DataFrame, array, text, …).
Field |
Type |
Description |
|---|---|---|
|
|
Stable name for querying (e.g. |
|
|
The object to store. |
Accessing Results#
Expression |
Returns |
Notes |
|---|---|---|
|
|
Scalar metrics keyed by |
|
|
All callback results keyed by |
|
|
Type-narrowed to evaluation results. |
|
|
Type-narrowed to metric results. |
|
|
Type-narrowed to artifact results. |
|
|
Type-narrowed to arbitrary payload results. |
|
|
Artifact entries keyed by |
|
|
All unique artifact names recorded. |
|
|
The artifact object; auto-loads from disk if serialized via |
|
|
Concatenated tensors across all eval batches. |
|
|
Per-loss scalar dict. |
Results Storage#
ResultsConfig controls where artifacts and execution contexts are stored (RAM vs disk).
See How to: Create and Use an Experiment for the full ResultsConfig and ResultRecording reference.
Next Steps#
Cross-validation: Use
EvaluationandEvalLossMetrictogether withCrossValidationandCVBinding; see How to: Use Cross-Validation