Source code for metatrain.utils.pydantic
import logging
from typing import Any
from pydantic import BaseModel, TypeAdapter, ValidationError, create_model
from ..share.base_hypers import BaseHypers
[docs]
def validate(model_cls: Any, data: dict, **kwargs: Any) -> None:
r"""Validate with pydantic, raising custom metatrain errors.
:param model_cls: The Pydantic model class to use for validation.
If it is not a pydantic model, it will be adapted to pydantic
using ``pydantic.TypeAdapter``.
:param data: The data to validate.
:param \*\*kwargs: Additional keyword arguments to pass to the validation method.
:raises MetatrainValidationError: If validation fails.
"""
if issubclass(model_cls, BaseModel):
try:
model_cls.model_validate(data, **kwargs)
except ValidationError as e:
raise MetatrainValidationError(model_cls, e.errors()) from e
else:
adapter = TypeAdapter(model_cls)
try:
adapter.validate_python(data, **kwargs)
except ValidationError as e:
raise MetatrainValidationError(model_cls, e.errors()) from e
[docs]
def validate_architecture_options(
options: dict, model_hypers: type, trainer_hypers: type
) -> None:
"""Validate architecture-specific options using Pydantic.
:param options: The architecture options to validate.
:param model_hypers: The ModelHypers class of the architecture.
:param trainer_hypers: The TrainerHypers class of the architecture.
"""
def _is_validatable(cls: Any) -> bool:
return issubclass(cls, (BaseModel, dict))
if not _is_validatable(model_hypers) or not _is_validatable(trainer_hypers):
logging.warning(
"Architecture does not provide validation of hyperparameters. "
"Continuing without validation."
)
return
ArchitectureOptions = create_model(
"ArchitectureOptions",
name=str,
atomic_types=list[int],
model=model_hypers,
training=trainer_hypers,
__config__={"extra": "forbid", "strict": True},
)
# Because passing NotRequired[list[int]] to an argument of a pydantic model
# is not possible, and creating a TypedDict using variables (model_hypers,
# trainer_hypers) as typehints is also not possible, if atomix_types was
# not provided we have to add a dummy value for it and remove it after
# validation.
added_atomic_types = False
if "atomic_types" not in options:
options["atomic_types"] = []
added_atomic_types = True
validate(ArchitectureOptions, options)
if added_atomic_types:
del options["atomic_types"]
[docs]
def validate_base_options(options: dict) -> None:
"""Validate base options using Pydantic.
:param options: The base options to validate.
:raises ValueError: If the options are invalid.
"""
validate(BaseHypers, options)