"""
PyTorch-based neighbor list for atomic structures.
Supports:
- Periodic boundary conditions (PBC) with arbitrary cell shapes
- Isolated systems (molecules)
- GPU acceleration
- Double precision
"""
import warnings
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
try:
from torch_cluster import radius, radius_graph
TORCH_CLUSTER_AVAILABLE = True
except ImportError:
TORCH_CLUSTER_AVAILABLE = False
warnings.warn(
"torch_cluster not available. Install with: "
"pip install torch-cluster -f "
"https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html"
)
[docs]
class TorchNeighborList:
"""
PyTorch-based neighbor list for atomic structures.
Supports:
- Periodic boundary conditions (PBC)
- Isolated systems
- GPU acceleration
- Double precision
Example:
>>> nbl = TorchNeighborList(cutoff=4.0, device='cpu')
>>> positions = torch.randn(10, 3, dtype=torch.float64)
>>> result = nbl.get_neighbors(positions)
>>> edge_index = result['edge_index'] # (2, num_edges)
>>> distances = result['distances'] # (num_edges,)
Notes:
- Default PBC backend is 'ghost' which constructs a support set of
ghost images near periodic boundaries and performs a single
bipartite radius search (support -> central).
- Offsets returned under PBC are integer lattice offsets defined
relative to wrapped fractional coordinates (positions wrapped
to [0,1) before offsetting). Downstream code should reconstruct
displacements as:
r_ij = ((frac[j] + offsets) @ cell) - (frac[i] @ cell)
where frac = remainder(positions @ inv(cell), 1.0).
- Fully differentiable and GPU-compatible. Truncation is handled by
auto-growing max_num_neighbors up to auto_max_neighbors with
warnings.
"""
[docs]
def __init__(
self,
cutoff: float,
atom_types: Optional[torch.Tensor] = None,
cutoff_dict: Optional[Dict[Tuple[int, int], float]] = None,
device: str = "cpu",
dtype: torch.dtype = torch.float64,
max_num_neighbors: int = 256,
truncation_handling: str = "auto",
auto_multiplier: int = 2,
auto_max_neighbors: int = 65536,
pbc_backend: str = "ghost",
):
"""
Initialize neighbor list.
Args:
cutoff: Maximum interaction cutoff radius in Angstroms
atom_types: (N,) tensor of atom types (e.g., atomic numbers)
cutoff_dict: Dict mapping (type_i, type_j) tuples to cutoff
distances. Keys should be sorted tuples: (min, max)
device: 'cpu' or 'cuda'
dtype: torch.float32 or torch.float64 (recommended: float64)
max_num_neighbors: Maximum number of neighbors per atom to consider
(default: 256). Increase if you encounter systems with very
dense neighbor environments.
Raises
------
ValueError: If cutoff_dict contains types not in atom_types
ValueError: If cutoff_dict values exceed maximum cutoff
"""
if not TORCH_CLUSTER_AVAILABLE:
raise ImportError(
"torch_cluster is required but not installed. "
"Install with: pip install torch-cluster"
)
self.cutoff = cutoff
self.atom_types = atom_types
self.cutoff_dict = cutoff_dict
self.device = device
self.dtype = dtype
self.max_num_neighbors = max_num_neighbors
self.truncation_handling = truncation_handling
self.auto_multiplier = int(auto_multiplier)
self.auto_max_neighbors = int(auto_max_neighbors)
self.pbc_backend = pbc_backend
# Deprecation notice for legacy backend
if self.pbc_backend == "legacy":
warnings.warn(
"TorchNeighborList: legacy PBC backend is deprecated and will"
"be removed in a future release. Use pbc_backend='ghost' "
"(default).",
DeprecationWarning,
)
# Validate cutoff_dict if both types and dict are provided
if cutoff_dict is not None and atom_types is not None:
self._validate_cutoff_dict(cutoff_dict, atom_types)
# Cache for efficiency
self._cached_result = None
self._cache_key = None
def _radius_with_inclusive(self) -> float:
"""
Return radius slightly larger than cutoff to include pairs at exactly
the cutoff distance (matching legacy <= cutoff behavior).
Use a conservative epsilon to account for accumulated FP error from
replication, transforms, and torch_cluster distance computation.
"""
eps = 1e-6 if self.dtype == torch.float64 else 1e-5
return float(self.cutoff + eps)
def _to_tensor(
self,
array: Union[np.ndarray, torch.Tensor],
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""
Convert numpy array or torch tensor to appropriate tensor.
Args:
array: Input array (numpy or torch)
dtype: Target dtype (uses self.dtype if not specified)
Returns
-------
torch.Tensor on self.device with appropriate dtype
"""
if dtype is None:
dtype = self.dtype
if isinstance(array, np.ndarray):
tensor = torch.from_numpy(array)
else:
tensor = array
return tensor.to(self.device).to(dtype)
@classmethod
def from_AtomicStructure(
cls,
structure,
cutoff: float,
frame: int = -1,
device: str = "cpu",
dtype: torch.dtype = torch.float64,
max_num_neighbors: int = 256,
):
"""
Factory method: create neighbor list from AtomicStructure.
Args:
structure: Instance of aenet.geometry.AtomicStructure
cutoff: Maximum interaction cutoff radius in Angstroms
frame: Frame index to use (default: -1 for last frame)
device: 'cpu' or 'cuda'
dtype: torch.float32 or torch.float64 (recommended: float64)
max_num_neighbors: Maximum number of neighbors per atom
Returns
-------
TorchNeighborList instance configured for the structure
Example:
>>> from aenet.geometry import AtomicStructure
>>> from aenet.torch_featurize import TorchNeighborList
>>> structure = AtomicStructure(coords, types, avec=avec)
>>> nbl = TorchNeighborList.from_AtomicStructure(
... structure, cutoff=4.0
... )
"""
return cls(
cutoff=cutoff,
device=device,
dtype=dtype,
max_num_neighbors=max_num_neighbors,
)
def get_neighbors(
self,
positions: torch.Tensor,
cell: Optional[torch.Tensor] = None,
pbc: Optional[torch.Tensor] = None,
fractional: bool = True,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Unified interface for neighbor finding.
Args:
positions: (N, 3) atom positions
- For isolated systems: Always Cartesian coordinates
in Angstroms
- For periodic systems: Fractional [0,1) or Cartesian,
see fractional arg
cell: (3, 3) lattice vectors as rows (None for isolated systems)
pbc: (3,) boolean tensor for PBC in each direction
(default: [True, True, True] if cell is provided)
fractional: For periodic systems only. If True, positions
are fractional coordinates [0, 1). If False, positions are
Cartesian (Angstroms).
Returns
-------
Dictionary containing:
- 'edge_index': (2, num_edges) neighbor pairs [source, target]
- 'distances': (num_edges,) pairwise distances in Angstroms
- 'offsets': (num_edges, 3) cell offsets (None for isolated
systems)
- 'num_neighbors': (N,) number of neighbors per atom
"""
if cell is None:
# Isolated system
edge_index, distances = self.get_neighbors_isolated(positions)
num_neighbors = self._count_neighbors(
edge_index, positions.shape[0]
)
return {
"edge_index": edge_index,
"distances": distances,
"offsets": None,
"num_neighbors": num_neighbors,
}
else:
# Periodic system
edge_index, distances, offsets = self.get_neighbors_pbc(
positions, cell, pbc, fractional
)
num_neighbors = self._count_neighbors(
edge_index, positions.shape[0]
)
return {
"edge_index": edge_index,
"distances": distances,
"offsets": offsets,
"num_neighbors": num_neighbors,
}
def get_neighbors_isolated(
self, positions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Find neighbors for isolated system (no PBC).
Args:
positions: (N, 3) atom positions in Angstroms (Cartesian)
Returns
-------
edge_index: (2, num_edges) neighbor pairs [source, target]
distances: (num_edges,) pairwise distances in Angstroms
"""
positions = positions.to(self.device).to(self.dtype)
# Handle single atom case
if positions.shape[0] <= 1:
edge_index = torch.empty(
(2, 0), dtype=torch.long, device=self.device
)
distances = torch.empty(0, dtype=self.dtype, device=self.device)
return edge_index, distances
# Use radius_graph from torch_cluster with truncation handling
# Start with a higher baseline in isolated mode to reduce retries
cur_max_nn = int(max(self.max_num_neighbors, 2048))
while True:
edge_index = radius_graph(
positions,
r=self._radius_with_inclusive(),
max_num_neighbors=cur_max_nn,
flow="source_to_target",
loop=False, # Don't include self-loops
)
# Compute distances
if edge_index.shape[1] > 0:
row, col = edge_index
diff = positions[row] - positions[col]
distances = torch.norm(diff, dim=1)
else:
distances = torch.empty(
0, dtype=self.dtype, device=self.device)
# Check truncation: any source node with degree == cur_max_nn
num_neighbors = self._count_neighbors(
edge_index, positions.shape[0])
truncated = (num_neighbors.numel() > 0
and num_neighbors.max().item() >= cur_max_nn)
if truncated:
if (self.truncation_handling == "auto"
and cur_max_nn < self.auto_max_neighbors):
new_max = min(cur_max_nn * self.auto_multiplier,
self.auto_max_neighbors)
warnings.warn(
f"TorchNeighborList: max_num_neighbors={cur_max_nn} "
"hit in isolated mode; "
f"retrying with max_num_neighbors={new_max}.",
RuntimeWarning,
)
cur_max_nn = int(new_max)
continue
elif self.truncation_handling == "error":
raise RuntimeError(
"TorchNeighborList: neighbor list truncated at "
f"max_num_neighbors={cur_max_nn} "
f"(isolated). Increase max_num_neighbors or "
"reduce cutoff."
)
elif self.truncation_handling == "warn":
warnings.warn(
f"TorchNeighborList: neighbor list may be "
"truncated at "
f"max_num_neighbors={cur_max_nn} (isolated).",
RuntimeWarning
)
# Update stored value if auto grew
self.max_num_neighbors = max(self.max_num_neighbors, cur_max_nn)
return edge_index, distances
def get_neighbors_pbc(
self,
positions: torch.Tensor,
cell: torch.Tensor,
pbc: Optional[torch.Tensor] = None,
fractional: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Find neighbors with periodic boundary conditions.
Args:
positions: (N, 3) atom positions
cell: (3, 3) lattice vectors as rows
pbc: (3,) boolean tensor for PBC in each direction
(default: [True, True, True])
fractional: If True, positions are fractional [0,1).
If False, positions are Cartesian (Angstroms).
Returns
-------
edge_index: (2, num_edges) neighbor pairs [source, target]
distances: (num_edges,) pairwise distances in Angstroms
offsets: (num_edges, 3) cell offset vectors for each edge
"""
backend = getattr(self, "pbc_backend", "legacy")
if backend == "ghost":
return self._get_neighbors_pbc_cell_list(
positions, cell, pbc, fractional
)
elif backend == "legacy":
return self._get_neighbors_pbc_old(
positions, cell, pbc, fractional
)
else:
raise ValueError(f"Unknown pbc_backend: {backend}")
def _get_neighbors_pbc_cell_list(
self,
positions: torch.Tensor,
cell: torch.Tensor,
pbc: Optional[torch.Tensor] = None,
fractional: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Optimized PBC neighbor finding using ghost atoms.
Instead of iterating over cells in Python (slow), we replicate atoms
near the boundaries to create "ghost" atoms, then run a single
radius_graph on the extended system.
"""
if pbc is None:
pbc = torch.tensor(
[True, True, True], dtype=torch.bool, device=self.device
)
else:
pbc = pbc.to(self.device)
positions = positions.to(self.device).to(self.dtype)
cell = cell.to(self.device).to(self.dtype)
# Ensure we have fractional coordinates
if not fractional:
try:
cell_inv = torch.linalg.inv(cell)
positions_frac = positions @ cell_inv
except RuntimeError:
positions_frac = torch.zeros_like(positions)
else:
positions_frac = positions
# Wrap fractional coordinates to [0, 1)
positions_frac = torch.remainder(positions_frac, 1.0)
# Calculate perpendicular widths of the cell to determine cutoff
# in fractional coords
# Volume = det(cell)
# Area_i = |a_j x a_k|
# Width_i = Volume / Area_i
volume = torch.abs(torch.det(cell)).clamp_min(1e-12)
cross_0 = torch.norm(torch.linalg.cross(cell[1], cell[2]))
cross_1 = torch.norm(torch.linalg.cross(cell[2], cell[0]))
cross_2 = torch.norm(torch.linalg.cross(cell[0], cell[1]))
widths = volume / torch.stack([cross_0, cross_1, cross_2]
).clamp_min(1e-12)
cutoff_frac = (self.cutoff / widths).to(self.device) * 1.01
# Number of periodic image layers needed along each lattice vector
search_layers = torch.ceil(self.cutoff / widths).to(torch.long)
# Start with original atoms
# We track: (positions, indices, offsets)
# indices: original atom index
# offsets: (3,) integer offset of the image
# Current set of atoms (starts with central cell)
current_pos = positions_frac
current_indices = torch.arange(len(positions), device=self.device)
current_offsets = torch.zeros((len(positions), 3),
dtype=torch.long,
device=self.device)
# Iteratively replicate along each periodic dimension
for i in range(3):
if not pbc[i]:
continue
# Determine how many image layers are required along axis i
s_i = int(search_layers[i].item())
if s_i <= 0:
continue
# For each layer k, replicate a snapshot of the current set
# to avoid re-replicating newly added ghosts within the same k
for k in range(1, s_i + 1):
base_pos_k = current_pos
base_idx_k = current_indices
base_off_k = current_offsets
# Lower boundary: replicate to +k (image +k)
mask_lower = base_pos_k[:, i] < cutoff_frac[i]
if mask_lower.any():
ghosts_pos_lower = base_pos_k[mask_lower].clone()
ghosts_pos_lower[:, i] = ghosts_pos_lower[:, i] + float(k)
ghosts_indices_lower = base_idx_k[mask_lower]
ghosts_offsets_lower = base_off_k[mask_lower].clone()
ghosts_offsets_lower[:, i] = ghosts_offsets_lower[:, i] + k
current_pos = torch.cat([current_pos, ghosts_pos_lower],
dim=0)
current_indices = torch.cat([current_indices,
ghosts_indices_lower], dim=0)
current_offsets = torch.cat([current_offsets,
ghosts_offsets_lower], dim=0)
# Upper boundary: replicate to -k (image -k)
mask_upper = base_pos_k[:, i] > (1.0 - cutoff_frac[i])
if mask_upper.any():
ghosts_pos_upper = base_pos_k[mask_upper].clone()
ghosts_pos_upper[:, i] = ghosts_pos_upper[:, i] - float(k)
ghosts_indices_upper = base_idx_k[mask_upper]
ghosts_offsets_upper = base_off_k[mask_upper].clone()
ghosts_offsets_upper[:, i] = ghosts_offsets_upper[:, i] - k
current_pos = torch.cat([current_pos, ghosts_pos_upper],
dim=0)
current_indices = torch.cat([current_indices,
ghosts_indices_upper], dim=0)
current_offsets = torch.cat([current_offsets,
ghosts_offsets_upper], dim=0)
# Deduplicate ghosts by (original_index, offsets) to avoid duplicates
# that can arise from different replication paths.
if current_pos.shape[0] > 0:
keys = torch.stack(
[
current_indices.to(torch.long),
current_offsets[:, 0],
current_offsets[:, 1],
current_offsets[:, 2],
],
dim=1,
)
# Build first-occurrence mask using a Python set
# (small int tensors)
seen = set()
keep_list = []
for r in keys.cpu().tolist():
t = tuple(r)
if t in seen:
keep_list.append(False)
else:
seen.add(t)
keep_list.append(True)
keep_mask = torch.tensor(keep_list, device=self.device,
dtype=torch.bool)
current_pos = current_pos[keep_mask]
current_indices = current_indices[keep_mask]
current_offsets = current_offsets[keep_mask]
# Convert all fractional positions to Cartesian
all_cart_pos = current_pos @ cell
# Run radius_graph on the full set
# We want neighbors for the original atoms (first N atoms)
# But radius_graph(x) computes all pairs.
# We can filter later, or use bipartite radius(x, y).
# x = all atoms (potential neighbors)
# y = original atoms (central cell)
n_orig = len(positions)
central_cart_pos = all_cart_pos[:n_orig]
# Use radius(x, y) -> neighbors of y in x
# x = support (all), y = query (central)
cur_max_nn = int(max(self.max_num_neighbors, 4096))
while True:
row, col = radius(
x=all_cart_pos,
y=central_cart_pos,
r=self._radius_with_inclusive(),
max_num_neighbors=cur_max_nn,
num_workers=1,
)
# Check truncation on the query (central) nodes
truncated = False
if row.numel() > 0:
deg = torch.bincount(row, minlength=n_orig)
truncated = deg.numel() > 0 and deg.max().item() >= cur_max_nn
if truncated:
if (self.truncation_handling == "auto"
and cur_max_nn < self.auto_max_neighbors):
new_max = min(cur_max_nn * self.auto_multiplier,
self.auto_max_neighbors)
warnings.warn(
f"TorchNeighborList: max_num_neighbors={cur_max_nn} "
"hit under PBC (bipartite radius); "
f"retrying with max_num_neighbors={new_max}.",
RuntimeWarning,
)
cur_max_nn = int(new_max)
continue
elif self.truncation_handling == "error":
raise RuntimeError(
"TorchNeighborList: neighbor list truncated at "
f"max_num_neighbors={cur_max_nn} (PBC bipartite). "
"Increase max_num_neighbors or reduce cutoff."
)
elif self.truncation_handling == "warn":
warnings.warn(
f"TorchNeighborList: neighbor list may be truncated "
f"at max_num_neighbors={cur_max_nn} (PBC bipartite).",
RuntimeWarning
)
break
# Update stored value if auto grew
self.max_num_neighbors = max(self.max_num_neighbors, cur_max_nn)
# row: indices in central_cart_pos (0 to N-1) -> source
# col: indices in all_cart_pos -> target
# Filter out self-loops (where index is same AND offset is zero)
# target indices (col) map to original indices via current_indices[col]
target_indices = current_indices[col]
target_offsets = current_offsets[col]
# Self loop condition: source_idx == target_idx AND target_offset == 0
# row is already source_idx (since we queried with
# central atoms 0..N-1)
mask_self = (row == target_indices) & (torch.all(target_offsets == 0,
dim=1))
mask_valid = ~mask_self
# Apply mask
row = row[mask_valid]
col = col[mask_valid]
target_indices = target_indices[mask_valid]
target_offsets = target_offsets[mask_valid]
# Compute distances
diff = central_cart_pos[row] - all_cart_pos[col]
distances = torch.norm(diff, dim=1)
# Construct edge_index
edge_index = torch.stack([row, target_indices], dim=0)
return edge_index, distances, target_offsets
def _get_neighbors_pbc_old(
self,
positions: torch.Tensor,
cell: torch.Tensor,
pbc: Optional[torch.Tensor] = None,
fractional: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Legacy implementation: Find neighbors with periodic boundary
conditions. (Kept for validation purposes)
"""
if pbc is None:
pbc = torch.tensor(
[True, True, True], dtype=torch.bool, device=self.device
)
else:
pbc = pbc.to(self.device)
positions = positions.to(self.device).to(self.dtype)
cell = cell.to(self.device).to(self.dtype)
# Convert to Cartesian if needed
if fractional:
cart_positions = positions @ cell
else:
# Use original Cartesian positions without wrapping to keep
# offsets consistent with the absolute coordinates returned.
cart_positions = positions
# Determine search range in each direction
search_cells = self._determine_search_cells(cell, pbc)
# Create replicated positions for periodic images
all_positions, all_offsets = self._create_periodic_images(
cart_positions, cell, search_cells, pbc
)
# Find neighbors using radius_graph with truncation handling
# Start with a higher baseline under PBC to avoid premature truncation
cur_max_nn = int(max(self.max_num_neighbors, 4096))
while True:
edge_index_all = radius_graph(
all_positions,
r=self._radius_with_inclusive(),
max_num_neighbors=cur_max_nn,
flow="source_to_target",
loop=False,
)
# Filter and compute distances for central cell sources
edge_index, distances, cell_offsets = self._process_pbc_edges(
edge_index_all, all_positions, all_offsets, positions.shape[0]
)
# Check truncation on central unit cell sources.
# We must detect truncation against the degrees computed on the
# full replicated graph (edge_index_all), not the filtered graph.
# Build per-source degree for all_positions nodes:
degrees_all = torch.zeros(
all_positions.shape[0], dtype=torch.long, device=self.device
)
if edge_index_all.shape[1] > 0:
src_nodes = edge_index_all[0]
unique_src, counts_src = torch.unique(
src_nodes, return_counts=True)
degrees_all[unique_src] = counts_src
# Identify the replicated source nodes that are in the central cell
central_source_mask = torch.all(all_offsets == 0, dim=1)
if torch.any(central_source_mask):
degrees_central = degrees_all[central_source_mask]
truncated = (degrees_central.numel() > 0
and degrees_central.max().item() >= cur_max_nn)
else:
truncated = False
if truncated:
if (self.truncation_handling == "auto"
and cur_max_nn < self.auto_max_neighbors):
new_max = min(cur_max_nn * self.auto_multiplier,
self.auto_max_neighbors)
warnings.warn(
f"TorchNeighborList: max_num_neighbors={cur_max_nn} "
"hit under PBC; "
f"retrying with max_num_neighbors={new_max}.",
RuntimeWarning,
)
cur_max_nn = int(new_max)
continue
elif self.truncation_handling == "error":
raise RuntimeError(
"TorchNeighborList: neighbor list truncated at "
f"max_num_neighbors={cur_max_nn} "
f"(PBC). Increase max_num_neighbors or reduce cutoff."
)
elif self.truncation_handling == "warn":
warnings.warn(
f"TorchNeighborList: neighbor list may be "
"truncated at "
f"max_num_neighbors={cur_max_nn} (PBC).",
RuntimeWarning
)
# Update stored value if auto grew
self.max_num_neighbors = max(self.max_num_neighbors, cur_max_nn)
return edge_index, distances, cell_offsets
def _determine_search_cells(
self, cell: torch.Tensor, pbc: torch.Tensor
) -> torch.Tensor:
"""
Determine how many periodic images to check in each direction.
Args:
cell: (3, 3) lattice vectors as rows
pbc: (3,) boolean tensor for PBC
Returns
-------
search_cells: (3,) integer number of cells to search in each
direction
"""
# Distances between opposite faces of the unit cell:
# d_i = Volume / Area(face_i), where face_i is
# opposite lattice vector i.
# This works for arbitrary (skewed) cells.
volume = torch.abs(torch.det(cell)).clamp_min(1e-12)
a = cell[0]
b = cell[1]
c = cell[2]
area0 = torch.norm(torch.linalg.cross(b, c)) # face normal to a
area1 = torch.norm(torch.linalg.cross(c, a)) # face normal to b
area2 = torch.norm(torch.linalg.cross(a, b)) # face normal to c
areas = torch.stack([area0, area1, area2]).clamp_min(1e-12)
face_distances = volume / areas # perpendicular distance between faces
# Number of cells needed to cover the cutoff along each direction
# Add +1 safety margin to ensure full coverage (corner/cross effects)
search_cells = torch.ceil(self.cutoff / face_distances
).to(torch.long) + 1
# Zero out non-periodic directions
search_cells = torch.where(
pbc, search_cells, torch.zeros_like(search_cells)
)
return search_cells
def _create_periodic_images(
self,
positions: torch.Tensor,
cell: torch.Tensor,
search_cells: torch.Tensor,
pbc: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Create periodic images of atoms.
Args:
positions: (N, 3) Cartesian positions
cell: (3, 3) lattice vectors
search_cells: (3,) number of cells to search in each direction
pbc: (3,) boolean for PBC
Returns
-------
all_positions: (num_images * N, 3) Cartesian positions including
periodic images
offsets: (num_images * N, 3) cell offset for each position
"""
n_atoms = positions.shape[0]
# Generate offset vectors
ranges = []
for s, p in zip(search_cells, pbc):
if p:
ranges.append(
torch.arange(
-s, s + 1, device=self.device, dtype=torch.long
)
)
else:
ranges.append(
torch.tensor([0], device=self.device, dtype=torch.long)
)
# Create meshgrid of offsets
offset_grid = torch.stack(
torch.meshgrid(*ranges, indexing="ij"), dim=-1
).reshape(-1, 3)
# Replicate positions for each offset
replicated_positions = positions.unsqueeze(0).expand(
offset_grid.shape[0], -1, -1
)
# Compute offset vectors in Cartesian coordinates
offset_vectors = (
offset_grid.to(self.dtype).unsqueeze(1) @ cell
).expand(-1, n_atoms, -1)
# Apply offsets
all_positions = (replicated_positions + offset_vectors).reshape(-1, 3)
all_offsets = (
offset_grid.unsqueeze(1).expand(-1, n_atoms, -1).reshape(-1, 3)
)
return all_positions, all_offsets
def _process_pbc_edges(
self,
edge_index: torch.Tensor,
all_positions: torch.Tensor,
all_offsets: torch.Tensor,
n_atoms: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Process edges from periodic images and compute distances.
Args:
edge_index: (2, num_edges) edge indices in replicated system
all_positions: (num_images * N, 3) all positions including images
all_offsets: (num_images * N, 3) cell offsets for each position
n_atoms: number of atoms in unit cell
Returns
-------
edge_index: (2, num_edges) filtered edge indices (in unit cell)
distances: (num_edges,) distances
cell_offsets: (num_edges, 3) cell offsets for each edge
"""
if edge_index.shape[1] == 0:
# No edges found
distances = torch.empty(0, dtype=self.dtype, device=self.device)
cell_offsets = torch.empty(
(0, 3), dtype=torch.long, device=self.device
)
return edge_index, distances, cell_offsets
row, col = edge_index
# Map replicated indices back to unit cell indices
unit_cell_row = row % n_atoms
unit_cell_col = col % n_atoms
# Get cell offsets
offset_row = all_offsets[row]
offset_col = all_offsets[col]
cell_offsets = offset_col - offset_row
# Compute distances
diff = all_positions[row] - all_positions[col]
distances = torch.norm(diff, dim=1)
# Filter: only keep edges where source is in central unit cell
# (to avoid double counting)
central_cell_mask = torch.all(offset_row == 0, dim=1)
# Also remove self-interactions in the central cell
self_interaction_mask = (unit_cell_row == unit_cell_col) & (
torch.all(cell_offsets == 0, dim=1)
)
valid_mask = central_cell_mask & (~self_interaction_mask)
# Apply filters
edge_index = torch.stack(
[unit_cell_row[valid_mask], unit_cell_col[valid_mask]]
)
distances = distances[valid_mask]
cell_offsets = cell_offsets[valid_mask]
return edge_index, distances, cell_offsets
def _count_neighbors(
self, edge_index: torch.Tensor, n_atoms: int
) -> torch.Tensor:
"""
Count number of neighbors for each atom.
Args:
edge_index: (2, num_edges) edge indices
n_atoms: total number of atoms
Returns
-------
num_neighbors: (n_atoms,) number of neighbors per atom
"""
if edge_index.shape[1] == 0:
return torch.zeros(n_atoms, dtype=torch.long, device=self.device)
num_neighbors = torch.zeros(
n_atoms, dtype=torch.long, device=self.device
)
unique, counts = torch.unique(edge_index[0], return_counts=True)
num_neighbors[unique] = counts
return num_neighbors
def _validate_cutoff_dict(
self,
cutoff_dict: Dict[Tuple[int, int], float],
atom_types: torch.Tensor,
) -> None:
"""
Validate that cutoff_dict is consistent with atom_types.
Args:
cutoff_dict: Dictionary of pair cutoffs
atom_types: Tensor of atom types
Raises
------
ValueError: If cutoff_dict keys contain undefined types
ValueError: If any cutoff exceeds self.cutoff
"""
unique_types = set(atom_types.cpu().numpy().tolist())
for (type_i, type_j), cutoff_val in cutoff_dict.items():
# Check types are defined
if type_i not in unique_types:
raise ValueError(
f"Type {type_i} in cutoff_dict not found in "
f"atom_types. Available types: {sorted(unique_types)}"
)
if type_j not in unique_types:
raise ValueError(
f"Type {type_j} in cutoff_dict not found in "
f"atom_types. Available types: {sorted(unique_types)}"
)
# Check cutoff doesn't exceed maximum
if cutoff_val > self.cutoff:
raise ValueError(
f"Cutoff {cutoff_val} for pair ({type_i}, {type_j}) "
f"exceeds maximum cutoff {self.cutoff}"
)
def get_neighbors_of_atom(
self,
atom_idx: int,
positions: Union[np.ndarray, torch.Tensor],
cell: Optional[Union[np.ndarray, torch.Tensor]] = None,
pbc: Optional[torch.Tensor] = None,
fractional: bool = True,
atom_types: Optional[torch.Tensor] = None,
cutoff_dict: Optional[Dict[Tuple[int, int], float]] = None,
return_coordinates: bool = False,
full_star: bool = False,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Get neighbors of a specific atom.
Args:
atom_idx: Index of the atom to query
positions: (N, 3) atom positions (numpy array or torch tensor)
cell: (3, 3) lattice vectors (None for isolated systems,
numpy array or torch tensor)
pbc: (3,) PBC flags
fractional: For periodic systems only. If True, positions
are fractional coordinates [0, 1). If False, positions are
Cartesian (Angstroms).
Default: True for backward compatibility.
atom_types: Override stored atom_types (optional)
cutoff_dict: Override stored cutoff_dict (optional)
return_coordinates: If True, also return actual neighbor
coordinates with PBC offsets applied (default: False)
full_star: If True, return all neighbors in both directions
(where atom_idx is source and target). This is useful for
extracting complete neighbor clusters for periodic systems.
Default: False (returns half-star for efficiency).
Returns
-------
Dictionary containing:
- 'indices': (num_neighbors,) neighbor atom indices
- 'distances': (num_neighbors,) distances to neighbors
- 'offsets': (num_neighbors, 3) cell offsets (or None)
- 'coordinates': (num_neighbors, 3) neighbor coordinates
(only if return_coordinates=True)
Note:
If both atom_types and cutoff_dict are provided (either
stored or as arguments), neighbors are filtered by
type-specific cutoffs.
"""
# Convert inputs to tensors
positions = self._to_tensor(positions)
if cell is not None:
cell = self._to_tensor(cell)
# Use stored values or overrides
types = atom_types if atom_types is not None else self.atom_types
cutoffs = cutoff_dict if cutoff_dict is not None else self.cutoff_dict
# Validate if both are provided
if cutoffs is not None and types is not None:
self._validate_cutoff_dict(cutoffs, types)
# Get or compute full neighbor list (with caching)
result = self._get_or_compute_neighbors(
positions, cell, pbc, fractional)
# Extract neighbors of specific atom
edge_index = result["edge_index"]
if full_star:
# Include both directions: where atom_idx is source AND target
# Source edges (atom_idx -> neighbors)
mask_source = edge_index[0] == atom_idx
neighbor_indices_source = edge_index[1][mask_source]
distances_source = result["distances"][mask_source]
offsets_source = (
result["offsets"][mask_source]
if result["offsets"] is not None else None
)
# Target edges (neighbors -> atom_idx)
# For these, we need to flip the direction
mask_target = edge_index[1] == atom_idx
neighbor_indices_target = edge_index[0][mask_target]
distances_target = result["distances"][mask_target]
offsets_target = (
-result["offsets"][mask_target] # Flip offset direction
if result["offsets"] is not None else None
)
# Combine both sets of neighbors
neighbor_indices = torch.cat(
[neighbor_indices_source, neighbor_indices_target])
distances = torch.cat([distances_source, distances_target])
if offsets_source is not None and offsets_target is not None:
offsets = torch.cat([offsets_source, offsets_target])
# Remove duplicates based on (neighbor_index, offset) pairs
# Create unique keys by combining index and offset
unique_mask = []
seen = set()
for i in range(len(neighbor_indices)):
idx = neighbor_indices[i].item()
off = tuple(offsets[i].cpu().numpy())
key = (idx, off)
if key not in seen:
seen.add(key)
unique_mask.append(True)
else:
unique_mask.append(False)
unique_mask = torch.tensor(unique_mask, device=self.device)
neighbor_indices = neighbor_indices[unique_mask]
distances = distances[unique_mask]
offsets = offsets[unique_mask]
else:
offsets = None
else:
# Half-star: only edges where atom_idx is the source
mask = edge_index[0] == atom_idx
neighbor_indices = edge_index[1][mask]
distances = result["distances"][mask]
offsets = (
result["offsets"][mask]
if result["offsets"] is not None else None
)
# Apply type-specific cutoff filtering if applicable
if types is not None and cutoffs is not None:
filter_mask = self._filter_by_type_cutoff(
atom_idx, neighbor_indices, distances, types, cutoffs
)
neighbor_indices = neighbor_indices[filter_mask]
distances = distances[filter_mask]
if offsets is not None:
offsets = offsets[filter_mask]
# Compute coordinates if requested
coordinates = None
if return_coordinates:
coordinates = positions[neighbor_indices]
if offsets is not None and cell is not None:
# Apply PBC offsets: convert to float and apply cell matrix
coordinates = coordinates + (offsets.to(self.dtype) @ cell)
result_dict = {
"indices": neighbor_indices,
"distances": distances,
"offsets": offsets,
}
if return_coordinates:
result_dict["coordinates"] = coordinates
return result_dict
def get_neighbors_by_atom(
self,
positions: torch.Tensor,
cell: Optional[torch.Tensor] = None,
pbc: Optional[torch.Tensor] = None,
atom_types: Optional[torch.Tensor] = None,
cutoff_dict: Optional[Dict[Tuple[int, int], float]] = None,
) -> List[Dict[str, Optional[torch.Tensor]]]:
"""
Get neighbors for all atoms in structured format.
Args:
positions: (N, 3) atom positions
cell: (3, 3) lattice vectors (None for isolated systems)
pbc: (3,) PBC flags
atom_types: Override stored atom_types (optional)
cutoff_dict: Override stored cutoff_dict (optional)
Returns
-------
List of length N_atoms, where each element is a dict:
- 'indices': neighbor indices
- 'distances': distances
- 'offsets': cell offsets
"""
n_atoms = positions.shape[0]
result = []
for i in range(n_atoms):
neighbors = self.get_neighbors_of_atom(
i, positions, cell, pbc, atom_types, cutoff_dict
)
result.append(neighbors)
return result
def _filter_by_type_cutoff(
self,
source_idx: int,
neighbor_indices: torch.Tensor,
distances: torch.Tensor,
atom_types: torch.Tensor,
cutoff_dict: Dict[Tuple[int, int], float],
) -> torch.Tensor:
"""
Filter neighbors based on type-dependent cutoffs.
Args:
source_idx: Index of source atom
neighbor_indices: Indices of neighbor atoms
distances: Distances to neighbors
atom_types: Atom types tensor
cutoff_dict: Dictionary of pair cutoffs
Returns
-------
Boolean mask for neighbors within type-specific cutoffs
"""
source_type = atom_types[source_idx].item()
mask = torch.zeros(
len(neighbor_indices), dtype=torch.bool, device=self.device
)
for i, (neigh_idx, dist) in enumerate(
zip(neighbor_indices, distances)
):
neigh_type = atom_types[neigh_idx].item()
# Use sorted tuple as key
pair = tuple(sorted([source_type, neigh_type]))
pair_cutoff = cutoff_dict.get(pair, self.cutoff)
mask[i] = dist <= pair_cutoff
return mask
def _get_or_compute_neighbors(
self,
positions: torch.Tensor,
cell: Optional[torch.Tensor],
pbc: Optional[torch.Tensor],
fractional: bool = True,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Get cached neighbor list or compute new one.
Args:
positions: Atom positions
cell: Lattice vectors
pbc: PBC flags
fractional: Whether positions are fractional or Cartesian
Returns
-------
Neighbor list result dictionary
"""
# Compute cache key
cache_key = self._compute_cache_key(positions, cell, pbc, fractional)
# Return cached result if available
if cache_key == self._cache_key and self._cached_result is not None:
return self._cached_result
# Compute new result
result = self.get_neighbors(positions, cell, pbc, fractional)
# Update cache
self._cache_key = cache_key
self._cached_result = result
return result
def _compute_cache_key(
self,
positions: torch.Tensor,
cell: Optional[torch.Tensor],
pbc: Optional[torch.Tensor],
fractional: bool = True,
) -> Tuple:
"""
Compute cache key for positions/cell/pbc combination.
Args:
positions: Atom positions
cell: Lattice vectors
pbc: PBC flags
fractional: Whether positions are fractional or Cartesian
Returns
-------
Tuple that can be used as cache key
"""
# Use hash of tensor data
pos_hash = hash(positions.cpu().numpy().tobytes())
cell_hash = (
hash(cell.cpu().numpy().tobytes()) if cell is not None else None
)
pbc_hash = (
hash(pbc.cpu().numpy().tobytes()) if pbc is not None else None
)
return (pos_hash, cell_hash, pbc_hash, fractional)
def __repr__(self) -> str:
return (
f"TorchNeighborList(cutoff={self.cutoff}, "
f"device='{self.device}', dtype={self.dtype}, "
f"pbc_backend='{getattr(self, 'pbc_backend', 'ghost')}')"
)