| import matplotlib.pyplot as plt |
| import os |
| from pathlib import Path |
|
|
| def plot_loss_curve(train_losses, val_accs=None, out_path=None, filename="loss_curve.png", prefix=""): |
| """ |
| Plots the training loss and optionally validation accuracy over epochs. |
| Saves the figure to out_path / filename. |
| """ |
| if out_path is None: |
| return |
| |
| out_path = Path(out_path) |
| out_path.mkdir(parents=True, exist_ok=True) |
| |
| fig, ax1 = plt.subplots(figsize=(10, 5)) |
| |
| epochs = range(len(train_losses)) |
| |
| color = 'tab:red' |
| ax1.set_xlabel('Epochs') |
| ax1.set_ylabel('Training Loss', color=color) |
| ax1.plot(epochs, train_losses, color=color, marker='o', label=f'{prefix} Train Loss') |
| ax1.tick_params(axis='y', labelcolor=color) |
| |
| if val_accs and len(val_accs) == len(train_losses): |
| ax2 = ax1.twinx() |
| color = 'tab:blue' |
| ax2.set_ylabel('Validation Acc', color=color) |
| ax2.plot(epochs, val_accs, color=color, marker='x', label=f'{prefix} Val Acc') |
| ax2.tick_params(axis='y', labelcolor=color) |
| |
| fig.tight_layout() |
| plt.title(f'{prefix} Training Metrics') |
| plt.grid(True) |
| |
| save_path = out_path / filename |
| plt.savefig(save_path) |
| plt.close() |
| print(f"Saved {prefix} loss curve to {save_path}") |
|
|