| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
|
|
| def plot_losses(log_dir): |
| """Plot training losses from TensorBoard logs""" |
| |
| pass |
|
|
| def save_checkpoint(model, optimizer, epoch, path): |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| }, path) |
|
|
| def load_checkpoint(model, optimizer, path): |
| checkpoint = torch.load(path) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| return checkpoint['epoch'] |
|
|
| def show_samples(samples): |
| """Display generated samples""" |
| plt.figure(figsize=(10, 10)) |
| plt.imshow(np.transpose(samples.numpy(), (1, 2, 0))) |
| plt.axis('off') |
| plt.show() |