Source code for modularml.core.execution.cross_validation.cv_binding
"""Binding definitions for cross-validation FeatureSets."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
from modularml.core.data.featureset import FeatureSet
from modularml.core.experiment.experiment_context import ExperimentContext
from modularml.utils.data.formatting import ensure_list
from modularml.utils.logging.warnings import warn
if TYPE_CHECKING:
from modularml.core.data.featureset_view import FeatureSetView
[docs]
class CVBinding:
"""
Configuration for cross-validation of a single FeatureSet.
Description:
Defines how a specified :class:`FeatureSet` should participate in
cross-validation. The head nodes of a :class:`ModelGraph` are typically
bound to a split or :class:`FeatureSet`. This configuration maps
fold-specific outputs back onto those bindings (for example, mapping
`train_split_name='my_training'`). Each fold then updates the expected
split names with data belonging only to the current fold.
Attributes:
featureset (FeatureSet): The :class:`FeatureSet` to create folds from.
source_splits (list[str]): Existing split names combined to form the pool.
group_by (str | list[str] | None): Columns used for group-based splitting.
stratify_by (str | list[str] | None): Columns used for stratified splitting.
"""
[docs]
def __init__(
self,
fs: str | FeatureSet,
source_splits: list[str],
*,
group_by: str | list[str] | None = None,
stratify_by: str | list[str] | None = None,
train_split_name: str = "train",
val_split_name: str = "val",
val_size: float | None = None,
):
"""
Configure cross-validation for a single FeatureSet.
Args:
fs (str | FeatureSet):
:class:`FeatureSet` (or its node ID/label) to apply cross-validation to.
source_splits (list[str]):
Existing splits of `fs` to pool before folding. For example,
`source_splits=['train', 'val']` merges both splits prior to sampling.
group_by (str | list[str] | None, optional):
Optional tag keys used for group-based splitting. Mutually exclusive
with `stratify_by`. Defaults to None.
stratify_by (str | list[str] | None, optional):
Optional tag keys used for stratified splitting. Mutually exclusive
with `group_by`. Defaults to None.
train_split_name (str, optional):
Split label that should receive each fold's training partition.
Defaults to `train`.
val_split_name (str, optional):
Split label that should receive each fold's validation partition.
Defaults to `val`.
val_size (float | None, optional):
Explicit validation proportion for each fold. If None, computed as
`1 / n_folds`. Defaults to None.
Raises:
ValueError: If configuration references missing splits or invalid sizes.
"""
# Store FeatureSet node ID
exp_ctx = ExperimentContext.get_active()
if not isinstance(fs, FeatureSet):
fs = exp_ctx.get_node(val=fs, enforce_type="FeatureSet")
self._fs_id = fs.node_id
# Existing splits to draw CV samples from
self.source_splits: list[str] = ensure_list(source_splits)
missing_splits = [
spl for spl in self.source_splits if spl not in fs.available_splits
]
if missing_splits:
msg = f"FeatureSet '{fs.label}' does not contain splits: {missing_splits}."
raise ValueError(msg)
# Splitting config
if (group_by is not None) and (stratify_by is not None):
msg = "Only one of `group_by` and `stratify_by` can be defined, not both."
raise ValueError(msg)
self.group_by = group_by
self.stratify_by = stratify_by
if (val_size is not None) and ((val_size >= 1) or (val_size <= 0)):
raise ValueError("`val_size` must be between 0 and 1, exclusive.")
self.val_size = val_size
# Fold split naming
self.train_split_name = train_split_name
self.val_split_name = val_split_name
if self.train_split_name not in fs.available_splits:
msg = (
f"`train_split_name` must correspond to an existing split name in "
f"FeatureSet '{fs.label}'. Available: {fs.available_splits}."
)
raise ValueError(msg)
if self.val_split_name not in fs.available_splits:
msg = (
f"`val_split_name` of '{self.val_split_name}' does not match an "
f"existing split in FeatureSet '{fs.label}'. The validation split "
"produced by each fold will not be used. "
)
warn(message=msg, stacklevel=2)
@dataclass
class _FoldViews:
fold_idx: int
train: FeatureSetView
val: FeatureSetView