PyTorch-based Training

This page covers training machine learning interatomic potentials (MLIPs) using the PyTorch-based implementation in aenet-python. The PyTorch implementation provides a pure Python workflow with GPU acceleration, and automatic differentiation for forces.

Note

Training as described here makes use of PyTorch. Make sure to install core torch support as described in Installation & Set-up. Most descriptor-based training workflows also require the matching torch-scatter and torch-cluster wheels.

Note

Alternative: For training using ænet’s Fortran-based tools, see Training ANN Potentials (Fortran).

Overview

The PyTorch training workflow consists of three main steps:

  1. Prepare structures: Load atomic structures with energies (and optionally forces)

  2. Configure training: Set up the model architecture and training parameters

  3. Train the model: Run the training loop and save the trained potential

This tutorial demonstrates both energy-only and training on energies and forces.

Example notebooks

Jupyter notebooks with examples can be found in the notebooks directory within the repository.

For the maintained PyTorch training walkthrough, including the file-backed TiO2 workflow, explicit CachedStructureDataset usage, fixed train/test splits, dataset-backed prediction, and committee training, see example-05-torch-training.ipynb.

If you need to construct atomic_energies programmatically before training or before building a large HDF5 dataset, see aenet.reference_energies.ReferenceEnergies. Its regression helper accepts lazy (composition, energy) samples directly, and its reference-compound helper selects the lowest-energy sample for each requested composition before solving the constrained system. The module also provides a file-path iterator backed by aenet.io.structure for streaming-friendly preprocessing.

Energy-Only Training

Here’s a compact CPU-only example that keeps the full setup in memory. The notebook linked above remains the maintained home for the file-backed TiO2 workflow, checkpoint rotation, explicit CachedStructureDataset usage, fixed train/test splits, dataset-backed prediction, and plotting.

import numpy as np
import torch

from aenet.torch_featurize import ChebyshevDescriptor
from aenet.torch_training import (
    Adam,
    Structure,
    TorchANNPotential,
    TorchTrainingConfig,
)

structures = [
    Structure(
        positions=np.array(
            [
                [0.0, 0.0, 0.0],
                [0.9, 0.0, 0.0],
                [0.0, 0.9, 0.0],
            ]
        ),
        species=["H", "H", "H"],
        energy=0.0,
    ),
    Structure(
        positions=np.array(
            [
                [0.1, 0.0, 0.0],
                [1.0, 0.0, 0.0],
                [0.0, 1.0, 0.0],
            ]
        ),
        species=["H", "H", "H"],
        energy=0.5,
    ),
]

descriptor = ChebyshevDescriptor(
    species=["H"],
    rad_order=1,
    rad_cutoff=2.0,
    ang_order=0,
    ang_cutoff=2.0,
    min_cutoff=0.1,
    device="cpu",
    dtype=torch.float64,
)
arch = {"H": [(4, "tanh")]}

mlp = TorchANNPotential(arch, descriptor=descriptor)

config = TorchTrainingConfig(
    iterations=1,
    method=Adam(mu=0.001, batchsize=1),
    testpercent=50,
    force_weight=0.0,
    atomic_energies={"H": 0.0},
    normalize_features=False,
    normalize_energy=False,
    memory_mode="cpu",
    device="cpu",
    checkpoint_dir=None,
    checkpoint_interval=0,
    max_checkpoints=None,
    save_best=False,
    use_scheduler=False,
)

results = mlp.train(structures=structures, config=config)
print(results.errors[["RMSE_train", "RMSE_test"]].tail(1))

This trains a neural network potential using energies only, with 50% of the structures held out for validation. The train() method returns a TrainOut object containing training history, statistics, and plotting helpers.

Note

Setting testpercent > 0 does more than hold out structures. It also enables any validation-driven controls in your configuration, such as use_scheduler=True and save_best=True. On very small validation splits, these controls can react to noisy metrics and change the training behavior qualitatively.

Reproducibility Controls

The PyTorch trainer separates run-level stochastic behavior from split selection:

>>> from aenet.torch_training import TorchTrainingConfig
>>> config = TorchTrainingConfig(seed=11, split_seed=7)
>>> (config.seed, config.split_seed)
(11, 7)

Use split_seed when you want the trainer-owned train/validation partition to stay fixed across runs. Use seed when you want model initialization, training-shuffle order, weighted sampling, and random force-subset selection to be reproducible. Committee workflows typically keep split_seed shared across members while varying seed per member.

Committee Training

Phase 2 committee support adds a trainer-side orchestration layer on top of the single-member TorchANNPotential workflow:

from pathlib import Path

from aenet.torch_training import (
    Adam,
    TorchCommitteeConfig,
    TorchCommitteePotential,
    TorchTrainingConfig,
)

committee = TorchCommitteePotential(arch, descriptor=descriptor)
train_config = TorchTrainingConfig(
    iterations=1,
    method=Adam(mu=0.001, batchsize=1),
    testpercent=50,
    split_seed=7,
    atomic_energies={"H": 0.0},
    normalize_features=False,
    normalize_energy=False,
    memory_mode="cpu",
    device="cpu",
    checkpoint_dir=None,
    checkpoint_interval=0,
    max_checkpoints=None,
    save_best=False,
    use_scheduler=False,
)
committee_config = TorchCommitteeConfig(
    num_members=2,
    base_seed=11,
    max_parallel=1,
    output_dir=Path("committee_run"),
)

result = committee.train(
    structures=structures,
    config=train_config,
    committee_config=committee_config,
)
print(result.metadata_path)
print([member.seed for member in result.members])
print(result)

member_results = result.trainouts
member_0_errors = member_results[0].errors
committee_table = result.to_dataframe()
committee_stats = result.stats

Committee runs materialize a stable output layout:

committee_run/
  committee_metadata.json
  member_000/
    model.pt
    history.json
    history.csv
    summary.json
  member_001/
    model.pt
    history.json
    history.csv
    summary.json

The committee layer computes any trainer-owned train/validation split once in the parent process and reuses that split across all members. In the first committee implementation, the main reproducibility pattern is a shared split_seed with distinct per-member seed values derived from base_seed or from member_seeds.

TorchCommitteeTrainResult mirrors the single-network TrainOut summary style where possible. print(result) reports the mean and standard deviation of each available final metric across completed committee members. result.stats exposes the same aggregate values programmatically, result.to_dataframe() returns one row per member, and member.trainout or result.trainouts rebuilds the familiar per-member TrainOut objects from the persisted history.json files.

Committee Inference and ASCII Export

Phase 3 adds committee-level loading, aggregated prediction, and a committee-wide ASCII export helper:

reloaded = TorchCommitteePotential.from_directory(result.output_dir)
predictions = reloaded.predict(structures, eval_forces=False)
first_result = predictions[0]

print(first_result.energy_mean, first_result.energy_std)
print(first_result.member_energies)
print(predictions.member_outputs[0].total_energy)

uncertainty_table = predictions.to_dataframe()
most_uncertain = predictions.top_uncertain(n=10)

dataset_predictions = reloaded.predict_dataset(test_dataset)
dataset_uncertainty_table = dataset_predictions.to_dataframe()

members = reloaded.to_aenet_ascii(
    Path("ascii_committee"),
    prefix="committee",
    structures=structures,
)
print(members[0])

predict() and the dataset-backed predict_dataset() return a list-like TorchCommitteePredictResult. Iterating over it or indexing it returns one aenet.mlip.ensemble.AenetEnsembleResult per input structure, so existing list-style code remains valid. The result also keeps the per-member aenet.io.predict.PredictOut objects in member_outputs and provides to_dataframe(), sort_by(), and top_uncertain() helpers for uncertainty-driven structure selection. Dataset-backed prediction tracks both split-local index and root-dataset source_index where possible. When eval_forces=False, it follows the cached-feature TorchANNPotential.predict_dataset() path. When eval_forces=True, each member falls back to materialized structures, so the dataset must expose raw structures through get_structure(), structures, or a supported Subset wrapper.

The maintained notebook notebooks/example-05-torch-training.ipynb now includes a TiO2 committee-training example that trains a small committee, reloads it from committee_metadata.json, inspects aggregated uncertainty, and exports the member manifest for later Fortran-backed ensemble inference.

to_aenet_ascii() exports each committee member into a stable layout and returns the member manifest expected by AenetEnsembleInterface and AenetEnsembleCalculator: pass structures=... or explicit descriptor_stats=... when exact descriptor statistics must be written into the ASCII files.

ascii_committee/
  member_000/
    committee.H.nn.ascii
  member_001/
    committee.H.nn.ascii

Structure Sampling Policies

The PyTorch trainer distinguishes three separate concepts that all affect training behavior:

  • use_scheduler controls the learning-rate scheduler

  • force_sampling controls which force-labeled structures contribute force loss in a given epoch window

  • sampling_policy controls how structures in the training split are drawn into training batches

The default structure-sampling policy is uniform shuffled batching:

>>> from aenet.torch_training import TorchTrainingConfig
>>> config = TorchTrainingConfig(sampling_policy="uniform")
>>> config.sampling_policy
'uniform'

Epoch semantics are different for uniform and non-uniform policies:

  • sampling_policy="uniform" uses shuffled batching without replacement, so each training structure appears exactly once per epoch.

  • Non-uniform policies use weighted sampling with replacement and draw len(train_split) structures per epoch. Some structures may appear multiple times in one epoch and some may not appear at all.

  • iterations still means training epochs. Under non-uniform sampling, one epoch is not guaranteed to be a full pass over distinct training structures.

  • Validation sampling remains uniform and deterministic.

The static non-uniform option sampling_policy="energy_weighted" biases sampling toward lower cohesive or referenced formation energy per atom:

>>> config = TorchTrainingConfig(
...     sampling_policy="energy_weighted",
...     atomic_energies={"H": 0.0},
... )
>>> config.sampling_policy
'energy_weighted'

The weighting always uses the same atomic-reference convention as the training targets. When the trainer builds datasets from raw structures=... input, that convention comes from TorchTrainingConfig.atomic_energies. When you pass a prebuilt dataset, the dataset owns atomic_energies and the trainer uses those instead. If no atomic references are provided in either path, training still proceeds with all-zero atomic references; in that case, the energy-weighted policy uses the provided per-atom labels as-is and emits a warning so the fallback is explicit.

The exact per-draw sampling probability is determined as follows for a training split with N structures. For structure i:

\[e_i = \frac{E_i - \sum_{a \in i} E^{\mathrm{atom}}_a}{n_i}\]

where E_i is the stored total energy, E^{atom}_a comes from the resolved atomic-reference convention, and n_i is the atom count. Then:

\[\Delta_i = e_i - \min_j e_j\]
\[\Delta_{\max} = \max_j \Delta_j\]

If \Delta_max <= 0, all structures receive equal weight:

\[w_i = 1\]

Otherwise:

\[w_i = \frac{1}{1 + \Delta_i / \Delta_{\max}}\]

The trainer draws with replacement using num_samples = N per epoch, so the probability that a single draw selects structure i is:

\[p_i = \frac{w_i}{\sum_j w_j}\]

This is the full implementation. Lower referenced per-atom energy means larger w_i and therefore larger sampling probability p_i.

The adaptive non-uniform option sampling_policy="error_weighted" starts with uniform epoch-0 sampling and then increases the sampling frequency of structures with higher recently observed training loss:

>>> config = TorchTrainingConfig(
...     sampling_policy="error_weighted",
... )
>>> config.sampling_policy
'error_weighted'

Its behavior is:

  • Epoch 0 uses uniform weights because no per-structure error history exists yet.

  • After each training epoch, the trainer computes a structure-level score from the sampled training structures and normalizes the next epoch’s weights so the mean weight is 1.

  • Those structure-level scores are measured in the same training target space used for the energy loss. If training uses referenced cohesive or formation energies, adaptive sampling uses the same references; if training uses raw total energies, adaptive sampling follows that convention instead.

  • Force losses do not contribute to these adaptive structure scores, even during force training. error_weighted always uses energy error only.

  • If a structure is sampled multiple times in an epoch, its next score uses the mean of those sampled occurrences.

  • If a structure is not sampled in an epoch, it keeps its previous score.

  • Resume currently does not persist adaptive sampler state; resumed error_weighted training therefore bootstraps from uniform sampling again.

The structure-level score used by error_weighted is:

  • all training modes: absolute energy error per atom for that structure

  • force-training settings such as force_weight, force_fraction, and force_sampling do not change the adaptive structure score definition

The exact adaptive-sampling update is:

  1. Epoch 0 starts with uniform structure scores:

    \[s_i^{(0)} = 1\]
  2. After epoch t, each sampled structure gets an observed score \hat{s}_i^{(t)} equal to the mean absolute energy error per atom of that structure’s sampled occurrences during the epoch. If a structure is not sampled in epoch t, it keeps its previous score:

    \[\begin{split}s_i^{(t+1)} = \begin{cases} \hat{s}_i^{(t)} & \text{if structure } i \text{ was sampled in epoch } t \\ s_i^{(t)} & \text{otherwise} \end{cases}\end{split}\]
  3. The trainer converts scores into positive sampler weights by first replacing non-finite values with 0 and clamping negative values to 0:

    \[u_i = \max(0, s_i)\]
  4. If all u_i are zero, the sampler falls back to uniform weights:

    \[w_i = 1\]
  5. Otherwise, the trainer clamps each nonzero weight to at least 10^{-12} and normalizes weights to unit mean:

    \[\tilde{u}_i = \max(u_i, 10^{-12})\]
    \[w_i = \frac{\tilde{u}_i}{\frac{1}{N}\sum_j \tilde{u}_j}\]
  6. As with energy_weighted, sampling is with replacement and num_samples = N per epoch, so each individual draw uses:

    \[p_i = \frac{w_i}{\sum_j w_j}\]

Because dividing by the mean does not change normalized probabilities, error_weighted is equivalent to drawing with probability proportional to the latest clamped per-structure score. The unit-mean normalization only keeps the raw weight magnitudes numerically well scaled.

For both non-uniform policies, force_sampling remains a separate control. It determines whether a sampled force-labeled structure contributes force loss; it does not define how often that structure is drawn into batches.

Force Training

To include force supervision, add force arrays to the structures and set force_weight > 0.0:

>>> from aenet.torch_training import Adam, TorchTrainingConfig

>>> config = TorchTrainingConfig(
...     iterations=2,
...     method=Adam(mu=0.001, batchsize=1),
...     testpercent=50,
...     force_weight=0.1,
...     force_fraction=0.5,
...     force_sampling="fixed",
... )
>>> config.force_weight
0.1
>>> config.force_fraction
0.5
>>> config.force_sampling
'fixed'

The force_weight parameter (α) balances energy and force contributions:

\[\text{Loss} = (1 - \alpha) \cdot \text{RMSE}_{\text{energy}} + \alpha \cdot \text{RMSE}_{\text{forces}}\]

Common values:

  • force_weight=0.0: Energy-only (fastest training)

  • force_weight=0.1: Primarily energy, with force regularization

  • force_weight=0.5: Equal weighting

  • force_weight=1.0: Force-only (rarely used)

Note

Force training requires structures with force data. Structures without forces will only contribute to the energy loss term.

The notebook linked above remains the maintained home for the longer force-training workflow, including checkpoint output and plotting.

Dataset Options

The PyTorch training workflow supports flexible dataset options, from simple structure lists to advanced HDF5-backed lazy-loading for large-scale training.

For detailed information about dataset classes, input formats, and performance optimization, see PyTorch Dataset Options.

The longer file-backed dataset workflow is intentionally kept in the training notebook above so the torch_datasets page can stay focused on compact API-facing examples.

Execution Model

The current trainer has two distinct runtime stages:

  1. Sample preparation happens in the main process when num_workers=0, or in DataLoader workers when num_workers > 0. Structures are converted to tensors on descriptor.device, and descriptor featurization, neighbor reuse, graph/triplet construction, and lazy HDF5 cache reads happen there.

  2. The collated batch is then moved onto config.device inside the training loop. Model forward passes, normalization, loss computation, and optimizer steps run on that device.

In practice, GPU training with num_workers > 0 is best understood as worker-side data preparation feeding a training loop on the selected device. It is not currently a separate mixed CPU/GPU execution pipeline.

If descriptor.device and config.device match, featurization and model compute happen on the same device. If they differ, samples are materialized on descriptor.device and transferred before the forward pass. The compact examples on this page create the descriptor on CPU, so later device='cuda' examples describe CPU-side sample preparation feeding GPU training unless you also move the descriptor to CUDA.

For HDF5-backed datasets, each worker reopens its own read-only file handle and keeps its own bounded in_memory_cache_size LRU cache. Trainer-owned runtime caches (cache_features, cache_neighbors, cache_force_triplets) are also per process/worker, so cache_warmup=True is skipped automatically when num_workers > 0. See PyTorch Dataset Options for persisted HDF5 cache precedence and for the distinction between build-time build_workers and training-time num_workers.

memory_mode='mixed' is reserved for a future real mixed-memory mode and currently raises NotImplementedError if requested. Today, the supported execution modes remain 'cpu' and 'gpu'.

Performance Optimization Tips

For Energy-Only Training

>>> from aenet.torch_training import TorchTrainingConfig
>>> config = TorchTrainingConfig(
...     force_weight=0.0,
...     cache_features=True,
...     num_workers=4,
...     prefetch_factor=4,
...     persistent_workers=True,
... )
>>> (config.cache_features, config.num_workers, config.prefetch_factor)
(True, 4, 4)

For Force Training

>>> config = TorchTrainingConfig(
...     force_weight=0.1,
...     force_fraction=0.3,
...     force_sampling="random",
...     cache_features=True,
...     cache_neighbors=True,
...     num_workers=4,
...     prefetch_factor=4,
... )
>>> (config.cache_neighbors, config.cache_force_triplets)
(True, False)

Caching Strategies

  • cache_features: For energy-only structure-list workflows, this can precompute features eagerly. For force training, it caches energy-view features for structures not selected for force supervision in the current epoch.

  • cache_neighbors: Reuse neighbor search results for energy-view reuse and legacy non-graph paths

  • cache_force_triplets: Cache CSR graphs and triplets for the default sparse force-training path instead of rebuilding them on demand

  • cache_*_max_entries: Bound the trainer-owned runtime caches per split and per process/worker instead of letting them grow without limit

  • cache_warmup: Optional single-process prefill of trainer-owned runtime caches before epoch 0; skipped automatically when num_workers > 0

These runtime caches are distinct from the on-disk HDF5 persisted cache sections created with HDF5StructureDataset.build_database(...). For HDF5 datasets, cache_features=True is still only a per-run in-memory layer; it does not replace persist_features=True or persist_force_derivatives=True, which are the build-time options for reusing raw features or sparse local derivatives across sessions. See PyTorch Dataset Options for the full cache-precedence workflow. HDF5 energy filtering is also a build-time concern: set max_energy and atomic_energies on the HDF5StructureDataset before calling build_database() rather than relying on TorchTrainingConfig.max_energy at runtime.

Common Pitfalls

  1. Descriptor mismatch: Ensure descriptor species order matches your dataset. Datasets use descriptor.species_to_idx for species indexing.

Training Configuration

The TorchTrainingConfig class provides extensive control over the training process. Here are the most commonly used parameters:

Basic Settings

>>> from aenet.torch_training import TorchTrainingConfig
>>> config = TorchTrainingConfig(
...     iterations=100,
...     testpercent=10,
...     device="cpu",
...     show_progress=True,
... )
>>> (config.iterations, config.device, config.show_progress)
(100, 'cpu', True)

Optimizer Selection

Choose and configure the optimization algorithm:

>>> from aenet.torch_training import Adam, SGD, TorchTrainingConfig

>>> method = Adam(
...     mu=0.001,
...     batchsize=32,
...     beta1=0.9,
...     beta2=0.999,
...     weight_decay=0.0,
... )
>>> (method.method_name, method.batchsize)
('adam', 32)

>>> method = SGD(
...     lr=0.01,
...     batchsize=32,
...     momentum=0.9,
...     weight_decay=0.0,
... )
>>> TorchTrainingConfig(iterations=100, method=method).method.method_name
'sgd'

Adam is recommended for most applications due to its adaptive learning rates and robust convergence properties.

Common Training Patterns

Small Dataset (< 100 structures)

config = TorchTrainingConfig(
    iterations=200,  # More epochs for small data
    method=Adam(mu=0.001, batchsize=16),  # Smaller batches
    testpercent=10,
    force_weight=0.1,
    device='cpu'  # CPU fine for small datasets
)

Large Dataset (> 500 structures)

config = TorchTrainingConfig(
    iterations=50,   # Fewer epochs needed
    method=Adam(mu=0.001, batchsize=64),  # Larger batches
    testpercent=10,
    force_weight=0.1,
    device='cuda',  # Model/loss on GPU
    # Performance optimizations
    cache_features=True,  # Runtime in-memory feature cache
    cache_feature_max_entries=1024,
    num_workers=8,         # Parallel CPU-side sample preparation
    prefetch_factor=4
)

Energy-Only with Maximum Speed

config = TorchTrainingConfig(
    iterations=100,
    method=Adam(mu=0.001, batchsize=32),
    testpercent=10,
    force_weight=0.0,  # Energy-only
    cache_features=True,  # Bounded runtime feature cache for this run
    cache_warmup=True,    # Optional single-process prefill
    device='cuda'
)

Force Training with Optimizations

config = TorchTrainingConfig(
    iterations=100,
    method=Adam(mu=0.001, batchsize=32),
    testpercent=10,
    force_weight=0.1,
    force_fraction=0.3,  # Use 30% of forces (3× faster)
    cache_neighbors=True,  # Cache worker-local neighbor lists
    num_workers=4,         # Parallel CPU-side sample preparation
    device='cuda'
)

Advanced Configuration Reference

This section documents all configuration parameters available in TorchTrainingConfig.

Checkpointing & Model Saving

checkpoint_dirstr (default: ‘checkpoints’)

Directory to save checkpoint files. Set to None to disable checkpointing.

checkpoint_intervalint (default: 1)

Save a checkpoint every N epochs. Set to 0 to disable periodic checkpoints.

max_checkpointsint (default: None)

Maximum number of checkpoint files to keep. Older checkpoints are automatically deleted. None = keep all checkpoints.

save_bestbool (default: True)

Save the model with the best validation loss as best_model.pt. Requires testpercent > 0 to compute validation loss.

For very small validation sets, the selected checkpoint can be unstable. In that case prefer save_best=False or supply a larger or explicit validation split.

Resuming Training

To resume training from a checkpoint, pass the checkpoint path to train(..., resume_from="checkpoints/checkpoint_epoch_0050.pt"). The notebook above contains the maintained checkpoint workflow.

When resume_from is provided, config.iterations means the number of additional epochs to run in that train() call. For example, resuming a checkpoint with iterations=10 runs 10 more epochs after the saved checkpoint epoch, regardless of how many epochs were completed in the original run. This applies to numbered checkpoints and best_model.pt alike.

The trainer will automatically:

  • Load model and optimizer state

  • Restore training history and normalization statistics

  • Continue from the next epoch

Note

Checkpoint files are NOT interchangeable with model files created by save(). Checkpoints include additional training state (optimizer, history) needed for resuming, while model files are optimized for deployment and inference.

Learning Rate Scheduling

use_schedulerbool (default: False)

Enable learning rate scheduler. Uses ReduceLROnPlateau, which reduces the learning rate when validation loss plateaus.

scheduler_patienceint (default: 10)

Number of epochs with no improvement before reducing learning rate.

scheduler_factorfloat (default: 0.5)

Factor by which to reduce learning rate. New LR = current LR × factor.

scheduler_min_lrfloat (default: 1e-6)

Minimum allowed learning rate. Scheduler stops reducing below this value.

Example Usage

>>> from aenet.torch_training import Adam, TorchTrainingConfig
>>> config = TorchTrainingConfig(
...     iterations=200,
...     method=Adam(mu=0.001, batchsize=32),
...     testpercent=10,
...     use_scheduler=True,
...     scheduler_patience=10,
...     scheduler_factor=0.5,
...     scheduler_min_lr=1e-6,
... )
>>> (config.use_scheduler, config.scheduler_patience)
(True, 10)

The scheduler helps training converge when progress stalls, automatically adjusting the learning rate for optimal performance.

Note

The scheduler requires testpercent > 0 to monitor validation loss. With only a few validation structures, the monitored loss can be too noisy for stable plateau detection. In that case prefer use_scheduler=False or a larger or explicit validation split.

Force Training Parameters

force_fractionfloat (default: 1.0)

Fraction of structures (0.0-1.0) to use for force training. Using a subset can significantly speed up training while maintaining accuracy. Example: force_fraction=0.3 uses 30% of force-labeled structures.

force_samplingstr (default: ‘random’)

Sampling strategy for force subset: 'random' (resample periodically) or 'fixed' (static subset). Random sampling provides better generalization.

force_resample_num_epochsint (default: 0)

Number of epochs between resampling the force-trained subset when force_sampling='random'. Controls the resampling frequency:

  • 0 = No resampling (use fixed subset for entire training)

  • 1 = Resample every epoch (maximum variety, highest computational cost)

  • N > 1 = Resample every N epochs (balance between variety and efficiency)

Note

The default value of 0 (no resampling) represents a conservative choice that maintains consistent training dynamics and reduces computational overhead. Set to 1 or higher for dynamic resampling.

force_min_structures_per_epochint (default: 1)

Minimum number of force-labeled structures per epoch, regardless of force_fraction. Ensures force gradient signal is not lost.

force_scale_unbiasedbool (default: False)

Apply sqrt(1/f) scaling to force RMSE where f is the supervised fraction, approximating constant scale under sub-sampling.

Performance & Caching

cache_featuresbool (default: False)

Enable feature caching. Behavior depends on training mode:

  • For energy-only training (force_weight=0): Pre-computes all features once, providing ~100× speedup

  • For force training (force_weight > 0): Caches features for structures not selected for force supervision in current epoch (useful with force_fraction < 1.0)

cache_feature_max_entriesint or None (default: 1024)

Maximum number of trainer-owned energy-view feature entries to retain per split and per process/worker when cache_features=True. Use None for an explicit unbounded cache or 0 to suppress storage.

cache_neighborsbool (default: False)

Cache per-structure neighbor graphs (indices, displacement vectors) across epochs. Avoids repeated neighbor searches for fixed geometries on energy-view reuse and legacy non-graph paths. Supported force training does not require this option.

cache_neighbor_max_entriesint or None (default: 512)

Maximum number of trainer-owned neighbor payload entries to retain per split and per process/worker when cache_neighbors=True. Use None for an explicit unbounded cache or 0 to suppress storage.

cache_force_tripletsbool (default: False)

Cache CSR neighbor graphs and precompute angular triplet indices for the default sparse force-training path. Leaving this disabled still uses the sparse graph/triplet path, but rebuilds those graph payloads on demand.

cache_force_triplet_max_entriesint or None (default: 256)

Maximum number of trainer-owned graph/triplet payload entries to retain per split and per process/worker when cache_force_triplets=True. Use None for an explicit unbounded cache or 0 to suppress storage.

cache_persist_dirstr (default: None)

Directory for persisting graph/triplet caches to disk for reuse across runs.

cache_scopestr (default: ‘all’)

Which dataset splits to cache: 'train', 'val', or 'all'.

cache_warmupbool (default: False)

If True, pre-populate trainer-owned runtime caches before the first epoch in single-process training. When all enabled caches have finite entry limits, warmup stops once those limits are filled. Warmup is skipped automatically when num_workers > 0 because workers own their own cache instances and the main-process warmup would not populate those worker-local caches.

num_workersint (default: 0)

Number of parallel DataLoader workers for structure loading, HDF5 reads, and on-the-fly featurization. 0 keeps sample preparation in the main process. Values >0 parallelize worker-side sample preparation; they do not parallelize model compute.

prefetch_factorint (default: 2)

Number of batches to prefetch per worker when num_workers > 0.

persistent_workersbool (default: True)

Keep DataLoader workers alive between epochs for faster iteration. During training, this is disabled automatically when force_sampling='random' uses epoch-level resampling, because worker copies would otherwise keep a stale force-supervision subset. Trainer-owned runtime caches and HDF5 in_memory_cache_size state are also worker-local when num_workers > 0. For HDF5-backed datasets, worker handles are opened lazily per worker and closed explicitly when that worker exits.

Data Filtering & Quality Control

max_energyfloat (default: None)

Exclude structures with referenced cohesive or formation energy per atom above this threshold when the trainer constructs datasets from raw structures=... input. If atomic_energies is omitted, the filter falls back to all-zero atomic references and uses the provided per-atom labels as-is. When you pass a prebuilt dataset=... or explicit train_dataset=.../test_dataset=..., this option is ignored and the trainer emits a warning.

max_forcesfloat (default: None)

Exclude structures with maximum atomic force magnitude above this threshold. Units: eV/Å.

Energy Configuration

atomic_energiesdict (default: None)

Optional atomic reference energies used to convert total energies to cohesive-energy targets during training when the trainer constructs datasets from raw structures=... input. Format: {'H': -13.6, 'O': -432.0, ...} in eV. If omitted, the training target remains the total energy because all atomic reference energies default to 0.0. When you pass a prebuilt dataset=... or explicit train_dataset=.../test_dataset=..., the dataset owns atomic_energies instead; matching config values are allowed, but mismatched values raise an error.

normalize_featuresbool (default: True)

Normalize features to zero mean and unit variance. Improves training stability and convergence.

normalize_energybool (default: True)

Normalize energies by shifting and scaling. Applied after cohesive energy conversion if enabled.

E_shiftfloat (default: None)

Override per-atom energy shift for normalization. Auto-computed from training set if None.

E_scalingfloat (default: None)

Override energy scaling factor. Auto-computed from training set if None.

feature_statsdict (default: None)

Override feature normalization statistics. Format: {'mean': np.ndarray, 'std': np.ndarray}. Auto-computed from training set if None.

Output & Diagnostics

save_energiesbool (default: False)

Save predicted energies for train/test sets to disk. The Path-of-input-file column preserves the original structure path or name when available; otherwise it uses a stable structure_XXXXXX identifier from the pre-split input order. For HDF5-backed datasets, the identifier is synthesized from persisted source metadata as display_name#frame=N when a display name is available, source_id#frame=N otherwise, then name#frame=N when only the persisted structure name is available, and structure_XXXXXX#frame=N as the final fallback. Source metadata is validated at HDF5 build time so these identifiers are not silently truncated on write.

save_forcesbool (default: False)

Save predicted forces for train/test sets to disk.

timingbool (default: False)

Enable detailed timing output for performance profiling.

show_progressbool (default: True)

Display epoch-level progress bar. The reported training errors depend on the active sampling strategy: with sampling_policy="uniform", the epoch training error is computed from one full pass over the training split without replacement; with non-uniform sampling, the displayed training error is computed from that epoch’s sampled-with-replacement training draws and may therefore include repeated structures and omit others. The final metrics returned by train() are recomputed afterwards from a deterministic full pass over the train/test splits.

show_batch_progressbool (default: False)

Display batch-level progress bar within each epoch. Verbose for large datasets.

Advanced Options

precisionstr (default: ‘auto’)

Numeric precision: 'auto' (match descriptor dtype), 'float32', or 'float64'. Higher precision improves accuracy but increases memory usage.

memory_modestr (default: ‘gpu’)

Memory management strategy: 'cpu', 'gpu', or 'mixed'. 'mixed' is reserved for a future real mixed-memory implementation and currently raises NotImplementedError. Use 'cpu' or 'gpu' with descriptor.device and device set explicitly to control the current execution path.

devicestr (default: None)

PyTorch device: 'cpu', 'cuda', or 'cuda:0'. Auto-detected if None. This selects the model/training-loop device. descriptor.device separately controls where structures are featurized. When the two differ, samples are prepared on descriptor.device and moved to device before the forward pass.

Monitoring Training Progress

The TrainOut object returned by train() provides built-in visualization and analysis tools:

Common entry points are:

  • results.plot_training_summary(outfile="training_summary.png") for a combined energy/force plot

  • results.plot_training_errors(outfile="energy_errors.png") for energy-only training curves

  • results.plot_force_errors(outfile="force_errors.png") when force data are present

  • results.errors for direct access to the underlying pandas DataFrame used for custom plotting

The notebook linked above demonstrates these plotting helpers in a full training workflow.

Signs of good training:

  • Steady decrease in both train and test RMSE

  • Test RMSE follows train RMSE (no overfitting)

  • Convergence to acceptable error levels (< 0.01 eV/atom for energy)

Signs of problems:

  • Test RMSE increases while train RMSE decreases (overfitting)

  • Both RMSEs plateau at high values (underfitting, poor architecture)

  • Divergence or oscillation (learning rate too high)

See Also