"""Unified :class:`FeatureSet` backed by a single :class:`SampleCollection`."""
from __future__ import annotations
from copy import deepcopy
from dataclasses import replace
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
import pyarrow as pa
from modularml.core.data.sample_collection import SampleCollection
from modularml.core.data.sample_collection_mixin import SampleCollectionMixin
from modularml.core.data.sample_schema import validate_str_list
from modularml.core.data.schema_constants import (
DOMAIN_FEATURES,
DOMAIN_SAMPLE_UUIDS,
DOMAIN_TAGS,
DOMAIN_TARGETS,
REP_RAW,
REP_TRANSFORMED,
)
from modularml.core.experiment.experiment_node import ExperimentNode
from modularml.core.references.experiment_reference import ResolutionError
from modularml.core.references.featureset_reference import (
FeatureSetColumnReference,
FeatureSetReference,
)
from modularml.core.splitting.split_mixin import SplitMixin, SplitterRecord
from modularml.core.transforms.scaler import Scaler
from modularml.core.transforms.scaler_record import ScalerRecord
from modularml.utils.data.comparators import deep_equal
from modularml.utils.data.conversion import flatten_to_2d, to_numpy, unflatten_from_2d
from modularml.utils.data.data_format import DataFormat
from modularml.utils.data.formatting import ensure_list
from modularml.utils.data.pyarrow_data import (
build_sample_schema_table,
hash_pyarrow_table,
resolve_column_selectors,
)
from modularml.utils.errors.exceptions import SplitOverlapWarning
from modularml.utils.io.cloning import clone_via_serialization
from modularml.utils.logging.warnings import warn
from modularml.visualization.visualizer.styling import FeatureSetDisplayOptions
from modularml.visualization.visualizer.visualizer import Visualizer
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
import pandas as pd
from modularml.core.data.featureset_view import FeatureSetView
from modularml.core.data.sample_schema import SampleSchema
[docs]
class FeatureSet(ExperimentNode, SplitMixin, SampleCollectionMixin):
"""
Unified :class:`FeatureSet` backed by a single :class:`SampleCollection`.
Each representation (e.g., "raw", "transformed") lives within the same \
:class:`SampleCollection` rather than as separate sub-collections. \
Splits store indices into this collection and may specify \
which representation(s) to use when retrieving data.
"""
[docs]
def __init__(
self,
label: str,
collection: SampleCollection,
**kwargs,
):
super().__init__(label=label, **kwargs)
# Create SampleCollection attribute
if not isinstance(collection, SampleCollection):
msg = f"Expected SampleCollection. Received: {type(collection)}."
raise TypeError(msg)
self.collection: SampleCollection = collection
# Store splits & spliiter configs
self._splits: dict[str, FeatureSetView] = {}
self._split_recs: list[SplitterRecord] = []
# Store scaler logs/configs
self._scaler_recs: list[ScalerRecord] = []
[docs]
@classmethod
def from_pyarrow_table(
cls,
label: str,
table: pa.Table,
schema: SampleSchema | None = None,
**kwargs,
) -> FeatureSet:
"""
Construct a new FeatureSet from an existing PyArrow table.
Args:
label (str):
Label to assign to this FeatureSet.
table (pa.Table):
Table to build FeatureSet around.
schema (SampleSchema | None):
PyArrow schema to use for SampleCollection.
kwargs:
Additional keyword arguments.
Returns:
FeatureSet: New FeatureSet instance
"""
collection: SampleCollection = SampleCollection(table=table, schema=schema)
return cls(
label=label,
collection=collection,
**kwargs,
)
[docs]
@classmethod
def from_dict(
cls,
label: str,
data: dict[str, Sequence[Any]],
feature_keys: str | list[str],
target_keys: str | list[str],
tag_keys: str | list[str] | None = None,
) -> FeatureSet:
"""
Construct a FeatureSet from a Python dictionary of column data.
Description:
Converts a dictionary of lists/arrays into a column-oriented Arrow table \
following the ModularML `SampleSchema` convention. Each key in the input \
dictionary corresponds to a column name, and values are list-like sequences \
of equal length representing sample data.
Unlike `from_pandas`, this constructor assumes that each list element \
already represents a complete sample (i.e., no grouping is applied).
Args:
label (str):
Label to assign to this FeatureSet.
data (dict[str, Sequence[Any]]):
A mapping from column names to list-like column data. Each list must \
have the same length, corresponding to the total number of samples.
feature_keys (str | list[str]):
Column name(s) in `data` to be used as features.
target_keys (str | list[str]):
Column name(s) in `data` to be used as targets.
tag_keys (str | list[str] | None, optional):
Column name(s) corresponding to identifying or categorical metadata \
(e.g., cell ID, protocol, SOC). Defaults to None.
Returns:
FeatureSet:
A new Arrow-backed FeatureSet containing all provided columns, organized
into standardized domains (`features`, `targets`, and `tags`).
Raises:
ValueError:
If required keys are missing, or column lengths are inconsistent.
Example:
FeatureSet construction via a dict of data:
>>> fs = FeatureSet.from_dict( # doctest: +SKIP
... label="CycleData",
... data={
... "voltage": [[3.1, 3.2, 3.3], [3.2, 3.3, 3.4]],
... "current": [[1.0, 1.1, 1.0], [0.9, 1.0, 1.1]],
... "soh": [0.95, 0.93],
... "cell_id": ["A1", "A2"],
... },
... feature_keys=["voltage", "current"],
... target_keys=["soh"],
... tag_keys=["cell_id"],
... )
"""
# 1. Standardize input args
feature_keys = ensure_list(feature_keys)
target_keys = ensure_list(target_keys)
tag_keys = ensure_list(tag_keys)
# 2. Ensure data is a np.ndarray for table construction
feature_data = {k: to_numpy(data[k]) for k in feature_keys}
target_data = {k: to_numpy(data[k]) for k in target_keys}
tag_data = {k: to_numpy(data[k]) for k in tag_keys}
# 3. Build table
table = build_sample_schema_table(
features=feature_data,
targets=target_data,
tags=tag_data,
)
# 4. Return new FeatureSet
return cls.from_pyarrow_table(
label=str(label),
table=table,
)
[docs]
@classmethod
def from_pandas(
cls,
label: str,
df: pd.DataFrame,
feature_cols: str | list[str],
target_cols: str | list[str],
group_by: str | list[str] | None = None,
tag_cols: str | list[str] | None = None,
sort_by: str | list[str] | None = None,
) -> FeatureSet:
"""
Construct a FeatureSet from a pandas DataFrame (column-wise storage).
Description:
Converts a DataFrame into a column-oriented Arrow table that matches the \
ModularML `SampleSchema` convention. Each domain (features, targets, tags) \
becomes a Struct column in the final Arrow table.
If `group_by` are provided, all rows sharing the same group key are \
aggregated into a single sample row, with feature and target columns \
stored as array-valued sequences (e.g., np.ndarray or list).
Args:
label (str):
Name to assign to this FeatureSet.
df (pd.DataFrame):
Input pandas DataFrame containing raw experimental or measurement data.
feature_cols (str | list[str]):
Column name(s) in `df` to be used as feature variables.
target_cols (str | list[str]):
Column name(s) in `df` to be used as target variables.
group_by (str | list[str] | None, optional):
One or more column names defining group boundaries. Each group becomes \
one sample (row) in the final table, with grouped columns stored as lists. \
If None, each original DataFrame row is treated as a sample.
tag_cols (str | list[str] | None, optional):
Column name(s) corresponding to identifying or categorical metadata \
(e.g., cell ID, protocol, SOC). Defaults to None.
sort_by (str | list[str] | None, optional):
Column name(s) used to sort rows within each group before aggregation. \
Only used when `group_by` is specified; ignored otherwise. \
Defaults to None.
Returns:
FeatureSet:
A new Arrow-backed FeatureSet whose table follows the ModularML \
SampleSchema convention.
Example:
FeatureSet construction from a Pandas dataframe:
>>> fs = FeatureSet.from_pandas( # doctest: +SKIP
... label="PulseData",
... df=raw_df,
... feature_cols=["voltage", "current"],
... target_cols=["soh"],
... group_by=["cell_id", "cycle_index"],
... tag_cols=["temperature", "cell_id"],
... sort_by="timestamp",
... )
"""
# 1. Standardize input args
feature_cols = ensure_list(feature_cols)
target_cols = ensure_list(target_cols)
tag_cols = ensure_list(tag_cols)
group_by = ensure_list(group_by)
sort_by = ensure_list(sort_by) if sort_by is not None else []
# 2. Apply grouping, if defined
if group_by:
grouped = df.groupby(group_by, sort=False)
else:
# Each row is a separate sample -> pseudo group by index
df = df.copy()
df["_temp_index"] = df.index
grouped = df.groupby("_temp_index", sort=False)
# 3. Aggregate each group into one “row”
feature_data: dict[str, list] = {c: [] for c in feature_cols}
target_data: dict[str, list] = {c: [] for c in target_cols}
tag_data: dict[str, list] = {c: [] for c in tag_cols}
for _, df_gb in grouped:
# Sort rows within this group if requested
df_group = df_gb.sort_values(by=sort_by) if sort_by else df_gb
# Convert grouped columns into arrays or scalars
for c in feature_cols:
vals = df_group[c].to_numpy()
feature_data[c].append(vals if len(vals) > 1 else vals[0])
for c in target_cols:
vals = df_group[c].to_numpy()
target_data[c].append(vals if len(vals) > 1 else vals[0])
for c in tag_cols:
unique_vals = df_group[c].unique()
tag_data[c].append(
unique_vals[0] if len(unique_vals) == 1 else unique_vals.tolist(),
)
# 4. Ensure data is a np.ndarray for table construction
feature_data = {k: to_numpy(v) for k, v in feature_data.items()}
target_data = {k: to_numpy(v) for k, v in target_data.items()}
tag_data = {k: to_numpy(v) for k, v in tag_data.items()}
# 5. Build table
table = build_sample_schema_table(
features=feature_data,
targets=target_data,
tags=tag_data,
)
# 6. Return new FeatureSet
return cls.from_pyarrow_table(
label=str(label),
table=table,
)
from_df = from_pandas
def __eq__(self, other: FeatureSet):
if not isinstance(other, FeatureSet):
msg = f"Cannot compare equality between FeatureSet and {type(other)}"
raise TypeError(msg)
# Compare ID and label
if (self.node_id != other.node_id) or (self.label != other.label):
return False
# Compare collection
if self.collection != other.collection:
return False
# Compare splits
if set(self.available_splits) != set(other.available_splits):
return False
for k in self.available_splits:
s_split = self.get_split(k)
o_split = other.get_split(k)
if set(s_split.indices) != set(o_split.indices):
return False
if set(s_split.columns) != set(o_split.columns):
return False
# Compare splitter configs
if len(self._split_recs) != len(other._split_recs):
return False
for i in range(len(self._split_recs)):
if self._split_recs[i].get_config() != other._split_recs[i].get_config():
return False
# Compare scaler configs
if len(self._scaler_recs) != len(other._scaler_recs):
return False
s_scaler_recs = sorted(self._scaler_recs, key=lambda x: x.order)
o_scaler_recs = sorted(other._scaler_recs, key=lambda x: x.order)
for i in range(len(s_scaler_recs)):
if not deep_equal(
s_scaler_recs[i].get_config(),
o_scaler_recs[i].get_config(),
):
return False
return True
__hash__ = None
# ================================================
# SampleCollectionMixin
# ================================================
def _resolve_caller_attributes(
self,
) -> tuple[SampleCollection, list[str] | None, np.ndarray | None]:
return (self.collection, None, None, self.node_id)
# ================================================
# FeatureSet Properties & Dunders
# ================================================
def __len__(self) -> int:
return self.n_samples
def __getitem__(self, key: str) -> FeatureSetView:
"""Alias for get_split(key)."""
return self.get_split(key)
def __repr__(self):
return f"FeatureSet(label='{self.label}', n_samples={len(self)})"
def __str__(self):
return self.__repr__()
def _summary_rows(self) -> list[tuple]:
return [
("label", self.label),
("n_samples", self.n_samples),
(
"columns",
[
(
DOMAIN_FEATURES,
str(
self.get_feature_keys(
include_domain_prefix=False,
include_rep_suffix=True,
),
),
),
(
DOMAIN_TARGETS,
str(
self.get_target_keys(
include_domain_prefix=False,
include_rep_suffix=True,
),
),
),
(
DOMAIN_TAGS,
str(
self.get_tag_keys(
include_domain_prefix=False,
include_rep_suffix=True,
),
),
),
],
),
(
"splits",
[(name, len(self._splits[name])) for name in self.available_splits],
),
# ("transforms", [(f"{rec.domain}", ", ".join(rec.keys)) for rec in self._scaler_recs]),
# ("node_id", self.node_id),
]
[docs]
def to_view(self) -> FeatureSetView:
"""
Create a FeatureSetView over the entire FeatureSet.
Returns:
FeatureSetView: A view referencing all rows.
"""
from modularml.core.data.featureset_view import FeatureSetView
return FeatureSetView(
source=self,
indices=np.arange(self.collection.n_samples),
columns=self.get_all_keys(
include_domain_prefix=True,
include_rep_suffix=True,
),
label=f"{self.label}_view",
)
@property
def splits(self) -> dict[str, FeatureSetView]:
"""Mapping of split names to their :class:`FeatureSetView` instances."""
return {k: self.get_split(k) for k in self.available_splits}
@property
def n_splits(self) -> int:
"""Number of registered splits."""
return len(self._splits)
# ================================================
# Split Utilities
# ================================================
# Most splitting logic is handled in the SplitterMixin class
@property
def available_splits(self) -> list[str]:
"""
All available splits.
Returns:
list[str]: Available split names.
"""
return list(self._splits.keys())
[docs]
def get_split(self, split_name: str) -> FeatureSetView:
"""
Gets the specified split, rebuilding the :class:`FeatureSetView` if necessary.
Description:
If new columns (e.g. transformed representations) have been added to
the :class:`FeatureSet` after the split was defined, the cached
:class:`FeatureSetView` is re-created to include all current columns
while preserving row indices.
Args:
split_name (str): Name of the split to retrieve.
Returns:
FeatureSetView:
A no-copy, row-wise view of the FeatureSet.
"""
if split_name not in self._splits:
msg = f"Split '{split_name}' does not exist. Available: {self.available_splits}"
raise KeyError(msg)
# If columns differ, rebuild the view
view = self._splits[split_name]
current_cols = self.get_all_keys(
include_domain_prefix=True,
include_rep_suffix=True,
)
if set(view.columns) != set(current_cols):
from modularml.core.data.featureset_view import FeatureSetView
view = FeatureSetView(
source=self,
indices=view.indices, # preserve row selection
columns=current_cols, # refresh columns
label=view.label,
)
self._splits[split_name] = view
return view
[docs]
def add_split(self, split: FeatureSetView):
"""
Adds a new FeatureSetView.
Args:
split (FeatureSetView): The new view to add.
"""
# Check that split references this instance and collection exists
if split.source is not self:
msg = (
f"New split `{split.label}` is not a view of this FeatureSet instance."
)
raise ValueError(msg)
# Check that split name is unique
if split.label is None or split.label in self.available_splits:
msg = f"Split label ('{split.label}') is missing or already exists."
raise ValueError(msg)
# Check that new split name follows naming conventions
try:
validate_str_list([*self.available_splits, split.label])
except ValueError as e:
msg = f"Failed to add new split `{split.label}`. {e}"
raise RuntimeError(msg) from e
# Check overlap with existing splits (only within the same collection)
overlap_samples: dict[str, list[int]] = {}
for existing_split in self._splits.values():
if not split.is_disjoint_with(existing_split):
overlap: list[int] = split.get_overlap_with(existing_split)
overlap_samples[existing_split.label] = overlap
if overlap_samples:
msg = f"Split '{split.label}' has overlapping samples with existing split(s): "
for k, v in overlap_samples.items():
msg += f"\n- Split '{k}' has {len(v)} overlapping samples"
hint = "Consider checking for disjoint splits or revising your conditions."
warn(msg, category=SplitOverlapWarning, stacklevel=2, hints=hint)
# Register new split
self._splits[split.label] = split
[docs]
def clear_splits(self) -> None:
"""Removes all previously defined splits."""
self._splits.clear()
self._split_recs.clear()
# ================================================
# Transform/Scaling
# ================================================
def _cleanup_transformed_rep_if_unused(self) -> None:
"""Remove REP_TRANSFORMED from the SampleCollection for any columns with no scalers."""
# Get columns that are not used in any scalers
# They should not have a ".transformed" representation
all_cols = self.get_all_keys(
include_domain_prefix=True,
include_rep_suffix=False,
)
all_cols.remove(DOMAIN_SAMPLE_UUIDS)
used_cols: list[str] = []
for rec in self._scaler_recs:
for k in rec.keys:
used_cols.append(f"{rec.domain}.{k}")
unused_cols = set(all_cols).difference(set(used_cols))
for col in unused_cols:
d, k = col.split(".")
if REP_TRANSFORMED in self.collection._get_rep_keys(domain=d, key=k):
self.collection.delete_rep(domain=d, key=k, rep=REP_TRANSFORMED)
def _get_flat_data(
self,
domain: str,
keys: list[str],
rep: str,
split: str | None,
merged_axes: int | tuple[int],
) -> tuple[np.ndarray, dict]:
"""Returns new 2D array and metadata of flattening process."""
# Get data specified by domain + keys + representation
source = self if split is None else self.get_split(split_name=split)
cols = [f"{domain}.{k}.{rep}" for k in keys]
data = source.get_data(
columns=cols,
fmt=DataFormat.NUMPY,
rep=rep,
include_domain_prefix=True,
include_rep_suffix=True,
)
# Scaler requires 2D array
# If we need to reshape, check `merged_axes`
# - If None, throw error that shapes are >2D, print shapes and state valid
# indices to use for merged_axes
# - If an integer, preserve the specified axis, and flatten the other(s)
flat_data, flat_meta = data, {}
if data.ndim > 2:
if merged_axes is None:
# Common use case is 1D data with 1 feature key (eg, shape = (n_sample, 1, n_f))
# We can just reshape to (n_sample, n_f) without throwing error
if data.ndim == 3 and len(keys) == 1:
merged_axes = (1, 2)
else:
msg = (
f"Scalers can only be fit to 2-dimensional data, `{domain}.{keys}` "
f"results in data with shape: {data.shape}. "
f"Use `merged_axes` to combine the specified axes and flatten the rest."
)
raise ValueError(msg)
flat_data, flat_meta = flatten_to_2d(arr=data, merged_axes=merged_axes)
if data.ndim < 2:
# Add extra dimension to make 2D
flat_data, flat_meta = flatten_to_2d(arr=data, merged_axes=0)
return flat_data, flat_meta
def _resolve_inverse_scaler_chain(
self,
*,
domain: str,
cols: list[str],
) -> tuple[list[ScalerRecord], bool]:
"""
Resolve the ordered list of scalers applicable for inverse scaling.
Description:
Determines which scaler records can be safely inverted for the provided
column set. The resolution enforces strict dependency ordering and
detects partial vs full unscaling cases.
Args:
domain (str):
Domain of the columns ("features", "targets", or "tags").
cols (list[str]):
Column keys (domain-local, no prefixes or suffixes).
Returns:
tuple[list[ScalerRecord], bool]:
A tuple containing:
- List of ScalerRecords to inverse (newest to oldest).
- Boolean indicating whether the unscale is partial.
Raises:
ValueError:
If the requested columns cannot be safely unscaled.
Or if no scalers are found on the specified columns.
"""
col_req = set(cols)
# Relevant scalers in this domain that touch requested cols
relevant = [
rec
for rec in self._scaler_recs
if (rec.domain == domain) and col_req.intersection(rec.keys)
]
if not relevant:
msg = f"No scalers found for {domain}.{cols}. Nothing to unscale."
raise ValueError(msg)
# Sort: most recent first
relevant = sorted(relevant, key=lambda r: r.order, reverse=True)
# Walk order and produce viable chain
chain: list[ScalerRecord] = []
working_cols = set(cols)
is_partial = False
for rec in relevant:
rec_cols = set(rec.keys)
# Case 1: scaler applied to more than defined columns
if working_cols.issubset(rec_cols) and (working_cols != rec_cols):
msg = (
f"Cannot unscale {sorted(working_cols)} because scaler on "
f"{sorted(rec_cols)} was applied later. Split columns first."
)
raise ValueError(msg)
# Case 2: exact match -> can invert
if working_cols == rec_cols:
chain.append(rec)
continue
# Case 3: scaler applied to less than defined columns
if rec_cols.issubset(working_cols) and (working_cols != rec_cols):
# We can only partially unscale to this point
is_partial = True
break
# Any other overlap is invalid
msg = (
f"Invalid scaler dependency ordering for {sorted(col_req)}. "
f"Conflicting scaler keys: {sorted(rec_cols)}."
)
raise ValueError(msg)
return chain, is_partial
def _iter_scaler_records_on_cols(
self,
*,
domain: str,
columns: str | list[str],
) -> Iterable[ScalerRecord]:
"""
Iterate over scaler objects applicable to the provided columns.
Description:
Resolves the scaler dependency chain for inverse scaling and returns
scaler objects in the correct order for calling `inverse_transform`.
Args:
domain (str):
One of {"features", "targets", "tags"}.
columns (str | list[str]):
Column keys within the domain (no domain prefix, no rep suffix).
Returns:
list[ScalerRecord]:
Scaler records ordered newest to oldest.
Raises:
ValueError:
If scalers cannot be safely resolved.
"""
cols = ensure_list(columns)
chain, partial = self._resolve_inverse_scaler_chain(
domain=domain,
cols=cols,
)
if partial:
warn(
"Only partial inverse scaling applied. Full inversion will require "
"seperating interleaved scalers. ",
hints=(
"It is best practice to only stack scalers on the same set of "
"columns during FeatureSet creation."
),
)
yield from chain
[docs]
def iter_scalers_on_cols(
self,
*,
domain: str,
columns: str | list[str],
) -> Iterable[Scaler]:
"""
Iterate over scaler objects applicable to the provided columns.
Description:
Resolves the scaler dependency chain for inverse scaling and returns
scaler objects in the correct order for calling `inverse_transform`.
Args:
domain (str):
One of {"features", "targets", "tags"}.
columns (str | list[str]):
Column keys within the domain (no domain prefix, no rep suffix).
Returns:
list[Scaler]:
Scalers ordered newest to oldest.
Raises:
ValueError:
If scalers cannot be safely resolved.
"""
cols = ensure_list(columns)
chain, partial = self._resolve_inverse_scaler_chain(
domain=domain,
cols=cols,
)
if partial:
warn(
"Only partial inverse scaling applied. Full inversion will require "
"seperating interleaved scalers. ",
hints=(
"It is best practice to only stack scalers on the same set of "
"columns during FeatureSet creation."
),
)
for rec in chain:
yield rec.scaler_obj
[docs]
def unscale_data_for_cols(
self,
*,
data: np.ndarray,
domain: str,
columns: str | list[str],
) -> np.ndarray:
"""
Inverse-scale provided data using FeatureSet scaler history.
Description:
Applies inverse transforms to user-provided NumPy data based on the
scaler records associated with the specified columns. Supports partial
unscaling with warnings when full inversion is not possible.
Args:
data (np.ndarray):
Scaled data to unscale.
domain (str):
One of {"features", "targets", "tags"}.
columns (str | list[str]):
Column keys represented by `data`, in correct order.
E.g., if `data` are the predictions of some model estimating
column `'targets.soh.transformed'`, then `domain='targets'`
and `columns='soh'`.
Returns:
np.ndarray:
Unscaled data array.
Raises:
ValueError:
If data shape is incompatible or scalers cannot be resolved.
"""
cols = ensure_list(columns)
for rec in self._iter_scaler_records_on_cols(domain=domain, columns=cols):
# Scaler requires 2D array
flat_data, flat_meta = data, {}
if data.ndim > 2:
if rec.merged_axes is None:
# Common use case is 1D data with 1 feature key (eg, shape = (n_sample, 1, n_f))
# We can just reshape to (n_sample, n_f) without throwing error
if data.ndim == 3 and len(cols) == 1:
merged_axes = (1, 2)
else:
msg = f"Expected 2-dimensional data. Received: {data.shape}."
raise ValueError(msg)
flat_data, flat_meta = flatten_to_2d(arr=data, merged_axes=merged_axes)
if data.ndim < 2:
# Add extra dimension to make 2D
flat_data, flat_meta = flatten_to_2d(arr=data, merged_axes=0)
# Inverse data and reshape
x_inv_flat = rec.scaler_obj.inverse_transform(data=flat_data)
data = (
unflatten_from_2d(flat=x_inv_flat, meta=flat_meta)
if flat_meta
else x_inv_flat
)
return data
# ================================================
# Duplication
# ================================================
[docs]
def copy(
self,
*,
label: str | None = None,
share_raw_data_buffer: bool = True,
restore_splits: bool = False,
restore_scalers: bool = False,
register: bool = False,
) -> FeatureSet:
"""
Create a copy of this FeatureSet with optional state restoration.
Description:
Constructs a new FeatureSet instance based on this one. By default,
the PyArrow buffers of the raw data columns are shared (zero-copy).
Splitters and scalers can optionally be re-applied to the new instance,
producing non-shared transform data columns.
This method does not mutate the original FeatureSet.
Args:
label (str | None, optional):
Label for the new FeatureSet. If None, appends "_copy"
to the current label.
share_raw_data_buffer (bool, optional):
If True, PyArrow buffers of raw data columns are shared between
the original and copied FeatureSet (zero-copy). If False,
a deep copy of the Arrow table is created.
restore_splits (bool, optional):
If True, all stored SplitterRecords are re-applied to the
new FeatureSet to regenerate splits.
restore_scalers (bool, optional):
If True, all ScalerRecords are re-applied in order to
regenerate transformed representations.
If True, `restore_splits` must be also be enabled.
register (bool, optional):
Whether to register the copied FeatureSet in the
ExperimentContext registry. If True, a new node ID will
be generated for the copied instance.
Must be True if `restore_scalers` or `restore_splits` are enabled.
Returns:
FeatureSet:
A new FeatureSet instance.
Raises:
ValueError:
If `restore_scalers=True` but `restore_splits=False` and
scalers depend on split-specific fitting.
"""
if (restore_scalers or restore_splits) and not register:
msg = (
"Cannot copy splits and/or scaler without registering the copy. "
"Set `register=True` and try again."
)
raise ValueError(msg)
new_label = label if label is not None else f"{self.label}_copy"
# Copy collection (only REP_RAW columns)
new_coll = self.collection.copy(
raw_only=True,
deep=not share_raw_data_buffer,
)
# Instantiate new FeatureSet
# When restoring splits, the node must be registered so that
# splitter reference resolution (e.g. group_by) can find it.
needs_registration = register or restore_splits or restore_scalers
new_fs = FeatureSet(
label=new_label,
collection=new_coll,
register=needs_registration,
)
# Restore splits (if specified)
if restore_splits:
for rec in self._split_recs:
new_fs.split(splitter=rec.splitter, register=True)
# Restore scalers (if specified)
if restore_scalers:
# Check that split dependecies were also restored
if not restore_splits:
for rec in self._scaler_recs:
if rec.fit_split is not None:
msg = (
"Cannot restore scalers that were fit to splits "
"unless `restore_splits=True`."
)
raise ValueError(msg)
# Restore scalers
for rec in self._scaler_recs:
new_fs.fit_transform(
scaler=rec.scaler_obj,
domain=rec.domain,
keys=rec.keys,
fit_to_split=rec.fit_split,
merged_axes=rec.merged_axes,
)
return new_fs
# ================================================
# Referencing
# ================================================
[docs]
def reference(
self,
columns: list[str] | None = None,
*,
features: str | list[str] | None = None,
targets: str | list[str] | None = None,
tags: str | list[str] | None = None,
rep: str | None = None,
) -> FeatureSetReference:
"""
Create a FeatureSetReference object pointing to columns in this FeatureSet.
Description:
Uses the same column-selection semantics as `FeatureSet.select`, but
returns symbolic DataReference objects instead of a view or materialized data.
References preserve:
- Domain (features / targets / tags)
- Key name
- Representation
- Transform history
This is the preferred mechanism for wiring FeatureSets into ModelStages
and other graph components.
Notes:
All columns of a domain are included unless specified. I.e., if `tags=None`
and no tags are specified in `columns`, then all tag columns are included in
the returned FeatureSetReference.
Args:
columns (list[str] | None):
Fully-qualified column names to reference
(e.g. "features.voltage.raw").
features (str | list[str] | None):
Feature-domain selectors. Accepts exact names or wildcards.
Domain prefix may be omitted.
targets (str | list[str] | None):
Target-domain selectors, following the same rules as `features`.
tags (str | list[str] | None):
Tag-domain selectors, following the same rules as `features`.
rep (str | None):
Default representation suffix applied when a selector omits a
representation.
Returns:
FeatureSetReference
"""
# Get all available keys
all_cols = self.get_all_keys(
include_domain_prefix=True,
include_rep_suffix=True,
)
all_cols.remove(DOMAIN_SAMPLE_UUIDS)
# Perform column selection (organized by domain)
selected: dict[str, set[str]] = resolve_column_selectors(
all_columns=all_cols,
columns=columns,
features=features,
targets=targets,
tags=tags,
rep=rep,
include_all_if_empty=True,
)
return FeatureSetReference(
node_id=self.node_id,
node_label=self.label,
features=tuple(selected[DOMAIN_FEATURES]),
targets=tuple(selected[DOMAIN_TARGETS]),
tags=tuple(selected[DOMAIN_TAGS]),
)
[docs]
def column_reference(
self,
column: str | None = None,
*,
feature: str | None = None,
target: str | None = None,
tag: str | None = None,
rep: str | None = None,
) -> FeatureSetColumnReference:
"""
Create a reference to a single column in this FeatureSet.
Args:
column (str | None):
Fully-qualified column name to reference
(e.g. "features.voltage.raw").
feature (str | None):
Feature-domain selector. Domain prefix may be omitted.
target (str | None):
Target-domain selector. Domain prefix may be omitted.
tag (str | None):
Tag-domain selector. Domain prefix may be omitted.
rep (str | None):
Default representation suffix applied when a selector omits a
representation.
Returns:
FeatureSetColumnReference
"""
# Early return if column=sample_id
if column == DOMAIN_SAMPLE_UUIDS:
return FeatureSetColumnReference(
node_id=self.node_id,
node_label=self.label,
domain=DOMAIN_SAMPLE_UUIDS,
key=None,
rep=None,
)
# Argument validation
if column is not None and not isinstance(column, str):
raise TypeError("`column` must be a string. Received: type(column)")
if feature is not None and not isinstance(feature, str):
raise TypeError("`feature` must be a string. Received: type(feature)")
if target is not None and not isinstance(target, str):
raise TypeError("`target` must be a string. Received: type(target)")
if tag is not None and not isinstance(tag, str):
raise TypeError("`tag` must be a string. Received: type(tag)")
if rep is not None and not isinstance(rep, str):
raise TypeError("`rep` must be a string. Received: type(rep)")
# Get all available keys
all_cols = self.get_all_keys(
include_domain_prefix=True,
include_rep_suffix=True,
)
all_cols.remove(DOMAIN_SAMPLE_UUIDS)
# Perform column selection (organized by domain)
selected: dict[str, set[str]] = resolve_column_selectors(
all_columns=all_cols,
columns=[column] if column is not None else None,
features=[feature] if feature is not None else None,
targets=[target] if target is not None else None,
tags=[tag] if tag is not None else None,
rep=rep,
include_all_if_empty=False,
)
flat_selected: set[str] = set().union(*selected.values())
if len(flat_selected) != 1:
msg = f"Cannot construct a column reference to more than one column. Received: {flat_selected}"
raise ResolutionError(msg)
domain, key, rep = next(iter(flat_selected)).split(".", maxsplit=2)
return FeatureSetColumnReference(
node_id=self.node_id,
node_label=self.label,
domain=domain,
key=key,
rep=rep,
)
# ================================================
# Configurable
# ================================================
[docs]
def get_config(self) -> dict[str, Any]:
"""Structural configuration of the FeatureSet and parents."""
config = super().get_config()
return config
[docs]
@classmethod
def from_config(
cls,
config: dict[str, Any],
*,
register: bool = True,
) -> FeatureSet:
"""Instantiate an empty :class:`FeatureSet`."""
empty_table = pa.table({})
return cls.from_pyarrow_table(
table=empty_table,
register=register,
**config,
)
# ================================================
# Stateful
# ================================================
[docs]
def get_state(self) -> dict[str, Any]:
"""Runtime state (i.e., PyArrow table and records) of the FeatureSet."""
return {
"super": super().get_state(),
"sample_collection": self.collection,
"table_hash": hash_pyarrow_table(self.collection.table),
"splits": {k: self.get_split(k) for k in self.available_splits},
"splitter_records": deepcopy(self._split_recs),
"scaler_records": deepcopy(self._scaler_recs),
}
[docs]
def set_state(self, state: dict[str, Any]):
"""
Restore :class:`FeatureSet` from semantic state.
Args:
state (dict[str, Any]): State dictionary produced by :meth:`get_state`.
"""
from modularml.core.data.featureset_view import FeatureSetView
# Set parent state first
super().set_state(state["super"])
# Clone sample collection (to prevent shared mutation)
state_coll: SampleCollection = state["sample_collection"]
self.collection = SampleCollection(
table=pa.Table.from_pandas(state_coll.table.to_pandas()),
schema=state_coll.schema,
)
# PyArrow table integrity check
expected = state.get("table_hash")
if expected is not None:
actual = hash_pyarrow_table(self.collection.table)
if actual != expected:
msg = f"Arrow table integrity check failed: {actual} != {expected}"
raise ValueError(msg)
# Restore splits (copy all instances)
state_splits: dict[str, FeatureSetView] = state["splits"]
self._splits = {
k: FeatureSetView(
source=self,
indices=fsv.indices,
columns=fsv.columns,
label=fsv.label,
)
for k, fsv in state_splits.items()
}
# TODO: splitter and scaler instances are not copied here
# This may result in accidental mutation
# Restore split records (copy all instances)
self._split_recs = state["splitter_records"]
# Restore scalers
self._scaler_recs = state["scaler_records"]
# ================================================
# Serialization
# ================================================
[docs]
def save(self, filepath: Path, *, overwrite: bool = False) -> Path:
"""
Serializes this FeatureSet to the specified filepath.
Args:
filepath (Path):
File location to save to. Note that the suffix may be overwritten
to enforce the ModularML file extension schema.
overwrite (bool, optional):
Whether to overwrite any existing file at the save location.
Defaults to False.
Returns:
Path: The actual filepath to write the FeatureSet is saved.
"""
from modularml.core.io.serialization_policy import SerializationPolicy
from modularml.core.io.serializer import serializer
return serializer.save(
self,
filepath,
policy=SerializationPolicy.BUILTIN,
overwrite=overwrite,
)
[docs]
@classmethod
def load(
cls,
filepath: Path,
*,
allow_packaged_code: bool = False,
overwrite: bool = False,
) -> FeatureSet:
"""
Load a FeatureSet from file.
Args:
filepath (Path):
File location of a previously saved FeatureSet.
allow_packaged_code : bool
Whether bundled code execution is allowed.
overwrite (bool):
Whether to replace any colliding node registrations in ExperimentContext
If False, a new node_id is assigned to the reloaded FeatureSet. Otherwise,
the existing FeatureSet is removed from the ExperimentContext registry.
Defaults to False.
Returns:
FeatureSet: The reloaded FeatureSet.
"""
from modularml.core.io.serializer import _enforce_file_suffix, serializer
# Append proper sufficx only if no suffix is given
if Path(filepath).suffix == "":
filepath = _enforce_file_suffix(path=filepath, cls=cls)
return serializer.load(
filepath,
allow_packaged_code=allow_packaged_code,
overwrite=overwrite,
)
# ================================================
# Visualizer
# ================================================
[docs]
def visualize(
self,
*,
show_features: bool = True,
show_targets: bool = True,
show_tags: bool | str = "root",
show_overlaps: bool = True,
):
"""
Displays a mermaid diagram for this FeatureSet.
Args:
show_features (bool, optional):
Show feature columns and shapes on nodes. Defaults to True.
show_targets (bool, optional):
Show target columns and shapes on nodes. Defaults to True.
show_tags (bool | str, optional):
Show tag columns and shapes. `"root"` shows only on the
FeatureSet root node, `True` shows on all splits, `False` hides everywhere. Defaults to "root".
show_overlaps (bool, optional):
Show overlap counts between splits. Defaults to True.
"""
display_opts = FeatureSetDisplayOptions(
show_features=show_features,
show_targets=show_targets,
show_tags=show_tags,
show_overlaps=show_overlaps,
)
return Visualizer(self, display_options=display_opts).display_mermaid()