Source code for modularml.core.experiment.results.group_results

"""Results tree mirroring :class:`PhaseGroup` execution structure."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from modularml.core.experiment.results.eval_results import EvalResults
from modularml.core.experiment.results.phase_results import PhaseResults
from modularml.core.experiment.results.train_results import TrainResults

if TYPE_CHECKING:
    from collections.abc import Iterator


[docs] @dataclass class PhaseGroupResults: """ Hierarchical results container matching the structure of a PhaseGroup. Description: PhaseGroupResults stores PhaseResults and nested PhaseGroupResults in the same hierarchy as the PhaseGroup that produced them. Results are stored in insertion order matching execution order. Provides convenience methods for: - Accessing results by phase or group label - Flattening nested results into a single-level mapping - Iterating over results in execution order Attributes: label (str): Phase-group label associated with this result tree. _results (dict[str, PhaseResults | PhaseGroupResults]): Ordered mapping of labels to phase or nested group results. Example: Accessing phase group results >>> # Access results by label >>> train_res = group_results.get_phase_result("train") # doctest: +SKIP >>> eval_res = group_results.get_phase_result("eval") # doctest: +SKIP >>> # Flatten nested structure >>> flat = group_results.flatten() # doctest: +SKIP >>> for label, phase_res in flat.items(): # doctest: +SKIP ... print(f"{label}: {phase_res!r}") >>> # Iterate in execution order >>> for label, result in group_results.items(): # doctest: +SKIP ... print(label, type(result).__name__) """ label: str _results: dict[str, PhaseResults | PhaseGroupResults] = field( default_factory=dict, ) def __repr__(self): entries = ", ".join( f"'{k}': {type(v).__name__}" for k, v in self._results.items() ) return f"PhaseGroupResults(label='{self.label}', results={{{entries}}})" # ================================================ # Runtime Modifiers # ================================================
[docs] def add_result(self, result: PhaseResults | PhaseGroupResults): """ Record a phase or group result. Args: result (PhaseResults | PhaseGroupResults): The result to add. Must have a unique label within this group. Raises: TypeError: If `result` is not a PhaseResults or PhaseGroupResults. ValueError: If a result with the same label already exists. """ if not isinstance(result, (PhaseResults, PhaseGroupResults)): msg = ( f"Expected PhaseResults or PhaseGroupResults. Received: {type(result)}." ) raise TypeError(msg) if result.label in self._results: msg = f"A result with label '{result.label}' already exists in this group." raise ValueError(msg) self._results[result.label] = result
# ================================================ # Properties # ================================================ @property def labels(self) -> list[str]: """ All top-level result labels in insertion order. Returns: list[str]: Ordered labels. """ return list(self._results.keys()) @property def phase_results(self) -> dict[str, PhaseResults]: """ Only the top-level PhaseResults entries, keyed by label. Description: Returns only the PhaseResults (not nested PhaseGroupResults) at this level of the hierarchy. The returned dict does not encode execution order. """ return {k: v for k, v in self._results.items() if isinstance(v, PhaseResults)} @property def group_results(self) -> dict[str, PhaseGroupResults]: """ Only the top-level PhaseGroupResults entries, keyed by label. Description: Returns only the nested PhaseGroupResults (not PhaseResults) at this level of the hierarchy. The returned dict does not encode execution order. """ return { k: v for k, v in self._results.items() if isinstance(v, PhaseGroupResults) } # ================================================ # Accessors # ================================================ def _resolve_single_phase(self, phase: str | None, req_cls: type) -> str: """ Resolve a phase label for results of type `req_cls`. If `phase` is None, auto-detects the single phase in the top level of the group. Raises if ambiguous. """ if phase is not None: if phase not in self._results: msg = f"No phase exists with label '{phase}'." raise ValueError(msg) return phase # Auto-detect (only works if single train phase) avail_lbls = [lbl for lbl, res in self.items() if isinstance(res, req_cls)] if len(avail_lbls) == 0: msg = f"No {req_cls.__qualname__} found in this group." raise ValueError(msg) if len(avail_lbls) > 1: msg = ( f"Multiple {req_cls.__qualname__} found: {avail_lbls}. " "Specify which one with the `phase` argument." ) raise ValueError(msg) return avail_lbls[0] def __getitem__(self, key: str) -> PhaseResults | PhaseGroupResults: """ Retrieve a result by its label. Args: key (str): The label of the phase or group result. Returns: PhaseResults | PhaseGroupResults: The result for the given label. Raises: KeyError: If no result exists with the given label. """ if key not in self._results: msg = f"No result exists with label '{key}'. Available: {self.labels}." raise KeyError(msg) return self._results[key] def __contains__(self, key: str) -> bool: """Check if a result exists with the given label.""" return key in self._results def __len__(self) -> int: """Number of top-level results in this group.""" return len(self._results)
[docs] def items(self) -> Iterator[tuple[str, PhaseResults | PhaseGroupResults]]: """ Iterate over label-result pairs in execution order. Returns: Iterator[tuple[str, PhaseResults | PhaseGroupResults]]: Iterator over label/result pairs. """ yield from self._results.items()
[docs] def get_phase_result(self, label: str) -> PhaseResults: """ Retrieve a PhaseResults by its label. Args: label (str): The phase label to look up. Returns: PhaseResults: The results for the specified phase. Raises: KeyError: If no result exists with the given label. TypeError: If the result is not of type PhaseResults. """ result = self[label] if not isinstance(result, PhaseResults): msg = ( f"Result with label '{label}' is a " f"{type(result).__name__}, not PhaseResults." ) raise TypeError(msg) return result
[docs] def get_train_result(self, label: str | None = None) -> TrainResults: """ Retrieve a TrainResults by its label. Args: label (str, optional): The training phase label to look up. Auto-detected if omitted. Defaults to None. Returns: TrainResults: The results for the specified phase. Raises: KeyError: If no result exists with the given label. TypeError: If the result is not of type TrainResults. """ phase_lbl = self._resolve_single_phase( phase=label, req_cls=TrainResults, ) result = self[phase_lbl] if not isinstance(result, TrainResults): msg = ( f"Result with label '{label}' is a " f"{type(result).__name__}, not TrainResults." ) raise TypeError(msg) return result
[docs] def get_eval_result(self, label: str | None = None) -> EvalResults: """ Retrieve a EvalResults by its label. Args: label (str): The evaluation phase label to look up. Auto-detected if omitted. Defaults to None. Returns: EvalResults: The results for the specified phase. Raises: KeyError: If no result exists with the given label. TypeError: If the result is not of type EvalResults. """ phase_lbl = self._resolve_single_phase( phase=label, req_cls=EvalResults, ) result = self[phase_lbl] if not isinstance(result, EvalResults): msg = ( f"Result with label '{label}' is a " f"{type(result).__name__}, not EvalResults." ) raise TypeError(msg) return result
[docs] def get_group_result(self, label: str | None = None) -> PhaseGroupResults: """ Retrieve a nested PhaseGroupResults by its label. Args: label (str): The group label to look up. Auto-detected if omitted. Defaults to None. Returns: PhaseGroupResults: The results for the specified group. Raises: KeyError: If no result exists with the given label. TypeError: If the result is a PhaseResults, not PhaseGroupResults. """ phase_lbl = self._resolve_single_phase( phase=label, req_cls=PhaseGroupResults, ) result = self[phase_lbl] if not isinstance(result, PhaseGroupResults): msg = ( f"Result with label '{label}' is a " f"{type(result).__name__}, not PhaseGroupResults." ) raise TypeError(msg) return result
# ================================================ # Flattening # ================================================
[docs] def flatten(self) -> dict[str, PhaseResults]: """ Flatten all nested groups into a single-level dict. Description: Recursively unravels the hierarchy of PhaseGroupResults into a flat mapping of phase labels to their PhaseResults. The returned dict preserves execution order. All phase labels must be unique across the entire hierarchy. If duplicate labels are found, a ValueError is raised. Returns: dict[str, PhaseResults]: A flat mapping of phase labels to results in execution order. Raises: ValueError: If duplicate phase labels exist across the hierarchy. Example: All internal PhaseResults can be flattened to a single list: >>> flat = group_results.flatten() # doctest: +SKIP >>> for label, phase_res in flat.items(): # doctest: +SKIP ... print(f"{label}: {type(phase_res).__name__}") """ flat: dict[str, PhaseResults] = {} duplicates: list[str] = [] self._collect_flat(into=flat, duplicates=duplicates) if duplicates: msg = ( "Cannot flatten PhaseGroupResults; duplicate phase labels " f"found across the hierarchy: {duplicates}." ) raise ValueError(msg) return flat
def _collect_flat( self, *, into: dict[str, PhaseResults], duplicates: list[str], ) -> None: """Recursively collect PhaseResults into a flat dict.""" for label, result in self._results.items(): if isinstance(result, PhaseGroupResults): result._collect_flat(into=into, duplicates=duplicates) else: if label in into: duplicates.append(label) into[label] = result