"""Utilities for applying configured losses to model outputs."""
from __future__ import annotations
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
import numpy as np
from modularml.core.data.sample_schema import DOMAIN_TARGETS
from modularml.core.data.schema_constants import DOMAIN_OUTPUTS
from modularml.core.experiment.experiment_context import ExperimentContext
from modularml.core.references.experiment_reference import ResolutionError
from modularml.core.references.featureset_reference import FeatureSetColumnReference
from modularml.core.references.model_io_reference import ModelOutputReference
from modularml.core.topology.model_node import ModelNode
from modularml.utils.data.conversion import align_ranks, convert_to_format, to_numpy
from modularml.utils.data.data_format import get_data_format_for_backend
from modularml.utils.data.shape_utils import ensure_tuple_shape
from modularml.utils.environment.optional_imports import ensure_tensorflow, ensure_torch
from modularml.utils.errors.exceptions import BackendMismatchError
from modularml.utils.nn.backend import Backend
from modularml.utils.representation.summary import Summarizable
if TYPE_CHECKING:
from modularml.core.data.batch import Batch
from modularml.core.data.batch_view import BatchView
from modularml.core.data.execution_context import ExecutionContext
from modularml.core.references.execution_reference import TensorLike
from modularml.core.references.reference_like import ReferenceLike
from modularml.core.topology.graph_node import GraphNode
from modularml.core.training.loss import Loss
[docs]
class AppliedLoss(Summarizable):
"""
Bind a :class:`Loss` to a :class:`ModelNode` with resolved inputs.
Attributes:
loss (Loss):
Wrapped :class:`Loss` instance evaluated for each execution step.
weight (float):
Scalar multiplier applied to computed loss values before aggregation.
label (str | None):
Friendly label used when logging summaries for this applied loss.
node_id (str):
Identifier of the :class:`ModelNode` targeted by this loss.
inputs (dict[str, ReferenceLike]):
Mapping of loss argument names to runtime references.
"""
[docs]
def __init__(
self,
loss: Loss,
on: str | ModelNode,
inputs: list[ReferenceLike] | dict[str, ReferenceLike],
*,
weight: float = 1.0,
label: str | None = None,
):
"""
Define a :class:`Loss` applied on a specified :class:`ModelNode`.
Args:
loss (Loss):
Loss instance to apply for each execution step.
on (str | ModelNode):
Node label/ID or :class:`ModelNode` object indicating where
the loss attaches.
inputs (list[ReferenceLike] | dict[str, ReferenceLike]):
Positional or keyword references resolved as loss arguments.
weight (float):
Scalar multiplier applied to the computed loss value.
label (str | None):
Optional label used when logging or summarizing this applied loss.
Raises:
TypeError:
If `on` is not a :class:`ModelNode` or string, or `inputs` has
an unsupported type.
"""
self.loss = loss
self.weight = float(weight)
self.label = label or loss.name
# Resolve ModelNode
if isinstance(on, ModelNode):
self.node_id = on.node_id
elif isinstance(on, str):
exp_ctx = ExperimentContext.get_active()
self.node_id = exp_ctx.get_node(
val=on,
enforce_type="ModelNode",
).node_id
else:
msg = f"`on` must be a ModelNode or string. Received: {type(on)}."
raise TypeError(msg)
# Normalize inputs to dict[str, str]
if isinstance(inputs, list):
self.inputs = {str(i): v for i, v in enumerate(inputs)}
elif isinstance(inputs, Mapping):
self.inputs = dict(inputs)
else:
raise TypeError("`inputs` must be list[str] or dict[str, str]")
def __eq__(self, other: AppliedLoss):
if not isinstance(other, AppliedLoss):
msg = f"Cannot compare equality between AppliedLoss and {type(other)}."
raise TypeError(msg)
return self.get_config() == other.get_config()
__hash__ = None
# ================================================
# Properties
# ================================================
@property
def backend(self) -> Backend:
"""
Backend used by the underlying :class:`Loss`.
Returns:
Backend: Backend declared on the wrapped :class:`Loss`.
"""
return self.loss.backend
# ================================================
# Input resolution
# ================================================
def _resolve_input(
self,
spec: str,
ctx: ExecutionContext,
) -> tuple[TensorLike, TensorLike, TensorLike]:
"""
Resolve a loss input reference into tensor data, weights, and masks.
Args:
spec (str): Reference string such as `outputs.default` or a FeatureSet column path.
ctx (ExecutionContext): Execution context containing upstream batches and outputs.
Returns:
tuple[TensorLike, TensorLike, TensorLike]: Tuple of tensor-like data, weights, and masks.
Raises:
BackendMismatchError: If the :class:`Loss` backend does not match the :class:`ModelNode`.
ResolutionError: If the reference cannot be resolved to a model output or FeatureSet column.
"""
exp_ctx = ExperimentContext.get_active()
# Validate node & backend
node: ModelNode = exp_ctx.get_node(
node_id=self.node_id,
enforce_type="ModelNode",
)
if self.loss.backend != node.backend:
msg = (
f"ModelNode ('{node.label}') and Loss ('{self.loss.name}') "
f"backends do not match. {node.backend} != {self.loss.backend}."
)
raise BackendMismatchError(message=msg)
# Remove any references to `on_node.label`
spec = spec.replace(f"{node.label}.", "")
# Convert to OutputRef input string starts with "outputs" or "targets"
if any(spec.startswith(x) for x in [DOMAIN_OUTPUTS, DOMAIN_TARGETS]):
return self._resolve_model_output(spec=spec, node=node, ctx=ctx)
# Otherwise, convert to a FeatureSetColumnReference
return self._resolve_featureset_column(spec=spec, node=node, ctx=ctx)
def _resolve_model_output(
self,
spec: str,
node: ModelNode,
ctx: ExecutionContext,
) -> tuple[TensorLike, TensorLike, TensorLike]:
"""
Resolve a model output reference into tensor, weight, and mask tuples.
Args:
spec (str):
Domain/role specification referencing :attr:`ExecutionContext.outputs`.
node (ModelNode):
Node supplying the outputs referenced by this loss.
ctx (ExecutionContext):
Execution context that holds model outputs for the current step.
Returns:
tuple[TensorLike, TensorLike, TensorLike]: Tuple containing tensor data, weights, and masks.
Raises:
ResolutionError: If the reference is malformed or the role cannot be inferred uniquely.
"""
# Extract role key
domain, role = spec, None
if "." in spec:
parts = spec.split(".")
if len(parts) == 1:
domain, role = parts[0], None
elif len(parts) == 2:
domain, role = parts
else:
msg = (
f"AppliedLoss input '{spec}' could not resolved. Too many "
f"components: {parts}."
)
raise ResolutionError(msg)
# Get model outputs
output_batch = ctx.outputs[node.node_id]
if role is None:
if len(output_batch.available_roles) != 1:
msg = (
f"Applied loss spec '{spec}' must specify a `role` when multiple "
"roles exist in the output data. Available roles: "
f"{output_batch.available_roles}."
)
raise ResolutionError(msg)
role = output_batch.available_roles[0]
# Resolve reference to tensor like data
ref = ModelOutputReference(
node_label=node.label,
node_id=node.node_id,
role=role,
domain=domain,
)
tensor_like = ref.resolve(ctx=ctx)
# Grab weights and mask from batch
weights = output_batch.role_weights[role]
mask = output_batch.role_masks[role]
return tensor_like, weights, mask
def _resolve_featureset_column(
self,
spec: str,
node: ModelNode,
ctx: ExecutionContext,
) -> tuple[TensorLike, TensorLike, TensorLike]:
"""
Resolve a FeatureSet column reference into tensor, weight, and mask tuples.
Args:
spec (str):
Column path referencing upstream FeatureSet data.
node (ModelNode):
Node consuming the upstream batch derived from the FeatureSet.
ctx (ExecutionContext):
Execution context providing FeatureSet :class:`BatchView` instances.
Returns:
tuple[TensorLike, TensorLike, TensorLike]:
Tuple containing tensor data, weights, and masks.
Raises:
ResolutionError:
If the upstream FeatureSet cannot be inferred or lacks a usable role.
"""
bv: BatchView = self._get_upstream_view(node=node, ctx=ctx)
ref = FeatureSetColumnReference.from_string(
val=spec,
experiment=ExperimentContext.get_active(),
known_attrs={
"node_id": bv.source.node_id,
"node_label": bv.source.label,
},
)
# Materialize batch
b: Batch = bv.materialize_batch(
fmt=get_data_format_for_backend(backend=self.backend),
columns=[f"{ref.domain}.{ref.key}.{ref.rep}"],
)
# Infer role
role = None
if len(b.available_roles) == 1:
role = b.available_roles[0]
elif "default" in b.available_roles:
role = "default"
elif "anchor" in b.available_roles:
role = "anchor"
else:
msg = (
f"AppliedLoss input '{spec}' could not resolved. Role must be "
f"specified when multiple exists. Available: {b.available_roles}."
)
raise ResolutionError(msg)
# Get domain data (tensor like)
tensor_like = b.role_data.get_data(role=role, domain=ref.domain)
# Create dummy weights (no masking available)
weights = np.ones(shape=ensure_tuple_shape(tensor_like.shape))
mask = np.ones(shape=len(weights), dtype=np.int8)
return tensor_like, weights, mask
def _get_upstream_view(
self,
node: ModelNode,
ctx: ExecutionContext,
) -> BatchView:
"""
Determine the :class:`BatchView` feeding the head node on this branch.
Args:
node (ModelNode):
Node whose upstream FeatureSet should be located.
ctx (ExecutionContext):
Execution context storing the head inputs per node.
Returns:
BatchView:
Upstream view that supplies data to the requested node.
Raises:
ResolutionError: If zero or multiple upstream FeatureSets feed the node.
"""
exp_ctx = ExperimentContext.get_active()
# All head node IDs in this ExecutionContext
head_node_ids = [x[0] for x in ctx.inputs]
# Get all upstream head nodes of `node`
upstream_views: list[BatchView] = []
visited: set[str] = set()
def _get_input_view(n: GraphNode):
# Record visited to protect against incidental loops
if n.node_id in visited:
return
visited.add(n.node_id)
# If this is a head node, add all inputs and return
if n.node_id in head_node_ids:
bvs: list[BatchView] = [
bv for inp_key, bv in ctx.inputs.items() if inp_key[0] == n.node_id
]
upstream_views.extend(bvs)
return
# Otherwise, recurse on upstream node
for ref in n._upstream_refs:
up_n = ref.resolve(ctx=exp_ctx)
_get_input_view(n=up_n)
_get_input_view(n=node)
if len(upstream_views) != 1:
msg = (
"FeatureSet-column-based loss inputs require that the applied-to-node "
f"has exactly one upstream FeatureSet. Detected: {len(upstream_views)}."
)
raise ResolutionError(msg)
return next(iter(upstream_views))
# ================================================
# Computation
# ================================================
def _apply_weights(self, raw_loss: Any, weights: Any) -> Any:
"""
Apply sample weights to the raw backend loss output.
Args:
raw_loss (Any):
Backend-specific tensor or array returned by :class:`Loss`.
weights (Any):
Per-sample weights aligned with `raw_loss`.
Returns:
Any: Weighted scalar compatible with the configured backend.
"""
# Apply sample weighting -> convert mean_weights to correct backend tensor
if self.backend == Backend.TORCH:
torch = ensure_torch()
# Ensure loss has shape (batch_size, )
raw_loss = raw_loss.view(-1)
w = torch.as_tensor(weights, device=raw_loss.device)
return torch.sum(raw_loss * w) * self.weight / len(raw_loss)
if self.backend == Backend.TENSORFLOW:
tf = ensure_tensorflow()
# Ensure loss has shape (batch_size, )
raw_loss = tf.reshape(raw_loss, [-1])
w = tf.convert_to_tensor(weights, dtype=raw_loss.dtype)
return tf.reduce_sum(raw_loss * w) * self.weight / len(raw_loss)
# Assume NumPy
raw_loss = np.reshape(raw_loss, (-1,))
w = np.reshape(weights, (-1,))
return np.sum(raw_loss * w) * self.weight / len(raw_loss)
[docs]
def compute(self, ctx: ExecutionContext) -> Any:
"""
Compute the weighted loss for a single execution step.
Args:
ctx (ExecutionContext):
Execution context supplying model outputs and upstream batches.
Returns:
Any: Backend-specific scalar/tensor representing the weighted loss value.
Raises:
BackendMismatchError:
If the :class:`Loss` backend differs from the :class:`ModelNode`
backend.
ResolutionError:
If any configured reference cannot be resolved from the execution
context.
"""
# Map self.inputs.keys() to batch tensor data
kw_data: dict[str, Any] = {}
kw_weights: list[np.ndarray] = []
kw_masks: list[np.ndarray] = []
# Collect required input(s) for each loss argument
for arg, spec in self.inputs.items():
data, weights, mask = self._resolve_input(spec=spec, ctx=ctx)
kw_data[arg] = convert_to_format(
data=data,
fmt=get_data_format_for_backend(backend=self.backend),
)
kw_weights.append(to_numpy(weights))
kw_masks.append(to_numpy(mask).astype(bool))
# Ensure all kwargs have matching shapes (aligns singletons)
ref_key = next(iter(kw_data.keys()))
for k in [x for x in kw_data if x != ref_key]:
kw_data[ref_key], kw_data[k] = align_ranks(
kw_data[ref_key],
kw_data[k],
backend=self.backend,
)
# Combine masks (logical AND across inputs)
combined_mask = np.logical_and.reduce(kw_masks) # shape: (n_samples, )
# Combine weights (mean across inputs, then apply mask)
mean_weights = np.mean(
np.stack(kw_weights, axis=0),
axis=0,
).reshape(-1) # shape: (n_samples, )
mean_weights = mean_weights * combined_mask.astype(mean_weights.dtype)
# Call loss function (convert to positional args if needed)
if all(k.isdigit() for k in kw_data):
args = [kw_data[str(i)] for i in range(len(kw_data))]
raw = self.loss(*args)
else:
raw = self.loss(**kw_data)
# Apply weighting
return self._apply_weights(raw, mean_weights)
# ================================================
# Representation
# ================================================
def _summary_rows(self) -> list[tuple]:
"""
Return summary table rows representing this applied loss configuration.
Returns:
list[tuple]: Sequence of key/value tuples rendered in summaries.
"""
rows: list[tuple] = [
("label", str(self.label)),
(
"loss",
self.loss._summary_rows()
if hasattr(self.loss, "_summary_rows")
else f"{self.loss!r}",
),
("inputs", str(self.inputs)),
("weight", str(self.weight)),
]
return rows
def __repr__(self):
return (
f"AppliedLoss(label={self.label!r}, loss={self.loss.name!r}, "
f"inputs={self.inputs!r}, weight={self.weight})"
)
# ================================================
# Configurable
# ================================================
[docs]
def get_config(self) -> dict[str, Any]:
"""
Return configuration required to reconstruct this applied loss.
Returns:
dict[str, Any]:
Serialized dictionary capturing loss, node, inputs, weight, and label.
"""
return {
"loss": self.loss, # not JSON safe
"on": self.node_id,
"inputs": self.inputs, # not JSON safe
"weight": self.weight,
"label": self.label,
}
[docs]
@classmethod
def from_config(cls, config: dict[str, Any]) -> AppliedLoss:
"""
Construct an :class:`AppliedLoss` from configuration.
Args:
config (dict[str, Any]):
Dictionary produced by :meth:`get_config`.
Returns:
AppliedLoss: Rehydrated applied loss instance.
"""
return cls(**config)