"""Training-phase implementation and batch scheduling utilities."""
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
from modularml.core.data.batch_view import BatchView
from modularml.core.data.execution_context import ExecutionContext
from modularml.core.data.featureset_view import FeatureSetView
from modularml.core.experiment.callbacks.callback import Callback
from modularml.core.experiment.checkpointing import (
TRAINING_HOOKS,
TRAINING_NAME_TEMPLATE,
TRAINING_PLACEHOLDERS,
Checkpointing,
)
from modularml.core.experiment.experiment_context import ExperimentContext
from modularml.core.experiment.phases.phase import ExperimentPhase, InputBinding
from modularml.core.training.applied_loss import AppliedLoss
from modularml.utils.data.formatting import ensure_list
from modularml.utils.environment.environment import IN_NOTEBOOK
from modularml.utils.logging.logger import get_logger
from modularml.utils.progress_bars.progress_task import ProgressTask
if TYPE_CHECKING:
from collections.abc import Iterator
from numpy.typing import NDArray
from modularml.core.data.sampled_view import SampledView
from modularml.core.experiment.experiment import Experiment
from modularml.core.experiment.results.train_results import TrainResults
from modularml.core.references.featureset_reference import FeatureSetReference
from modularml.core.sampling.base_sampler import BaseSampler
from modularml.core.topology.graph_node import GraphNode
logger = get_logger(name="TrainPhase")
[docs]
class ResultRecording(Enum):
"""
Controls which execution contexts are retained in TrainResults.
ALL:
Record every batch of every epoch (default). Provides full
access to all outputs, losses, and tensors across the entire
training run but uses more memory.
LAST:
Keep only the final epoch's execution contexts. When an
:class:`~modularml.callbacks.early_stopping.EarlyStopping`
callback with ``restore_best=True`` is active, "last" is
interpreted as the **best** epoch.
NONE:
Do not record execution contexts at all. Scalar metrics
(e.g. ``train_loss``, ``val_loss``) are still logged to the
:class:`MetricStore` and remain accessible.
"""
ALL = "all"
LAST = "last"
NONE = "none"
[docs]
@classmethod
def from_value(cls, value: str | ResultRecording) -> ResultRecording:
"""
Normalize strings or enums to a :class:`ResultRecording`.
Args:
value (str | ResultRecording): Input value to normalize.
Returns:
ResultRecording: Canonical enum member.
Raises:
ValueError: If ``value`` cannot be mapped to a valid member.
"""
if isinstance(value, cls):
return value
if isinstance(value, str):
try:
return cls(value.lower())
except ValueError:
pass
msg = (
f"Invalid ResultRecording: {value}. "
f"Expected one of {[r.value for r in cls]}"
)
raise ValueError(msg)
class BatchSchedulingPolicy(Enum):
"""
Defines how batches from multiple samplers are scheduled during training.
Let samplers produce the following batch sequences:
S1 = [b1, b2, b3]
S2 = [c1, c2]
The available scheduling policies behave as follows:
ZIP_STRICT:
Lockstep iteration, stopping when the shortest sampler is exhausted.
Output:
(b1, c1), (b2, c2)
Total steps:
min(len(S1), len(S2))
ZIP_CYCLE:
Lockstep iteration until the longest sampler is exhausted.
Shorter samplers cycle from the beginning as needed.
Output:
(b1, c1), (b2, c2), (b3, c1)
Total steps:
max(len(S1), len(S2))
ALTERNATE_STRICT:
Alternate one batch at a time from each sampler in round-robin order,
stopping when any sampler is exhausted.
Output:
b1, c1, b2, c2
Total steps:
sum(len(Si)) until first sampler is exhausted
ALTERNATE_CYCLE:
Alternate one batch at a time from each sampler in round-robin order,
cycling shorter samplers until the longest sampler is exhausted.
Output:
b1, c1, b2, c2, b3, c1
Total steps:
sum(max(len(Si)))
Notes:
- This policy controls batch ordering only.
- No semantic alignment between samplers is performed.
- If semantic alignment is required (e.g. contrastive pairs),
it must be handled inside the sampler via roles.
- Sequential training on different samplers should be expressed
as multiple training phases, not via batch scheduling.
"""
ZIP_STRICT = "zip_strict"
ZIP_CYCLE = "zip_cycle"
ALTERNATE_STRICT = "alternate_strict"
ALTERNATE_CYCLE = "alternate_cycle"
@classmethod
def from_value(cls, value: str | BatchSchedulingPolicy):
"""
Normalize strings or enums to a :class:`BatchSchedulingPolicy`.
Args:
value (str | BatchSchedulingPolicy): Input value to normalize.
Returns:
BatchSchedulingPolicy: Canonical policy enum.
Raises:
ValueError: If `value` cannot be mapped to a valid policy.
"""
if isinstance(value, cls):
return value
if isinstance(value, str):
try:
return cls(value.lower())
except ValueError:
pass
msg = (
f"Invalid BatchSchedulingPolicy: {value}. "
f"Expected one of {[p.value for p in cls]}"
)
raise ValueError(msg)
@dataclass(frozen=True)
class SamplerExecutionKey:
"""
Unique key describing a sampler execution context.
Attributes:
featureset_id (str): Identifier of the source :class:`FeatureSet`.
split (str | None): Split name used for sampling, if any.
sampler_cfg (Any | None): Serializable sampler configuration.
"""
featureset_id: str
split: str | None
sampler_cfg: Any | None
@dataclass
class SamplerExecution:
"""
Recorded execution info for a sampler and its bindings.
Attributes:
sampler_id (int): Stable identifier assigned to the sampler instance.
sampled (SampledView): Materialized data produced by the sampler.
bindings (list[InputBinding]): Input bindings satisfied by the sampler.
"""
sampler_id: int
sampled: SampledView
bindings: list[InputBinding]
[docs]
class TrainPhase(ExperimentPhase):
"""Phase that trains model graph nodes over one or more epochs."""
[docs]
def __init__(
self,
label: str,
*,
input_sources: list[InputBinding],
losses: list[AppliedLoss],
n_epochs: int = 1,
active_nodes: list[str | GraphNode] | None = None,
batch_schedule: BatchSchedulingPolicy | str = BatchSchedulingPolicy.ZIP_STRICT,
callbacks: list[Callback] | None = None,
checkpointing: Checkpointing | None = None,
result_recording: ResultRecording | str = ResultRecording.ALL,
):
"""
Initiallizes a new training phase for the experiment.
Args:
label (str):
A label to assign to this phase of the experiment. Used for logging.
input_sources (list[InputBinding]):
Input bindings for each head node in ModelGraph.
losses (list[AppliedLoss]):
A list of losses to be applied during this training pahse.
n_epochs (int):
Number of epochs to perform.
active_nodes (list[str | GraphNode] | None, optional):
A list of GraphNodes to train in this training phase. Nodes can
be listed via their ID, label, or with the actual node instance. If
None, all nodes comprising the ModelGraph are used. Defaults to None.
batch_schedule (str | BatchSchedulingPolicy, optional):
Defines how batches from multiple samplers are scheduled during
training. This is only relevant if more than one sampler is defined
in `input_sources`.
Let samplers `S1` and `S2` produce: `S1 = [b1, b2, b3]` and
`S2 = [c1, c2]`
The outputs of each policy is given below:
* "zip_strict": (b1, c1), (b2, c2)
* "zip_cycle": (b1, c1), (b2, c2), (b3, c1)
* "alternate_strict": b1, c1, b2, c2
* "alternate_cycle": b1, c1, b2, c2, b3, c1
See also :class:`BatchSchedulingPolicy`.
callbacks (list[Callback] | None, optional):
An optional list of Callbacks to run during phase execution.
checkpointing (Checkpointing | None, optional):
An optional Checkpointing callback that automatically saves model
state at configurable lifecycle hook points. Unlike regular callbacks,
this is configured as a phase-level argument rather than added manually
via ``add_callback()``. Defaults to None.
result_recording (ResultRecording | str, optional):
Controls which execution contexts are retained in the returned
:class:`TrainResults`. See :class:`ResultRecording` for details.
Defaults to ``ResultRecording.ALL``.
"""
if losses is None:
raise ValueError("Training requires at least once defined loss.")
super().__init__(
label=label,
input_sources=input_sources,
losses=losses,
active_nodes=active_nodes,
callbacks=callbacks,
)
self.batch_schedule = BatchSchedulingPolicy.from_value(batch_schedule)
if n_epochs < 1:
raise ValueError("n_epochs must be >= 1")
self.n_epochs = n_epochs
self.result_recording = ResultRecording.from_value(result_recording)
# Checkpointing
self._checkpointing: Checkpointing | None = None
self.set_checkpointing(checkpointing)
self._validate_samplers()
# Integer IDs assigned to each binding
# Each ID corresponds to a unique sampler (used for alternating schedulers)
self._sampler_ids: NDArray[np.int_] = np.arange(
len(self.input_sources),
dtype=int,
)
# Stop flag for callbacks like EarlyStopping
self._stop_requested = False
# ================================================
# Convenience Constructors
# ================================================
[docs]
@classmethod
def from_split(
cls,
label: str,
*,
split: str,
sampler: BaseSampler,
losses: list[AppliedLoss],
n_epochs: int = 1,
active_nodes: list[str | GraphNode] | None = None,
batch_schedule: BatchSchedulingPolicy | str = BatchSchedulingPolicy.ZIP_STRICT,
callbacks: list[Callback] | None = None,
checkpointing: Checkpointing | None = None,
result_recording: ResultRecording | str = ResultRecording.ALL,
) -> TrainPhase:
"""
Initiallizes a new training phase for a given FeatureSet split.
Notes:
All active head nodes must input from the defined split. If the model
graph has multiple head nodes that input from different FeatureSets,
you will need to use the default TrainPhase constructor.
Args:
label (str):
A label to assign to this phase of the experiment. Used for logging.
split (str):
The FeatureSet split to train on.
sampler (BaseSampler, optional):
A sampler to use to generate batches from this split.
losses (list[AppliedLoss]):
A list of losses to be applied during this training pahse.
n_epochs (int):
Number of epochs to perform.
active_nodes (list[str | GraphNode] | None, optional):
A list of GraphNodes to train in this training phase. Nodes can
be listed via their ID, label, or with the actual node instance. If
None, all nodes comprising the ModelGraph are used. Defaults to None.
batch_schedule (str | BatchSchedulingPolicy, optional):
Defines how batches from multiple samplers are scheduled during
training. This is only relevant if there is more than one head node.
Let samplers `S1` and `S2` produce: `S1 = [b1, b2, b3]` and `S2 = [c1, c2]`
The outputs of each policy is given below:
* "zip_strict": (b1, c1), (b2, c2)
* "zip_cycle": (b1, c1), (b2, c2), (b3, c1)
* "alternate_strict": b1, c1, b2, c2
* "alternate_cycle": b1, c1, b2, c2, b3, c1
See also :class:`BatchSchedulingPolicy`.
callbacks (list[Callback] | None, optional):
An optional list of Callbacks to run during phase execution.
checkpointing (Checkpointing | None, optional):
An optional Checkpointing callback that automatically saves model
state at configurable lifecycle hook points. Defaults to None.
result_recording (ResultRecording | str, optional):
Controls which execution contexts are retained in the returned
:class:`TrainResults`. See :class:`ResultRecording` for details.
Defaults to `ResultRecording.ALL`.
"""
input_sources = cls._build_input_sources_from_split(
split=split,
sampler=sampler,
active_nodes=active_nodes,
)
phase = cls(
label=label,
input_sources=input_sources,
losses=losses,
n_epochs=n_epochs,
active_nodes=active_nodes,
batch_schedule=batch_schedule,
callbacks=callbacks,
checkpointing=checkpointing,
result_recording=result_recording,
)
return phase
# ================================================
# Validation
# ================================================
def _validate_samplers(self):
exp_ctx = ExperimentContext.get_active()
# Ensure all binding of head nodes define a sampler and stream
for binding in self.input_sources:
if binding.sampler is None:
node = exp_ctx.get_node(
node_id=binding.node_id,
enforce_type="GraphNode",
)
msg = (
"TrainPhase requires that samplers are defined for all input "
f"sources. Missing sampler for node: '{node.label}'. Use "
"`<node>.create_input_binding(...)` to create the input source."
)
raise ValueError(msg)
# ================================================
# Representation
# ================================================
def __repr__(self):
return f"TrainPhase(label='{self.label}')"
# ================================================
# Callback Convenience
# ================================================
[docs]
def add_callback(self, callback: Callback):
"""
Add a callback to this training phase.
Args:
callback (Callback): Callback to append.
Raises:
ValueError: If another callback of the same type and label exists.
"""
similar_callbacks = [
cb for cb in ensure_list(self.callbacks) if type(callback) is type(cb)
]
if callback.label in [cb.label for cb in similar_callbacks]:
msg = (
f"Another {type(callback).__qualname__} callback already "
f"exists with label '{callback.label}'. "
)
raise ValueError(msg)
self.callbacks.append(callback)
[docs]
def request_stop(self) -> None:
"""
Request early termination of this training phase.
Description:
Sets an internal flag that is checked at the end of each epoch.
When set, the training loop will break cleanly after the current
epoch completes. Intended to be called by callbacks like
EarlyStopping.
"""
self._stop_requested = True
# ================================================
# Checkpointing
# ================================================
@property
def checkpointing(self) -> Checkpointing | None:
"""The Checkpointing instance configured for this phase, or None."""
return self._checkpointing
[docs]
def set_checkpointing(self, checkpointing: Checkpointing | None) -> None:
"""
Attach or replace the Checkpointing configuration for this phase.
Validates that all ``save_on`` hooks are valid for a TrainPhase
and that the ``name_template`` only uses allowed placeholders.
If no ``name_template`` is set, the training default is applied.
Args:
checkpointing (Checkpointing | None):
The Checkpointing configuration, or None to disable.
"""
if checkpointing is None:
self._checkpointing = None
return
# Validate hooks
invalid = set(checkpointing.save_on) - TRAINING_HOOKS
if invalid:
msg = (
f"Invalid `save_on` hooks for TrainPhase: {sorted(invalid)}. "
f"Valid hooks: {sorted(TRAINING_HOOKS)}."
)
raise ValueError(msg)
# Apply default template if not set
if checkpointing.name_template is None:
checkpointing.name_template = TRAINING_NAME_TEMPLATE
# Validate placeholders
Checkpointing.validate_placeholders(
checkpointing.name_template,
TRAINING_PLACEHOLDERS,
context_name="TrainPhase",
)
self._checkpointing = checkpointing
def _invoke_checkpointing(
self,
hook: str,
*,
experiment: Experiment,
epoch_idx: int = 0,
batch_idx: int = 0,
) -> None:
"""Invoke Checkpointing if configured and conditions are met."""
if self._checkpointing is None:
return
if experiment._checkpointing_disabled:
return
if not self._checkpointing.should_save(hook):
return
if self._checkpointing.mode == "memory":
state = experiment.model_graph.get_state()
self._checkpointing.record_memory(key=epoch_idx, state=state)
else:
name = self._checkpointing.format_name(
phase=self.label,
epoch=epoch_idx,
batch=batch_idx,
)
if self._checkpointing.directory is None:
msg = (
"Cannot save disk checkpoint: no checkpoint directory "
"set. Either set `directory` on the TrainPhase "
"Checkpointing, or set one on the parent Experiment "
"so it can be inherited."
)
raise RuntimeError(msg)
# Ensure directory exists
self._checkpointing.directory.mkdir(parents=True, exist_ok=True)
filepath = self._checkpointing.directory / name
path = experiment.model_graph.save_checkpoint(
filepath=filepath,
overwrite=self._checkpointing.overwrite,
)
self._checkpointing.record_disk(key=epoch_idx, path=Path(path))
# Register with experiment for centralized tracking
experiment._checkpoints[f"{self.label}/{name}"] = Path(path)
# ================================================
# Execution
# ================================================
[docs]
def is_epoch_end(self) -> bool:
"""Whether current `iter_execution` state is at the end of an epoch."""
if not hasattr(self, "_is_epoch_end"):
self._is_epoch_end = False
return self._is_epoch_end
def _masked_batch_like(self, bv: BatchView) -> BatchView:
"""Creates a new, fully masked BatchView with the same size as `bv`."""
return BatchView(
source=bv.source,
role_indices=np.full(bv.n_samples, -1, dtype=int),
role_indice_weights=None,
)
def _build_sampler_executions(
self,
*,
show_sampler_progress: bool = True,
) -> list[SamplerExecution]:
"""
Groups `self.input_sources` into SamplerExecution objects.
Description:
Multiple bindings may share the same effective sampling “execution”:
same upstream FeatureSet, same split restriction, and same (hashable)
sampler configuration. In that case, we should only materialize batches
once, and reuse the resulting SampledView for all bindings in the group.
Grouping is conservative:
- If sampler config is missing or unhashable -> do not dedupe.
- If sampler has no `get_config()` -> do not dedupe.
A SamplerExecution stores:
- sampler_id: a unique integer [0..N_unique-1]
- sampled: the SampledView produced by executing that sampler once
- bindings: all InputBindings that reuse this execution
Returns:
list[SamplerExecution]:
All unique SamplerExecution groups, sorted by `sampler_id`.
"""
# key -> sampler_id
key_to_id: dict[SamplerExecutionKey, int] = {}
# sampler_id -> list of bindings
id_to_bindings: dict[int, list[InputBinding]] = defaultdict(list)
# sampler_id -> sampled view
id_to_sampled: dict[int, SampledView] = {}
# Need a unique key for when we cannot safely dedupe
fallback_counter = 0
for binding in self.input_sources:
# Build a dedupe key for this sampler config
sampler_cfg = None
if hasattr(binding.sampler, "get_config"):
try:
sampler_cfg = binding.sampler.get_config()
except Exception: # noqa: BLE001
sampler_cfg = None
if sampler_cfg is not None:
try:
hash(sampler_cfg)
except TypeError:
sampler_cfg = None
if sampler_cfg is None:
# No reliable config => do not dedupe this sampler execution
fallback_counter += 1
exec_key = SamplerExecutionKey(
featureset_id=binding.upstream_ref.node_id,
split=binding.split,
sampler_cfg=("__no_dedupe__", fallback_counter),
)
else:
exec_key = SamplerExecutionKey(
featureset_id=binding.upstream_ref.node_id,
split=binding.split,
sampler_cfg=sampler_cfg,
)
# Assign sampler_id for this exec group
if exec_key not in key_to_id:
# New exec key -> create ID and materialize SampledView
sampler_id = len(key_to_id)
key_to_id[exec_key] = sampler_id
# InputBinding.resolve_input_view restrict columns to that specific binding
# Since samplers are column-agnostics, we need to clear the columns
fsv = binding.resolve_input_view()
sampler_src = FeatureSetView(
source=fsv.source,
indices=fsv.indices,
columns=fsv.source.get_all_keys(
include_domain_prefix=True,
include_rep_suffix=True,
),
label=f"{fsv.label}_for_sampler",
)
# Materialize batches once
binding.sampler.bind_sources(sources=[sampler_src])
binding.sampler.show_progress = show_sampler_progress
binding.sampler._progress_task.enabled = show_sampler_progress
binding.sampler.materialize_batches(
show_progress=show_sampler_progress,
)
# Capture the sampled output
id_to_sampled[sampler_id] = binding.sampler.sampled
else:
# Existing exec -> get sampler_id
sampler_id = key_to_id[exec_key]
# Add this binding to the exec group
id_to_bindings[sampler_id].append(binding)
# Build ordered executions (sort by sampler ID)
execs: list[SamplerExecution] = []
for sampler_id in sorted(id_to_sampled.keys()):
execs.append(
SamplerExecution(
sampler_id=sampler_id,
sampled=id_to_sampled[sampler_id],
bindings=id_to_bindings[sampler_id],
),
)
return execs
def _iter_schedule(
self,
*,
policy: BatchSchedulingPolicy,
sampler_lengths: list[int],
) -> Iterator[dict[int, int]]:
"""
Yield a per-step schedule mapping sampler_id -> batch_index.
ZIP_*:
Each yielded dict contains all sampler_ids (one batch per sampler per step)
ALTERNATE_*:
Each yielded dict contains exactly one sampler_id (the active sampler for
that step). Non-active samplers must be masked by the caller.
"""
n_samplers = len(sampler_lengths)
if policy == BatchSchedulingPolicy.ZIP_STRICT:
# Zipped batching, but stop when shortest would be exceeded
n_steps = min(sampler_lengths)
for i in range(n_steps):
# Each sampler uses same batch idx
yield dict.fromkeys(range(n_samplers), i)
elif policy == BatchSchedulingPolicy.ZIP_CYCLE:
# Zipped batching, but stop when largest would be exceeded
n_steps = max(sampler_lengths)
for i in range(n_steps):
# Must loop batch idx if beyond length of this sampler
yield {sid: i % sampler_lengths[sid] for sid in range(n_samplers)}
elif policy == BatchSchedulingPolicy.ALTERNATE_STRICT:
# Round-robin batching, but stop when the shortest would be exceeded
n_rounds = min(sampler_lengths)
for i in range(n_rounds):
for sid in range(n_samplers):
yield {sid: i}
elif policy == BatchSchedulingPolicy.ALTERNATE_CYCLE:
# Round-robin batching, but stop when the largest would be exceeded
n_rounds = max(sampler_lengths)
for i in range(n_rounds):
for sid in range(n_samplers):
# Must loop batch idx if beyond length of this sampler
yield {sid: i % sampler_lengths[sid]}
else:
msg = f"Unknown BatchSchedulingPolicy: {policy}"
raise TypeError(msg)
[docs]
def iter_execution(
self,
*,
results: TrainResults | None = None,
show_sampler_progress: bool = True,
show_training_progress: bool = True,
persist_progress: bool = IN_NOTEBOOK,
persist_epoch_progress: bool = IN_NOTEBOOK,
val_loss_metric: str = "val_loss",
) -> Iterator[ExecutionContext]:
"""
Iterate over all execution steps for this training phase.
This generator produces a sequence of :class:`ExecutionContext`
objects representing the full training schedule of the phase,
across all epochs and all batch steps within each epoch.
The execution flow is:
1. Group input bindings into unique sampler executions via
``_build_sampler_executions()``. Samplers with identical configuration,
FeatureSet, and split are executed only once and shared across bindings.
2. For each sampler execution, obtain the number of materialized batches
from its :class:`SampledView`.
3. For each epoch:
a. Generate a step-wise batch schedule using ``_iter_schedule()``
according to ``self.batch_schedule``.
b. For each schedule step, construct the inputs for *all* head-node
bindings:
- If a sampler is active in the current step, its corresponding
batch is selected.
- If a sampler is inactive (ALTERNATE policies), its inputs are
replaced with a fully masked :class:`BatchView`.
4. Yield an :class:`ExecutionContext` containing:
- Phase label
- Epoch index
- Batch index (within the epoch)
- Resolved input :class:`BatchView` objects for all bindings
Notes:
- ZIP scheduling policies always provide one batch per sampler per step.
- ALTERNATE scheduling policies activate exactly one sampler per step;
all others are masked.
- No semantic alignment between samplers is performed here. Any required
alignment (e.g., contrastive pairs or matched samples) must be handled
inside the sampler itself via roles.
- The yielded :class:`ExecutionContext` objects are intended to be
consumed directly by the ModelGraph training loop
(e.g., ``ModelGraph.train_step(ctx)``).
Args:
results (TrainResults | None, optional):
Optional container in which results will be registered.
show_sampler_progress (bool, optional):
Whether to show a progress bar for sampler batching.
Defaults to True.
show_training_progress (bool, optional):
Whether to show a progress bar for training execution.
Defaults to True.
persist_progress (bool, optional):
Whether to leave all progress bars visible after completion.
Overrides all nested progress bar persistence settings.
Defaults to ``IN_NOTEBOOK`` (True in notebooks, False in scripts).
persist_epoch_progress (bool, optional):
Whether to leave per-epoch training bars visible after completion.
Defaults to ``IN_NOTEBOOK``.
val_loss_metric (str, optional):
Name of a recorded validation loss metric to display in the
progress bar. Results must be tracked and the metric must exist.
If not, no validation loss field will be shown.
Defaults to ``"val_loss"``.
Yields:
ExecutionContext:
A fully specified execution context for a single batch step
within a specific epoch of this training phase.
""" # Reset stop flag
self._stop_requested = False
# Samplers may be repeated over input bindings
# We group by unique samplers (same sampler cfg, same FeatureSet + split)
sampler_execs = self._build_sampler_executions(
show_sampler_progress=show_sampler_progress,
)
sampler_lens = [se.sampled.num_batches for se in sampler_execs]
if any(x == 0 for x in sampler_lens):
msg = (
"One or more samplers produced zero batches; cannot execute TrainPhase."
)
raise RuntimeError(msg)
# ------------------------------------------------
# Progress Bar: epoch counter
# ------------------------------------------------
step_cnt = sum(
1
for _ in self._iter_schedule(
policy=self.batch_schedule,
sampler_lengths=sampler_lens,
)
)
epoch_ptask = ProgressTask(
style="training",
description=f"Training ['{self.label}']",
total=self.n_epochs,
enabled=show_training_progress,
persist=persist_progress,
)
epoch_ptask.start()
# ------------------------------------------------
# Callbacks: on_phase_start
# ------------------------------------------------
exp_ctx = ExperimentContext.get_active()
experiment = exp_ctx.get_experiment()
last_ctx: ExecutionContext | None = None
try:
experiment._in_callback = True
try:
for cb in self.callbacks:
cb._on_phase_start(
experiment=experiment,
phase=self,
results=results,
)
finally:
experiment._in_callback = False
# Checkpointing: reset and optionally save at phase_start
if self._checkpointing is not None:
self._checkpointing.reset()
self._invoke_checkpointing(
"phase_start",
experiment=experiment,
)
# ------------------------------------------------
# Iterate over all epochs
# ------------------------------------------------
for epoch_idx in range(self.n_epochs):
# ------------------------------------------------
# Progress Bar: batch counter
# ------------------------------------------------
per_epoch_ptask = ProgressTask(
style="training_loss",
description=f"Epoch {epoch_idx}",
total=step_cnt,
enabled=show_training_progress,
persist=(persist_epoch_progress and persist_progress)
or ((epoch_idx == self.n_epochs - 1) and persist_progress),
)
# Determine scheduling from self.batch_schedule
step_iter = self._iter_schedule(
policy=self.batch_schedule,
sampler_lengths=sampler_lens,
)
# Variables for loss tracking
running_train = 0
running_aux = 0
# ------------------------------------------------
# Iterate over all batches in this epoch
# ------------------------------------------------
for step_idx, step_plan in enumerate(step_iter):
# step_plan: {sampler_id: batch_idx_for_that_sampler}
inputs: dict[tuple[str, FeatureSetReference], BatchView] = {}
# For each sampler, decide whether it's active this step and select/mask
for sid, se in enumerate(sampler_execs):
for binding in se.bindings:
# Get list of batches for a given binding
all_bvs = se.sampled.get_stream(name=binding.stream)
key = (binding.node_id, binding.upstream_ref)
# Use real batch view, if this sampler is active in this step
if sid in step_plan:
bv = all_bvs[step_plan[sid]]
inputs[key] = bv
# Otherwise, use a fully masked batch
else:
bv = all_bvs[0]
inputs[key] = self._masked_batch_like(bv=bv)
ctx = ExecutionContext(
phase_label=self.label,
epoch_idx=epoch_idx,
batch_idx=step_idx,
inputs=inputs,
)
# ------------------------------------------------
# Callbacks: on_epoch_start
# ------------------------------------------------
if step_idx == 0:
experiment._in_callback = True
try:
for cb in self.callbacks:
cb._on_epoch_start(
experiment=experiment,
phase=self,
exec_ctx=ctx,
results=results,
)
finally:
experiment._in_callback = False
self._invoke_checkpointing(
"epoch_start",
experiment=experiment,
epoch_idx=epoch_idx,
)
# ------------------------------------------------
# Callbacks: on_batch_start
# ------------------------------------------------
experiment._in_callback = True
try:
for cb in self.callbacks:
cb._on_batch_start(
experiment=experiment,
phase=self,
exec_ctx=ctx,
results=results,
)
finally:
experiment._in_callback = False
self._invoke_checkpointing(
"batch_start",
experiment=experiment,
epoch_idx=epoch_idx,
batch_idx=step_idx,
)
yield ctx
last_ctx = ctx
# Get step-wise loss totals
step_train = ctx.losses.to_float().trainable
step_aux = ctx.losses.to_float().auxiliary
# Update running sums (per-epoch)
running_train += step_train
running_aux += step_aux
# Compute running averages
avg_train = running_train / (step_idx + 1)
avg_aux = running_aux / (step_idx + 1)
# Log raw train_loss to MetricStore (batch-level)
if results is not None:
results._metrics.log(
name="train_loss",
value=step_train,
epoch_idx=epoch_idx,
batch_idx=step_idx,
)
# ------------------------------------------------
# Callbacks: on_batch_end
# ------------------------------------------------
experiment._in_callback = True
try:
for cb in self.callbacks:
cb._on_batch_end(
experiment=experiment,
phase=self,
exec_ctx=ctx,
results=results,
)
finally:
experiment._in_callback = False
self._invoke_checkpointing(
"batch_end",
experiment=experiment,
epoch_idx=epoch_idx,
batch_idx=step_idx,
)
# ------------------------------------------------
# Callbacks: on_epoch_end
# ------------------------------------------------
if step_idx == step_cnt - 1:
experiment._in_callback = True
try:
for cb in self.callbacks:
cb._on_epoch_end(
experiment=experiment,
phase=self,
exec_ctx=ctx,
results=results,
)
finally:
experiment._in_callback = False
self._invoke_checkpointing(
"epoch_end",
experiment=experiment,
epoch_idx=epoch_idx,
batch_idx=step_idx,
)
# Increment batch progress bar
tick_fields = {
"loss_total": avg_train + avg_aux,
"loss_train": avg_train,
"loss_aux": avg_aux,
}
# Show val_loss on final step if validation ran this epoch
if (
(step_idx == step_cnt - 1)
and (results is not None)
and (val_loss_metric in results.metric_names())
):
# Filter all val losses to those executed this epoch
val_entries = results.metrics().select(
name=val_loss_metric,
epoch=epoch_idx,
)
# Take only the most recent
val_entries.sort(
key=lambda x: x.batch_idx if x.batch_idx is not None else 0,
)
if len(val_entries) > 0:
tick_fields["val_loss"] = val_entries[-1].value
per_epoch_ptask.tick(n=1, **tick_fields)
# Log epoch-level train_loss
if results is not None:
results._metrics.log(
name="train_loss",
value=avg_train,
epoch_idx=epoch_idx,
)
per_epoch_ptask.finish()
epoch_ptask.tick(n=1)
# Check stop flag (set by callbacks like EarlyStopping)
if self._stop_requested:
msg = f"Training stopped at epoch {epoch_idx}."
logger.debug(msg=msg, stacklevel=2)
break
# ------------------------------------------------
# Callbacks: on_phase_end
# ------------------------------------------------
exp_ctx = ExperimentContext.get_active()
experiment._in_callback = True
try:
for cb in self.callbacks:
cb._on_phase_end(
experiment=experiment,
phase=self,
results=results,
)
finally:
experiment._in_callback = False
self._invoke_checkpointing(
"phase_end",
experiment=experiment,
epoch_idx=epoch_idx,
)
except BaseException as exc:
experiment._in_callback = True
try:
for cb in self.callbacks:
cb._on_exception(
experiment=experiment,
phase=self,
exec_ctx=last_ctx,
exception=exc,
results=results,
)
finally:
experiment._in_callback = False
raise
# Finish progress bar
epoch_ptask.finish()
# ================================================
# Configurable
# ================================================
[docs]
def get_config(self) -> dict[str, Any]:
"""
Return configuration details required to reconstruct this phase.
Returns:
dict[str, Any]:
Configuration used to reconstruct the phase.
"""
cfg = super().get_config()
cfg.update(
{
"phase_type": "TrainPhase",
"n_epochs": self.n_epochs,
"batch_schedule": self.batch_schedule.value,
"checkpointing": (
self._checkpointing.get_config()
if self._checkpointing is not None
else None
),
"result_recording": self.result_recording.value,
},
)
return cfg
[docs]
@classmethod
def from_config(cls, config: dict) -> TrainPhase:
"""
Construct a phase from a configuration dictionary.
Args:
config (dict[str, Any]):
Configuration details. Keys must be strings.
Returns:
ExperimentPhase: Reconstructed phase.
"""
if "phase_type" not in config:
raise ValueError("TrainPhase config must include `phase_type`")
if config["phase_type"] != "TrainPhase":
msg = (
"Invalid config for TrainPhase. Received config for: "
f"{config['phase_type']}"
)
raise ValueError(msg)
# Reconstruct losses
losses = None
if config["losses"] is not None:
losses = [AppliedLoss.from_config(cfg) for cfg in config["losses"]]
# Reconstruct checkpointing if present
ckpt_cfg = config.get("checkpointing")
checkpointing = (
Checkpointing.from_config(ckpt_cfg) if ckpt_cfg is not None else None
)
# Build TrainPhase
obj = cls(
label=config["label"],
input_sources=[
InputBinding.from_config(cfg) for cfg in config["input_sources"]
],
losses=losses,
n_epochs=config["n_epochs"],
active_nodes=config["active_nodes"],
batch_schedule=config["batch_schedule"],
callbacks=[Callback.from_config(cfg) for cfg in config["callbacks"]],
checkpointing=checkpointing,
result_recording=config.get("result_recording", "all"),
)
return obj