import matplotlib.pyplot as plt import os from pathlib import Path def plot_loss_curve(train_losses, val_accs=None, out_path=None, filename="loss_curve.png", prefix=""): """ Plots the training loss and optionally validation accuracy over epochs. Saves the figure to out_path / filename. """ if out_path is None: return out_path = Path(out_path) out_path.mkdir(parents=True, exist_ok=True) fig, ax1 = plt.subplots(figsize=(10, 5)) epochs = range(len(train_losses)) color = 'tab:red' ax1.set_xlabel('Epochs') ax1.set_ylabel('Training Loss', color=color) ax1.plot(epochs, train_losses, color=color, marker='o', label=f'{prefix} Train Loss') ax1.tick_params(axis='y', labelcolor=color) if val_accs and len(val_accs) == len(train_losses): ax2 = ax1.twinx() color = 'tab:blue' ax2.set_ylabel('Validation Acc', color=color) ax2.plot(epochs, val_accs, color=color, marker='x', label=f'{prefix} Val Acc') ax2.tick_params(axis='y', labelcolor=color) fig.tight_layout() plt.title(f'{prefix} Training Metrics') plt.grid(True) save_path = out_path / filename plt.savefig(save_path) plt.close() print(f"Saved {prefix} loss curve to {save_path}")