Training Utilities

The training package provides utilities for managing the training process, including checkpointing, normalization, metrics tracking, and training loop execution.

Checkpoint Manager

class aenet.torch_training.training.CheckpointManager(checkpoint_dir: str | None = None, max_to_keep: int | None = None, save_best: bool = True)[source]

Bases: object

Manages checkpoint operations for training.

Handles: - Saving checkpoints with full training state - Loading checkpoints to resume training - Rotating old checkpoints to limit disk usage - Tracking best model based on validation metrics

Parameters:
  • checkpoint_dir (str or Path, optional) – Directory to save checkpoints. If None, checkpoints are disabled.

  • max_to_keep (int, optional) – Maximum number of checkpoints to keep. Older ones are deleted.

  • save_best (bool) – Whether to save the best model separately.

infer_start_epoch(checkpoint_path: str, payload: Dict[str, Any] | None = None) int[source]

Infer starting epoch from checkpoint metadata or filename.

Parameters:
  • checkpoint_path (str) – Path to checkpoint file.

  • payload (dict, optional) – Loaded checkpoint payload. When present, payload["epoch"] is used as the authoritative epoch number.

Returns:

Starting epoch (checkpoint epoch + 1), or 0 if it cannot be inferred.

Return type:

int

load_checkpoint(path: str, model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) Dict[str, Any][source]

Load a checkpoint and restore training state.

Parameters:
  • path (str) – Path to checkpoint file.

  • model (nn.Module) – Model to load state into.

  • optimizer (torch.optim.Optimizer) – Optimizer to load state into.

  • device (torch.device) – Device to map tensors to.

Returns:

Checkpoint payload containing epoch, history, etc.

Return type:

dict

Raises:

RuntimeError – If checkpoint loading fails.

save_best_model(trainer, optimizer: torch.optim.Optimizer, epoch: int, training_config: Any | None = None)[source]

Save the best model checkpoint using the unified format.

save_checkpoint(trainer, optimizer: torch.optim.Optimizer, epoch: int, training_config: Any | None = None, filename: str | None = None)[source]

Save a training checkpoint using the unified model format.

Parameters:
  • trainer (TorchANNPotential) – Trainer instance to save.

  • optimizer (torch.optim.Optimizer) – Optimizer to save.

  • epoch (int) – Current epoch number.

  • training_config (TorchTrainingConfig, optional) – Training configuration.

  • filename (str, optional) – Filename for checkpoint. If None, uses format “checkpoint_epoch_{epoch:04d}.pt”

should_save_best(val_loss: float) bool[source]

Check if current validation loss is the best so far.

Parameters:

val_loss (float) – Current validation loss.

Returns:

True if this is the best validation loss.

Return type:

bool

Normalization Manager

class aenet.torch_training.training.NormalizationManager(normalize_features: bool = True, normalize_energy: bool = True, dtype: torch.dtype = torch.float64, device: torch.device = torch.device)[source]

Bases: object

Manages feature and energy normalization for training.

Handles: - Computing feature statistics (mean/std) from training data - Computing energy normalization (shift/scaling) from training data - Applying normalization during training and inference - Storing and retrieving normalization parameters

Parameters:
  • normalize_features (bool) – Whether to normalize features.

  • normalize_energy (bool) – Whether to normalize energy targets.

  • dtype (torch.dtype) – Data type for statistics.

  • device (torch.device) – Device for statistics tensors.

apply_energy_normalization(energy: torch.Tensor) torch.Tensor[source]

Apply energy normalization (model output space).

Parameters:

energy (torch.Tensor) – Raw per-atom energies summed over structure.

Returns:

Normalized energy.

Return type:

torch.Tensor

apply_feature_normalization(features: torch.Tensor) torch.Tensor[source]

Apply feature normalization.

Parameters:

features (torch.Tensor) – Raw features, shape (N, F).

Returns:

Normalized features if normalization is enabled, otherwise returns input unchanged.

Return type:

torch.Tensor

compute_energy_stats(dataloader: torch.utils.data.DataLoader, atomic_energies_by_index: torch.Tensor | None = None, provided_shift: float | None = None, provided_scaling: float | None = None, show_progress: bool = True)[source]

Compute energy normalization statistics.

Computes per-atom energy min/max/avg and derives shift/scaling to normalize to [-1, 1] range (matching aenet convention).

Parameters:
  • dataloader (DataLoader) – DataLoader for training data.

  • atomic_energies_by_index (torch.Tensor, optional) – Per-species atomic energies, indexed by species_indices. Always applied to subtract from total energies. If all zeros, statistics are computed on total energies.

  • provided_shift (float, optional) – Override computed E_shift.

  • provided_scaling (float, optional) – Override computed E_scaling.

  • show_progress (bool, optional) – Whether to show progress bar during computation. Default: True

compute_feature_stats(dataloader: torch.utils.data.DataLoader, n_features: int, provided_stats: Dict[str, Any] | None = None, show_progress: bool = True)[source]

Compute or load feature normalization statistics.

Parameters:
  • dataloader (DataLoader) – DataLoader for training data (no shuffle needed).

  • n_features (int) – Number of feature dimensions.

  • provided_stats (dict, optional) – Pre-computed statistics with ‘mean’ and ‘std’/’cov’ keys. If provided, these will be used instead of computing from data.

  • show_progress (bool, optional) – Whether to show progress bar during computation. Default: True

denormalize_energy(energy_norm: torch.Tensor, n_atoms: int) float[source]

Convert normalized model output to physical energy.

Parameters:
  • energy_norm (torch.Tensor) – Normalized energy from model (sum of per-atom energies).

  • n_atoms (int) – Number of atoms in structure.

Returns:

Physical energy.

Return type:

float

get_state() Dict[str, Any][source]

Get normalization state for serialization.

Returns:

Dictionary containing normalization parameters.

Return type:

dict

has_feature_stats() bool[source]

Check if feature statistics have been computed or set.

property normalize_energy: bool

Whether energy normalization is enabled.

property normalize_features: bool

Whether feature normalization is enabled.

set_energy_stats(shift: float, scaling: float)[source]

Set energy normalization parameters.

Parameters:
  • shift (float) – Energy shift (per-atom midpoint).

  • scaling (float) – Energy scaling factor.

set_feature_stats(mean, std)[source]

Set feature normalization statistics from provided values.

Parameters:
  • mean (array-like) – Feature means.

  • std (array-like) – Feature standard deviations.

set_state(state: Dict[str, Any])[source]

Restore normalization state from dictionary.

Parameters:

state (dict) – Dictionary containing normalization parameters.

Metrics Tracker

class aenet.torch_training.training.MetricsTracker(track_detailed_timing: bool = False)[source]

Bases: object

Tracks training metrics and history.

Maintains per-epoch metrics for energy RMSE, force RMSE, learning rate, and timing information.

Parameters:

track_detailed_timing (bool) – Whether to track detailed timing breakdowns (data loading, loss computation, optimizer steps).

get_best(metric: str, mode: str = 'min') float[source]

Get the best value for a specific metric.

Parameters:
  • metric (str) – Metric name.

  • mode (str) – Either ‘min’ or ‘max’.

Returns:

Best value, or NaN if metric not found or empty.

Return type:

float

get_history() Dict[str, List[float]][source]

Get complete training history.

Returns:

Dictionary mapping metric names to lists of per-epoch values.

Return type:

dict

get_latest(metric: str) float[source]

Get the latest value for a specific metric.

Parameters:

metric (str) – Metric name (e.g., ‘train_energy_rmse’).

Returns:

Latest value, or NaN if metric not found or empty.

Return type:

float

replace_latest(train_energy_rmse: float, train_energy_mae: float, train_force_rmse: float, test_energy_rmse: float, test_energy_mae: float, test_force_rmse: float) None[source]

Replace the latest stored metric values.

This is used when the trainer wants to recompute final epoch metrics with a deterministic full-pass evaluation after the optimization loop has finished.

Parameters:
  • train_energy_rmse (float) – Replacement training energy RMSE.

  • train_energy_mae (float) – Replacement training energy MAE.

  • train_force_rmse (float) – Replacement training force RMSE.

  • test_energy_rmse (float) – Replacement validation energy RMSE.

  • test_energy_mae (float) – Replacement validation energy MAE.

  • test_force_rmse (float) – Replacement validation force RMSE.

reset()[source]

Clear all metrics history.

set_history(history: Dict[str, List[float]])[source]

Set history from a dictionary (e.g., when loading checkpoint).

Parameters:

history (dict) – Dictionary mapping metric names to lists of values.

update(train_energy_rmse: float, train_energy_mae: float, train_force_rmse: float, test_energy_rmse: float, test_energy_mae: float, test_force_rmse: float, learning_rate: float, epoch_time: float, forward_time: float, backward_time: float, train_timing: Dict[str, float], val_timing: Dict[str, float])[source]

Record metrics for one epoch.

Parameters:
  • train_energy_rmse (float) – Training energy RMSE.

  • train_energy_mae (float) – Training energy MAE.

  • train_force_rmse (float) – Training force RMSE (or NaN if not computed).

  • test_energy_rmse (float) – Validation energy RMSE (or NaN if no validation).

  • test_energy_mae (float) – Validation energy MAE (or NaN if no validation).

  • test_force_rmse (float) – Validation force RMSE (or NaN if not computed).

  • learning_rate (float) – Current learning rate.

  • epoch_time (float) – Total epoch time in seconds.

  • forward_time (float) – Forward pass time in seconds.

  • backward_time (float) – Backward pass time in seconds.

  • train_timing (dict) – Detailed training timing breakdown with keys: ‘data_loading’, ‘loss_computation’, ‘optimizer’, ‘total’

  • val_timing (dict) – Detailed validation timing breakdown (same keys as train_timing).

Training Loop

class aenet.torch_training.training.TrainingLoop(model, descriptor, normalizer: NormalizationManager, device: torch.device, dtype: torch.dtype)[source]

Bases: object

Executes training and validation epochs.

Handles: - Batch iteration and processing - Loss computation (energy + forces) - Backpropagation and optimization - Timing and metrics collection

Parameters:
  • model (nn.Module) – Model to train (EnergyModelAdapter).

  • descriptor (ChebyshevDescriptor) – Descriptor for force computation.

  • normalizer (NormalizationManager) – Normalization manager for features and energies.

  • device (torch.device) – Device for computation.

  • dtype (torch.dtype) – Data type for tensors.

run_epoch(loader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer | None, alpha: float, atomic_energies_by_index: torch.Tensor | None = None, train: bool = True, show_batch_progress: bool = False, force_scale_unbiased: bool = False, collect_structure_scores: bool = False) Tuple[float, float, float, Dict[str, float], dict[int, float] | None][source]

Run one epoch over loader.

Parameters:
  • loader (DataLoader) – DataLoader for batches.

  • optimizer (torch.optim.Optimizer or None) – Optimizer for training (None for validation).

  • alpha (float) – Force loss weight (0.0 = energy only, 1.0 = forces only).

  • atomic_energies_by_index (torch.Tensor, optional) – Atomic energies indexed by species. Always applied to convert total energies to cohesive. If not provided or all zeros, training targets remain as total energies.

  • train (bool) – Whether this is training (True) or validation (False).

  • show_batch_progress (bool) – Whether to show per-batch progress bar.

  • force_scale_unbiased (bool) – If True, apply optional sqrt(1/f) scaling to the per-batch force RMSE, where f is the supervised fraction of atoms with available force labels. This approximates constant loss magnitude when sub-sampling forces.

  • collect_structure_scores (bool) – If True, accumulate per-structure energy-error scores for adaptive sampling updates.

Returns:

  • energy_rmse (float) – Energy RMSE for epoch.

  • energy_mae (float) – Energy MAE for epoch.

  • force_rmse (float) – Force RMSE for epoch (NaN if no forces).

  • timing (dict) – Timing breakdown with keys: ‘data_loading’, ‘loss_computation’, ‘optimizer’, ‘total’.

  • structure_scores (dict[int, float] or None) – Per-structure mean scores keyed by split-local dataset index. Returned only when collect_structure_scores=True.