Source code for metatrain.utils.testing.training
import copy
from pathlib import Path
from typing import Any
import metatensor.torch as mts
import torch
from omegaconf import OmegaConf
from metatrain.utils.data.readers import read_systems
from metatrain.utils.hypers import init_with_defaults
from metatrain.utils.io import model_from_checkpoint
from metatrain.utils.loss import LossSpecification
from metatrain.utils.neighbor_lists import (
get_system_with_neighbor_lists,
)
from .architectures import ArchitectureTests
[docs]
class TrainingTests(ArchitectureTests):
"""Puts architectures to test in real training scenarios."""
check_gradients: bool = True
[docs]
def test_continue(
self,
monkeypatch: Any,
tmp_path: Path,
dataset_path: str,
dataset_targets: dict[str, dict],
default_hypers: dict[str, Any],
model_hypers: dict[str, Any],
) -> None:
"""Tests that a model can be checkpointed and loaded
for a continuation of the training process
:param monkeypatch: Pytest fixture to modify the current working
directory.
:param tmp_path: Temporary path to use for saving checkpoints.
:param dataset_path: Path to the dataset to use for training.
:param dataset_targets: Target hypers for the targets in the dataset.
:param default_hypers: Default hyperparameters for the architecture.
:param model_hypers: Hyperparameters to initialize the model.
"""
monkeypatch.chdir(tmp_path)
dataset, targets_info, dataset_info = self.get_dataset(
dataset_targets, dataset_path
)
model = self.model_cls(model_hypers, dataset_info)
hypers = copy.deepcopy(default_hypers)
hypers["training"]["num_epochs"] = 0
loss_conf = OmegaConf.create(
{k: init_with_defaults(LossSpecification) for k in dataset_targets}
)
OmegaConf.resolve(loss_conf)
hypers["training"]["loss"] = loss_conf
trainer = self.trainer_cls(hypers["training"])
trainer.train(
model=model,
dtype=torch.float32,
devices=[torch.device("cpu")],
train_datasets=[dataset],
val_datasets=[dataset],
checkpoint_dir=".",
)
trainer.save_checkpoint(model, "tmp.ckpt")
checkpoint = torch.load("tmp.ckpt", weights_only=False, map_location="cpu")
model_after = model_from_checkpoint(checkpoint, context="restart")
assert isinstance(model_after, self.model_cls)
model_after.restart(model.dataset_info)
hypers["training"]["num_epochs"] = 0
trainer = self.trainer_cls(hypers["training"])
trainer.train(
model=model_after,
dtype=torch.float32,
devices=[torch.device("cpu")],
train_datasets=[dataset],
val_datasets=[dataset],
checkpoint_dir=".",
)
# evaluation
systems = read_systems(dataset_path)
systems = [system.to(torch.float32) for system in systems[:5]]
for system in systems:
system.positions.requires_grad_(True)
get_system_with_neighbor_lists(system, model.requested_neighbor_lists())
model.eval()
model_after.eval()
output_before = model(
systems[:5], {k: model.outputs[k] for k in dataset_targets}
)
output_after = model_after(
systems[:5], {k: model_after.outputs[k] for k in dataset_targets}
)
# For each target, check that outputs are the same after loading
# from checkpoint, including gradients
for i, target_key in enumerate(dataset_targets):
assert mts.allclose(output_before[target_key], output_after[target_key]), (
f"Output mismatch for {target_key}"
)
# We can't run a backward pass twice.
if i > 0 or not self.check_gradients:
continue
target_before = output_before[target_key].block().values
target_before.backward(torch.ones_like(target_before))
gradients_before = [s.positions.grad for s in systems]
for system in systems:
system.positions.grad = None
target_after = output_after[target_key].block().values
target_after.backward(torch.ones_like(target_after))
gradients_after = [s.positions.grad for s in systems]
assert torch.allclose(
torch.vstack(gradients_before), torch.vstack(gradients_after)
)