Source code for modularml.core.training.loss

"""Backend-agnostic loss wrapper supporting serialization and registry hooks."""

from __future__ import annotations

import inspect
from pathlib import Path
from typing import TYPE_CHECKING, Any

from modularml.utils.data.comparators import deep_equal
from modularml.utils.environment.optional_imports import ensure_tensorflow, ensure_torch
from modularml.utils.errors.exceptions import BackendNotSupportedError, LossError
from modularml.utils.nn.backend import Backend, infer_backend, normalize_backend

if TYPE_CHECKING:
    from collections.abc import Callable


def _safe_infer_backend(obj_or_cls: Any) -> Backend:
    """
    Infer the backend for a loss object or class and enforce validity.

    Args:
        obj_or_cls (Any): Backend-aware instance or class to inspect.

    Returns:
        Backend: Resolved backend enum value.

    Raises:
        ValueError: If the backend cannot be inferred.

    """
    backend = infer_backend(obj_or_cls=obj_or_cls)

    if backend == Backend.NONE:
        msg = "Could not infer backend from loss class. Specify backend explicitly."
        raise ValueError(msg)

    return backend


[docs] class Loss: """ Backend-agnostic wrapper around loss functions used in model training. Description: Supports built-in loss classes from PyTorch and TensorFlow along with custom callables. Provides serialization, backend normalization, and config/state handling to slot into ModularML training workflows. Attributes: name (str | None): Lowercase name of the selected loss when applicable. cls (type | None): Backing loss class when using backend libraries. fn (Callable | None): Direct loss callable if provided. reduction (str): Reduction passed to backend constructors. kwargs (dict[str, Any] | None): Keyword arguments used for construction. _backend (Backend | None): Resolved backend enum. """
[docs] def __init__( self, loss: str | type | Callable | None = None, *, loss_kwargs: dict[str, Any] | None = None, backend: Backend | None = None, reduction: str = "none", factory: Callable | None = None, ): """ Initialize the loss wrapper with a name, class, callable, or factory. Args: loss (str | type | Callable | None): Loss identifier, class, or callable; mutually exclusive with `factory`. loss_kwargs (dict[str, Any] | None): Keyword arguments passed to loss construction. backend (Backend | None): Backend enum; required when `loss` is a string or factory. reduction (str): Reduction argument forwarded to backend constructors. factory (Callable | None): Callable producing a loss instance when invoked. Returns: None: This initializer does not return a value. Raises: ValueError: If both `loss` and `factory` are provided, backend inference fails, or inputs are invalid. LossError If initialization does not provide a valid loss definition. """ # Supported initialization modes: # 1. Name + backend # 2. Loss class # 3. Callable loss function # 4. Callable factory if loss is not None and factory is not None: msg = ( "Provide either a loss fnc/cls/name (`loss`) or a factory " "`factory`, not both." ) raise ValueError(msg) # Runtime attributes self.name: str | None = None # name of importable class (eg, torch "MSELoss") self.cls: type | None = None # loss class (eg, torch.nn.MSELoss) self.fn: Callable | None = None # loss function (callable) self._factory: Callable | None = ( factory # factory to generate a loss class during __call__ ) self.reduction = ( reduction # reduction argument to pass during class construction ) self.kwargs = loss_kwargs # other kwargs to pass to class construction self._backend: Backend | None = normalize_backend(backend) if backend else None self._callable: Callable | None = ( None # built callable -> this is used during __call__ ) # Case 1: loss name if isinstance(loss, str): self.name = loss.lower() if backend is None: msg = ( "Backend must be specified when initializing a loss with a " "string-name." ) raise ValueError(msg) self._backend = normalize_backend(backend) self.cls = self._resolve() self.kwargs = loss_kwargs or {} self.fn = None self._factory = None self._callable = None # Case 2: loss class elif isinstance(loss, type): self.cls = loss self.name = loss.__name__.lower() self._backend = _safe_infer_backend(loss) self.kwargs = loss_kwargs or {} self.fn = None self._factory = None self._callable = None # Case 3: callable loss function/instance elif callable(loss): self.fn = loss self.name = loss.__name__ self._backend = backend or _safe_infer_backend(loss) self.kwargs = loss_kwargs or {} self._factory = None self._callable = None # Case 4: factory elif factory is not None: if backend is None: raise ValueError("Backend must be specified when using loss factory.") self._backend = normalize_backend(backend) else: msg = "Loss must be initialized with name, class, callable, or factory." raise LossError(msg)
def __eq__(self, other): if not isinstance(other, Loss): msg = f"Cannot compare equality between Loss and {type(other)}" raise TypeError(msg) # Compare config if not deep_equal(self.get_config(), other.get_config()): return False if hasattr(self, "get_state"): return deep_equal(self.get_state(), other.get_state()) return True __hash__ = None # ================================================ # Internal helpers # ================================================ def _resolve(self) -> Callable: """ Resolve the configured :class:`Loss` name to a backend-specific class. Returns: Callable: Backend loss class matching `self.name`. Raises: LossError: If the loss name is missing or cannot be resolved. BackendNotSupportedError: If the backend lacks known loss modules. """ if not isinstance(self.name, str): raise LossError("Loss name must be a string to resolve dynamically.") name_lc = self.name.lower() # Resolve backend optimizer module if self.backend == Backend.TORCH: torch = ensure_torch() module = torch.nn keywords = ["loss"] elif self.backend == Backend.TENSORFLOW: tf = ensure_tensorflow() module = tf.keras.losses keywords = [] else: raise BackendNotSupportedError( backend=self.backend, method="Loss._resolve()", ) # Inspect available classes candidates: dict[str, type] = {} for attr_name in dir(module): try: attr = getattr(module, attr_name) except Exception: # noqa: BLE001, S112 continue if not isinstance(attr, type): continue # Match class names ignoring case attr_key = attr_name.lower() candidates[attr_key] = attr for _k in keywords: if _k in attr_key: _attr_key = attr_key.replace(_k, "") candidates[_attr_key] = attr # Resolve loss loss_cls = candidates.get(name_lc) if loss_cls is None: available = sorted(candidates.keys()) msg = f"Unknown loss name '{self.name}' for backend '{self.backend}'. Available losses: {available}" raise LossError(msg) return loss_cls # ================================================ # Properties # ================================================ @property def allowed_keywords(self) -> list[str]: """ List valid keyword arguments for the currently built loss callable. Returns: list[str]: Argument names accepted by :attr:`_callable`. Raises: RuntimeError: If the loss has not been built yet. """ if self._callable is None: msg = "Loss callable has not been built yet." raise RuntimeError(msg) # Get the signature object sig = inspect.signature(self._callable) # Iterate through the parameters in the signature arg_names = [param.name for param in sig.parameters.values()] return arg_names @property def backend(self) -> Backend: """ Backend configured for this :class:`Loss`. Returns: Backend: Resolved backend enum. Raises: LossError: If the backend has not been set. """ if self._backend is None: raise LossError("Loss backend has not been resolved.") return self._backend @property def is_built(self) -> bool: """ Whether the underlying loss callable has been instantiated. Returns: bool: True if :attr:`_callable` is available. """ return self._callable is not None # ================================================ # Build # ================================================
[docs] def build( self, *, backend: Backend | None = None, force_rebuild: bool = False, **kwargs, ): """ Instantiate the loss callable if not already provided. Args: backend (Backend | None): Backend to use, overriding the inferred backend if provided. force_rebuild (bool): Whether to rebuild even if already instantiated. **kwargs (Any): Additional keyword arguments forwarded to loss construction. Raises: LossError: If rebuilding an already built loss without `force_rebuild`. ValueError: If backend mismatches occur or no backend is set before building. BackendNotSupportedError: If the backend lacks constructor support. """ if self.is_built and not force_rebuild: raise LossError( message=( "Loss.built() is being called on an already instantiated loss. " "If you want to rebuild the loss, set `force_rebuild=True`." ), ) # Set/validate backend if backend is not None: if self.backend is not None and backend != self.backend: msg = ( "Backend passed to Loss.build differs from backend at init: " f"{backend} != {self.backend}" ) raise ValueError(msg) self.backend = backend if self.backend is None: raise ValueError("Backend must be set before building loss.") # Update kwargs self.kwargs |= kwargs # Cases 1-2: resolved class if self.cls is not None: if self.backend == Backend.TORCH: # noqa: SIM114 self._callable = self.cls(reduction=self.reduction, **self.kwargs) elif self.backend == Backend.TENSORFLOW: self._callable = self.cls(reduction=self.reduction, **self.kwargs) else: raise BackendNotSupportedError(self.backend, "Loss.build") return # Case 2: callable loss function if self.fn is not None: self._callable = self.fn return # Case 3: factory if self._factory is not None: self._callable = self._factory(**self.kwargs) return
# ================================================ # Callable # ================================================ def __call__(self, *args, **kwargs): """ Execute the underlying loss callable. Args: *args (Any): Positional arguments forwarded to the backend callable. **kwargs (Any): Keyword arguments forwarded to the backend callable. Returns: Any: Output of the loss callable. Raises: LossError: If the callable fails during execution. """ if not self.is_built: self.build() try: return self._callable(*args, **kwargs) except Exception as e: msg = f"Loss execution failed: {e}" raise LossError(msg) from e # ================================================ # Configurable # ================================================
[docs] def get_config(self) -> dict[str, Any]: """ Return configuration required to reconstruct this loss wrapper. Returns: dict[str, Any]: Serialized configuration capturing loss definition. """ name = None if self.name is None else str(self.name).lower() return { "loss": self.fn or name or None, "loss_kwargs": self.kwargs, "backend": None if self.backend is None else str(self.backend.value).lower(), "reduction": self.reduction, "factory": self._factory, }
[docs] @classmethod def from_config(cls, config: dict[str, Any]) -> Loss: """ Construct a :class:`Loss` from configuration. Args: config (dict[str, Any]): Serialized configuration dictionary. Returns: Loss: Rebuilt loss wrapper instance. """ return cls(**config)
# ================================================ # Representation # ================================================ def __repr__(self): return f"Loss(name={self.name}, backend={self._backend})" # ================================================ # Serialization # ================================================
[docs] def save(self, filepath: Path, *, overwrite: bool = False) -> Path: """ Serialize this loss wrapper to disk. Args: filepath (Path): Destination path; suffix may be adjusted to match ModularML conventions. overwrite (bool): Whether to overwrite existing files. Returns: Path: Actual file path written by the serializer. """ from modularml.core.io.serialization_policy import SerializationPolicy from modularml.core.io.serializer import serializer return serializer.save( self, filepath, policy=SerializationPolicy.BUILTIN, overwrite=overwrite, )
[docs] @classmethod def load(cls, filepath: Path, *, allow_packaged_code: bool = False) -> Loss: """ Load a loss wrapper from disk. Args: filepath (Path): Path to a serialized loss artifact. allow_packaged_code (bool): Whether packaged code execution is allowed. Returns: Loss: Reloaded loss instance. """ from modularml.core.io.serializer import _enforce_file_suffix, serializer # Append proper sufficx only if no suffix is given if Path(filepath).suffix == "": filepath = _enforce_file_suffix(path=filepath, cls=cls) return serializer.load(filepath, allow_packaged_code=allow_packaged_code)