Source code for modularml.core.models.torch_base_model

"""Base classes for native PyTorch models in ModularML."""

from abc import ABC
from typing import Any

import numpy as np

from modularml.core.models.base_model import BaseModel
from modularml.utils.environment.optional_imports import check_torch
from modularml.utils.nn.backend import Backend

torch = check_torch()

if torch is not None:
    TorchModuleBase = torch.nn.Module
else:

    class TorchModuleBase:
        def __init__(self, *args: Any, **kwargs: Any) -> None:  # noqa: ARG002
            raise ImportError(
                "PyTorch is required to use TorchBaseModel. "
                "Install it with `pip install torch`.",
            )


[docs] class TorchBaseModel(BaseModel, TorchModuleBase, ABC): """ Base class for ModularML-native PyTorch models. Description: Intended for framework-owned PyTorch architectures such as :class:`modularml.models.torch.SequentialMLP`. User-defined modules can subclass this base, although :class:`TorchModelWrapper` is usually simpler for existing :class:`torch.nn.Module` graphs. """
[docs] def __init__(self, **init_args: Any): """Initialize the PyTorch + :class:`BaseModel` inheritance chain.""" if torch is None: raise ImportError( "PyTorch is required to instantiate TorchBaseModel. " "Install it with `pip install torch`.", ) # torch.nn.Module must be initialized first torch.nn.Module.__init__(self) # BaseModel handles backend + built flag _ = init_args.pop("backend", None) super().__init__(backend=Backend.TORCH, **init_args)
# ================================================ # Model Weights (Stateful) # ================================================
[docs] def get_weights(self) -> dict[str, np.ndarray]: """Return PyTorch tensors as numpy arrays via :meth:`state_dict`.""" if not self.is_built: return {} return {k: v.detach().cpu().numpy() for k, v in self.state_dict().items()}
[docs] def set_weights(self, weights: dict[str, np.ndarray]) -> None: """Restore numpy-based weights produced by :meth:`get_weights`.""" if not weights: return torch_state = {k: torch.as_tensor(v) for k, v in weights.items()} self.load_state_dict(torch_state, strict=True)
[docs] def reset_weights(self) -> None: """Re-initialize all model weights using each layer's default initializer.""" def _reset(m: TorchModuleBase) -> None: if hasattr(m, "reset_parameters") and callable(m.reset_parameters): m.reset_parameters() self.apply(_reset)