from pathlib import Path
from typing import Any
import pytest
import torch
from metatrain.utils.abc import TrainerInterface
from metatrain.utils.architectures import get_default_hypers, import_architecture
from metatrain.utils.data import (
Dataset,
DatasetInfo,
TargetInfo,
get_atomic_types,
get_dataset,
)
from metatrain.utils.data.target_info import (
get_energy_target_info,
get_generic_target_info,
)
[docs]
class ArchitectureTests:
"""This is the base class for all architecture tests.
It doesn't implement any tests itself, but provides fixtures
and helper functions that are generally useful for testing
architectures.
Child classes can override everything, including fixtures, to
make the tests suit their needs. Note that some fixtures defined
here depend on other fixtures, but when overriding them, you can
change completely their signature.
"""
architecture: str
"""Name of the architecture to be tested.
Based on this, the test suite will find the model and trainer classes
as well as the hyperparameters.
"""
[docs]
@pytest.fixture
def dataset_path(self) -> str:
"""Fixture that provides a path to a dataset file for testing.
:return: The path to the dataset file.
"""
return str(Path(__file__).parents[4] / "tests/resources/qm9_reduced_100.xyz")
[docs]
@pytest.fixture
def dataset_targets(self, dataset_path: str) -> dict[str, dict]:
"""Fixture that provides the target hyperparameters for the dataset used
in testing.
:param dataset_path: The path to the dataset file.
:return: A dictionary with target hyperparameters.
"""
energy_target = {
"quantity": "energy",
"read_from": dataset_path,
"reader": "ase",
"key": "U0",
"unit": "eV",
"type": "scalar",
"per_atom": False,
"num_subtargets": 1,
"forces": False,
"stress": False,
"virial": False,
}
return {"energy": energy_target}
[docs]
def get_dataset(
self, dataset_targets: dict[str, dict], dataset_path: str
) -> tuple[Dataset, dict[str, TargetInfo], DatasetInfo]:
"""Helper function to load the dataset used in testing.
:param dataset_targets: The target hyperparameters for the dataset.
:param dataset_path: The path to the dataset file.
:return: A tuple containing the dataset, target info, and dataset info.
"""
dataset, targets_info, _ = get_dataset(
{
"systems": {
"read_from": dataset_path,
"reader": "ase",
},
"targets": dataset_targets,
}
)
dataset_info = DatasetInfo(
length_unit="",
atomic_types=get_atomic_types(dataset),
targets=targets_info,
)
return dataset, targets_info, dataset_info
[docs]
@pytest.fixture(params=("cpu", "cuda"))
def device(self, request: pytest.FixtureRequest) -> torch.device:
"""Fixture to provide the torch device for testing.
:param request: The pytest request fixture.
:return: The torch device to be used.
"""
device = request.param
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")
return torch.device(device)
[docs]
@pytest.fixture(params=[torch.float32, torch.float64])
def dtype(self, request: pytest.FixtureRequest) -> torch.dtype:
"""Fixture to provide the model data type for testing.
:param request: The pytest request fixture.
:return: The torch data type to be used.
"""
return request.param
[docs]
@pytest.fixture
def dataset_info(self) -> DatasetInfo:
"""Fixture that provides a basic ``DatasetInfo`` with an
energy target for testing.
:return: A ``DatasetInfo`` instance with an energy target.
"""
return DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"energy": get_energy_target_info(
"energy", {"quantity": "energy", "unit": "eV"}
)
},
)
[docs]
@pytest.fixture(params=[True, False])
def per_atom(self, request: pytest.FixtureRequest) -> bool:
"""Fixture to test both per-atom and per-system targets.
:param request: The pytest request fixture.
:return: Whether the target is per-atom or not.
"""
return request.param
[docs]
@pytest.fixture
def dataset_info_scalar(self, per_atom: bool) -> DatasetInfo:
"""Fixture that provides a basic ``DatasetInfo`` with a scalar target
for testing.
:param per_atom: Whether the target is per-atom or not.
:return: A ``DatasetInfo`` instance with a scalar target.
"""
return DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"scalar": get_generic_target_info(
"scalar",
{
"quantity": "scalar",
"unit": "",
"type": "scalar",
"num_subtargets": 5,
"per_atom": per_atom,
},
)
},
)
[docs]
@pytest.fixture
def dataset_info_vector(self, per_atom: bool) -> DatasetInfo:
"""Fixture that provides a basic ``DatasetInfo`` with a vector target
for testing.
:param per_atom: Whether the target is per-atom or not.
:return: A ``DatasetInfo`` instance with a vector target.
"""
return DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"vector": get_generic_target_info(
"vector",
{
"quantity": "vector",
"unit": "",
"type": {"cartesian": {"rank": 1}},
"num_subtargets": 5,
"per_atom": per_atom,
},
)
},
)
[docs]
@pytest.fixture(params=[0, 1, 2, 3])
def o3_lambda(self, request: pytest.FixtureRequest) -> int:
"""Fixture to provide different O(3) lambda values for
testing spherical tensors.
:param request: The pytest request fixture.
:return: The O(3) lambda value.
"""
return request.param
[docs]
@pytest.fixture(params=[-1, 1])
def o3_sigma(self, request: pytest.FixtureRequest) -> int:
"""Fixture to provide different O(3) sigma values for
testing spherical tensors.
:param request: The pytest request fixture.
:return: The O(3) sigma value.
"""
return request.param
[docs]
@pytest.fixture
def dataset_info_spherical(self, o3_lambda: int, o3_sigma: int) -> DatasetInfo:
"""Fixture that provides a basic ``DatasetInfo`` with a
spherical target for testing.
:param o3_lambda: The O(3) lambda of the spherical target.
:param o3_sigma: The O(3) sigma of the spherical target.
:return: A ``DatasetInfo`` instance with a spherical target.
"""
return DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"spherical_target": get_generic_target_info(
"spherical_target",
{
"quantity": "",
"unit": "",
"type": {
"spherical": {
"irreps": [
{"o3_lambda": o3_lambda, "o3_sigma": o3_sigma}
]
}
},
"num_subtargets": 5,
"per_atom": False,
},
)
},
)
[docs]
@pytest.fixture
def dataset_info_multispherical(self, per_atom: bool) -> DatasetInfo:
"""Fixture that provides a basic ``DatasetInfo`` with multiple spherical
targets for testing.
:param per_atom: Whether the target is per-atom or not.
:return: A ``DatasetInfo`` instance with a multiple spherical targets.
"""
return DatasetInfo(
length_unit="Angstrom",
atomic_types=[1, 6, 7, 8],
targets={
"spherical_tensor": get_generic_target_info(
"spherical_tensor",
{
"quantity": "spherical_tensor",
"unit": "",
"type": {
"spherical": {
"irreps": [
{"o3_lambda": 2, "o3_sigma": 1},
{"o3_lambda": 1, "o3_sigma": 1},
{"o3_lambda": 0, "o3_sigma": 1},
]
}
},
"num_subtargets": 100,
"per_atom": per_atom,
},
)
},
)
# Replace the Any type hint with type[ModelInterface]
# once https://github.com/metatensor/metatrain/issues/942 is solved.
@property
def model_cls(self) -> Any:
"""The model class to be tested."""
architecture = import_architecture(self.architecture)
return architecture.__model__
@property
def trainer_cls(self) -> type[TrainerInterface]:
"""The trainer class to be tested."""
architecture = import_architecture(self.architecture)
return architecture.__trainer__
[docs]
@pytest.fixture
def default_hypers(self) -> dict:
"""Fixture that provides the default hyperparameters for testing.
:return: The default hyperparameters for the architecture.
"""
return get_default_hypers(self.architecture)
[docs]
@pytest.fixture
def model_hypers(self) -> dict:
"""Fixture that provides the model hyperparameters for testing.
If not overriden, these are the default model hyperparameters.
:return: The model hyperparameters for testing.
"""
return get_default_hypers(self.architecture)["model"]
[docs]
@pytest.fixture
def minimal_model_hypers(self) -> dict:
"""The hypers that produce the smallest possible model.
This should be overridden in each architecture test class
to ensure that the tests run quickly/checkpoints occupy
little disk space.
:return: The minimal model hyperparameters for testing.
"""
return get_default_hypers(self.architecture)["model"]