Source code for aenet.torch_training.inference.predictor

"""
Prediction and inference for trained PyTorch models.

Handles energy and force prediction on new structures.
"""

from typing import Any, Dict, List, Optional, Tuple

import torch
from torch.utils.data import DataLoader
from .datasets import (
    EnergyInferenceDataset,
    energy_collate,
    dataset_energy_collate,
    ForceInferenceDataset,
    force_collate,
)

from ..config import Structure
from ..loss import compute_force_loss
from ..training.normalization import NormalizationManager


[docs] class Predictor: """ Handles prediction for trained models. Parameters ---------- model : nn.Module Trained model (EnergyModelAdapter). descriptor : ChebyshevDescriptor Descriptor for featurization. normalizer : NormalizationManager Normalization manager with trained statistics. atomic_energies : dict, optional Atomic reference energies per species. These are added back to model predictions to get total energies. Defaults to zeros if not provided. device : torch.device Device for computation. dtype : torch.dtype Data type for tensors. """ def __init__( self, model, descriptor, normalizer: NormalizationManager, atomic_energies: Optional[Dict[str, float]] = None, device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.float64, ): self.model = model self.descriptor = descriptor self.normalizer = normalizer self.atomic_energies = atomic_energies self.device = device self.dtype = dtype # Build atomic_energies_by_index (defaults to zeros if not provided) e_list = [] for s in descriptor.species: if atomic_energies is not None: e_list.append(float(atomic_energies.get(s, 0.0))) else: e_list.append(0.0) self.atomic_energies_by_index = torch.tensor( e_list, dtype=dtype, device=device )
[docs] def predict( self, structures: List[Structure], eval_forces: bool = False, return_atom_energies: bool = False, track_timing: bool = False, batch_size: Optional[int] = None, num_workers: Optional[int] = None, prefetch_factor: Optional[int] = None, persistent_workers: Optional[bool] = None, ) -> Tuple[ List[float], Optional[List[torch.Tensor]], Optional[List[torch.Tensor]], Optional[Dict[str, List[float]]], ]: """ Predict energies (and optionally forces) for structures. Parameters ---------- structures : list of Structure Structures to predict on. eval_forces : bool Whether to also predict forces. return_atom_energies : bool Whether to return per-atom energies. track_timing : bool Whether to track and return timing information. Returns ------- energies : list of float Predicted energies (total energies if energy_target='cohesive', otherwise model outputs directly). forces : list of Tensor or None Predicted forces (N_i, 3) per structure if requested, otherwise None. atom_energies : list of Tensor or None Per-atom energies (N_i,) per structure if requested, otherwise None. timing : dict or None Timing information per structure if requested, otherwise None. Keys: 'featurization', 'energy_eval', 'force_eval', 'total' """ import time as time_module energies: List[float] = [] forces_out: List[torch.Tensor] = [] atom_energies_out: List[torch.Tensor] = [] # Initialize timing dict if requested timing_dict: Optional[Dict[str, List[float]]] = None if track_timing: timing_dict = { 'featurization': [], 'energy_eval': [], 'force_eval': [], 'total': [] } # DataLoader-powered energy-only path if (not eval_forces) and (batch_size is not None): # Defaults for dataloader knobs nw = int(num_workers) if num_workers is not None else 0 pf = int(prefetch_factor) if prefetch_factor is not None else 2 pw = (bool(persistent_workers) if persistent_workers is not None else True) ds = EnergyInferenceDataset(structures, self.descriptor) dl_kwargs: Dict[str, Any] = dict(num_workers=nw) if nw > 0: dl_kwargs.update(prefetch_factor=pf, persistent_workers=pw) loader = DataLoader( ds, batch_size=int(batch_size), shuffle=False, collate_fn=energy_collate, **dl_kwargs, ) import time as _time for batch in loader: t_b0 = _time.time() if track_timing else 0.0 feats = batch["features"].to( device=self.device, dtype=self.dtype) species_idx = batch["species_indices"].to(device=self.device) n_atoms_b = batch["n_atoms"].to(device=self.device) # Feature normalization (on device) feats = self.normalizer.apply_feature_normalization(feats) t_energy_start = _time.time() if track_timing else 0.0 E_atomic_b = self.model(feats, species_idx) # Sum per-structure in normalized space batch_idx = torch.repeat_interleave( torch.arange(len(n_atoms_b), device=self.device), n_atoms_b.long(), ) energy_pred_norm_b = torch.zeros( len(n_atoms_b), dtype=feats.dtype, device=self.device ) energy_pred_norm_b.scatter_add_( 0, batch_idx, E_atomic_b.squeeze()) t_energy_end = _time.time() if track_timing else 0.0 # Denormalize and convert to total energy if needed offsets = [0] for n in n_atoms_b.tolist(): offsets.append(offsets[-1] + int(n)) for j in range(len(n_atoms_b)): E_pred = self.normalizer.denormalize_energy( energy_pred_norm_b[j], int(n_atoms_b[j].item()) ) # Add atomic reference energies to get total energy sl = slice(offsets[j], offsets[j + 1]) E_atoms_sum = float( self.atomic_energies_by_index[species_idx[sl]] .sum() .detach() .cpu() ) E_total = E_pred + E_atoms_sum energies.append(E_total) # Per-atom energies if requested if return_atom_energies: sl = slice(offsets[j], offsets[j + 1]) E_atomic_denorm = ( E_atomic_b[sl] * self.normalizer.E_scaling) atom_energies_out.append( E_atomic_denorm.detach().cpu()) if track_timing and timing_dict is not None: t_b1 = _time.time() per = (t_b1 - t_b0) / float(len(n_atoms_b)) for _ in range(len(n_atoms_b)): timing_dict["featurization"].append(0.0) timing_dict["energy_eval"].append( t_energy_end - t_energy_start) timing_dict["force_eval"].append(0.0) timing_dict["total"].append(per) return ( energies, None, atom_energies_out if return_atom_energies else None, timing_dict, ) # DataLoader-powered batched forces path if eval_forces and (batch_size is not None): nw = int(num_workers) if num_workers is not None else 0 pf = int(prefetch_factor) if prefetch_factor is not None else 2 pw = (bool(persistent_workers) if persistent_workers is not None else True) ds_f = ForceInferenceDataset(structures, self.descriptor) dl_kwargs_f: Dict[str, Any] = dict(num_workers=nw) if nw > 0: dl_kwargs_f.update(prefetch_factor=pf, persistent_workers=pw) loader_f = DataLoader( ds_f, batch_size=int(batch_size), shuffle=False, collate_fn=force_collate, **dl_kwargs_f, ) import time as _time for batch in loader_f: t_b0 = _time.time() if track_timing else 0.0 # Energy-view forward (batched over concatenated atoms) feats = batch["features"].to( device=self.device, dtype=self.dtype) species_idx = batch["species_indices"].to(device=self.device) n_atoms_b = batch["n_atoms"].to(device=self.device) feats = self.normalizer.apply_feature_normalization(feats) t_energy_start = _time.time() if track_timing else 0.0 E_atomic_b = self.model(feats, species_idx) batch_idx = torch.repeat_interleave( torch.arange(len(n_atoms_b), device=self.device), n_atoms_b.long(), ) energy_pred_norm_b = torch.zeros( len(n_atoms_b), dtype=feats.dtype, device=self.device ) energy_pred_norm_b.scatter_add_( 0, batch_idx, E_atomic_b.squeeze()) t_energy_end = _time.time() if track_timing else 0.0 # Prepare offsets for per-structure slicing offsets = [0] for n in n_atoms_b.tolist(): offsets.append(offsets[-1] + int(n)) # Add atomic reference energies to get total energy species_f = batch["species_f"] # list[str] length N_total for j in range(len(n_atoms_b)): E_pred = self.normalizer.denormalize_energy( energy_pred_norm_b[j], int(n_atoms_b[j].item()) ) sl = slice(offsets[j], offsets[j + 1]) E_atoms_sum = float( self.atomic_energies_by_index[species_idx[sl]] .sum() .detach() .cpu() ) E_total = E_pred + E_atoms_sum energies.append(E_total) if return_atom_energies: sl = slice(offsets[j], offsets[j + 1]) E_atomic_denorm = ( E_atomic_b[sl] * self.normalizer.E_scaling) atom_energies_out.append( E_atomic_denorm.detach().cpu()) # Batched force prediction using CSR/Triplets graph positions_f = batch["positions_f"].to( device=self.device, dtype=self.dtype ) species_idx_f = batch["species_indices_f"].to(self.device) graph_f = batch["graph_f"] triplets_f = batch["triplets_f"] # Move graph/triplets to device graph_dev = None if graph_f is not None: graph_dev = { "center_ptr": graph_f["center_ptr"].to(self.device), "nbr_idx": graph_f["nbr_idx"].to(self.device), "r_ij": graph_f["r_ij"].to( self.device, dtype=self.dtype), "d_ij": graph_f["d_ij"].to( self.device, dtype=self.dtype), } triplets_dev = None if triplets_f is not None: triplets_dev = { "tri_i": triplets_f["tri_i"].to(self.device), "tri_j": triplets_f["tri_j"].to(self.device), "tri_k": triplets_f["tri_k"].to(self.device), "tri_j_local": triplets_f["tri_j_local"].to( self.device), "tri_k_local": triplets_f["tri_k_local"].to( self.device), } t_force_start = _time.time() if track_timing else 0.0 forces_ref = torch.zeros_like(positions_f) with torch.enable_grad(): _, forces_pred = compute_force_loss( positions=positions_f.clone(), species=species_f, forces_ref=forces_ref, descriptor=self.descriptor, network=self.model, species_indices=species_idx_f, cell=None, pbc=None, E_scaling=float(self.normalizer.E_scaling), neighbor_info=None, chunk_size=None, feature_mean=( self.normalizer.feature_mean if self.normalizer.normalize_features else None ), feature_std=( self.normalizer.feature_std if self.normalizer.normalize_features else None ), graph=graph_dev, triplets=triplets_dev, ) t_force_end = _time.time() if track_timing else 0.0 # Split forces back per-structure using offsets for j in range(len(n_atoms_b)): sl = slice(offsets[j], offsets[j + 1]) forces_out.append(forces_pred[sl].detach().cpu()) if track_timing and timing_dict is not None: t_b1 = _time.time() per = (t_b1 - t_b0) / float(len(n_atoms_b)) for _ in range(len(n_atoms_b)): timing_dict["featurization"].append(0.0) timing_dict["energy_eval"].append( t_energy_end - t_energy_start) timing_dict["force_eval"].append( t_force_end - t_force_start) timing_dict["total"].append(per) return ( energies, forces_out, atom_energies_out if return_atom_energies else None, timing_dict, ) for st in structures: t_start = time_module.time() if track_timing else 0.0 # Build tensors positions = torch.from_numpy(st.positions).to(self.device) if self.dtype == torch.float64: positions = positions.double() else: positions = positions.float() # Prepare PBC tensors (if available) and featurize cell_torch = None pbc_torch = None if getattr(st, "cell", None) is not None: cell_torch = torch.as_tensor( st.cell, dtype=self.dtype, device=self.device ) if getattr(st, "pbc", None) is not None: pbc_torch = torch.as_tensor( st.pbc, dtype=torch.bool, device=self.device ) t_feat_start = time_module.time() if track_timing else 0.0 features, nb_info = self.descriptor.featurize_with_neighbor_info( positions, st.species, cell_torch, pbc_torch ) species_indices = torch.tensor( [self.descriptor.species_to_idx[s] for s in st.species], dtype=torch.long, device=self.device, ) if self.dtype == torch.float64: features = features.double() else: features = features.float() # Feature normalization features = self.normalizer.apply_feature_normalization(features) t_feat_end = time_module.time() if track_timing else 0.0 # Predict per-atom energies (normalized model output) t_energy_start = time_module.time() if track_timing else 0.0 E_atomic = self.model(features, species_indices) E_pred_norm = E_atomic.sum() # Denormalize model output E_pred = self.normalizer.denormalize_energy( E_pred_norm, len(st.species) ) t_energy_end = time_module.time() if track_timing else 0.0 # Add atomic reference energies to get total energy E_atoms_sum = float( self.atomic_energies_by_index[species_indices] .sum() .detach() .cpu() ) E_total = E_pred + E_atoms_sum energies.append(E_total) # Store per-atom energies if requested if return_atom_energies: # Denormalize per-atom energies E_atomic_denorm = E_atomic * self.normalizer.E_scaling atom_energies_out.append(E_atomic_denorm.detach().cpu()) t_force_start = 0.0 t_force_end = 0.0 if eval_forces: t_force_start = time_module.time() if track_timing else 0.0 # Use semi-analytical gradient path to predict forces neighbor_info = { "neighbor_lists": nb_info["neighbor_lists"], "neighbor_vectors": nb_info["neighbor_vectors"], } # Dummy zeros for forces_ref; we only want predictions forces_ref = torch.zeros_like(positions) # Enable gradients for force prediction (autograd for dE/dR) with torch.enable_grad(): _, forces_pred = compute_force_loss( positions=positions.clone(), species=st.species, forces_ref=forces_ref, descriptor=self.descriptor, network=self.model, species_indices=species_indices, cell=cell_torch, pbc=pbc_torch, E_scaling=float(self.normalizer.E_scaling), neighbor_info=neighbor_info, chunk_size=None, feature_mean=( self.normalizer.feature_mean if self.normalizer.normalize_features else None ), feature_std=( self.normalizer.feature_std if self.normalizer.normalize_features else None ), ) forces_out.append(forces_pred.detach().cpu()) t_force_end = time_module.time() if track_timing else 0.0 # Record timing for this structure if track_timing and timing_dict is not None: t_end = time_module.time() timing_dict['featurization'].append(t_feat_end - t_feat_start) timing_dict['energy_eval'].append( t_energy_end - t_energy_start) timing_dict['force_eval'].append( t_force_end - t_force_start if eval_forces else 0.0) timing_dict['total'].append(t_end - t_start) return ( energies, forces_out if eval_forces else None, atom_energies_out if return_atom_energies else None, timing_dict, )
[docs] def predict_dataset( self, dataset, eval_forces: bool = False, return_atom_energies: bool = False, track_timing: bool = False, batch_size: Optional[int] = None, num_workers: Optional[int] = None, prefetch_factor: Optional[int] = None, persistent_workers: Optional[bool] = None, ) -> Tuple[ List[float], Optional[List[torch.Tensor]], Optional[List[torch.Tensor]], Optional[Dict[str, List[float]]], Dict[str, Any], ]: """ Predict energies from a dataset that already yields featurized samples. This path is intended for PyTorch-only dataset-backed inference where reusing cached features is desirable. """ if eval_forces: raise NotImplementedError( "predict_dataset() currently supports energy-only inference." ) nw = int(num_workers) if num_workers is not None else 0 pf = int(prefetch_factor) if prefetch_factor is not None else 2 pw = (bool(persistent_workers) if persistent_workers is not None else True) bs = int(batch_size) if batch_size is not None else 32 dl_kwargs: Dict[str, Any] = dict(num_workers=nw) if nw > 0: dl_kwargs.update(prefetch_factor=pf, persistent_workers=pw) loader = DataLoader( dataset, batch_size=bs, shuffle=False, collate_fn=dataset_energy_collate, **dl_kwargs, ) import time as time_module energies: List[float] = [] atom_energies_out: List[torch.Tensor] = [] coords_out: List[Optional[Any]] = [] atom_types_out: List[List[str]] = [] names_out: List[Optional[str]] = [] timing_dict: Optional[Dict[str, List[float]]] = None if track_timing: timing_dict = { 'featurization': [], 'energy_eval': [], 'force_eval': [], 'total': [] } for batch in loader: t_b0 = time_module.time() if track_timing else 0.0 feats = batch["features"].to(device=self.device, dtype=self.dtype) species_idx = batch["species_indices"].to(device=self.device) n_atoms_b = batch["n_atoms"].to(device=self.device) feats = self.normalizer.apply_feature_normalization(feats) t_energy_start = time_module.time() if track_timing else 0.0 E_atomic_b = self.model(feats, species_idx) batch_idx = torch.repeat_interleave( torch.arange(len(n_atoms_b), device=self.device), n_atoms_b.long(), ) energy_pred_norm_b = torch.zeros( len(n_atoms_b), dtype=feats.dtype, device=self.device ) energy_pred_norm_b.scatter_add_( 0, batch_idx, E_atomic_b.squeeze() ) t_energy_end = time_module.time() if track_timing else 0.0 offsets = [0] for n in n_atoms_b.tolist(): offsets.append(offsets[-1] + int(n)) species_lists = batch["species_lists"] positions_list = batch["positions_list"] names_list = batch["names_list"] for j in range(len(n_atoms_b)): E_pred = self.normalizer.denormalize_energy( energy_pred_norm_b[j], int(n_atoms_b[j].item()) ) sl = slice(offsets[j], offsets[j + 1]) E_atoms_sum = float( self.atomic_energies_by_index[species_idx[sl]] .sum() .detach() .cpu() ) energies.append(E_pred + E_atoms_sum) atom_types_out.append(list(species_lists[j])) names_out.append( str(names_list[j]) if names_list[j] is not None else None ) pos = positions_list[j] if pos is None: coords_out.append(None) elif isinstance(pos, torch.Tensor): coords_out.append(pos.detach().cpu().numpy()) else: coords_out.append(pos) if return_atom_energies: E_atomic_denorm = ( E_atomic_b[sl] * self.normalizer.E_scaling ) atom_energies_out.append(E_atomic_denorm.detach().cpu()) if track_timing and timing_dict is not None: t_b1 = time_module.time() per = (t_b1 - t_b0) / float(len(n_atoms_b)) for _ in range(len(n_atoms_b)): timing_dict["featurization"].append(0.0) timing_dict["energy_eval"].append( t_energy_end - t_energy_start ) timing_dict["force_eval"].append(0.0) timing_dict["total"].append(per) metadata = { "coords": coords_out, "atom_types": atom_types_out, "names": names_out, } return ( energies, None, atom_energies_out if return_atom_energies else None, timing_dict, metadata, )
[docs] def cohesive_energy( self, structure: Structure, atomic_energies: Optional[Dict[str, float]] = None, ) -> float: """ Compute cohesive energy from a structure with total energy. Parameters ---------- structure : Structure Structure containing total energy and species list. atomic_energies : dict, optional Per-species atomic reference energies. If None, uses predictor's stored E_atomic. Returns ------- float Cohesive energy (total - sum of atomic reference energies). Raises ------ ValueError If no atomic energies are available. """ if atomic_energies is None: atomic_energies = self.atomic_energies if atomic_energies is None: raise ValueError( "Atomic energies not available. Provide atomic_energies or " "use a predictor with atomic_energies set." ) E_atoms = sum( float(atomic_energies.get(el, 0.0)) for el in structure.species ) return float(structure.energy) - E_atoms