"""
Feature and energy normalization for PyTorch training.
Handles computation and application of normalization statistics.
"""
from typing import Any, Dict, Optional
import torch
from torch.utils.data import DataLoader
# Progress bar (match aenet.mlip behavior)
try:
from tqdm import tqdm # type: ignore
except Exception:
tqdm = None # type: ignore
[docs]
class NormalizationManager:
"""
Manages feature and energy normalization for training.
Handles:
- Computing feature statistics (mean/std) from training data
- Computing energy normalization (shift/scaling) from training data
- Applying normalization during training and inference
- Storing and retrieving normalization parameters
Parameters
----------
normalize_features : bool
Whether to normalize features.
normalize_energy : bool
Whether to normalize energy targets.
dtype : torch.dtype
Data type for statistics.
device : torch.device
Device for statistics tensors.
"""
def __init__(
self,
normalize_features: bool = True,
normalize_energy: bool = True,
dtype: torch.dtype = torch.float64,
device: torch.device = torch.device("cpu"),
):
self._normalize_features = normalize_features
self._normalize_energy = normalize_energy
self.dtype = dtype
self.device = device
# Feature normalization statistics
self.feature_mean: Optional[torch.Tensor] = None
self.feature_std: Optional[torch.Tensor] = None
# Track exact feature range observed during stats computation
self.feature_min: Optional[torch.Tensor] = None
self.feature_max: Optional[torch.Tensor] = None
# Energy normalization statistics
self.E_shift: float = 0.0
self.E_scaling: float = 1.0
@property
def normalize_features(self) -> bool:
"""Whether feature normalization is enabled."""
return self._normalize_features
@property
def normalize_energy(self) -> bool:
"""Whether energy normalization is enabled."""
return self._normalize_energy
[docs]
def has_feature_stats(self) -> bool:
"""Check if feature statistics have been computed or set."""
return self.feature_mean is not None and self.feature_std is not None
[docs]
def set_feature_stats(self, mean, std):
"""
Set feature normalization statistics from provided values.
Parameters
----------
mean : array-like
Feature means.
std : array-like
Feature standard deviations.
"""
self.feature_mean = torch.as_tensor(
mean, dtype=self.dtype, device=self.device
).view(-1)
self.feature_std = torch.as_tensor(
std, dtype=self.dtype, device=self.device
).view(-1)
[docs]
def set_energy_stats(self, shift: float, scaling: float):
"""
Set energy normalization parameters.
Parameters
----------
shift : float
Energy shift (per-atom midpoint).
scaling : float
Energy scaling factor.
"""
self.E_shift = float(shift)
self.E_scaling = float(scaling)
[docs]
def compute_feature_stats(
self,
dataloader: DataLoader,
n_features: int,
provided_stats: Optional[Dict[str, Any]] = None,
show_progress: bool = True,
):
"""
Compute or load feature normalization statistics.
Parameters
----------
dataloader : DataLoader
DataLoader for training data (no shuffle needed).
n_features : int
Number of feature dimensions.
provided_stats : dict, optional
Pre-computed statistics with 'mean' and 'std'/'cov' keys.
If provided, these will be used instead of computing from data.
show_progress : bool, optional
Whether to show progress bar during computation. Default: True
"""
if not self._normalize_features:
return
# Use provided stats if available
if provided_stats is not None:
mean_np = provided_stats.get("mean", None)
std_np = provided_stats.get("std", None)
if std_np is None:
std_np = provided_stats.get("cov", None)
if mean_np is not None and std_np is not None:
self.feature_mean = torch.as_tensor(
mean_np, dtype=self.dtype, device=self.device
).view(-1)
self.feature_std = torch.as_tensor(
std_np, dtype=self.dtype, device=self.device
).view(-1)
return
# Compute from training data with progress bar
sum_f = torch.zeros(n_features, dtype=self.dtype, device="cpu")
sumsq_f = torch.zeros(n_features, dtype=self.dtype, device="cpu")
# Track exact min/max across the dataset
min_f = torch.full(
(n_features,),
float("inf"),
dtype=self.dtype,
device="cpu",
)
max_f = torch.full(
(n_features,),
float("-inf"),
dtype=self.dtype,
device="cpu",
)
total_atoms = 0
# Wrap dataloader with progress bar
batch_iter = dataloader
if show_progress and tqdm is not None:
try:
total = len(dataloader)
except Exception:
total = None
batch_iter = tqdm(
dataloader,
total=total,
desc="Computing feature stats",
ncols=80,
leave=False
)
with torch.no_grad():
for batch in batch_iter:
feats = batch["features"]
feats = feats.to(dtype=self.dtype, device="cpu")
sum_f += feats.sum(dim=0).cpu()
sumsq_f += (feats * feats).sum(dim=0).cpu()
# Update min/max across atoms in this batch
try:
bmin = feats.min(dim=0).values.cpu()
bmax = feats.max(dim=0).values.cpu()
min_f = torch.minimum(min_f, bmin)
max_f = torch.maximum(max_f, bmax)
except Exception:
pass
total_atoms += int(feats.shape[0])
if total_atoms > 0:
mean = sum_f / float(total_atoms)
var = torch.clamp(
sumsq_f / float(total_atoms) - mean * mean, min=0.0
)
std = torch.sqrt(var + torch.as_tensor(1e-12, dtype=var.dtype))
self.feature_mean = mean.to(device=self.device)
self.feature_std = std.to(device=self.device)
# Persist exact min/max if computed
try:
self.feature_min = min_f.to(device=self.device)
self.feature_max = max_f.to(device=self.device)
except Exception:
self.feature_min = None
self.feature_max = None
[docs]
def compute_energy_stats(
self,
dataloader: DataLoader,
atomic_energies_by_index: Optional[torch.Tensor] = None,
provided_shift: Optional[float] = None,
provided_scaling: Optional[float] = None,
show_progress: bool = True,
):
"""
Compute energy normalization statistics.
Computes per-atom energy min/max/avg and derives shift/scaling
to normalize to [-1, 1] range (matching aenet convention).
Parameters
----------
dataloader : DataLoader
DataLoader for training data.
atomic_energies_by_index : torch.Tensor, optional
Per-species atomic energies, indexed by species_indices.
Always applied to subtract from total energies. If all zeros,
statistics are computed on total energies.
provided_shift : float, optional
Override computed E_shift.
provided_scaling : float, optional
Override computed E_scaling.
show_progress : bool, optional
Whether to show progress bar during computation. Default: True
"""
if not self.normalize_energy:
return
# Use provided values if available
if provided_shift is not None:
self.E_shift = float(provided_shift)
if provided_scaling is not None:
self.E_scaling = float(provided_scaling)
# If both provided, no need to compute
if provided_shift is not None and provided_scaling is not None:
return
# Compute from data with progress bar
e_min = None
e_max = None
e_sum = 0.0
n_struct = 0
# Wrap dataloader with progress bar
batch_iter = dataloader
if show_progress and tqdm is not None:
try:
total = len(dataloader)
except Exception:
total = None
batch_iter = tqdm(
dataloader,
total=total,
desc="Computing energy stats",
ncols=80,
leave=False
)
with torch.no_grad():
for batch in batch_iter:
n_atoms_b = batch["n_atoms"].to(self.device)
energy_ref_b = batch["energy_ref"].to(
self.device, dtype=self.dtype
)
species_indices_b = batch["species_indices"].to(self.device)
# Subtract atomic reference energies (always if provided)
energy_target_b = energy_ref_b
if atomic_energies_by_index is not None:
per_atom_Ea_b = atomic_energies_by_index[species_indices_b]
batch_idx_b = torch.repeat_interleave(
torch.arange(len(n_atoms_b), device=self.device),
n_atoms_b.long(),
)
Ea_sum_b = torch.zeros(
len(n_atoms_b),
dtype=energy_ref_b.dtype,
device=self.device,
)
Ea_sum_b.scatter_add_(0, batch_idx_b, per_atom_Ea_b)
energy_target_b = energy_ref_b - Ea_sum_b
# Per-atom energies
e_pa = energy_target_b / n_atoms_b
# Update stats
batch_min = float(torch.min(e_pa).item())
batch_max = float(torch.max(e_pa).item())
e_min = batch_min if e_min is None else min(e_min, batch_min)
e_max = batch_max if e_max is None else max(e_max, batch_max)
e_sum += float(torch.sum(e_pa).item())
n_struct += int(len(n_atoms_b))
if e_min is not None and e_max is not None and e_max > e_min:
# Normalize to [-1, 1] range
if provided_scaling is None:
self.E_scaling = float(2.0 / (e_max - e_min))
if provided_shift is None:
self.E_shift = float(0.5 * (e_max + e_min))
else:
# Degenerate case: disable energy normalization
self.E_scaling = 1.0
self.E_shift = 0.0
[docs]
def apply_feature_normalization(
self, features: torch.Tensor
) -> torch.Tensor:
"""
Apply feature normalization.
Parameters
----------
features : torch.Tensor
Raw features, shape (N, F).
Returns
-------
torch.Tensor
Normalized features if normalization is enabled,
otherwise returns input unchanged.
"""
if (
not self.normalize_features
or self.feature_mean is None
or self.feature_std is None
):
return features
mean = self.feature_mean.to(
device=features.device, dtype=features.dtype
)
std = torch.clamp(
self.feature_std.to(device=features.device, dtype=features.dtype),
min=1e-12,
)
return (features - mean) / std
[docs]
def apply_energy_normalization(self, energy: torch.Tensor) -> torch.Tensor:
"""
Apply energy normalization (model output space).
Parameters
----------
energy : torch.Tensor
Raw per-atom energies summed over structure.
Returns
-------
torch.Tensor
Normalized energy.
"""
if not self.normalize_energy:
return energy
return energy # Already in normalized space during training
[docs]
def denormalize_energy(
self, energy_norm: torch.Tensor, n_atoms: int
) -> float:
"""
Convert normalized model output to physical energy.
Parameters
----------
energy_norm : torch.Tensor
Normalized energy from model (sum of per-atom energies).
n_atoms : int
Number of atoms in structure.
Returns
-------
float
Physical energy.
"""
if not self.normalize_energy:
return float(energy_norm.detach().cpu())
return float(
(energy_norm / self.E_scaling + self.E_shift * n_atoms)
.detach()
.cpu()
)
[docs]
def get_state(self) -> Dict[str, Any]:
"""
Get normalization state for serialization.
Returns
-------
dict
Dictionary containing normalization parameters.
"""
state = {
"normalize_features": self.normalize_features,
"normalize_energy": self.normalize_energy,
"E_shift": self.E_shift,
"E_scaling": self.E_scaling,
}
if self.feature_mean is not None:
state["feature_mean"] = self.feature_mean.cpu().numpy()
if self.feature_std is not None:
state["feature_std"] = self.feature_std.cpu().numpy()
# Include exact feature range if available
if self.feature_min is not None:
state["feature_min"] = self.feature_min.cpu().numpy()
if self.feature_max is not None:
state["feature_max"] = self.feature_max.cpu().numpy()
return state
[docs]
def set_state(self, state: Dict[str, Any]):
"""
Restore normalization state from dictionary.
Parameters
----------
state : dict
Dictionary containing normalization parameters.
"""
self._normalize_features = state.get("normalize_features", True)
self._normalize_energy = state.get("normalize_energy", True)
self.E_shift = float(state.get("E_shift", 0.0))
self.E_scaling = float(state.get("E_scaling", 1.0))
if "feature_mean" in state:
self.feature_mean = torch.as_tensor(
state["feature_mean"], dtype=self.dtype, device=self.device
)
if "feature_std" in state:
self.feature_std = torch.as_tensor(
state["feature_std"], dtype=self.dtype, device=self.device
)
if "feature_min" in state:
self.feature_min = torch.as_tensor(
state["feature_min"], dtype=self.dtype, device=self.device
)
if "feature_max" in state:
self.feature_max = torch.as_tensor(
state["feature_max"], dtype=self.dtype, device=self.device
)