# mypy: disable-error-code=misc
# We ignore misc errors in this file because TypedDict
# with default values is not allowed by mypy.
import torch
from metatomic.torch import System
from typing_extensions import TypedDict
from metatrain.utils.neighbor_lists import NeighborListOptions
[docs]
class LongRangeHypers(TypedDict):
"""In some systems and datasets, enabling long-range Coulomb interactions
might be beneficial for the accuracy of the model and/or
its physical correctness."""
enable: bool = False
"""Toggle for enabling long-range interactions"""
use_ewald: bool = False
"""Use Ewald summation. If False, P3M is used"""
smearing: float = 1.4
"""Smearing width in Fourier space"""
kspace_resolution: float = 1.33
"""Resolution of the reciprocal space grid"""
interpolation_nodes: int = 5
"""Number of grid points for interpolation (for PME only)"""
[docs]
class LongRangeFeaturizer(torch.nn.Module):
"""A class to compute long-range features starting from short-range features.
:param hypers: Dictionary containing the hyperparameters for the long-range
featurizer.
:param feature_dim: The dimension of the short-range features (which also
corresponds to the number of long-range features that will be returned).
:param neighbor_list_options: A :py:class:`NeighborListOptions` object containing
the neighbor list information for the short-range model.
"""
def __init__(
self,
hypers: LongRangeHypers,
feature_dim: int,
neighbor_list_options: NeighborListOptions,
) -> None:
super(LongRangeFeaturizer, self).__init__()
try:
from torchpme import (
Calculator,
CoulombPotential,
EwaldCalculator,
P3MCalculator,
)
except ImportError:
raise ImportError(
"`torch-pme` is required for long-range models. "
"Please install it with `pip install 'torch-pme>=0.3.2'`."
)
self.ewald_calculator = EwaldCalculator(
potential=CoulombPotential(
smearing=float(hypers["smearing"]),
exclusion_radius=neighbor_list_options.cutoff,
),
full_neighbor_list=neighbor_list_options.full_list,
lr_wavelength=float(hypers["kspace_resolution"]),
)
"""Calculator to compute the long-range electrostatic potential using the Ewald
summation method."""
self.p3m_calculator = P3MCalculator(
potential=CoulombPotential(
smearing=float(hypers["smearing"]),
exclusion_radius=neighbor_list_options.cutoff,
),
interpolation_nodes=hypers["interpolation_nodes"],
full_neighbor_list=neighbor_list_options.full_list,
mesh_spacing=float(hypers["kspace_resolution"]),
)
"""Calculator to compute the long-range electrostatic potential using the P3M
method."""
self.use_ewald = hypers["use_ewald"]
"""If ``True``, use the Ewald summation method instead of the P3M method for
periodic systems during training."""
self.direct_calculator = Calculator(
potential=CoulombPotential(
smearing=None,
exclusion_radius=neighbor_list_options.cutoff,
),
full_neighbor_list=False, # see docs of torch.combinations
)
"""Calculator for the electrostatic potential in non-periodic systems."""
self.neighbor_list_options = neighbor_list_options
"""Neighbor list information for the short-range model."""
self.charges_map = torch.nn.Linear(feature_dim, feature_dim)
"""Map the short-range features to atomic charges."""
self.out_projection = torch.nn.Sequential(
torch.nn.Linear(feature_dim, feature_dim),
torch.nn.SiLU(),
torch.nn.Linear(feature_dim, feature_dim),
)
[docs]
def forward(
self,
systems: list[System],
features: torch.Tensor,
neighbor_distances: torch.Tensor,
) -> torch.Tensor:
"""Compute the long-range features for a list of systems.
:param systems: A list of :py:class:`System` objects for which to compute the
long-range features. Each system must contain a neighbor list consistent
with the neighbor list options used to create the class.
:param features: A tensor of short-range features for the systems.
:param neighbor_distances: A tensor of neighbor distances for the systems,
which must be consistent with the neighbor list options used to create the
class.
:return: A tensor of long-range features for the systems.
"""
charges = self.charges_map(features)
last_len_nodes = 0
last_len_edges = 0
long_range_features = []
for system in systems:
system_charges = charges[last_len_nodes : last_len_nodes + len(system)]
last_len_nodes += len(system)
neighbor_list = system.get_neighbor_list(self.neighbor_list_options)
neighbor_indices_system = neighbor_list.samples.view(
["first_atom", "second_atom"]
).values
neighbor_distances_system = neighbor_distances[
last_len_edges : last_len_edges + len(neighbor_indices_system)
]
last_len_edges += len(neighbor_indices_system)
if system.pbc.any():
if system.pbc.sum() == 1:
raise NotImplementedError(
"Long-range featurizer does not support 1D systems."
)
if self.use_ewald and self.training: # use Ewald for training only
potential = self.ewald_calculator.forward(
charges=system_charges,
cell=system.cell,
positions=system.positions,
neighbor_indices=neighbor_indices_system,
neighbor_distances=neighbor_distances_system,
periodic=system.pbc,
)
else:
potential = self.p3m_calculator.forward(
charges=system_charges,
cell=system.cell,
positions=system.positions,
neighbor_indices=neighbor_indices_system,
neighbor_distances=neighbor_distances_system,
periodic=system.pbc,
)
else: # non-periodic
# compute the distance between all pairs of atoms
neighbor_indices_system = torch.combinations(
torch.arange(len(system), device=system.positions.device), 2
)
neighbor_distances_system = torch.sqrt(
torch.sum(
(
system.positions[neighbor_indices_system[:, 1]]
- system.positions[neighbor_indices_system[:, 0]]
)
** 2,
dim=1,
)
)
potential = self.direct_calculator.forward(
charges=system_charges,
cell=system.cell,
positions=system.positions,
neighbor_indices=neighbor_indices_system,
neighbor_distances=neighbor_distances_system,
)
long_range_features.append(self.out_projection(potential))
return torch.concatenate(long_range_features)
[docs]
class DummyLongRangeFeaturizer(torch.nn.Module):
# a dummy class for torchscript
def __init__(self) -> None:
super().__init__()
self.use_ewald = True
[docs]
def forward(
self,
systems: list[System],
features: torch.Tensor,
neighbor_distances: torch.Tensor,
) -> torch.Tensor:
return torch.tensor(0)