Source code for aenet.torch_featurize.chebyshev

"""
Vectorized Chebyshev Polynomial Evaluation for PyTorch.

This module implements Chebyshev polynomial evaluation using the explicit
cosine form T_n(x) = cos(n * arccos(x)) instead of recurrence relations,
enabling efficient vectorization and GPU acceleration.

References
----------
    N. Artrith, A. Urban, and G. Ceder, PRB 96 (2017) 014112
"""

from typing import Tuple

import torch
import torch.nn as nn


[docs] class ChebyshevPolynomials(nn.Module): """ Vectorized Chebyshev polynomial evaluation using cosine form. Uses T_n(x) = cos(n * arccos(x)) for numerical stability and efficient vectorization across all polynomial orders. Parameters ---------- max_order : int Maximum Chebyshev polynomial order to compute r_min : float Minimum distance (inner cutoff) in Angstroms r_max : float Maximum distance (outer cutoff) in Angstroms dtype : torch.dtype, optional Data type for computations (default: torch.float64) Examples -------- >>> cheb = ChebyshevPolynomials(max_order=5, r_min=0.5, r_max=4.0) >>> r = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64) >>> T = cheb(r) # Shape: (3, 6) for orders 0-5 """
[docs] def __init__( self, max_order: int, r_min: float, r_max: float, dtype: torch.dtype = torch.float64, ): super().__init__() self.max_order = max_order self.r_min = r_min self.r_max = r_max self.dtype = dtype # Register order indices as buffer (for automatic device transfer) orders = torch.arange(max_order + 1, dtype=dtype) self.register_buffer("orders", orders)
def rescale_distances(self, r: torch.Tensor) -> torch.Tensor: """ Rescale distances from [r_min, r_max] to [-1, 1]. The Chebyshev polynomials are defined on [-1, 1], so distances must be rescaled using: x = (2*r - r_min - r_max) / (r_max - r_min) Parameters ---------- r : torch.Tensor Distances in Angstroms, any shape Returns ------- torch.Tensor Rescaled distances in [-1, 1], same shape as r Notes ----- Values are clamped to [-1, 1] for numerical stability. """ x = (2.0 * r - self.r_min - self.r_max) / (self.r_max - self.r_min) # Clamp to valid range for numerical stability x = torch.clamp(x, -1.0, 1.0) return x def cutoff_function(self, r: torch.Tensor, Rc: float) -> torch.Tensor: """ Cosine cutoff function. Implements: fc(r) = 0.5 * [cos(π*r/Rc) + 1] for r < Rc fc(r) = 0 for r >= Rc Parameters ---------- r : torch.Tensor Distances, any shape Rc : float Cutoff radius in Angstroms Returns ------- torch.Tensor Cutoff function values, same shape as r Notes ----- The cutoff function smoothly goes to zero at r=Rc, ensuring continuous features with continuous first derivatives. """ fc = torch.where( r < Rc, 0.5 * (torch.cos(torch.pi * r / Rc) + 1.0), torch.zeros_like(r), ) return fc def cutoff_derivative(self, r: torch.Tensor, Rc: float) -> torch.Tensor: """ Derivative of the cosine cutoff function. Implements: dfc/dr = -0.5 * π/Rc * sin(π*r/Rc) for r < Rc dfc/dr = 0 for r >= Rc Parameters ---------- r : torch.Tensor Distances, any shape Rc : float Cutoff radius in Angstroms Returns ------- torch.Tensor Derivative of cutoff function, same shape as r """ dfc = torch.where( r < Rc, -0.5 * torch.pi / Rc * torch.sin(torch.pi * r / Rc), torch.zeros_like(r), ) return dfc def forward(self, r: torch.Tensor) -> torch.Tensor: """ Evaluate Chebyshev polynomials for given distances. Uses the explicit formula T_n(x) = cos(n * arccos(x)) which allows computing all orders simultaneously. Parameters ---------- r : torch.Tensor Distances in Angstroms, shape (..., N) Returns ------- torch.Tensor Chebyshev polynomials, shape (..., N, max_order+1) T[..., i, n] = T_n(x_i) where x_i is rescaled r[..., i] Notes ----- Broadcasting: - arccos_x has shape (..., N) - self.orders has shape (max_order+1,) - Result has shape (..., N, max_order+1) """ # Rescale distances to [-1, 1] x = self.rescale_distances(r) # Compute arccos once (clamping for numerical stability) arccos_x = torch.arccos(x.clamp(-1.0, 1.0)) # Compute T_n(x) = cos(n * arccos(x)) for all n simultaneously # Broadcasting: arccos_x[..., N, 1] * orders[max_order+1] T = torch.cos(self.orders * arccos_x.unsqueeze(-1)) return T def evaluate_with_derivatives( self, r: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Evaluate Chebyshev polynomials and their derivatives. Uses the relationship: dT_n/dx = n * U_{n-1}(x) where U_n are Chebyshev polynomials of the second kind: U_n(x) = sin((n+1)*arccos(x)) / sqrt(1-x²) Parameters ---------- r : torch.Tensor Distances in Angstroms, shape (..., N) Returns ------- T : torch.Tensor Chebyshev polynomials, shape (..., N, max_order+1) dT_dr : torch.Tensor Derivatives w.r.t. r, shape (..., N, max_order+1) Notes ----- The derivative uses the chain rule: dT_n/dr = dT_n/dx * dx/dr where dx/dr = 2/(r_max - r_min) from the rescaling. """ # Rescale distances to [-1, 1] x = self.rescale_distances(r) # Compute arccos (clamped for stability) arccos_x = torch.arccos(x.clamp(-1.0, 1.0)) # Chebyshev polynomials T_n(x) = cos(n * arccos(x)) T = torch.cos(self.orders * arccos_x.unsqueeze(-1)) # For derivatives, we need U_{n-1}(x) # U_n(x) = sin((n+1)*arccos(x)) / sqrt(1-x²) # Compute sqrt term with small epsilon for numerical stability sqrt_term = torch.sqrt((1.0 - x**2).clamp(min=1e-10)) # Compute U polynomials for all orders # U_n corresponds to order n, so U_{n-1} is at index n-1 U_orders = torch.arange( self.max_order + 1, dtype=self.dtype, device=r.device ) U = torch.sin( (U_orders + 1) * arccos_x.unsqueeze(-1) ) / sqrt_term.unsqueeze(-1) # dT_n/dx = n * U_{n-1} # For n=0: derivative is 0 # For n≥1: use n * U[n-1] dT_dx = torch.zeros_like(T) if self.max_order >= 1: dT_dx[..., 1:] = self.orders[1:] * U[..., :-1] # Chain rule: dT/dr = dT/dx * dx/dr dx_dr = 2.0 / (self.r_max - self.r_min) dT_dr = dT_dx * dx_dr return T, dT_dr
[docs] class RadialBasis(nn.Module): """ Radial basis functions combining Chebyshev polynomials with cutoff. Implements: G_rad = T_n(r) * fc(r) Parameters ---------- rad_order : int Maximum order for radial Chebyshev polynomials rad_cutoff : float Radial cutoff radius in Angstroms min_cutoff : float, optional Minimum cutoff (inner radius) in Angstroms (default: 0.55) dtype : torch.dtype, optional Data type for computations (default: torch.float64) Examples -------- >>> rad_basis = RadialBasis(rad_order=10, rad_cutoff=4.0) >>> distances = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64) >>> G_rad = rad_basis(distances) # Shape: (3, 11) for orders 0-10 """
[docs] def __init__( self, rad_order: int, rad_cutoff: float, min_cutoff: float = 0.55, dtype: torch.dtype = torch.float64, ): super().__init__() # Chebyshev polynomials are evaluated on [0, Rc], not [min_cutoff, Rc] # min_cutoff is only used for neighbor filtering, not polynomial domain self.cheb = ChebyshevPolynomials( max_order=rad_order, r_min=0.0, # Always 0.0 for radial basis r_max=rad_cutoff, dtype=dtype, ) self.rad_cutoff = rad_cutoff self.rad_order = rad_order
def forward(self, distances: torch.Tensor) -> torch.Tensor: """ Evaluate radial symmetry functions. Parameters ---------- distances : torch.Tensor Pairwise distances in Angstroms, shape (num_pairs,) Returns ------- torch.Tensor Radial features, shape (num_pairs, rad_order+1) """ # Chebyshev polynomials T = self.cheb(distances) # (num_pairs, rad_order+1) # Cutoff function fc = self.cheb.cutoff_function(distances, self.rad_cutoff) # Combine: G_rad = T * fc G_rad = T * fc.unsqueeze(-1) return G_rad def forward_with_derivatives( self, distances: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Evaluate radial symmetry functions with derivatives. Uses the product rule: d(T*fc)/dr = dT/dr * fc + T * dfc/dr Parameters ---------- distances : torch.Tensor Pairwise distances in Angstroms, shape (num_pairs,) Returns ------- G_rad : torch.Tensor Radial features, shape (num_pairs, rad_order+1) dG_rad_dr : torch.Tensor Derivatives w.r.t. distance, shape (num_pairs, rad_order+1) """ # Polynomials and derivatives T, dT_dr = self.cheb.evaluate_with_derivatives(distances) # Cutoff and its derivative fc = self.cheb.cutoff_function(distances, self.rad_cutoff) dfc_dr = self.cheb.cutoff_derivative(distances, self.rad_cutoff) # Product rule: d(T*fc)/dr = dT/dr * fc + T * dfc/dr G_rad = T * fc.unsqueeze(-1) dG_rad_dr = dT_dr * fc.unsqueeze(-1) + T * dfc_dr.unsqueeze(-1) return G_rad, dG_rad_dr
[docs] class AngularBasis(nn.Module): """ Angular basis functions using Chebyshev polynomials. For a triplet of atoms (i, j, k), computes: G_ang = T_n(cos θ_ijk) * fc(r_ij) * fc(r_ik) where θ_ijk is the angle at atom i. Parameters ---------- ang_order : int Maximum order for angular Chebyshev polynomials ang_cutoff : float Angular cutoff radius in Angstroms min_cutoff : float, optional Minimum cutoff (not used for angular, kept for consistency) dtype : torch.dtype, optional Data type for computations (default: torch.float64) Notes ----- For angular features, cos(θ) is already in [-1, 1], so we use r_min=-1.0 and r_max=1.0 (no rescaling needed). """
[docs] def __init__( self, ang_order: int, ang_cutoff: float, min_cutoff: float = 0.55, dtype: torch.dtype = torch.float64, ): super().__init__() # For cos(θ), already in [-1, 1], so no rescaling needed self.cheb = ChebyshevPolynomials( max_order=ang_order, r_min=-1.0, r_max=1.0, dtype=dtype ) self.ang_cutoff = ang_cutoff self.ang_order = ang_order
def forward( self, r_ij: torch.Tensor, r_ik: torch.Tensor, cos_theta: torch.Tensor ) -> torch.Tensor: """ Evaluate angular symmetry functions. Parameters ---------- r_ij : torch.Tensor Distances from atom i to atom j, shape (num_triplets,) r_ik : torch.Tensor Distances from atom i to atom k, shape (num_triplets,) cos_theta : torch.Tensor Cosine of angles θ_ijk, shape (num_triplets,) Returns ------- torch.Tensor Angular features, shape (num_triplets, ang_order+1) Notes ----- The cosine of angles is clamped to [-1, 1] for numerical stability. """ # Chebyshev of cos(θ) - no rescaling needed as already in [-1,1] # Since r_min=-1, r_max=1, rescaling is identity: x = cos_theta T_theta = self.cheb(cos_theta) # Cutoff functions for both distances fc_ij = self.cheb.cutoff_function(r_ij, self.ang_cutoff) fc_ik = self.cheb.cutoff_function(r_ik, self.ang_cutoff) # Combine: G_ang = T(cos θ) * fc(r_ij) * fc(r_ik) G_ang = T_theta * (fc_ij * fc_ik).unsqueeze(-1) return G_ang def forward_with_derivatives( self, r_ij: torch.Tensor, r_ik: torch.Tensor, cos_theta: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Evaluate angular features and their partial derivatives. Returns derivatives w.r.t. each input: r_ij, r_ik, and cos_theta. """ # Polynomials and their derivatives w.r.t. cos_theta T_theta, dT_dcos = self.cheb.evaluate_with_derivatives(cos_theta) # Cutoff functions and their derivatives fc_ij = self.cheb.cutoff_function(r_ij, self.ang_cutoff) fc_ik = self.cheb.cutoff_function(r_ik, self.ang_cutoff) dfc_drij = self.cheb.cutoff_derivative(r_ij, self.ang_cutoff) dfc_drik = self.cheb.cutoff_derivative(r_ik, self.ang_cutoff) # Combined cutoff product cutoff_prod = (fc_ij * fc_ik).unsqueeze(-1) # G_ang = T * fc_ij * fc_ik G_ang = T_theta * cutoff_prod # Partial derivatives using product rule # dG/d(cos_theta) = (dT/dcos) * fc_ij * fc_ik dG_dcos = dT_dcos * cutoff_prod # dG/dr_ij = T * (dfc/drij) * fc_ik dG_drij = T_theta * (dfc_drij * fc_ik).unsqueeze(-1) # dG/dr_ik = T * fc_ij * (dfc/drik) dG_drik = T_theta * (fc_ij * dfc_drik).unsqueeze(-1) return G_ang, dG_dcos, dG_drij, dG_drik