Source code for modularml.core.models.torch_base_model
"""Base classes for native PyTorch models in ModularML."""
from abc import ABC
from typing import Any
import numpy as np
import torch
from modularml.core.models.base_model import BaseModel
from modularml.utils.nn.backend import Backend
[docs]
class TorchBaseModel(BaseModel, torch.nn.Module, ABC):
"""
Base class for ModularML-native PyTorch models.
Description:
Intended for framework-owned PyTorch architectures such as
:class:`modularml.models.torch.SequentialMLP`. User-defined modules
can subclass this base, although :class:`TorchModelWrapper` is
usually simpler for existing :class:`torch.nn.Module` graphs.
"""
[docs]
def __init__(self, **init_args: Any):
"""Initialize the PyTorch + :class:`BaseModel` inheritance chain."""
# torch.nn.Module must be initialized first
torch.nn.Module.__init__(self)
# BaseModel handles backend + built flag
_ = init_args.pop("backend", None)
super().__init__(backend=Backend.TORCH, **init_args)
# ================================================
# Model Weights (Stateful)
# ================================================
[docs]
def get_weights(self) -> dict[str, np.ndarray]:
"""Return PyTorch tensors as numpy arrays via :meth:`state_dict`."""
if not self.is_built:
return {}
return {k: v.detach().cpu().numpy() for k, v in self.state_dict().items()}
[docs]
def set_weights(self, weights: dict[str, np.ndarray]) -> None:
"""Restore numpy-based weights produced by :meth:`get_weights`."""
if not weights:
return
torch_state = {k: torch.as_tensor(v) for k, v in weights.items()}
self.load_state_dict(torch_state, strict=True)