| import json |
| import logging |
| import math |
| import os |
| import time |
| from contextlib import suppress |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| try: |
| import wandb |
| except ImportError: |
| wandb = None |
|
|
| from open_clip import LPLoss, LPMetrics, lp_gather_features |
| from open_clip.utils import do_mixup, get_mix_lambda |
| from .distributed import is_master |
| from .zero_shot import zero_shot_eval |
|
|
|
|
| class AverageMeter(object): |
| """Computes and stores the average and current value""" |
|
|
| def __init__(self): |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
|
|
| def unwrap_model(model): |
| if hasattr(model, "module"): |
| return model.module |
| else: |
| return model |
|
|
|
|
| def train_one_epoch( |
| model, |
| data, |
| epoch, |
| optimizer, |
| scaler, |
| scheduler, |
| args, |
| tb_writer=None, |
| extra_suffix="", |
| ): |
| device = torch.device(args.device) |
| autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress |
| model.train() |
| loss = LPLoss(args.lp_loss) |
|
|
| dataloader, sampler = data["train"].dataloader, data["train"].sampler |
| if args.distributed and sampler is not None: |
| sampler.set_epoch(epoch) |
| num_batches_per_epoch = dataloader.num_batches |
| sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) |
|
|
| |
| if args.dataset_type == "toy": |
| dataloader.dataset.generate_queue() |
|
|
| loss_m = AverageMeter() |
| batch_time_m = AverageMeter() |
| data_time_m = AverageMeter() |
| end = time.time() |
|
|
| for i, batch in enumerate(dataloader): |
| step = num_batches_per_epoch * epoch + i |
|
|
| if isinstance(scheduler, dict): |
| for s in scheduler.values(): |
| s(step) |
| else: |
| scheduler(step) |
|
|
| audio = batch |
| class_label = batch["class_label"] |
| |
| class_label = class_label.to(device=device, non_blocking=True) |
|
|
| if args.mixup: |
| |
| mix_lambda = torch.from_numpy( |
| get_mix_lambda(0.5, len(audio["waveform"])) |
| ).to(device) |
| class_label = do_mixup(class_label, mix_lambda) |
| else: |
| mix_lambda = None |
|
|
| data_time_m.update(time.time() - end) |
| if isinstance(optimizer, dict): |
| for o_ in optimizer.values(): |
| o_.zero_grad() |
| else: |
| optimizer.zero_grad() |
|
|
| with autocast(): |
| pred = model(audio, mix_lambda=mix_lambda, device=device) |
| total_loss = loss(pred, class_label) |
|
|
| if isinstance(optimizer, dict): |
| if scaler is not None: |
| scaler.scale(total_loss).backward() |
| for o_ in optimizer.values(): |
| if args.horovod: |
| o_.synchronize() |
| scaler.unscale_(o_) |
| with o_.skip_synchronize(): |
| scaler.step(o_) |
| else: |
| scaler.step(o_) |
| scaler.update() |
| else: |
| total_loss.backward() |
| for o_ in optimizer.values(): |
| o_.step() |
| else: |
| if scaler is not None: |
| scaler.scale(total_loss).backward() |
| if args.horovod: |
| optimizer.synchronize() |
| scaler.unscale_(optimizer) |
| with optimizer.skip_synchronize(): |
| scaler.step(optimizer) |
| else: |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| total_loss.backward() |
| optimizer.step() |
|
|
| |
| with torch.no_grad(): |
| unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100)) |
| unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100)) |
|
|
| batch_time_m.update(time.time() - end) |
| end = time.time() |
| batch_count = i + 1 |
|
|
| if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): |
| if isinstance(audio, dict): |
| batch_size = len(audio["waveform"]) |
| else: |
| batch_size = len(audio) |
| num_samples = batch_count * batch_size * args.world_size |
| samples_per_epoch = dataloader.num_samples |
| percent_complete = 100.0 * batch_count / num_batches_per_epoch |
|
|
| |
| loss_m.update(total_loss.item(), batch_size) |
| if isinstance(optimizer, dict): |
| logging.info( |
| f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " |
| f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " |
| f"Data (t): {data_time_m.avg:.3f} " |
| f"Batch (t): {batch_time_m.avg:.3f} " |
| f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}" |
| ) |
| log_data = { |
| "loss": loss_m.val, |
| "data_time": data_time_m.val, |
| "batch_time": batch_time_m.val, |
| "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], |
| } |
| else: |
| logging.info( |
| f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " |
| f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " |
| f"Data (t): {data_time_m.avg:.3f} " |
| f"Batch (t): {batch_time_m.avg:.3f} " |
| f"LR: {optimizer.param_groups[0]['lr']:5f} " |
| ) |
|
|
| |
| log_data = { |
| "loss": loss_m.val, |
| "data_time": data_time_m.val, |
| "batch_time": batch_time_m.val, |
| "lr": optimizer.param_groups[0]["lr"], |
| } |
| for name, val in log_data.items(): |
| name = f"train{extra_suffix}/{name}" |
| if tb_writer is not None: |
| tb_writer.add_scalar(name, val, step) |
| if args.wandb: |
| assert wandb is not None, "Please install wandb." |
| wandb.log({name: val, "step": step}) |
|
|
| |
| batch_time_m.reset() |
| data_time_m.reset() |
| |
|
|
|
|
| def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): |
| metrics = {} |
| if not args.parallel_eval: |
| if not is_master(args): |
| return metrics |
| device = torch.device(args.device) |
| model.eval() |
|
|
| |
| |
| |
| if is_master(args): |
| print("Evaluating...") |
| metric_names = args.lp_metrics.split(",") |
| eval_tool = LPMetrics(metric_names=metric_names) |
|
|
| autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress |
| if "val" in data and ( |
| args.val_frequency |
| and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) |
| ): |
| if args.parallel_eval: |
| dataloader, sampler = data["val"].dataloader, data["val"].sampler |
| if args.distributed and sampler is not None: |
| sampler.set_epoch(epoch) |
| samples_per_val = dataloader.num_samples |
| else: |
| dataloader = data["val"].dataloader |
| num_samples = 0 |
| samples_per_val = dataloader.num_samples |
|
|
| eval_info = {"pred": [], "target": []} |
| with torch.no_grad(): |
| for i, batch in enumerate(dataloader): |
| audio = batch |
| class_label = batch["class_label"] |
|
|
| |
| class_label = class_label.to(device=device, non_blocking=True) |
|
|
| with autocast(): |
| pred = model(audio, device=device) |
| if args.parallel_eval: |
| pred, class_label = lp_gather_features( |
| pred, class_label, args.world_size, args.horovod |
| ) |
| eval_info["pred"].append(pred) |
| eval_info["target"].append(class_label) |
|
|
| num_samples += class_label.shape[0] |
|
|
| if (i % 100) == 0: |
| logging.info( |
| f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" |
| ) |
|
|
| if is_master(args): |
| eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu() |
| eval_info["target"] = torch.cat(eval_info["target"], 0).cpu() |
| metric_dict = eval_tool.evaluate_mertics( |
| eval_info["pred"], eval_info["target"] |
| ) |
| metrics.update(metric_dict) |
| if "epoch" not in metrics.keys(): |
| metrics.update({"epoch": epoch}) |
|
|
| if is_master(args): |
| if not metrics: |
| return metrics |
|
|
| logging.info( |
| f"Eval Epoch: {epoch} " |
| + "\n".join( |
| ["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics] |
| ) |
| ) |
| if args.save_logs: |
| for name, val in metrics.items(): |
| if tb_writer is not None: |
| tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch) |
|
|
| with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: |
| f.write(json.dumps(metrics)) |
| f.write("\n") |
|
|
| if args.wandb: |
| assert wandb is not None, "Please install wandb." |
| for name, val in metrics.items(): |
| wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch}) |
|
|
| return metrics |
| else: |
| return metrics |
|
|