"""Fit-phase implementation for batch-fit scikit-learn models."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from modularml.core.data.batch_view import BatchView
from modularml.core.data.execution_context import ExecutionContext
from modularml.core.data.featureset_view import FeatureSetView
from modularml.core.data.schema_constants import ROLE_DEFAULT
from modularml.core.experiment.callbacks.callback import Callback
from modularml.core.experiment.experiment_context import ExperimentContext
from modularml.core.experiment.phases.phase import ExperimentPhase, InputBinding
from modularml.core.training.applied_loss import AppliedLoss
if TYPE_CHECKING:
from collections.abc import Iterator
from modularml.core.data.featureset import FeatureSet
from modularml.core.experiment.results.fit_results import FitResults
from modularml.core.references.featureset_reference import FeatureSetReference
from modularml.core.topology.graph_node import GraphNode
[docs]
class FitPhase(ExperimentPhase):
"""
Phase that fits batch-fit model nodes on the complete dataset.
Description:
FitPhase is designed for scikit-learn models (and similar) that require
all training data at once via `.fit(X, y)` rather than iterative
mini-batch gradient updates.
Unlike TrainPhase, FitPhase has no epochs or sampling. It yields a
single ExecutionContext containing the entire dataset from the specified
split(s).
By default, fitted nodes are frozen after fitting so that downstream
gradient-trained nodes can use their outputs without interference.
"""
[docs]
def __init__(
self,
label: str,
*,
input_sources: list[InputBinding],
losses: list[AppliedLoss] | None = None,
active_nodes: list[GraphNode] | None = None,
freeze_after_fit: bool = True,
callbacks: list[Callback] | None = None,
):
"""
Initialize a new fit phase for the experiment.
Notes:
All `input_sources` must originate from the same upstream FeatureSet.
If multiple FeatureSets need to be fitted, they must be done so in
separate FitPhases.
Args:
label (str):
A label to assign to this phase of the experiment. Used for logging.
input_sources (list[InputBinding]):
Input bindings for each head node in ModelGraph. All bindings must
resolve to the same FeatureSet.
losses (list[AppliedLoss], optional):
A list of losses to compute after fitting (for metrics only).
active_nodes (list[GraphNode] | None, optional):
A list of GraphNodes to fit. Nodes can be listed via their ID,
label, or with the actual node instance. If None, all nodes
comprising the ModelGraph are used. Defaults to None.
freeze_after_fit (bool, optional):
Whether to freeze fitted nodes after `.fit()` completes.
Defaults to True.
callbacks (list[Callback] | None, optional):
An optional list of Callbacks to run during phase execution.
"""
super().__init__(
label=label,
input_sources=input_sources,
losses=losses,
active_nodes=active_nodes,
callbacks=callbacks,
)
self.freeze_after_fit = freeze_after_fit
self._inp_fsv: FeatureSetView | None = None
self._validate_single_featureset()
# ================================================
# Convenience Constructors
# ================================================
[docs]
@classmethod
def from_split(
cls,
label: str,
*,
split: str,
losses: list[AppliedLoss] | None = None,
active_nodes: list[GraphNode] | None = None,
freeze_after_fit: bool = True,
callbacks: list[Callback] | None = None,
) -> FitPhase:
"""
Initialize a new fit phase for a given FeatureSet split.
Notes:
All active head nodes must input from the defined split.
Args:
label (str):
A label to assign to this phase of the experiment. Used for logging.
split (str):
The FeatureSet split name to fit on (e.g., "train").
losses (list[AppliedLoss], optional):
A list of losses to compute after fitting (for metrics only).
active_nodes (list[GraphNode] | None, optional):
A list of GraphNodes to fit. Nodes can be listed via their ID,
label, or with the actual node instance. If None, all nodes
comprising the ModelGraph are used. Defaults to None.
freeze_after_fit (bool, optional):
Whether to freeze fitted nodes after `.fit()` completes.
Defaults to True.
callbacks (list[Callback] | None, optional):
An optional list of Callbacks to run during phase execution.
"""
input_sources = cls._build_input_sources_from_split(
split=split,
sampler=None,
active_nodes=active_nodes,
)
return cls(
label=label,
input_sources=input_sources,
losses=losses,
active_nodes=active_nodes,
freeze_after_fit=freeze_after_fit,
callbacks=callbacks,
)
# ================================================
# Validation
# ================================================
def _validate_single_featureset(self):
"""Ensure all input sources originate from same FeatureSet (and split)."""
fs_node_ids = {binding.upstream_ref.node_id for binding in self.input_sources}
if len(fs_node_ids) > 1:
fs_lbls = {
binding.upstream_ref.node_label for binding in self.input_sources
}
msg = (
"All `input_sources` of a FitPhase must resolve to a single upstream "
f"FeatureSet. Detected multiple: {fs_lbls}."
)
raise ValueError(msg)
fs_splits: set[str | None] = {binding.split for binding in self.input_sources}
if len(fs_splits) > 1:
msg = (
"All `input_sources` of a FitPhase must resolve to the same split of "
f"the same FeatureSet. Detected multiple splits: {fs_splits}."
)
raise ValueError(msg)
# Convert this FeatureSet + split to a view
fs: FeatureSet = ExperimentContext.get_active().get_node(
node_id=next(iter(fs_node_ids)),
enforce_type="FeatureSet",
)
split = next(iter(fs_splits))
if split is None:
self._inp_fsv = fs.to_view()
else:
self._inp_fsv = fs.get_split(split_name=split)
# ================================================
# Representation
# ================================================
def __repr__(self):
return f"FitPhase(label='{self.label}')"
# ================================================
# Execution
# ================================================
[docs]
def iter_execution(
self,
*,
results: FitResults | None = None,
) -> Iterator[ExecutionContext]:
"""
Iterate over execution steps for this fit phase.
Description:
Generates a single ExecutionContext containing the entire dataset
from the specified split. No batching or epochs are used.
Args:
results (FitResults | None, optional):
Optional container in which results will be registered.
Yields:
ExecutionContext:
A single execution context containing all data, suitable for
`ModelGraph.fit_step(ctx)`.
"""
# Validate input view
if not isinstance(self._inp_fsv, FeatureSetView):
msg = f"Failed to resolve input view for FitPhase '{self.label}'."
raise TypeError(msg)
n = len(self._inp_fsv)
if n == 0:
msg = f"FitPhase '{self.label}' has no samples in view '{self._inp_fsv!r}'."
raise RuntimeError(msg)
# ------------------------------------------------
# Callbacks: on_phase_start
# ------------------------------------------------
exp_ctx = ExperimentContext.get_active()
experiment = exp_ctx.get_experiment()
experiment._in_callback = True
try:
for cb in self.callbacks:
cb._on_phase_start(
experiment=experiment,
phase=self,
results=results,
)
finally:
experiment._in_callback = False
try:
# Create a single BatchView over the entire dataset
bv = BatchView(
source=self._inp_fsv.source,
role_indices={ROLE_DEFAULT: self._inp_fsv.indices},
)
inputs: dict[tuple[str, FeatureSetReference], BatchView] = {
(binding.node_id, binding.upstream_ref): bv
for binding in self.input_sources
}
exec_ctx = ExecutionContext(
phase_label=self.label,
epoch_idx=0,
batch_idx=0,
inputs=inputs,
)
# ------------------------------------------------
# Callbacks: on_batch_start
# ------------------------------------------------
experiment._in_callback = True
try:
for cb in self.callbacks:
cb._on_batch_start(
experiment=experiment,
phase=self,
exec_ctx=exec_ctx,
results=results,
)
finally:
experiment._in_callback = False
yield exec_ctx
# ------------------------------------------------
# Callbacks: on_batch_end
# ------------------------------------------------
experiment._in_callback = True
try:
for cb in self.callbacks:
cb._on_batch_end(
experiment=experiment,
phase=self,
exec_ctx=exec_ctx,
results=results,
)
finally:
experiment._in_callback = False
# ------------------------------------------------
# Callbacks: on_phase_end
# ------------------------------------------------
for cb in self.callbacks:
cb._on_phase_end(
experiment=experiment,
phase=self,
results=results,
)
except BaseException as exc:
experiment._in_callback = True
try:
for cb in self.callbacks:
cb._on_exception(
experiment=experiment,
phase=self,
exec_ctx=None,
exception=exc,
results=results,
)
finally:
experiment._in_callback = False
raise
# ================================================
# Configurable
# ================================================
[docs]
def get_config(self) -> dict[str, Any]:
"""
Return configuration details required to reconstruct this phase.
Returns:
dict[str, Any]:
Configuration used to reconstruct the phase.
"""
cfg = super().get_config()
cfg.update(
{
"phase_type": "FitPhase",
"freeze_after_fit": self.freeze_after_fit,
},
)
return cfg
[docs]
@classmethod
def from_config(cls, config: dict) -> FitPhase:
"""
Construct a FitPhase from a configuration dictionary.
Args:
config (dict[str, Any]):
Configuration details. Keys must be strings.
Returns:
FitPhase: Reconstructed phase.
"""
if "phase_type" not in config:
raise ValueError("FitPhase config must include `phase_type`")
if config["phase_type"] != "FitPhase":
msg = (
"Invalid config for FitPhase. Received config for: "
f"{config['phase_type']}"
)
raise ValueError(msg)
losses = None
if config["losses"] is not None:
losses = [AppliedLoss.from_config(cfg) for cfg in config["losses"]]
return cls(
label=config["label"],
input_sources=[
InputBinding.from_config(cfg) for cfg in config["input_sources"]
],
losses=losses,
active_nodes=config["active_nodes"],
freeze_after_fit=config.get("freeze_after_fit", True),
callbacks=[Callback.from_config(cfg) for cfg in config["callbacks"]],
)