Source code for modularml.core.sampling.similiarity_condition

"""Declarative similarity/dissimilarity rules for sampler roles."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal

import numpy as np

from modularml.core.io.protocols import Configurable

if TYPE_CHECKING:
    from collections.abc import Callable


[docs] @dataclass class SimilarityCondition(Configurable): """ A flexible rule for defining similarity or dissimilarity between two samples. Attributes: mode ({"similar","dissimilar"}, default="similar"): Whether to select pairs that are within tolerance ("similar") \ or explicitly outside tolerance ("dissimilar"). tolerance (float, default=0.0): Numeric threshold for values to be considered "similar". \ Example: `tolerance=0.05` means `|a - b| <= 0.05` is a valid match. metric (Callable[[float, float], float] | None, default=None): Optional custom distance function. If not provided: - Numeric types use absolute difference. - Other types use equality. weight_mode ({"binary","linear","exp"}, default="binary"): Strategy for assigning weights to valid or fallback pairs: - "binary": All matches get weight 1.0; non-matches get 0.1 if `allow_fallback=True`. - "linear": Weight = tolerance / diff (≥1 for matches, <1 for non-matches). - "exp": Weight = exp(1 - diff / tolerance), clipped at `max_weight`. max_weight (float, default=100): Maximum weight allowed when using `weight_mode != 'uniform'` min_weight (float, default=0.1): Minimum weight allowed when using `weight_mode != 'uniform'` allow_fallback (bool, default=False): If True, pairs that fail the match condition are not discarded; \ they are instead given a down-weighted score (<1). \ If False, non-matches always score 0.0. Weight semantics: - Matches always receive weight ≥ 1.0 (better matches → larger weight). - Non-matches receive weight < 1.0 if `allow_fallback=True`, else 0.0. - `fallback=False` is equivalent to \ `weightmode='binary', min_weight=0.0, max_weight=1.0` """ mode: Literal["similar", "dissimilar"] = "similar" tolerance: float = 0.0 metric: Callable[[float, float], float] | None = None weight_mode: Literal["binary", "linear", "exp"] = "binary" max_weight: float = 1.0 min_weight: float = 0.0 allow_fallback: bool = False def __postinit__(self): """Validate mode and weight-mode configuration.""" valid_modes = ["similar", "dissimilar"] if self.mode not in valid_modes: msg = f"`mode` must be one of: {valid_modes}. Received: {self.mode}" raise ValueError(msg) valid_weight_modes = ["binary", "linear", "exp"] if self.weight_mode not in valid_weight_modes: msg = ( f"`weight_mode` must be one of: {valid_weight_modes}. " f"Received: {self.weight_mode}" ) raise ValueError(msg)
[docs] def score(self, a, b) -> float: """ Compute the similarity/dissimilarity score between two values. Args: a: First value (anchor). b: Second value (candidate). Returns: float: A non-negative weight score. Notes: - For matches (diff <= tolerance if mode="similar"): Score >= 1.0 (better matches = larger weight). - For non-matches: Score < 1.0 if `allow_fallback=True`. Score = 0.0 if `allow_fallback=False`. """ # Step 1: compute difference if self.metric: diff = abs(self.metric(a, b)) elif isinstance(a, (int, float, np.number)) and isinstance( b, (int, float, np.number), ): diff = abs(a - b) else: diff = 0 if a == b else float("inf") # Step 2: check match condition is_match = diff <= self.tolerance if self.mode == "dissimilar": is_match = not is_match if not self.allow_fallback: return int(is_match) # Step 3: assign weight # - All valid matches get a weight >= 1 # - All invalid matches get a weight < 1 if self.weight_mode == "binary": return self.max_weight if is_match else self.min_weight if self.weight_mode == "linear": # Inverse linear weight based on how much better than match condition this pairing is if self.mode == "similar": # Smaller diff = higher weight # e.g. diff=0.2, tol=0.5 -> 2.5 return np.clip( self.tolerance / max(diff, 1e-9), a_max=self.max_weight, a_min=self.min_weight, ).__float__() # Dissimilar = Larger diff -> higher weight # e.g. diff=1.0, tol=0.5 -> 2.0 return np.clip( diff / max(self.tolerance, 1e-9), a_max=self.max_weight, a_min=self.min_weight, ).__float__() if self.weight_mode == "exp": # Exponential weight based on how much better than match condition this pairing is if self.mode == "similar": # Smaller diff = higher weight # e.g. diff=0.2, tol=0.5 -> np.exp(1 - 0.2/0.5) = 1.822 return np.clip( float(np.exp(1 - diff / max(self.tolerance, 1e-9))), a_max=self.max_weight, a_min=self.min_weight, ).__float__() # Dissimilar = Larger diff -> higher weight # e.g., diff=1.0, tol=0.5 -> np.exp(1/0.5 - 1) = 2.718 return np.clip( float(np.exp(diff / max(self.tolerance, 1e-9) - 1)), a_max=self.max_weight, a_min=self.min_weight, ).__float__() return 0.0
# ================================================ # Configuration # ================================================
[docs] def get_config(self) -> dict[str, Any]: """ Return a JSON-serializable configuration for this condition. Returns: dict[str, Any]: Configuration capturing all constructor arguments. """ return { "mode": self.mode, "tolerance": self.tolerance, "metric": self.metric, "weight_mode": self.weight_mode, "max_weight": self.max_weight, "min_weight": self.min_weight, "allow_fallback": self.allow_fallback, }
[docs] @classmethod def from_config(cls, cfg: dict[str, Any]) -> SimilarityCondition: """ Reconstruct a condition from configuration. Args: cfg (dict[str, Any]): Dictionary previously produced by :meth:`get_config`. Returns: SimilarityCondition: Rehydrated similarity condition. """ return cls(**cfg)