| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| from inference_avwm import model_forward_wrapper_av |
| import torch |
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| import matplotlib |
| matplotlib.use('Agg') |
| from collections import OrderedDict |
| from copy import deepcopy |
| from time import time |
| import argparse |
| import logging |
| import os |
| import matplotlib.pyplot as plt |
| import yaml |
|
|
|
|
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import DataLoader, ConcatDataset |
| from torch.utils.data.distributed import DistributedSampler |
| from diffusers.models import AutoencoderKL |
|
|
| from distributed import init_distributed |
| from models import AVCDiT_models |
| from diffusion import create_diffusion |
| from datasets import TrainingDataset |
| from misc import transform |
| from soundstream import SoundStream |
| import torchaudio |
| from eval_audio import build_mel_transform, mel_cosine_stereo, drms_avg_db_stereo, save_ref_hat_spectrogram_panel |
|
|
| |
| |
| |
|
|
|
|
| def load_checkpoint_if_available(model, ema, opt, scaler, config, device, logger, args): |
| start_epoch = 0 |
| train_steps = 0 |
| latest_path = os.path.join(config['results_dir'], config['run_name'], "checkpoints", "latest.pth.tar") |
| if os.path.isfile(latest_path) or config.get('from_checkpoint', 0): |
| latest_path = latest_path if os.path.isfile(latest_path) else config.get('from_checkpoint', 0) |
| print("Loading model from ", latest_path) |
| checkpoint = torch.load(latest_path, map_location=f"cuda:{device}", weights_only=False) |
|
|
| ema_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["ema"].items()} |
| model.load_state_dict(ema_ckp, strict=False) |
| print("Model weights loaded.") |
| ema.load_state_dict(ema_ckp, strict=False) |
| print("EMA weights loaded.") |
|
|
| if args.restart_from_checkpoint: |
| logger.info("Restarting training: epoch and step counters set to 0.") |
| else: |
| if "opt" in checkpoint: |
| opt_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["opt"].items()} |
| opt.load_state_dict(opt_ckp) |
| print("Optimizer state loaded.") |
| if "scaler" in checkpoint and scaler is not None: |
| scaler.load_state_dict(checkpoint["scaler"]) |
| print("GradScaler state loaded.") |
| if "epoch" in checkpoint: |
| start_epoch = checkpoint["epoch"] + 1 |
| if "train_steps" in checkpoint: |
| train_steps = checkpoint["train_steps"] |
| logger.info(f"Resuming from epoch {start_epoch}, step {train_steps}") |
|
|
| return start_epoch, train_steps |
|
|
|
|
| @torch.no_grad() |
| def update_ema(ema_model, model, decay=0.9999): |
| """ |
| Step the EMA model towards the current model. |
| """ |
| ema_params = OrderedDict(ema_model.named_parameters()) |
| model_params = OrderedDict(model.named_parameters()) |
|
|
| for name, param in model_params.items(): |
| name = name.replace('_orig_mod.', '') |
| ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) |
|
|
|
|
| def requires_grad(model, flag=True): |
| """ |
| Set requires_grad flag for all parameters in a model. |
| """ |
| for p in model.parameters(): |
| p.requires_grad = flag |
|
|
|
|
| def cleanup(): |
| """ |
| End DDP training. |
| """ |
| dist.destroy_process_group() |
|
|
|
|
| def create_logger(logging_dir): |
| """ |
| Create a logger that writes to a log file and stdout. |
| """ |
| if dist.get_rank() == 0: |
| logging.basicConfig( |
| level=logging.INFO, |
| format='[\033[34m%(asctime)s\033[0m] %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] |
| ) |
| logger = logging.getLogger(__name__) |
| else: |
| logger = logging.getLogger(__name__) |
| logger.addHandler(logging.NullHandler()) |
| return logger |
|
|
| |
| |
| |
|
|
| def main(args): |
| """ |
| Trains a new AVCDiT model. |
| """ |
| assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
|
|
| |
| _, rank, device, _ = init_distributed() |
| |
| seed = args.global_seed * dist.get_world_size() + rank |
| torch.manual_seed(seed) |
| print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") |
| with open("config/eval_config.yaml", "r") as f: |
| default_config = yaml.safe_load(f) |
| config = default_config |
| |
| with open(args.config, "r") as f: |
| user_config = yaml.safe_load(f) |
| config.update(user_config) |
| |
| |
| os.makedirs(config['results_dir'], exist_ok=True) |
| experiment_dir = f"{config['results_dir']}/{config['run_name']}" |
| checkpoint_dir = f"{experiment_dir}/checkpoints" |
| if rank == 0: |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| logger = create_logger(experiment_dir) |
| logger.info(f"Experiment directory created at {experiment_dir}") |
| else: |
| logger = create_logger(None) |
|
|
| |
| tokenizer_v = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device) |
|
|
| tokenizer_a = SoundStream(C=32, D=16, n_q=8, codebook_size=1024).to(device) |
| tokenizer_a_path=config["tokenizer_a_path"] |
| tokenizer_a_checkpoint = torch.load(tokenizer_a_path, map_location=f"cuda:{device}") |
| tokenizer_a.load_state_dict(tokenizer_a_checkpoint["model_state"]) |
| tokenizer_a.eval() |
|
|
| latent_size = config['image_size'] // 8 |
|
|
| assert config['image_size'] % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." |
| num_cond = config['context_size'] |
| model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4).to(device) |
| |
| ema = deepcopy(model).to(device) |
| requires_grad(ema, False) |
| |
| |
| lr = float(config.get('lr', 1e-4)) |
| opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0) |
|
|
|
|
| bfloat_enable = bool(hasattr(args, 'bfloat16') and args.bfloat16) |
| if bfloat_enable: |
| scaler = torch.amp.GradScaler() |
|
|
| start_epoch, train_steps = load_checkpoint_if_available( |
| model, ema, opt, scaler if bfloat_enable else None, config, device, logger, args |
| ) |
| |
| |
| if args.torch_compile: |
| model = torch.compile(model) |
| model = DDP(model, device_ids=[device]) |
| diffusion = create_diffusion(timestep_respacing="", dual=True) |
| |
| logger.info(f"AVCDiT Parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
| train_dataset = [] |
| test_dataset = [] |
|
|
| for dataset_name in config["datasets"]: |
| data_config = config["datasets"][dataset_name] |
|
|
| for data_split_type in ["train", "test"]: |
| if data_split_type in data_config: |
| goals_per_obs = int(data_config["goals_per_obs"]) |
| if data_split_type == 'test': |
| goals_per_obs = 4 |
| |
| if "distance" in data_config: |
| min_dist_cat=data_config["distance"]["min_dist_cat"] |
| max_dist_cat=data_config["distance"]["max_dist_cat"] |
| else: |
| min_dist_cat=config["distance"]["min_dist_cat"] |
| max_dist_cat=config["distance"]["max_dist_cat"] |
|
|
| if "len_traj_pred" in data_config: |
| len_traj_pred=data_config["len_traj_pred"] |
| else: |
| len_traj_pred=config["len_traj_pred"] |
|
|
| dataset = TrainingDataset( |
| data_folder=data_config["data_folder"], |
| data_split_folder=data_config[data_split_type], |
| dataset_name=dataset_name, |
| image_size=config["image_size"], |
| min_dist_cat=min_dist_cat, |
| max_dist_cat=max_dist_cat, |
| len_traj_pred=len_traj_pred, |
| context_size=config["context_size"], |
| normalize=config["normalize"], |
| goals_per_obs=goals_per_obs, |
| transform=transform, |
| predefined_index=None, |
| traj_stride=1, |
| sample_rate=config["sample_rate"], |
| |
| input_sr=config["input_sr"], |
| evaluate=(data_split_type=="test") |
| ) |
| if data_split_type == "train": |
| train_dataset.append(dataset) |
| else: |
| test_dataset.append(dataset) |
| print(f"Dataset: {dataset_name} ({data_split_type}), size: {len(dataset)}") |
|
|
| |
| print(f"Combining {len(train_dataset)} datasets.") |
| train_dataset = ConcatDataset(train_dataset) |
| test_dataset = ConcatDataset(test_dataset) |
|
|
| sampler = DistributedSampler( |
| train_dataset, |
| num_replicas=dist.get_world_size(), |
| rank=rank, |
| shuffle=True, |
| seed=args.global_seed |
| ) |
| loader = DataLoader( |
| train_dataset, |
| batch_size=config['batch_size'], |
| shuffle=False, |
| sampler=sampler, |
| num_workers=config['num_workers'], |
| pin_memory=True, |
| drop_last=True, |
| persistent_workers=True |
| ) |
| logger.info(f"Dataset contains {len(train_dataset):,} images") |
|
|
| |
| model.train() |
| ema.eval() |
|
|
| |
| log_steps = 0 |
| running_loss = 0 |
| start_time = time() |
|
|
| logger.info(f"Training for {args.epochs} epochs...") |
| for epoch in range(start_epoch, args.epochs): |
| sampler.set_epoch(epoch) |
| steps_per_epoch = len(loader) |
| if rank == 0: |
| logger.info(f"Epoch {epoch} contains {steps_per_epoch} steps.") |
| logger.info(f"Beginning epoch {epoch}...") |
|
|
| for x_v, x_a, y, diff, rel_t in loader: |
| x_v = x_v.to(device, non_blocking=True) |
| x_a = x_a.to(device, non_blocking=True) |
| y = y.to(device, non_blocking=True) |
| diff = diff.to(device, non_blocking=True) |
| rel_t = rel_t.to(device, non_blocking=True) |
| |
| with torch.amp.autocast('cuda', enabled=bfloat_enable, dtype=torch.bfloat16): |
| with torch.no_grad(): |
| |
| B, T = x_v.shape[:2] |
| |
| x_v = x_v.flatten(0,1) |
| x_v = tokenizer_v.encode(x_v).latent_dist.sample().mul_(0.18215) |
| x_v = x_v.unflatten(0, (B, T)) |
| |
| x_a = x_a.flatten(0,1) |
| x_a = tokenizer_a.encoder(x_a) |
| x_a = x_a.unflatten(0, (B, T)) |
| |
| num_goals = T - num_cond |
| |
| x_v_start = x_v[:, num_cond:].flatten(0, 1) |
| x_v_cond = x_v[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_v.shape[2], x_v.shape[3], x_v.shape[4]).flatten(0, 1) |
| x_a_start = x_a[:, num_cond:].flatten(0, 1) |
| x_a_cond = x_a[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a.shape[2], x_a.shape[3]).flatten(0, 1) |
| |
| y = y.flatten(0, 1) |
| rel_t = rel_t.flatten(0, 1) |
|
|
|
|
|
|
| diff = diff.flatten(0, 1) |
| diff_tok = diff.unsqueeze(1).expand(-1, 16, -1) |
| x_a_start = torch.cat([x_a_start, diff_tok], dim=2) |
| |
| t = torch.randint(0, diffusion.num_timesteps, (x_v_start.shape[0],), device=device) |
| model_kwargs = dict(y=y, x_v_cond=x_v_cond, x_a_cond=x_a_cond, rel_t=rel_t) |
| loss_dict = diffusion.training_losses(model, x_v_start, x_a_start, t, model_kwargs) |
| loss = loss_dict["loss"].mean() |
|
|
| if not bfloat_enable: |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| else: |
| scaler.scale(loss).backward() |
| if config.get('grad_clip_val', 0) > 0: |
| scaler.unscale_(opt) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['grad_clip_val']) |
| scaler.step(opt) |
| scaler.update() |
| |
| update_ema(ema, model.module) |
|
|
| |
| running_loss += loss.detach().item() |
| log_steps += 1 |
| train_steps += 1 |
| if train_steps % args.log_every == 0: |
| |
| torch.cuda.synchronize() |
| end_time = time() |
| steps_per_sec = log_steps / (end_time - start_time) |
| samples_per_sec = dist.get_world_size()*x_v_cond.shape[0]*steps_per_sec |
| |
| avg_loss = torch.tensor(running_loss / log_steps, device=device) |
| dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) |
| avg_loss = avg_loss.item() / dist.get_world_size() |
| total_steps = len(loader) * args.epochs |
| progress_pct = train_steps / total_steps * 100 |
|
|
| remaining_steps = total_steps - train_steps |
| eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0 |
| eta_hours = eta_seconds / 3600 |
|
|
| logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Samples/Sec: {samples_per_sec:.2f}") |
| logger.info(f"Progress: {progress_pct:.2f}% | ETA: {eta_hours:.1f}h") |
| |
| running_loss = 0 |
| log_steps = 0 |
| start_time = time() |
|
|
| |
| if train_steps % args.ckpt_every == 0 and train_steps > 0: |
| if rank == 0: |
| checkpoint = { |
| "model": model.module.state_dict(), |
| "ema": ema.state_dict(), |
| "opt": opt.state_dict(), |
| "args": args, |
| "epoch": epoch, |
| "train_steps": train_steps |
| } |
| if bfloat_enable: |
| checkpoint.update({"scaler": scaler.state_dict()}) |
| checkpoint_path = f"{checkpoint_dir}/latest.pth.tar" |
| torch.save(checkpoint, checkpoint_path) |
| if train_steps % (10*args.ckpt_every) == 0 and train_steps > 0: |
| checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pth.tar" |
| torch.save(checkpoint, checkpoint_path) |
| logger.info(f"Saved checkpoint to {checkpoint_path}") |
| |
| if train_steps % args.eval_every == 0 and train_steps > 0: |
| eval_start_time = time() |
| |
| save_dir = os.path.join(experiment_dir, str(train_steps)) |
| sim_score_val = evaluate(ema, tokenizer_v, tokenizer_a, diffusion, test_dataset, rank, config["batch_size"], config["num_workers"], latent_size, device, save_dir, args.global_seed, bfloat_enable, num_cond, config["sample_rate"], config["input_sr"], logger) |
| dist.barrier() |
| eval_end_time = time() |
| eval_time = eval_end_time - eval_start_time |
| |
| logger.info(f"(step={train_steps:07d}) Val Perceptual Loss: {sim_score_val:.4f}, Eval Time: {eval_time:.2f}") |
|
|
| model.eval() |
| |
|
|
| logger.info("Done!") |
| cleanup() |
|
|
| def denormalize_dis(ndata: float, min_v=-20.0, max_v=20.0, scale=0.15): |
| n01 = (float(ndata) + 1.0) / 2.0 |
| raw = n01 * (max_v - min_v) + min_v |
| return raw * scale |
|
|
| @torch.no_grad |
| def evaluate(model, vae, sstream, diffusion, test_dataloaders, rank, batch_size, num_workers, latent_size, device, save_dir, seed, bfloat_enable, num_cond, sample_rate, input_sr, logger): |
| sampler = DistributedSampler( |
| test_dataloaders, |
| num_replicas=dist.get_world_size(), |
| rank=rank, |
| shuffle=True, |
| seed=seed |
| ) |
| loader = DataLoader( |
| test_dataloaders, |
| batch_size=batch_size, |
| shuffle=False, |
| sampler=sampler, |
| num_workers=num_workers, |
| pin_memory=True, |
| drop_last=True |
| ) |
| from dreamsim import dreamsim |
| eval_model, _ = dreamsim(pretrained=True) |
| score = torch.tensor(0.).to(device) |
| n_samples = torch.tensor(0).to(device) |
|
|
| down_resampler = torchaudio.transforms.Resample(orig_freq=input_sr, new_freq=sample_rate, lowpass_filter_width=64).to(device, dtype=torch.bfloat16) |
| mel_tf = build_mel_transform( |
| sample_rate=sample_rate, |
| n_fft=1024, win_length=1024, hop_length=256, |
| n_mels=80, power=1.0, |
| device=device, |
| ) |
| |
| for x_v, x_a, y, diff, rel_t, x_a_orig in loader: |
| x_v = x_v.to(device) |
| x_a = x_a.to(device) |
| x_a_orig = x_a_orig.to(device) |
| y = y.to(device) |
| diff = diff.to(device).flatten(0, 1) |
| rel_t = rel_t.to(device).flatten(0, 1) |
| with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16): |
| B, T = x_v.shape[:2] |
| num_goals = T - num_cond |
| samples_v, samples_a, diff_pred = model_forward_wrapper_av((model, diffusion, vae, sstream), (x_v, x_a), y, num_timesteps=None, latent_size=latent_size, device=device, num_cond=num_cond, num_goals=num_goals, rel_t=rel_t) |
| |
| samples_a = down_resampler(samples_a) |
|
|
| x_start_pixels = x_v[:, num_cond:].flatten(0, 1) |
| x_cond_pixels = x_v[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_v.shape[2], x_v.shape[3], x_v.shape[4]).flatten(0, 1) |
| samples_v = samples_v * 0.5 + 0.5 |
| x_start_pixels = x_start_pixels * 0.5 + 0.5 |
| x_cond_pixels = x_cond_pixels * 0.5 + 0.5 |
| res = eval_model(x_start_pixels, samples_v) |
| score += res.sum() |
| n_samples += len(res) |
|
|
| |
| |
| x_start_audio = x_a_orig[:, num_cond:].flatten(0, 1) |
| x_cond_audio = x_a_orig[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a_orig.shape[2], x_a_orig.shape[3]).flatten(0, 1) |
| break |
| |
| if rank == 0: |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| if diff is not None: |
| mae = torch.mean(torch.abs(diff_pred - diff)) |
| logger.info(f"Distance Diff MAE = {mae.item():.6f}") |
|
|
| mel_cosine_ls=[] |
| for i in range(min(samples_v.shape[0], 10)): |
| _, ax = plt.subplots(1,3,dpi=256) |
| ax[0].imshow((x_cond_pixels[i, -1].permute(1,2,0).cpu().numpy()*255).astype('uint8')) |
| ax[1].imshow((x_start_pixels[i].permute(1,2,0).cpu().numpy()*255).astype('uint8')) |
| ax[2].imshow((samples_v[i].permute(1,2,0).cpu().float().numpy()*255).astype('uint8')) |
| plt.savefig(f'{save_dir}/{i}.png') |
| plt.close() |
|
|
|
|
| mel_cos = mel_cosine_stereo(x_start_audio[i], samples_a[i], sample_rate=sample_rate, mel_tf=mel_tf) |
| mel_cosine_ls.append(mel_cos) |
| ok = save_ref_hat_spectrogram_panel( |
| x_start_audio[i], samples_a[i], |
| out_path=f"{save_dir}/{i}_spectrograms.png", |
| n_fft=512, hop_length=160, win_length=400, pool=4, |
| title="gt vs pred" |
| ) |
|
|
| |
| torchaudio.save(f"{save_dir}/{i}_gen.wav", samples_a[i].cpu().to(torch.float32), sample_rate=sample_rate) |
| torchaudio.save(f"{save_dir}/{i}_gt.wav", x_start_audio[i].cpu().to(torch.float32), sample_rate=sample_rate) |
| torchaudio.save(f"{save_dir}/{i}_cond.wav", x_cond_audio[i, -1].cpu().to(torch.float32), sample_rate=sample_rate) |
| logger.info("the first 10 mel cosine: " + ", ".join(f"{v:.6f}" for v in mel_cosine_ls)) |
|
|
|
|
| dist.all_reduce(score) |
| dist.all_reduce(n_samples) |
| sim_score = score/n_samples |
| return sim_score |
|
|
|
|
| def get_args_parser(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, required=True) |
| parser.add_argument("--epochs", type=int, default=300) |
| parser.add_argument("--global-seed", type=int, default=0) |
| parser.add_argument("--log-every", type=int, default=100) |
| parser.add_argument("--ckpt-every", type=int, default=2000) |
| parser.add_argument("--eval-every", type=int, default=5000) |
| parser.add_argument("--bfloat16", type=int, default=1) |
| parser.add_argument("--torch-compile", type=int, default=1) |
| parser.add_argument("--restart-from-checkpoint", type=int, default=0, |
| help="If 1, only load model weights and reset epoch/step to zero (cold start)") |
| return parser |
|
|
| if __name__ == "__main__": |
| args = get_args_parser().parse_args() |
| main(args) |
|
|