"""
Training loop execution for PyTorch training.
Handles epoch execution including batch processing, loss computation,
and optimization steps.
"""
import time
from typing import Dict, List, Optional, Tuple
import torch
from torch.utils.data import DataLoader
from ..loss import compute_energy_loss, compute_force_loss
from .normalization import NormalizationManager
# Progress bar (match aenet.mlip behavior)
try:
from tqdm import tqdm # type: ignore
except Exception:
tqdm = None # type: ignore
def _iter_progress(iterable, enable: bool, desc: str):
"""Wrap an iterable with tqdm progress bar if enabled and available."""
if enable and tqdm is not None:
try:
total = len(iterable) # type: ignore[arg-type]
except Exception:
total = None
return tqdm(iterable, total=total, desc=desc, ncols=80, leave=False)
return iterable
[docs]
class TrainingLoop:
"""
Executes training and validation epochs.
Handles:
- Batch iteration and processing
- Loss computation (energy + forces)
- Backpropagation and optimization
- Timing and metrics collection
Parameters
----------
model : nn.Module
Model to train (EnergyModelAdapter).
descriptor : ChebyshevDescriptor
Descriptor for force computation.
normalizer : NormalizationManager
Normalization manager for features and energies.
device : torch.device
Device for computation.
dtype : torch.dtype
Data type for tensors.
"""
def __init__(
self,
model,
descriptor,
normalizer: NormalizationManager,
device: torch.device,
dtype: torch.dtype,
):
self.model = model
self.descriptor = descriptor
self.normalizer = normalizer
self.device = device
self.dtype = dtype
# Timing state (last epoch)
self.last_forward_time: float = 0.0
self.last_backward_time: float = 0.0
self.last_data_loading_time: float = 0.0
self.last_loss_computation_time: float = 0.0
self.last_optimizer_time: float = 0.0
[docs]
def run_epoch(
self,
loader: DataLoader,
optimizer: Optional[torch.optim.Optimizer],
alpha: float,
atomic_energies_by_index: Optional[torch.Tensor] = None,
train: bool = True,
show_batch_progress: bool = False,
force_scale_unbiased: bool = False,
collect_structure_scores: bool = False,
) -> Tuple[float, float, float, Dict[str, float], Optional[dict[int, float]]]:
"""
Run one epoch over loader.
Parameters
----------
loader : DataLoader
DataLoader for batches.
optimizer : torch.optim.Optimizer or None
Optimizer for training (None for validation).
alpha : float
Force loss weight (0.0 = energy only, 1.0 = forces only).
atomic_energies_by_index : torch.Tensor, optional
Atomic energies indexed by species. Always applied to convert
total energies to cohesive. If not provided or all zeros,
training targets remain as total energies.
train : bool
Whether this is training (True) or validation (False).
show_batch_progress : bool
Whether to show per-batch progress bar.
force_scale_unbiased : bool
If True, apply optional sqrt(1/f) scaling to the per-batch
force RMSE, where f is the supervised fraction of atoms
with available force labels. This approximates constant
loss magnitude when sub-sampling forces.
collect_structure_scores : bool
If True, accumulate per-structure energy-error scores for
adaptive sampling updates.
Returns
-------
energy_rmse : float
Energy RMSE for epoch.
energy_mae : float
Energy MAE for epoch.
force_rmse : float
Force RMSE for epoch (NaN if no forces).
timing : dict
Timing breakdown with keys: 'data_loading', 'loss_computation',
'optimizer', 'total'.
structure_scores : dict[int, float] or None
Per-structure mean scores keyed by split-local dataset index.
Returned only when ``collect_structure_scores=True``.
"""
energy_losses: List[float] = []
energy_maes: List[float] = []
force_losses: List[float] = []
structure_score_sums: dict[int, float] = {}
structure_score_counts: dict[int, int] = {}
forward_time_total: float = 0.0
backward_time_total: float = 0.0
data_loading_time_total: float = 0.0
loss_computation_time_total: float = 0.0
optimizer_time_total: float = 0.0
t_epoch_start = time.perf_counter()
iterator = _iter_progress(
loader,
enable=show_batch_progress,
desc=("train" if train else "val"),
)
t_batch_start = time.perf_counter()
for batch in iterator:
# Data loading time
t_data_end = time.perf_counter()
data_loading_time_total += t_data_end - t_batch_start
# Energy view tensors
features = batch["features"].to(self.device)
species_indices = batch["species_indices"].to(self.device)
n_atoms = batch["n_atoms"].to(self.device)
energy_ref = batch["energy_ref"].to(self.device)
# Ensure dtype consistency
if self.dtype == torch.float64:
features = features.double()
energy_ref = energy_ref.double()
else:
features = features.float()
energy_ref = energy_ref.float()
# Subtract atomic reference energies (always performed if provided)
# When atomic_energies are all zeros, this becomes a no-op
if atomic_energies_by_index is not None:
per_atom_Ea = atomic_energies_by_index[species_indices]
batch_idx = torch.repeat_interleave(
torch.arange(len(n_atoms), device=self.device),
n_atoms.long(),
)
Ea_sum = torch.zeros(
len(n_atoms), dtype=energy_ref.dtype, device=self.device
)
Ea_sum.scatter_add_(0, batch_idx, per_atom_Ea)
energy_ref = energy_ref - Ea_sum
# Feature normalization
features = self.normalizer.apply_feature_normalization(features)
# Forward + loss computation
t_forward_start = time.perf_counter()
t_loss_start = time.perf_counter()
# Energy loss
E_shift = self.normalizer.E_shift
E_scaling = self.normalizer.E_scaling
energy_loss_t, energy_pred = compute_energy_loss(
features=features,
energy_ref=energy_ref,
n_atoms=n_atoms,
network=self.model,
species_indices=species_indices,
E_shift=float(E_shift),
E_scaling=float(E_scaling),
)
# Compute MAE (per atom)
energy_mae_t = torch.mean(
torch.abs((energy_pred - energy_ref) / n_atoms))
energy_error_per_structure = torch.abs(
(energy_pred - energy_ref) / n_atoms
)
# Optional force loss
force_loss_t: Optional[torch.Tensor] = None
if alpha > 0.0 and batch["positions_f"] is not None:
features_f = batch.get("features_f", None)
positions_f = batch["positions_f"].to(self.device)
forces_ref_f = batch["forces_ref_f"].to(self.device)
species_indices_f = batch["species_indices_f"].to(self.device)
species_f = batch["species_f"]
local_derivatives_f = batch.get("local_derivatives_f", None)
graph_f = batch.get("graph_f", None)
if local_derivatives_f is None and graph_f is None:
raise RuntimeError(
"Force-training batches must include either "
"precomputed local derivatives or graph data."
)
# dtype
if self.dtype == torch.float64:
if features_f is not None:
features_f = features_f.to(self.device).double()
positions_f = positions_f.double()
forces_ref_f = forces_ref_f.double()
else:
if features_f is not None:
features_f = features_f.to(self.device).float()
positions_f = positions_f.float()
forces_ref_f = forces_ref_f.float()
force_loss_t, forces_pred = compute_force_loss(
positions=positions_f,
species=species_f,
forces_ref=forces_ref_f,
descriptor=self.descriptor,
network=self.model,
species_indices=species_indices_f,
cell=None,
pbc=None,
E_scaling=float(E_scaling),
neighbor_info=None,
chunk_size=None,
feature_mean=(
self.normalizer.feature_mean
if self.normalizer.normalize_features
else None
),
feature_std=(
self.normalizer.feature_std
if self.normalizer.normalize_features
else None
),
features=(
features_f
if local_derivatives_f is not None
else None
),
local_derivatives=local_derivatives_f,
graph=graph_f,
triplets=batch.get("triplets_f", None),
center_indices=None,
)
# Optional unbiased scaling of RMSE based on
# supervised fraction
if force_scale_unbiased:
try:
eff_total = int(batch.get("n_atoms_force_total", 0))
eff_supervised = int(
batch.get("n_atoms_force_supervised", 0))
if eff_total > 0:
eff_f = float(eff_supervised) / float(eff_total)
if eff_f > 0.0 and eff_f < 1.0:
scale = torch.tensor(
eff_f,
dtype=force_loss_t.dtype,
device=force_loss_t.device,
)
# Approximate scaling for RMSE (MSE would
# not require scaling)
force_loss_t = force_loss_t / torch.sqrt(scale)
except Exception:
# If bookkeeping not present, skip scaling
pass
# Combine losses
if force_loss_t is None:
combined = (1.0 - alpha) * energy_loss_t
else:
combined = (
(1.0 - alpha) * energy_loss_t + alpha * force_loss_t
)
t_loss_end = time.perf_counter()
loss_computation_time_total += t_loss_end - t_loss_start
forward_time_total += t_loss_end - t_forward_start
# Backward + optimizer
if train and optimizer is not None:
t_backward_start = time.perf_counter()
optimizer.zero_grad(set_to_none=True)
combined.backward()
t_optimizer_start = time.perf_counter()
backward_time_total += t_optimizer_start - t_backward_start
optimizer.step()
t_optimizer_end = time.perf_counter()
optimizer_time_total += t_optimizer_end - t_optimizer_start
# Collect losses and MAE
energy_losses.append(float(energy_loss_t.detach().cpu()))
energy_maes.append(float(energy_mae_t.detach().cpu()))
if force_loss_t is not None:
force_losses.append(float(force_loss_t.detach().cpu()))
if collect_structure_scores:
sample_indices = batch.get("sample_indices", None)
if sample_indices is not None:
energy_scores = energy_error_per_structure.detach().cpu().tolist()
for sample_idx, energy_score in zip(
sample_indices.tolist(),
energy_scores,
strict=True,
):
structure_score = float(energy_score)
idx = int(sample_idx)
structure_score_sums[idx] = (
structure_score_sums.get(idx, 0.0) + structure_score
)
structure_score_counts[idx] = (
structure_score_counts.get(idx, 0) + 1
)
# Prepare for next batch
t_batch_start = time.perf_counter()
t_epoch_end = time.perf_counter()
# Compute RMSE and MAE averages
energy_rmse = float(
sum(energy_losses) / max(1, len(energy_losses))
)
energy_mae = float(
sum(energy_maes) / max(1, len(energy_maes))
)
force_rmse = (
float(sum(force_losses) / max(1, len(force_losses)))
if force_losses
else float("nan")
)
# Store timing for this epoch
self.last_forward_time = forward_time_total
self.last_backward_time = backward_time_total
self.last_data_loading_time = data_loading_time_total
self.last_loss_computation_time = loss_computation_time_total
self.last_optimizer_time = optimizer_time_total
timing = {
"data_loading": data_loading_time_total,
"loss_computation": loss_computation_time_total,
"optimizer": optimizer_time_total,
"total": t_epoch_end - t_epoch_start,
}
structure_scores = None
if collect_structure_scores:
structure_scores = {
idx: (
structure_score_sums[idx]
/ max(1, structure_score_counts.get(idx, 0))
)
for idx in structure_score_sums
}
return energy_rmse, energy_mae, force_rmse, timing, structure_scores