Source code for src.architectures.generate

import ast
from pathlib import Path
from typing import TypedDict

from jinja2 import Environment, FileSystemLoader

from metatrain.utils.architectures import (
    find_all_architectures,
    get_architecture_path,
    get_hypers_classes,
    preload_documentation_module,
    write_hypers_yaml,
)
from metatrain.utils.hypers import get_hypers_list


ARCHITECTURES_DIR = Path(__file__).parent
TEMPLATES_DIR = ARCHITECTURES_DIR / "templates"
DEFAULT_HYPERS_DIR = ARCHITECTURES_DIR / "default_hypers"
GENERATED_DIR = ARCHITECTURES_DIR / "generated"


JINJA_ENV = Environment(
    loader=FileSystemLoader(TEMPLATES_DIR),
    trim_blocks=True,
    lstrip_blocks=True,
)


SECTIONS = [
    "installation",
    "default_hypers",
    "model_hypers",
    "trainer_hypers",
    "references",
]


[docs] class ArchitectureDocVariables(TypedDict): """Variables to use inside the architecture documentation. The docstring of the architecture will be processed as a ``jinja`` template. You can find documentation about them `here <https://jinja.palletsprojects.com/en/stable/templates>`_ , but the simplest functionality consists of using variables enclosed in double curly braces ``{{variable_name}}``, which will be replaced by their corresponding value. For example, a file with the following content: .. code-block:: rst This is the documentation for {{architecture}}. generates a documentation file that for the architecture ``pet`` would be: .. code-block:: rst This is the documentation for pet. There are some special variables that start with ``SECTION_``. These contain the content of different sections of the documentation, and they will be appended to the docstring if they are not already present. For example, given the docstring: .. code-block:: python \""" My architecture =============== This is my architecture. {{SECTION_DEFAULT_HYPERS}} Some important section ====================== Explain something important here. \""" The final documentation will append to the docstring all the sections except ``SECTION_DEFAULT_HYPERS``, since it is already present. Following you can find a description of all the available variables. The sections are appended in the order documented here. """ SECTION_INSTALLATION: str """Section containing installation instructions for this architecture.""" SECTION_DEFAULT_HYPERS: str """Section containing a yaml file with the default hyperparameters for this architecture.""" SECTION_MODEL_HYPERS: str """Section containing the description of the model hyperparameters for this architecture.""" SECTION_TRAINER_HYPERS: str """Section containing the description of the trainer hyperparameters for this architecture.""" SECTION_REFERENCES: str """Section containing references for this architecture. It will render the references that have been used as ``:footcite:p:`` during the architecture documentation.""" architecture: str """The name of the architecture. This excludes any 'experimental.' or 'deprecated.' prefix.""" architecture_path: str """The full python import path to the architecture. E.g.: ``"metatrain.experimental.my_architecture"`` """ default_hypers_path: str """Path to the yaml file with the default hyperparameters for this architecture. This is a path relative to the ``docs/src/architectures/generated`` directory. """ model_hypers_path: str """The full python import path to the model's hypers class of this architecture. E.g.: ``"metatrain.pet.documentation.ModelHypers"`` """ trainer_hypers_path: str """The full python import path to the trainer's hypers class of this architecture. E.g.: ``"metatrain.pet.documentation.TrainerHypers"`` """ model_hypers: list[str] """List of model hyperparameter names for this architecture.""" trainer_hypers: list[str] """List of trainer hyperparameter names for this architecture."""
def setup_architectures_docs(): """Generate the architecture documentation files. This function goes through all available architectures, and for each of them generates a yaml file with the default hyperparameters (so that it can be easily included in the documentation) and their rst documentation file. See :ref:`newarchitecture-documentation-page` for more information. """ # If the default_hypers directory does not exist, create it DEFAULT_HYPERS_DIR.mkdir(exist_ok=True) # Same for the generated directory GENERATED_DIR.mkdir(exist_ok=True) for architecture_name in find_all_architectures(): # Load documentation module in an isolated way to avoid # requiring dependencies for every architecture. preload_documentation_module(architecture_name) architecture_real_name = architecture_name.replace("experimental.", "").replace( "deprecated.", "" ) # Write default hypers file yaml_path = DEFAULT_HYPERS_DIR / f"{architecture_real_name}-default-hypers.yaml" write_hypers_yaml(architecture_name, yaml_path) generate_rst(architecture_name, yaml_path=yaml_path) def generate_rst( architecture_name: str, yaml_path: Path, ): """Generate the rst documentation file for a given architecture. :param architecture_name: The name of the architecture to generate the documentation for. :param yaml_path: Path to the yaml file with the default hyperparameters for this architecture. """ # Get the name of the architecture without any prefix. architecture_real_name = architecture_name.replace("experimental.", "").replace( "deprecated.", "" ) # Get the full python import path to the architecture arch_path = f"metatrain.{architecture_name}" # Get the docstring from the documentation.py file doc_file = get_architecture_path(architecture_name) / "documentation.py" with open(doc_file, "r") as f: module = ast.parse(f.read(), filename=str(doc_file)) docstring = ast.get_docstring(module) if docstring is None: raise ValueError( f"The documentation.py file for architecture " f"'{architecture_name}' does not have a module docstring." ) hypers_classes = get_hypers_classes(architecture_name) # Prepare template variables template_variables = dict( architecture=architecture_real_name, architecture_path=arch_path, default_hypers_path=".." / yaml_path.relative_to(ARCHITECTURES_DIR), model_hypers_path=f"{arch_path}.documentation.ModelHypers", trainer_hypers_path=f"{arch_path}.documentation.TrainerHypers", model_hypers=get_hypers_list(hypers_classes["model"]), trainer_hypers=get_hypers_list(hypers_classes["trainer"]), ) # Read section templates and render them for section in SECTIONS: template = JINJA_ENV.get_template(f"{section}.rst") template_variables[f"SECTION_{section.upper()}"] = template.render( **template_variables ) # Prepend docstring with reference and append missing sections docstring = ( f".. _architecture-{template_variables['architecture']}:" + "\n\n" + docstring ) # Check for missing sections and add them to the end of the docstring for section in SECTIONS: section_var = "{{SECTION_" + section.upper() + "}}" if section_var not in docstring: docstring += f"\n\n{section_var}" # Render docstring template docstring = JINJA_ENV.from_string(docstring).render(**template_variables) # Write to file with open(GENERATED_DIR / f"{architecture_real_name}.rst", "w") as f: f.write(docstring + "\n")