Source code for metatrain.utils.hypers
from typing import Type, TypedDict, TypeVar
from typing_extensions import TypedDict as TE_TypedDict
HypersType = TypeVar("HypersType")
[docs]
def get_hypers_list(hypers_cls: Type[HypersType]) -> list[str]:
"""Get the list of hyperparameter names defined in a TypedDict hypers class.
Inheritance of parameters is allowed from parent classes, but make
sure that the parent classes only contain hyperparameters as
attributes! (i.e., no methods allowed). Private attributes (starting
with "_") are not considered as hyperparameters, so one can have
arbitrary private methods or attributes in the class and its parents,
although this is not recommended.
:param hypers_cls: The class defining the hyperparameters.
:return: A list with the names of the hyperparameters.
"""
hypers_list = []
# First find hypers from parent classes
parent_classes = [*hypers_cls.mro()[1:], *getattr(hypers_cls, "__orig_bases__", [])]
for base in parent_classes:
if base not in (TE_TypedDict, TypedDict, dict, object):
parent_hypers = get_hypers_list(base)
hypers_list.extend(parent_hypers)
this_class_vars = vars(hypers_cls)
# Now get hypers from this class
for key in this_class_vars.keys():
# Skip private attributes
if not key.startswith("_"):
hypers_list.append(key)
return hypers_list
[docs]
def init_with_defaults(hypers_cls: Type[HypersType]) -> dict:
"""Initialize a TypedDict hypers class with its default values.
Inheritance of parameters is allowed from parent classes, but make
sure that the parent classes only contain hyperparameters as
attributes! (i.e., no methods allowed). Private attributes (starting
with "_") are not considered as hyperparameters, so one can have
arbitrary private methods or attributes in the class and its parents,
although this is not recommended.
:param hypers_cls: The class defining the hyperparameters.
:return: A dict with the default hyperparameters.
"""
defaults_dict = {}
# First find defaults from parent classes
parent_classes = [*hypers_cls.mro()[1:], *getattr(hypers_cls, "__orig_bases__", [])]
for base in parent_classes:
if base not in (TE_TypedDict, TypedDict, dict, object):
base_defaults = init_with_defaults(base)
defaults_dict.update(base_defaults)
this_class_vars = vars(hypers_cls)
# Now get defaults from this class
for key, value in this_class_vars.items():
# Skip private attributes
if not key.startswith("_"):
defaults_dict[key] = value
# Overwrite using the registered overwrites
to_overwrite = _OVERWRITTEN_DEFAULTS.get(hypers_cls, {})
for k in to_overwrite:
if k in defaults_dict:
defaults_dict[k] = to_overwrite[k]
return defaults_dict
# Private global dictionary to store overwritten defaults
_OVERWRITTEN_DEFAULTS = {}
[docs]
def overwrite_defaults(
hypers_cls: Type,
new_defaults: dict,
) -> None:
"""Overwrite the default hyperparameters.
This function does not check that the new defaults correspond
to valid hyperparameters of the given hypers class. If the new
defaults contain keys that are not hyperparameters of the class,
they will simply be ignored.
:param hypers_cls: The hypers class whose defaults to overwrite.
:param new_defaults: A dict with the new default hyperparameters.
"""
_OVERWRITTEN_DEFAULTS[hypers_cls] = new_defaults