"""
Complete featurization pipeline using Chebyshev descriptors.
This module implements the AUC (Artrith-Urban-Ceder) descriptor with the
typespin architecture from aenet's Fortran code.
"""
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from torch_scatter import scatter_add
from .chebyshev import AngularBasis, RadialBasis
from ..torch_nblist import TorchNeighborList
from .graph import center_ids_of_edge as _center_ids_of_edge
[docs]
class ChebyshevDescriptor(nn.Module):
"""
Complete featurization pipeline using Chebyshev descriptors.
Implements the typespin architecture from aenet's Fortran code:
- Single radial and angular basis functions
- Two feature sets: unweighted + typespin-weighted
- Typespin coefficients centered around zero
This exactly matches the behavior of aenet's Fortran implementation.
Notes
-----
- Under PBC, neighbor displacements are reconstructed using wrapped
fractional positions and integer image offsets provided by the ghost
PBC backend of TorchNeighborList. Specifically:
positions_frac = remainder(positions @ inv(cell), 1.0)
r_ij = ((positions_frac[j] + offsets) @ cell)-(positions_frac[i] @ cell)
This aligns featurization with the neighbor list semantics and ensures
PBC vs explicit supercell consistency.
- The neighbor cutoff used internally for neighbor list construction is
max(rad_cutoff, ang_cutoff) to ensure complete coverage for both
basis sets.
"""
[docs]
def __init__(
self,
species: List[str],
rad_order: int,
rad_cutoff: float,
ang_order: int,
ang_cutoff: float,
min_cutoff: float = 0.55,
device: str = "cpu",
dtype: torch.dtype = torch.float64,
profile_timing: bool = False,
):
"""
Initialize Chebyshev descriptor.
Args:
species: List of atomic species (e.g., ['O', 'H'])
rad_order: Maximum radial Chebyshev order
rad_cutoff: Radial cutoff radius (Angstroms)
ang_order: Maximum angular Chebyshev order
ang_cutoff: Angular cutoff radius (Angstroms)
min_cutoff: Minimum distance cutoff (Angstroms)
device: 'cpu' or 'cuda'
dtype: torch.float64 for double precision
profile_timing: If True, track separate timings for neighbor
list construction vs feature computation
"""
super().__init__()
self.species = species
self.n_species = len(species)
self.rad_order = rad_order
self.rad_cutoff = rad_cutoff
self.ang_order = ang_order
self.ang_cutoff = ang_cutoff
self.min_cutoff = min_cutoff
self.device = device
self.dtype = dtype
# Timing profiling
self.profile_timing = profile_timing
self.nblist_time = 0.0
self.feature_time = 0.0
self.nblist_calls = 0
# Multi-species flag
self.multi = self.n_species > 1
# Create species to index mapping
self.species_to_idx = {s: i for i, s in enumerate(species)}
# Compute typespin coefficients (centered around zero)
self.typespin = self._compute_typespin()
# Neighbor list
max_cutoff = max(rad_cutoff, ang_cutoff)
self.nbl = TorchNeighborList(
cutoff=max_cutoff, device=device, dtype=dtype
)
# SINGLE radial basis function (not per type pair!)
self.rad_basis = RadialBasis(
rad_order=rad_order,
rad_cutoff=rad_cutoff,
min_cutoff=min_cutoff,
dtype=dtype,
)
# Move to device (ensures internal buffers are on correct device)
self.rad_basis.to(self.device)
# SINGLE angular basis function (not per type triplet!)
self.ang_basis = AngularBasis(
ang_order=ang_order,
ang_cutoff=ang_cutoff,
min_cutoff=min_cutoff,
dtype=dtype,
)
# Move to device (ensures internal buffers are on correct device)
self.ang_basis.to(self.device)
# Calculate number of features
self.n_features = self._calculate_n_features()
def _compute_typespin(self) -> torch.Tensor:
"""
Compute typespin coefficients matching Fortran implementation.
For even number of species, zero is skipped:
- 2 species: {-1, 1}
- 4 species: {-2, -1, 1, 2}
For odd number of species, zero is included:
- 3 species: {-1, 0, 1}
- 5 species: {-2, -1, 0, 1, 2}
Returns
-------
typespin: (n_species,) tensor of typespin values
"""
typespin = torch.zeros(
self.n_species, dtype=self.dtype, device=self.device
)
# Use int() to truncate towards zero like Fortran, not floor division
s = int(-self.n_species / 2)
for i in range(self.n_species):
# Skip zero for even number of species
if s == 0 and self.n_species % 2 == 0:
s += 1
typespin[i] = float(s)
s += 1
return typespin
def _calculate_n_features(self) -> int:
"""
Calculate number of features for each atom.
For multi-species systems:
- Radial: 2 * (rad_order + 1)
- Angular: 2 * (ang_order + 1)
For single-species systems:
- Radial: (rad_order + 1)
- Angular: (ang_order + 1)
Returns
-------
Number of features (same for all species)
"""
n_rad = self.rad_order + 1
n_ang = self.ang_order + 1
if self.multi:
# Two sets: unweighted + typespin-weighted
n_features = 2 * (n_rad + n_ang)
else:
# Single set only
n_features = n_rad + n_ang
return n_features
def get_n_features(self) -> int:
"""Get number of features (same for all species)."""
return self.n_features
def get_timing_stats(self) -> dict:
"""
Get timing statistics for neighbor list vs feature computation.
Returns
-------
Dictionary with keys:
- nblist_time: Total time spent in neighbor list construction
- feature_time: Total time spent in feature computation
- nblist_calls: Number of neighbor list calls
- nblist_time_per_call: Average time per neighbor list call
- feature_time_per_call: Average time per feature computation
"""
stats = {
'nblist_time': self.nblist_time,
'feature_time': self.feature_time,
'nblist_calls': self.nblist_calls,
}
if self.nblist_calls > 0:
stats['nblist_time_per_call'
] = self.nblist_time / self.nblist_calls
stats['feature_time_per_call'
] = self.feature_time / self.nblist_calls
else:
stats['nblist_time_per_call'] = 0.0
stats['feature_time_per_call'] = 0.0
return stats
def reset_timing_stats(self):
"""Reset timing statistics to zero."""
self.nblist_time = 0.0
self.feature_time = 0.0
self.nblist_calls = 0
def compute_radial_features(
self,
positions: torch.Tensor,
species_indices: torch.Tensor,
neighbor_indices: List[torch.Tensor],
neighbor_vectors: List[torch.Tensor],
) -> torch.Tensor:
"""
Compute radial features from pre-computed neighbor information.
This is the core implementation that accepts neighbor vectors directly
from the local structural environment (LSE), enabling clean gradient
computation without re-computing neighbors or applying PBC.
For each atom i:
- Set 1 (unweighted): sum over neighbors j of G_rad(d_ij)
- Set 2 (typespin): sum over neighbors j of s_j * G_rad(d_ij)
Args:
positions: (N, 3) atomic positions (for autograd tracking)
species_indices: (N,) species index for each atom
neighbor_indices: List of (nnb_i,) tensors with neighbor
atom indices
neighbor_vectors: List of (nnb_i, 3) tensors with
displacement vectors
Returns
-------
radial_features: (N, n_rad_features) tensor
"""
n_atoms = len(positions)
n_rad = self.rad_order + 1
# Build edge lists from neighbor information
center_indices_list = []
neighbor_indices_list = []
distances_list = []
for i, (nb_idx, nb_vec) in enumerate(zip(neighbor_indices,
neighbor_vectors)):
if len(nb_idx) == 0:
continue
# Compute distances from vectors
distances = torch.norm(nb_vec, dim=-1)
# Filter by radial cutoff
mask = (distances <= self.rad_cutoff
) & (distances > self.min_cutoff)
if mask.any():
n_valid = mask.sum().item()
center_indices_list.append(torch.full(
(n_valid,), i, dtype=torch.long, device=self.device))
neighbor_indices_list.append(nb_idx[mask])
distances_list.append(distances[mask])
if len(distances_list) == 0:
# Return zeros if no neighbors
if self.multi:
return torch.zeros(
n_atoms, 2 * n_rad, dtype=self.dtype, device=self.device
)
else:
return torch.zeros(
n_atoms, n_rad, dtype=self.dtype, device=self.device
)
# Concatenate all edges
center_indices = torch.cat(center_indices_list)
neighbor_indices_flat = torch.cat(neighbor_indices_list)
distances_rad = torch.cat(distances_list)
# Compute radial basis for all pairs
G_rad = self.rad_basis(distances_rad) # (n_pairs, n_rad)
# Unweighted features: scatter_add over neighbors
rad_features_unweighted = scatter_add(
G_rad, center_indices, dim=0, dim_size=n_atoms
)
if not self.multi:
return rad_features_unweighted
# Typespin-weighted features (for multi-species only)
neighbor_species = species_indices[neighbor_indices_flat]
neighbor_typespin = self.typespin[neighbor_species] # (n_pairs,)
# Multiply by typespin
G_rad_weighted = G_rad * neighbor_typespin.unsqueeze(-1)
# Scatter_add weighted features
rad_features_weighted = scatter_add(
G_rad_weighted, center_indices, dim=0, dim_size=n_atoms
)
# Concatenate unweighted and weighted features
rad_features = torch.cat(
[rad_features_unweighted, rad_features_weighted], dim=1
)
return rad_features
def compute_angular_features(
self,
positions: torch.Tensor,
species_indices: torch.Tensor,
neighbor_indices: List[torch.Tensor],
neighbor_vectors: List[torch.Tensor],
) -> torch.Tensor:
"""
Compute angular features from pre-computed neighbor information.
This is the core implementation that accepts neighbor vectors directly
from the local structural environment (LSE), enabling clean gradient
computation without re-computing neighbors or applying PBC.
For each atom i with neighbors j, k:
- Set 1 (unweighted): sum over triplets of G_ang
- Set 2 (typespin): sum over triplets of s_j * s_k * G_ang
Args:
positions: (N, 3) atomic positions (for autograd tracking)
species_indices: (N,) species index for each atom
neighbor_indices: List of (nnb_i,) tensors with neighbor
atom indices
neighbor_vectors: List of (nnb_i, 3) tensors with
displacement vectors
Returns
-------
angular_features: (N, n_ang_features) tensor
"""
n_atoms = len(positions)
n_ang = self.ang_order + 1
# Generate all valid triplets (i, j, k) from neighbor information
triplet_centers = []
triplet_j_global = []
triplet_k_global = []
triplet_j_local = []
triplet_k_local = []
for i, (nb_idx, nb_vec) in enumerate(zip(neighbor_indices,
neighbor_vectors)):
if len(nb_idx) < 2:
continue
# Compute distances from vectors
distances = torch.norm(nb_vec, dim=-1)
# Filter by angular cutoff
mask = (distances <= self.ang_cutoff
) & (distances > self.min_cutoff)
valid_nb_idx = nb_idx[mask]
n_valid = len(valid_nb_idx)
if n_valid < 2:
continue
# Generate all pairs (j, k) for this center
for j_local in range(n_valid):
for k_local in range(j_local + 1, n_valid):
triplet_centers.append(i)
triplet_j_global.append(valid_nb_idx[j_local])
triplet_k_global.append(valid_nb_idx[k_local])
triplet_j_local.append(j_local)
triplet_k_local.append(k_local)
if len(triplet_centers) == 0:
# No valid triplets
if self.multi:
return torch.zeros(
n_atoms, 2 * n_ang, dtype=self.dtype, device=self.device
)
else:
return torch.zeros(
n_atoms, n_ang, dtype=self.dtype, device=self.device
)
# Convert to tensors
triplet_centers = torch.tensor(
triplet_centers, dtype=torch.long, device=self.device
)
triplet_j_global = torch.tensor(
triplet_j_global, dtype=torch.long, device=self.device
)
triplet_k_global = torch.tensor(
triplet_k_global, dtype=torch.long, device=self.device
)
# Collect distances and normalized vectors for all triplets
d_ij_list = []
d_ik_list = []
vec_j_norm_list = []
vec_k_norm_list = []
for idx, center_i in enumerate(triplet_centers):
center_i = center_i.item()
j_local = triplet_j_local[idx]
k_local = triplet_k_local[idx]
nb_vec = neighbor_vectors[center_i]
distances = torch.norm(nb_vec, dim=-1)
# Filter by cutoff
mask = (distances <= self.ang_cutoff
) & (distances > self.min_cutoff)
valid_vec = nb_vec[mask]
valid_dist = distances[mask]
d_ij_list.append(valid_dist[j_local])
d_ik_list.append(valid_dist[k_local])
vec_j_norm_list.append(valid_vec[j_local] / valid_dist[j_local])
vec_k_norm_list.append(valid_vec[k_local] / valid_dist[k_local])
d_ij = torch.stack(d_ij_list)
d_ik = torch.stack(d_ik_list)
vec_j_norm = torch.stack(vec_j_norm_list)
vec_k_norm = torch.stack(vec_k_norm_list)
# Compute cos(theta_ijk) for all triplets
cos_theta = (vec_j_norm * vec_k_norm).sum(dim=1).clamp(-1.0, 1.0)
# Compute angular basis for all triplets
G_ang = self.ang_basis(d_ij, d_ik, cos_theta) # (n_triplets, n_ang)
# Scatter_add unweighted features
ang_features_unweighted = scatter_add(
G_ang, triplet_centers, dim=0, dim_size=n_atoms
)
if not self.multi:
return ang_features_unweighted
# Typespin-weighted features (for multi-species only)
neighbor_j_species = species_indices[triplet_j_global]
neighbor_k_species = species_indices[triplet_k_global]
typespin_j = self.typespin[neighbor_j_species]
typespin_k = self.typespin[neighbor_k_species]
typespin_product = typespin_j * typespin_k
# Multiply by typespin product
G_ang_weighted = G_ang * typespin_product.unsqueeze(-1)
# Scatter_add weighted features
ang_features_weighted = scatter_add(
G_ang_weighted, triplet_centers, dim=0, dim_size=n_atoms
)
# Concatenate unweighted and weighted features
ang_features = torch.cat(
[ang_features_unweighted, ang_features_weighted], dim=1
)
return ang_features
def forward(
self,
positions: torch.Tensor,
species: List[str],
neighbor_indices: Optional[List[torch.Tensor]] = None,
neighbor_vectors: Optional[List[torch.Tensor]] = None,
cell: Optional[torch.Tensor] = None,
pbc: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Featurize atomic structure.
This method supports two calling modes:
1. With pre-computed neighbors (for gradient computation):
forward(positions, species, neighbor_indices, neighbor_vectors)
2. Legacy mode (automatic neighbor computation):
forward(positions, species, cell=cell, pbc=pbc)
Args:
positions: (N, 3) atomic positions
species: List of N species names
neighbor_indices: List of (nnb_i,) tensors with neighbor
atom indices (optional)
neighbor_vectors: List of (nnb_i, 3) tensors with
displacement vectors (optional)
cell: (3, 3) lattice vectors (for legacy mode)
pbc: (3,) periodic boundary conditions (for legacy mode)
Returns
-------
features: (N, n_features) feature matrix
"""
# Legacy mode: compute neighbors automatically
if neighbor_indices is None or neighbor_vectors is None:
return self.forward_from_positions(positions, species, cell, pbc)
# New mode: use pre-computed neighbors for gradient computation
# Convert species to indices
species_indices = torch.tensor(
[self.species_to_idx[s] for s in species],
dtype=torch.long,
device=self.device,
)
# Move inputs to device
positions = positions.to(self.device).to(self.dtype)
# Compute radial features
rad_features = self.compute_radial_features(
positions, species_indices, neighbor_indices, neighbor_vectors
)
# Compute angular features
ang_features = self.compute_angular_features(
positions, species_indices, neighbor_indices, neighbor_vectors
)
# Concatenate in Fortran order: [radial_unweighted, angular_unweighted,
# radial_weighted, angular_weighted]
# (or [radial, angular] for single-species)
if self.multi:
n_rad = self.rad_order + 1
n_ang = self.ang_order + 1
features = torch.cat(
[
rad_features[:, :n_rad], # radial unweighted
ang_features[:, :n_ang], # angular unweighted
rad_features[:, n_rad:], # radial weighted
ang_features[:, n_ang:], # angular weighted
],
dim=1,
)
else:
# Single species: just concatenate radial and angular
features = torch.cat([rad_features, ang_features], dim=1)
return features
def forward_from_positions(
self,
positions: torch.Tensor,
species: List[str],
cell: Optional[torch.Tensor] = None,
pbc: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Convenience method that computes neighbors then calls forward().
Use this when you don't need gradients and want automatic neighbor
computation. For gradient computation, use forward() directly with
pre-computed neighbor information.
Args:
positions: (N, 3) atomic positions
species: List of N species names
cell: (3, 3) lattice vectors (None for isolated)
pbc: (3,) periodic boundary conditions
Returns
-------
features: (N, n_features) feature matrix
Notes
-----
For periodic systems, displacements are reconstructed using wrapped
fractional coordinates and integer offsets from the neighbor list:
positions_frac = remainder(positions @ inv(cell), 1.0)
r_ij = ((positions_frac[j] + offsets) @ cell)
- (positions_frac[i] @ cell)
This matches the semantics of the ghost PBC backend for correctness and
PBC vs supercell consistency.
"""
import time
# Move inputs to device
positions = positions.to(self.device).to(self.dtype)
if cell is not None:
cell = cell.to(self.device).to(self.dtype)
# Timing: Neighbor list construction
if self.profile_timing:
t_nb_start = time.perf_counter()
# Get neighbor data using the maximum cutoff
neighbor_data = self.nbl.get_neighbors(
positions, cell, pbc, fractional=False
)
edge_index = neighbor_data['edge_index']
distances = neighbor_data['distances']
offsets = neighbor_data['offsets']
# Filter by minimum cutoff
mask = distances > self.min_cutoff
edge_index = edge_index[:, mask]
distances = distances[mask]
if offsets is not None:
offsets = offsets[mask]
# Compute displacement vectors
i_indices = edge_index[0]
j_indices = edge_index[1]
if cell is not None and offsets is not None:
# Use wrapped fractional positions to be consistent with the PBC
# ghost backend offsets (which are defined relative to wrapped
# fractional coords)
cell_inv = torch.linalg.inv(cell)
positions_frac = torch.remainder(positions @ cell_inv, 1.0)
r_ij = (
(positions_frac[j_indices] + offsets.to(self.dtype)) @ cell
) - (positions_frac[i_indices] @ cell)
else:
r_ij = positions[j_indices] - positions[i_indices]
# Organize per atom
n_atoms = len(positions)
neighbor_indices = []
neighbor_vectors = []
for atom_idx in range(n_atoms):
neighbor_mask = i_indices == atom_idx
atom_neighbors = j_indices[neighbor_mask]
atom_vectors = r_ij[neighbor_mask]
neighbor_indices.append(atom_neighbors)
neighbor_vectors.append(atom_vectors)
if self.profile_timing:
self.nblist_time += time.perf_counter() - t_nb_start
self.nblist_calls += 1
t_feat_start = time.perf_counter()
# Call core forward method (feature computation)
features = self.forward(positions, species,
neighbor_indices, neighbor_vectors)
if self.profile_timing:
self.feature_time += time.perf_counter() - t_feat_start
return features
def _compute_radial_gradients(
self,
positions: torch.Tensor,
species_indices: torch.Tensor,
neighbor_indices: List[torch.Tensor],
neighbor_vectors: List[torch.Tensor],
) -> torch.Tensor:
"""
Compute gradients of radial features using a fully vectorized
semi-analytical method.
This method uses the analytical derivatives of the basis functions
and the chain rule to compute gradients with respect to atomic
positions in a single, vectorized pass.
Args:
positions: (N, 3) atomic positions
species_indices: (N,) species index for each atom
neighbor_indices: List of neighbor indices per atom
neighbor_vectors: List of displacement vectors per atom
Returns
-------
gradients: (N, n_rad_features, N, 3) gradient tensor
"""
n_atoms = len(positions)
n_rad = self.rad_order + 1
# Build edge lists from neighbor information
center_indices_list = []
neighbor_indices_list = []
vectors_list = []
for i, (nb_idx, nb_vec) in enumerate(zip(neighbor_indices,
neighbor_vectors)):
if len(nb_idx) == 0:
continue
distances = torch.norm(nb_vec, dim=-1)
mask = (distances <= self.rad_cutoff
) & (distances > self.min_cutoff)
if mask.any():
n_valid = mask.sum().item()
center_indices_list.append(torch.full(
(n_valid,), i, dtype=torch.long, device=self.device))
neighbor_indices_list.append(nb_idx[mask])
vectors_list.append(nb_vec[mask])
if len(vectors_list) == 0:
# Return zeros if no neighbors
n_rad_features = 2 * n_rad if self.multi else n_rad
return torch.zeros(
n_atoms, n_rad_features, n_atoms, 3,
dtype=self.dtype, device=self.device
)
# Concatenate all edges
center_indices = torch.cat(center_indices_list)
neighbor_indices_flat = torch.cat(neighbor_indices_list)
r_ij = torch.cat(vectors_list)
# Compute distances and basis derivatives
distances = torch.norm(r_ij, dim=-1)
_, dG_rad_dr = self.rad_basis.forward_with_derivatives(distances)
# Chain rule: dG/dr_vec = (dG/dr) * (dr/dr_vec)
# dr/dr_vec is the normalized displacement vector
dG_rad_drij = dG_rad_dr.unsqueeze(-1) * (
r_ij / (distances.unsqueeze(-1) + 1e-10)).unsqueeze(1)
# Initialize gradient tensor
n_rad_features = 2 * n_rad if self.multi else n_rad
gradients = torch.zeros(
n_atoms, n_rad_features, n_atoms, 3,
dtype=self.dtype, device=self.device
)
# Unweighted gradients
# Contribution to central atom i: -dG/drij
# Contribution to neighbor atom j: +dG/drij
# We can do this with two scatter_add operations
# Reshape for scattering: (n_pairs, n_rad, 3)
grad_unweighted = dG_rad_drij
# Scatter to central atoms (negative contribution)
# This is tricky because scatter_add doesn't support (i, j) indexing
# We will loop for now, but this can be optimized
for pair_idx in range(len(center_indices)):
i = center_indices[pair_idx].item()
j = neighbor_indices_flat[pair_idx].item()
gradients[i, :n_rad, i] -= grad_unweighted[pair_idx]
gradients[i, :n_rad, j] += grad_unweighted[pair_idx]
if not self.multi:
return gradients
# Typespin-weighted gradients
neighbor_species = species_indices[neighbor_indices_flat]
neighbor_typespin = self.typespin[neighbor_species]
grad_weighted = (grad_unweighted
* neighbor_typespin.unsqueeze(-1).unsqueeze(-1))
for pair_idx in range(len(center_indices)):
i = center_indices[pair_idx].item()
j = neighbor_indices_flat[pair_idx].item()
gradients[i, n_rad:, i] -= grad_weighted[pair_idx]
gradients[i, n_rad:, j] += grad_weighted[pair_idx]
return gradients
def _compute_angular_gradients(
self,
positions: torch.Tensor,
species_indices: torch.Tensor,
neighbor_indices: List[torch.Tensor],
neighbor_vectors: List[torch.Tensor],
) -> torch.Tensor:
"""
Compute gradients of angular features using a fully analytical,
vectorized method.
Avoids autograd on geometric quantities to
improve performance and numerical stability (especially under PBC).
Returns
-------
gradients: (N, n_ang_features, N, 3) gradient tensor
"""
n_atoms = len(positions)
n_ang = self.ang_order + 1
eps_norm = 1e-12
eps_dist = 1e-20
# Build triplet lists (i, j, k) and corresponding displacement vectors
triplet_i: list[int] = []
triplet_j: list[int] = []
triplet_k: list[int] = []
r_ij_list: list[torch.Tensor] = []
r_ik_list: list[torch.Tensor] = []
for i, (nb_idx, nb_vec) in enumerate(zip(neighbor_indices,
neighbor_vectors)):
if len(nb_idx) < 2:
continue
distances = torch.norm(nb_vec, dim=-1)
mask = (distances <= self.ang_cutoff
) & (distances > self.min_cutoff)
if not mask.any():
continue
valid_nb_idx = nb_idx[mask]
valid_vectors = nb_vec[mask]
n_valid = len(valid_nb_idx)
if n_valid < 2:
continue
# Generate all unique neighbor pairs (j, k) with j_local < k_local
for j_local in range(n_valid):
for k_local in range(j_local + 1, n_valid):
triplet_i.append(i)
triplet_j.append(int(valid_nb_idx[j_local].item()))
triplet_k.append(int(valid_nb_idx[k_local].item()))
r_ij_list.append(valid_vectors[j_local])
r_ik_list.append(valid_vectors[k_local])
if len(triplet_i) == 0:
n_ang_features = 2 * n_ang if self.multi else n_ang
return torch.zeros(
n_atoms, n_ang_features, n_atoms, 3,
dtype=self.dtype, device=self.device
)
# Stack and prepare tensors
triplet_i_t = torch.tensor(
triplet_i, dtype=torch.long, device=self.device)
triplet_j_t = torch.tensor(
triplet_j, dtype=torch.long, device=self.device)
triplet_k_t = torch.tensor(
triplet_k, dtype=torch.long, device=self.device)
r_ij = torch.stack(r_ij_list
).to(self.dtype).to(self.device) # (T, 3)
r_ik = torch.stack(r_ik_list
).to(self.dtype).to(self.device) # (T, 3)
# Distances with tiny epsilon to avoid zero-norm issues
d_ij = torch.sqrt((r_ij * r_ij).sum(dim=-1) + eps_dist) # (T,)
d_ik = torch.sqrt((r_ik * r_ik).sum(dim=-1) + eps_dist) # (T,)
# Unit vectors
u_ij = r_ij / (d_ij.unsqueeze(-1) + eps_norm) # (T, 3)
u_ik = r_ik / (d_ik.unsqueeze(-1) + eps_norm) # (T, 3)
# Cosine of angle at i
cos_theta = (u_ij * u_ik).sum(dim=-1).clamp(-1.0, 1.0) # (T,)
# Evaluate angular basis and partial derivatives
# Shapes: (T, n_ang)
_, dG_dcos, dG_drij, dG_drik = self.ang_basis.forward_with_derivatives(
d_ij, d_ik, cos_theta
)
# Geometric derivatives of cos(theta) wrt positions r_j, r_k, r_i
# dcos/dr_j = (1/d_ij) * (u_ik - cos_theta * u_ij)
# dcos/dr_k = (1/d_ik) * (u_ij - cos_theta * u_ik)
# dcos/dr_i = - (dcos/dr_j + dcos/dr_k)
dcos_drj = (u_ik - cos_theta.unsqueeze(-1) * u_ij
) / (d_ij.unsqueeze(-1) + eps_norm) # (T,3)
dcos_drk = (u_ij - cos_theta.unsqueeze(-1) * u_ik
) / (d_ik.unsqueeze(-1) + eps_norm) # (T,3)
dcos_dri = -(dcos_drj + dcos_drk) # (T,3)
# Total gradients per triplet and per angular feature (vectorized)
# Shapes for grads_*: (T, n_ang, 3)
dG_dcos_e = dG_dcos.unsqueeze(-1) # (T, n_ang, 1)
dG_drij_e = dG_drij.unsqueeze(-1) # (T, n_ang, 1)
dG_drik_e = dG_drik.unsqueeze(-1) # (T, n_ang, 1)
u_ij_e = u_ij.unsqueeze(1) # (T, 1, 3)
u_ik_e = u_ik.unsqueeze(1) # (T, 1, 3)
dcos_drj_e = dcos_drj.unsqueeze(1) # (T, 1, 3)
dcos_drk_e = dcos_drk.unsqueeze(1) # (T, 1, 3)
dcos_dri_e = dcos_dri.unsqueeze(1) # (T, 1, 3)
# Chain rule combinations
grads_j = dG_dcos_e * dcos_drj_e + dG_drij_e * u_ij_e
grads_k = dG_dcos_e * dcos_drk_e + dG_drik_e * u_ik_e
grads_i = (dG_dcos_e * dcos_dri_e
- dG_drij_e * u_ij_e - dG_drik_e * u_ik_e)
# Initialize output gradient tensor
n_ang_features = 2 * n_ang if self.multi else n_ang
gradients = torch.zeros(
n_atoms, n_ang_features, n_atoms, 3,
dtype=self.dtype, device=self.device
)
# Prepare flattened center-target indices for efficient scatter
# idx = center * n_atoms + target
flat_size = n_atoms * n_atoms
flat_idx_cc = triplet_i_t * n_atoms + triplet_i_t # (T,)
flat_idx_cj = triplet_i_t * n_atoms + triplet_j_t # (T,)
flat_idx_ck = triplet_i_t * n_atoms + triplet_k_t # (T,)
# Accumulate unweighted gradients feature-by-feature
for ang_idx in range(n_ang):
# (T,3) slices
gi = grads_i[:, ang_idx, :]
gj = grads_j[:, ang_idx, :]
gk = grads_k[:, ang_idx, :]
# Accumulate into flattened (center, target) axis
accum_flat = torch.zeros(
flat_size, 3, dtype=self.dtype, device=self.device)
accum_flat.index_add_(0, flat_idx_cc, gi)
accum_flat.index_add_(0, flat_idx_cj, gj)
accum_flat.index_add_(0, flat_idx_ck, gk)
# Reshape back to (n_atoms, n_atoms, 3) and assign
gradients[:, ang_idx, :, :] = accum_flat.view(n_atoms, n_atoms, 3)
# Typespin-weighted angular gradients (multi-species only)
if self.multi:
species_j = species_indices[triplet_j_t]
species_k = species_indices[triplet_k_t]
typespin_prod = (
self.typespin[species_j] * self.typespin[species_k]
).unsqueeze(-1) # (T,1)
for ang_idx in range(n_ang):
gi_w = grads_i[:, ang_idx, :] * typespin_prod
gj_w = grads_j[:, ang_idx, :] * typespin_prod
gk_w = grads_k[:, ang_idx, :] * typespin_prod
accum_flat_w = torch.zeros(
flat_size, 3, dtype=self.dtype, device=self.device)
accum_flat_w.index_add_(0, flat_idx_cc, gi_w)
accum_flat_w.index_add_(0, flat_idx_cj, gj_w)
accum_flat_w.index_add_(0, flat_idx_ck, gk_w)
gradients[:, n_ang + ang_idx, :, :
] = accum_flat_w.view(n_atoms, n_atoms, 3)
return gradients
def compute_feature_gradients(
self,
positions: torch.Tensor,
species: List[str],
cell: Optional[torch.Tensor] = None,
pbc: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute features and their gradients w.r.t. positions efficiently.
Uses a semi-analytical, vectorized approach that is much faster than
the naive feature-by-feature autograd loop.
Args:
positions: (N, 3) atomic positions
species: List of N species names
cell: (3, 3) lattice vectors
pbc: (3,) periodic boundary conditions
Returns
-------
features: (N, F) feature tensor
gradients: (N, F, N, 3) gradient tensor where
gradients[i, f, j, k] = ∂feature[i,f]/∂position[j,k]
"""
# Get neighbor information
positions_device = positions.to(self.device).to(self.dtype)
if cell is not None:
cell = cell.to(self.device).to(self.dtype)
neighbor_data = self.nbl.get_neighbors(
positions_device, cell, pbc, fractional=False
)
edge_index = neighbor_data['edge_index']
distances = neighbor_data['distances']
offsets = neighbor_data['offsets']
mask = distances > self.min_cutoff
edge_index = edge_index[:, mask]
if offsets is not None:
offsets = offsets[mask]
i_indices = edge_index[0]
j_indices = edge_index[1]
if cell is not None and offsets is not None:
# Use wrapped fractional positions to align with
# ghost-backend offsets
cell_inv = torch.linalg.inv(cell)
positions_frac = torch.remainder(positions_device @ cell_inv, 1.0)
r_ij = (
(positions_frac[j_indices] + offsets.to(self.dtype)) @ cell
) - (positions_frac[i_indices] @ cell)
else:
r_ij = positions_device[j_indices] - positions_device[i_indices]
n_atoms = len(positions)
neighbor_indices = []
neighbor_vectors = []
for atom_idx in range(n_atoms):
neighbor_mask = i_indices == atom_idx
neighbor_indices.append(j_indices[neighbor_mask])
neighbor_vectors.append(r_ij[neighbor_mask])
# Compute features
features = self.forward(
positions, species, neighbor_indices, neighbor_vectors
)
# Convert species to indices for gradient computation
species_indices = torch.tensor(
[self.species_to_idx[s] for s in species],
dtype=torch.long,
device=self.device,
)
# Compute radial and angular gradients
rad_grads = self._compute_radial_gradients(
positions, species_indices, neighbor_indices, neighbor_vectors
)
ang_grads = self._compute_angular_gradients(
positions, species_indices, neighbor_indices, neighbor_vectors
)
# Combine gradients in the correct feature order
n_rad = self.rad_order + 1
n_ang = self.ang_order + 1
gradients = torch.zeros(
n_atoms, self.n_features, n_atoms, 3,
dtype=self.dtype, device=self.device
)
if self.multi:
# Order: [rad_unweighted, ang_unweighted, rad_weighted,
# ang_weighted]
gradients[:, :n_rad] = rad_grads[:, :n_rad]
gradients[:, n_rad:n_rad + n_ang] = ang_grads[:, :n_ang]
gradients[:, n_rad + n_ang:2 * n_rad + n_ang
] = rad_grads[:, n_rad:]
gradients[:, 2 * n_rad + n_ang:] = ang_grads[:, n_ang:]
else:
gradients[:, :n_rad] = rad_grads
gradients[:, n_rad:] = ang_grads
return features, gradients
def forward_with_graph(
self,
positions: torch.Tensor,
species_indices: torch.Tensor,
graph,
triplets=None,
center_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Vectorized forward using CSR neighbor graph and optional triplets.
Args:
positions: (N,3) positions, used only for dtype/device alignment
species_indices: (N,) long tensor of species indices
graph: NeighborGraph dict with keys
center_ptr[int32 N+1], nbr_idx[int32 E],
r_ij[float E,3], d_ij[float E]
triplets: Optional TripletIndex dict with keys
tri_i/j/k[int32 T], tri_j_local/tri_k_local[int32 T]
center_indices: Optional (M,) indices of centers to include
(not used in forward accumulation here)
Returns
-------
features: (N, F)
"""
device = self.device
dtype = self.dtype
positions = positions.to(device=device, dtype=dtype)
species_indices = species_indices.to(device=device)
N = positions.shape[0]
n_rad = self.rad_order + 1
n_ang = self.ang_order + 1
# Edge-level arrays
nbr_idx = graph["nbr_idx"].to(device=device, dtype=torch.int64)
d_ij = graph["d_ij"].to(device=device, dtype=dtype)
# Centers per edge from CSR
center_of_edge = _center_ids_of_edge(graph).to(
device=device, dtype=torch.int64)
# Radial basis on all edges
G_rad = self.rad_basis(d_ij) # (E, n_rad)
# Scatter to centers (unweighted radial)
rad_unw = scatter_add(G_rad, center_of_edge, dim=0, dim_size=N)
if self.multi:
# Typespin-weighted radial
neigh_types = species_indices[nbr_idx]
tspin_j = self.typespin[neigh_types] # (E,)
G_rad_w = G_rad * tspin_j.unsqueeze(-1)
rad_w = scatter_add(G_rad_w, center_of_edge, dim=0, dim_size=N)
else:
rad_w = None
# Angular features if triplets provided
if triplets is not None and n_ang > 0:
center_ptr = graph["center_ptr"].to(
device=device, dtype=torch.int64)
r_edges = graph["r_ij"].to(device=device, dtype=dtype)
tri_i = triplets["tri_i"].to(device=device, dtype=torch.int64)
tri_j = triplets["tri_j"].to(device=device, dtype=torch.int64)
tri_k = triplets["tri_k"].to(device=device, dtype=torch.int64)
tri_j_local = triplets["tri_j_local"].to(
device=device, dtype=torch.int64)
tri_k_local = triplets["tri_k_local"].to(
device=device, dtype=torch.int64)
# Edge indices within r_ij for (i,j) and (i,k)
start_i = center_ptr[tri_i] # (T,)
edge_j_idx = start_i + tri_j_local # (T,)
edge_k_idx = start_i + tri_k_local # (T,)
r_ij_vec = r_edges[edge_j_idx] # (T,3)
r_ik_vec = r_edges[edge_k_idx] # (T,3)
eps = 1e-20
d_ij_t = torch.sqrt((r_ij_vec * r_ij_vec).sum(dim=-1) + eps)
d_ik_t = torch.sqrt((r_ik_vec * r_ik_vec).sum(dim=-1) + eps)
u_ij = r_ij_vec / (d_ij_t.unsqueeze(-1) + 1e-12)
u_ik = r_ik_vec / (d_ik_t.unsqueeze(-1) + 1e-12)
cos_theta = (u_ij * u_ik).sum(dim=-1).clamp(-1.0, 1.0)
# Angular basis over all triplets
G_ang = self.ang_basis(d_ij_t, d_ik_t, cos_theta) # (T, n_ang)
ang_unw = scatter_add(G_ang, tri_i, dim=0, dim_size=N)
if self.multi:
tspin_prod = (self.typespin[species_indices[tri_j]]
* self.typespin[species_indices[tri_k]])
G_ang_w = G_ang * tspin_prod.unsqueeze(-1)
ang_w = scatter_add(G_ang_w, tri_i, dim=0, dim_size=N)
else:
ang_w = None
else:
ang_unw = torch.zeros(N, n_ang, dtype=dtype, device=device)
ang_w = torch.zeros_like(ang_unw) if self.multi else None
# Assemble features in Fortran order used elsewhere
if self.multi:
features = torch.cat([rad_unw[:, :n_rad],
ang_unw[:, :n_ang],
rad_w[:, :n_rad],
ang_w[:, :n_ang]], dim=1)
else:
features = torch.cat([rad_unw[:, :n_rad],
ang_unw[:, :n_ang]], dim=1)
return features
def compute_features_and_local_derivatives_with_graph(
self,
positions: torch.Tensor,
species_indices: torch.Tensor,
graph,
triplets,
center_indices: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, dict[str, dict[str, Optional[torch.Tensor]]]]:
"""
Compute features plus sparse local derivative blocks from graph data.
The returned derivative representation avoids any dense
``(N, F, N, 3)`` expansion. Radial contributions are expressed per
neighbor edge, while angular contributions are expressed per triplet.
This representation is intended for direct force contraction and for
later serialization into HDF5-backed derivative caches.
Parameters
----------
positions : torch.Tensor
Atomic positions with shape ``(N, 3)``.
species_indices : torch.Tensor
Species indices with shape ``(N,)``.
graph : dict
CSR neighbor graph with ``center_ptr``, ``nbr_idx``, ``r_ij``,
and ``d_ij`` entries.
triplets : dict or None
Optional triplet index dictionary.
center_indices : torch.Tensor, optional
Reserved for future center filtering. The current implementation
matches :meth:`forward_with_graph` and does not filter centers.
Returns
-------
tuple
``(features, local_derivatives)`` where ``features`` has shape
``(N, F)`` and ``local_derivatives`` is a nested dictionary with
``radial`` and ``angular`` blocks.
"""
del center_indices
device = self.device
dtype = self.dtype
positions = positions.to(device=device, dtype=dtype)
species_indices = species_indices.to(device=device)
features = self.forward_with_graph(
positions=positions,
species_indices=species_indices,
graph=graph,
triplets=triplets,
center_indices=None,
)
# Edge-level local radial derivatives.
nbr_idx = graph["nbr_idx"].to(device=device, dtype=torch.int64)
r_edges = graph["r_ij"].to(device=device, dtype=dtype)
d_edges = graph["d_ij"].to(device=device, dtype=dtype).clamp_min(1e-20)
u_edges = r_edges / (d_edges.unsqueeze(-1) + 1e-12)
center_of_edge = _center_ids_of_edge(graph).to(
device=device, dtype=torch.int64
)
_, dG_dr = self.rad_basis.forward_with_derivatives(d_edges)
dG_drij = dG_dr.unsqueeze(-1) * u_edges.unsqueeze(1)
radial_block: dict[str, Optional[torch.Tensor]] = {
"center_idx": center_of_edge,
"neighbor_idx": nbr_idx,
"dG_drij": dG_drij,
"neighbor_typespin": (
self.typespin[species_indices[nbr_idx]]
if self.multi else None
),
}
# Triplet-level local angular derivatives.
n_ang = self.ang_order + 1
if triplets is not None and n_ang > 0:
center_ptr = graph["center_ptr"].to(device=device, dtype=torch.int64)
tri_i = triplets["tri_i"].to(device=device, dtype=torch.int64)
tri_j = triplets["tri_j"].to(device=device, dtype=torch.int64)
tri_k = triplets["tri_k"].to(device=device, dtype=torch.int64)
tri_j_local = triplets["tri_j_local"].to(
device=device, dtype=torch.int64
)
tri_k_local = triplets["tri_k_local"].to(
device=device, dtype=torch.int64
)
start_i = center_ptr[tri_i]
edge_j_idx = start_i + tri_j_local
edge_k_idx = start_i + tri_k_local
r_ij = r_edges[edge_j_idx]
r_ik = r_edges[edge_k_idx]
eps = 1e-20
d_ij = torch.sqrt((r_ij * r_ij).sum(dim=-1) + eps)
d_ik = torch.sqrt((r_ik * r_ik).sum(dim=-1) + eps)
u_ij = r_ij / (d_ij.unsqueeze(-1) + 1e-12)
u_ik = r_ik / (d_ik.unsqueeze(-1) + 1e-12)
cos_theta = (u_ij * u_ik).sum(dim=-1).clamp(-1.0, 1.0)
(_, dG_dcos, dG_drij_ang, dG_drik_ang
) = self.ang_basis.forward_with_derivatives(d_ij, d_ik, cos_theta)
dcos_drj = (
u_ik - cos_theta.unsqueeze(-1) * u_ij
) / (d_ij.unsqueeze(-1) + 1e-12)
dcos_drk = (
u_ij - cos_theta.unsqueeze(-1) * u_ik
) / (d_ik.unsqueeze(-1) + 1e-12)
dcos_dri = -(dcos_drj + dcos_drk)
dG_dcos_e = dG_dcos.unsqueeze(-1)
dG_drij_e = dG_drij_ang.unsqueeze(-1)
dG_drik_e = dG_drik_ang.unsqueeze(-1)
u_ij_e = u_ij.unsqueeze(1)
u_ik_e = u_ik.unsqueeze(1)
dcos_drj_e = dcos_drj.unsqueeze(1)
dcos_drk_e = dcos_drk.unsqueeze(1)
dcos_dri_e = dcos_dri.unsqueeze(1)
grads_j = dG_dcos_e * dcos_drj_e + dG_drij_e * u_ij_e
grads_k = dG_dcos_e * dcos_drk_e + dG_drik_e * u_ik_e
grads_i = (
dG_dcos_e * dcos_dri_e
- dG_drij_e * u_ij_e
- dG_drik_e * u_ik_e
)
angular_block: dict[str, Optional[torch.Tensor]] = {
"center_idx": tri_i,
"neighbor_j_idx": tri_j,
"neighbor_k_idx": tri_k,
"grads_i": grads_i,
"grads_j": grads_j,
"grads_k": grads_k,
"triplet_typespin": (
self.typespin[species_indices[tri_j]]
* self.typespin[species_indices[tri_k]]
if self.multi else None
),
}
else:
empty_idx = torch.empty(0, dtype=torch.int64, device=device)
empty_grad = torch.empty(0, n_ang, 3, dtype=dtype, device=device)
angular_block = {
"center_idx": empty_idx,
"neighbor_j_idx": empty_idx,
"neighbor_k_idx": empty_idx,
"grads_i": empty_grad,
"grads_j": empty_grad,
"grads_k": empty_grad,
"triplet_typespin": (
torch.empty(0, dtype=dtype, device=device)
if self.multi else None
),
}
local_derivatives = {
"radial": radial_block,
"angular": angular_block,
}
return features, local_derivatives
def compute_feature_gradients_with_graph(
self,
positions: torch.Tensor,
species_indices: torch.Tensor,
graph,
triplets,
center_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Vectorized features and gradients using CSR neighbors and triplets.
Returns
-------
features: (N, F)
gradients: (N, F, N, 3)
"""
device = self.device
dtype = self.dtype
positions = positions.to(device=device, dtype=dtype)
species_indices = species_indices.to(device=device)
N = positions.shape[0]
n_rad = self.rad_order + 1
n_ang = self.ang_order + 1
features, local_derivatives = (
self.compute_features_and_local_derivatives_with_graph(
positions=positions,
species_indices=species_indices,
graph=graph,
triplets=triplets,
center_indices=center_indices,
)
)
radial_block = local_derivatives["radial"]
angular_block = local_derivatives["angular"]
# Prepare output gradients
n_features = self.get_n_features()
gradients = torch.zeros(
N, n_features, N, 3, dtype=dtype, device=device
)
# Reconstruct dense radial gradients from edge-local blocks.
center_of_edge = radial_block["center_idx"]
nbr_idx = radial_block["neighbor_idx"]
dG_drij = radial_block["dG_drij"]
flat_size = N * N
idx_cc = center_of_edge * N + center_of_edge
idx_cj = center_of_edge * N + nbr_idx
for k in range(n_rad):
accum = torch.zeros(flat_size, 3, dtype=dtype, device=device)
gk = dG_drij[:, k, :]
accum.index_add_(0, idx_cc, -gk)
accum.index_add_(0, idx_cj, gk)
gradients[:, k, :, :] = accum.view(N, N, 3)
if self.multi:
tspin_j = radial_block["neighbor_typespin"].view(-1, 1, 1)
dG_w = dG_drij * tspin_j
for k in range(n_rad):
accum = torch.zeros(flat_size, 3, dtype=dtype, device=device)
gk = dG_w[:, k, :]
accum.index_add_(0, idx_cc, -gk)
accum.index_add_(0, idx_cj, gk)
gradients[:, n_rad + n_ang + k, :, :] = accum.view(N, N, 3)
# Reconstruct dense angular gradients from triplet-local blocks.
tri_i = angular_block["center_idx"]
if tri_i.numel() > 0 and n_ang > 0:
tri_j = angular_block["neighbor_j_idx"]
tri_k = angular_block["neighbor_k_idx"]
grads_i = angular_block["grads_i"]
grads_j = angular_block["grads_j"]
grads_k = angular_block["grads_k"]
idx_cc = tri_i * N + tri_i
idx_cj = tri_i * N + tri_j
idx_ck = tri_i * N + tri_k
for a in range(n_ang):
accum = torch.zeros(flat_size, 3, dtype=dtype, device=device)
accum.index_add_(0, idx_cc, grads_i[:, a, :])
accum.index_add_(0, idx_cj, grads_j[:, a, :])
accum.index_add_(0, idx_ck, grads_k[:, a, :])
gradients[:, n_rad + a, :, :] = accum.view(N, N, 3)
if self.multi:
tsp = angular_block["triplet_typespin"].unsqueeze(-1)
for a in range(n_ang):
accum = torch.zeros(flat_size, 3, dtype=dtype, device=device)
accum.index_add_(0, idx_cc, grads_i[:, a, :] * tsp)
accum.index_add_(0, idx_cj, grads_j[:, a, :] * tsp)
accum.index_add_(0, idx_ck, grads_k[:, a, :] * tsp)
gradients[:, 2 * n_rad + n_ang + a, :, :] = (
accum.view(N, N, 3)
)
return features, gradients
def compute_feature_gradients_from_neighbor_info(
self,
positions: torch.Tensor,
species: List[str],
neighbor_indices: List[torch.Tensor],
neighbor_vectors: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute features and their gradients using precomputed neighbor info.
This avoids recomputing neighbors and displacement vectors by reusing
neighbor_indices and neighbor_vectors produced earlier (e.g., by
featurize_with_neighbor_info).
Args:
positions: (N, 3) atomic positions
species: List of N species names
neighbor_indices: List of (nnb_i,) tensors with neighbor indices
neighbor_vectors: List of (nnb_i, 3) tensors with
displacement vectors
Returns
-------
features: (N, F) feature tensor
gradients: (N, F, N, 3) gradient tensor where
gradients[i, f, j, k] = ∂feature[i,f]/∂position[j,k]
"""
# Move positions to device/dtype
positions_device = positions.to(self.device).to(self.dtype)
# Compute features using provided neighbor data
features = self.forward(
positions_device, species, neighbor_indices, neighbor_vectors
)
# Species indices
species_indices = torch.tensor(
[self.species_to_idx[s] for s in species],
dtype=torch.long,
device=self.device,
)
# Compute radial and angular gradients using provided neighbor data
rad_grads = self._compute_radial_gradients(
positions_device, species_indices,
neighbor_indices, neighbor_vectors
)
ang_grads = self._compute_angular_gradients(
positions_device, species_indices,
neighbor_indices, neighbor_vectors
)
# Combine in correct feature order
n_rad = self.rad_order + 1
n_ang = self.ang_order + 1
n_atoms = positions_device.shape[0]
gradients = torch.zeros(
n_atoms, self.n_features, n_atoms, 3,
dtype=self.dtype, device=self.device
)
if self.multi:
# [rad_unweighted, ang_unweighted, rad_weighted, ang_weighted]
gradients[:, :n_rad] = rad_grads[:, :n_rad]
gradients[:, n_rad:n_rad + n_ang] = ang_grads[:, :n_ang]
gradients[:, n_rad + n_ang:2 * n_rad + n_ang
] = rad_grads[:, n_rad:]
gradients[:, 2 * n_rad + n_ang:] = ang_grads[:, n_ang:]
else:
gradients[:, :n_rad] = rad_grads
gradients[:, n_rad:] = ang_grads
return features, gradients
def compute_forces_from_energy(
self,
positions: torch.Tensor,
species: List[str],
energy_model: nn.Module,
cell: Optional[torch.Tensor] = None,
pbc: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute atomic forces from an energy model.
Forces are computed as F = -∂E/∂r using autograd through
the full featurization → energy prediction pipeline.
Args:
positions: (N, 3) atomic positions
species: List of N species names
energy_model: Neural network that predicts energy from features
cell: (3, 3) lattice vectors
pbc: (3,) periodic boundary conditions
Returns
-------
energy: Scalar total energy
forces: (N, 3) force on each atom
"""
# Enable gradient tracking for positions
positions_grad = positions.clone().detach().requires_grad_(True)
# Compute features using convenience wrapper
features = self.forward_from_positions(
positions_grad, species, cell, pbc)
# Predict energy
energy = energy_model(features).sum()
# Compute forces via autograd: F = -∂E/∂r
forces = -torch.autograd.grad(
energy,
positions_grad,
create_graph=True, # Enable higher-order derivatives
)[0]
return energy, forces
def featurize_with_neighbor_info(
self,
positions: torch.Tensor,
species: List[str],
cell: Optional[torch.Tensor] = None,
pbc: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict]:
"""
Featurize structure and extract neighbor information for
force training.
This method computes atomic features and extracts neighbor lists and
displacement vectors needed for computing feature derivatives during
force training. The neighbor information can be saved to HDF5 for
later use with on-demand derivative computation.
Args:
positions: (N, 3) atomic positions (Cartesian coordinates)
species: List of N species names
cell: (3, 3) lattice vectors as rows (None for isolated systems)
pbc: (3,) periodic boundary conditions (default: all True if
cell provided)
Returns
-------
features: (N, n_features) feature tensor
neighbor_info: Dictionary containing:
- 'neighbor_counts': (N,) number of neighbors per atom
- 'neighbor_lists': List of N arrays, each containing
neighbor indices
- 'neighbor_vectors': List of N arrays, each (nnb, 3)
displacement vectors
- 'max_neighbors': int, maximum number of neighbors
across all atoms
Example
-------
>>> descriptor = ChebyshevDescriptor(['O', 'H'], 10, 4.0, 3, 1.5)
>>> positions = torch.tensor([[0.0, 0.0, 0.0],
... [0.0, 0.0, 1.0],
... [0.0, 1.0, 0.0]])
>>> features, neighbor_info = \
... descriptor.featurize_with_neighbor_info(
... positions, ['O', 'H', 'H']
... )
>>> print(neighbor_info['neighbor_counts']) # tensor([2, 1, 1])
>>> print(neighbor_info['max_neighbors']) # 2
Notes
-----
The neighbor cutoff used is the maximum of rad_cutoff and ang_cutoff
to ensure all neighbors relevant for both radial and angular features
are included.
The displacement vectors are computed as r_j - r_i for neighbor j
of atom i, and include periodic image offsets when applicable.
"""
# Move inputs to device
positions = positions.to(self.device).to(self.dtype)
if cell is not None:
cell = cell.to(self.device).to(self.dtype)
# Get neighbor data using the maximum cutoff
# Use fractional=False since positions are assumed Cartesian
neighbor_data = self.nbl.get_neighbors(
positions, cell, pbc, fractional=False
)
edge_index = neighbor_data['edge_index']
distances = neighbor_data['distances']
offsets = neighbor_data['offsets'] # None for isolated systems
# Filter by minimum cutoff (remove too-close neighbors)
mask = distances > self.min_cutoff
edge_index = edge_index[:, mask]
distances = distances[mask]
if offsets is not None:
offsets = offsets[mask]
# Compute displacement vectors for all neighbor pairs
i_indices = edge_index[0]
j_indices = edge_index[1]
if cell is not None and offsets is not None:
# Periodic system: reconstruct displacements using wrapped
# fractional coordinates to ensure consistency with
# ghost-backend image offsets
cell_inv = torch.linalg.inv(cell)
positions_frac = torch.remainder(positions @ cell_inv, 1.0)
r_ij = (
(positions_frac[j_indices] + offsets.to(self.dtype)) @ cell
) - (positions_frac[i_indices] @ cell)
else:
# Isolated system
r_ij = positions[j_indices] - positions[i_indices]
# Organize neighbor information per atom
n_atoms = len(positions)
neighbor_counts = torch.zeros(
n_atoms, dtype=torch.long, device=self.device)
neighbor_indices_list = []
neighbor_vectors_list = []
for atom_idx in range(n_atoms):
# Find all neighbors of this atom
neighbor_mask = i_indices == atom_idx
atom_neighbors = j_indices[neighbor_mask]
atom_vectors = r_ij[neighbor_mask]
neighbor_counts[atom_idx] = len(atom_neighbors)
neighbor_indices_list.append(atom_neighbors)
neighbor_vectors_list.append(atom_vectors)
# Compute features using core forward method with neighbor data
features = self.forward(
positions, species, neighbor_indices_list, neighbor_vectors_list)
# Package neighbor info for output (convert to numpy)
neighbor_lists_np = [nb.cpu().numpy() for nb in neighbor_indices_list]
neighbor_vectors_np = [vec.cpu().numpy()
for vec in neighbor_vectors_list]
max_neighbors = neighbor_counts.max().item() if n_atoms > 0 else 0
neighbor_info = {
'neighbor_counts': neighbor_counts.cpu().numpy(),
'neighbor_lists': neighbor_lists_np,
'neighbor_vectors': neighbor_vectors_np,
'max_neighbors': max_neighbors,
}
return features, neighbor_info
[docs]
class BatchedFeaturizer(nn.Module):
"""
Batched featurization for multiple structures.
More efficient for training on datasets.
"""
[docs]
def __init__(self, featurizer: ChebyshevDescriptor):
"""
Initialize batched featurizer.
Args:
featurizer: ChebyshevDescriptor instance to use
"""
super().__init__()
self.featurizer = featurizer
def forward(
self,
batch_positions: List[torch.Tensor],
batch_species: List[List[str]],
batch_cells: Optional[List[torch.Tensor]] = None,
batch_pbc: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Featurize batch of structures.
Args:
batch_positions: List of (N_i, 3) position tensors
batch_species: List of species lists
batch_cells: List of (3, 3) cell tensors
batch_pbc: List of (3,) pbc tensors
Returns
-------
features: (total_atoms, n_features) concatenated features
batch_indices: (total_atoms,) batch index for each atom
"""
all_features = []
all_batch_indices = []
for batch_idx, (pos, species) in enumerate(
zip(batch_positions, batch_species)
):
cell = batch_cells[batch_idx] if batch_cells else None
pbc = batch_pbc[batch_idx] if batch_pbc else None
# Use convenience wrapper
features = self.featurizer.forward_from_positions(
pos, species, cell, pbc)
all_features.append(features)
all_batch_indices.append(
torch.full((len(pos),), batch_idx, dtype=torch.long)
)
# Concatenate
features = torch.cat(all_features, dim=0)
batch_indices = torch.cat(all_batch_indices, dim=0)
return features, batch_indices