| |
| |
| |
| |
| |
| from distributed import init_distributed |
| import torch |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| import yaml |
| import argparse |
| import os |
| import numpy as np |
|
|
| from diffusion import create_diffusion |
| from diffusers.models import AutoencoderKL |
|
|
| import misc |
| import distributed as dist |
| from models import AVCDiT_models |
| from datasets import EvalDataset |
| from PIL import Image |
| from soundstream import SoundStream |
| import torchaudio |
| from skimage.measure import block_reduce |
|
|
| import matplotlib.pyplot as plt |
| import librosa |
| import time |
| import warnings |
| warnings.filterwarnings("ignore", category=UserWarning) |
| from collections import defaultdict |
| import json |
|
|
| def save_image(output_file, img, unnormalize_img): |
| img = img.detach().cpu() |
| if unnormalize_img: |
| img = misc.unnormalize(img) |
| |
| img = img * 255 |
| img = img.byte() |
| image = Image.fromarray(img.permute(1, 2, 0).numpy(), mode='RGB') |
|
|
| image.save(output_file) |
| |
| def save_audio(output_file, audio_tensor, sample_rate): |
| audio_tensor = audio_tensor.detach().cpu() |
| if audio_tensor.ndim == 1: |
| audio_tensor = audio_tensor.unsqueeze(0) |
| torchaudio.save(output_file, audio_tensor.to(torch.float32), sample_rate) |
| |
| def get_dataset_eval(config, dataset_name, eval_type, predefined_index=True): |
| data_config = config["eval_datasets"][dataset_name] |
| if predefined_index: |
| predefined_index = f"data_splits/{dataset_name}/test/{eval_type}.pkl" |
| else: |
| predefined_index=None |
|
|
| |
| dataset = EvalDataset( |
| data_folder=data_config["data_folder"], |
| data_split_folder=data_config["test"], |
| dataset_name=dataset_name, |
| image_size=config["image_size"], |
| min_dist_cat=config["eval_distance"]["eval_min_dist_cat"], |
| max_dist_cat=config["eval_distance"]["eval_max_dist_cat"], |
| len_traj_pred=config["eval_len_traj_pred"], |
| traj_stride=config["traj_stride"], |
| context_size=config["eval_context_size"], |
| normalize=config["normalize"], |
| transform=misc.transform, |
| goals_per_obs=4, |
| predefined_index=predefined_index, |
| traj_names='traj_names.txt' |
| ) |
| |
| return dataset |
|
|
|
|
| @torch.no_grad() |
| def model_forward_wrapper_v(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False): |
| model, diffusion, vae = all_models |
| x = curr_obs.to(device) |
| y = curr_delta.to(device) |
|
|
| with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16): |
| B, T = x.shape[:2] |
|
|
| if rel_t is None: |
| rel_t = (torch.ones(B)* (1. / 128.)).to(device) |
| rel_t *= num_timesteps |
|
|
| x = x.flatten(0,1) |
| x = vae.encode(x).latent_dist.sample().mul_(0.18215).unflatten(0, (B, T)) |
| x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3], x.shape[4]).flatten(0, 1) |
| z = torch.randn(B*num_goals, 4, latent_size, latent_size, device=device) |
| y = y.flatten(0, 1) |
| model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t) |
| samples = diffusion.p_sample_loop( |
| model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device |
| ) |
| samples = vae.decode(samples / 0.18215).sample |
|
|
| return torch.clip(samples, -1., 1.) |
|
|
|
|
| @torch.no_grad() |
| def model_forward_wrapper_a(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False): |
| model, diffusion, sstream = all_models |
| x = curr_obs.to(device) |
| y = curr_delta.to(device) |
| with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16): |
| B, T = x.shape[:2] |
| if rel_t is None: |
| rel_t = (torch.ones(B)* (1. / 128.)).to(device) |
| rel_t *= num_timesteps |
| x = x.flatten(0,1) |
| x = sstream.encoder(x).unflatten(0, (B, T)) |
| x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3]).flatten(0, 1) |
| z = torch.randn(B*num_goals, 16, 181, device=device) |
| y = y.flatten(0, 1) |
| model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t) |
| samples = diffusion.p_sample_loop( |
| model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device |
| ) |
| |
| patch_tok = samples[..., -1:] |
| diff_pred = patch_tok.mean(dim=1, keepdim=True) |
| samples = samples[..., :-1] |
| |
| quantized, _, _ = sstream.quantizer(samples.permute(0, 2, 1)) |
| samples = sstream.decoder(quantized.permute(0, 2, 1)) |
| return samples, diff_pred |
| |
|
|
| @torch.no_grad() |
| def model_forward_wrapper_av(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False): |
| model, diffusion, vae, sstream = all_models |
| x_v, x_a = curr_obs |
| x_v = x_v.to(device) |
| x_a = x_a.to(device) |
| y = curr_delta.to(device) |
| with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16): |
| B, T_v = x_v.shape[:2] |
| B, T_a = x_a.shape[:2] |
|
|
| if rel_t is None: |
| rel_t = (torch.ones(B)* (1. / 128.)).to(device) |
| rel_t *= num_timesteps |
| x_v = x_v.flatten(0,1) |
| x_a = x_a.flatten(0,1) |
| x_v = vae.encode(x_v).latent_dist.sample().mul_(0.18215).unflatten(0, (B, T_v)) |
| x_a = sstream.encoder(x_a).unflatten(0, (B, T_a)) |
| x_v_cond = x_v[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_v.shape[2], x_v.shape[3], x_v.shape[4]).flatten(0, 1) |
| x_a_cond = x_a[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a.shape[2], x_a.shape[3]).flatten(0, 1) |
| z_v = torch.randn(B*num_goals, 4, latent_size, latent_size, device=device) |
| z_a = torch.randn(B*num_goals, 16, 181, device=device) |
| y = y.flatten(0, 1) |
| model_kwargs = dict(y=y, x_v_cond=x_v_cond, x_a_cond=x_a_cond, rel_t=rel_t) |
| samples_v, samples_a = diffusion.p_sample_loop( |
| model.forward, z_v.shape, z_a.shape, z_v, z_a, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device |
| ) |
| patch_tok = samples_a[..., -1:] |
| diff_pred = patch_tok.mean(dim=1, keepdim=True) |
| samples_a = samples_a[..., :-1] |
| samples_v = vae.decode(samples_v / 0.18215).sample |
| quantized, _, _ = sstream.quantizer(samples_a.permute(0, 2, 1)) |
| samples_a = sstream.decoder(quantized.permute(0, 2, 1)) |
| return torch.clip(samples_v, -1., 1.), samples_a, diff_pred |
|
|
|
|
| def generate_rollout(args, output_dir, rollout_frames, idxs, all_models, obs_av, gt_av, diffs_seq, delta, num_cond, device): |
| (obs_image, obs_audio, orig_obs_audio)=obs_av |
| (gt_image, gt_audio, orig_gt_audio)=gt_av |
|
|
| gt_image = gt_image[:,:rollout_frames] |
| gt_audio = gt_audio[:,:rollout_frames] |
| curr_v = obs_image.to(device) |
| curr_a = obs_audio.to(device) |
| down_resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000, lowpass_filter_width=64).to(device, dtype=torch.bfloat16) |
| episode_records = defaultdict(list) |
| value_key = "denorm_gt" if args.gt else "denorm_pred" |
| |
| for i in range(gt_image.shape[1]): |
| curr_delta = delta[:, i:i+1].to(device) |
|
|
| x_gt_pixels = gt_image[:, i].to(device) |
| x_gt_audios_orig = orig_gt_audio[:, i].to(device) |
| if args.gt: |
| visualize_preds(output_dir, idxs, i+1, x_gt_pixels, x_gt_audios_orig, 16000) |
| denorm_gt_vals = denorm_from_tensor(diffs_seq[:, i:i+1, :]) |
| idxs_1d = idxs.detach().view(-1).cpu().numpy() |
| for b, sample_idx in enumerate(idxs_1d): |
| episode_records[int(sample_idx)].append({"sec": int(i+1), "value": float(denorm_gt_vals[b])}) |
| else: |
| diff_gt = diffs_seq[:, i:i+1, :].unsqueeze(1).to(device) |
| x_pred_pixels, x_pred_audios, diff_pred = model_forward_wrapper_av(all_models, (curr_v, curr_a), curr_delta, num_timesteps=1, latent_size=args.latent_size, device=device, num_cond=num_cond, num_goals=1) |
| x_pred_audios_orig = down_resampler(x_pred_audios) |
| curr_v = torch.cat((curr_v, x_pred_pixels.unsqueeze(1)), dim=1) |
| curr_v = curr_v[:, 1:] |
| curr_a = torch.cat((curr_a, x_pred_audios.unsqueeze(1)), dim=1) |
| curr_a = curr_a[:, 1:] |
| denorm_pred_vals = denorm_from_tensor(diff_pred) |
| denorm_gt_vals = denorm_from_tensor(diff_gt) |
| visualize_preds(output_dir, idxs, i+1, x_pred_pixels, x_pred_audios_orig, 16000) |
| visualize_compare(output_dir, idxs, i+1, |
| x_pred_pixels, x_pred_audios_orig, |
| x_gt_pixels, x_gt_audios_orig, |
| denorm_pred_vals=denorm_pred_vals, |
| denorm_gt_vals=denorm_gt_vals) |
| idxs_1d = idxs.detach().view(-1).cpu().numpy() |
| for b, sample_idx in enumerate(idxs_1d): |
| episode_records[int(sample_idx)].append({"sec": int(i+1), "value": float(denorm_pred_vals[b])}) |
|
|
| for sample_idx, rows in episode_records.items(): |
| rows = sorted(rows, key=lambda r: r["sec"]) |
| sample_folder = os.path.join(output_dir, f"id_{sample_idx}") |
| os.makedirs(sample_folder, exist_ok=True) |
| out_json = os.path.join(sample_folder, "distance.json") |
| compact = [{ "sec": r["sec"], value_key: r["value"] } for r in rows] |
| with open(out_json, "w") as f: |
| json.dump(compact, f, indent=2) |
|
|
|
|
| def generate_time(args, output_dir, idxs, all_models, obs_av, gt_av, diffs_seq, delta, secs, num_cond, device): |
| (obs_image, obs_audio, _)=obs_av |
| (gt_image, _, orig_gt_audio)=gt_av |
| down_resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000, lowpass_filter_width=64).to(device, dtype=torch.bfloat16) |
| episode_records = defaultdict(list) |
| value_key = "denorm_gt" if args.gt else "denorm_pred" |
|
|
| for sec in secs: |
| curr_delta = delta[:, :sec].sum(dim=1, keepdim=True) |
| x_gt_pixels = gt_image[:, sec-1].to(device) |
| x_gt_audios_orig = orig_gt_audio[:, sec-1].to(device) |
| if args.gt: |
| denorm_gt_vals = denorm_from_tensor(diffs_seq[:, :sec, :].sum(dim=1, keepdim=True)) |
| visualize_preds(output_dir, idxs, sec, x_gt_pixels, x_gt_audios_orig, 16000) |
| idxs_1d = idxs.detach().view(-1).cpu().numpy() |
| for b, sample_idx in enumerate(idxs_1d): |
| episode_records[int(sample_idx)].append({"sec": int(sec), "value": float(denorm_gt_vals[b])}) |
| else: |
| diff_gt = diffs_seq[:, :sec, :].sum(dim=1, keepdim=True).to(device) |
|
|
| print(obs_image.shape, obs_audio.shape, curr_delta.shape, obs_image.dtype, obs_audio.dtype, curr_delta.dtype) |
| x_pred_pixels, x_pred_audios, diff_pred = model_forward_wrapper_av(all_models, (obs_image, obs_audio) , curr_delta, sec, args.latent_size, num_cond=num_cond, num_goals=1, device=device) |
| x_pred_audios_orig = down_resampler(x_pred_audios) |
| denorm_pred_vals = denorm_from_tensor(diff_pred) |
| denorm_gt_vals = denorm_from_tensor(diff_gt) |
|
|
| visualize_preds(output_dir, idxs, sec, x_pred_pixels, x_pred_audios_orig, 16000) |
| visualize_compare(output_dir, idxs, sec, |
| x_pred_pixels, x_pred_audios_orig, |
| x_gt_pixels, x_gt_audios_orig, |
| denorm_pred_vals=denorm_pred_vals, |
| denorm_gt_vals=denorm_gt_vals) |
| idxs_1d = idxs.detach().view(-1).cpu().numpy() |
| for b, sample_idx in enumerate(idxs_1d): |
| episode_records[int(sample_idx)].append({"sec": int(sec), "value": float(denorm_pred_vals[b])}) |
| for sample_idx, rows in episode_records.items(): |
| rows = sorted(rows, key=lambda r: r["sec"]) |
| sample_folder = os.path.join(output_dir, f"id_{sample_idx}") |
| os.makedirs(sample_folder, exist_ok=True) |
| out_json = os.path.join(sample_folder, "distance.json") |
| compact = [{ "sec": r["sec"], value_key: r["value"] } for r in rows] |
| with open(out_json, "w") as f: |
| json.dump(compact, f, indent=2) |
|
|
|
|
| def visualize_preds(output_dir, idxs, sec, x_pred_pixels, x_pred_audios, sample_rate): |
| idxs_1d = idxs.detach().view(-1) |
| for batch_idx, sample_idx in enumerate(idxs_1d): |
| sample_idx = int(sample_idx.item()) |
| sample_folder = os.path.join(output_dir, f'id_{sample_idx}') |
| os.makedirs(sample_folder, exist_ok=True) |
| image_file = os.path.join(sample_folder, f'{sec}.png') |
| save_image(image_file, x_pred_pixels[batch_idx], True) |
| audio_file = os.path.join(sample_folder, f'{sec}.wav') |
| save_audio(audio_file, x_pred_audios[batch_idx], sample_rate) |
|
|
| def _compute_binaural_spectrogram_np(audio_2ch: np.ndarray): |
| def _stft_abs(signal): |
| n_fft = 512 |
| hop_length = 160 |
| win_length = 400 |
| stft = np.abs(librosa.stft(signal, n_fft=n_fft, hop_length=hop_length, win_length=win_length)) |
| stft = block_reduce(stft, block_size=(4, 4), func=np.mean) |
| return stft |
| L = np.log1p(_stft_abs(audio_2ch[0])) |
| R = np.log1p(_stft_abs(audio_2ch[1])) |
| spec = np.stack([L, R], axis=-1) |
| return spec |
|
|
| def denorm_from_tensor(t: torch.Tensor, min_v=-20.0, max_v=20.0, scale=0.15) -> torch.Tensor: |
| x = t.detach().float().view(t.shape[0], -1)[:, 0] |
| n01 = (x + 1.0) / 2.0 |
| raw = n01 * (max_v - min_v) + min_v |
| return raw * scale |
|
|
| def visualize_compare(output_dir, idxs, sec, |
| x_pred_pixels, x_pred_audios_orig, |
| x_gt_pixels, x_gt_audios_orig, |
| denorm_pred_vals, |
| denorm_gt_vals): |
| idxs_np = idxs.detach().view(-1).cpu().numpy() |
|
|
| B = x_pred_pixels.shape[0] |
| assert x_gt_pixels.shape[0] == B and x_pred_audios_orig.shape[0] == B and x_gt_audios_orig.shape[0] == B |
|
|
| for b in range(B): |
| sample_idx = int(idxs_np[b]) |
| sample_folder = os.path.join(output_dir, f'id_{sample_idx}') |
| os.makedirs(sample_folder, exist_ok=True) |
| out_path = os.path.join(sample_folder, f'compare_{sec}.png') |
| def _tensor_to_display_img(x: torch.Tensor): |
| x = x.detach().cpu() |
| x = misc.unnormalize(x) |
| x = (x * 255.0).round().clamp(0, 255) |
| x = x.to(torch.uint8).permute(1, 2, 0) |
| return x.numpy() |
| |
| pred_img = _tensor_to_display_img(x_pred_pixels[b]) |
| gt_img = _tensor_to_display_img(x_gt_pixels[b]) |
|
|
| pred_aud = x_pred_audios_orig[b].detach().cpu().float().numpy() |
| gt_aud = x_gt_audios_orig[b].detach().cpu().float().numpy() |
| pred_spec = _compute_binaural_spectrogram_np(pred_aud) |
| gt_spec = _compute_binaural_spectrogram_np(gt_aud) |
|
|
| vmin_L = min(pred_spec[:, :, 0].min(), gt_spec[:, :, 0].min()) |
| vmax_L = max(pred_spec[:, :, 0].max(), gt_spec[:, :, 0].max()) |
| vmin_R = min(pred_spec[:, :, 1].min(), gt_spec[:, :, 1].min()) |
| vmax_R = max(pred_spec[:, :, 1].max(), gt_spec[:, :, 1].max()) |
| |
| dn_pred = float(denorm_pred_vals[b]) if denorm_pred_vals is not None else 0 |
| dn_gt = float(denorm_gt_vals[b]) if denorm_gt_vals is not None else 0 |
|
|
| fig, axes = plt.subplots(2, 4, figsize=(14, 6), constrained_layout=True) |
|
|
| axes[0, 0].imshow(pred_img); axes[0, 0].set_title('pred image'); axes[0, 0].axis('off') |
| axes[0, 1].imshow(gt_img); axes[0, 1].set_title('gt image'); axes[0, 1].axis('off') |
|
|
| axes[1, 0].axis('off') |
| axes[1, 1].axis('off') |
|
|
| im_pred_L = axes[0, 2].imshow(pred_spec[:, :, 0], origin='lower', aspect='auto', vmin=vmin_L, vmax=vmax_L) |
| axes[0, 2].set_title('pred spec (Left)'); axes[0, 2].set_xticks([]); axes[0, 2].set_yticks([]) |
| im_gt_L = axes[0, 3].imshow(gt_spec[:, :, 0], origin='lower', aspect='auto', vmin=vmin_L, vmax=vmax_L) |
| axes[0, 3].set_title('gt spec (Left)'); axes[0, 3].set_xticks([]); axes[0, 3].set_yticks([]) |
| im_pred_R = axes[1, 2].imshow(pred_spec[:, :, 1], origin='lower', aspect='auto', vmin=vmin_R, vmax=vmax_R) |
| axes[1, 2].set_title('pred spec (Right)'); axes[1, 2].set_xticks([]); axes[1, 2].set_yticks([]) |
| im_gt_R = axes[1, 3].imshow(gt_spec[:, :, 1], origin='lower', aspect='auto', vmin=vmin_R, vmax=vmax_R) |
| axes[1, 3].set_title('gt spec (Right)'); axes[1, 3].set_xticks([]); axes[1, 3].set_yticks([]) |
|
|
| fig.suptitle( |
| f'id={sample_idx}, sec={sec} | denorm(reward_pred)={dn_pred:.4f}, denorm(reward_gt)={dn_gt:.4f}', |
| fontsize=11 |
| ) |
| plt.savefig(out_path, dpi=180) |
| plt.close(fig) |
|
|
|
|
| @torch.no_grad() |
| def main(args): |
| _, _, device, _ = init_distributed() |
| print(args) |
| device = torch.device(device) |
| num_tasks = dist.get_world_size() |
| global_rank = dist.get_rank() |
| exp_eval = args.exp |
|
|
| |
| if args.gt: |
| args.save_output_dir = os.path.join(args.output_dir, 'gt') |
| else: |
| exp_name = os.path.basename(exp_eval).split('.')[0] |
| args.save_output_dir = os.path.join(args.output_dir, exp_name) |
| |
| if args.ckp != '0100000': |
| args.save_output_dir = args.save_output_dir + "_%s"%(args.ckp) |
|
|
| os.makedirs(args.save_output_dir, exist_ok=True) |
|
|
| with open("config/eval_config.yaml", "r") as f: |
| default_config = yaml.safe_load(f) |
| config = default_config |
|
|
| with open(exp_eval, "r") as f: |
| user_config = yaml.safe_load(f) |
| config.update(user_config) |
|
|
| eval_len_traj_pred=config["eval_len_traj_pred"] |
| if args.rollout_frames==-1: |
| args.rollout_frames=eval_len_traj_pred |
| assert args.rollout_frames<=eval_len_traj_pred |
| latent_size = config['image_size'] // 8 |
| args.latent_size = config['image_size'] // 8 |
|
|
| num_cond = config['context_size'] |
| print("loading") |
| model_lst = (None, None, None, None) |
| if not args.gt: |
| model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4, mode="av") |
| ckp = torch.load(f'{config["results_dir"]}/{config["run_name"]}/checkpoints/{args.ckp}.pth.tar', map_location='cpu', weights_only=False) |
| print(model.load_state_dict(ckp["ema"], strict=True)) |
| model.eval() |
| model.to(device) |
| model = torch.compile(model) |
| diffusion = create_diffusion(str(250), dual=True) |
| vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device) |
|
|
| sstream = SoundStream(C=32, D=16, n_q=8, codebook_size=1024).to(device) |
| sstream_path=config["tokenizer_a_path"] |
| sstream_checkpoint = torch.load(sstream_path, map_location=device) |
| sstream.load_state_dict(sstream_checkpoint["model_state"]) |
| sstream.eval() |
|
|
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], find_unused_parameters=False) |
| model_lst = (model, diffusion, vae, sstream) |
|
|
| |
| dataset_names = args.datasets.split(',') |
| datasets = {} |
|
|
| for dataset_name in dataset_names: |
| dataset_val = get_dataset_eval(config, dataset_name, args.eval_type, predefined_index=False) |
|
|
| if len(dataset_val) % num_tasks != 0: |
| print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' |
| 'This will slightly alter validation results as extra duplicate entries are added to achieve ' |
| 'equal num of samples per-process.') |
| sampler_val = torch.utils.data.DistributedSampler( |
| dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) |
|
|
| curr_data_loader = torch.utils.data.DataLoader( |
| dataset_val, sampler=sampler_val, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| drop_last=False |
| ) |
| datasets[dataset_name] = curr_data_loader |
|
|
| print_freq = 1 |
| header = 'Evaluation: ' |
| metric_logger = dist.MetricLogger(delimiter=" ") |
|
|
| for dataset_name in dataset_names: |
| dataset_save_output_dir = os.path.join(args.save_output_dir, dataset_name) |
| os.makedirs(dataset_save_output_dir, exist_ok=True) |
| curr_data_loader = datasets[dataset_name] |
| |
| for data_iter_step, (idxs, obs_image, gt_image, obs_audio, gt_audio, diffs_seq, delta, orig_obs_audio, orig_gt_audio) in enumerate(metric_logger.log_every(curr_data_loader, print_freq, header)): |
| with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16): |
| obs_image = obs_image[:, -num_cond:].to(device) |
| gt_image = gt_image.to(device) |
| obs_audio = obs_audio[:, -num_cond:].to(device) |
| gt_audio = gt_audio.to(device) |
| orig_obs_audio = orig_obs_audio[:, -num_cond:].to(device) |
| orig_gt_audio = orig_gt_audio.to(device) |
|
|
| diffs_seq = diffs_seq.to(device) |
| obs_av=(obs_image, obs_audio, orig_obs_audio) |
| gt_av=(gt_image, gt_audio, orig_gt_audio) |
| if args.eval_type == 'rollout': |
| curr_rollout_output_dir = os.path.join(dataset_save_output_dir, f'rollout_{args.rollout_frames}frames') |
| os.makedirs(curr_rollout_output_dir, exist_ok=True) |
| generate_rollout(args, curr_rollout_output_dir, args.rollout_frames, idxs, model_lst, obs_av, gt_av, diffs_seq, delta, num_cond, device) |
| elif args.eval_type == 'time': |
| if args.time_secs != '': |
| secs = np.array([int(sec) for sec in args.time_secs.split(',')]) |
| else: |
| secs = np.array([int(sec) for sec in range(1,args.rollout_frames+1)]) |
| curr_time_output_dir = os.path.join(dataset_save_output_dir, 'time') |
| os.makedirs(curr_time_output_dir, exist_ok=True) |
| generate_time(args, curr_time_output_dir, idxs, model_lst, obs_av, gt_av, diffs_seq, delta, secs, num_cond, device) |
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument("--output_dir", type=str, default=None, help="output directory") |
| parser.add_argument("--exp", type=str, default=None, help="experiment name") |
| parser.add_argument("--ckp", type=str, default='0100000') |
| parser.add_argument("--num_sec_eval", type=int, default=5) |
| parser.add_argument("--input_fps", type=int, default=4) |
| parser.add_argument("--datasets", type=str, default=None, help="dataset name") |
| parser.add_argument("--num_workers", type=int, default=8, help="num workers") |
| parser.add_argument("--batch_size", type=int, default=16, help="batch size") |
| parser.add_argument("--eval_type", type=str, default=None, help="type of evaluation has to be either 'time' or 'rollout'") |
| |
| parser.add_argument("--time_secs", type=str, default='', help="") |
| parser.add_argument("--rollout_frames", type=int, default=-1, help="") |
| parser.add_argument("--gt", type=int, default=0, help="set to 1 to produce ground truth evaluation set") |
| args = parser.parse_args() |
| |
| main(args) |