Training Utilities
The training package provides utilities for managing the training process,
including checkpointing, normalization, metrics tracking, and training loop execution.
Checkpoint Manager
- class aenet.torch_training.training.CheckpointManager(checkpoint_dir: str | None = None, max_to_keep: int | None = None, save_best: bool = True)[source]
Bases:
objectManages checkpoint operations for training.
Handles: - Saving checkpoints with full training state - Loading checkpoints to resume training - Rotating old checkpoints to limit disk usage - Tracking best model based on validation metrics
- Parameters:
checkpoint_dir (str or Path, optional) – Directory to save checkpoints. If None, checkpoints are disabled.
max_to_keep (int, optional) – Maximum number of checkpoints to keep. Older ones are deleted.
save_best (bool) – Whether to save the best model separately.
- infer_start_epoch(checkpoint_path: str, payload: Dict[str, Any] | None = None) int[source]
Infer starting epoch from checkpoint metadata or filename.
- Parameters:
checkpoint_path (str) – Path to checkpoint file.
payload (dict, optional) – Loaded checkpoint payload. When present,
payload["epoch"]is used as the authoritative epoch number.
- Returns:
Starting epoch (checkpoint epoch + 1), or 0 if it cannot be inferred.
- Return type:
int
- load_checkpoint(path: str, model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) Dict[str, Any][source]
Load a checkpoint and restore training state.
- Parameters:
path (str) – Path to checkpoint file.
model (nn.Module) – Model to load state into.
optimizer (torch.optim.Optimizer) – Optimizer to load state into.
device (torch.device) – Device to map tensors to.
- Returns:
Checkpoint payload containing epoch, history, etc.
- Return type:
dict
- Raises:
RuntimeError – If checkpoint loading fails.
- save_best_model(trainer, optimizer: torch.optim.Optimizer, epoch: int, training_config: Any | None = None)[source]
Save the best model checkpoint using the unified format.
- save_checkpoint(trainer, optimizer: torch.optim.Optimizer, epoch: int, training_config: Any | None = None, filename: str | None = None)[source]
Save a training checkpoint using the unified model format.
- Parameters:
trainer (TorchANNPotential) – Trainer instance to save.
optimizer (torch.optim.Optimizer) – Optimizer to save.
epoch (int) – Current epoch number.
training_config (TorchTrainingConfig, optional) – Training configuration.
filename (str, optional) – Filename for checkpoint. If None, uses format “checkpoint_epoch_{epoch:04d}.pt”
Normalization Manager
- class aenet.torch_training.training.NormalizationManager(normalize_features: bool = True, normalize_energy: bool = True, dtype: torch.dtype = torch.float64, device: torch.device = torch.device)[source]
Bases:
objectManages feature and energy normalization for training.
Handles: - Computing feature statistics (mean/std) from training data - Computing energy normalization (shift/scaling) from training data - Applying normalization during training and inference - Storing and retrieving normalization parameters
- Parameters:
normalize_features (bool) – Whether to normalize features.
normalize_energy (bool) – Whether to normalize energy targets.
dtype (torch.dtype) – Data type for statistics.
device (torch.device) – Device for statistics tensors.
- apply_energy_normalization(energy: torch.Tensor) torch.Tensor[source]
Apply energy normalization (model output space).
- Parameters:
energy (torch.Tensor) – Raw per-atom energies summed over structure.
- Returns:
Normalized energy.
- Return type:
torch.Tensor
- apply_feature_normalization(features: torch.Tensor) torch.Tensor[source]
Apply feature normalization.
- Parameters:
features (torch.Tensor) – Raw features, shape (N, F).
- Returns:
Normalized features if normalization is enabled, otherwise returns input unchanged.
- Return type:
torch.Tensor
- compute_energy_stats(dataloader: torch.utils.data.DataLoader, atomic_energies_by_index: torch.Tensor | None = None, provided_shift: float | None = None, provided_scaling: float | None = None, show_progress: bool = True)[source]
Compute energy normalization statistics.
Computes per-atom energy min/max/avg and derives shift/scaling to normalize to [-1, 1] range (matching aenet convention).
- Parameters:
dataloader (DataLoader) – DataLoader for training data.
atomic_energies_by_index (torch.Tensor, optional) – Per-species atomic energies, indexed by species_indices. Always applied to subtract from total energies. If all zeros, statistics are computed on total energies.
provided_shift (float, optional) – Override computed E_shift.
provided_scaling (float, optional) – Override computed E_scaling.
show_progress (bool, optional) – Whether to show progress bar during computation. Default: True
- compute_feature_stats(dataloader: torch.utils.data.DataLoader, n_features: int, provided_stats: Dict[str, Any] | None = None, show_progress: bool = True)[source]
Compute or load feature normalization statistics.
- Parameters:
dataloader (DataLoader) – DataLoader for training data (no shuffle needed).
n_features (int) – Number of feature dimensions.
provided_stats (dict, optional) – Pre-computed statistics with ‘mean’ and ‘std’/’cov’ keys. If provided, these will be used instead of computing from data.
show_progress (bool, optional) – Whether to show progress bar during computation. Default: True
- denormalize_energy(energy_norm: torch.Tensor, n_atoms: int) float[source]
Convert normalized model output to physical energy.
- Parameters:
energy_norm (torch.Tensor) – Normalized energy from model (sum of per-atom energies).
n_atoms (int) – Number of atoms in structure.
- Returns:
Physical energy.
- Return type:
float
- get_state() Dict[str, Any][source]
Get normalization state for serialization.
- Returns:
Dictionary containing normalization parameters.
- Return type:
dict
- property normalize_energy: bool
Whether energy normalization is enabled.
- property normalize_features: bool
Whether feature normalization is enabled.
- set_energy_stats(shift: float, scaling: float)[source]
Set energy normalization parameters.
- Parameters:
shift (float) – Energy shift (per-atom midpoint).
scaling (float) – Energy scaling factor.
Metrics Tracker
- class aenet.torch_training.training.MetricsTracker(track_detailed_timing: bool = False)[source]
Bases:
objectTracks training metrics and history.
Maintains per-epoch metrics for energy RMSE, force RMSE, learning rate, and timing information.
- Parameters:
track_detailed_timing (bool) – Whether to track detailed timing breakdowns (data loading, loss computation, optimizer steps).
- get_best(metric: str, mode: str = 'min') float[source]
Get the best value for a specific metric.
- Parameters:
metric (str) – Metric name.
mode (str) – Either ‘min’ or ‘max’.
- Returns:
Best value, or NaN if metric not found or empty.
- Return type:
float
- get_history() Dict[str, List[float]][source]
Get complete training history.
- Returns:
Dictionary mapping metric names to lists of per-epoch values.
- Return type:
dict
- get_latest(metric: str) float[source]
Get the latest value for a specific metric.
- Parameters:
metric (str) – Metric name (e.g., ‘train_energy_rmse’).
- Returns:
Latest value, or NaN if metric not found or empty.
- Return type:
float
- replace_latest(train_energy_rmse: float, train_energy_mae: float, train_force_rmse: float, test_energy_rmse: float, test_energy_mae: float, test_force_rmse: float) None[source]
Replace the latest stored metric values.
This is used when the trainer wants to recompute final epoch metrics with a deterministic full-pass evaluation after the optimization loop has finished.
- Parameters:
train_energy_rmse (float) – Replacement training energy RMSE.
train_energy_mae (float) – Replacement training energy MAE.
train_force_rmse (float) – Replacement training force RMSE.
test_energy_rmse (float) – Replacement validation energy RMSE.
test_energy_mae (float) – Replacement validation energy MAE.
test_force_rmse (float) – Replacement validation force RMSE.
- set_history(history: Dict[str, List[float]])[source]
Set history from a dictionary (e.g., when loading checkpoint).
- Parameters:
history (dict) – Dictionary mapping metric names to lists of values.
- update(train_energy_rmse: float, train_energy_mae: float, train_force_rmse: float, test_energy_rmse: float, test_energy_mae: float, test_force_rmse: float, learning_rate: float, epoch_time: float, forward_time: float, backward_time: float, train_timing: Dict[str, float], val_timing: Dict[str, float])[source]
Record metrics for one epoch.
- Parameters:
train_energy_rmse (float) – Training energy RMSE.
train_energy_mae (float) – Training energy MAE.
train_force_rmse (float) – Training force RMSE (or NaN if not computed).
test_energy_rmse (float) – Validation energy RMSE (or NaN if no validation).
test_energy_mae (float) – Validation energy MAE (or NaN if no validation).
test_force_rmse (float) – Validation force RMSE (or NaN if not computed).
learning_rate (float) – Current learning rate.
epoch_time (float) – Total epoch time in seconds.
forward_time (float) – Forward pass time in seconds.
backward_time (float) – Backward pass time in seconds.
train_timing (dict) – Detailed training timing breakdown with keys: ‘data_loading’, ‘loss_computation’, ‘optimizer’, ‘total’
val_timing (dict) – Detailed validation timing breakdown (same keys as train_timing).
Training Loop
- class aenet.torch_training.training.TrainingLoop(model, descriptor, normalizer: NormalizationManager, device: torch.device, dtype: torch.dtype)[source]
Bases:
objectExecutes training and validation epochs.
Handles: - Batch iteration and processing - Loss computation (energy + forces) - Backpropagation and optimization - Timing and metrics collection
- Parameters:
model (nn.Module) – Model to train (EnergyModelAdapter).
descriptor (ChebyshevDescriptor) – Descriptor for force computation.
normalizer (NormalizationManager) – Normalization manager for features and energies.
device (torch.device) – Device for computation.
dtype (torch.dtype) – Data type for tensors.
- run_epoch(loader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer | None, alpha: float, atomic_energies_by_index: torch.Tensor | None = None, train: bool = True, show_batch_progress: bool = False, force_scale_unbiased: bool = False, collect_structure_scores: bool = False) Tuple[float, float, float, Dict[str, float], dict[int, float] | None][source]
Run one epoch over loader.
- Parameters:
loader (DataLoader) – DataLoader for batches.
optimizer (torch.optim.Optimizer or None) – Optimizer for training (None for validation).
alpha (float) – Force loss weight (0.0 = energy only, 1.0 = forces only).
atomic_energies_by_index (torch.Tensor, optional) – Atomic energies indexed by species. Always applied to convert total energies to cohesive. If not provided or all zeros, training targets remain as total energies.
train (bool) – Whether this is training (True) or validation (False).
show_batch_progress (bool) – Whether to show per-batch progress bar.
force_scale_unbiased (bool) – If True, apply optional sqrt(1/f) scaling to the per-batch force RMSE, where f is the supervised fraction of atoms with available force labels. This approximates constant loss magnitude when sub-sampling forces.
collect_structure_scores (bool) – If True, accumulate per-structure energy-error scores for adaptive sampling updates.
- Returns:
energy_rmse (float) – Energy RMSE for epoch.
energy_mae (float) – Energy MAE for epoch.
force_rmse (float) – Force RMSE for epoch (NaN if no forces).
timing (dict) – Timing breakdown with keys: ‘data_loading’, ‘loss_computation’, ‘optimizer’, ‘total’.
structure_scores (dict[int, float] or None) – Per-structure mean scores keyed by split-local dataset index. Returned only when
collect_structure_scores=True.