Source code for aenet.reference_energies

"""Helpers for constructing atomic reference energies.

This module centers the regression workflow on lightweight
``(composition, energy)`` samples, so users can stream data from files,
databases, or custom parsers without materializing full structure objects in
memory. File-backed convenience adapters are provided separately for workflows
that already rely on :mod:`aenet.io.structure`.
"""

from __future__ import annotations

import copy
import math
import random
import re
from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from numbers import Integral, Real
from os import PathLike
from typing import Any

import numpy as np

from .geometry.structure import AtomicStructure
from .io.structure import read as read_structure

__all__ = ["ReferenceEnergies", "iter_composition_energy_samples_from_files"]


@dataclass(frozen=True)
class _EnergySample:
    """Lightweight composition/energy sample used by regression helpers."""

    composition: dict[str, int]
    energy: float


@dataclass(frozen=True)
class _SolveResult:
    """Internal result of one constrained reference-energy solve."""

    atomic_energies: dict[str, float]
    species_order: list[str]
    fixed_atomic_energies: dict[str, float]
    free_species: list[str]
    rank: int
    residual_sum_squares: float
    rmse: float
    singular_values: list[float]


def _validated_energy(value: Any) -> float:
    """Return one finite energy value."""
    try:
        energy = float(value)
    except Exception as exc:
        raise ValueError(
            "Regression requires finite total energies for every sample."
        ) from exc

    if not math.isfinite(energy):
        raise ValueError(
            "Regression requires finite total energies for every sample."
        )
    return energy


def _normalize_composition(
    composition: Mapping[Any, Any],
) -> dict[str, int]:
    """Validate and normalize one composition mapping."""
    if not isinstance(composition, Mapping):
        raise TypeError("Each sample composition must be a mapping.")

    normalized: dict[str, int] = {}
    for species, count in composition.items():
        key = str(species)
        if isinstance(count, bool):
            raise ValueError("Composition counts must be integer-valued.")
        if isinstance(count, Integral):
            value = int(count)
        elif isinstance(count, Real):
            numeric = float(count)
            if not math.isfinite(numeric) or not numeric.is_integer():
                raise ValueError("Composition counts must be integer-valued.")
            value = int(numeric)
        else:
            raise ValueError(
                "Composition counts must be integer-valued."
            )
        if value < 0:
            raise ValueError("Composition counts must be non-negative.")
        normalized[key] = value

    if sum(normalized.values()) <= 0:
        raise ValueError(
            "Each sample composition must contain at least one atom."
        )
    return normalized


def _normalize_energy_sample(sample: Any) -> _EnergySample:
    """Normalize one public regression sample."""
    if isinstance(sample, _EnergySample):
        return sample

    if isinstance(sample, tuple) and len(sample) == 2:
        composition, energy = sample
        return _EnergySample(
            composition=_normalize_composition(composition),
            energy=_validated_energy(energy),
        )

    raise TypeError(
        "Each regression sample must be a ``(composition, energy)`` pair."
    )


def _formula_to_composition(formula: str) -> dict[str, int]:
    """Parse a simple chemical formula into a composition mapping."""
    if not isinstance(formula, str) or not formula.strip():
        raise ValueError(
            "Reference compounds must be non-empty formula strings or "
            "composition mappings."
        )

    token_pattern = re.compile(r"([A-Z][a-z]*)(\d*)")
    normalized = formula.strip()
    composition: dict[str, int] = {}
    position = 0

    for match in token_pattern.finditer(normalized):
        if match.start() != position:
            raise ValueError(
                f"Invalid reference compound formula {formula!r}."
            )
        species = match.group(1)
        count_text = match.group(2)
        count = int(count_text) if count_text else 1
        if count <= 0:
            raise ValueError(
                f"Invalid reference compound formula {formula!r}."
            )
        composition[species] = composition.get(species, 0) + count
        position = match.end()

    if position != len(normalized) or not composition:
        raise ValueError(
            f"Invalid reference compound formula {formula!r}."
        )

    return composition


def _normalize_reference_compound(compound: Any) -> dict[str, int]:
    """Normalize one reference-compound specification."""
    if isinstance(compound, Mapping):
        return _normalize_composition(compound)
    if isinstance(compound, str):
        return _normalize_composition(_formula_to_composition(compound))
    raise TypeError(
        "Each reference compound must be a formula string or a composition "
        "mapping."
    )


def _composition_key(
    composition: Mapping[str, int],
) -> tuple[tuple[str, int], ...]:
    """Return a hashable canonical key for one composition."""
    return tuple(
        (species, int(count))
        for species, count in sorted(composition.items())
        if int(count) != 0
    )


def _solve_atomic_energies(
    samples: Sequence[_EnergySample],
    *,
    fixed_atomic_energies: dict[str, float] | None,
) -> _SolveResult:
    """Solve one constrained atomic-energy system from normalized samples."""
    if len(samples) == 0:
        raise ValueError("At least one regression sample is required.")

    seen_species: set[str] = set()
    for sample in samples:
        seen_species.update(sample.composition.keys())
    species_order = sorted(seen_species)

    fixed = _normalize_fixed_atomic_energies(
        fixed_atomic_energies,
        species_order=species_order,
    )

    composition_rows = [
        _composition_row(sample.composition, species_order=species_order)
        for sample in samples
    ]
    energies = [sample.energy for sample in samples]

    composition_matrix = np.stack(composition_rows)
    energy_vector = np.array(energies, dtype=np.float64)

    fixed_species = [species for species in species_order if species in fixed]
    free_species = [species for species in species_order if species not in fixed]

    energy_offset = np.zeros(len(samples), dtype=np.float64)
    if fixed_species:
        fixed_indices = [species_order.index(species) for species in fixed_species]
        fixed_values = np.array(
            [fixed[species] for species in fixed_species],
            dtype=np.float64,
        )
        energy_offset = composition_matrix[:, fixed_indices] @ fixed_values

    reduced_targets = energy_vector - energy_offset

    rank = 0
    singular_values: list[float] = []
    residual_sum_squares = 0.0
    resolved = dict(fixed)

    if free_species:
        free_indices = [species_order.index(species) for species in free_species]
        free_matrix = composition_matrix[:, free_indices]
        solution, residuals, rank, singular = np.linalg.lstsq(
            free_matrix,
            reduced_targets,
            rcond=None,
        )
        singular_values = [float(value) for value in singular.tolist()]

        if int(rank) < len(free_species):
            raise ValueError(
                "Reference-energy system is underdetermined or rank-deficient. "
                "Provide fixed_atomic_energies for one or more species."
            )

        predicted = free_matrix @ solution
        residual_sum_squares = float(np.sum((predicted - reduced_targets) ** 2))
        for species, energy in zip(free_species, solution.tolist()):
            resolved[species] = float(energy)
    else:
        predicted = np.zeros(len(samples), dtype=np.float64)
        rank = 0
        residual_sum_squares = float(np.sum((predicted - reduced_targets) ** 2))

    rmse = math.sqrt(residual_sum_squares / float(len(samples)))

    return _SolveResult(
        atomic_energies=resolved,
        species_order=list(species_order),
        fixed_atomic_energies=dict(fixed),
        free_species=list(free_species),
        rank=int(rank),
        residual_sum_squares=float(residual_sum_squares),
        rmse=float(rmse),
        singular_values=singular_values,
    )


def _select_reference_compound_samples(
    samples,
    *,
    reference_compounds: Sequence[Any],
) -> tuple[list[_EnergySample], int, list[dict[str, int]], list[int]]:
    """Select minimum-energy samples for the requested reference compounds."""
    if len(reference_compounds) == 0:
        raise ValueError("At least one reference compound is required.")

    normalized_references = [
        _normalize_reference_compound(compound)
        for compound in reference_compounds
    ]
    reference_keys = [_composition_key(comp) for comp in normalized_references]

    duplicates = {
        key for key in reference_keys if reference_keys.count(key) > 1
    }
    if duplicates:
        raise ValueError(
            "reference_compounds contains duplicate or equivalent "
            "compositions."
        )

    best_samples: dict[tuple[tuple[str, int], ...], _EnergySample] = {}
    candidate_counts = {key: 0 for key in reference_keys}
    n_samples_total = 0

    for raw_sample in samples:
        sample = _normalize_energy_sample(raw_sample)
        n_samples_total += 1
        sample_key = _composition_key(sample.composition)
        if sample_key not in candidate_counts:
            continue

        candidate_counts[sample_key] += 1
        previous = best_samples.get(sample_key)
        if previous is None or sample.energy < previous.energy:
            best_samples[sample_key] = sample

    if n_samples_total == 0:
        raise ValueError("At least one regression sample is required.")

    missing = [
        normalized_references[index]
        for index, key in enumerate(reference_keys)
        if key not in best_samples
    ]
    if missing:
        raise ValueError(
            "Missing requested reference compounds in the provided samples: "
            f"{missing!r}."
        )

    selected_samples = [best_samples[key] for key in reference_keys]
    counts_in_order = [candidate_counts[key] for key in reference_keys]
    return (
        selected_samples,
        n_samples_total,
        normalized_references,
        counts_in_order,
    )


def _composition_from_species_sequence(species: Sequence[Any]) -> dict[str, int]:
    """Return a composition mapping from a species list."""
    composition: dict[str, int] = {}
    for species_name in species:
        key = str(species_name)
        composition[key] = composition.get(key, 0) + 1
    return composition


def _iter_composition_energy_samples_from_atomic_structure(
    structure: AtomicStructure,
) -> Iterator[tuple[dict[str, int], float]]:
    """Yield one ``(composition, energy)`` sample per frame."""
    composition = _normalize_composition(structure.composition)
    for frame in range(structure.nframes):
        yield dict(composition), _validated_energy(structure.energy[frame])


[docs] def iter_composition_energy_samples_from_files( paths: PathLike[str] | str | Iterable[PathLike[str] | str], *, frmt: str | None = None, **read_kwargs, ) -> Iterator[tuple[dict[str, int], float]]: """Yield lazy ``(composition, energy)`` samples from structure files. Parameters ---------- paths : path-like or iterable of path-like One path or an iterable of paths readable by :func:`aenet.io.structure.read`. Files containing multiple frames yield one sample per frame. frmt : str, optional Explicit input format forwarded to :func:`aenet.io.structure.read`. **read_kwargs Additional keyword arguments forwarded to :func:`aenet.io.structure.read`. Yields ------ tuple[dict[str, int], float] One ``(composition, energy)`` pair per frame. """ if isinstance(paths, (str, PathLike)): iterable: Iterable[PathLike[str] | str] = [paths] else: iterable = paths for path in iterable: structure = read_structure(path, frmt=frmt, **read_kwargs) yield from _iter_composition_energy_samples_from_atomic_structure( structure )
def _normalize_fixed_atomic_energies( fixed_atomic_energies: dict[str, float] | None, *, species_order: list[str], ) -> dict[str, float]: """Validate user-supplied fixed species energies.""" if fixed_atomic_energies is None: return {} normalized = { str(species): float(energy) for species, energy in fixed_atomic_energies.items() } for species, energy in normalized.items(): if species not in species_order: raise ValueError( f"fixed_atomic_energies contains unknown species {species!r}." ) if not math.isfinite(energy): raise ValueError( "fixed_atomic_energies must contain only finite values." ) return normalized def _composition_row( composition: dict[str, int], *, species_order: list[str], ) -> np.ndarray: """Return the species-count row for one composition.""" return np.array( [composition.get(species, 0) for species in species_order], dtype=np.float64, ) def _normalize_subset_size( *, subset_size: int | None, subset_fraction: float | None, n_hint: int | None = None, ) -> int | None: """Validate and normalize optional subset selection arguments.""" if subset_size is not None and subset_fraction is not None: raise ValueError( "Specify at most one of subset_size and subset_fraction." ) if subset_size is not None: if isinstance(subset_size, bool) or not isinstance( subset_size, Integral ): raise ValueError("subset_size must be a positive integer.") size = int(subset_size) if size <= 0: raise ValueError("subset_size must be a positive integer.") if n_hint is not None and size > n_hint: raise ValueError( "subset_size cannot exceed the number of available samples." ) return size if subset_fraction is None: return None fraction = float(subset_fraction) if not math.isfinite(fraction) or fraction <= 0.0 or fraction > 1.0: raise ValueError( "subset_fraction must be a finite float in the interval (0, 1]." ) if n_hint is None: raise ValueError( "subset_fraction requires a sized sample collection. Use " "subset_size for lazy iterators." ) return max(1, int(math.ceil(n_hint * fraction))) def _sample_count_hint(samples: Any) -> int | None: """Return an optional item-count hint for a sample iterable.""" try: return len(samples) # type: ignore[arg-type] except Exception: return None def _select_regression_samples( samples, *, subset_size: int | None, subset_fraction: float | None, random_seed: int | None, ) -> tuple[list[_EnergySample], int, set[str]]: """Return the chosen regression samples and the total sample count.""" normalized_subset_size = _normalize_subset_size( subset_size=subset_size, subset_fraction=subset_fraction, n_hint=_sample_count_hint(samples), ) rng = random.Random(random_seed) selected: list[_EnergySample] = [] n_samples_total = 0 seen_species: set[str] = set() for raw_sample in samples: sample = _normalize_energy_sample(raw_sample) n_samples_total += 1 seen_species.update(sample.composition.keys()) if normalized_subset_size is None: selected.append(sample) continue if len(selected) < normalized_subset_size: selected.append(sample) continue replace_idx = rng.randint(0, n_samples_total - 1) if replace_idx < normalized_subset_size: selected[replace_idx] = sample if n_samples_total == 0: raise ValueError("At least one regression sample is required.") if ( normalized_subset_size is not None and n_samples_total < normalized_subset_size ): raise ValueError( "subset_size cannot exceed the number of available samples." ) return selected, n_samples_total, seen_species
[docs] @dataclass(frozen=True) class ReferenceEnergies: """Resolved atomic reference energies with provenance metadata. Parameters ---------- _atomic_energies : dict[str, float] Internal storage for the resolved species-energy mapping. method : str Name of the construction workflow. _metadata : dict, optional Workflow metadata and diagnostics. """ _atomic_energies: dict[str, float] method: str _metadata: dict[str, Any] = field(default_factory=dict, repr=False) def __post_init__(self) -> None: """Normalize stored mappings to plain finite floats.""" normalized = { str(species): float(energy) for species, energy in self._atomic_energies.items() } for energy in normalized.values(): if not math.isfinite(energy): raise ValueError( "atomic_energies must contain only finite values." ) object.__setattr__(self, "_atomic_energies", normalized) object.__setattr__(self, "_metadata", copy.deepcopy(self._metadata)) @property def atomic_energies(self) -> dict[str, float]: """Return a copy of the resolved atomic reference energies.""" return dict(self._atomic_energies) @property def metadata(self) -> dict[str, Any]: """Return a copy of the workflow metadata and diagnostics.""" return copy.deepcopy(self._metadata)
[docs] @classmethod def from_regression( cls, samples, *, fixed_atomic_energies: dict[str, float] | None = None, subset_size: int | None = None, subset_fraction: float | None = None, random_seed: int | None = None, ) -> ReferenceEnergies: """Estimate atomic energies from total-energy regression. Parameters ---------- samples : iterable Iterable of ``(composition, energy)`` pairs. ``composition`` must be a mapping from species labels to integer counts, and ``energy`` must be a finite total energy. The iterable may be lazy. fixed_atomic_energies : dict[str, float], optional User-specified species energies held fixed during the fit. This is required for underdetermined composition spaces unless the user explicitly chooses a different reference convention upstream. subset_size : int, optional Number of regression samples to draw uniformly without replacement before fitting. Mutually exclusive with ``subset_fraction``. subset_fraction : float, optional Fraction of regression samples to draw uniformly without replacement before fitting. Mutually exclusive with ``subset_size``. random_seed : int, optional Seed used for deterministic subset selection. Returns ------- ReferenceEnergies Resolved atomic reference energies and fit metadata. Raises ------ ValueError If the inputs are empty, contain invalid samples, specify invalid subset arguments, contain unknown fixed species, or remain underdetermined after applying user constraints. """ selected_samples, n_samples_total, _ = _select_regression_samples( samples, subset_size=subset_size, subset_fraction=subset_fraction, random_seed=random_seed, ) solve = _solve_atomic_energies( selected_samples, fixed_atomic_energies=fixed_atomic_energies, ) metadata = { "species_order": list(solve.species_order), "fixed_atomic_energies": dict(solve.fixed_atomic_energies), "free_species": list(solve.free_species), "n_samples_total": int(n_samples_total), "n_samples_used": int(len(selected_samples)), "subset_size": int(len(selected_samples)) if subset_size is not None or subset_fraction is not None else None, "subset_fraction": ( float(subset_fraction) if subset_fraction is not None else None ), "random_seed": random_seed, "rank": int(solve.rank), "residual_sum_squares": float(solve.residual_sum_squares), "rmse": float(solve.rmse), "singular_values": list(solve.singular_values), } return cls( _atomic_energies=solve.atomic_energies, method="regression", _metadata=metadata, )
[docs] @classmethod def from_reference_compounds( cls, samples, *, reference_compounds: Sequence[Any], fixed_atomic_energies: dict[str, float] | None = None, ) -> ReferenceEnergies: """Construct atomic energies from user-chosen reference compounds. Parameters ---------- samples : iterable Iterable of ``(composition, energy)`` pairs. ``composition`` must be a mapping from species labels to integer counts, and ``energy`` must be a finite total energy. The iterable may be lazy. reference_compounds : sequence Requested reference compositions, specified either as formula strings such as ``"TiO2"`` or as explicit composition mappings. If multiple matching samples exist for one requested composition, the lowest-energy sample is used. fixed_atomic_energies : dict[str, float], optional User-specified species energies held fixed during the solve. Returns ------- ReferenceEnergies Resolved atomic reference energies and solver metadata. Raises ------ ValueError If the inputs are empty, contain invalid samples, request missing or duplicate reference compounds, contain unknown fixed species, or remain underdetermined after applying user constraints. """ ( selected_samples, n_samples_total, normalized_references, candidate_counts, ) = _select_reference_compound_samples( samples, reference_compounds=reference_compounds, ) solve = _solve_atomic_energies( selected_samples, fixed_atomic_energies=fixed_atomic_energies, ) metadata = { "species_order": list(solve.species_order), "fixed_atomic_energies": dict(solve.fixed_atomic_energies), "free_species": list(solve.free_species), "n_samples_total": int(n_samples_total), "n_reference_compounds": int(len(normalized_references)), "reference_compounds": [dict(comp) for comp in normalized_references], "reference_candidate_counts": list(candidate_counts), "selected_reference_samples": [ { "composition": dict(sample.composition), "energy": float(sample.energy), } for sample in selected_samples ], "rank": int(solve.rank), "residual_sum_squares": float(solve.residual_sum_squares), "rmse": float(solve.rmse), "singular_values": list(solve.singular_values), } return cls( _atomic_energies=solve.atomic_energies, method="reference_compounds", _metadata=metadata, )