| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| import math |
| import sys |
| from typing import Iterable |
|
|
| import torch |
|
|
| import util.misc as misc |
| import util.lr_sched as lr_sched |
|
|
|
|
| def train_one_epoch( |
| model, |
| data_loader, |
| optimizer, |
| device, |
| epoch: int, |
| loss_scaler, |
| log_writer=None, |
| args=None, |
| ): |
| model.train(True) |
| metric_logger = misc.MetricLogger(delimiter=" ") |
| metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) |
| header = "Epoch: [{}]".format(epoch) |
| print_freq = 20 |
|
|
| optimizer.zero_grad() |
|
|
| if log_writer is not None: |
| print("log_dir: {}".format(log_writer.log_dir)) |
|
|
| for data_iter_step, (samples, _) in enumerate( |
| metric_logger.log_every(data_loader, print_freq, header) |
| ): |
|
|
| |
| samples = samples.to(device, non_blocking=True) |
| lr_sched.adjust_learning_rate( |
| optimizer, data_iter_step / len(data_loader) + epoch, args |
| ) |
|
|
| with torch.cuda.amp.autocast(): |
| loss, _, _ = model(samples, mask_ratio=args.mask_ratio) |
|
|
| loss_value = loss.item() |
|
|
| if not math.isfinite(loss_value): |
| print("Loss is {}, stopping training".format(loss_value)) |
| sys.exit(1) |
|
|
| loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=1.0) |
| optimizer.zero_grad() |
| torch.cuda.synchronize() |
| metric_logger.update(loss=loss_value) |
|
|
| lr = optimizer.param_groups[0]["lr"] |
| metric_logger.update(lr=lr) |
|
|
| loss_value_reduce = misc.all_reduce_mean(loss_value) |
| if log_writer is not None: |
| """We use epoch_1000x as the x-axis in tensorboard. |
| This calibrates different curves when batch size changes. |
| """ |
| epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) |
| log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x) |
| log_writer.add_scalar("lr", lr, epoch_1000x) |
|
|
| |
| metric_logger.synchronize_between_processes() |
| print("Averaged stats:", metric_logger) |
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|