import copy
import glob
import gzip
import logging
import os
from typing import Any, Dict, Literal
import pytest
import torch
from omegaconf import OmegaConf
from metatrain.utils.abc import ModelInterface, TrainerInterface
from metatrain.utils.hypers import init_with_defaults
from metatrain.utils.loss import LossSpecification
from .architectures import ArchitectureTests
ALLOWED_NEW_KEYS_CONDITIONS = [
# torch added this key in LambdaLR
lambda prefix, key: "scheduler" in f"{prefix}.{key}" and key == "_is_initial"
]
def check_same_checkpoint_structure(
checkpoint: Dict[str, Any], reference: Dict[str, Any], prefix: str = ""
) -> None:
"""
Check that the structure of two checkpoints is the same.
:param checkpoint: The checkpoint to be checked.
:param reference: The reference checkpoint.
:param prefix: The prefix to be added to the keys in the error messages.
"""
assert isinstance(checkpoint, dict)
assert isinstance(reference, dict)
for key in reference:
if key not in checkpoint:
raise KeyError(f"missing key from checkpoint: {prefix}.{key}")
for key in checkpoint:
if any(cond(prefix, key) for cond in ALLOWED_NEW_KEYS_CONDITIONS):
continue
if key not in reference:
raise KeyError(f"new key in checkpoint: {prefix}.{key}")
for key in reference:
if isinstance(reference[key], dict):
check_same_checkpoint_structure(
checkpoint[key], reference[key], prefix=prefix + "." + str(key)
)
[docs]
class CheckpointTests(ArchitectureTests):
"""Test suite for model and trainer checkpoints.
This test suite verifies that the checkpoints for the architecture
follow the expected behavior of ``metatrain`` checkpoints.
"""
incompatible_trainer_checkpoints: list[str] = []
"""A list of checkpoint paths that are known
to be incompatible with the current trainer version when restarting.
This should be overriden in subclasses.
"""
[docs]
@pytest.fixture
def model_trainer(
self,
dataset_path: str,
dataset_targets: dict,
minimal_model_hypers: dict,
default_hypers: dict,
) -> tuple[ModelInterface, TrainerInterface]:
"""Fixture that returns a trained model and trainer.
The model and trainer are used in the test suite to verify checkpoint
functionality.
:param dataset_path: The path to the dataset file to train on.
:param dataset_targets: The targets that the dataset contains.
:param minimal_model_hypers: Hyperparameters to initialize the model.
These should give the smallest possible model to use as little
disk space as possible when saving checkpoints.
:param default_hypers: Default hyperparameters to initialize the trainer.
:return: A tuple containing the trained model and the trainer.
"""
# Load dataset
dataset, targets_info, dataset_info = self.get_dataset(
dataset_targets, dataset_path
)
# Initialize model
model = self.model_cls(minimal_model_hypers, dataset_info)
# Set the training hyperparameters:
# - Just 1 epoch to keep the test fast
# - Default loss for each target
hypers = copy.deepcopy(default_hypers)
hypers["training"]["num_epochs"] = 1
loss_hypers = OmegaConf.create(
{k: init_with_defaults(LossSpecification) for k in dataset_targets}
)
loss_hypers = OmegaConf.to_container(loss_hypers, resolve=True)
hypers["training"]["loss"] = loss_hypers
# Initialize trainer
trainer = self.trainer_cls(hypers["training"])
# Train the model.
trainer.train(
model,
dtype=model.__supported_dtypes__[0],
devices=[torch.device("cpu")],
train_datasets=[dataset],
val_datasets=[dataset],
checkpoint_dir="",
)
return model, trainer
[docs]
@pytest.mark.parametrize("context", ["restart", "finetune", "export"])
def test_loading_old_checkpoints(
self,
default_hypers: dict,
model_trainer: tuple[ModelInterface, TrainerInterface],
context: Literal["restart", "finetune", "export"],
) -> None:
"""Tests that checkpoints from previous versions can be loaded.
This test goes through all the checkpoint files in the
``checkpoints/`` folder of the current directory (presumably the
architecture's tests folder) and tries to load them in the current
model and trainer.
The test skips trainer checkpoints that are listed in this class's
``incompatible_trainer_checkpoints`` attribute when the context is
``restart``.
:param default_hypers: Default hyperparameters to initialize the trainer.
:param model_trainer: Model and trainer to be used for loading the checkpoints.
:param context: The context in which to load the checkpoint.
"""
model, trainer = model_trainer
for path in glob.glob("checkpoints/*.ckpt.gz"):
if path in self.incompatible_trainer_checkpoints and context == "restart":
continue
with gzip.open(path, "rb") as fd:
checkpoint = torch.load(fd, weights_only=False)
if checkpoint["model_ckpt_version"] != model.__checkpoint_version__:
checkpoint = model.__class__.upgrade_checkpoint(checkpoint)
model.load_checkpoint(checkpoint, context)
if context == "restart":
if checkpoint["trainer_ckpt_version"] != trainer.__checkpoint_version__:
checkpoint = trainer.__class__.upgrade_checkpoint(checkpoint)
trainer.load_checkpoint(checkpoint, default_hypers, context)
[docs]
def test_checkpoint_did_not_change(
self,
monkeypatch: Any,
tmp_path: str,
model_trainer: tuple[ModelInterface, TrainerInterface],
) -> None:
"""
Test that the checkpoint did not change.
This test gets the current version of the model and trainer,
and loads the checkpoint for that version from the ``checkpoints/``
folder. If that checkpoint is not compatible with the current code,
this means that the checkpoint version of either the model or the
trainer needs to be bumped.
:param monkeypatch: The pytest monkeypatch fixture.
:param tmp_path: The pytest tmp_path fixture.
:param model_trainer: Model and trainer to test.
"""
model, trainer = model_trainer
cwd = os.getcwd()
monkeypatch.chdir(tmp_path)
trainer.save_checkpoint(model, "checkpoint.ckpt")
checkpoint = torch.load("checkpoint.ckpt", weights_only=False)
monkeypatch.chdir(cwd)
model_version = model.__checkpoint_version__
trainer_version = trainer.__checkpoint_version__
ckpt_name = f"model-v{model_version}_trainer-v{trainer_version}.ckpt.gz"
ckpt_path = f"checkpoints/{ckpt_name}"
if not os.path.exists(ckpt_path):
with gzip.open(ckpt_name, "wb") as output:
with open(os.path.join(tmp_path, "checkpoint.ckpt"), "rb") as input:
output.write(input.read())
raise ValueError(
f"missing reference checkpoint for model version {model_version} and "
f"trainer version {trainer_version}, we created one for you with the "
f"current state of the code. Please move it to {ckpt_path} if you "
"have no other changes to do."
)
else:
with gzip.open(ckpt_path, "rb") as fd:
reference = torch.load(fd, weights_only=False)
try:
check_same_checkpoint_structure(checkpoint, reference)
except KeyError as e:
raise ValueError(
"checkpoint structure changed. Please increase the checkpoint "
"version and implement checkpoint update"
) from e
[docs]
@pytest.mark.parametrize("context", ["finetune", "restart", "export"])
def test_get_checkpoint(
self,
context: Literal["finetune", "restart", "export"],
caplog: Any,
model_trainer: tuple[ModelInterface, TrainerInterface],
) -> None:
"""
Test that the checkpoint created by the ``model.get_checkpoint()``
function can be loaded back in all possible contexts.
This test can fail either if the model is unable to produce
checkpoints, or if the generated checkpoint can't be loaded back
by the model in the specified context.
:param context: The context in which to load the generated checkpoint.
:param caplog: The pytest caplog fixture.
:param model_trainer: Model and trainer to be used for the test.
"""
model, _ = model_trainer
checkpoint = model.get_checkpoint()
caplog.set_level(logging.INFO)
self.model_cls.load_checkpoint(checkpoint, context)
if context == "restart":
assert "Using latest model from epoch None" in caplog.text
else:
assert "Using best model from epoch None" in caplog.text
[docs]
@pytest.mark.parametrize("cls_type", ["model", "trainer"])
def test_failed_checkpoint_upgrade(
self, cls_type: Literal["model", "trainer"]
) -> None:
"""Test error raised when trying to upgrade an invalid checkpoint version.
This test creates a checkpoint with an invalid version number and
tries to upgrade it using the corresponding class.
If this test fails, it likely means that you are not raising the
error in your model/trainer's ``upgrade_checkpoint`` method when
the checkpoint version is not recognized. To raise the appropiate
error:
.. code-block:: python
cls_type = "model" # or "trainer"
raise RuntimeError(
f"Unable to upgrade the checkpoint: the checkpoint is using {cls_type} "
f"version {checkpoint_version}, while the current {cls_type} version "
f"is {self.__class__.__checkpoint_version__}."
)
:param cls_type: The class type to test.
"""
invalid_version = 99999999999999
checkpoint = {f"{cls_type}_ckpt_version": invalid_version}
cls = self.model_cls if cls_type == "model" else self.trainer_cls
version = cls.__checkpoint_version__
match = (
f"Unable to upgrade the checkpoint: the checkpoint is using {cls_type} "
f"version {invalid_version}, while the current {cls_type} version is "
f"{version}."
)
with pytest.raises(RuntimeError, match=match):
cls.upgrade_checkpoint(checkpoint)