| import torch |
| from torch import nn |
| from torch.optim import Optimizer |
| from torch.utils.data import DataLoader |
| from torch.cuda.amp import GradScaler, autocast |
| import numpy as np |
| from tqdm import tqdm |
| from typing import Dict, Tuple |
|
|
|
|
| from utils import barrier, reduce_mean, update_loss_info |
|
|
|
|
| def train( |
| model: nn.Module, |
| data_loader: DataLoader, |
| loss_fn: nn.Module, |
| optimizer: Optimizer, |
| grad_scaler: GradScaler, |
| device: torch.device, |
| rank: int, |
| nprocs: int, |
| ) -> Tuple[nn.Module, Optimizer, GradScaler, Dict[str, float]]: |
| model.train() |
| info = None |
| data_iter = tqdm(data_loader) if rank == 0 else data_loader |
| ddp = nprocs > 1 |
| regression = (model.module.bins is None) if ddp else (model.bins is None) |
|
|
| for image, target_points, target_density in data_iter: |
| image = image.to(device) |
| target_points = [p.to(device) for p in target_points] |
| target_density = target_density.to(device) |
| with torch.set_grad_enabled(True): |
|
|
| if grad_scaler is not None: |
| with autocast(enabled=grad_scaler.is_enabled()): |
| if not regression: |
| pred_class, pred_density = model(image) |
| loss, loss_info = loss_fn(pred_class, pred_density, target_density, target_points) |
| else: |
| pred_density = model(image) |
| loss, loss_info = loss_fn(pred_density, target_density, target_points) |
|
|
| else: |
| if not regression: |
| pred_class, pred_density = model(image) |
| loss, loss_info = loss_fn(pred_class, pred_density, target_density, target_points) |
| else: |
| pred_density = model(image) |
| loss, loss_info = loss_fn(pred_density, target_density, target_points) |
|
|
| optimizer.zero_grad() |
| if grad_scaler is not None: |
| grad_scaler.scale(loss).backward() |
| grad_scaler.step(optimizer) |
| grad_scaler.update() |
| else: |
| loss.backward() |
| optimizer.step() |
|
|
| loss_info = {k: reduce_mean(v.detach(), nprocs).item() if ddp else v.detach().item() for k, v in loss_info.items()} |
| |
| |
| info = update_loss_info(info, loss_info) |
|
|
| barrier(ddp) |
|
|
| return model, optimizer, grad_scaler, {k: np.mean(v) for k, v in info.items()} |
|
|