Source code for aenet.torch_training.builders.network_builder

"""
Network architecture builder for PyTorch training.

Handles construction of per-species neural networks from architecture
specifications.
"""

import importlib.util
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn


[docs] class NetworkBuilder: """ Builds neural network architectures for atomic energy prediction. Supports both aenet-PyTorch NetAtom (if available) and a fallback implementation using standard PyTorch modules. Parameters ---------- descriptor : ChebyshevDescriptor Descriptor instance providing feature dimension and species info. device : torch.device Device for network. dtype : torch.dtype Data type for network parameters. """ def __init__(self, descriptor, device: torch.device, dtype: torch.dtype): self.descriptor = descriptor self.device = device self.dtype = dtype @staticmethod def _import_netatom() -> Optional[type]: """ Dynamically import aenet-PyTorch NetAtom from external/aenet-pytorch. Returns ------- NetAtom class or None if not found. """ try: # Determine project root (4 levels up from this file) root = Path(__file__).resolve().parents[4] net_path = ( root / "external" / "aenet-pytorch" / "src" / "network.py" ) if not net_path.exists(): return None spec = importlib.util.spec_from_file_location( "aenet_pytorch.network", str(net_path) ) if spec is None or spec.loader is None: return None module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # type: ignore if hasattr(module, "NetAtom"): return getattr(module, "NetAtom") return None except Exception: return None
[docs] def validate_arch( self, arch: Dict[str, List[Tuple[int, str]]] ) -> Tuple[List[List[int]], List[List[str]]]: """ Validate architecture and produce per-species hidden sizes and activations. Parameters ---------- arch : dict {species_symbol: [(nodes, activation), ...]} Output layer is implicit. Returns ------- hidden_sizes : list of list of int Per-species hidden layer sizes. activations : list of list of str Per-species activation functions. Raises ------ ValueError On unsupported activation or missing species. """ supported = {"linear", "tanh", "sigmoid"} species_order = list(self.descriptor.species) hidden_sizes: List[List[int]] = [] activations: List[List[str]] = [] for s in species_order: if s not in arch: raise ValueError(f"Species '{s}' missing in architecture.") layers = arch[s] hs: List[int] = [] acts: List[str] = [] for nodes, act in layers: act_l = act.lower() if act_l not in supported: raise ValueError( f"Unsupported activation '{act}' for species '{s}'. " f"Supported: {sorted(supported)}" ) hs.append(int(nodes)) acts.append(act_l) if len(hs) == 0: raise ValueError( f"Architecture for species '{s}' must be non-empty." ) hidden_sizes.append(hs) activations.append(acts) return hidden_sizes, activations
def _build_fallback_per_species_mlps( self, n_features: int, species: List[str], hidden_sizes: List[List[int]], activations: List[List[str]], ) -> nn.ModuleList: """ Fallback builder that mimics NetAtom.functions[iesp] layout. Returns ------- nn.ModuleList Per-species nn.Sequential models mapping (F) -> (1). """ act_map: Dict[str, nn.Module] = { "linear": nn.Identity(), "tanh": nn.Tanh(), "sigmoid": nn.Sigmoid(), } seqs: List[nn.Sequential] = [] for i, _sp in enumerate(species): hs = hidden_sizes[i] acts = activations[i] layers: List[Tuple[str, nn.Module]] = [] # First linear + act layers.append( (f"Linear_Sp{i+1}_F1", nn.Linear(n_features, hs[0])) ) layers.append((f"Active_Sp{i+1}_F1", act_map[acts[0]])) # Hidden stacks for j in range(1, len(hs)): layers.append( (f"Linear_Sp{i+1}_F{j+1}", nn.Linear(hs[j - 1], hs[j])) ) layers.append((f"Active_Sp{i+1}_F{j+1}", act_map[acts[j]])) # Output layer layers.append( (f"Linear_Sp{i+1}_F{len(hs)+1}", nn.Linear(hs[-1], 1)) ) seqs.append(nn.Sequential(OrderedDict(layers))) return nn.ModuleList(seqs)
[docs] def build_network( self, arch: Dict[str, List[Tuple[int, str]]] ) -> nn.Module: """ Build NetAtom (preferred) or fallback per-species MLPs. Parameters ---------- arch : dict Architecture specification per species. Returns ------- nn.Module Network with attributes: - .functions: ModuleList of per-species Sequential MLPs - .device: device string """ species = list(self.descriptor.species) n_features = int(self.descriptor.get_n_features()) hidden_sizes, activations = self.validate_arch(arch) NetAtom = self._import_netatom() if NetAtom is not None: # NetAtom expects lists per species input_size = [n_features for _ in species] alpha = 1.0 net = NetAtom( input_size=input_size, hidden_size=hidden_sizes, species=species, activations=activations, alpha=alpha, device=str(self.device), ) # Ensure dtype/device if self.dtype == torch.float64: net = net.double() else: net = net.float() # NetAtom stores device string net.device = str(self.device) net.to(self.device) return net # Fallback: simple wrapper with .functions and .device class _FallbackNet(nn.Module): def __init__(self_inner): super().__init__() self_inner.functions = ( self._build_fallback_per_species_mlps( n_features=n_features, species=species, hidden_sizes=hidden_sizes, activations=activations, ) ) self_inner.species = species # for debugging self_inner.device = str(self.device) def forward(self_inner, *args, **kwargs): """Not used; adapter calls .functions.""" raise RuntimeError( "Use EnergyModelAdapter to call per-atom energies." ) net = _FallbackNet() if self.dtype == torch.float64: net = net.double() else: net = net.float() net.to(self.device) return net