Source code for modularml.core.experiment.experiment_context

"""Experiment context registry and lifecycle helpers."""

from __future__ import annotations

import os
from contextlib import contextmanager
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from weakref import ref

from matplotlib.pylab import Enum

from modularml.utils.environment.environment import IN_NOTEBOOK
from modularml.utils.logging.warnings import warn

if TYPE_CHECKING:
    from modularml.core.data.featureset import FeatureSet
    from modularml.core.experiment.experiment import Experiment
    from modularml.core.experiment.experiment_node import ExperimentNode
    from modularml.core.topology.compute_node import ComputeNode
    from modularml.core.topology.model_graph import ModelGraph


_ACTIVE_EXPERIMENT_CONTEXT: ContextVar[ExperimentContext | None] = ContextVar(
    "_ACTIVE_EXPERIMENT_CONTEXT",
    default=None,
)


class RegistrationPolicy(Enum):
    """Controls behavior when registering objects with duplicate labels."""

    ERROR = "error"
    OVERWRITE = "overwrite"
    OVERWRITE_WARN = "overwrite_warn"
    NO_REGISTER = "no_register"

    @classmethod
    def from_value(cls, value):
        """
        Cast a string or enum value to :class:`RegistrationPolicy`.

        Args:
            value (str | RegistrationPolicy): Source value.

        Returns:
            RegistrationPolicy: Normalized policy value.

        Raises:
            ValueError: If the provided value cannot be mapped to a policy.

        """
        if isinstance(value, cls):
            return value
        if isinstance(value, str):
            v = value.strip().lower()
            for policy in cls:
                if policy.value == v:
                    return policy
        msg = f"Invalid registration policy: {value!r}. Expected one of: {[p.value for p in cls]}"
        raise ValueError(msg)


[docs] class ExperimentContext: """Registry and lifecycle controller for a single Experiment."""
[docs] def __init__( self, *, experiment: Experiment | None = None, registration_policy: RegistrationPolicy | str | None = None, ): self._experiment_ref = ref(experiment) if experiment else None # Registries self._nodes_by_id: dict[str, ExperimentNode] = {} self._node_label_to_id: dict[str, str] = {} self._mg: ModelGraph | None = None # Registration policy if registration_policy is None: self._policy = self._resolve_default_policy() else: self._policy = RegistrationPolicy.from_value(registration_policy)
# ================================================ # Active Context Helpers # ================================================ @classmethod def _set_active(cls, ctx: ExperimentContext): _ACTIVE_EXPERIMENT_CONTEXT.set(ctx)
[docs] @classmethod def get_active(cls) -> ExperimentContext: """ Return the active :class:`ExperimentContext`. Args: cls (type[ExperimentContext]): Ignored class reference. Returns: ExperimentContext: Currently active context. Raises: RuntimeError: If no active context is set. """ ctx = _ACTIVE_EXPERIMENT_CONTEXT.get() if ctx is None: raise RuntimeError("There is no active ExperimentContext.") return ctx
[docs] @contextmanager def activate(self): """ Activates a new ExperimentContext within the context scope. Yields: ExperimentContext Example: Activating a new context is done as follows: >>> with ExperimentContext(experiment=my_exp).activate(): # doctest: +SKIP ... ref.resolve() # resolves using a context of `my_exp` """ token = _ACTIVE_EXPERIMENT_CONTEXT.set(self) try: yield self finally: _ACTIVE_EXPERIMENT_CONTEXT.reset(token)
[docs] @contextmanager def temporary(self): """ Create a fully isolated temporary execution scope. Description: All modifications to: - registered nodes - model graph - registration policy - experiment binding will be reverted when the context exits. This is primarily used for cross-validation and other meta-execution procedures. Yields: ExperimentContext Example: Creating a temporary context scope is done as follows: >>> ctx = ExperimentContext.get_active() # doctest: +SKIP >>> with ctx.temporary(): # doctest: +SKIP ... ctx.set_registration_policy("overwrite") ... ctx.register_experiment_node(new_fs) ... run_fold() ... # context fully restored on exit """ # Record context state old_state = self.get_state() token = _ACTIVE_EXPERIMENT_CONTEXT.set(self) try: yield self finally: # Reset state and active context self.set_state(old_state) _ACTIVE_EXPERIMENT_CONTEXT.reset(token)
# ================================================ # Policy Management # ================================================ def _resolve_default_policy(self) -> RegistrationPolicy: """ Determine the default registration policy based on environment. Priority (highest to lowest): 1. Environment variable 2. Jupyter notebook detection 3. Script default """ # 1. Explicit env override env = os.getenv("MODULARML_EXP_REGISTRATION_POLICY") if env: return RegistrationPolicy.from_value(env) # 2. If running in Jupyter Notebook -> default to OVERWRITE_WARN if IN_NOTEBOOK: return RegistrationPolicy.OVERWRITE_WARN # 3. Else --> default to ERROR return RegistrationPolicy.ERROR
[docs] def set_registration_policy(self, policy: str | RegistrationPolicy): """ Permanently set the registration policy. Args: policy (str | RegistrationPolicy): Policy name or enum. """ self._policy = RegistrationPolicy.from_value(policy)
[docs] @contextmanager def use_policy(self, policy: str | RegistrationPolicy): """ Temporarily override the registration policy inside a context. Args: policy (str | RegistrationPolicy): Policy to use within the scope. Yields: None: Control returns to the caller once the context exits. """ old = self._policy self._policy = RegistrationPolicy.from_value(policy) try: yield finally: self._policy = old
[docs] @contextmanager def dont_register(self): """ Temporarily disable ExperimentNode registration. Any nodes created inside this context will not be registered to the active ExperimentContext. Example: Scoped policy setting to no registration: >>> with ExperimentContext.dont_register(): # doctest: +SKIP ... internal_copy = ModelStage.from_state(...) """ old = self._policy self._policy = RegistrationPolicy.NO_REGISTER try: yield finally: self._policy = old
# ================================================ # Experiment Lifecycle # ================================================
[docs] def set_experiment( self, experiment: Experiment, *, reset_registries: bool = False, ): """ Set the experiment reference for this context. Args: experiment (Experiment): Experiment to associate. reset_registries (bool, optional): Whether to clear node registries prior to association. """ self._experiment_ref = ref(experiment) if reset_registries: self.clear_registries()
[docs] def get_experiment(self) -> Experiment | None: """Return the Experiment active in this context, if defined.""" return None if self._experiment_ref is None else self._experiment_ref()
# ================================================ # Registry Helpers # ================================================
[docs] def clear_registries(self): """Clear all registered items.""" self._nodes_by_id.clear() self._node_label_to_id.clear()
[docs] def register_experiment_node( self, node: ExperimentNode, *, check_label_collision: bool = True, ): """ Register a node with optional collision handling. Args: node (ExperimentNode): Node to register in this context. check_label_collision (bool, optional): Whether to enforce uniqueness for labels. Defaults to True. Raises: TypeError: If `node` is not an :class:`ExperimentNode`. ValueError: If duplicates are encountered under :attr:`RegistrationPolicy.ERROR`. """ from modularml.core.experiment.experiment_node import ExperimentNode # Validate node if not isinstance(node, ExperimentNode): msg = f"`node` must be an ExperimentNode. Received: {type(node)}" raise TypeError(msg) node_id = node.node_id label = node.label # ID collision checks if node_id in self._nodes_by_id: if self._policy is RegistrationPolicy.ERROR: msg = f"ExperimentNode with ID '{node_id}' is already registered." raise ValueError(msg) if self._policy in ( RegistrationPolicy.OVERWRITE, RegistrationPolicy.OVERWRITE_WARN, ): old = self._nodes_by_id[node_id] self._node_label_to_id.pop(old.label, None) if self._policy == RegistrationPolicy.OVERWRITE_WARN: msg = f"Overwriting existing node with ID '{old.node_id}'." warn(msg, category=UserWarning, stacklevel=2) else: return # Label collision checks if check_label_collision and (label in self._node_label_to_id): if self._policy is RegistrationPolicy.ERROR: msg = f"ExperimentNode label '{label}' already exists." raise ValueError(msg) if self._policy in ( RegistrationPolicy.OVERWRITE, RegistrationPolicy.OVERWRITE_WARN, ): old_id = self._node_label_to_id[label] old = self._nodes_by_id.pop(old_id, None) if self._policy == RegistrationPolicy.OVERWRITE_WARN: msg = f"Overwriting existing node with label '{old.label}'." warn(msg, category=UserWarning, stacklevel=2) # Register unique node UUID and string-based label self._nodes_by_id[node.node_id] = node self._node_label_to_id[node.label] = node.node_id
[docs] def remove_node( self, *, node_id: str | None = None, label: str | None = None, error_if_missing: bool = True, ): """ Remove a registered ExperimentNode from this context. Exactly one of `node_id` or `label` must be provided. Args: node_id (str | None): Internal node UUID to remove. label (str | None): Node label to remove. error_if_missing (bool): Whether to raise if the node does not exist. Returns: ExperimentNode | None: The removed node if found, otherwise None. Raises: ValueError: If neither or both of `node_id` / `label` are provided. KeyError: If the node does not exist and `error_if_missing=True`. """ if (node_id is None) == (label is None): raise ValueError("Must provide exactly one of `node_id` or `label`.") # Resolve node_id if label was given if label is not None: node_id = self._node_label_to_id.get(label) if node_id is None: if error_if_missing: msg = f"No ExperimentNode with label '{label}'." raise KeyError(msg) return None # Remove node node = self._nodes_by_id.pop(node_id, None) if node is None: if error_if_missing: msg = f"No ExperimentNode with id '{node_id}'." raise KeyError(msg) return None # Remove label mapping self._node_label_to_id.pop(node.label, None) return node
[docs] def update_node_label( self, node_id: str, new_label: str, *, check_label_collision: bool = True, ): """ Update the label mapping for a registered node. Args: node_id (str): Identifier of the node whose label is updated. new_label (str): Replacement label. check_label_collision (bool, optional): Whether to enforce uniqueness of labels. Defaults to True. Raises: KeyError: If the node ID is not registered. ValueError: If a collision occurs and `check_label_collision` is True. """ if node_id not in self._nodes_by_id: msg = f"Node ID '{node_id}' not registered." raise KeyError(msg) old_label = self._nodes_by_id[node_id].label # If unchanged -> skip if new_label == old_label: return # Check collision if check_label_collision and (new_label in self._node_label_to_id): msg = f"ExperimentNode label '{new_label}' already exists." raise ValueError(msg) # Update registry mapping self._node_label_to_id.pop(old_label, None) self._node_label_to_id[new_label] = node_id
[docs] def register_model_graph(self, graph: ModelGraph): """ Register a ModelGraph to this context. Args: graph (ModelGraph): Model graph instance to associate. Raises: TypeError: If `graph` is not a :class:`ModelGraph`. ValueError: If overwrite is disallowed and a graph already exists. """ from modularml.core.topology.model_graph import ModelGraph # Validate graph if not isinstance(graph, ModelGraph): msg = f"`graph` must be a ModelGraph instance. Received: {type(graph)}" raise TypeError(msg) if self._policy == RegistrationPolicy.NO_REGISTER: return # Check collisions if self._mg is not None: if self._policy == RegistrationPolicy.ERROR: msg = "A ModelGraph has already been registered to this context." raise ValueError(msg) if self._policy == RegistrationPolicy.OVERWRITE_WARN: msg = f"Overwriting existing ModelGraph '{self._mg.label}'." warn(msg, category=UserWarning, stacklevel=2) # Update internal reference self._mg = graph
[docs] def remove_model_graph(self): """Removes the registered ModelGraph from this context.""" self._mg = None
# ================================================ # Node Lookup # ================================================
[docs] def has_node(self, *, node_id: str | None = None, label: str | None = None) -> bool: """Check whether node is registered in this context.""" if node_id is not None: return node_id in self._nodes_by_id if label is not None: return label in self._node_label_to_id raise ValueError("Must provide `node_id` or `label`.")
[docs] def get_node( self, val: str | None = None, *, node_id: str | None = None, label: str | None = None, enforce_type: str = "ExperimentNode", ) -> ExperimentNode: """ Retrieve the specified node, as registered in this context. Args: val (str, optional): Either the ID or label of a node. ID is checked first. If provided, `node_id` and `label` must be None. node_id (str, optional): ID of node to retrieve. If provided, `val` and `label` must be None. label (str, optional): Label of node to retrieve. If provided, `val` and `node_id` must be None. enforce_type (type, optional): If specified, additional validation is performed to ensure the reutrn node is of the specified type. Defaults to "ExperimentNode". """ node = None # If val, check ID then label if val is not None: if node_id is not None or label is not None: msg = "`node_id` and `label` must be None if `val` is defined." raise ValueError(msg) if self.has_node(node_id=val): return self.get_node(node_id=val, enforce_type=enforce_type) return self.get_node(label=val, enforce_type=enforce_type) # Get node from node_id if node_id is not None: if val is not None or label is not None: msg = "`val` and `label` must be None if `node_id` is defined." raise ValueError(msg) try: node = self._nodes_by_id[node_id] except KeyError as exc: msg = f"No ExperimentNode with id '{node_id}'" raise KeyError(msg) from exc # Get node from label elif label is not None: if val is not None or node_id is not None: msg = "`val` and `node_id` must be None if `label` is defined." raise ValueError(msg) try: node = self._nodes_by_id[self._node_label_to_id[label]] except KeyError as exc: msg = f"No ExperimentNode with label '{label}'" raise KeyError(msg) from exc else: raise ValueError("Must provide node_id or label.") # Enforce node type if enforce_type == "ExperimentNode": from modularml.core.experiment.experiment_node import ExperimentNode if not isinstance(node, ExperimentNode): msg = f"Retrieved node is not of type '{enforce_type}'. Received: {type(node)}." raise TypeError(msg) return node if enforce_type == "GraphNode": from modularml.core.topology.graph_node import GraphNode if not isinstance(node, GraphNode): msg = f"Retrieved node is not of type '{enforce_type}'. Received: {type(node)}." raise TypeError(msg) return node if enforce_type == "ComputeNode": from modularml.core.topology.compute_node import ComputeNode if not isinstance(node, ComputeNode): msg = f"Retrieved node is not of type '{enforce_type}'. Received: {type(node)}." raise TypeError(msg) return node if enforce_type == "ModelNode": from modularml.core.topology.model_node import ModelNode if not isinstance(node, ModelNode): msg = f"Retrieved node is not of type '{enforce_type}'. Received: {type(node)}." raise TypeError(msg) return node if enforce_type == "MergeNode": from modularml.core.topology.merge_nodes.merge_node import MergeNode if not isinstance(node, MergeNode): msg = f"Retrieved node is not of type '{enforce_type}'. Received: {type(node)}." raise TypeError(msg) return node if enforce_type == "FeatureSet": from modularml.core.data.featureset import FeatureSet if not isinstance(node, FeatureSet): msg = f"Retrieved node is not of type '{enforce_type}'. Received: {type(node)}." raise TypeError(msg) return node msg = f"Unsupported `enforce_type`: {enforce_type}." raise ValueError(msg)
@property def available_nodes(self) -> dict[str, ExperimentNode]: """ All registered ExperimentNodes. Returns: dict[str, ExperimentNode]: Nodes keyed by node_id. """ return self._nodes_by_id @property def available_computenodes(self) -> dict[str, ComputeNode]: """ All registered ComputeNode. Returns: dict[str, ComputeNode]: Nodes keyed by node_id. """ from modularml.core.topology.compute_node import ComputeNode cnodes = {} for n in self._nodes_by_id.values(): if isinstance(n, ComputeNode): cnodes[n.node_id] = n return cnodes @property def available_featuresets(self) -> dict[str, FeatureSet]: """ All registered FeatureSets. Returns: dict[str, FeatureSet]: Nodes keyed by node_id. """ from modularml.core.data.featureset import FeatureSet fnodes = {} for n in self._nodes_by_id.values(): if isinstance(n, FeatureSet): fnodes[n.node_id] = n return fnodes @property def model_graph(self) -> ModelGraph | None: """The active ModelGraph instance in this context.""" return self._mg # ================================================ # Stateful # ================================================
[docs] def get_state(self) -> dict[str, Any]: """ Capture the current registration state for restoration. Returns: dict[str, Any]: Snapshot that can be supplied to :meth:`set_state`. """ return { "nodes": self._nodes_by_id.copy(), # shallow "node_states": { k: (v.get_state() if hasattr(v, "get_state") else None) for k, v in self._nodes_by_id.items() }, "model_graph": self._mg, "model_graph_state": self._mg.get_state() if self._mg is not None else None, "policy": self._policy, "experiment_ref": self._experiment_ref, }
[docs] def set_state(self, state: dict[str, Any]): """ Restore the context from a serialized state snapshot. Args: state (dict[str, Any]): Snapshot produced by :meth:`get_state`. """ self.clear_registries() self._experiment_ref = state["experiment_ref"] self._policy = state["policy"] # Restore all nodes self._nodes_by_id = state["nodes"] self._node_label_to_id = {} for node_id, n in self._nodes_by_id.items(): if hasattr(n, "set_state"): n.set_state(state["node_states"][node_id]) self._node_label_to_id[n.label] = node_id # Restore model graph self._mg = state["model_graph"] if self.model_graph is not None: self._mg.set_state(state["model_graph_state"])