Source code for aenet.torch_training.training.checkpoint_manager

"""
Checkpoint management for PyTorch training.

Handles saving, loading, and rotating model checkpoints during training.
"""

from pathlib import Path
from typing import Any, Dict, Optional

import torch
import torch.nn as nn


[docs] class CheckpointManager: """ Manages checkpoint operations for training. Handles: - Saving checkpoints with full training state - Loading checkpoints to resume training - Rotating old checkpoints to limit disk usage - Tracking best model based on validation metrics Parameters ---------- checkpoint_dir : str or Path, optional Directory to save checkpoints. If None, checkpoints are disabled. max_to_keep : int, optional Maximum number of checkpoints to keep. Older ones are deleted. save_best : bool Whether to save the best model separately. """ def __init__( self, checkpoint_dir: Optional[str] = None, max_to_keep: Optional[int] = None, save_best: bool = True, ): self.checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir else None self.max_to_keep = max_to_keep self.save_best = save_best self.best_val_loss: Optional[float] = None if self.checkpoint_dir is not None: self._ensure_dir(self.checkpoint_dir) def _ensure_dir(self, path: Path): """Create directory if it doesn't exist.""" path.mkdir(parents=True, exist_ok=True) def _rotate_checkpoints(self): """Remove old checkpoints beyond max_to_keep limit.""" if self.checkpoint_dir is None or self.max_to_keep is None: return if self.max_to_keep <= 0: return ckpts = sorted(self.checkpoint_dir.glob("checkpoint_epoch_*.pt")) if len(ckpts) <= self.max_to_keep: return to_remove = ckpts[: len(ckpts) - self.max_to_keep] for p in to_remove: try: p.unlink() except Exception: pass
[docs] def save_checkpoint( self, trainer, optimizer: torch.optim.Optimizer, epoch: int, training_config: Optional[Any] = None, filename: Optional[str] = None, ): """ Save a training checkpoint using the unified model format. Parameters ---------- trainer : TorchANNPotential Trainer instance to save. optimizer : torch.optim.Optimizer Optimizer to save. epoch : int Current epoch number. training_config : TorchTrainingConfig, optional Training configuration. filename : str, optional Filename for checkpoint. If None, uses format "checkpoint_epoch_{epoch:04d}.pt" """ if self.checkpoint_dir is None: return if filename is None: filename = f"checkpoint_epoch_{epoch:04d}.pt" path = self.checkpoint_dir / filename # Import save_model here to avoid circular imports from ..model_export import save_model # Add epoch number to extra_metadata extra_metadata = { "epoch": int(epoch), "best_val_loss": ( float(self.best_val_loss) if self.best_val_loss is not None else None ), } try: save_model( trainer=trainer, path=path, optimizer=optimizer, training_config=training_config, extra_metadata=extra_metadata, ) # Rotate old checkpoints after successful save self._rotate_checkpoints() except Exception as e: print(f"[WARN] Failed to save checkpoint at {path}: {e}")
[docs] def load_checkpoint( self, path: str, model: nn.Module, optimizer: torch.optim.Optimizer, device: torch.device, ) -> Dict[str, Any]: """ Load a checkpoint and restore training state. Parameters ---------- path : str Path to checkpoint file. model : nn.Module Model to load state into. optimizer : torch.optim.Optimizer Optimizer to load state into. device : torch.device Device to map tensors to. Returns ------- dict Checkpoint payload containing epoch, history, etc. Raises ------ RuntimeError If checkpoint loading fails. """ try: payload = torch.load(path, map_location=device, weights_only=False) model.load_state_dict(payload["model_state_dict"]) optimizer.load_state_dict(payload["optimizer_state_dict"]) if "history" not in payload and "training_history" in payload: payload["history"] = payload["training_history"] if "epoch" not in payload: extra_metadata = payload.get("extra_metadata", {}) or {} if "epoch" in extra_metadata: payload["epoch"] = extra_metadata["epoch"] if "best_val_loss" in payload: self.best_val_loss = payload["best_val_loss"] return payload except Exception as e: raise RuntimeError(f"Failed to load checkpoint '{path}': {e}")
[docs] def should_save_best(self, val_loss: float) -> bool: """ Check if current validation loss is the best so far. Parameters ---------- val_loss : float Current validation loss. Returns ------- bool True if this is the best validation loss. """ if not self.save_best: return False is_best = ( (self.best_val_loss is None) or (val_loss < self.best_val_loss) ) if is_best: self.best_val_loss = float(val_loss) return is_best
[docs] def save_best_model( self, trainer, optimizer: torch.optim.Optimizer, epoch: int, training_config: Optional[Any] = None, ): """Save the best model checkpoint using the unified format.""" if self.checkpoint_dir is None or not self.save_best: return self.save_checkpoint( trainer=trainer, optimizer=optimizer, epoch=epoch, training_config=training_config, filename="best_model.pt", )
[docs] def infer_start_epoch( self, checkpoint_path: str, payload: Optional[Dict[str, Any]] = None, ) -> int: """ Infer starting epoch from checkpoint metadata or filename. Parameters ---------- checkpoint_path : str Path to checkpoint file. payload : dict, optional Loaded checkpoint payload. When present, ``payload["epoch"]`` is used as the authoritative epoch number. Returns ------- int Starting epoch (checkpoint epoch + 1), or 0 if it cannot be inferred. """ if payload is not None: try: if "epoch" in payload: return int(payload["epoch"]) + 1 extra_metadata = payload.get("extra_metadata", {}) or {} if "epoch" in extra_metadata: return int(extra_metadata["epoch"]) + 1 except Exception: pass try: name = Path(checkpoint_path).name if name.startswith("checkpoint_epoch_") and name.endswith(".pt"): return int(name[len("checkpoint_epoch_"): -3]) + 1 except Exception: pass return 0