from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from modularml.core.experiment.callbacks.callback import Callback
from modularml.core.experiment.callbacks.callback_result import CallbackResult
from modularml.core.experiment.phases.train_phase import TrainPhase
from modularml.utils.logging.logger import get_logger
if TYPE_CHECKING:
from modularml.core.data.execution_context import ExecutionContext
from modularml.core.experiment.experiment import Experiment
from modularml.core.experiment.phases.phase import ExperimentPhase
from modularml.core.experiment.results.phase_results import PhaseResults
logger = get_logger(name="EarlyStopping")
@dataclass
class EarlyStoppingResult(CallbackResult):
"""
Result emitted by EarlyStopping on phase end.
Attributes:
best_epoch (int | None):
The epoch index that achieved the best monitored metric value,
or None if no metric was ever observed.
best_value (float | None):
The best observed metric value.
stopped_epoch (int | None):
The epoch at which training was stopped, or None if training
completed without early stopping.
restored (bool):
Whether the model state was restored to the best epoch.
"""
kind: ClassVar[str] = "early_stopping"
best_epoch: int | None = None
best_value: float | None = None
stopped_epoch: int | None = None
restored: bool = False
[docs]
class EarlyStopping(Callback):
"""
Stop training when a monitored metric stops improving.
This callback monitors a named metric from the MetricStore and stops
training if no improvement is observed for a given number of epochs
(patience). Improvement is determined by the `mode` parameter:
"min" expects the metric to decrease, "max" expects it to increase.
When triggered, EarlyStopping calls `phase.request_stop()` which
sets a flag that the training loop checks at the end of each epoch.
When `restore_best=True`, the model state is restored to the epoch
with the best monitored metric value at the end of the phase. If the
phase has a `Checkpointing` callback configured, its saved states are
used for restoration. Otherwise, EarlyStopping manages its own
in-memory state snapshots as a fallback.
Example:
>>> # Stop if validation loss doesn't improve for 5 epochs
>>> phase.add_callback( # doctest: +SKIP
... EarlyStopping(monitor="val_loss", patience=5)
... )
>>> # Stop and restore model to best epoch
>>> phase.add_callback( # doctest: +SKIP
... EarlyStopping(
... monitor="val_loss",
... patience=5,
... restore_best=True,
... )
... )
"""
[docs]
def __init__(
self,
*,
monitor: str = "val_loss",
mode: Literal["min", "max"] = "min",
patience: int = 5,
min_delta: float = 0.0,
restore_best: bool = False,
reducer: Literal["mean", "sum", "last", "first"] = "last",
label: str | None = None,
execution_order: int = 1,
) -> None:
"""
Initialize an EarlyStopping callback.
Args:
monitor (str, optional):
Name of the metric to monitor. Must match a metric name that
is logged into the MetricStore during training (e.g. by a
EvalLossMetric or custom MetricCallback).
Defaults to "val_loss".
mode (Literal["min", "max"], optional):
Whether the monitored metric should be minimized or maximized.
Defaults to "min".
patience (int, optional):
Number of epochs with no improvement after which training
will be stopped. Defaults to 5.
min_delta (float, optional):
Minimum change in the monitored metric to qualify as an
improvement. Defaults to 0.0.
restore_best (bool, optional):
Whether to restore the model state to the epoch with the best
monitored metric value at the end of the phase. If the phase
has a Checkpointing callback, its saved states are used.
Otherwise, in-memory snapshots are managed automatically.
Defaults to False.
reducer (str, optional):
If the `monitor` metric is produced more than one per epoch, `reducer`
defines how to aggregate all values in that epoch. Typically, `monitor`
is produced at most once per epoch and this argument is not used.
label (str | None, optional):
Stable identifier for this callback. Defaults to "EarlyStopping".
execution_order (int, optional):
Used for execution ordering or multiple callbacks, where higher values
are executed later than lower values. This value should be greater than
the callback that produces the `monitor` metric. Unless you manually
modified the other callback execution orders, a value of 1 is fine.
"""
super().__init__(
label=label or "EarlyStopping",
execution_order=execution_order,
)
self._monitor = monitor
self._mode = mode
self._patience = patience
self._min_delta = min_delta
self._restore_best = restore_best
self._reducer = reducer
self._best_value: float | None = None
self._best_epoch: int | None = None
self._wait: int = 0
self._stopped_epoch: int | None = None
# In-memory state for restore_best (used when no Checkpointing)
self._best_state: dict[str, Any] | None = None
# ================================================
# Properties
# ================================================
@property
def monitor(self) -> str:
"""The name of the monitored metric."""
return self._monitor
@property
def best_value(self) -> float | None:
"""The best observed metric value so far."""
return self._best_value
@property
def best_epoch(self) -> int | None:
"""The epoch index that achieved the best metric value, or None."""
return self._best_epoch
@property
def stopped_epoch(self) -> int | None:
"""The epoch at which training was stopped, or None if not triggered."""
return self._stopped_epoch
@property
def restore_best(self) -> bool:
"""Whether model state restoration to the best epoch is enabled."""
return self._restore_best
# ================================================
# Lifecycle Hooks
# ================================================
[docs]
def on_phase_start(
self,
*,
experiment: Experiment,
phase: ExperimentPhase,
results: PhaseResults | None = None,
) -> None:
"""Reset internal state at the start of each phase."""
self._best_value = None
self._best_epoch = None
self._wait = 0
self._stopped_epoch = None
self._best_state = None
[docs]
def on_epoch_end(
self,
*,
experiment: Experiment,
phase: ExperimentPhase,
exec_ctx: ExecutionContext,
results: PhaseResults | None = None,
) -> CallbackResult | None:
"""Check monitored metric and request stop if patience is exceeded."""
if results is None:
return None
# Get metric values for this epoch
metric_value = self._get_metric_value(
results=results,
epoch_idx=exec_ctx.epoch_idx,
)
# If no metric was recorded, treat as no improvement
if metric_value is None:
self._wait += 1
if self._wait >= self._patience:
self._stopped_epoch = exec_ctx.epoch_idx
phase.request_stop()
return None
# Check improvement
improved = self._check_improvement(metric_value)
if improved:
self._best_value = metric_value
self._best_epoch = exec_ctx.epoch_idx
self._wait = 0
# Snapshot state for restore_best (in-memory fallback)
if self._restore_best and not self._has_phase_checkpointing(phase):
self._best_state = experiment.model_graph.get_state()
else:
self._wait += 1
if self._wait >= self._patience:
self._stopped_epoch = exec_ctx.epoch_idx
phase.request_stop()
return None
[docs]
def on_phase_end(
self,
*,
experiment: Experiment,
phase: ExperimentPhase,
results: PhaseResults | None = None,
) -> EarlyStoppingResult | None:
"""Restore best model state (if enabled) and return summary result."""
restored = False
if self._restore_best and self._best_epoch is not None:
restored = self._restore_best_state(
experiment=experiment,
phase=phase,
)
return EarlyStoppingResult(
best_epoch=self._best_epoch,
best_value=self._best_value,
stopped_epoch=self._stopped_epoch,
restored=restored,
)
# ================================================
# Internal
# ================================================
def _get_metric_value(
self,
*,
results: PhaseResults,
epoch_idx: int,
) -> float | None:
"""Extract the monitored metric value for a given epoch, or None."""
if self._monitor not in results.metric_names():
return None
metric_series = results.metrics().where(
name=self._monitor,
epoch=epoch_idx,
)
if len(metric_series.values()) == 0:
return None
# Reducer on "batch", if needed
metric_series = metric_series.collapse(axis="batch", reducer=self._reducer)
entry = metric_series.one()
return entry.value
def _check_improvement(self, value: float) -> bool:
"""Check whether the new value is an improvement over the best."""
if self._best_value is None:
return True
if self._mode == "min":
return value < (self._best_value - self._min_delta)
return value > (self._best_value + self._min_delta)
def _has_phase_checkpointing(self, phase: ExperimentPhase) -> bool:
"""Check whether the phase has a Checkpointing callback configured."""
return isinstance(phase, TrainPhase) and (phase.checkpointing is not None)
def _restore_best_state(
self,
*,
experiment: Experiment,
phase: TrainPhase,
) -> bool:
"""
Restore model state to the best epoch.
Tries phase Checkpointing first, falls back to in-memory snapshot.
Returns:
bool: Whether restoration was successful.
"""
# Try phase Checkpointing attr
if self._has_phase_checkpointing(phase):
ckpt = phase.checkpointing
if ckpt.has_key(self._best_epoch):
if ckpt.mode == "memory":
state = ckpt.get_state(self._best_epoch)
experiment.model_graph.set_state(state)
else:
path = ckpt.get_path(self._best_epoch)
experiment.model_graph.restore_checkpoint(path)
msg = (
f"Restored model to best epoch {self._best_epoch} "
f"(via Checkpointing, {self._monitor}={self._best_value})."
)
logger.info(msg=msg, stacklevel=2)
return True
# Fallback to in-memory snapshot
if self._best_state is not None:
experiment.model_graph.set_state(self._best_state)
msg = (
f"Restored model to best epoch {self._best_epoch} "
f"(in-memory, {self._monitor}={self._best_value})."
)
logger.info(msg=msg, stacklevel=2)
return True
msg = (
f"restore_best=True but no saved state found for epoch "
f"{self._best_epoch}. Model was not restored."
)
logger.warning(msg=msg, stacklevel=2)
return False
# ================================================
# Configurable
# ================================================
[docs]
def get_config(self) -> dict[str, Any]:
"""Return configuration details required to reconstruct this callback."""
return {
"callback_type": self.__class__.__qualname__,
"monitor": self._monitor,
"mode": self._mode,
"patience": self._patience,
"min_delta": self._min_delta,
"restore_best": self._restore_best,
"reducer": self._reducer,
"label": self.label,
}
[docs]
@classmethod
def from_config(cls, config: dict) -> EarlyStopping:
"""Construct from config data."""
return cls(
monitor=config.get("monitor", "val_loss"),
mode=config.get("mode", "min"),
patience=config.get("patience", 5),
min_delta=config.get("min_delta", 0.0),
restore_best=config.get("restore_best", False),
reducer=config.get("reducer", "last"),
label=config.get("label"),
)