# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from distributed import init_distributed import torch torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import yaml import argparse import os import numpy as np from diffusion import create_diffusion from diffusers.models import AutoencoderKL import misc import distributed as dist from models import AVCDiT_models from datasets import EvalDataset from PIL import Image from soundstream import SoundStream import torchaudio from skimage.measure import block_reduce import matplotlib.pyplot as plt import librosa import time import warnings warnings.filterwarnings("ignore", category=UserWarning) from collections import defaultdict import json def save_image(output_file, img, unnormalize_img): img = img.detach().cpu() if unnormalize_img: img = misc.unnormalize(img) img = img * 255 img = img.byte() image = Image.fromarray(img.permute(1, 2, 0).numpy(), mode='RGB') image.save(output_file) def save_audio(output_file, audio_tensor, sample_rate): audio_tensor = audio_tensor.detach().cpu() if audio_tensor.ndim == 1: audio_tensor = audio_tensor.unsqueeze(0) torchaudio.save(output_file, audio_tensor.to(torch.float32), sample_rate) def get_dataset_eval(config, dataset_name, eval_type, predefined_index=True): data_config = config["eval_datasets"][dataset_name] if predefined_index: predefined_index = f"data_splits/{dataset_name}/test/{eval_type}.pkl" else: predefined_index=None dataset = EvalDataset( data_folder=data_config["data_folder"], data_split_folder=data_config["test"], dataset_name=dataset_name, image_size=config["image_size"], min_dist_cat=config["eval_distance"]["eval_min_dist_cat"], max_dist_cat=config["eval_distance"]["eval_max_dist_cat"], len_traj_pred=config["eval_len_traj_pred"], traj_stride=config["traj_stride"], context_size=config["eval_context_size"], normalize=config["normalize"], transform=misc.transform, goals_per_obs=4, predefined_index=predefined_index, traj_names='traj_names.txt' ) return dataset @torch.no_grad() def model_forward_wrapper_v(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False): model, diffusion, vae = all_models x = curr_obs.to(device) y = curr_delta.to(device) with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16): B, T = x.shape[:2] if rel_t is None: rel_t = (torch.ones(B)* (1. / 128.)).to(device) rel_t *= num_timesteps x = x.flatten(0,1) x = vae.encode(x).latent_dist.sample().mul_(0.18215).unflatten(0, (B, T)) x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3], x.shape[4]).flatten(0, 1) z = torch.randn(B*num_goals, 4, latent_size, latent_size, device=device) y = y.flatten(0, 1) model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t) samples = diffusion.p_sample_loop( model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device ) samples = vae.decode(samples / 0.18215).sample return torch.clip(samples, -1., 1.) @torch.no_grad() def model_forward_wrapper_a(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False): model, diffusion, sstream = all_models x = curr_obs.to(device) y = curr_delta.to(device) with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16): B, T = x.shape[:2] if rel_t is None: rel_t = (torch.ones(B)* (1. / 128.)).to(device) rel_t *= num_timesteps x = x.flatten(0,1) x = sstream.encoder(x).unflatten(0, (B, T)) x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3]).flatten(0, 1) z = torch.randn(B*num_goals, 16, 181, device=device) y = y.flatten(0, 1) model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t) samples = diffusion.p_sample_loop( model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device ) # REWARD TOKEN patch_tok = samples[..., -1:] # [N, 64, 1] diff_pred = patch_tok.mean(dim=1, keepdim=True) # [N, 1] samples = samples[..., :-1] # AUDIO TOKENS quantized, _, _ = sstream.quantizer(samples.permute(0, 2, 1)) # [1, T', D] samples = sstream.decoder(quantized.permute(0, 2, 1)) return samples, diff_pred @torch.no_grad() def model_forward_wrapper_av(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False): model, diffusion, vae, sstream = all_models x_v, x_a = curr_obs x_v = x_v.to(device) x_a = x_a.to(device) y = curr_delta.to(device) with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16): B, T_v = x_v.shape[:2] B, T_a = x_a.shape[:2] if rel_t is None: rel_t = (torch.ones(B)* (1. / 128.)).to(device) rel_t *= num_timesteps x_v = x_v.flatten(0,1) x_a = x_a.flatten(0,1) x_v = vae.encode(x_v).latent_dist.sample().mul_(0.18215).unflatten(0, (B, T_v)) x_a = sstream.encoder(x_a).unflatten(0, (B, T_a)) x_v_cond = x_v[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_v.shape[2], x_v.shape[3], x_v.shape[4]).flatten(0, 1) x_a_cond = x_a[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a.shape[2], x_a.shape[3]).flatten(0, 1) z_v = torch.randn(B*num_goals, 4, latent_size, latent_size, device=device) z_a = torch.randn(B*num_goals, 16, 181, device=device) #TODO y = y.flatten(0, 1) model_kwargs = dict(y=y, x_v_cond=x_v_cond, x_a_cond=x_a_cond, rel_t=rel_t) samples_v, samples_a = diffusion.p_sample_loop( model.forward, z_v.shape, z_a.shape, z_v, z_a, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device ) patch_tok = samples_a[..., -1:] # [N, 16, 1] diff_pred = patch_tok.mean(dim=1, keepdim=True) # [N, 1] samples_a = samples_a[..., :-1] samples_v = vae.decode(samples_v / 0.18215).sample quantized, _, _ = sstream.quantizer(samples_a.permute(0, 2, 1)) # [1, T', D] samples_a = sstream.decoder(quantized.permute(0, 2, 1)) return torch.clip(samples_v, -1., 1.), samples_a, diff_pred def generate_rollout(args, output_dir, rollout_frames, idxs, all_models, obs_av, gt_av, diffs_seq, delta, num_cond, device): (obs_image, obs_audio, orig_obs_audio)=obs_av (gt_image, gt_audio, orig_gt_audio)=gt_av gt_image = gt_image[:,:rollout_frames] gt_audio = gt_audio[:,:rollout_frames] curr_v = obs_image.to(device) curr_a = obs_audio.to(device) down_resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000, lowpass_filter_width=64).to(device, dtype=torch.bfloat16) episode_records = defaultdict(list) value_key = "denorm_gt" if args.gt else "denorm_pred" for i in range(gt_image.shape[1]): curr_delta = delta[:, i:i+1].to(device) x_gt_pixels = gt_image[:, i].to(device) x_gt_audios_orig = orig_gt_audio[:, i].to(device) if args.gt: visualize_preds(output_dir, idxs, i+1, x_gt_pixels, x_gt_audios_orig, 16000) denorm_gt_vals = denorm_from_tensor(diffs_seq[:, i:i+1, :]) # [B] idxs_1d = idxs.detach().view(-1).cpu().numpy() for b, sample_idx in enumerate(idxs_1d): episode_records[int(sample_idx)].append({"sec": int(i+1), "value": float(denorm_gt_vals[b])}) else: diff_gt = diffs_seq[:, i:i+1, :].unsqueeze(1).to(device) x_pred_pixels, x_pred_audios, diff_pred = model_forward_wrapper_av(all_models, (curr_v, curr_a), curr_delta, num_timesteps=1, latent_size=args.latent_size, device=device, num_cond=num_cond, num_goals=1) x_pred_audios_orig = down_resampler(x_pred_audios) curr_v = torch.cat((curr_v, x_pred_pixels.unsqueeze(1)), dim=1) # append current prediction curr_v = curr_v[:, 1:] # remove first observation curr_a = torch.cat((curr_a, x_pred_audios.unsqueeze(1)), dim=1) # append current prediction curr_a = curr_a[:, 1:] # remove first observation denorm_pred_vals = denorm_from_tensor(diff_pred) # [B] denorm_gt_vals = denorm_from_tensor(diff_gt) # [B] visualize_preds(output_dir, idxs, i+1, x_pred_pixels, x_pred_audios_orig, 16000) visualize_compare(output_dir, idxs, i+1, x_pred_pixels, x_pred_audios_orig, x_gt_pixels, x_gt_audios_orig, denorm_pred_vals=denorm_pred_vals, denorm_gt_vals=denorm_gt_vals) idxs_1d = idxs.detach().view(-1).cpu().numpy() for b, sample_idx in enumerate(idxs_1d): episode_records[int(sample_idx)].append({"sec": int(i+1), "value": float(denorm_pred_vals[b])}) for sample_idx, rows in episode_records.items(): rows = sorted(rows, key=lambda r: r["sec"]) sample_folder = os.path.join(output_dir, f"id_{sample_idx}") os.makedirs(sample_folder, exist_ok=True) out_json = os.path.join(sample_folder, "distance.json") compact = [{ "sec": r["sec"], value_key: r["value"] } for r in rows] with open(out_json, "w") as f: json.dump(compact, f, indent=2) def generate_time(args, output_dir, idxs, all_models, obs_av, gt_av, diffs_seq, delta, secs, num_cond, device): (obs_image, obs_audio, _)=obs_av (gt_image, _, orig_gt_audio)=gt_av down_resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000, lowpass_filter_width=64).to(device, dtype=torch.bfloat16) episode_records = defaultdict(list) # {sample_idx: [{"sec": int, "value": float}, ...]} value_key = "denorm_gt" if args.gt else "denorm_pred" for sec in secs: curr_delta = delta[:, :sec].sum(dim=1, keepdim=True) x_gt_pixels = gt_image[:, sec-1].to(device) x_gt_audios_orig = orig_gt_audio[:, sec-1].to(device) if args.gt: denorm_gt_vals = denorm_from_tensor(diffs_seq[:, :sec, :].sum(dim=1, keepdim=True)) # [B] visualize_preds(output_dir, idxs, sec, x_gt_pixels, x_gt_audios_orig, 16000) idxs_1d = idxs.detach().view(-1).cpu().numpy() for b, sample_idx in enumerate(idxs_1d): episode_records[int(sample_idx)].append({"sec": int(sec), "value": float(denorm_gt_vals[b])}) else: diff_gt = diffs_seq[:, :sec, :].sum(dim=1, keepdim=True).to(device) print(obs_image.shape, obs_audio.shape, curr_delta.shape, obs_image.dtype, obs_audio.dtype, curr_delta.dtype) x_pred_pixels, x_pred_audios, diff_pred = model_forward_wrapper_av(all_models, (obs_image, obs_audio) , curr_delta, sec, args.latent_size, num_cond=num_cond, num_goals=1, device=device) x_pred_audios_orig = down_resampler(x_pred_audios) denorm_pred_vals = denorm_from_tensor(diff_pred) # [B] denorm_gt_vals = denorm_from_tensor(diff_gt) # [B] visualize_preds(output_dir, idxs, sec, x_pred_pixels, x_pred_audios_orig, 16000) visualize_compare(output_dir, idxs, sec, x_pred_pixels, x_pred_audios_orig, x_gt_pixels, x_gt_audios_orig, denorm_pred_vals=denorm_pred_vals, denorm_gt_vals=denorm_gt_vals) idxs_1d = idxs.detach().view(-1).cpu().numpy() for b, sample_idx in enumerate(idxs_1d): episode_records[int(sample_idx)].append({"sec": int(sec), "value": float(denorm_pred_vals[b])}) for sample_idx, rows in episode_records.items(): rows = sorted(rows, key=lambda r: r["sec"]) sample_folder = os.path.join(output_dir, f"id_{sample_idx}") os.makedirs(sample_folder, exist_ok=True) out_json = os.path.join(sample_folder, "distance.json") compact = [{ "sec": r["sec"], value_key: r["value"] } for r in rows] with open(out_json, "w") as f: json.dump(compact, f, indent=2) def visualize_preds(output_dir, idxs, sec, x_pred_pixels, x_pred_audios, sample_rate): idxs_1d = idxs.detach().view(-1) for batch_idx, sample_idx in enumerate(idxs_1d): sample_idx = int(sample_idx.item()) sample_folder = os.path.join(output_dir, f'id_{sample_idx}') os.makedirs(sample_folder, exist_ok=True) image_file = os.path.join(sample_folder, f'{sec}.png') save_image(image_file, x_pred_pixels[batch_idx], True) audio_file = os.path.join(sample_folder, f'{sec}.wav') save_audio(audio_file, x_pred_audios[batch_idx], sample_rate) def _compute_binaural_spectrogram_np(audio_2ch: np.ndarray): def _stft_abs(signal): n_fft = 512 hop_length = 160 win_length = 400 stft = np.abs(librosa.stft(signal, n_fft=n_fft, hop_length=hop_length, win_length=win_length)) stft = block_reduce(stft, block_size=(4, 4), func=np.mean) return stft L = np.log1p(_stft_abs(audio_2ch[0])) R = np.log1p(_stft_abs(audio_2ch[1])) spec = np.stack([L, R], axis=-1) # (F,T,2) return spec def denorm_from_tensor(t: torch.Tensor, min_v=-20.0, max_v=20.0, scale=0.15) -> torch.Tensor: x = t.detach().float().view(t.shape[0], -1)[:, 0] n01 = (x + 1.0) / 2.0 raw = n01 * (max_v - min_v) + min_v return raw * scale def visualize_compare(output_dir, idxs, sec, x_pred_pixels, x_pred_audios_orig, x_gt_pixels, x_gt_audios_orig, denorm_pred_vals, denorm_gt_vals): idxs_np = idxs.detach().view(-1).cpu().numpy() B = x_pred_pixels.shape[0] assert x_gt_pixels.shape[0] == B and x_pred_audios_orig.shape[0] == B and x_gt_audios_orig.shape[0] == B for b in range(B): sample_idx = int(idxs_np[b]) sample_folder = os.path.join(output_dir, f'id_{sample_idx}') os.makedirs(sample_folder, exist_ok=True) out_path = os.path.join(sample_folder, f'compare_{sec}.png') def _tensor_to_display_img(x: torch.Tensor): x = x.detach().cpu() x = misc.unnormalize(x) x = (x * 255.0).round().clamp(0, 255) x = x.to(torch.uint8).permute(1, 2, 0) return x.numpy() pred_img = _tensor_to_display_img(x_pred_pixels[b]) gt_img = _tensor_to_display_img(x_gt_pixels[b]) pred_aud = x_pred_audios_orig[b].detach().cpu().float().numpy() gt_aud = x_gt_audios_orig[b].detach().cpu().float().numpy() pred_spec = _compute_binaural_spectrogram_np(pred_aud) gt_spec = _compute_binaural_spectrogram_np(gt_aud) vmin_L = min(pred_spec[:, :, 0].min(), gt_spec[:, :, 0].min()) vmax_L = max(pred_spec[:, :, 0].max(), gt_spec[:, :, 0].max()) vmin_R = min(pred_spec[:, :, 1].min(), gt_spec[:, :, 1].min()) vmax_R = max(pred_spec[:, :, 1].max(), gt_spec[:, :, 1].max()) dn_pred = float(denorm_pred_vals[b]) if denorm_pred_vals is not None else 0 dn_gt = float(denorm_gt_vals[b]) if denorm_gt_vals is not None else 0 fig, axes = plt.subplots(2, 4, figsize=(14, 6), constrained_layout=True) axes[0, 0].imshow(pred_img); axes[0, 0].set_title('pred image'); axes[0, 0].axis('off') axes[0, 1].imshow(gt_img); axes[0, 1].set_title('gt image'); axes[0, 1].axis('off') axes[1, 0].axis('off') axes[1, 1].axis('off') im_pred_L = axes[0, 2].imshow(pred_spec[:, :, 0], origin='lower', aspect='auto', vmin=vmin_L, vmax=vmax_L) axes[0, 2].set_title('pred spec (Left)'); axes[0, 2].set_xticks([]); axes[0, 2].set_yticks([]) im_gt_L = axes[0, 3].imshow(gt_spec[:, :, 0], origin='lower', aspect='auto', vmin=vmin_L, vmax=vmax_L) axes[0, 3].set_title('gt spec (Left)'); axes[0, 3].set_xticks([]); axes[0, 3].set_yticks([]) im_pred_R = axes[1, 2].imshow(pred_spec[:, :, 1], origin='lower', aspect='auto', vmin=vmin_R, vmax=vmax_R) axes[1, 2].set_title('pred spec (Right)'); axes[1, 2].set_xticks([]); axes[1, 2].set_yticks([]) im_gt_R = axes[1, 3].imshow(gt_spec[:, :, 1], origin='lower', aspect='auto', vmin=vmin_R, vmax=vmax_R) axes[1, 3].set_title('gt spec (Right)'); axes[1, 3].set_xticks([]); axes[1, 3].set_yticks([]) fig.suptitle( f'id={sample_idx}, sec={sec} | denorm(reward_pred)={dn_pred:.4f}, denorm(reward_gt)={dn_gt:.4f}', fontsize=11 ) plt.savefig(out_path, dpi=180) plt.close(fig) @torch.no_grad() def main(args): _, _, device, _ = init_distributed() print(args) device = torch.device(device) num_tasks = dist.get_world_size() global_rank = dist.get_rank() exp_eval = args.exp # model & config setup if args.gt: args.save_output_dir = os.path.join(args.output_dir, 'gt') else: exp_name = os.path.basename(exp_eval).split('.')[0] args.save_output_dir = os.path.join(args.output_dir, exp_name) if args.ckp != '0100000': args.save_output_dir = args.save_output_dir + "_%s"%(args.ckp) os.makedirs(args.save_output_dir, exist_ok=True) with open("config/eval_config.yaml", "r") as f: default_config = yaml.safe_load(f) config = default_config with open(exp_eval, "r") as f: user_config = yaml.safe_load(f) config.update(user_config) eval_len_traj_pred=config["eval_len_traj_pred"] if args.rollout_frames==-1: args.rollout_frames=eval_len_traj_pred assert args.rollout_frames<=eval_len_traj_pred latent_size = config['image_size'] // 8 args.latent_size = config['image_size'] // 8 num_cond = config['context_size'] print("loading") model_lst = (None, None, None, None) if not args.gt: model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4, mode="av") ckp = torch.load(f'{config["results_dir"]}/{config["run_name"]}/checkpoints/{args.ckp}.pth.tar', map_location='cpu', weights_only=False) print(model.load_state_dict(ckp["ema"], strict=True)) model.eval() model.to(device) model = torch.compile(model) diffusion = create_diffusion(str(250), dual=True) vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device) sstream = SoundStream(C=32, D=16, n_q=8, codebook_size=1024).to(device) sstream_path=config["tokenizer_a_path"] sstream_checkpoint = torch.load(sstream_path, map_location=device) sstream.load_state_dict(sstream_checkpoint["model_state"]) sstream.eval() model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], find_unused_parameters=False) model_lst = (model, diffusion, vae, sstream) # Loading Datasets dataset_names = args.datasets.split(',') datasets = {} for dataset_name in dataset_names: dataset_val = get_dataset_eval(config, dataset_name, args.eval_type, predefined_index=False) if len(dataset_val) % num_tasks != 0: print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) curr_data_loader = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=False ) datasets[dataset_name] = curr_data_loader print_freq = 1 header = 'Evaluation: ' metric_logger = dist.MetricLogger(delimiter=" ") for dataset_name in dataset_names: dataset_save_output_dir = os.path.join(args.save_output_dir, dataset_name) os.makedirs(dataset_save_output_dir, exist_ok=True) curr_data_loader = datasets[dataset_name] for data_iter_step, (idxs, obs_image, gt_image, obs_audio, gt_audio, diffs_seq, delta, orig_obs_audio, orig_gt_audio) in enumerate(metric_logger.log_every(curr_data_loader, print_freq, header)): with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16): obs_image = obs_image[:, -num_cond:].to(device) gt_image = gt_image.to(device) obs_audio = obs_audio[:, -num_cond:].to(device) gt_audio = gt_audio.to(device) orig_obs_audio = orig_obs_audio[:, -num_cond:].to(device) orig_gt_audio = orig_gt_audio.to(device) diffs_seq = diffs_seq.to(device) obs_av=(obs_image, obs_audio, orig_obs_audio) gt_av=(gt_image, gt_audio, orig_gt_audio) if args.eval_type == 'rollout': curr_rollout_output_dir = os.path.join(dataset_save_output_dir, f'rollout_{args.rollout_frames}frames') os.makedirs(curr_rollout_output_dir, exist_ok=True) generate_rollout(args, curr_rollout_output_dir, args.rollout_frames, idxs, model_lst, obs_av, gt_av, diffs_seq, delta, num_cond, device) elif args.eval_type == 'time': if args.time_secs != '': secs = np.array([int(sec) for sec in args.time_secs.split(',')]) else: secs = np.array([int(sec) for sec in range(1,args.rollout_frames+1)]) curr_time_output_dir = os.path.join(dataset_save_output_dir, 'time') os.makedirs(curr_time_output_dir, exist_ok=True) generate_time(args, curr_time_output_dir, idxs, model_lst, obs_av, gt_av, diffs_seq, delta, secs, num_cond, device) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--output_dir", type=str, default=None, help="output directory") parser.add_argument("--exp", type=str, default=None, help="experiment name") parser.add_argument("--ckp", type=str, default='0100000') parser.add_argument("--num_sec_eval", type=int, default=5) parser.add_argument("--input_fps", type=int, default=4) parser.add_argument("--datasets", type=str, default=None, help="dataset name") parser.add_argument("--num_workers", type=int, default=8, help="num workers") parser.add_argument("--batch_size", type=int, default=16, help="batch size") parser.add_argument("--eval_type", type=str, default=None, help="type of evaluation has to be either 'time' or 'rollout'") # Rollout Evaluation Args parser.add_argument("--time_secs", type=str, default='', help="") #'1,2,3,4' parser.add_argument("--rollout_frames", type=int, default=-1, help="") parser.add_argument("--gt", type=int, default=0, help="set to 1 to produce ground truth evaluation set") args = parser.parse_args() main(args)