import copy
from typing import Any
import torch
from metatomic.torch import System
from metatrain.utils.data import DatasetInfo
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
from .architectures import ArchitectureTests
[docs]
class TorchscriptTests(ArchitectureTests):
"""Test suite to check that architectures can be jit compiled with
TorchScript."""
float_hypers: list[str] = []
"""List of hyperparameter keys (dot-separated for nested keys)
that are floats. A test will set these to integers to test that
TorchScript compilation works in that case."""
[docs]
def test_torchscript(self, model_hypers: dict, dataset_info: DatasetInfo) -> None:
"""Tests that the model can be jitted.
If this test fails it probably means that there is some
code in the model that is not compatible with TorchScript.
The exception raised by the test should indicate where
the problem is.
:param model_hypers: Hyperparameters to initialize the model.
:param dataset_info: Dataset to initialize the model.
"""
model = self.model_cls(model_hypers, dataset_info)
system = System(
types=torch.tensor([6, 1, 8, 7]),
positions=torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]]
),
cell=torch.zeros(3, 3),
pbc=torch.tensor([False, False, False]),
)
system = get_system_with_neighbor_lists(
system, model.requested_neighbor_lists()
)
model = torch.jit.script(model)
model(
[system],
model.outputs,
)
[docs]
def test_torchscript_spherical(
self, model_hypers: dict, dataset_info_spherical: DatasetInfo
) -> None:
"""Tests that there is no problem with jitting with spherical targets.
:param model_hypers: Hyperparameters to initialize the model.
:param dataset_info_spherical: Dataset to initialize the model
(containing spherical targets).
"""
self.test_torchscript(
model_hypers=model_hypers, dataset_info=dataset_info_spherical
)
[docs]
def test_torchscript_save_load(
self, tmpdir: Any, model_hypers: dict, dataset_info: DatasetInfo
) -> None:
"""Tests that the model can be jitted, saved and loaded.
:param tmpdir: Temporary directory where to save the
model.
:param model_hypers: Hyperparameters to initialize the model.
:param dataset_info: Dataset to initialize the model.
"""
model = self.model_cls(model_hypers, dataset_info)
with tmpdir.as_cwd():
torch.jit.save(torch.jit.script(model), "model.pt")
torch.jit.load("model.pt")
[docs]
def test_torchscript_integers(
self, model_hypers: dict, dataset_info: DatasetInfo
) -> None:
"""Tests that the model can be jitted when some float
parameters are instead supplied as integers.
:param model_hypers: Hyperparameters to initialize the model.
:param dataset_info: Dataset to initialize the model.
"""
new_hypers = copy.deepcopy(model_hypers)
for hyper in self.float_hypers:
nested_key = hyper.split(".")
sub_dict = new_hypers
for key in nested_key[:-1]:
sub_dict = sub_dict[key]
sub_dict[nested_key[-1]] = int(sub_dict[nested_key[-1]])
model = self.model_cls(new_hypers, dataset_info)
system = System(
types=torch.tensor([6, 1, 8, 7]),
positions=torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]]
),
cell=torch.zeros(3, 3),
pbc=torch.tensor([False, False, False]),
)
system = get_system_with_neighbor_lists(
system, model.requested_neighbor_lists()
)
model = torch.jit.script(model)
model(
[system],
model.outputs,
)