"""Backend-agnostic optimizer utilities for ModularML training."""
from __future__ import annotations
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any
from modularml.core.io.protocols import Configurable, Stateful
from modularml.utils.data.comparators import deep_equal
from modularml.utils.environment.optional_imports import ensure_tensorflow, ensure_torch
from modularml.utils.errors.exceptions import (
BackendNotSupportedError,
OptimizerError,
OptimizerNotSetError,
)
from modularml.utils.nn.backend import Backend, infer_backend, normalize_backend
if TYPE_CHECKING:
from collections.abc import Callable
def _safe_infer_backend(obj_or_cls: Any) -> Backend:
"""
Infer the backend for an optimizer object or class and enforce validity.
Args:
obj_or_cls (Any): Instance or class exposing backend metadata.
Returns:
Backend: Resolved backend enum value.
Raises:
ValueError: If the backend cannot be determined.
"""
backend = infer_backend(obj_or_cls=obj_or_cls)
if backend == Backend.NONE:
msg = (
"Could not infer backend from optimizer class. Specify backend explicitly."
)
raise ValueError(msg)
return backend
[docs]
class Optimizer(Configurable, Stateful):
"""
Backend-agnostic optimizer wrapper that lazily constructs backend objects.
Description:
Supports initialization from optimizer names, classes, or callables while
normalizing backend selection for PyTorch and TensorFlow.
Attributes:
name (str | None):
Lowercase name of the optimizer when known.
cls (type | None):
Backend optimizer class resolved via introspection.
kwargs (dict[str, Any]):
Keyword arguments applied when instantiating the optimizer.
_backend (Backend | None):
Resolved backend enum for this optimizer.
instance (Any | None):
Concrete optimizer object once built.
parameters (Any | None):
Stored model parameters used for PyTorch optimizer construction.
_pending_state (dict[str, Any] | None):
Deferred backend state restored after :meth:`Optimizer.build`.
"""
[docs]
def __init__(
self,
opt: str | type | None = None,
*,
opt_kwargs: dict[str, Any] | None = None,
factory: Callable | None = None,
backend: Backend | None = None,
):
"""
Initialize the optimizer wrapper from a name, class, or factory.
Args:
opt (str | type | None):
Optimizer name or class, mutually exclusive with `factory`.
opt_kwargs (dict[str, Any] | None):
Keyword arguments provided when instantiating the optimizer.
factory (Callable | None):
Callable returning an optimizer when invoked during :meth:`Optimizer.build`.
backend (Backend | None):
Backend enum required for name/factory initialization.
Raises:
ValueError:
If arguments conflict, `backend` is missing when required, or inputs are invalid.
TypeError:
If `opt` is neither a string nor a type.
"""
if opt is not None and factory is not None:
msg = (
"Provide either an optimizer (`opt`) or a `factory` callable, not both."
)
raise ValueError(msg)
# Case 1: class / name + kwargs
if opt is not None:
if isinstance(opt, str):
self.name = opt.lower()
if backend is None:
msg = (
"Backend must be specified when initializing an optimizer "
"with a string-name."
)
raise ValueError(msg)
self._backend = normalize_backend(backend)
self.cls = self._resolve()
elif isinstance(opt, type):
self.name = opt.__name__
self.cls = opt
self._backend = _safe_infer_backend(self.cls)
else:
msg = (
"Optimizer (`opt`) must be a string-name or class. "
f"Recevied: {type(opt)}"
)
raise TypeError(msg)
self.kwargs = opt_kwargs or {}
self._factory = None
# Case 2: factory
elif factory is not None:
self._factory = factory
# don't know name or cls until instantiated
self.cls = None
self.name = None
self.kwargs = opt_kwargs or {}
if backend is None:
msg = (
"Backend must be specified when initializing an optimizer "
"with a factory."
)
raise ValueError(msg)
self._backend = normalize_backend(backend)
else:
raise ValueError(
"Must provide either an optimizer (`opt`) or a `factory` callable.",
)
# Runtime state
self.instance: Any | None = None
self.parameters: Any | None = None
# Pending serialized internal state (used after from_state())
# Structure:
# PyTorch: {"state_dict": ...}
# TF: {"weights": [...]}
self._pending_state: dict[str, Any] | None = None
[docs]
@classmethod
def from_factory(cls, factory: Callable, *, backend: Backend) -> Optimizer:
"""
Instantiate an :class:`Optimizer` directly from a factory and backend.
Args:
factory (Callable): Callable that produces an optimizer instance.
backend (Backend): Backend enum applied to the optimizer.
Returns:
Optimizer: New wrapper configured to call the provided factory.
"""
return cls(factory=factory, backend=backend)
def __eq__(self, other):
if not isinstance(other, Optimizer):
msg = f"Cannot compare equality between Optimizer and {type(other)}"
raise TypeError(msg)
# Compare config
if not deep_equal(self.get_config(), other.get_config()):
return False
# Compare state
return deep_equal(self.get_state(), other.get_state())
__hash__ = None
# ================================================
# Core Properties
# ================================================
@property
def is_built(self) -> bool:
"""Whether the backend optimizer instance has been constructed."""
return self.instance is not None
@property
def backend(self) -> Backend | None:
"""
Backend enum associated with this optimizer.
Returns:
Backend | None: Resolved backend value if available.
"""
return self._backend
@backend.setter
def backend(self, value: Backend):
self._backend = value
# ================================================
# Representation
# ================================================
def _summary_rows(self) -> list[tuple]:
"""
Return summary rows describing the optimizer configuration.
Returns:
list[tuple]: Key/value pairs rendered in textual summaries.
"""
return [
("name", self.name),
("cls", str(self.cls.__name__ if self.cls else None)),
("kwargs", [(k, str(v)) for k, v in self.kwargs.items()]),
("backend", f"{self.backend!r}"),
]
def __repr__(self):
msg_kwargs = ""
for k, v in self.kwargs.items():
msg_kwargs += f", {k}={v}"
name = self.name if self.name is not None else "<custom>"
return f"Optimizer('{name}'{msg_kwargs})"
# ================================================
# Internal helpers
# ================================================
def _resolve(self) -> Callable:
"""
Resolve a named optimizer to its backend-specific class via introspection.
Returns:
Callable: Optimizer class pulled from the backend module.
Raises:
OptimizerError: If the name is missing or cannot be matched.
BackendNotSupportedError: If the backend lacks optimizer resolution support.
"""
if not isinstance(self.name, str):
raise OptimizerError(
"Optimizer name must be a string to resolve dynamically.",
)
name_lc = self.name.lower()
# Resolve backend optimizer module
if self.backend == Backend.TORCH:
torch = ensure_torch()
module = torch.optim
elif self.backend == Backend.TENSORFLOW:
tf = ensure_tensorflow()
module = tf.keras.optimizers
else:
raise BackendNotSupportedError(
backend=self.backend,
method="Optimizer._resolve()",
)
# Inspect available classes
candidates: dict[str, type] = {}
for attr_name in dir(module):
try:
attr = getattr(module, attr_name)
except Exception: # noqa: BLE001, S112
continue
if not isinstance(attr, type):
continue
# Match class names ignoring case
candidates[attr_name.lower()] = attr
# Resolve optimizer
opt_cls = candidates.get(name_lc)
if opt_cls is None:
available = sorted(candidates.keys())
msg = (
f"Unknown optimizer name '{self.name}' for backend '{self.backend}'. "
f"Available optimizers: {available}"
)
raise OptimizerError(msg)
return opt_cls
def _check_optimizer(self):
"""
Ensure the optimizer has been built before performing backend operations.
Raises:
OptimizerNotSetError: If the optimizer instance is missing.
"""
if not self.is_built:
raise OptimizerNotSetError(message="Optimizer has not been built.")
def _extract_kwargs_from_instance(self):
"""
Populate metadata fields based on an instantiated backend optimizer.
Raises:
ValueError: If the optimizer instance has not been set yet.
"""
if self.instance is None:
raise ValueError("Instance cannot be None.")
# If lazy backend, infer from instance
if self._backend is None:
self._backend = _safe_infer_backend(self.instance)
# Extract kwargs, cls, and cls_name
if self.backend == Backend.TORCH:
self.kwargs = deepcopy(self.instance.defaults)
self.cls = self.instance.__class__
self.name = self.cls.__name__
elif self.backend == Backend.TENSORFLOW:
self.kwargs = deepcopy(self.instance.get_config())
self.cls = self.instance.__class__
self.name = self.cls.__name__
# ================================================
# Build
# ================================================
[docs]
def build(
self,
*,
parameters: Any | None = None,
backend: Backend | None = None,
force_rebuild: bool = False,
):
"""
Instantiate the backend optimizer if not already provided.
Args:
parameters (Any | None):
Trainable parameters required when building PyTorch optimizers.
backend (Backend | None):
Backend override enforced before construction.
force_rebuild (bool):
Whether to rebuild even if the optimizer is already instantiated.
Raises:
OptimizerNotSetError:
If attempting to rebuild without `force_rebuild`.
ValueError:
If backend validations fail or parameters are missing for PyTorch.
BackendNotSupportedError:
If an unsupported backend is requested.
RuntimeError:
If the initialization mode is unsupported.
"""
if self.is_built and not force_rebuild:
msg = (
"Optimizer.built() is being called on an already instantiated "
"optimizer. If you want to rebuild the optimizer, set "
"`force_rebuild=True`."
)
raise OptimizerNotSetError(message=msg)
# Set/validate backend
if backend is not None:
if self.backend is not None and backend != self.backend:
msg = (
"Backend passed to Optimizer.build differs from backend "
f"at init: {backend} != {self.backend}"
)
raise ValueError(msg)
self.backend = backend
if self.backend is None:
raise ValueError("Backend must be set before building optimizer.")
# Instantiate backend-specific optimizer
# Case 1: class + kwargs
if self.cls is not None:
if self.backend == Backend.TORCH:
if parameters is None:
raise ValueError(
"Torch Optimizer requires model parameters during build.",
)
self.parameters = parameters
self.instance = self.cls(self.parameters, **self.kwargs)
elif self.backend == Backend.TENSORFLOW:
self.parameters = None # TF doesn't need parameters at construction
self.instance = self.cls(**self.kwargs)
else:
raise BackendNotSupportedError(
backend=self._backend,
method="Optimizer.build()",
)
# Case 2: factory
elif self._factory is not None:
if self.backend == Backend.TORCH:
if parameters is None:
raise ValueError("Torch optimizer factory requires parameters.")
self.instance = self._factory(parameters)
elif self.backend == Backend.TENSORFLOW:
self.instance = self._factory(None)
else:
raise BackendNotSupportedError(
backend=self._backend,
method="Optimizer.build()",
)
# Extract self.cls, self.name, and self.kwargs from instance
self._extract_kwargs_from_instance()
else:
raise RuntimeError("Unsupported initiatization state.")
# If we have a pending internal state (from from_state), restore it now
if self._pending_state is not None:
self._restore_internal_state(self._pending_state)
self._pending_state = None
# ================================================
# Backprop methods
# ================================================
[docs]
def step(self, grads=None, variables=None):
"""
Perform a backend-specific optimizer step.
Args:
grads (Any | None):
Gradient tensors required by TensorFlow optimizers.
variables (Any | None):
Trainable variables paired with `grads` in TensorFlow.
Raises:
OptimizerNotSetError: If the optimizer has not been built.
ValueError: If TensorFlow requires gradients or variables that are missing.
BackendNotSupportedError: If the backend is unsupported.
"""
self._check_optimizer()
if self._backend == Backend.TORCH:
self.instance.step()
elif self._backend == Backend.TENSORFLOW:
if grads is None or variables is None:
msg = (
"TensorFlow backend requires both `grads` and `variables` to be "
"set in Optimizer.step()."
)
raise ValueError(msg)
self.instance.apply_gradients(zip(grads, variables, strict=True))
else:
raise BackendNotSupportedError(
backend=self._backend,
method="Optimizer.step()",
)
[docs]
def zero_grad(self):
"""
Reset accumulated gradients on the backend optimizer.
Raises:
OptimizerNotSetError: If the optimizer has not been built.
BackendNotSupportedError: If the backend is unsupported.
"""
self._check_optimizer()
if self._backend == Backend.TORCH:
self.instance.zero_grad()
elif self._backend == Backend.TENSORFLOW:
tf = ensure_tensorflow()
for var in self.instance.variables():
var.assign(tf.zeros_like(var))
else:
raise BackendNotSupportedError(
backend=self._backend,
method="Optimizer.zero_grad()",
)
# ================================================
# Configurable
# ================================================
[docs]
def get_config(self) -> dict[str, Any]:
"""
Return the serialized configuration for this optimizer.
Returns:
dict[str, Any]:
Dictionary capturing optimizer definition for reconstruction.
"""
# Prefer to return only cls (str name) + kwargs
if self.is_built:
return {
"opt": str(self.name).lower(),
"opt_kwargs": self.kwargs,
"backend": None
if self.backend is None
else str(self.backend.value).lower(),
}
return {
"opt": None if self.name is None else str(self.name).lower(),
"opt_kwargs": self.kwargs,
"backend": None
if self.backend is None
else str(self.backend.value).lower(),
"factory": self._factory,
}
[docs]
@classmethod
def from_config(cls, config: dict[str, Any]) -> Optimizer:
"""
Instantiate an optimizer from serialized configuration.
Args:
config (dict[str, Any]):
Configuration dictionary produced by :meth:`get_config`.
Returns:
Optimizer: Reconstructed optimizer wrapper.
"""
return cls(**config)
# ================================================
# Stateful
# ================================================
[docs]
def get_state(self) -> dict[str, Any]:
"""
Capture serialized optimizer state including backend internals.
Returns:
dict[str, Any]: State payload containing build flag and backend data.
"""
state = {"is_built": self.is_built}
if self.is_built:
state["internal"] = self._capture_internal_state()
return state
[docs]
def set_state(self, state: dict[str, Any]) -> None:
"""
Store serialized state for later restoration during :meth:`build`.
Args:
state (dict[str, Any]):
Serialized optimizer state produced by :meth:`get_state`.
"""
# Stash pending optimizer internal state; will be applied in build()
if state.get("is_built") is not None:
self._pending_state = state.get("internal")
# ================================================
# Internal state handling
# ================================================
def _capture_internal_state(self) -> dict[str, Any] | None:
"""
Capture backend-specific optimizer state for serialization.
Returns:
dict[str, Any] | None:
Backend payload such as `state_dict` or weights.
"""
if not self.is_built:
return None
if self._backend == Backend.TORCH:
return {"state_dict": self.instance.state_dict()}
if self._backend == Backend.TENSORFLOW:
try:
return {"weights": self.instance.get_weights()}
except AttributeError:
# Optimizer not yet initialized with variables
return {"weights": None}
return None
def _restore_internal_state(self, state: dict[str, Any]) -> None:
"""
Restore backend-specific optimizer state captured during serialization.
Args:
state (dict[str, Any]):
Serialized backend state captured by :meth:`_capture_internal_state`.
Raises:
BackendNotSupportedError: If the backend is unsupported during restoration.
"""
if not self.is_built or state is None:
return
if self._backend == Backend.TORCH:
d_state = state.get("state_dict")
if d_state is not None:
self.instance.load_state_dict(d_state)
elif self._backend == Backend.TENSORFLOW:
weights = state.get("weights")
if weights is not None:
self.instance.set_weights(weights)
else:
raise BackendNotSupportedError(
backend=self._backend,
method="Optimizer._restore_internal_state()",
)
# ================================================
# Serialization
# ================================================
[docs]
def save(self, filepath: Path, *, overwrite: bool = False) -> Path:
"""
Serialize the optimizer to disk using the built-in serializer.
Args:
filepath (Path):
Destination path; suffix may be adjusted to match conventions.
overwrite (bool):
Whether to overwrite an existing artifact.
Returns:
Path: Actual file path written by the serializer.
"""
from modularml.core.io.serialization_policy import SerializationPolicy
from modularml.core.io.serializer import serializer
return serializer.save(
self,
filepath,
policy=SerializationPolicy.BUILTIN,
overwrite=overwrite,
)
[docs]
@classmethod
def load(cls, filepath: Path, *, allow_packaged_code: bool = False) -> Optimizer:
"""
Load a serialized optimizer from disk.
Args:
filepath (Path): Path to a serialized optimizer artifact.
allow_packaged_code (bool): Whether bundled code execution is permitted.
Returns:
Optimizer: Reloaded optimizer instance.
"""
from modularml.core.io.serializer import _enforce_file_suffix, serializer
# Append proper sufficx only if no suffix is given
if Path(filepath).suffix == "":
filepath = _enforce_file_suffix(path=filepath, cls=cls)
return serializer.load(filepath, allow_packaged_code=allow_packaged_code)