"""Cross-validation execution strategy implementation."""
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING
import numpy as np
from modularml.core.data.featureset_view import FeatureSetView
from modularml.core.execution.cross_validation.cv_binding import CVBinding, _FoldViews
from modularml.core.execution.strategy import ExecutionStrategy
from modularml.core.experiment.experiment import Experiment
from modularml.core.experiment.experiment_context import ExperimentContext
from modularml.core.experiment.phases.phase_group import PhaseGroup
from modularml.core.experiment.phases.train_phase import TrainPhase
from modularml.splitters.random_splitter import RandomSplitter
from modularml.utils.data.formatting import ensure_list
from modularml.utils.environment.environment import IN_NOTEBOOK
from modularml.utils.progress_bars.progress_task import ProgressTask
if TYPE_CHECKING:
from modularml.core.data.featureset import FeatureSet
from modularml.core.execution.cross_validation.cv_results import CVResults
[docs]
class CrossValidation(ExecutionStrategy):
"""
Cross-validation execution strategy.
Description:
Orchestrates repeated execution of an :class:`Experiment` or
:class:`PhaseGroup` by remapping :class:`FeatureSet` nodes to
fold-specific train/validation splits defined via :class:`CVBinding`
objects.
"""
[docs]
def __init__(
self,
*,
label: str = "CV",
bindings: CVBinding | list[CVBinding],
n_folds: int = 5,
seed: int = 13,
phase: TrainPhase | PhaseGroup | None = None,
experiment: Experiment | None = None,
):
"""
Initialize the cross-validation strategy.
Args:
label (str, optional):
Human-readable label applied to generated fold groups.
Defaults to `CV`.
bindings (CVBinding | list[CVBinding]):
One or more :class:`CVBinding` instances describing how each
:class:`FeatureSet` participates in folding.
n_folds (int, optional):
Number of folds to generate. Must be greater than or equal to 1.
Defaults to `5`.
seed (int, optional):
Random seed forwarded to :class:`RandomSplitter`. Defaults to `13`.
phase (TrainPhase | PhaseGroup | None, optional):
Optional :class:`TrainPhase` or :class:`PhaseGroup` template to run
inside each fold. If omitted, the experiment execution plan is used.
experiment (Experiment | None, optional):
:class:`Experiment` to execute. Defaults to the active experiment
from :class:`ExperimentContext`.
Raises:
TypeError: If no experiment is available or if `phase` has an invalid type.
ValueError: If `n_folds` or `val_size` settings are inconsistent.
"""
self.label = label
self.seed = int(seed)
# Get experiment on which CV is applied
if experiment is None:
experiment = ExperimentContext.get_active().get_experiment()
if not isinstance(experiment, Experiment):
msg = (
"Cross validation requires a reference to an experiment. "
"Either provide one explicitly, or set the active context."
)
raise TypeError(msg)
self.experiment = experiment
# Validate template phase group to perform CV over
self.phase_template = PhaseGroup(label=self.label)
if phase is not None:
if isinstance(phase, TrainPhase):
self.phase_template.add_phase(phase=phase)
elif isinstance(phase, PhaseGroup):
self.phase_template.add_group(group=phase)
else:
msg = f"Expected TrainPhase or PhaseGroup. Received: {type(phase)}."
raise TypeError(msg)
# If no phases given, use all defined in experiment
else:
self.phase_template = self.experiment.execution_plan
self.phase_template.label = self.label
# Validate bindings
self.bindings: dict[str, CVBinding] = {
b._fs_id: b for b in ensure_list(bindings)
}
for b in self.bindings.values():
if not isinstance(b, CVBinding):
msg = f"Expected CVBinding. Received: {type(b)}."
raise TypeError(msg)
# Validate `n_folds` and `val_size`
if (not isinstance(n_folds, int)) or (n_folds < 1):
msg = "`n_folds` must be an integer greater than or equal to 1."
raise ValueError(msg)
for b in self.bindings.values():
if (b.val_size is not None) and (b.val_size > (1 / n_folds)):
msg = (
"`val_size` cannot be larger than `1/n_folds`: "
f"{b.val_size} > {1 / n_folds}."
)
raise ValueError(msg)
self.n_folds = n_folds
def _generate_fold_views(self) -> dict[int, dict[str, _FoldViews]]:
"""
Generate per-fold train/validation views.
Returns:
dict[int, dict[str, _FoldViews]]:
Fold-specific views keyed first by fold index and then by
:class:`FeatureSet` node identifier.
"""
# Precompute all fold views (keyed by fold and FeatureSet.node_id)
all_folds: dict[int, dict[str, _FoldViews]] = defaultdict(dict)
for fs_id, cv_binding in self.bindings.items():
# Get referenced FeatureSet
fs: FeatureSet = self.experiment.ctx.get_node(
node_id=fs_id,
enforce_type="FeatureSet",
)
# Build view over all splits in this binding
views = [fs.get_split(spl) for spl in cv_binding.source_splits]
unique_rows = np.unique(np.hstack([v.indices for v in views]))
src_view = FeatureSetView.from_featureset(
fs=fs,
rows=unique_rows,
label="cv_pool",
)
# Construct splitter
if cv_binding.val_size is None:
fold_ratios = {
f"fold_{i}": 1 / self.n_folds for i in range(self.n_folds)
}
else:
fold_ratios = {
f"fold_{i}": cv_binding.val_size for i in range(self.n_folds)
}
fold_ratios["remaining"] = 1 - self.n_folds * cv_binding.val_size
cv_splitter = RandomSplitter(
ratios=fold_ratios,
stratify_by=cv_binding.stratify_by,
group_by=cv_binding.group_by,
seed=self.seed,
)
# Create and record folds
fold_splits: dict[str, FeatureSetView] = cv_splitter.split(
view=src_view,
return_views=True,
)
for i in range(self.n_folds):
val_view = fold_splits[f"fold_{i}"]
val_view.label = "val"
train_view = src_view.take_difference(val_view, label="train")
fold_views = _FoldViews(
fold_idx=i,
train=train_view,
val=val_view,
)
if fs_id in all_folds[i]:
msg = (
f"Data for fold {i} already exists for FeatureSet '{fs.label}'."
)
raise KeyError(msg)
all_folds[i][fs_id] = fold_views
return all_folds
def _replace_featuresets(
self,
fold_views: dict[str, _FoldViews],
ctx: ExperimentContext,
):
"""
Replace context :class:`FeatureSet` nodes with fold-specific splits.
Args:
fold_views (dict[str, _FoldViews]):
Fold splits keyed by :class:`FeatureSet` node identifier.
ctx (ExperimentContext):
Context whose nodes are replaced temporarily for the fold.
"""
for fs_id, fold_data in fold_views.items():
# Get old FeatureSet and remove from ctx
fs_old: FeatureSet = ctx.remove_node(
node_id=fs_id,
error_if_missing=True,
)
# Register new FeatureSet w/o splits or scalers
fs_new = fs_old.copy(
label=fs_old.label,
share_raw_data_buffer=True,
restore_splits=False,
restore_scalers=False,
register=False,
)
fs_new._node_id = fs_old.node_id
ctx.register_experiment_node(
node=fs_new,
check_label_collision=True,
)
# Since splits not involved in CV may be needed
# We must walk back over the recorded split recs
# Only the "train" and "eval" splits will be overwritten
cv_binding = self.bindings[fs_id]
for rec in sorted(fs_old._split_recs, key=lambda r: r.order):
# If rec is applied to the "train" or "eval", we need to re-execute
if rec.applied_to.split_name in [
cv_binding.train_split_name,
cv_binding.val_split_name,
]:
src_to_split = fs_new.get_split(
split_name=rec.applied_to.split_name,
)
src_to_split.split(
splitter=rec.splitter,
return_views=False,
register=True,
)
# Otherwise we can directly copy sample IDs
else:
for spl_name in rec.produced_splits:
# Check if uses fold specific views
if spl_name == cv_binding.train_split_name:
new_view = fold_data.train
elif spl_name == cv_binding.val_split_name:
new_view = fold_data.val
else:
new_view = fs_old.get_split(split_name=spl_name)
# Copy exact view (exact sample IDs)
fs_new._splits[spl_name] = FeatureSetView.from_featureset(
fs=fs_new,
rows=new_view.indices,
columns=new_view.columns,
label=spl_name,
)
# Ensure record is attached to new fs
fs_new._split_recs.append(rec)
# Reapply scalers
for rec in sorted(fs_old._scaler_recs, key=lambda r: r.order):
fs_new.fit_transform(
scaler=rec.scaler_obj,
domain=rec.domain,
keys=rec.keys,
fit_to_split=rec.fit_split,
merged_axes=rec.merged_axes,
)
[docs]
def run(
self,
*,
show_fold_progress: bool = True,
persist_progress: bool = IN_NOTEBOOK,
**kwargs,
) -> CVResults:
"""
Execute cross-validation across all folds.
Args:
show_fold_progress (bool, optional):
Whether to show a progress bar over fold execution. Defaults to True.
persist_progress (bool, optional):
Whether to keep progress bars visible after completion. Defaults to
`IN_NOTEBOOK` (True in notebooks, False in scripts).
**kwargs:
Additional display flags forwarded to :meth:`Experiment.run_group`.
Returns:
CVResults: Cross-fold results container.
"""
from modularml.core.execution.cross_validation.cv_results import CVResults
# Precompute all train/val FeatureSetViews for each fold
all_folds = self._generate_fold_views()
# ------------------------------------------------
# Progress Bar: folds
# ------------------------------------------------
fold_ptask = ProgressTask(
style="cross_validation",
description=f"Cross-Validation ['{self.label}']",
total=self.n_folds,
enabled=show_fold_progress,
persist=persist_progress,
)
fold_ptask.start()
kwargs["persist_progress"] = persist_progress
# ------------------------------------------------
# Fold Execution
# ------------------------------------------------
all_res = CVResults(label=self.label)
for fold_idx in range(self.n_folds):
# Create temporary context and replace featureset
with self.experiment.ctx.temporary() as tmp_ctx:
# Update all featuresets participating in CV
self._replace_featuresets(
fold_views=all_folds[fold_idx],
ctx=tmp_ctx,
)
# Run `self.phase_template`
fold_res = self.experiment.run_group(
group=self.phase_template,
**kwargs,
)
# Record fold results
fold_res.label = f"fold_{fold_idx}"
all_res.add_result(result=fold_res)
fold_ptask.tick(n=1)
fold_ptask.finish()
return all_res