Source code for modularml.callbacks.eval_loss_metric
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal
from modularml.callbacks.metric import EvaluationMetric, MetricResult
from modularml.core.training.applied_loss import AppliedLoss
if TYPE_CHECKING:
from modularml.callbacks.evaluation import EvaluationCallbackResult
from modularml.core.data.execution_context import ExecutionContext
[docs]
class EvalLossMetric(EvaluationMetric):
"""
Extracts a scalar loss from an Evaluation callback and logs it as a metric.
This built-in EvaluationMetric reads the results of its parent
Evaluation callback and extracts the aggregated loss value for a
specific node. The extracted value is logged as a named metric
(default: "val_loss") into the MetricStore.
Example:
Below showcases created a tracked "val_loss" during a training phase.
>>> mse_loss = AppliedLoss(...) # doctest: +SKIP
>>> val_metric = EvalLossMetric(
... loss=mse_loss, name="val_loss"
... ) # doctest: +SKIP
>>> eval_cb = Evaluation( # doctest: +SKIP
... eval_phase=eval_phase, metrics=[val_metric]
... )
>>> phase.add_callback(eval_cb) # doctest: +SKIP
"""
[docs]
def __init__(
self,
*,
loss: AppliedLoss,
reducer: Literal["sum", "mean"] = "mean",
name: str = "val_loss",
) -> None:
"""
Initialize a ValidationLossMetric.
Args:
loss (AppliedLoss):
An applied loss to track. Will be appended to the parent
EvalPhase in Evaluation, if not already present.
reducer (Literal["sum", "mean"], optional):
How to aggregate per-batch losses from the evaluation into
a single scalar. Defaults to "mean".
name (str, optional):
The metric name to log under. Defaults to "val_loss".
mode (Literal["min", "max"], optional):
Whether lower or higher values are better. Defaults to "min".
"""
super().__init__(name=name)
# Validate loss
if not isinstance(loss, AppliedLoss):
msg = f"Expected type of AppliedLoss. Received: {type(loss)}."
raise TypeError(msg)
self._loss = loss
# Validate reducer
red_methods = ["sum", "mean"]
if reducer not in red_methods:
msg = f"Expected one of {red_methods}. Received: {reducer}."
raise ValueError(msg)
self._reducer = reducer
# ================================================
# Configurable
# ================================================
[docs]
def get_config(self) -> dict[str, Any]:
"""Return configuration details required to reconstruct this callback."""
cfg = super().get_config()
cfg.update(
{
"loss": self._loss.get_config(),
"reducer": self._reducer,
},
)
return cfg
[docs]
@classmethod
def from_config(cls, config: dict) -> EvalLossMetric:
"""Construct from config data."""
return cls(
loss=AppliedLoss.from_config(config["loss"]),
reducer=config.get("reducer", "mean"),
name=config.get("name", "val_loss"),
)