"""
Metrics tracking for PyTorch training.
Handles training history and metrics computation.
"""
from typing import Dict, List
[docs]
class MetricsTracker:
"""
Tracks training metrics and history.
Maintains per-epoch metrics for energy RMSE, force RMSE, learning rate,
and timing information.
Parameters
----------
track_detailed_timing : bool
Whether to track detailed timing breakdowns (data loading, loss
computation, optimizer steps).
"""
def __init__(self, track_detailed_timing: bool = False):
self.track_detailed_timing = track_detailed_timing
self.history: Dict[str, List[float]] = {
"train_energy_rmse": [],
"test_energy_rmse": [],
"train_energy_mae": [],
"test_energy_mae": [],
"train_force_rmse": [],
"test_force_rmse": [],
"learning_rate": [],
"epoch_time": [],
"forward_time": [],
"backward_time": [],
}
if track_detailed_timing:
self.history.update(
{
"epoch_data_loading_time_train": [],
"epoch_loss_time_train": [],
"epoch_optimizer_time_train": [],
"epoch_data_loading_time_val": [],
"epoch_loss_time_val": [],
"epoch_optimizer_time_val": [],
"epoch_train_time": [],
"epoch_val_time": [],
}
)
[docs]
def update(
self,
train_energy_rmse: float,
train_energy_mae: float,
train_force_rmse: float,
test_energy_rmse: float,
test_energy_mae: float,
test_force_rmse: float,
learning_rate: float,
epoch_time: float,
forward_time: float,
backward_time: float,
train_timing: Dict[str, float],
val_timing: Dict[str, float],
):
"""
Record metrics for one epoch.
Parameters
----------
train_energy_rmse : float
Training energy RMSE.
train_energy_mae : float
Training energy MAE.
train_force_rmse : float
Training force RMSE (or NaN if not computed).
test_energy_rmse : float
Validation energy RMSE (or NaN if no validation).
test_energy_mae : float
Validation energy MAE (or NaN if no validation).
test_force_rmse : float
Validation force RMSE (or NaN if not computed).
learning_rate : float
Current learning rate.
epoch_time : float
Total epoch time in seconds.
forward_time : float
Forward pass time in seconds.
backward_time : float
Backward pass time in seconds.
train_timing : dict
Detailed training timing breakdown with keys:
'data_loading', 'loss_computation', 'optimizer', 'total'
val_timing : dict
Detailed validation timing breakdown (same keys as train_timing).
"""
self.history["train_energy_rmse"].append(float(train_energy_rmse))
self.history["train_energy_mae"].append(float(train_energy_mae))
self.history["train_force_rmse"].append(float(train_force_rmse))
self.history["test_energy_rmse"].append(float(test_energy_rmse))
self.history["test_energy_mae"].append(float(test_energy_mae))
self.history["test_force_rmse"].append(float(test_force_rmse))
self.history["learning_rate"].append(float(learning_rate))
self.history["epoch_time"].append(float(epoch_time))
self.history["forward_time"].append(float(forward_time))
self.history["backward_time"].append(float(backward_time))
if self.track_detailed_timing:
self.history["epoch_data_loading_time_train"].append(
float(train_timing.get("data_loading", 0.0))
)
self.history["epoch_loss_time_train"].append(
float(train_timing.get("loss_computation", 0.0))
)
self.history["epoch_optimizer_time_train"].append(
float(train_timing.get("optimizer", 0.0))
)
self.history["epoch_train_time"].append(
float(train_timing.get("total", 0.0))
)
self.history["epoch_data_loading_time_val"].append(
float(val_timing.get("data_loading", 0.0))
)
self.history["epoch_loss_time_val"].append(
float(val_timing.get("loss_computation", 0.0))
)
self.history["epoch_optimizer_time_val"].append(
float(val_timing.get("optimizer", 0.0))
)
self.history["epoch_val_time"].append(
float(val_timing.get("total", 0.0))
)
[docs]
def get_history(self) -> Dict[str, List[float]]:
"""
Get complete training history.
Returns
-------
dict
Dictionary mapping metric names to lists of per-epoch values.
"""
return self.history
[docs]
def replace_latest(
self,
train_energy_rmse: float,
train_energy_mae: float,
train_force_rmse: float,
test_energy_rmse: float,
test_energy_mae: float,
test_force_rmse: float,
) -> None:
"""
Replace the latest stored metric values.
This is used when the trainer wants to recompute final epoch metrics
with a deterministic full-pass evaluation after the optimization loop
has finished.
Parameters
----------
train_energy_rmse : float
Replacement training energy RMSE.
train_energy_mae : float
Replacement training energy MAE.
train_force_rmse : float
Replacement training force RMSE.
test_energy_rmse : float
Replacement validation energy RMSE.
test_energy_mae : float
Replacement validation energy MAE.
test_force_rmse : float
Replacement validation force RMSE.
"""
replacements = {
"train_energy_rmse": float(train_energy_rmse),
"train_energy_mae": float(train_energy_mae),
"train_force_rmse": float(train_force_rmse),
"test_energy_rmse": float(test_energy_rmse),
"test_energy_mae": float(test_energy_mae),
"test_force_rmse": float(test_force_rmse),
}
for key, value in replacements.items():
if key not in self.history or len(self.history[key]) == 0:
self.history.setdefault(key, []).append(value)
else:
self.history[key][-1] = value
[docs]
def get_latest(self, metric: str) -> float:
"""
Get the latest value for a specific metric.
Parameters
----------
metric : str
Metric name (e.g., 'train_energy_rmse').
Returns
-------
float
Latest value, or NaN if metric not found or empty.
"""
if metric not in self.history or len(self.history[metric]) == 0:
return float("nan")
return float(self.history[metric][-1])
[docs]
def get_best(self, metric: str, mode: str = "min") -> float:
"""
Get the best value for a specific metric.
Parameters
----------
metric : str
Metric name.
mode : str
Either 'min' or 'max'.
Returns
-------
float
Best value, or NaN if metric not found or empty.
"""
if metric not in self.history or len(self.history[metric]) == 0:
return float("nan")
values = self.history[metric]
if mode == "min":
return float(min(values))
elif mode == "max":
return float(max(values))
else:
raise ValueError(f"Invalid mode '{mode}', use 'min' or 'max'")
[docs]
def reset(self):
"""Clear all metrics history."""
for key in self.history:
self.history[key] = []
[docs]
def set_history(self, history: Dict[str, List[float]]):
"""
Set history from a dictionary (e.g., when loading checkpoint).
Parameters
----------
history : dict
Dictionary mapping metric names to lists of values.
"""
self.history = history