Source code for metatrain.utils.testing.exported
import torch
from metatomic.torch import ModelEvaluationOptions, ModelMetadata, System
from metatrain.utils.data import DatasetInfo
from metatrain.utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
)
from .architectures import ArchitectureTests
[docs]
class ExportedTests(ArchitectureTests):
"""Test suite to test exported models."""
[docs]
def test_to(
self,
device: torch.device,
dtype: torch.dtype,
model_hypers: dict,
dataset_info: DatasetInfo,
) -> None:
"""Tests that the `.to()` method of the exported model works.
In other words, it tests that the exported model can be moved to
different devices and dtypes.
:param device: The device to move the exported model to.
:param dtype: The dtype to move the exported model to.
:param model_hypers: Hyperparameters to initialize the model.
:param dataset_info: Dataset information to initialize the model.
"""
model = self.model_cls(model_hypers, dataset_info).to(dtype=dtype)
exported = model.export(metadata=ModelMetadata(name="test"))
# test correct metadata
assert "This is the test model" in str(exported.metadata())
exported.to(device=device)
system = System(
types=torch.tensor([6, 6]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
cell=torch.zeros(3, 3),
pbc=torch.tensor([False, False, False]),
)
requested_neighbor_lists = get_requested_neighbor_lists(exported)
system = get_system_with_neighbor_lists(system, requested_neighbor_lists)
system = system.to(device=device, dtype=dtype)
evaluation_options = ModelEvaluationOptions(
length_unit=dataset_info.length_unit,
outputs=model.outputs,
)
exported([system], evaluation_options, check_consistency=True)