Source code for modularml.core.experiment.results.train_results

"""Results container for training phases."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal

from modularml.core.experiment.results.phase_results import PhaseResults

if TYPE_CHECKING:
    from modularml.callbacks.early_stopping import EarlyStoppingResult


[docs] @dataclass class TrainResults(PhaseResults): """ Results container for a training phase. Description: TrainResults wraps the outputs of a TrainPhase, which executes multiple epochs with multiple batches per epoch. This class provides: - Access to training data keyed by epoch and batches - Direct access to validation losses/tensors from evaluation callbacks - Loss aggregation per epoch Validation callbacks (kind="evaluation") are automatically detected and their results exposed through dedicated accessors. Attributes: label (str): Phase label. _execution (list[ExecutionContext]): Ordered execution contexts. _callbacks (list[CallbackResult]): Recorded callback outputs. _metrics (MetricStore): Stored scalar metrics. _series_cache (dict[tuple, Any]): Cache of memoized AxisSeries queries. """ # ================================================ # Representation # ================================================ def __repr__(self): n_epochs = self.n_epochs if self._execution else 0 return f"TrainResults(label='{self.label}', epochs={n_epochs})" # ================================================ # Properties # ================================================ @property def epoch_indices(self) -> list[int]: """ Sorted list of recorded epoch indices. Returns: list[int]: Epoch indices in ascending order. """ epoch_vals = self.execution_contexts().axis_values("epoch") return sorted(int(e) for e in epoch_vals) @property def n_epochs(self) -> int: """ The number of epochs executed during training. Returns: int: Total number of recorded epochs. """ return len(self.epoch_indices) # ================================================ # EarlyStopping Convenience # ================================================ @property def early_stopping_result(self) -> EarlyStoppingResult | None: """ Return the EarlyStoppingResult recorded at phase end, if present. Returns: EarlyStoppingResult | None: The result emitted by an :class:`EarlyStopping` callback, or ``None`` if no such callback was attached to this phase. """ values = self.callbacks(kind="early_stopping").values() return values[0] if values else None
[docs] def best_epoch( self, metric: str = "val_loss", direction: Literal["min", "max"] = "min", ) -> int: """ Return the epoch index at which ``metric`` was best. Args: metric (str): Name of the metric to inspect (e.g. ``"val_loss"``). Defaults to ``"val_loss"``. direction (Literal["min", "max"]): Whether a lower (``"min"``) or higher (``"max"``) value is considered better. Defaults to ``"min"``. Returns: int: Epoch index at which ``metric`` achieved its best value. Raises: ValueError: If ``metric`` is not found in the recorded metrics. """ available = self.metric_names() if metric not in available: msg = f"Metric '{metric}' not found. Available metrics: {available}." raise ValueError(msg) entries = self.metrics().where(name=metric).values() best_entry = ( min(entries, key=lambda e: e.value) if direction == "min" else max(entries, key=lambda e: e.value) ) return best_entry.epoch_idx
@property def last_epoch(self) -> int | None: """ The epoch the model is currently at after training. If an :class:`EarlyStopping` callback with ``restore_best=True`` ran and successfully restored model weights, this returns the restored (best) epoch. Otherwise, returns the last recorded epoch index. Returns ``None`` if no epochs were recorded. Returns: int | None: Current model-state epoch, or ``None``. """ es = self.early_stopping_result if es is not None and es.restored and es.best_epoch is not None: return es.best_epoch if self.epoch_indices: return self.epoch_indices[-1] return None