Source code for metatrain.utils.abc

from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import (
    Any,
    Dict,
    Generic,
    List,
    Literal,
    Optional,
    TypeVar,
    Union,
)

import torch
from metatensor.torch import Labels, TensorMap
from metatomic.torch import (
    AtomisticModel,
    ModelMetadata,
    ModelOutput,
    System,
)

from metatrain.utils.data.dataset import Dataset, DatasetInfo


HypersType = TypeVar("HypersType")


[docs] class ModelInterface(torch.nn.Module, Generic[HypersType], metaclass=ABCMeta): """ Abstract base class for a machine learning model in metatrain. All architectures in metatrain must be implemented as sub-class of this class, and implement the corresponding methods. :param hypers: A dictionary with the model's hyper-parameters. :param dataset_info: Information containing details about the dataset, such as target quantities and atomic types. :param metadata: Metadata about the model, e.g. author, description, and references. """ __checkpoint_version__: int """The current version of the model's checkpoint. This is used to upgrade checkpoints produced with earlier versions of the code. See :ref:`ckpt_version` for more information.""" __supported_devices__: List[torch.device] """List of torch devices supported by this model architecture. They should be sorted in order of preference since ``metatrain`` will use this and ``__supported_dtypes__`` to determine, based on the user request and machines’ availability, the optimal ``dtype`` and ``device`` for training. """ __supported_dtypes__: List[torch.dtype] """List of torch dtypes supported by this model architecture. They should be sorted in order of preference since ``metatrain`` will use this and ``__supported_devices__`` to determine, based on the user request and machines’ availability, the optimal ``dtype`` and ``device`` for training. """ __default_metadata__: ModelMetadata """Default metadata for this model architecture. Can be used to provide references that will be stored in the exported model. The references are stored in a dictionary with keys ``implementation`` and ``architecture``. The ``implementation`` key should contain references to the software used in the implementation of the architecture, while the ``architecture`` key should contain references about the general architecture. """ def __init__( self, hypers: HypersType, dataset_info: DatasetInfo, metadata: ModelMetadata ) -> None: """""" super().__init__() required_attributes = [ "__checkpoint_version__", "__supported_devices__", "__supported_dtypes__", "__default_metadata__", ] for attribute in required_attributes: if not hasattr(self.__class__, attribute): raise TypeError( f"missing '{attribute}' class attribute for " f"'{self.__class__.__module__}.{self.__class__.__name__}'" ) self.hypers = hypers """The model hyper passed at initialization""" self.dataset_info = dataset_info """The dataset info passed at initialization""" self.metadata = metadata """The metadata passed at initialization"""
[docs] @abstractmethod def forward( self, systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: """ Execute the model for the given ``systems``, computing the requested ``outputs``. :param systems: List of systems to evaluate the model on. :param outputs: Dictionary of outputs that the model should compute. :param selected_atoms: Optional ``Labels`` specifying a subset of atoms to compute the outputs for. If ``None``, the outputs are computed for all atoms in each system. :return: A dictionary mapping each requested output name to the corresponding ``TensorMap`` containing the computed values. .. seealso:: :py:class:`metatomic.torch.ModelInterface` for more explanation about the different arguments. """
[docs] @abstractmethod def supported_outputs(self) -> Dict[str, ModelOutput]: """ Get the outputs currently supported by this model. This will likely be the same outputs that are set as this model capabilities in :py:func:`ModelInterface.export`. :return: A dictionary of the supported outputs by this model. """
[docs] @abstractmethod def restart(self, dataset_info: DatasetInfo) -> "ModelInterface": """ Update a model to restart training, potentially with different dataset and/or targets. This function is called whenever training restarts, with the same or a different dataset. It enables transfer learning (changing the targets), and fine-tuning (same targets, different datasets) :param dataset_info: Information about the new dataset, including the targets that will be used for training. :return: The updated model, or a new instance of the model, that is able to handle the new dataset. """
[docs] @classmethod @abstractmethod def load_checkpoint( cls, checkpoint: Dict[str, Any], context: Literal["restart", "finetune", "export"], ) -> "ModelInterface": """ Create a model from a checkpoint (i.e. state dictionary). :param checkpoint: Checkpoint's state dictionary. :param context: Context in which to load the model. Possible values are ``"restart"`` when restarting a stopped traininf run, ``"finetune"`` when loading a model for further fine-tuning or transfer learning, and ``"export"`` when loading a model for final export. When multiple checkpoints are stored together, this can be used to pick one of them depending on the context. :return: An instance of the model. """
[docs] @abstractmethod def export( self, metadata: Optional[ModelMetadata] = None, ) -> AtomisticModel: """ Turn this model into an instance of :py:class:`metatomic.torch.MetatensorAtomisticModel`, containing the model itself, a definition of the model capabilities and some metadata about the model. :param metadata: additional metadata to add in the model as specified by the user. :return: An instance of :py:class:`metatomic.torch.MetatensorAtomisticModel` """
[docs] @classmethod @abstractmethod def upgrade_checkpoint(cls, checkpoint: Dict["str", Any]) -> Dict["str", Any]: """ Upgrade the checkpoint to the current version of the model. :param checkpoint: Checkpoint's state dictionary. :raises RuntimeError: if the checkpoint cannot be upgraded to the current version of the model. :return: The upgraded checkpoint. """
[docs] @abstractmethod def get_checkpoint(self) -> Dict[str, Any]: """ Get the checkpoint of the model. This should contain all the information needed by `load_checkpoint` to recreate the same model instance. :return: The model's checkpoint. """
[docs] class TrainerInterface(Generic[HypersType], metaclass=ABCMeta): """ Abstract base class for a model trainer in metatrain. All architectures in metatrain must implement such a trainer, which is responsible for training the model. The trainer must be a be sub-class of this class, and implement the corresponding methods. :param hypers: A dictionary with the trainer's hyper-parameters. """ __checkpoint_version__: int """The current version of the trainer's checkpoint. This is used to upgrade checkpoints produced with earlier versions of the code. See :ref:`ckpt_version` for more information.""" def __init__(self, hypers: HypersType): required_attributes = [ "__checkpoint_version__", ] for attribute in required_attributes: if not hasattr(self.__class__, attribute): raise TypeError( f"missing '{attribute}' class attribute for " f"'{self.__class__.__module__}.{self.__class__.__name__}'" ) self.__dict__["__intialized"] = True self.hypers = hypers """The trainer hypers passed at intialization""" def __setattr__(self, name: str, value: Any) -> None: if not hasattr(self, "__intialized") or not self.__dict__["__intialized"]: raise ValueError( "you must call `super().__init__(hypers)` before setting new fields" ) super().__setattr__(name, value)
[docs] @abstractmethod def train( self, model: ModelInterface, dtype: torch.dtype, devices: List[torch.device], train_datasets: List[Union[Dataset, torch.utils.data.Subset]], val_datasets: List[Union[Dataset, torch.utils.data.Subset]], checkpoint_dir: str, ) -> None: """ Train the ``model`` using the ``train_datasets``. How to train the model is left to this class, using the hyper-parameter given in ``__init__``. :param model: the model to train :param dtype: ``torch.dtype`` used by the data in the datasets :param devices: ``torch.device`` to use for training the model. When training with more than one device (e.g. multi-GPU training), this can contains multiple devices. :param train_datasets: datasets to use to train the model :param val_datasets: datasets to use for model validation :param checkpoint_dir: directory where checkpoints shoudl be saved """
[docs] @abstractmethod def save_checkpoint(self, model: ModelInterface, path: Union[str, Path]) -> None: """ Save a checkoint of both the ``model`` and trainer state to the given ``path`` :param model: The model to save in the checkpoint. :param path: The path where to save the checkpoint. """
[docs] @classmethod @abstractmethod def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict: """ Upgrade the checkpoint to the current version of the trainer. :param checkpoint: Checkpoint's state dictionary. :raises RuntimeError: if the checkpoint cannot be upgraded to the current version of the trainer. :return: The upgraded checkpoint. """
[docs] @classmethod @abstractmethod def load_checkpoint( cls, checkpoint: Dict[str, Any], hypers: HypersType, context: Literal["restart", "finetune"], ) -> "TrainerInterface": """ Create a trainer instance from data stored in the ``checkpoint``. :param checkpoint: Checkpoint's state dictionary. :param hypers: Hyper-parameters for the trainer, as specified by the user. :param context: Context in which to load the model. Possible values are ``"restart"`` when restarting a stopped traininf run, and ``"finetune"`` when loading a model for further fine-tuning or transfer learning. When multiple checkpoints are stored together, this can be used to pick one of them depending on the context. :return: The loaded trainer instance. """