"""Core definitions for experiment phases and input bindings."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from modularml.core.data.schema_constants import STREAM_DEFAULT
from modularml.core.experiment.experiment_context import ExperimentContext
from modularml.core.references.featureset_reference import FeatureSetReference
from modularml.core.sampling.base_sampler import BaseSampler
from modularml.utils.data.formatting import ensure_list, find_duplicates
from modularml.utils.errors.error_handling import ErrorMode
from modularml.utils.nn.accelerator import Accelerator
from modularml.visualization.visualizer.visualizer import Visualizer
if TYPE_CHECKING:
from collections.abc import Iterator
from pathlib import Path
from modularml.core.data.batch_view import BatchView
from modularml.core.data.execution_context import ExecutionContext
from modularml.core.data.featureset import FeatureSet
from modularml.core.data.featureset_view import FeatureSetView
from modularml.core.experiment.callbacks.callback import Callback
from modularml.core.experiment.results.phase_results import PhaseResults
from modularml.core.topology.graph_node import GraphNode
from modularml.core.training.applied_loss import AppliedLoss
def _get_upstream_featureset_refs_for_node(node_id: str) -> list[FeatureSetReference]:
# Resolve node
exp_ctx = ExperimentContext.get_active()
node: GraphNode = exp_ctx.get_node(node_id=node_id, enforce_type="GraphNode")
# Get all upstream FeatureSetReference of this node
ups_fs_refs: list[FeatureSetReference] = [
ref
for ref in node.get_upstream_refs(error_mode=ErrorMode.IGNORE)
if isinstance(ref, FeatureSetReference)
]
return ups_fs_refs
def _resolve_upstream_featureset_ref(
node_id: str,
val: str | FeatureSetReference | FeatureSetView | FeatureSet | None = None,
) -> FeatureSetReference:
from modularml.core.data.featureset import FeatureSet
from modularml.core.data.featureset_view import FeatureSetView
# Resolve node
exp_ctx = ExperimentContext.get_active()
node: GraphNode = exp_ctx.get_node(node_id=node_id, enforce_type="GraphNode")
# Get all upstream FeatureSetReference of this node
ups_fs_refs: list[FeatureSetReference] = _get_upstream_featureset_refs_for_node(
node_id=node_id,
)
if len(ups_fs_refs) == 0:
msg = (
f"There are no upstream FeatureSets of '{node.label}'. Cannot "
"generate binding."
)
raise RuntimeError(msg)
if val is None:
if len(ups_fs_refs) > 1:
msg = (
f"GraphNode '{node.label}' has multiple upstream FeatureSets. "
f"You must specify `upstream` explicitly."
)
raise ValueError(msg)
return ups_fs_refs[0]
if isinstance(val, str):
for ref in ups_fs_refs:
if ref.node_id == val:
return ref
if ref.node_label == val:
return ref
msg = (
f"No upstream FeatureSet of node '{node.label}' found with ID or "
f"label of '{val}'."
)
raise ValueError(msg)
if isinstance(val, FeatureSetReference):
if val not in ups_fs_refs:
msg = f"No matching FeatureSetReference exists on node '{node.label}'."
raise ValueError(msg)
return val
if isinstance(val, FeatureSetView):
val = val.source
if isinstance(val, FeatureSet):
for ref in ups_fs_refs:
if ref.node_id == val.node_id:
return ref
if ref.node_label == val.label:
return ref
msg = (
f"No upstream FeatureSet of node '{node.label}' matches given "
"the given FeatureSet/FeatureSetView."
)
raise ValueError(msg)
msg = (
"`upstream` must of type str, FeatureSet, or FeatureSetView. "
f"Received: {type(val)}."
)
raise TypeError(msg)
def _normalize_node_value_to_id(value: str | GraphNode) -> str:
"""Converts a node ID, label, or instance to its node ID."""
from modularml.core.topology.graph_node import GraphNode
exp_ctx = ExperimentContext.get_active()
if isinstance(value, GraphNode):
return value.node_id
if isinstance(value, str):
if exp_ctx.has_node(node_id=value):
return value
if exp_ctx.has_node(label=value):
gnode: GraphNode = exp_ctx.get_node(
label=value,
enforce_type="GraphNode",
)
return gnode.node_id
msg = f"The given GraphNode value ('{value}') does not correspond to any node IDs or labels in the active ExperimentContext."
raise ValueError(msg)
msg = f"GraphNode values must be instances, node IDs, or node labels. Received: {type(value)}."
raise TypeError(msg)
class ExperimentPhase(ABC):
"""Abstract base for executable experiment phases."""
def __init__(
self,
label: str,
input_sources: list[InputBinding],
losses: list[AppliedLoss] | None = None,
active_nodes: list[GraphNode] | None = None,
callbacks: list[Callback] | None = None,
accelerator: Accelerator | str | None = None,
):
"""
Initiallizes a new phase of the experiment.
Args:
label (str):
A label to assign to this phase of the experiment. Used for logging.
input_sources (list[InputBinding]):
Input bindings for each head node in ModelGraph.
losses (list[AppliedLoss], optional):
A list of losses to be applied during this experiment phase.
active_nodes (list[GraphNode] | None, optional):
A list of active GraphNodes in this phase of the experiment. Nodes can
be listed via their ID, label, or with the actual node instance. If
None, all nodes comprising the ModelGraph are used. Defaults to None.
callbacks (list[Callback] | None, optional):
An optional list of Callbacks to run during phase execution.
accelerator (Accelerator | str | None, optional):
Optional phase-level accelerator. When provided, it is passed to
model graph execution and reused by all nodes unless a node-level
override exists.
"""
self.label = label
self.input_sources = self._normalize_input_sources(sources=input_sources)
self.losses: list[AppliedLoss] = ensure_list(losses)
self._validate_losses()
self.callbacks: list[Callback] = ensure_list(callbacks)
self._validate_callbacks()
self.active_nodes = self._resolve_active_nodes(active_nodes)
self.accelerator = (
accelerator
if isinstance(accelerator, Accelerator) or accelerator is None
else Accelerator(device=accelerator)
)
self._validate_inputs_for_head_nodes()
def __repr__(self):
return f"ExperimentPhase(label={self.label})"
# ================================================
# Convenience Constructors
# ================================================
@classmethod
def _build_input_sources_from_split(
cls,
*,
split: str,
sampler: BaseSampler | None = None,
active_nodes: list[str | GraphNode] | None = None,
) -> list[InputBinding]:
"""
Build InputBindings automatically from a split name.
Rules:
- All active head nodes must resolve to exactly one upstream FeatureSet
- All must resolve to the same FeatureSet
- The FeatureSet must contain the given split
Args:
split (str):
Split name of the upstream FeatureSet (e.g. "train", "val").
Onnly rows from this split are use for phase execution.
sampler (BaseSampler, optional):
An optional sampler to use to generate batches from this split.
Required if this binding is for a TrainPhase.
active_nodes (list[GraphNode] | None, optional):
A list of active GraphNodes in this phase of the experiment. Nodes can
be listed via their ID, label, or with the actual node instance. If
None, all nodes comprising the ModelGraph are used. Defaults to None.
"""
# Validate environment
exp_ctx = ExperimentContext.get_active()
mg = exp_ctx.model_graph
if mg is None:
msg = "Cannot infer input sources without a registered ModelGraph."
raise RuntimeError(msg)
# Resolve active nodes
active_node_ids = cls._resolve_active_nodes(nodes=active_nodes)
# Identify active head nodes
head_nodes: list[GraphNode] = [
exp_ctx.get_node(node_id=n_id, enforce_type="GraphNode")
for n_id in mg.head_nodes
if n_id in active_node_ids
]
if not head_nodes:
msg = "No active head nodes found for phase."
raise ValueError(msg)
# Resolve upstream fsv per head node
fs_refs: list[FeatureSetReference] = []
for node in head_nodes:
ups = _get_upstream_featureset_refs_for_node(node.node_id)
if len(ups) != 1:
msg = (
f"Head node '{node.label}' has {len(ups)} upstream FeatureSets. "
"Automatic split-based binding requires exactly one. "
"Use the default phase constructor instead."
)
raise ValueError(msg)
fs_refs.append(ups[0])
# Ensure all refs point to same FeatureSet
fs_ids = {ref.node_id for ref in fs_refs}
if len(fs_ids) != 1:
msg = (
"Automatic split-based binding requires all head nodes to "
"share the same upstream FeatureSet. "
"Use the default phase constructor instead."
)
raise ValueError(msg)
# Validate split exists
fs = fs_refs[0].resolve().source
if split not in fs.available_splits:
msg = (
f"Split '{split}' does not exist in FeatureSet '{fs.label}'. "
f"Available splits: {fs.available_splits}."
)
raise ValueError(msg)
# Create bindings
return [
node.create_input_binding(
split=split,
sampler=sampler,
)
for node in head_nodes
]
# ================================================
# Validation
# ================================================
def _validate_losses(self):
from modularml.core.training.applied_loss import AppliedLoss
# Validate loss type
loss_lbls = []
for ls in self.losses:
if not isinstance(ls, AppliedLoss):
msg = f"Loss entries must be of type AppliedLoss. Received: {type(ls)}."
raise TypeError(msg)
loss_lbls.append(ls.label)
# Ensure unique loss labels (used for tracking)
dup_lbls = find_duplicates(items=loss_lbls)
if len(dup_lbls) > 0:
msg = (
f"Multiple AppliedLosses have the same label: {dup_lbls}. "
"Loss labels must be unique."
)
raise ValueError(msg)
def _validate_callbacks(self):
from modularml.core.experiment.callbacks.callback import Callback
# Validate callback type
callback_lbls = []
for cb in self.callbacks:
if not isinstance(cb, Callback):
msg = (
f"Callback entries must be of type Callback. Received: {type(cb)}."
)
raise TypeError(msg)
callback_lbls.append(cb.label)
# Ensure unique callback labels (used for tracking)
dup_lbls = find_duplicates(items=callback_lbls)
if len(dup_lbls) > 0:
msg = (
f"Multiple Callbacks have the same label: {dup_lbls}. "
"Callbacks labels must be unique."
)
raise ValueError(msg)
# Ensure callbacks are ordered for execution
self.callbacks.sort(key=lambda cb: cb._exec_order)
def _normalize_input_sources(
self,
sources: list[InputBinding],
) -> list[InputBinding]:
"""
Validate input sources.
Returns:
list[InputBinding]:
Validated/cleaned bindings.
"""
sources = ensure_list(sources)
clean_sources: list[InputBinding] = []
for binding in sources:
# Validate binding
if not isinstance(binding, InputBinding):
msg = (
f"Input source values must be of type InputBinding. "
f"Received: {type(binding)}."
)
raise TypeError(msg)
clean_sources.append(binding)
return clean_sources
@staticmethod
def _resolve_active_nodes(
nodes: list[str | GraphNode] | None,
) -> list[str]:
"""
Resolve active GraphNode.
Returns:
list[str]
List of node IDs of active nodes in this phase.
"""
exp_ctx = ExperimentContext.get_active()
# If None, use all nodes in the active ModelGraph
if nodes is None:
mg = exp_ctx.model_graph
if mg is None:
msg = "No ModelGraph has been set in the active ExperimentContext. Either explictly list out `active_nodes`, or register a ModelGraph."
raise ValueError(msg)
# Get all node IDs of the nodes comprising the ModelGraph
return list(mg.nodes.keys())
# Otherwise, normalize each node value to a known node ID
node_ids: list[str] = []
for n in ensure_list(nodes):
n_id = _normalize_node_value_to_id(value=n)
node_ids.append(n_id)
return node_ids
def _validate_inputs_for_head_nodes(self):
"""Validates that all head nodes have defined inputs."""
# Get active ModelGraph, must be defined prior to phase init
exp_ctx = ExperimentContext.get_active()
mg = exp_ctx.model_graph
if mg is None:
msg = "Cannot define an ExperimentPhase before a ModelGraph has been registered."
raise RuntimeError(msg)
# Check that all active head nodes have all inputs defined
for n_id in mg.head_nodes:
# Skip if not active
if n_id not in self.active_nodes:
continue
node = exp_ctx.get_node(node_id=n_id, enforce_type="GraphNode")
# Get all upstream FeatureSetRefs of this head node
ups_fs_refs = _get_upstream_featureset_refs_for_node(node_id=n_id)
req_refs = [
inp.upstream_ref for inp in self.input_sources if inp.node_id == n_id
]
missing: list[FeatureSetReference] = [
ref for ref in ups_fs_refs if ref not in req_refs
]
if missing:
msg = (
f"Head node '{node.label}' is missing an input binding for "
f"upstream FeatureSet(s): '{[r.node_label for r in missing]}'."
)
raise ValueError(msg)
# ================================================
# Stop Control
# ================================================
def request_stop(self) -> None:
"""Request early termination of this phase."""
raise NotImplementedError
# ================================================
# Execution
# ================================================
@abstractmethod
def iter_execution(
self,
*,
results: PhaseResults | None = None,
**kwargs,
) -> Iterator[ExecutionContext]:
"""Iterate over execution steps for this phase."""
...
# ================================================
# Configurable
# ================================================
def get_config(self) -> dict[str, Any]:
"""
Return configuration details required to reconstruct this phase.
Returns:
dict[str, Any]:
Configuration used to reconstruct the phase.
"""
losses_cfg = None
if self.losses is not None:
losses_cfg = [ls.get_config() for ls in self.losses]
return {
"label": self.label,
"input_sources": [inp.get_config() for inp in self.input_sources],
"losses": losses_cfg,
"active_nodes": self.active_nodes,
"callbacks": [cb.get_config() for cb in ensure_list(self.callbacks)],
"accelerator": (
self.accelerator.get_config() if self.accelerator is not None else None
),
}
@classmethod
def from_config(cls, config: dict) -> ExperimentPhase:
"""
Construct a phase from a configuration dictionary.
Args:
config (dict[str, Any]):
Configuration details. Keys must be strings.
Returns:
ExperimentPhase: Reconstructed phase.
"""
if "phase_type" not in config:
raise ValueError("ExperimentPhase config must include `phase_type`")
# Create subclasses directly
phase_type = config["phase_type"]
if phase_type == "EvalPhase":
from modularml.core.experiment.phases.eval_phase import EvalPhase
return EvalPhase.from_config(config=config)
if phase_type == "TrainPhase":
from modularml.core.experiment.phases.train_phase import TrainPhase
return TrainPhase.from_config(config=config)
if phase_type == "FitPhase":
from modularml.core.experiment.phases.fit_phase import FitPhase
return FitPhase.from_config(config=config)
msg = f"Unknown ExperimentPhase subclass: {phase_type}."
raise ValueError(msg)
# ================================================
# YAML
# ================================================
def to_yaml(self, path: str | Path, *, overwrite: bool = False) -> Path:
"""
Export this phase to a human-readable YAML file.
Args:
path (str | Path): Destination file path. A ``.yaml`` extension
is added automatically if not already present.
overwrite (bool, optional): Whether to overwrite an existing file
at ``path``. Defaults to False.
Returns:
Path: The resolved path the file was written to.
Raises:
FileExistsError: If ``path`` already exists and ``overwrite`` is False.
"""
from modularml.core.io.yaml import to_yaml
return to_yaml(self, path, overwrite=overwrite)
@classmethod
def from_yaml(cls, path: str | Path, *, overwrite: bool = False) -> ExperimentPhase:
"""
Reconstruct a phase from a YAML file.
All referenced graph nodes and FeatureSets must already exist in
the active :class:`ExperimentContext`.
Args:
path (str | Path): Path to the YAML file.
overwrite (bool, optional): Passed to node-registering translators.
When False (default), raises ValueError if any reconstructed
node label conflicts with an existing registration in the active
ExperimentContext. Defaults to False.
Returns:
ExperimentPhase: Reconstructed phase.
"""
from modularml.core.io.yaml import from_yaml
return from_yaml(path, overwrite=overwrite)
# ================================================
# Visualizer
# ================================================
def visualize(self):
"""Displays a mermaid diagram for this phase."""
return Visualizer(self).display_mermaid()