Upload folder using huggingface_hub
Browse files- config/data_config.yaml +10 -0
- config/eval_config.yaml +12 -0
- config/train_config_stage1.yaml +22 -0
- config/train_config_stage2.yaml +25 -0
- config/train_config_stage3.yaml +25 -0
- datasets.py +307 -0
- diffusion/__init__.py +53 -0
- diffusion/diffusion_utils.py +83 -0
- diffusion/gaussian_diffusion.py +870 -0
- diffusion/gaussian_diffusion_dual.py +975 -0
- diffusion/respace.py +125 -0
- diffusion/respace_dual.py +135 -0
- diffusion/timestep_sampler.py +145 -0
- distributed.py +277 -0
- eval_audio.py +210 -0
- eval_metrics.py +1033 -0
- inference_avwm.py +498 -0
- mel_scale.py +221 -0
- merge_experts.py +128 -0
- misc.py +232 -0
- models.py +482 -0
- soundstream.py +178 -0
- train_avwm_stage1.py +463 -0
- train_avwm_stage2.py +514 -0
- train_avwm_stage3.py +532 -0
config/data_config.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
action_stats:
|
| 2 |
+
min: [-2.5, -4] # [min_dx, min_dy]
|
| 3 |
+
max: [5, 4] # [max_dx, max_dy]
|
| 4 |
+
|
| 5 |
+
distance_diff_stats:
|
| 6 |
+
min: [-20] # [min]
|
| 7 |
+
max: [20] # [max]
|
| 8 |
+
|
| 9 |
+
avw_4k:
|
| 10 |
+
metric_waypoint_spacing: 0.15
|
config/eval_config.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
eval_distance:
|
| 2 |
+
eval_min_dist_cat: -16
|
| 3 |
+
eval_max_dist_cat: 16
|
| 4 |
+
eval_len_traj_pred: 16
|
| 5 |
+
eval_context_size: 4
|
| 6 |
+
traj_stride: 8
|
| 7 |
+
|
| 8 |
+
eval_datasets:
|
| 9 |
+
avw_4k:
|
| 10 |
+
data_folder: /path/to/dataset/avw_4k
|
| 11 |
+
test: /path/to/data_splits/avw_4k/test
|
| 12 |
+
goals_per_obs: 4
|
config/train_config_stage1.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
batch_size: 16
|
| 2 |
+
context_size: 4
|
| 3 |
+
datasets:
|
| 4 |
+
avw_4k:
|
| 5 |
+
data_folder: /path/to/dataset/avw_4k
|
| 6 |
+
goals_per_obs: 4
|
| 7 |
+
test: /path/to/data_splits/avw_4k/val
|
| 8 |
+
train: /path/to/data_splits/avw_4k/train
|
| 9 |
+
distance:
|
| 10 |
+
max_dist_cat: 16
|
| 11 |
+
min_dist_cat: -16
|
| 12 |
+
from_checkpoint: /path/to/pretrained/cdit_b_100000.pth.tar
|
| 13 |
+
grad_clip_val: 10.0
|
| 14 |
+
image_size: 224
|
| 15 |
+
len_traj_pred: 16
|
| 16 |
+
lr: 16.0e-05
|
| 17 |
+
model: AVCDiT-B/2
|
| 18 |
+
normalize: true
|
| 19 |
+
num_workers: 1
|
| 20 |
+
results_dir: logs
|
| 21 |
+
run_name: training_stage1
|
| 22 |
+
train: true
|
config/train_config_stage2.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
batch_size: 24
|
| 2 |
+
context_size: 4
|
| 3 |
+
datasets:
|
| 4 |
+
avw_4k:
|
| 5 |
+
data_folder: /path/to/dataset/avw_4k
|
| 6 |
+
goals_per_obs: 4
|
| 7 |
+
test: /path/to/data_splits/avw_4k/val
|
| 8 |
+
train: /path/to/data_splits/avw_4k/train
|
| 9 |
+
distance:
|
| 10 |
+
max_dist_cat: 16
|
| 11 |
+
min_dist_cat: -16
|
| 12 |
+
from_checkpoint: logs/training_stage1/checkpoints/latest.pth.tar
|
| 13 |
+
sample_rate: 16000
|
| 14 |
+
input_sr: 48000
|
| 15 |
+
tokenizer_a_path: /path/to/pretrained/soundstream.pt
|
| 16 |
+
grad_clip_val: 10.0
|
| 17 |
+
image_size: 224
|
| 18 |
+
len_traj_pred: 16
|
| 19 |
+
lr: 8.0e-4
|
| 20 |
+
model: AVCDiT-B/2
|
| 21 |
+
normalize: true
|
| 22 |
+
num_workers: 12
|
| 23 |
+
results_dir: logs
|
| 24 |
+
run_name: training_stage2
|
| 25 |
+
train: true
|
config/train_config_stage3.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
batch_size: 4
|
| 2 |
+
context_size: 4
|
| 3 |
+
datasets:
|
| 4 |
+
avw_4k:
|
| 5 |
+
data_folder: /path/to/dataset/avw_4k
|
| 6 |
+
goals_per_obs: 4
|
| 7 |
+
test: /path/to/data_splits/avw_4k/val
|
| 8 |
+
train: /path/to/data_splits/avw_4k/train
|
| 9 |
+
distance:
|
| 10 |
+
max_dist_cat: 16
|
| 11 |
+
min_dist_cat: -16
|
| 12 |
+
from_checkpoint: /path/to/pretrained/experts_merged.pth
|
| 13 |
+
sample_rate: 16000
|
| 14 |
+
input_sr: 48000
|
| 15 |
+
tokenizer_a_path: /path/to/pretrained/soundstream.pt
|
| 16 |
+
grad_clip_val: 10.0
|
| 17 |
+
image_size: 224
|
| 18 |
+
len_traj_pred: 16
|
| 19 |
+
lr: 16.0e-05
|
| 20 |
+
model: AVCDiT-B/2
|
| 21 |
+
normalize: true
|
| 22 |
+
num_workers: 12
|
| 23 |
+
results_dir: logs
|
| 24 |
+
run_name: training_stage3
|
| 25 |
+
train: true
|
datasets.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# NoMaD, GNM, ViNT: https://github.com/robodhruv/visualnav-transformer
|
| 9 |
+
# --------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import os
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from typing import Tuple
|
| 16 |
+
import yaml
|
| 17 |
+
import pickle
|
| 18 |
+
import tqdm
|
| 19 |
+
from torch.utils.data import Dataset
|
| 20 |
+
from misc import angle_difference, get_data_path, get_delta_np, normalize_data, to_local_coords
|
| 21 |
+
import torchaudio
|
| 22 |
+
|
| 23 |
+
class BaseDataset(Dataset):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
data_folder: str,
|
| 27 |
+
data_split_folder: str,
|
| 28 |
+
dataset_name: str,
|
| 29 |
+
image_size: Tuple[int, int],
|
| 30 |
+
min_dist_cat: int,
|
| 31 |
+
max_dist_cat: int,
|
| 32 |
+
len_traj_pred: int,
|
| 33 |
+
traj_stride: int,
|
| 34 |
+
context_size: int,
|
| 35 |
+
transform: object,
|
| 36 |
+
traj_names: str,
|
| 37 |
+
normalize: bool = True,
|
| 38 |
+
predefined_index: list = None,
|
| 39 |
+
goals_per_obs: int = 1,
|
| 40 |
+
):
|
| 41 |
+
self.data_folder = data_folder
|
| 42 |
+
self.data_split_folder = data_split_folder
|
| 43 |
+
self.dataset_name = dataset_name
|
| 44 |
+
self.goals_per_obs = goals_per_obs
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
traj_names_file = os.path.join(data_split_folder, traj_names)
|
| 48 |
+
with open(traj_names_file, "r") as f:
|
| 49 |
+
file_lines = f.read()
|
| 50 |
+
self.traj_names = file_lines.split("\n")
|
| 51 |
+
if "" in self.traj_names:
|
| 52 |
+
self.traj_names.remove("")
|
| 53 |
+
|
| 54 |
+
self.image_size = image_size
|
| 55 |
+
self.distance_categories = list(range(min_dist_cat, max_dist_cat + 1))
|
| 56 |
+
self.min_dist_cat = self.distance_categories[0]
|
| 57 |
+
self.max_dist_cat = self.distance_categories[-1]
|
| 58 |
+
self.len_traj_pred = len_traj_pred
|
| 59 |
+
self.traj_stride = traj_stride
|
| 60 |
+
|
| 61 |
+
self.context_size = context_size
|
| 62 |
+
self.normalize = normalize
|
| 63 |
+
|
| 64 |
+
# load data/data_config.yaml
|
| 65 |
+
with open("config/data_config.yaml", "r") as f:
|
| 66 |
+
all_data_config = yaml.safe_load(f)
|
| 67 |
+
|
| 68 |
+
dataset_names = list(all_data_config.keys())
|
| 69 |
+
dataset_names.sort()
|
| 70 |
+
# use this index to retrieve the dataset name from the data_config.yaml
|
| 71 |
+
self.data_config = all_data_config[self.dataset_name]
|
| 72 |
+
self.transform = transform
|
| 73 |
+
self._load_index(predefined_index)
|
| 74 |
+
self.ACTION_STATS = {}
|
| 75 |
+
for key in all_data_config['action_stats']:
|
| 76 |
+
self.ACTION_STATS[key] = np.expand_dims(all_data_config['action_stats'][key], axis=0)
|
| 77 |
+
self.DISTANCE_DIFF_STATS = {} # [NEW]
|
| 78 |
+
for key in all_data_config['distance_diff_stats']: # [NEW]
|
| 79 |
+
self.DISTANCE_DIFF_STATS[key] = np.expand_dims(all_data_config['distance_diff_stats'][key], axis=0) # [NEW]
|
| 80 |
+
|
| 81 |
+
def _load_index(self, predefined_index) -> None:
|
| 82 |
+
"""
|
| 83 |
+
Generates a list of tuples of (obs_traj_name, goal_traj_name, obs_time, goal_time) for each observation in the dataset
|
| 84 |
+
"""
|
| 85 |
+
if predefined_index:
|
| 86 |
+
print(f"****** Using a predefined evaluation index... {predefined_index}******")
|
| 87 |
+
with open(predefined_index, "rb") as f:
|
| 88 |
+
self.index_to_data = pickle.load(f)
|
| 89 |
+
return
|
| 90 |
+
else:
|
| 91 |
+
print("****** Evaluating from NON PREDEFINED index... ******")
|
| 92 |
+
index_to_data_path = os.path.join(
|
| 93 |
+
self.data_split_folder,
|
| 94 |
+
f"dataset_dist_{self.min_dist_cat}_to_{self.max_dist_cat}_n{self.context_size}_len_traj_pred_{self.len_traj_pred}.pkl",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.index_to_data, self.goals_index = self._build_index()
|
| 98 |
+
with open(index_to_data_path, "wb") as f:
|
| 99 |
+
pickle.dump((self.index_to_data, self.goals_index), f)
|
| 100 |
+
|
| 101 |
+
def _build_index(self, use_tqdm: bool = False):
|
| 102 |
+
"""
|
| 103 |
+
Build an index consisting of tuples (trajectory name, time, max goal distance)
|
| 104 |
+
"""
|
| 105 |
+
samples_index = []
|
| 106 |
+
goals_index = []
|
| 107 |
+
|
| 108 |
+
for traj_name in tqdm.tqdm(self.traj_names, disable=not use_tqdm, dynamic_ncols=True):
|
| 109 |
+
traj_data = self._get_trajectory(traj_name)
|
| 110 |
+
traj_len = len(traj_data["position"])
|
| 111 |
+
for goal_time in range(0, traj_len):
|
| 112 |
+
goals_index.append((traj_name, goal_time))
|
| 113 |
+
|
| 114 |
+
begin_time = self.context_size - 1
|
| 115 |
+
end_time = traj_len - self.len_traj_pred
|
| 116 |
+
for curr_time in range(begin_time, end_time, self.traj_stride):
|
| 117 |
+
max_goal_distance = min(self.max_dist_cat, traj_len - curr_time - 1)
|
| 118 |
+
min_goal_distance = max(self.min_dist_cat, -curr_time)
|
| 119 |
+
samples_index.append((traj_name, curr_time, min_goal_distance, max_goal_distance))
|
| 120 |
+
|
| 121 |
+
return samples_index, goals_index
|
| 122 |
+
|
| 123 |
+
def _get_trajectory(self, trajectory_name):
|
| 124 |
+
with open(os.path.join(self.data_folder, trajectory_name, "traj_data.pkl"), "rb") as f:
|
| 125 |
+
traj_data = pickle.load(f)
|
| 126 |
+
for k,v in traj_data.items():
|
| 127 |
+
traj_data[k] = v.astype('float')
|
| 128 |
+
return traj_data
|
| 129 |
+
|
| 130 |
+
def __len__(self) -> int:
|
| 131 |
+
return len(self.index_to_data)
|
| 132 |
+
|
| 133 |
+
def _compute_actions(self, traj_data, curr_time, goal_time):
|
| 134 |
+
start_index = curr_time
|
| 135 |
+
end_index = curr_time + self.len_traj_pred + 1
|
| 136 |
+
yaw = traj_data["yaw"][start_index:end_index]
|
| 137 |
+
positions = traj_data["position"][start_index:end_index]
|
| 138 |
+
goal_pos = traj_data["position"][goal_time]
|
| 139 |
+
goal_yaw = traj_data["yaw"][goal_time]
|
| 140 |
+
dist_window = traj_data["distance_to_target"][start_index:end_index] # shape (len_traj_pred+1,) # [NEW]
|
| 141 |
+
goal_dist = traj_data["distance_to_target"][goal_time] # shape (N,) or scalar # [NEW]
|
| 142 |
+
|
| 143 |
+
if len(yaw.shape) == 2:
|
| 144 |
+
yaw = yaw.squeeze(1)
|
| 145 |
+
|
| 146 |
+
if yaw.shape != (self.len_traj_pred + 1,):
|
| 147 |
+
raise ValueError("is used?")
|
| 148 |
+
|
| 149 |
+
waypoints_pos = to_local_coords(positions, positions[0], yaw[0])
|
| 150 |
+
waypoints_yaw = angle_difference(yaw[0], yaw)
|
| 151 |
+
actions = np.concatenate([waypoints_pos, waypoints_yaw.reshape(-1, 1)], axis=-1)
|
| 152 |
+
actions = actions[1:]
|
| 153 |
+
|
| 154 |
+
goal_pos = to_local_coords(goal_pos, positions[0], yaw[0])
|
| 155 |
+
goal_yaw = angle_difference(yaw[0], goal_yaw)
|
| 156 |
+
|
| 157 |
+
diffs_seq = (dist_window[0] - dist_window).reshape(-1, 1)[1:] # [NEW]
|
| 158 |
+
goal_diff = (dist_window[0] - goal_dist).reshape(-1, 1) # [NEW]
|
| 159 |
+
|
| 160 |
+
if self.normalize:
|
| 161 |
+
actions[:, :2] /= self.data_config["metric_waypoint_spacing"]
|
| 162 |
+
goal_pos[:, :2] /= self.data_config["metric_waypoint_spacing"]
|
| 163 |
+
diffs_seq /= self.data_config["metric_waypoint_spacing"] # [NEW]
|
| 164 |
+
goal_diff /= self.data_config["metric_waypoint_spacing"] # [NEW]
|
| 165 |
+
|
| 166 |
+
goal_pos = np.concatenate([goal_pos, goal_yaw.reshape(-1, 1)], axis=-1)
|
| 167 |
+
return actions, goal_pos, diffs_seq, goal_diff
|
| 168 |
+
|
| 169 |
+
class TrainingDataset(BaseDataset):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
data_folder: str,
|
| 173 |
+
data_split_folder: str,
|
| 174 |
+
dataset_name: str,
|
| 175 |
+
image_size: Tuple[int, int],
|
| 176 |
+
min_dist_cat: int,
|
| 177 |
+
max_dist_cat: int,
|
| 178 |
+
len_traj_pred: int,
|
| 179 |
+
traj_stride: int,
|
| 180 |
+
context_size: int,
|
| 181 |
+
transform: object,
|
| 182 |
+
traj_names: str = 'traj_names.txt',
|
| 183 |
+
normalize: bool = True,
|
| 184 |
+
predefined_index: list = None,
|
| 185 |
+
goals_per_obs: int = 1,
|
| 186 |
+
# sample_rate: int = 16000,
|
| 187 |
+
# target_len: int = 7840
|
| 188 |
+
sample_rate: int = 16000,
|
| 189 |
+
input_sr: int = 48000,
|
| 190 |
+
evaluate: bool = False
|
| 191 |
+
):
|
| 192 |
+
super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat,
|
| 193 |
+
len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs)
|
| 194 |
+
self.resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=input_sr, lowpass_filter_width=64)
|
| 195 |
+
self.evaluate = evaluate
|
| 196 |
+
|
| 197 |
+
def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
|
| 198 |
+
try:
|
| 199 |
+
f_curr, curr_time, min_goal_dist, max_goal_dist = self.index_to_data[i]
|
| 200 |
+
goal_offset = np.random.randint(min_goal_dist, max_goal_dist + 1, size=(self.goals_per_obs))
|
| 201 |
+
goal_time = (curr_time + goal_offset).astype('int')
|
| 202 |
+
rel_time = (goal_offset).astype('float')/(128.) # TODO: refactor, currently a fixed const
|
| 203 |
+
|
| 204 |
+
context_times = list(range(curr_time - self.context_size + 1, curr_time + 1))
|
| 205 |
+
context = [(f_curr, t) for t in context_times] + [(f_curr, t) for t in goal_time]
|
| 206 |
+
|
| 207 |
+
obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context])
|
| 208 |
+
obs_audio = torch.stack([torchaudio.load(get_data_path(self.data_folder, f, t, data_type="audio"))[0] for f, t in context])
|
| 209 |
+
if self.evaluate:
|
| 210 |
+
orig_obs_audio = obs_audio
|
| 211 |
+
obs_audio = self.resampler(obs_audio)
|
| 212 |
+
|
| 213 |
+
# Load other trajectory data
|
| 214 |
+
curr_traj_data = self._get_trajectory(f_curr)
|
| 215 |
+
|
| 216 |
+
# Compute actions
|
| 217 |
+
_, goal_pos, _, goal_diff = self._compute_actions(curr_traj_data, curr_time, goal_time)
|
| 218 |
+
goal_pos[:, :2] = normalize_data(goal_pos[:, :2], self.ACTION_STATS)
|
| 219 |
+
goal_diff = normalize_data(goal_diff, self.DISTANCE_DIFF_STATS)
|
| 220 |
+
|
| 221 |
+
if self.evaluate:
|
| 222 |
+
return (
|
| 223 |
+
torch.as_tensor(obs_image, dtype=torch.float32),
|
| 224 |
+
torch.as_tensor(obs_audio, dtype=torch.float32),
|
| 225 |
+
torch.as_tensor(goal_pos, dtype=torch.float32),
|
| 226 |
+
torch.as_tensor(goal_diff, dtype=torch.float32),
|
| 227 |
+
torch.as_tensor(rel_time, dtype=torch.float32),
|
| 228 |
+
torch.as_tensor(orig_obs_audio, dtype=torch.float32),
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
return (
|
| 232 |
+
torch.as_tensor(obs_image, dtype=torch.float32),
|
| 233 |
+
torch.as_tensor(obs_audio, dtype=torch.float32),
|
| 234 |
+
torch.as_tensor(goal_pos, dtype=torch.float32),
|
| 235 |
+
torch.as_tensor(goal_diff, dtype=torch.float32),
|
| 236 |
+
torch.as_tensor(rel_time, dtype=torch.float32),
|
| 237 |
+
)
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f"Exception in {self.dataset_name}", e)
|
| 240 |
+
raise Exception(e)
|
| 241 |
+
|
| 242 |
+
class EvalDataset(BaseDataset):
|
| 243 |
+
def __init__(
|
| 244 |
+
self,
|
| 245 |
+
data_folder: str,
|
| 246 |
+
data_split_folder: str,
|
| 247 |
+
dataset_name: str,
|
| 248 |
+
image_size: Tuple[int, int],
|
| 249 |
+
min_dist_cat: int,
|
| 250 |
+
max_dist_cat: int,
|
| 251 |
+
len_traj_pred: int,
|
| 252 |
+
traj_stride: int,
|
| 253 |
+
context_size: int,
|
| 254 |
+
transform: object,
|
| 255 |
+
traj_names: str,
|
| 256 |
+
normalize: bool = True,
|
| 257 |
+
predefined_index: list = None,
|
| 258 |
+
goals_per_obs: int = 1,
|
| 259 |
+
sample_rate: int = 16000,
|
| 260 |
+
input_sr: int = 48000
|
| 261 |
+
):
|
| 262 |
+
super().__init__(data_folder, data_split_folder, dataset_name, image_size, min_dist_cat, max_dist_cat,
|
| 263 |
+
len_traj_pred, traj_stride, context_size, transform, traj_names, normalize, predefined_index, goals_per_obs)
|
| 264 |
+
self.resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=input_sr, lowpass_filter_width=64)
|
| 265 |
+
|
| 266 |
+
def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
|
| 267 |
+
try:
|
| 268 |
+
f_curr, curr_time, _, _ = self.index_to_data[i]
|
| 269 |
+
context_times = list(range(curr_time - self.context_size + 1, curr_time + 1))
|
| 270 |
+
pred_times = list(range(curr_time + 1, curr_time + self.len_traj_pred + 1))
|
| 271 |
+
|
| 272 |
+
context = [(f_curr, t) for t in context_times]
|
| 273 |
+
pred = [(f_curr, t) for t in pred_times]
|
| 274 |
+
|
| 275 |
+
obs_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in context])
|
| 276 |
+
pred_image = torch.stack([self.transform(Image.open(get_data_path(self.data_folder, f, t))) for f, t in pred])
|
| 277 |
+
|
| 278 |
+
orig_obs_audio = torch.stack([torchaudio.load(get_data_path(self.data_folder, f, t, data_type="audio"))[0] for f, t in context])
|
| 279 |
+
orig_pred_audio = torch.stack([torchaudio.load(get_data_path(self.data_folder, f, t, data_type="audio"))[0] for f, t in pred])
|
| 280 |
+
|
| 281 |
+
obs_audio = self.resampler(orig_obs_audio)
|
| 282 |
+
pred_audio = self.resampler(orig_pred_audio)
|
| 283 |
+
|
| 284 |
+
curr_traj_data = self._get_trajectory(f_curr)
|
| 285 |
+
|
| 286 |
+
# Compute actions
|
| 287 |
+
actions, _, diffs_seq, _ = self._compute_actions(curr_traj_data, curr_time, np.array([curr_time+1])) # last argument is dummy goal
|
| 288 |
+
actions[:, :2] = normalize_data(actions[:, :2], self.ACTION_STATS)
|
| 289 |
+
diffs_seq = normalize_data(diffs_seq, self.DISTANCE_DIFF_STATS)
|
| 290 |
+
|
| 291 |
+
delta = get_delta_np(actions)
|
| 292 |
+
diffs_seq = get_delta_np(diffs_seq)
|
| 293 |
+
|
| 294 |
+
return (
|
| 295 |
+
torch.tensor([i], dtype=torch.float32), # for logging purposes
|
| 296 |
+
torch.as_tensor(obs_image, dtype=torch.float32),
|
| 297 |
+
torch.as_tensor(pred_image, dtype=torch.float32),
|
| 298 |
+
torch.as_tensor(obs_audio, dtype=torch.float32),
|
| 299 |
+
torch.as_tensor(pred_audio, dtype=torch.float32),
|
| 300 |
+
torch.as_tensor(diffs_seq, dtype=torch.float32),
|
| 301 |
+
torch.as_tensor(delta, dtype=torch.float32),
|
| 302 |
+
torch.as_tensor(orig_obs_audio, dtype=torch.float32),
|
| 303 |
+
torch.as_tensor(orig_pred_audio, dtype=torch.float32),
|
| 304 |
+
)
|
| 305 |
+
except Exception as e:
|
| 306 |
+
print(f"Exception in {self.dataset_name}", e)
|
| 307 |
+
raise Exception(e)
|
diffusion/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import gaussian_diffusion as gd_orig
|
| 2 |
+
from . import gaussian_diffusion_dual as gd_dual
|
| 3 |
+
# from .respace import SpacedDiffusion, space_timesteps
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def create_diffusion(
|
| 7 |
+
timestep_respacing,
|
| 8 |
+
noise_schedule="linear",
|
| 9 |
+
use_kl=False,
|
| 10 |
+
sigma_small=False,
|
| 11 |
+
predict_xstart=False,
|
| 12 |
+
learn_sigma=True,
|
| 13 |
+
rescale_learned_sigmas=False,
|
| 14 |
+
diffusion_steps=1000,
|
| 15 |
+
dual=False
|
| 16 |
+
):
|
| 17 |
+
if dual:
|
| 18 |
+
print("Using DUAL diffusion")
|
| 19 |
+
from .respace_dual import SpacedDiffusion, space_timesteps
|
| 20 |
+
gd_module = gd_dual
|
| 21 |
+
else:
|
| 22 |
+
print("Using SINGLE diffusion")
|
| 23 |
+
from .respace import SpacedDiffusion, space_timesteps
|
| 24 |
+
gd_module = gd_orig
|
| 25 |
+
|
| 26 |
+
betas = gd_module.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
| 27 |
+
# betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
| 28 |
+
if use_kl:
|
| 29 |
+
loss_type = gd_module.LossType.RESCALED_KL
|
| 30 |
+
elif rescale_learned_sigmas:
|
| 31 |
+
loss_type = gd_module.LossType.RESCALED_MSE
|
| 32 |
+
else:
|
| 33 |
+
loss_type = gd_module.LossType.MSE
|
| 34 |
+
if timestep_respacing is None or timestep_respacing == "":
|
| 35 |
+
timestep_respacing = [diffusion_steps]
|
| 36 |
+
return SpacedDiffusion(
|
| 37 |
+
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
| 38 |
+
betas=betas,
|
| 39 |
+
model_mean_type=(
|
| 40 |
+
gd_module.ModelMeanType.EPSILON if not predict_xstart else gd_module.ModelMeanType.START_X
|
| 41 |
+
),
|
| 42 |
+
model_var_type=(
|
| 43 |
+
(
|
| 44 |
+
gd_module.ModelVarType.FIXED_LARGE
|
| 45 |
+
if not sigma_small
|
| 46 |
+
else gd_module.ModelVarType.FIXED_SMALL
|
| 47 |
+
)
|
| 48 |
+
if not learn_sigma
|
| 49 |
+
else gd_module.ModelVarType.LEARNED_RANGE
|
| 50 |
+
),
|
| 51 |
+
loss_type=loss_type
|
| 52 |
+
# rescale_timesteps=rescale_timesteps,
|
| 53 |
+
)
|
diffusion/diffusion_utils.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch as th
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
| 6 |
+
"""
|
| 7 |
+
Compute the KL divergence between two gaussians.
|
| 8 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
| 9 |
+
scalars, among other use cases.
|
| 10 |
+
"""
|
| 11 |
+
tensor = None
|
| 12 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
| 13 |
+
if isinstance(obj, th.Tensor):
|
| 14 |
+
tensor = obj
|
| 15 |
+
break
|
| 16 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
| 17 |
+
|
| 18 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
| 19 |
+
# Tensors, but it does not work for th.exp().
|
| 20 |
+
logvar1, logvar2 = [
|
| 21 |
+
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
| 22 |
+
for x in (logvar1, logvar2)
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
return 0.5 * (
|
| 26 |
+
-1.0
|
| 27 |
+
+ logvar2
|
| 28 |
+
- logvar1
|
| 29 |
+
+ th.exp(logvar1 - logvar2)
|
| 30 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def approx_standard_normal_cdf(x):
|
| 35 |
+
"""
|
| 36 |
+
A fast approximation of the cumulative distribution function of the
|
| 37 |
+
standard normal.
|
| 38 |
+
"""
|
| 39 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
| 43 |
+
"""
|
| 44 |
+
Compute the log-likelihood of a continuous Gaussian distribution.
|
| 45 |
+
:param x: the targets
|
| 46 |
+
:param means: the Gaussian mean Tensor.
|
| 47 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
| 48 |
+
:return: a tensor like x of log probabilities (in nats).
|
| 49 |
+
"""
|
| 50 |
+
centered_x = x - means
|
| 51 |
+
inv_stdv = th.exp(-log_scales)
|
| 52 |
+
normalized_x = centered_x * inv_stdv
|
| 53 |
+
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
|
| 54 |
+
return log_probs
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
| 58 |
+
"""
|
| 59 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
| 60 |
+
given image.
|
| 61 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
| 62 |
+
rescaled to the range [-1, 1].
|
| 63 |
+
:param means: the Gaussian mean Tensor.
|
| 64 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
| 65 |
+
:return: a tensor like x of log probabilities (in nats).
|
| 66 |
+
"""
|
| 67 |
+
assert x.shape == means.shape == log_scales.shape
|
| 68 |
+
centered_x = x - means
|
| 69 |
+
inv_stdv = th.exp(-log_scales)
|
| 70 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
| 71 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
| 72 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
| 73 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
| 74 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
| 75 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
| 76 |
+
cdf_delta = cdf_plus - cdf_min
|
| 77 |
+
log_probs = th.where(
|
| 78 |
+
x < -0.999,
|
| 79 |
+
log_cdf_plus,
|
| 80 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
| 81 |
+
)
|
| 82 |
+
assert log_probs.shape == x.shape
|
| 83 |
+
return log_probs
|
diffusion/gaussian_diffusion.py
ADDED
|
@@ -0,0 +1,870 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch as th
|
| 5 |
+
import enum
|
| 6 |
+
|
| 7 |
+
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def mean_flat(tensor):
|
| 13 |
+
"""
|
| 14 |
+
Take the mean over all non-batch dimensions.
|
| 15 |
+
"""
|
| 16 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ModelMeanType(enum.Enum):
|
| 20 |
+
"""
|
| 21 |
+
Which type of output the model predicts.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
| 25 |
+
START_X = enum.auto() # the model predicts x_0
|
| 26 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ModelVarType(enum.Enum):
|
| 30 |
+
"""
|
| 31 |
+
What is used as the model's output variance.
|
| 32 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
| 33 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
LEARNED = enum.auto()
|
| 37 |
+
FIXED_SMALL = enum.auto()
|
| 38 |
+
FIXED_LARGE = enum.auto()
|
| 39 |
+
LEARNED_RANGE = enum.auto()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class LossType(enum.Enum):
|
| 43 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
| 44 |
+
RESCALED_MSE = (
|
| 45 |
+
enum.auto()
|
| 46 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
| 47 |
+
KL = enum.auto() # use the variational lower-bound
|
| 48 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
| 49 |
+
|
| 50 |
+
def is_vb(self):
|
| 51 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
| 55 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
| 56 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
| 57 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
| 58 |
+
return betas
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
| 62 |
+
"""
|
| 63 |
+
This is the deprecated API for creating beta schedules.
|
| 64 |
+
See get_named_beta_schedule() for the new library of schedules.
|
| 65 |
+
"""
|
| 66 |
+
if beta_schedule == "quad":
|
| 67 |
+
betas = (
|
| 68 |
+
np.linspace(
|
| 69 |
+
beta_start ** 0.5,
|
| 70 |
+
beta_end ** 0.5,
|
| 71 |
+
num_diffusion_timesteps,
|
| 72 |
+
dtype=np.float64,
|
| 73 |
+
)
|
| 74 |
+
** 2
|
| 75 |
+
)
|
| 76 |
+
elif beta_schedule == "linear":
|
| 77 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
| 78 |
+
elif beta_schedule == "warmup10":
|
| 79 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
| 80 |
+
elif beta_schedule == "warmup50":
|
| 81 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
| 82 |
+
elif beta_schedule == "const":
|
| 83 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
| 84 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
| 85 |
+
betas = 1.0 / np.linspace(
|
| 86 |
+
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
raise NotImplementedError(beta_schedule)
|
| 90 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
| 91 |
+
return betas
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
| 95 |
+
"""
|
| 96 |
+
Get a pre-defined beta schedule for the given name.
|
| 97 |
+
The beta schedule library consists of beta schedules which remain similar
|
| 98 |
+
in the limit of num_diffusion_timesteps.
|
| 99 |
+
Beta schedules may be added, but should not be removed or changed once
|
| 100 |
+
they are committed to maintain backwards compatibility.
|
| 101 |
+
"""
|
| 102 |
+
if schedule_name == "linear":
|
| 103 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
| 104 |
+
# diffusion steps.
|
| 105 |
+
scale = 1000 / num_diffusion_timesteps
|
| 106 |
+
return get_beta_schedule(
|
| 107 |
+
"linear",
|
| 108 |
+
beta_start=scale * 0.0001,
|
| 109 |
+
beta_end=scale * 0.02,
|
| 110 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
| 111 |
+
)
|
| 112 |
+
elif schedule_name == "squaredcos_cap_v2":
|
| 113 |
+
return betas_for_alpha_bar(
|
| 114 |
+
num_diffusion_timesteps,
|
| 115 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
| 116 |
+
)
|
| 117 |
+
else:
|
| 118 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
| 122 |
+
"""
|
| 123 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
| 124 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
| 125 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
| 126 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
| 127 |
+
produces the cumulative product of (1-beta) up to that
|
| 128 |
+
part of the diffusion process.
|
| 129 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
| 130 |
+
prevent singularities.
|
| 131 |
+
"""
|
| 132 |
+
betas = []
|
| 133 |
+
for i in range(num_diffusion_timesteps):
|
| 134 |
+
t1 = i / num_diffusion_timesteps
|
| 135 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 136 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
| 137 |
+
return np.array(betas)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class GaussianDiffusion:
|
| 141 |
+
"""
|
| 142 |
+
Utilities for training and sampling diffusion models.
|
| 143 |
+
Original ported from this codebase:
|
| 144 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
| 145 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
| 146 |
+
starting at T and going to 1.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
*,
|
| 152 |
+
betas,
|
| 153 |
+
model_mean_type,
|
| 154 |
+
model_var_type,
|
| 155 |
+
loss_type
|
| 156 |
+
):
|
| 157 |
+
|
| 158 |
+
self.model_mean_type = model_mean_type
|
| 159 |
+
self.model_var_type = model_var_type
|
| 160 |
+
self.loss_type = loss_type
|
| 161 |
+
|
| 162 |
+
# Use float64 for accuracy.
|
| 163 |
+
betas = np.array(betas, dtype=np.float64)
|
| 164 |
+
self.betas = betas
|
| 165 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
| 166 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
| 167 |
+
|
| 168 |
+
self.num_timesteps = int(betas.shape[0])
|
| 169 |
+
|
| 170 |
+
alphas = 1.0 - betas
|
| 171 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 172 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
| 173 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
| 174 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
| 175 |
+
|
| 176 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 177 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
| 178 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
| 179 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
| 180 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
| 181 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
| 182 |
+
|
| 183 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 184 |
+
self.posterior_variance = (
|
| 185 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
| 186 |
+
)
|
| 187 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 188 |
+
self.posterior_log_variance_clipped = np.log(
|
| 189 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
| 190 |
+
) if len(self.posterior_variance) > 1 else np.array([])
|
| 191 |
+
|
| 192 |
+
self.posterior_mean_coef1 = (
|
| 193 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
| 194 |
+
)
|
| 195 |
+
self.posterior_mean_coef2 = (
|
| 196 |
+
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def q_mean_variance(self, x_start, t):
|
| 200 |
+
"""
|
| 201 |
+
Get the distribution q(x_t | x_0).
|
| 202 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
| 203 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
| 204 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
| 205 |
+
"""
|
| 206 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 207 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
| 208 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
| 209 |
+
return mean, variance, log_variance
|
| 210 |
+
|
| 211 |
+
def q_sample(self, x_start, t, noise=None):
|
| 212 |
+
"""
|
| 213 |
+
Diffuse the data for a given number of diffusion steps.
|
| 214 |
+
In other words, sample from q(x_t | x_0).
|
| 215 |
+
:param x_start: the initial data batch.
|
| 216 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
| 217 |
+
:param noise: if specified, the split-out normal noise.
|
| 218 |
+
:return: A noisy version of x_start.
|
| 219 |
+
"""
|
| 220 |
+
if noise is None:
|
| 221 |
+
noise = th.randn_like(x_start)
|
| 222 |
+
assert noise.shape == x_start.shape
|
| 223 |
+
return (
|
| 224 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 225 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
| 229 |
+
"""
|
| 230 |
+
Compute the mean and variance of the diffusion posterior:
|
| 231 |
+
q(x_{t-1} | x_t, x_0)
|
| 232 |
+
"""
|
| 233 |
+
assert x_start.shape == x_t.shape
|
| 234 |
+
posterior_mean = (
|
| 235 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
| 236 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 237 |
+
)
|
| 238 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
| 239 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
| 240 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
| 241 |
+
)
|
| 242 |
+
assert (
|
| 243 |
+
posterior_mean.shape[0]
|
| 244 |
+
== posterior_variance.shape[0]
|
| 245 |
+
== posterior_log_variance_clipped.shape[0]
|
| 246 |
+
== x_start.shape[0]
|
| 247 |
+
)
|
| 248 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 249 |
+
|
| 250 |
+
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
| 251 |
+
"""
|
| 252 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
| 253 |
+
the initial x, x_0.
|
| 254 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
| 255 |
+
as input.
|
| 256 |
+
:param x: the [N x C x ...] tensor at time t.
|
| 257 |
+
:param t: a 1-D Tensor of timesteps.
|
| 258 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
| 259 |
+
:param denoised_fn: if not None, a function which applies to the
|
| 260 |
+
x_start prediction before it is used to sample. Applies before
|
| 261 |
+
clip_denoised.
|
| 262 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 263 |
+
pass to the model. This can be used for conditioning.
|
| 264 |
+
:return: a dict with the following keys:
|
| 265 |
+
- 'mean': the model mean output.
|
| 266 |
+
- 'variance': the model variance output.
|
| 267 |
+
- 'log_variance': the log of 'variance'.
|
| 268 |
+
- 'pred_xstart': the prediction for x_0.
|
| 269 |
+
"""
|
| 270 |
+
if model_kwargs is None:
|
| 271 |
+
model_kwargs = {}
|
| 272 |
+
|
| 273 |
+
B, C = x.shape[:2]
|
| 274 |
+
assert t.shape == (B,)
|
| 275 |
+
model_output = model(x, t, **model_kwargs)
|
| 276 |
+
if isinstance(model_output, tuple):
|
| 277 |
+
model_output, extra = model_output
|
| 278 |
+
else:
|
| 279 |
+
extra = None
|
| 280 |
+
|
| 281 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
| 282 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
| 283 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
| 284 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
| 285 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
| 286 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
| 287 |
+
frac = (model_var_values + 1) / 2
|
| 288 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
| 289 |
+
model_variance = th.exp(model_log_variance)
|
| 290 |
+
else:
|
| 291 |
+
model_variance, model_log_variance = {
|
| 292 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
| 293 |
+
# to get a better decoder log likelihood.
|
| 294 |
+
ModelVarType.FIXED_LARGE: (
|
| 295 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
| 296 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
| 297 |
+
),
|
| 298 |
+
ModelVarType.FIXED_SMALL: (
|
| 299 |
+
self.posterior_variance,
|
| 300 |
+
self.posterior_log_variance_clipped,
|
| 301 |
+
),
|
| 302 |
+
}[self.model_var_type]
|
| 303 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
| 304 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
| 305 |
+
|
| 306 |
+
def process_xstart(x):
|
| 307 |
+
if denoised_fn is not None:
|
| 308 |
+
x = denoised_fn(x)
|
| 309 |
+
if clip_denoised:
|
| 310 |
+
return x.clamp(-1, 1)
|
| 311 |
+
return x
|
| 312 |
+
|
| 313 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
| 314 |
+
pred_xstart = process_xstart(model_output)
|
| 315 |
+
else:
|
| 316 |
+
pred_xstart = process_xstart(
|
| 317 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
| 318 |
+
)
|
| 319 |
+
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
| 320 |
+
|
| 321 |
+
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
| 322 |
+
return {
|
| 323 |
+
"mean": model_mean,
|
| 324 |
+
"variance": model_variance,
|
| 325 |
+
"log_variance": model_log_variance,
|
| 326 |
+
"pred_xstart": pred_xstart,
|
| 327 |
+
"extra": extra,
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
| 331 |
+
assert x_t.shape == eps.shape
|
| 332 |
+
return (
|
| 333 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
| 334 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
| 338 |
+
return (
|
| 339 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
| 340 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 341 |
+
|
| 342 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
| 343 |
+
"""
|
| 344 |
+
Compute the mean for the previous step, given a function cond_fn that
|
| 345 |
+
computes the gradient of a conditional log probability with respect to
|
| 346 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
| 347 |
+
condition on y.
|
| 348 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
| 349 |
+
"""
|
| 350 |
+
gradient = cond_fn(x, t, **model_kwargs)
|
| 351 |
+
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
| 352 |
+
return new_mean
|
| 353 |
+
|
| 354 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
| 355 |
+
"""
|
| 356 |
+
Compute what the p_mean_variance output would have been, should the
|
| 357 |
+
model's score function be conditioned by cond_fn.
|
| 358 |
+
See condition_mean() for details on cond_fn.
|
| 359 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
| 360 |
+
from Song et al (2020).
|
| 361 |
+
"""
|
| 362 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
| 363 |
+
|
| 364 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
| 365 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
| 366 |
+
|
| 367 |
+
out = p_mean_var.copy()
|
| 368 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
| 369 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
| 370 |
+
return out
|
| 371 |
+
|
| 372 |
+
def p_sample(
|
| 373 |
+
self,
|
| 374 |
+
model,
|
| 375 |
+
x,
|
| 376 |
+
t,
|
| 377 |
+
clip_denoised=True,
|
| 378 |
+
denoised_fn=None,
|
| 379 |
+
cond_fn=None,
|
| 380 |
+
model_kwargs=None,
|
| 381 |
+
):
|
| 382 |
+
"""
|
| 383 |
+
Sample x_{t-1} from the model at the given timestep.
|
| 384 |
+
:param model: the model to sample from.
|
| 385 |
+
:param x: the current tensor at x_{t-1}.
|
| 386 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
| 387 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
| 388 |
+
:param denoised_fn: if not None, a function which applies to the
|
| 389 |
+
x_start prediction before it is used to sample.
|
| 390 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
| 391 |
+
similarly to the model.
|
| 392 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 393 |
+
pass to the model. This can be used for conditioning.
|
| 394 |
+
:return: a dict containing the following keys:
|
| 395 |
+
- 'sample': a random sample from the model.
|
| 396 |
+
- 'pred_xstart': a prediction of x_0.
|
| 397 |
+
"""
|
| 398 |
+
out = self.p_mean_variance(
|
| 399 |
+
model,
|
| 400 |
+
x,
|
| 401 |
+
t,
|
| 402 |
+
clip_denoised=clip_denoised,
|
| 403 |
+
denoised_fn=denoised_fn,
|
| 404 |
+
model_kwargs=model_kwargs,
|
| 405 |
+
)
|
| 406 |
+
noise = th.randn_like(x)
|
| 407 |
+
nonzero_mask = (
|
| 408 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
| 409 |
+
) # no noise when t == 0
|
| 410 |
+
if cond_fn is not None:
|
| 411 |
+
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
| 412 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
| 413 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
| 414 |
+
|
| 415 |
+
def p_sample_loop(
|
| 416 |
+
self,
|
| 417 |
+
model,
|
| 418 |
+
shape,
|
| 419 |
+
noise=None,
|
| 420 |
+
clip_denoised=True,
|
| 421 |
+
denoised_fn=None,
|
| 422 |
+
cond_fn=None,
|
| 423 |
+
model_kwargs=None,
|
| 424 |
+
device=None,
|
| 425 |
+
progress=False,
|
| 426 |
+
):
|
| 427 |
+
"""
|
| 428 |
+
Generate samples from the model.
|
| 429 |
+
:param model: the model module.
|
| 430 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
| 431 |
+
:param noise: if specified, the noise from the encoder to sample.
|
| 432 |
+
Should be of the same shape as `shape`.
|
| 433 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
| 434 |
+
:param denoised_fn: if not None, a function which applies to the
|
| 435 |
+
x_start prediction before it is used to sample.
|
| 436 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
| 437 |
+
similarly to the model.
|
| 438 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 439 |
+
pass to the model. This can be used for conditioning.
|
| 440 |
+
:param device: if specified, the device to create the samples on.
|
| 441 |
+
If not specified, use a model parameter's device.
|
| 442 |
+
:param progress: if True, show a tqdm progress bar.
|
| 443 |
+
:return: a non-differentiable batch of samples.
|
| 444 |
+
"""
|
| 445 |
+
final = None
|
| 446 |
+
for sample in self.p_sample_loop_progressive(
|
| 447 |
+
model,
|
| 448 |
+
shape,
|
| 449 |
+
noise=noise,
|
| 450 |
+
clip_denoised=clip_denoised,
|
| 451 |
+
denoised_fn=denoised_fn,
|
| 452 |
+
cond_fn=cond_fn,
|
| 453 |
+
model_kwargs=model_kwargs,
|
| 454 |
+
device=device,
|
| 455 |
+
progress=progress,
|
| 456 |
+
):
|
| 457 |
+
final = sample
|
| 458 |
+
return final["sample"]
|
| 459 |
+
|
| 460 |
+
def p_sample_loop_progressive(
|
| 461 |
+
self,
|
| 462 |
+
model,
|
| 463 |
+
shape,
|
| 464 |
+
noise=None,
|
| 465 |
+
clip_denoised=True,
|
| 466 |
+
denoised_fn=None,
|
| 467 |
+
cond_fn=None,
|
| 468 |
+
model_kwargs=None,
|
| 469 |
+
device=None,
|
| 470 |
+
progress=False,
|
| 471 |
+
):
|
| 472 |
+
"""
|
| 473 |
+
Generate samples from the model and yield intermediate samples from
|
| 474 |
+
each timestep of diffusion.
|
| 475 |
+
Arguments are the same as p_sample_loop().
|
| 476 |
+
Returns a generator over dicts, where each dict is the return value of
|
| 477 |
+
p_sample().
|
| 478 |
+
"""
|
| 479 |
+
if device is None:
|
| 480 |
+
device = next(model.parameters()).device
|
| 481 |
+
assert isinstance(shape, (tuple, list))
|
| 482 |
+
if noise is not None:
|
| 483 |
+
img = noise
|
| 484 |
+
else:
|
| 485 |
+
img = th.randn(*shape, device=device)
|
| 486 |
+
indices = list(range(self.num_timesteps))[::-1]
|
| 487 |
+
|
| 488 |
+
if progress:
|
| 489 |
+
# Lazy import so that we don't depend on tqdm.
|
| 490 |
+
from tqdm.auto import tqdm
|
| 491 |
+
|
| 492 |
+
indices = tqdm(indices)
|
| 493 |
+
|
| 494 |
+
for i in indices:
|
| 495 |
+
t = th.tensor([i] * shape[0], device=device)
|
| 496 |
+
with th.no_grad():
|
| 497 |
+
out = self.p_sample(
|
| 498 |
+
model,
|
| 499 |
+
img,
|
| 500 |
+
t,
|
| 501 |
+
clip_denoised=clip_denoised,
|
| 502 |
+
denoised_fn=denoised_fn,
|
| 503 |
+
cond_fn=cond_fn,
|
| 504 |
+
model_kwargs=model_kwargs,
|
| 505 |
+
)
|
| 506 |
+
yield out
|
| 507 |
+
img = out["sample"]
|
| 508 |
+
|
| 509 |
+
def ddim_sample(
|
| 510 |
+
self,
|
| 511 |
+
model,
|
| 512 |
+
x,
|
| 513 |
+
t,
|
| 514 |
+
clip_denoised=True,
|
| 515 |
+
denoised_fn=None,
|
| 516 |
+
cond_fn=None,
|
| 517 |
+
model_kwargs=None,
|
| 518 |
+
eta=0.0,
|
| 519 |
+
):
|
| 520 |
+
"""
|
| 521 |
+
Sample x_{t-1} from the model using DDIM.
|
| 522 |
+
Same usage as p_sample().
|
| 523 |
+
"""
|
| 524 |
+
out = self.p_mean_variance(
|
| 525 |
+
model,
|
| 526 |
+
x,
|
| 527 |
+
t,
|
| 528 |
+
clip_denoised=clip_denoised,
|
| 529 |
+
denoised_fn=denoised_fn,
|
| 530 |
+
model_kwargs=model_kwargs,
|
| 531 |
+
)
|
| 532 |
+
if cond_fn is not None:
|
| 533 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
| 534 |
+
|
| 535 |
+
# Usually our model outputs epsilon, but we re-derive it
|
| 536 |
+
# in case we used x_start or x_prev prediction.
|
| 537 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
| 538 |
+
|
| 539 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
| 540 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
| 541 |
+
sigma = (
|
| 542 |
+
eta
|
| 543 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
| 544 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
| 545 |
+
)
|
| 546 |
+
# Equation 12.
|
| 547 |
+
noise = th.randn_like(x)
|
| 548 |
+
mean_pred = (
|
| 549 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
| 550 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
| 551 |
+
)
|
| 552 |
+
nonzero_mask = (
|
| 553 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
| 554 |
+
) # no noise when t == 0
|
| 555 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
| 556 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
| 557 |
+
|
| 558 |
+
def ddim_reverse_sample(
|
| 559 |
+
self,
|
| 560 |
+
model,
|
| 561 |
+
x,
|
| 562 |
+
t,
|
| 563 |
+
clip_denoised=True,
|
| 564 |
+
denoised_fn=None,
|
| 565 |
+
cond_fn=None,
|
| 566 |
+
model_kwargs=None,
|
| 567 |
+
eta=0.0,
|
| 568 |
+
):
|
| 569 |
+
"""
|
| 570 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
| 571 |
+
"""
|
| 572 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
| 573 |
+
out = self.p_mean_variance(
|
| 574 |
+
model,
|
| 575 |
+
x,
|
| 576 |
+
t,
|
| 577 |
+
clip_denoised=clip_denoised,
|
| 578 |
+
denoised_fn=denoised_fn,
|
| 579 |
+
model_kwargs=model_kwargs,
|
| 580 |
+
)
|
| 581 |
+
if cond_fn is not None:
|
| 582 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
| 583 |
+
# Usually our model outputs epsilon, but we re-derive it
|
| 584 |
+
# in case we used x_start or x_prev prediction.
|
| 585 |
+
eps = (
|
| 586 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
| 587 |
+
- out["pred_xstart"]
|
| 588 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
| 589 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
| 590 |
+
|
| 591 |
+
# Equation 12. reversed
|
| 592 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
| 593 |
+
|
| 594 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
| 595 |
+
|
| 596 |
+
def ddim_sample_loop(
|
| 597 |
+
self,
|
| 598 |
+
model,
|
| 599 |
+
shape,
|
| 600 |
+
noise=None,
|
| 601 |
+
clip_denoised=True,
|
| 602 |
+
denoised_fn=None,
|
| 603 |
+
cond_fn=None,
|
| 604 |
+
model_kwargs=None,
|
| 605 |
+
device=None,
|
| 606 |
+
progress=False,
|
| 607 |
+
eta=0.0,
|
| 608 |
+
):
|
| 609 |
+
"""
|
| 610 |
+
Generate samples from the model using DDIM.
|
| 611 |
+
Same usage as p_sample_loop().
|
| 612 |
+
"""
|
| 613 |
+
final = None
|
| 614 |
+
for sample in self.ddim_sample_loop_progressive(
|
| 615 |
+
model,
|
| 616 |
+
shape,
|
| 617 |
+
noise=noise,
|
| 618 |
+
clip_denoised=clip_denoised,
|
| 619 |
+
denoised_fn=denoised_fn,
|
| 620 |
+
cond_fn=cond_fn,
|
| 621 |
+
model_kwargs=model_kwargs,
|
| 622 |
+
device=device,
|
| 623 |
+
progress=progress,
|
| 624 |
+
eta=eta,
|
| 625 |
+
):
|
| 626 |
+
final = sample
|
| 627 |
+
return final["sample"]
|
| 628 |
+
|
| 629 |
+
def ddim_sample_loop_progressive(
|
| 630 |
+
self,
|
| 631 |
+
model,
|
| 632 |
+
shape,
|
| 633 |
+
noise=None,
|
| 634 |
+
clip_denoised=True,
|
| 635 |
+
denoised_fn=None,
|
| 636 |
+
cond_fn=None,
|
| 637 |
+
model_kwargs=None,
|
| 638 |
+
device=None,
|
| 639 |
+
progress=False,
|
| 640 |
+
eta=0.0,
|
| 641 |
+
):
|
| 642 |
+
"""
|
| 643 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
| 644 |
+
each timestep of DDIM.
|
| 645 |
+
Same usage as p_sample_loop_progressive().
|
| 646 |
+
"""
|
| 647 |
+
if device is None:
|
| 648 |
+
device = next(model.parameters()).device
|
| 649 |
+
assert isinstance(shape, (tuple, list))
|
| 650 |
+
if noise is not None:
|
| 651 |
+
img = noise
|
| 652 |
+
else:
|
| 653 |
+
img = th.randn(*shape, device=device)
|
| 654 |
+
indices = list(range(self.num_timesteps))[::-1]
|
| 655 |
+
|
| 656 |
+
if progress:
|
| 657 |
+
# Lazy import so that we don't depend on tqdm.
|
| 658 |
+
from tqdm.auto import tqdm
|
| 659 |
+
|
| 660 |
+
indices = tqdm(indices)
|
| 661 |
+
|
| 662 |
+
for i in indices:
|
| 663 |
+
t = th.tensor([i] * shape[0], device=device)
|
| 664 |
+
with th.no_grad():
|
| 665 |
+
out = self.ddim_sample(
|
| 666 |
+
model,
|
| 667 |
+
img,
|
| 668 |
+
t,
|
| 669 |
+
clip_denoised=clip_denoised,
|
| 670 |
+
denoised_fn=denoised_fn,
|
| 671 |
+
cond_fn=cond_fn,
|
| 672 |
+
model_kwargs=model_kwargs,
|
| 673 |
+
eta=eta,
|
| 674 |
+
)
|
| 675 |
+
yield out
|
| 676 |
+
img = out["sample"]
|
| 677 |
+
|
| 678 |
+
def _vb_terms_bpd(
|
| 679 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
| 680 |
+
):
|
| 681 |
+
"""
|
| 682 |
+
Get a term for the variational lower-bound.
|
| 683 |
+
The resulting units are bits (rather than nats, as one might expect).
|
| 684 |
+
This allows for comparison to other papers.
|
| 685 |
+
:return: a dict with the following keys:
|
| 686 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
| 687 |
+
- 'pred_xstart': the x_0 predictions.
|
| 688 |
+
"""
|
| 689 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
| 690 |
+
x_start=x_start, x_t=x_t, t=t
|
| 691 |
+
)
|
| 692 |
+
out = self.p_mean_variance(
|
| 693 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
| 694 |
+
)
|
| 695 |
+
kl = normal_kl(
|
| 696 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
| 697 |
+
)
|
| 698 |
+
kl = mean_flat(kl) / np.log(2.0)
|
| 699 |
+
|
| 700 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
| 701 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
| 702 |
+
)
|
| 703 |
+
assert decoder_nll.shape == x_start.shape
|
| 704 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
| 705 |
+
|
| 706 |
+
# At the first timestep return the decoder NLL,
|
| 707 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
| 708 |
+
output = th.where((t == 0), decoder_nll, kl)
|
| 709 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
| 710 |
+
|
| 711 |
+
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
| 712 |
+
"""
|
| 713 |
+
Compute training losses for a single timestep.
|
| 714 |
+
:param model: the model to evaluate loss on.
|
| 715 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
| 716 |
+
:param t: a batch of timestep indices.
|
| 717 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 718 |
+
pass to the model. This can be used for conditioning.
|
| 719 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
| 720 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
| 721 |
+
Some mean or variance settings may also have other keys.
|
| 722 |
+
"""
|
| 723 |
+
if model_kwargs is None:
|
| 724 |
+
model_kwargs = {}
|
| 725 |
+
if noise is None:
|
| 726 |
+
noise = th.randn_like(x_start)
|
| 727 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
| 728 |
+
|
| 729 |
+
terms = {}
|
| 730 |
+
|
| 731 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
| 732 |
+
terms["loss"] = self._vb_terms_bpd(
|
| 733 |
+
model=model,
|
| 734 |
+
x_start=x_start,
|
| 735 |
+
x_t=x_t,
|
| 736 |
+
t=t,
|
| 737 |
+
clip_denoised=False,
|
| 738 |
+
model_kwargs=model_kwargs,
|
| 739 |
+
)["output"]
|
| 740 |
+
if self.loss_type == LossType.RESCALED_KL:
|
| 741 |
+
terms["loss"] *= self.num_timesteps
|
| 742 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
| 743 |
+
model_output = model(x_t, t, **model_kwargs)
|
| 744 |
+
|
| 745 |
+
if self.model_var_type in [
|
| 746 |
+
ModelVarType.LEARNED,
|
| 747 |
+
ModelVarType.LEARNED_RANGE,
|
| 748 |
+
]:
|
| 749 |
+
B, C = x_t.shape[:2]
|
| 750 |
+
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
|
| 751 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
| 752 |
+
# Learn the variance using the variational bound, but don't let
|
| 753 |
+
# it affect our mean prediction.
|
| 754 |
+
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
| 755 |
+
terms["vb"] = self._vb_terms_bpd(
|
| 756 |
+
model=lambda *args, r=frozen_out: r,
|
| 757 |
+
x_start=x_start,
|
| 758 |
+
x_t=x_t,
|
| 759 |
+
t=t,
|
| 760 |
+
clip_denoised=False,
|
| 761 |
+
)["output"]
|
| 762 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
| 763 |
+
# Divide by 1000 for equivalence with initial implementation.
|
| 764 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
| 765 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
| 766 |
+
|
| 767 |
+
target = {
|
| 768 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
| 769 |
+
x_start=x_start, x_t=x_t, t=t
|
| 770 |
+
)[0],
|
| 771 |
+
ModelMeanType.START_X: x_start,
|
| 772 |
+
ModelMeanType.EPSILON: noise,
|
| 773 |
+
}[self.model_mean_type]
|
| 774 |
+
assert model_output.shape == target.shape == x_start.shape
|
| 775 |
+
|
| 776 |
+
terms["mse"] = mean_flat((target - model_output) ** 2)
|
| 777 |
+
if "vb" in terms:
|
| 778 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
| 779 |
+
else:
|
| 780 |
+
terms["loss"] = terms["mse"]
|
| 781 |
+
else:
|
| 782 |
+
raise NotImplementedError(self.loss_type)
|
| 783 |
+
|
| 784 |
+
return terms
|
| 785 |
+
|
| 786 |
+
def _prior_bpd(self, x_start):
|
| 787 |
+
"""
|
| 788 |
+
Get the prior KL term for the variational lower-bound, measured in
|
| 789 |
+
bits-per-dim.
|
| 790 |
+
This term can't be optimized, as it only depends on the encoder.
|
| 791 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
| 792 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
| 793 |
+
"""
|
| 794 |
+
batch_size = x_start.shape[0]
|
| 795 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
| 796 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
| 797 |
+
kl_prior = normal_kl(
|
| 798 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
| 799 |
+
)
|
| 800 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
| 801 |
+
|
| 802 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
| 803 |
+
"""
|
| 804 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
| 805 |
+
as well as other related quantities.
|
| 806 |
+
:param model: the model to evaluate loss on.
|
| 807 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
| 808 |
+
:param clip_denoised: if True, clip denoised samples.
|
| 809 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 810 |
+
pass to the model. This can be used for conditioning.
|
| 811 |
+
:return: a dict containing the following keys:
|
| 812 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
| 813 |
+
- prior_bpd: the prior term in the lower-bound.
|
| 814 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
| 815 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
| 816 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
| 817 |
+
"""
|
| 818 |
+
device = x_start.device
|
| 819 |
+
batch_size = x_start.shape[0]
|
| 820 |
+
|
| 821 |
+
vb = []
|
| 822 |
+
xstart_mse = []
|
| 823 |
+
mse = []
|
| 824 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
| 825 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
| 826 |
+
noise = th.randn_like(x_start)
|
| 827 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
| 828 |
+
# Calculate VLB term at the current timestep
|
| 829 |
+
with th.no_grad():
|
| 830 |
+
out = self._vb_terms_bpd(
|
| 831 |
+
model,
|
| 832 |
+
x_start=x_start,
|
| 833 |
+
x_t=x_t,
|
| 834 |
+
t=t_batch,
|
| 835 |
+
clip_denoised=clip_denoised,
|
| 836 |
+
model_kwargs=model_kwargs,
|
| 837 |
+
)
|
| 838 |
+
vb.append(out["output"])
|
| 839 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
| 840 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
| 841 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
| 842 |
+
|
| 843 |
+
vb = th.stack(vb, dim=1)
|
| 844 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
| 845 |
+
mse = th.stack(mse, dim=1)
|
| 846 |
+
|
| 847 |
+
prior_bpd = self._prior_bpd(x_start)
|
| 848 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
| 849 |
+
return {
|
| 850 |
+
"total_bpd": total_bpd,
|
| 851 |
+
"prior_bpd": prior_bpd,
|
| 852 |
+
"vb": vb,
|
| 853 |
+
"xstart_mse": xstart_mse,
|
| 854 |
+
"mse": mse,
|
| 855 |
+
}
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
| 859 |
+
"""
|
| 860 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
| 861 |
+
:param arr: the 1-D numpy array.
|
| 862 |
+
:param timesteps: a tensor of indices into the array to extract.
|
| 863 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
| 864 |
+
dimension equal to the length of timesteps.
|
| 865 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
| 866 |
+
"""
|
| 867 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
| 868 |
+
while len(res.shape) < len(broadcast_shape):
|
| 869 |
+
res = res[..., None]
|
| 870 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
diffusion/gaussian_diffusion_dual.py
ADDED
|
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# Modified from OpenAI's diffusion repos
|
| 8 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
| 9 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
| 10 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch as th
|
| 17 |
+
import enum
|
| 18 |
+
|
| 19 |
+
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def mean_flat(tensor):
|
| 23 |
+
"""
|
| 24 |
+
Take the mean over all non-batch dimensions.
|
| 25 |
+
"""
|
| 26 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ModelMeanType(enum.Enum):
|
| 30 |
+
"""
|
| 31 |
+
Which type of output the model predicts.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
| 35 |
+
START_X = enum.auto() # the model predicts x_0
|
| 36 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ModelVarType(enum.Enum):
|
| 40 |
+
"""
|
| 41 |
+
What is used as the model's output variance.
|
| 42 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
| 43 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
LEARNED = enum.auto()
|
| 47 |
+
FIXED_SMALL = enum.auto()
|
| 48 |
+
FIXED_LARGE = enum.auto()
|
| 49 |
+
LEARNED_RANGE = enum.auto()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class LossType(enum.Enum):
|
| 53 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
| 54 |
+
RESCALED_MSE = (
|
| 55 |
+
enum.auto()
|
| 56 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
| 57 |
+
KL = enum.auto() # use the variational lower-bound
|
| 58 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
| 59 |
+
|
| 60 |
+
def is_vb(self):
|
| 61 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
| 65 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
| 66 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
| 67 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
| 68 |
+
return betas
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
| 72 |
+
"""
|
| 73 |
+
This is the deprecated API for creating beta schedules.
|
| 74 |
+
See get_named_beta_schedule() for the new library of schedules.
|
| 75 |
+
"""
|
| 76 |
+
if beta_schedule == "quad":
|
| 77 |
+
betas = (
|
| 78 |
+
np.linspace(
|
| 79 |
+
beta_start ** 0.5,
|
| 80 |
+
beta_end ** 0.5,
|
| 81 |
+
num_diffusion_timesteps,
|
| 82 |
+
dtype=np.float64,
|
| 83 |
+
)
|
| 84 |
+
** 2
|
| 85 |
+
)
|
| 86 |
+
elif beta_schedule == "linear":
|
| 87 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
| 88 |
+
elif beta_schedule == "warmup10":
|
| 89 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
| 90 |
+
elif beta_schedule == "warmup50":
|
| 91 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
| 92 |
+
elif beta_schedule == "const":
|
| 93 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
| 94 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
| 95 |
+
betas = 1.0 / np.linspace(
|
| 96 |
+
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
raise NotImplementedError(beta_schedule)
|
| 100 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
| 101 |
+
return betas
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
| 105 |
+
"""
|
| 106 |
+
Get a pre-defined beta schedule for the given name.
|
| 107 |
+
The beta schedule library consists of beta schedules which remain similar
|
| 108 |
+
in the limit of num_diffusion_timesteps.
|
| 109 |
+
Beta schedules may be added, but should not be removed or changed once
|
| 110 |
+
they are committed to maintain backwards compatibility.
|
| 111 |
+
"""
|
| 112 |
+
if schedule_name == "linear":
|
| 113 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
| 114 |
+
# diffusion steps.
|
| 115 |
+
scale = 1000 / num_diffusion_timesteps
|
| 116 |
+
return get_beta_schedule(
|
| 117 |
+
"linear",
|
| 118 |
+
beta_start=scale * 0.0001,
|
| 119 |
+
beta_end=scale * 0.02,
|
| 120 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
| 121 |
+
)
|
| 122 |
+
elif schedule_name == "squaredcos_cap_v2":
|
| 123 |
+
return betas_for_alpha_bar(
|
| 124 |
+
num_diffusion_timesteps,
|
| 125 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
| 132 |
+
"""
|
| 133 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
| 134 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
| 135 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
| 136 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
| 137 |
+
produces the cumulative product of (1-beta) up to that
|
| 138 |
+
part of the diffusion process.
|
| 139 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
| 140 |
+
prevent singularities.
|
| 141 |
+
"""
|
| 142 |
+
betas = []
|
| 143 |
+
for i in range(num_diffusion_timesteps):
|
| 144 |
+
t1 = i / num_diffusion_timesteps
|
| 145 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 146 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
| 147 |
+
return np.array(betas)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class GaussianDiffusion:
|
| 151 |
+
"""
|
| 152 |
+
Utilities for training and sampling diffusion models.
|
| 153 |
+
Original ported from this codebase:
|
| 154 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
| 155 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
| 156 |
+
starting at T and going to 1.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
*,
|
| 162 |
+
betas,
|
| 163 |
+
model_mean_type,
|
| 164 |
+
model_var_type,
|
| 165 |
+
loss_type
|
| 166 |
+
):
|
| 167 |
+
|
| 168 |
+
self.model_mean_type = model_mean_type
|
| 169 |
+
self.model_var_type = model_var_type
|
| 170 |
+
self.loss_type = loss_type
|
| 171 |
+
|
| 172 |
+
# Use float64 for accuracy.
|
| 173 |
+
betas = np.array(betas, dtype=np.float64)
|
| 174 |
+
self.betas = betas
|
| 175 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
| 176 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
| 177 |
+
|
| 178 |
+
self.num_timesteps = int(betas.shape[0])
|
| 179 |
+
|
| 180 |
+
alphas = 1.0 - betas
|
| 181 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 182 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
| 183 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
| 184 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
| 185 |
+
|
| 186 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 187 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
| 188 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
| 189 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
| 190 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
| 191 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
| 192 |
+
|
| 193 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 194 |
+
self.posterior_variance = (
|
| 195 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
| 196 |
+
)
|
| 197 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 198 |
+
self.posterior_log_variance_clipped = np.log(
|
| 199 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
| 200 |
+
) if len(self.posterior_variance) > 1 else np.array([])
|
| 201 |
+
|
| 202 |
+
self.posterior_mean_coef1 = (
|
| 203 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
| 204 |
+
)
|
| 205 |
+
self.posterior_mean_coef2 = (
|
| 206 |
+
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def q_mean_variance(self, x_start, t):
|
| 210 |
+
"""
|
| 211 |
+
Get the distribution q(x_t | x_0).
|
| 212 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
| 213 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
| 214 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
| 215 |
+
"""
|
| 216 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 217 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
| 218 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
| 219 |
+
return mean, variance, log_variance
|
| 220 |
+
|
| 221 |
+
def q_sample(self, x_start, t, noise=None):
|
| 222 |
+
"""
|
| 223 |
+
Diffuse the data for a given number of diffusion steps.
|
| 224 |
+
In other words, sample from q(x_t | x_0).
|
| 225 |
+
:param x_start: the initial data batch.
|
| 226 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
| 227 |
+
:param noise: if specified, the split-out normal noise.
|
| 228 |
+
:return: A noisy version of x_start.
|
| 229 |
+
"""
|
| 230 |
+
if noise is None:
|
| 231 |
+
noise = th.randn_like(x_start)
|
| 232 |
+
assert noise.shape == x_start.shape
|
| 233 |
+
return (
|
| 234 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 235 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
| 239 |
+
"""
|
| 240 |
+
Compute the mean and variance of the diffusion posterior:
|
| 241 |
+
q(x_{t-1} | x_t, x_0)
|
| 242 |
+
"""
|
| 243 |
+
assert x_start.shape == x_t.shape
|
| 244 |
+
posterior_mean = (
|
| 245 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
| 246 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 247 |
+
)
|
| 248 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
| 249 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
| 250 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
| 251 |
+
)
|
| 252 |
+
assert (
|
| 253 |
+
posterior_mean.shape[0]
|
| 254 |
+
== posterior_variance.shape[0]
|
| 255 |
+
== posterior_log_variance_clipped.shape[0]
|
| 256 |
+
== x_start.shape[0]
|
| 257 |
+
)
|
| 258 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 259 |
+
|
| 260 |
+
def q_posterior_mean_variance_dual(self, x_start, x_t, t):
|
| 261 |
+
"""
|
| 262 |
+
Compute the posterior mean and variance for each modality:
|
| 263 |
+
q(x_{t-1} | x_t, x_0)
|
| 264 |
+
Inputs:
|
| 265 |
+
x_start: tuple (x_v_start, x_a_start)
|
| 266 |
+
x_t: tuple (x_v_t, x_a_t)
|
| 267 |
+
t: Tensor of shape [B]
|
| 268 |
+
Outputs:
|
| 269 |
+
posterior_mean: (mean_v, mean_a)
|
| 270 |
+
posterior_variance: (var_v, var_a)
|
| 271 |
+
posterior_log_variance_clipped: (logvar_v, logvar_a)
|
| 272 |
+
"""
|
| 273 |
+
x_v_start, x_a_start = x_start
|
| 274 |
+
x_v_t, x_a_t = x_t
|
| 275 |
+
|
| 276 |
+
def single_modality_q(x_start_i, x_t_i):
|
| 277 |
+
assert x_start_i.shape == x_t_i.shape
|
| 278 |
+
posterior_mean = (
|
| 279 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t_i.shape) * x_start_i
|
| 280 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t_i.shape) * x_t_i
|
| 281 |
+
)
|
| 282 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t_i.shape)
|
| 283 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
| 284 |
+
self.posterior_log_variance_clipped, t, x_t_i.shape
|
| 285 |
+
)
|
| 286 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 287 |
+
|
| 288 |
+
mean_v, var_v, logvar_v = single_modality_q(x_v_start, x_v_t)
|
| 289 |
+
mean_a, var_a, logvar_a = single_modality_q(x_a_start, x_a_t)
|
| 290 |
+
|
| 291 |
+
return (mean_v, mean_a), (var_v, var_a), (logvar_v, logvar_a)
|
| 292 |
+
|
| 293 |
+
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
| 294 |
+
"""
|
| 295 |
+
Dual-modality version.
|
| 296 |
+
x: (x_v_t, x_a_t)
|
| 297 |
+
model: takes (x_v_t, x_a_t, t, **model_kwargs)
|
| 298 |
+
returns: out_v, out_a: dicts with 'mean', 'variance', 'log_variance', 'pred_xstart'
|
| 299 |
+
"""
|
| 300 |
+
if model_kwargs is None:
|
| 301 |
+
model_kwargs = {}
|
| 302 |
+
|
| 303 |
+
x_v, x_a = x
|
| 304 |
+
B, C_v = x_v.shape[:2]
|
| 305 |
+
B, C_a = x_a.shape[:2]
|
| 306 |
+
assert t.shape == (B,)
|
| 307 |
+
|
| 308 |
+
# Call model once to get both outputs
|
| 309 |
+
model_output_v, model_output_a = model(x_v, x_a, t, **model_kwargs)
|
| 310 |
+
|
| 311 |
+
# Helper function for one modality
|
| 312 |
+
def process_modality(x_t, model_output, C):
|
| 313 |
+
if isinstance(model_output, tuple):
|
| 314 |
+
model_output, _ = model_output # drop extra output if any
|
| 315 |
+
|
| 316 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
| 317 |
+
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
|
| 318 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
| 319 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
| 320 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x_t.shape)
|
| 321 |
+
frac = (model_var_values + 1) / 2
|
| 322 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
| 323 |
+
model_variance = th.exp(model_log_variance)
|
| 324 |
+
else:
|
| 325 |
+
model_variance_, model_log_variance_ = {
|
| 326 |
+
ModelVarType.FIXED_LARGE: (
|
| 327 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
| 328 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
| 329 |
+
),
|
| 330 |
+
ModelVarType.FIXED_SMALL: (
|
| 331 |
+
self.posterior_variance,
|
| 332 |
+
self.posterior_log_variance_clipped,
|
| 333 |
+
),
|
| 334 |
+
}[self.model_var_type]
|
| 335 |
+
model_variance = _extract_into_tensor(model_variance_, t, x_t.shape)
|
| 336 |
+
model_log_variance = _extract_into_tensor(model_log_variance_, t, x_t.shape)
|
| 337 |
+
|
| 338 |
+
def process_xstart(x):
|
| 339 |
+
if denoised_fn is not None:
|
| 340 |
+
x = denoised_fn(x)
|
| 341 |
+
if clip_denoised:
|
| 342 |
+
x = x.clamp(-1, 1)
|
| 343 |
+
return x
|
| 344 |
+
|
| 345 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
| 346 |
+
pred_xstart = process_xstart(model_output)
|
| 347 |
+
else:
|
| 348 |
+
pred_xstart = process_xstart(
|
| 349 |
+
self._predict_xstart_from_eps(x_t=x_t, t=t, eps=model_output)
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
model_mean, _, _ = self.q_posterior_mean_variance(
|
| 353 |
+
x_start=pred_xstart, x_t=x_t, t=t
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
return {
|
| 357 |
+
"mean": model_mean,
|
| 358 |
+
"variance": model_variance,
|
| 359 |
+
"log_variance": model_log_variance,
|
| 360 |
+
"pred_xstart": pred_xstart,
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
out_v = process_modality(x_v, model_output_v, C_v)
|
| 364 |
+
out_a = process_modality(x_a, model_output_a, C_a)
|
| 365 |
+
|
| 366 |
+
return out_v, out_a
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
| 370 |
+
assert x_t.shape == eps.shape
|
| 371 |
+
return (
|
| 372 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
| 373 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
| 377 |
+
return (
|
| 378 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
| 379 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 380 |
+
|
| 381 |
+
def condition_mean(
|
| 382 |
+
self,
|
| 383 |
+
cond_fn, # callable(x_v, x_a, t, **model_kwargs) -> (grad_v, grad_a)
|
| 384 |
+
p_mean_var_v, # dict for video: contains 'mean', 'variance'
|
| 385 |
+
p_mean_var_a, # dict for audio
|
| 386 |
+
x_v, x_a, # x_t for video/audio
|
| 387 |
+
t,
|
| 388 |
+
model_kwargs=None,
|
| 389 |
+
):
|
| 390 |
+
"""
|
| 391 |
+
Compute conditional mean separately for each modality:
|
| 392 |
+
new_mean = mean + variance * ∇ log p(y|x_t)
|
| 393 |
+
"""
|
| 394 |
+
if model_kwargs is None:
|
| 395 |
+
model_kwargs = {}
|
| 396 |
+
|
| 397 |
+
# cond_fn must return (grad_v, grad_a)
|
| 398 |
+
grad_v, grad_a = cond_fn(x_v, x_a, t, **model_kwargs)
|
| 399 |
+
|
| 400 |
+
new_mean_v = p_mean_var_v["mean"].float() + p_mean_var_v["variance"] * grad_v.float()
|
| 401 |
+
new_mean_a = p_mean_var_a["mean"].float() + p_mean_var_a["variance"] * grad_a.float()
|
| 402 |
+
|
| 403 |
+
return new_mean_v, new_mean_a
|
| 404 |
+
|
| 405 |
+
def p_sample(
|
| 406 |
+
self,
|
| 407 |
+
model,
|
| 408 |
+
x_v,
|
| 409 |
+
x_a,
|
| 410 |
+
t,
|
| 411 |
+
clip_denoised=True,
|
| 412 |
+
denoised_fn=None,
|
| 413 |
+
cond_fn=None,
|
| 414 |
+
model_kwargs=None,
|
| 415 |
+
):
|
| 416 |
+
"""
|
| 417 |
+
Sample x_{t-1} from the model at the given timestep.
|
| 418 |
+
:param model: the model to sample from.
|
| 419 |
+
:param x: the current tensor at x_{t-1}.
|
| 420 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
| 421 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
| 422 |
+
:param denoised_fn: if not None, a function which applies to the
|
| 423 |
+
x_start prediction before it is used to sample.
|
| 424 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
| 425 |
+
similarly to the model.
|
| 426 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 427 |
+
pass to the model. This can be used for conditioning.
|
| 428 |
+
:return: a dict containing the following keys:
|
| 429 |
+
- 'sample': a random sample from the model.
|
| 430 |
+
- 'pred_xstart': a prediction of x_0.
|
| 431 |
+
"""
|
| 432 |
+
# out = self.p_mean_variance(
|
| 433 |
+
# model,
|
| 434 |
+
# x,
|
| 435 |
+
# t,
|
| 436 |
+
# clip_denoised=clip_denoised,
|
| 437 |
+
# denoised_fn=denoised_fn,
|
| 438 |
+
# model_kwargs=model_kwargs,
|
| 439 |
+
# )
|
| 440 |
+
out_v, out_a = self.p_mean_variance(
|
| 441 |
+
model=model,
|
| 442 |
+
x=(x_v, x_a),
|
| 443 |
+
t=t,
|
| 444 |
+
clip_denoised=clip_denoised,
|
| 445 |
+
denoised_fn=denoised_fn,
|
| 446 |
+
model_kwargs=model_kwargs,
|
| 447 |
+
)
|
| 448 |
+
noise_v = th.randn_like(x_v)
|
| 449 |
+
noise_a = th.randn_like(x_a)
|
| 450 |
+
|
| 451 |
+
nonzero_mask_v = (
|
| 452 |
+
(t != 0).float().view(-1, *([1] * (len(x_v.shape) - 1)))
|
| 453 |
+
) # no noise when t == 0
|
| 454 |
+
nonzero_mask_a = (
|
| 455 |
+
(t != 0).float().view(-1, *([1] * (len(x_a.shape) - 1)))
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
if cond_fn is not None:
|
| 459 |
+
|
| 460 |
+
out_v["mean"], out_a["mean"] = condition_mean(cond_fn, out_v, out_a, x_v, x_a, t, model_kwargs=model_kwargs)
|
| 461 |
+
sample_v = out_v["mean"] + nonzero_mask_v * th.exp(0.5 * out_v["log_variance"]) * noise_v
|
| 462 |
+
sample_a = out_a["mean"] + nonzero_mask_a * th.exp(0.5 * out_a["log_variance"]) * noise_a
|
| 463 |
+
return {"sample_v": sample_v, "sample_a": sample_a, "pred_xstart_v": out_v["pred_xstart"], "pred_xstart_a": out_a["pred_xstart"]}
|
| 464 |
+
|
| 465 |
+
def p_sample_loop(
|
| 466 |
+
self,
|
| 467 |
+
model,
|
| 468 |
+
shape_v,
|
| 469 |
+
shape_a,
|
| 470 |
+
noise_v=None,
|
| 471 |
+
noise_a=None,
|
| 472 |
+
clip_denoised=True,
|
| 473 |
+
denoised_fn=None,
|
| 474 |
+
cond_fn=None,
|
| 475 |
+
model_kwargs=None,
|
| 476 |
+
device=None,
|
| 477 |
+
progress=False,
|
| 478 |
+
):
|
| 479 |
+
"""
|
| 480 |
+
Generate samples from the model.
|
| 481 |
+
:param model: the model module.
|
| 482 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
| 483 |
+
:param noise: if specified, the noise from the encoder to sample.
|
| 484 |
+
Should be of the same shape as `shape`.
|
| 485 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
| 486 |
+
:param denoised_fn: if not None, a function which applies to the
|
| 487 |
+
x_start prediction before it is used to sample.
|
| 488 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
| 489 |
+
similarly to the model.
|
| 490 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 491 |
+
pass to the model. This can be used for conditioning.
|
| 492 |
+
:param device: if specified, the device to create the samples on.
|
| 493 |
+
If not specified, use a model parameter's device.
|
| 494 |
+
:param progress: if True, show a tqdm progress bar.
|
| 495 |
+
:return: a non-differentiable batch of samples.
|
| 496 |
+
"""
|
| 497 |
+
final = None
|
| 498 |
+
for sample in self.p_sample_loop_progressive(
|
| 499 |
+
model,
|
| 500 |
+
shape_v,
|
| 501 |
+
shape_a,
|
| 502 |
+
noise_v=noise_v,
|
| 503 |
+
noise_a=noise_a,
|
| 504 |
+
clip_denoised=clip_denoised,
|
| 505 |
+
denoised_fn=denoised_fn,
|
| 506 |
+
cond_fn=cond_fn,
|
| 507 |
+
model_kwargs=model_kwargs,
|
| 508 |
+
device=device,
|
| 509 |
+
progress=progress,
|
| 510 |
+
):
|
| 511 |
+
final = sample
|
| 512 |
+
return final["sample_v"], final["sample_a"]
|
| 513 |
+
|
| 514 |
+
def p_sample_loop_progressive(
|
| 515 |
+
self,
|
| 516 |
+
model,
|
| 517 |
+
shape_v,
|
| 518 |
+
shape_a,
|
| 519 |
+
noise_v=None,
|
| 520 |
+
noise_a=None,
|
| 521 |
+
clip_denoised=True,
|
| 522 |
+
denoised_fn=None,
|
| 523 |
+
cond_fn=None,
|
| 524 |
+
model_kwargs=None,
|
| 525 |
+
device=None,
|
| 526 |
+
progress=False,
|
| 527 |
+
):
|
| 528 |
+
"""
|
| 529 |
+
Generate samples from the model and yield intermediate samples from
|
| 530 |
+
each timestep of diffusion.
|
| 531 |
+
Arguments are the same as p_sample_loop().
|
| 532 |
+
Returns a generator over dicts, where each dict is the return value of
|
| 533 |
+
p_sample().
|
| 534 |
+
"""
|
| 535 |
+
if device is None:
|
| 536 |
+
device = next(model.parameters()).device
|
| 537 |
+
assert isinstance(shape_v, (tuple, list))
|
| 538 |
+
assert isinstance(shape_a, (tuple, list))
|
| 539 |
+
|
| 540 |
+
if noise_v is not None:
|
| 541 |
+
img = noise_v
|
| 542 |
+
else:
|
| 543 |
+
img = th.randn(*shape_v, device=device)
|
| 544 |
+
if noise_a is not None:
|
| 545 |
+
audio = noise_a
|
| 546 |
+
else:
|
| 547 |
+
audio = th.randn(*shape_a, device=device)
|
| 548 |
+
|
| 549 |
+
indices = list(range(self.num_timesteps))[::-1]
|
| 550 |
+
|
| 551 |
+
if progress:
|
| 552 |
+
# Lazy import so that we don't depend on tqdm.
|
| 553 |
+
from tqdm.auto import tqdm
|
| 554 |
+
|
| 555 |
+
indices = tqdm(indices)
|
| 556 |
+
|
| 557 |
+
for i in indices:
|
| 558 |
+
t = th.tensor([i] * shape_v[0], device=device)
|
| 559 |
+
with th.no_grad():
|
| 560 |
+
#{"sample_v": sample_v, "sample_a": sample_a, "pred_xstart_v": out_v["pred_xstart"], "pred_xstart_a": out_a["pred_xstart"]}
|
| 561 |
+
out = self.p_sample(
|
| 562 |
+
model,
|
| 563 |
+
img,
|
| 564 |
+
audio,
|
| 565 |
+
t,
|
| 566 |
+
clip_denoised=clip_denoised,
|
| 567 |
+
denoised_fn=denoised_fn,
|
| 568 |
+
cond_fn=cond_fn,
|
| 569 |
+
model_kwargs=model_kwargs,
|
| 570 |
+
)
|
| 571 |
+
yield out
|
| 572 |
+
img = out["sample_v"]
|
| 573 |
+
audio = out["sample_a"]
|
| 574 |
+
|
| 575 |
+
def ddim_sample(
|
| 576 |
+
self,
|
| 577 |
+
model,
|
| 578 |
+
x,
|
| 579 |
+
t,
|
| 580 |
+
clip_denoised=True,
|
| 581 |
+
denoised_fn=None,
|
| 582 |
+
cond_fn=None,
|
| 583 |
+
model_kwargs=None,
|
| 584 |
+
eta=0.0,
|
| 585 |
+
):
|
| 586 |
+
"""
|
| 587 |
+
Sample x_{t-1} from the model using DDIM.
|
| 588 |
+
Same usage as p_sample().
|
| 589 |
+
"""
|
| 590 |
+
out = self.p_mean_variance(
|
| 591 |
+
model,
|
| 592 |
+
x,
|
| 593 |
+
t,
|
| 594 |
+
clip_denoised=clip_denoised,
|
| 595 |
+
denoised_fn=denoised_fn,
|
| 596 |
+
model_kwargs=model_kwargs,
|
| 597 |
+
)
|
| 598 |
+
if cond_fn is not None:
|
| 599 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
| 600 |
+
|
| 601 |
+
# Usually our model outputs epsilon, but we re-derive it
|
| 602 |
+
# in case we used x_start or x_prev prediction.
|
| 603 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
| 604 |
+
|
| 605 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
| 606 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
| 607 |
+
sigma = (
|
| 608 |
+
eta
|
| 609 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
| 610 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
| 611 |
+
)
|
| 612 |
+
# Equation 12.
|
| 613 |
+
noise = th.randn_like(x)
|
| 614 |
+
mean_pred = (
|
| 615 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
| 616 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
| 617 |
+
)
|
| 618 |
+
nonzero_mask = (
|
| 619 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
| 620 |
+
) # no noise when t == 0
|
| 621 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
| 622 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
| 623 |
+
|
| 624 |
+
def ddim_reverse_sample(
|
| 625 |
+
self,
|
| 626 |
+
model,
|
| 627 |
+
x,
|
| 628 |
+
t,
|
| 629 |
+
clip_denoised=True,
|
| 630 |
+
denoised_fn=None,
|
| 631 |
+
cond_fn=None,
|
| 632 |
+
model_kwargs=None,
|
| 633 |
+
eta=0.0,
|
| 634 |
+
):
|
| 635 |
+
"""
|
| 636 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
| 637 |
+
"""
|
| 638 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
| 639 |
+
out = self.p_mean_variance(
|
| 640 |
+
model,
|
| 641 |
+
x,
|
| 642 |
+
t,
|
| 643 |
+
clip_denoised=clip_denoised,
|
| 644 |
+
denoised_fn=denoised_fn,
|
| 645 |
+
model_kwargs=model_kwargs,
|
| 646 |
+
)
|
| 647 |
+
if cond_fn is not None:
|
| 648 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
| 649 |
+
# Usually our model outputs epsilon, but we re-derive it
|
| 650 |
+
# in case we used x_start or x_prev prediction.
|
| 651 |
+
eps = (
|
| 652 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
| 653 |
+
- out["pred_xstart"]
|
| 654 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
| 655 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
| 656 |
+
|
| 657 |
+
# Equation 12. reversed
|
| 658 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
| 659 |
+
|
| 660 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
| 661 |
+
|
| 662 |
+
def ddim_sample_loop(
|
| 663 |
+
self,
|
| 664 |
+
model,
|
| 665 |
+
shape,
|
| 666 |
+
noise=None,
|
| 667 |
+
clip_denoised=True,
|
| 668 |
+
denoised_fn=None,
|
| 669 |
+
cond_fn=None,
|
| 670 |
+
model_kwargs=None,
|
| 671 |
+
device=None,
|
| 672 |
+
progress=False,
|
| 673 |
+
eta=0.0,
|
| 674 |
+
):
|
| 675 |
+
"""
|
| 676 |
+
Generate samples from the model using DDIM.
|
| 677 |
+
Same usage as p_sample_loop().
|
| 678 |
+
"""
|
| 679 |
+
final = None
|
| 680 |
+
for sample in self.ddim_sample_loop_progressive(
|
| 681 |
+
model,
|
| 682 |
+
shape,
|
| 683 |
+
noise=noise,
|
| 684 |
+
clip_denoised=clip_denoised,
|
| 685 |
+
denoised_fn=denoised_fn,
|
| 686 |
+
cond_fn=cond_fn,
|
| 687 |
+
model_kwargs=model_kwargs,
|
| 688 |
+
device=device,
|
| 689 |
+
progress=progress,
|
| 690 |
+
eta=eta,
|
| 691 |
+
):
|
| 692 |
+
final = sample
|
| 693 |
+
return final["sample"]
|
| 694 |
+
|
| 695 |
+
def ddim_sample_loop_progressive(
|
| 696 |
+
self,
|
| 697 |
+
model,
|
| 698 |
+
shape,
|
| 699 |
+
noise=None,
|
| 700 |
+
clip_denoised=True,
|
| 701 |
+
denoised_fn=None,
|
| 702 |
+
cond_fn=None,
|
| 703 |
+
model_kwargs=None,
|
| 704 |
+
device=None,
|
| 705 |
+
progress=False,
|
| 706 |
+
eta=0.0,
|
| 707 |
+
):
|
| 708 |
+
"""
|
| 709 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
| 710 |
+
each timestep of DDIM.
|
| 711 |
+
Same usage as p_sample_loop_progressive().
|
| 712 |
+
"""
|
| 713 |
+
if device is None:
|
| 714 |
+
device = next(model.parameters()).device
|
| 715 |
+
assert isinstance(shape, (tuple, list))
|
| 716 |
+
if noise is not None:
|
| 717 |
+
img = noise
|
| 718 |
+
else:
|
| 719 |
+
img = th.randn(*shape, device=device)
|
| 720 |
+
indices = list(range(self.num_timesteps))[::-1]
|
| 721 |
+
|
| 722 |
+
if progress:
|
| 723 |
+
# Lazy import so that we don't depend on tqdm.
|
| 724 |
+
from tqdm.auto import tqdm
|
| 725 |
+
|
| 726 |
+
indices = tqdm(indices)
|
| 727 |
+
|
| 728 |
+
for i in indices:
|
| 729 |
+
t = th.tensor([i] * shape[0], device=device)
|
| 730 |
+
with th.no_grad():
|
| 731 |
+
out = self.ddim_sample(
|
| 732 |
+
model,
|
| 733 |
+
img,
|
| 734 |
+
t,
|
| 735 |
+
clip_denoised=clip_denoised,
|
| 736 |
+
denoised_fn=denoised_fn,
|
| 737 |
+
cond_fn=cond_fn,
|
| 738 |
+
model_kwargs=model_kwargs,
|
| 739 |
+
eta=eta,
|
| 740 |
+
)
|
| 741 |
+
yield out
|
| 742 |
+
img = out["sample"]
|
| 743 |
+
|
| 744 |
+
def _vb_terms_bpd(
|
| 745 |
+
self, model, x_v_start, x_a_start, x_v_t, x_a_t, t, clip_denoised=True, model_kwargs=None
|
| 746 |
+
):
|
| 747 |
+
"""
|
| 748 |
+
Dual-modality VB loss.
|
| 749 |
+
"""
|
| 750 |
+
|
| 751 |
+
# --- True posterior
|
| 752 |
+
(true_mean_v, true_mean_a), _, (logvar_v, logvar_a) = self.q_posterior_mean_variance_dual(
|
| 753 |
+
x_start=(x_v_start, x_a_start),
|
| 754 |
+
x_t=(x_v_t, x_a_t),
|
| 755 |
+
t=t,
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
# --- Model prediction
|
| 759 |
+
out_v, out_a = self.p_mean_variance(
|
| 760 |
+
model=model,
|
| 761 |
+
x=(x_v_t, x_a_t),
|
| 762 |
+
t=t,
|
| 763 |
+
clip_denoised=clip_denoised,
|
| 764 |
+
model_kwargs=model_kwargs,
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
# --- KL loss
|
| 768 |
+
kl_v = normal_kl(true_mean_v, logvar_v, out_v["mean"], out_v["log_variance"])
|
| 769 |
+
kl_a = normal_kl(true_mean_a, logvar_a, out_a["mean"], out_a["log_variance"])
|
| 770 |
+
kl_v = mean_flat(kl_v) / np.log(2.0)
|
| 771 |
+
kl_a = mean_flat(kl_a) / np.log(2.0)
|
| 772 |
+
|
| 773 |
+
# --- NLL loss (only at t=0)
|
| 774 |
+
decoder_nll_v = -discretized_gaussian_log_likelihood(
|
| 775 |
+
x_v_start, means=out_v["mean"], log_scales=0.5 * out_v["log_variance"]
|
| 776 |
+
)
|
| 777 |
+
decoder_nll_v = mean_flat(decoder_nll_v) / np.log(2.0)
|
| 778 |
+
|
| 779 |
+
decoder_nll_a = -discretized_gaussian_log_likelihood(
|
| 780 |
+
x_a_start, means=out_a["mean"], log_scales=0.5 * out_a["log_variance"]
|
| 781 |
+
)
|
| 782 |
+
decoder_nll_a = mean_flat(decoder_nll_a) / np.log(2.0)
|
| 783 |
+
|
| 784 |
+
# --- Final VB loss
|
| 785 |
+
output_v = th.where((t == 0), decoder_nll_v, kl_v)
|
| 786 |
+
output_a = th.where((t == 0), decoder_nll_a, kl_a)
|
| 787 |
+
|
| 788 |
+
return {
|
| 789 |
+
"output_v": output_v,
|
| 790 |
+
"output_a": output_a,
|
| 791 |
+
"pred_xstart": (out_v["pred_xstart"], out_a["pred_xstart"]),
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
def training_losses(self, model, x_v_start, x_a_start, t, model_kwargs=None, noise_v=None, noise_a=None):
|
| 795 |
+
"""
|
| 796 |
+
Compute training losses for a single timestep.
|
| 797 |
+
:param model: the model to evaluate loss on.
|
| 798 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
| 799 |
+
:param t: a batch of timestep indices.
|
| 800 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 801 |
+
pass to the model. This can be used for conditioning.
|
| 802 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
| 803 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
| 804 |
+
Some mean or variance settings may also have other keys.
|
| 805 |
+
"""
|
| 806 |
+
if model_kwargs is None:
|
| 807 |
+
model_kwargs = {}
|
| 808 |
+
if noise_v is None:
|
| 809 |
+
noise_v = th.randn_like(x_v_start)
|
| 810 |
+
x_v_t = self.q_sample(x_v_start, t, noise=noise_v)
|
| 811 |
+
if noise_a is None:
|
| 812 |
+
noise_a = th.randn_like(x_a_start)
|
| 813 |
+
x_a_t = self.q_sample(x_a_start, t, noise=noise_a)
|
| 814 |
+
|
| 815 |
+
terms = {}
|
| 816 |
+
|
| 817 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
| 818 |
+
vb_terms = self._vb_terms_bpd(
|
| 819 |
+
model=model,
|
| 820 |
+
x_v_start=x_v_start,
|
| 821 |
+
x_a_start=x_a_start,
|
| 822 |
+
x_v_t=x_v_t,
|
| 823 |
+
x_a_t=x_a_t,
|
| 824 |
+
t=t,
|
| 825 |
+
clip_denoised=False,
|
| 826 |
+
model_kwargs=model_kwargs,
|
| 827 |
+
)
|
| 828 |
+
terms["vb_v"] = vb_terms["output_v"]
|
| 829 |
+
terms["vb_a"] = vb_terms["output_a"]
|
| 830 |
+
terms["loss"] = vb_terms["output_v"] + vb_terms["output_a"]
|
| 831 |
+
if self.loss_type == LossType.RESCALED_KL:
|
| 832 |
+
terms["loss"] *= self.num_timesteps
|
| 833 |
+
|
| 834 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
| 835 |
+
model_output_v, model_output_a = model(x_v_t, x_a_t, t, **model_kwargs)
|
| 836 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
| 837 |
+
B, C_v = x_v_t.shape[:2]
|
| 838 |
+
B, C_a = x_a_t.shape[:2]
|
| 839 |
+
|
| 840 |
+
model_output_v, model_var_v = th.split(model_output_v, C_v, dim=1)
|
| 841 |
+
model_output_a, model_var_a = th.split(model_output_a, C_a, dim=1)
|
| 842 |
+
|
| 843 |
+
frozen_out_v = th.cat([model_output_v.detach(), model_var_v], dim=1)
|
| 844 |
+
frozen_out_a = th.cat([model_output_a.detach(), model_var_a], dim=1)
|
| 845 |
+
|
| 846 |
+
frozen_model = lambda *args, **kwargs: (frozen_out_v, frozen_out_a)
|
| 847 |
+
|
| 848 |
+
vb_output = self._vb_terms_bpd(
|
| 849 |
+
model=frozen_model,
|
| 850 |
+
x_v_start=x_v_start,
|
| 851 |
+
x_a_start=x_a_start,
|
| 852 |
+
x_v_t=x_v_t,
|
| 853 |
+
x_a_t=x_a_t,
|
| 854 |
+
t=t,
|
| 855 |
+
clip_denoised=False,
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
terms["vb_v"] = vb_output["output_v"]
|
| 859 |
+
terms["vb_a"] = vb_output["output_a"]
|
| 860 |
+
|
| 861 |
+
# === MSE Loss ===
|
| 862 |
+
def process_mse(modality, x_start, x_t, model_output, noise):
|
| 863 |
+
target = {
|
| 864 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance_dual(
|
| 865 |
+
x_start=(x_v_start, x_a_start),
|
| 866 |
+
x_t=(x_v_t, x_a_t),
|
| 867 |
+
t=t,
|
| 868 |
+
)[0][0 if modality == "v" else 1],
|
| 869 |
+
ModelMeanType.START_X: x_start,
|
| 870 |
+
ModelMeanType.EPSILON: noise,
|
| 871 |
+
}[self.model_mean_type]
|
| 872 |
+
|
| 873 |
+
assert model_output.shape == target.shape == x_start.shape
|
| 874 |
+
terms[f"mse_{modality}"] = mean_flat((target - model_output) ** 2)
|
| 875 |
+
|
| 876 |
+
process_mse("v", x_v_start, x_v_t, model_output_v, noise_v)
|
| 877 |
+
process_mse("a", x_a_start, x_a_t, model_output_a, noise_a)
|
| 878 |
+
|
| 879 |
+
if "vb_v" in terms and "vb_a" in terms:
|
| 880 |
+
terms["vb"] = terms["vb_v"] + terms["vb_a"]
|
| 881 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
| 882 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
| 883 |
+
|
| 884 |
+
terms["loss"] = terms["mse_v"] + terms["mse_a"]
|
| 885 |
+
if "vb" in terms:
|
| 886 |
+
terms["loss"] += terms["vb"]
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
return terms
|
| 890 |
+
|
| 891 |
+
def _prior_bpd(self, x_start):
|
| 892 |
+
"""
|
| 893 |
+
Get the prior KL term for the variational lower-bound, measured in
|
| 894 |
+
bits-per-dim.
|
| 895 |
+
This term can't be optimized, as it only depends on the encoder.
|
| 896 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
| 897 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
| 898 |
+
"""
|
| 899 |
+
batch_size = x_start.shape[0]
|
| 900 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
| 901 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
| 902 |
+
kl_prior = normal_kl(
|
| 903 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
| 904 |
+
)
|
| 905 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
| 906 |
+
|
| 907 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
| 908 |
+
"""
|
| 909 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
| 910 |
+
as well as other related quantities.
|
| 911 |
+
:param model: the model to evaluate loss on.
|
| 912 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
| 913 |
+
:param clip_denoised: if True, clip denoised samples.
|
| 914 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
| 915 |
+
pass to the model. This can be used for conditioning.
|
| 916 |
+
:return: a dict containing the following keys:
|
| 917 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
| 918 |
+
- prior_bpd: the prior term in the lower-bound.
|
| 919 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
| 920 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
| 921 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
| 922 |
+
"""
|
| 923 |
+
device = x_start.device
|
| 924 |
+
batch_size = x_start.shape[0]
|
| 925 |
+
|
| 926 |
+
vb = []
|
| 927 |
+
xstart_mse = []
|
| 928 |
+
mse = []
|
| 929 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
| 930 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
| 931 |
+
noise = th.randn_like(x_start)
|
| 932 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
| 933 |
+
# Calculate VLB term at the current timestep
|
| 934 |
+
with th.no_grad():
|
| 935 |
+
out = self._vb_terms_bpd(
|
| 936 |
+
model,
|
| 937 |
+
x_start=x_start,
|
| 938 |
+
x_t=x_t,
|
| 939 |
+
t=t_batch,
|
| 940 |
+
clip_denoised=clip_denoised,
|
| 941 |
+
model_kwargs=model_kwargs,
|
| 942 |
+
)
|
| 943 |
+
vb.append(out["output"])
|
| 944 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
| 945 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
| 946 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
| 947 |
+
|
| 948 |
+
vb = th.stack(vb, dim=1)
|
| 949 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
| 950 |
+
mse = th.stack(mse, dim=1)
|
| 951 |
+
|
| 952 |
+
prior_bpd = self._prior_bpd(x_start)
|
| 953 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
| 954 |
+
return {
|
| 955 |
+
"total_bpd": total_bpd,
|
| 956 |
+
"prior_bpd": prior_bpd,
|
| 957 |
+
"vb": vb,
|
| 958 |
+
"xstart_mse": xstart_mse,
|
| 959 |
+
"mse": mse,
|
| 960 |
+
}
|
| 961 |
+
|
| 962 |
+
|
| 963 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
| 964 |
+
"""
|
| 965 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
| 966 |
+
:param arr: the 1-D numpy array.
|
| 967 |
+
:param timesteps: a tensor of indices into the array to extract.
|
| 968 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
| 969 |
+
dimension equal to the length of timesteps.
|
| 970 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
| 971 |
+
"""
|
| 972 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
| 973 |
+
while len(res.shape) < len(broadcast_shape):
|
| 974 |
+
res = res[..., None]
|
| 975 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
diffusion/respace.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch as th
|
| 3 |
+
|
| 4 |
+
from .gaussian_diffusion import GaussianDiffusion
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def space_timesteps(num_timesteps, section_counts):
|
| 8 |
+
"""
|
| 9 |
+
Create a list of timesteps to use from an original diffusion process,
|
| 10 |
+
given the number of timesteps we want to take from equally-sized portions
|
| 11 |
+
of the original process.
|
| 12 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
| 13 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
| 14 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
| 15 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
| 16 |
+
from the DDIM paper is used, and only one section is allowed.
|
| 17 |
+
:param num_timesteps: the number of diffusion steps in the original
|
| 18 |
+
process to divide up.
|
| 19 |
+
:param section_counts: either a list of numbers, or a string containing
|
| 20 |
+
comma-separated numbers, indicating the step count
|
| 21 |
+
per section. As a special case, use "ddimN" where N
|
| 22 |
+
is a number of steps to use the striding from the
|
| 23 |
+
DDIM paper.
|
| 24 |
+
:return: a set of diffusion steps from the original process to use.
|
| 25 |
+
"""
|
| 26 |
+
if isinstance(section_counts, str):
|
| 27 |
+
if section_counts.startswith("ddim"):
|
| 28 |
+
desired_count = int(section_counts[len("ddim") :])
|
| 29 |
+
for i in range(1, num_timesteps):
|
| 30 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
| 31 |
+
return set(range(0, num_timesteps, i))
|
| 32 |
+
raise ValueError(
|
| 33 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
| 34 |
+
)
|
| 35 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
| 36 |
+
size_per = num_timesteps // len(section_counts)
|
| 37 |
+
extra = num_timesteps % len(section_counts)
|
| 38 |
+
start_idx = 0
|
| 39 |
+
all_steps = []
|
| 40 |
+
for i, section_count in enumerate(section_counts):
|
| 41 |
+
size = size_per + (1 if i < extra else 0)
|
| 42 |
+
if size < section_count:
|
| 43 |
+
raise ValueError(
|
| 44 |
+
f"cannot divide section of {size} steps into {section_count}"
|
| 45 |
+
)
|
| 46 |
+
if section_count <= 1:
|
| 47 |
+
frac_stride = 1
|
| 48 |
+
else:
|
| 49 |
+
frac_stride = (size - 1) / (section_count - 1)
|
| 50 |
+
cur_idx = 0.0
|
| 51 |
+
taken_steps = []
|
| 52 |
+
for _ in range(section_count):
|
| 53 |
+
taken_steps.append(start_idx + round(cur_idx))
|
| 54 |
+
cur_idx += frac_stride
|
| 55 |
+
all_steps += taken_steps
|
| 56 |
+
start_idx += size
|
| 57 |
+
return set(all_steps)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class SpacedDiffusion(GaussianDiffusion):
|
| 61 |
+
"""
|
| 62 |
+
A diffusion process which can skip steps in a base diffusion process.
|
| 63 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
| 64 |
+
original diffusion process to retain.
|
| 65 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, use_timesteps, dual, **kwargs):
|
| 69 |
+
self.use_timesteps = set(use_timesteps)
|
| 70 |
+
self.timestep_map = []
|
| 71 |
+
self.original_num_steps = len(kwargs["betas"])
|
| 72 |
+
self.dual = dual
|
| 73 |
+
|
| 74 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
| 75 |
+
last_alpha_cumprod = 1.0
|
| 76 |
+
new_betas = []
|
| 77 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
| 78 |
+
if i in self.use_timesteps:
|
| 79 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
| 80 |
+
last_alpha_cumprod = alpha_cumprod
|
| 81 |
+
self.timestep_map.append(i)
|
| 82 |
+
kwargs["betas"] = np.array(new_betas)
|
| 83 |
+
super().__init__(**kwargs)
|
| 84 |
+
|
| 85 |
+
def p_mean_variance(
|
| 86 |
+
self, model, *args, **kwargs
|
| 87 |
+
): # pylint: disable=signature-differs
|
| 88 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
| 89 |
+
|
| 90 |
+
def training_losses(
|
| 91 |
+
self, model, *args, **kwargs
|
| 92 |
+
): # pylint: disable=signature-differs
|
| 93 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
| 94 |
+
|
| 95 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
| 96 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
| 97 |
+
|
| 98 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
| 99 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
| 100 |
+
|
| 101 |
+
def _wrap_model(self, model):
|
| 102 |
+
if isinstance(model, _WrappedModel):
|
| 103 |
+
return model
|
| 104 |
+
return _WrappedModel(
|
| 105 |
+
model, self.timestep_map, self.original_num_steps, self.dual
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def _scale_timesteps(self, t):
|
| 109 |
+
# Scaling is done by the wrapped model.
|
| 110 |
+
return t
|
| 111 |
+
|
| 112 |
+
class _WrappedModel:
|
| 113 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
| 114 |
+
self.model = model
|
| 115 |
+
self.timestep_map = timestep_map
|
| 116 |
+
# self.rescale_timesteps = rescale_timesteps
|
| 117 |
+
self.original_num_steps = original_num_steps
|
| 118 |
+
|
| 119 |
+
def __call__(self, x, ts, **kwargs):
|
| 120 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
| 121 |
+
new_ts = map_tensor[ts]
|
| 122 |
+
# if self.rescale_timesteps:
|
| 123 |
+
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
| 124 |
+
return self.model(x, new_ts, **kwargs)
|
| 125 |
+
|
diffusion/respace_dual.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# Modified from OpenAI's diffusion repos
|
| 8 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
| 9 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
| 10 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch as th
|
| 14 |
+
|
| 15 |
+
from .gaussian_diffusion_dual import GaussianDiffusion
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def space_timesteps(num_timesteps, section_counts):
|
| 19 |
+
"""
|
| 20 |
+
Create a list of timesteps to use from an original diffusion process,
|
| 21 |
+
given the number of timesteps we want to take from equally-sized portions
|
| 22 |
+
of the original process.
|
| 23 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
| 24 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
| 25 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
| 26 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
| 27 |
+
from the DDIM paper is used, and only one section is allowed.
|
| 28 |
+
:param num_timesteps: the number of diffusion steps in the original
|
| 29 |
+
process to divide up.
|
| 30 |
+
:param section_counts: either a list of numbers, or a string containing
|
| 31 |
+
comma-separated numbers, indicating the step count
|
| 32 |
+
per section. As a special case, use "ddimN" where N
|
| 33 |
+
is a number of steps to use the striding from the
|
| 34 |
+
DDIM paper.
|
| 35 |
+
:return: a set of diffusion steps from the original process to use.
|
| 36 |
+
"""
|
| 37 |
+
if isinstance(section_counts, str):
|
| 38 |
+
if section_counts.startswith("ddim"):
|
| 39 |
+
desired_count = int(section_counts[len("ddim") :])
|
| 40 |
+
for i in range(1, num_timesteps):
|
| 41 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
| 42 |
+
return set(range(0, num_timesteps, i))
|
| 43 |
+
raise ValueError(
|
| 44 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
| 45 |
+
)
|
| 46 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
| 47 |
+
size_per = num_timesteps // len(section_counts)
|
| 48 |
+
extra = num_timesteps % len(section_counts)
|
| 49 |
+
start_idx = 0
|
| 50 |
+
all_steps = []
|
| 51 |
+
for i, section_count in enumerate(section_counts):
|
| 52 |
+
size = size_per + (1 if i < extra else 0)
|
| 53 |
+
if size < section_count:
|
| 54 |
+
raise ValueError(
|
| 55 |
+
f"cannot divide section of {size} steps into {section_count}"
|
| 56 |
+
)
|
| 57 |
+
if section_count <= 1:
|
| 58 |
+
frac_stride = 1
|
| 59 |
+
else:
|
| 60 |
+
frac_stride = (size - 1) / (section_count - 1)
|
| 61 |
+
cur_idx = 0.0
|
| 62 |
+
taken_steps = []
|
| 63 |
+
for _ in range(section_count):
|
| 64 |
+
taken_steps.append(start_idx + round(cur_idx))
|
| 65 |
+
cur_idx += frac_stride
|
| 66 |
+
all_steps += taken_steps
|
| 67 |
+
start_idx += size
|
| 68 |
+
return set(all_steps)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class SpacedDiffusion(GaussianDiffusion):
|
| 72 |
+
"""
|
| 73 |
+
A diffusion process which can skip steps in a base diffusion process.
|
| 74 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
| 75 |
+
original diffusion process to retain.
|
| 76 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, use_timesteps, **kwargs):
|
| 80 |
+
self.use_timesteps = set(use_timesteps)
|
| 81 |
+
self.timestep_map = []
|
| 82 |
+
self.original_num_steps = len(kwargs["betas"])
|
| 83 |
+
|
| 84 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
| 85 |
+
last_alpha_cumprod = 1.0
|
| 86 |
+
new_betas = []
|
| 87 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
| 88 |
+
if i in self.use_timesteps:
|
| 89 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
| 90 |
+
last_alpha_cumprod = alpha_cumprod
|
| 91 |
+
self.timestep_map.append(i)
|
| 92 |
+
kwargs["betas"] = np.array(new_betas)
|
| 93 |
+
super().__init__(**kwargs)
|
| 94 |
+
|
| 95 |
+
def p_mean_variance(
|
| 96 |
+
self, model, *args, **kwargs
|
| 97 |
+
): # pylint: disable=signature-differs
|
| 98 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
| 99 |
+
|
| 100 |
+
def training_losses(
|
| 101 |
+
self, model, *args, **kwargs
|
| 102 |
+
): # pylint: disable=signature-differs
|
| 103 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
| 104 |
+
|
| 105 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
| 106 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
| 107 |
+
|
| 108 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
| 109 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
| 110 |
+
|
| 111 |
+
def _wrap_model(self, model):
|
| 112 |
+
if isinstance(model, _WrappedModel):
|
| 113 |
+
return model
|
| 114 |
+
return _WrappedModel(
|
| 115 |
+
model, self.timestep_map, self.original_num_steps
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _scale_timesteps(self, t):
|
| 119 |
+
# Scaling is done by the wrapped model.
|
| 120 |
+
return t
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class _WrappedModel:
|
| 124 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
| 125 |
+
self.model = model
|
| 126 |
+
self.timestep_map = timestep_map
|
| 127 |
+
# self.rescale_timesteps = rescale_timesteps
|
| 128 |
+
self.original_num_steps = original_num_steps
|
| 129 |
+
|
| 130 |
+
def __call__(self, x_v, x_a, ts, **kwargs):
|
| 131 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
| 132 |
+
new_ts = map_tensor[ts]
|
| 133 |
+
# if self.rescale_timesteps:
|
| 134 |
+
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
| 135 |
+
return self.model(x_v, x_a, new_ts, **kwargs)
|
diffusion/timestep_sampler.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch as th
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_named_schedule_sampler(name, diffusion):
|
| 9 |
+
"""
|
| 10 |
+
Create a ScheduleSampler from a library of pre-defined samplers.
|
| 11 |
+
:param name: the name of the sampler.
|
| 12 |
+
:param diffusion: the diffusion object to sample for.
|
| 13 |
+
"""
|
| 14 |
+
if name == "uniform":
|
| 15 |
+
return UniformSampler(diffusion)
|
| 16 |
+
elif name == "loss-second-moment":
|
| 17 |
+
return LossSecondMomentResampler(diffusion)
|
| 18 |
+
else:
|
| 19 |
+
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ScheduleSampler(ABC):
|
| 23 |
+
"""
|
| 24 |
+
A distribution over timesteps in the diffusion process, intended to reduce
|
| 25 |
+
variance of the objective.
|
| 26 |
+
By default, samplers perform unbiased importance sampling, in which the
|
| 27 |
+
objective's mean is unchanged.
|
| 28 |
+
However, subclasses may override sample() to change how the resampled
|
| 29 |
+
terms are reweighted, allowing for actual changes in the objective.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
@abstractmethod
|
| 33 |
+
def weights(self):
|
| 34 |
+
"""
|
| 35 |
+
Get a numpy array of weights, one per diffusion step.
|
| 36 |
+
The weights needn't be normalized, but must be positive.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def sample(self, batch_size, device):
|
| 40 |
+
"""
|
| 41 |
+
Importance-sample timesteps for a batch.
|
| 42 |
+
:param batch_size: the number of timesteps.
|
| 43 |
+
:param device: the torch device to save to.
|
| 44 |
+
:return: a tuple (timesteps, weights):
|
| 45 |
+
- timesteps: a tensor of timestep indices.
|
| 46 |
+
- weights: a tensor of weights to scale the resulting losses.
|
| 47 |
+
"""
|
| 48 |
+
w = self.weights()
|
| 49 |
+
p = w / np.sum(w)
|
| 50 |
+
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
| 51 |
+
indices = th.from_numpy(indices_np).long().to(device)
|
| 52 |
+
weights_np = 1 / (len(p) * p[indices_np])
|
| 53 |
+
weights = th.from_numpy(weights_np).float().to(device)
|
| 54 |
+
return indices, weights
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class UniformSampler(ScheduleSampler):
|
| 58 |
+
def __init__(self, diffusion):
|
| 59 |
+
self.diffusion = diffusion
|
| 60 |
+
self._weights = np.ones([diffusion.num_timesteps])
|
| 61 |
+
|
| 62 |
+
def weights(self):
|
| 63 |
+
return self._weights
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class LossAwareSampler(ScheduleSampler):
|
| 67 |
+
def update_with_local_losses(self, local_ts, local_losses):
|
| 68 |
+
"""
|
| 69 |
+
Update the reweighting using losses from a model.
|
| 70 |
+
Call this method from each rank with a batch of timesteps and the
|
| 71 |
+
corresponding losses for each of those timesteps.
|
| 72 |
+
This method will perform synchronization to make sure all of the ranks
|
| 73 |
+
maintain the exact same reweighting.
|
| 74 |
+
:param local_ts: an integer Tensor of timesteps.
|
| 75 |
+
:param local_losses: a 1D Tensor of losses.
|
| 76 |
+
"""
|
| 77 |
+
batch_sizes = [
|
| 78 |
+
th.tensor([0], dtype=th.int32, device=local_ts.device)
|
| 79 |
+
for _ in range(dist.get_world_size())
|
| 80 |
+
]
|
| 81 |
+
dist.all_gather(
|
| 82 |
+
batch_sizes,
|
| 83 |
+
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Pad all_gather batches to be the maximum batch size.
|
| 87 |
+
batch_sizes = [x.item() for x in batch_sizes]
|
| 88 |
+
max_bs = max(batch_sizes)
|
| 89 |
+
|
| 90 |
+
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
| 91 |
+
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
| 92 |
+
dist.all_gather(timestep_batches, local_ts)
|
| 93 |
+
dist.all_gather(loss_batches, local_losses)
|
| 94 |
+
timesteps = [
|
| 95 |
+
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
|
| 96 |
+
]
|
| 97 |
+
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
| 98 |
+
self.update_with_all_losses(timesteps, losses)
|
| 99 |
+
|
| 100 |
+
@abstractmethod
|
| 101 |
+
def update_with_all_losses(self, ts, losses):
|
| 102 |
+
"""
|
| 103 |
+
Update the reweighting using losses from a model.
|
| 104 |
+
Sub-classes should override this method to update the reweighting
|
| 105 |
+
using losses from the model.
|
| 106 |
+
This method directly updates the reweighting without synchronizing
|
| 107 |
+
between workers. It is called by update_with_local_losses from all
|
| 108 |
+
ranks with identical arguments. Thus, it should have deterministic
|
| 109 |
+
behavior to maintain state across workers.
|
| 110 |
+
:param ts: a list of int timesteps.
|
| 111 |
+
:param losses: a list of float losses, one per timestep.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class LossSecondMomentResampler(LossAwareSampler):
|
| 116 |
+
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
| 117 |
+
self.diffusion = diffusion
|
| 118 |
+
self.history_per_term = history_per_term
|
| 119 |
+
self.uniform_prob = uniform_prob
|
| 120 |
+
self._loss_history = np.zeros(
|
| 121 |
+
[diffusion.num_timesteps, history_per_term], dtype=np.float64
|
| 122 |
+
)
|
| 123 |
+
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
| 124 |
+
|
| 125 |
+
def weights(self):
|
| 126 |
+
if not self._warmed_up():
|
| 127 |
+
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
| 128 |
+
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
|
| 129 |
+
weights /= np.sum(weights)
|
| 130 |
+
weights *= 1 - self.uniform_prob
|
| 131 |
+
weights += self.uniform_prob / len(weights)
|
| 132 |
+
return weights
|
| 133 |
+
|
| 134 |
+
def update_with_all_losses(self, ts, losses):
|
| 135 |
+
for t, loss in zip(ts, losses):
|
| 136 |
+
if self._loss_counts[t] == self.history_per_term:
|
| 137 |
+
# Shift out the oldest loss term.
|
| 138 |
+
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
| 139 |
+
self._loss_history[t, -1] = loss
|
| 140 |
+
else:
|
| 141 |
+
self._loss_history[t, self._loss_counts[t]] = loss
|
| 142 |
+
self._loss_counts[t] += 1
|
| 143 |
+
|
| 144 |
+
def _warmed_up(self):
|
| 145 |
+
return (self._loss_counts == self.history_per_term).all()
|
distributed.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
from torcheval.metrics import FrechetInceptionDistance
|
| 4 |
+
|
| 5 |
+
from collections import defaultdict, deque
|
| 6 |
+
import os
|
| 7 |
+
import datetime
|
| 8 |
+
import builtins
|
| 9 |
+
from logging import getLogger
|
| 10 |
+
import pickle
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
logger = getLogger()
|
| 14 |
+
|
| 15 |
+
def is_dist_avail_and_initialized():
|
| 16 |
+
if not dist.is_available():
|
| 17 |
+
return False
|
| 18 |
+
if not dist.is_initialized():
|
| 19 |
+
return False
|
| 20 |
+
return True
|
| 21 |
+
|
| 22 |
+
def get_world_size():
|
| 23 |
+
if not is_dist_avail_and_initialized():
|
| 24 |
+
return 1
|
| 25 |
+
return dist.get_world_size()
|
| 26 |
+
|
| 27 |
+
def get_rank():
|
| 28 |
+
if not is_dist_avail_and_initialized():
|
| 29 |
+
return 0
|
| 30 |
+
return dist.get_rank()
|
| 31 |
+
|
| 32 |
+
def is_main_process():
|
| 33 |
+
return get_rank() == 0
|
| 34 |
+
|
| 35 |
+
def setup_for_distributed(is_master):
|
| 36 |
+
"""
|
| 37 |
+
This function disables printing when not in master process
|
| 38 |
+
"""
|
| 39 |
+
builtin_print = builtins.print
|
| 40 |
+
|
| 41 |
+
def print(*args, **kwargs):
|
| 42 |
+
force = kwargs.pop('force', False)
|
| 43 |
+
force = force or (get_world_size() > 8)
|
| 44 |
+
if is_master or force:
|
| 45 |
+
now = datetime.datetime.now().time()
|
| 46 |
+
builtin_print('[{}] '.format(now), end='') # print with time stamp
|
| 47 |
+
builtin_print(*args, **kwargs)
|
| 48 |
+
|
| 49 |
+
builtins.print = print
|
| 50 |
+
|
| 51 |
+
def init_distributed(port=37124, rank_and_world_size=(None, None)):
|
| 52 |
+
rank, world_size = rank_and_world_size
|
| 53 |
+
dist_url='env://'
|
| 54 |
+
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(port))
|
| 55 |
+
print("Using port", os.environ['MASTER_PORT'])
|
| 56 |
+
|
| 57 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 58 |
+
try:
|
| 59 |
+
rank = int(os.environ["RANK"])
|
| 60 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 61 |
+
gpu = int(os.environ["LOCAL_RANK"])
|
| 62 |
+
except Exception:
|
| 63 |
+
logger.info('torchrun env vars not sets')
|
| 64 |
+
|
| 65 |
+
elif "SLURM_PROCID" in os.environ:
|
| 66 |
+
try:
|
| 67 |
+
world_size = int(os.environ['SLURM_NTASKS'])
|
| 68 |
+
rank = int(os.environ['SLURM_PROCID'])
|
| 69 |
+
gpu = rank % torch.cuda.device_count()
|
| 70 |
+
if 'HOSTNAME' in os.environ:
|
| 71 |
+
os.environ['MASTER_ADDR'] = os.environ['HOSTNAME']
|
| 72 |
+
else:
|
| 73 |
+
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
| 74 |
+
except Exception:
|
| 75 |
+
logger.info('SLURM vars not set')
|
| 76 |
+
|
| 77 |
+
else:
|
| 78 |
+
rank = 0
|
| 79 |
+
world_size = 1
|
| 80 |
+
gpu = 0
|
| 81 |
+
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
| 82 |
+
|
| 83 |
+
torch.cuda.set_device(gpu)
|
| 84 |
+
|
| 85 |
+
torch.distributed.init_process_group(
|
| 86 |
+
backend='nccl',
|
| 87 |
+
world_size=world_size,
|
| 88 |
+
rank=rank,
|
| 89 |
+
init_method=dist_url
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# setup_for_distributed(rank == 0)
|
| 93 |
+
return world_size, rank, gpu, True
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SmoothedValue(object):
|
| 97 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 98 |
+
window or the global series average.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(self, window_size=20, fmt=None):
|
| 102 |
+
if fmt is None:
|
| 103 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 104 |
+
self.deque = deque(maxlen=window_size)
|
| 105 |
+
self.total = 0.0
|
| 106 |
+
self.count = 0
|
| 107 |
+
self.fmt = fmt
|
| 108 |
+
|
| 109 |
+
def update(self, value, n=1):
|
| 110 |
+
self.deque.append(value)
|
| 111 |
+
self.count += n
|
| 112 |
+
self.total += value * n
|
| 113 |
+
|
| 114 |
+
def synchronize_between_processes(self):
|
| 115 |
+
"""
|
| 116 |
+
Warning: does not synchronize the deque!
|
| 117 |
+
"""
|
| 118 |
+
if not is_dist_avail_and_initialized():
|
| 119 |
+
return
|
| 120 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 121 |
+
dist.barrier()
|
| 122 |
+
dist.all_reduce(t)
|
| 123 |
+
t = t.tolist()
|
| 124 |
+
self.count = int(t[0])
|
| 125 |
+
self.total = t[1]
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def median(self):
|
| 129 |
+
d = torch.tensor(list(self.deque))
|
| 130 |
+
return d.median().item()
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def avg(self):
|
| 134 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 135 |
+
return d.mean().item()
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def global_avg(self):
|
| 139 |
+
return self.total / self.count
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def max(self):
|
| 143 |
+
return max(self.deque)
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def value(self):
|
| 147 |
+
return self.deque[-1]
|
| 148 |
+
|
| 149 |
+
def __str__(self):
|
| 150 |
+
return self.fmt.format(
|
| 151 |
+
median=self.median,
|
| 152 |
+
avg=self.avg,
|
| 153 |
+
global_avg=self.global_avg,
|
| 154 |
+
max=self.max,
|
| 155 |
+
value=self.value)
|
| 156 |
+
|
| 157 |
+
class MetricLogger(object):
|
| 158 |
+
def __init__(self, delimiter="\t"):
|
| 159 |
+
self.meters = defaultdict(SmoothedValue)
|
| 160 |
+
self.delimiter = delimiter
|
| 161 |
+
|
| 162 |
+
def update(self, **kwargs):
|
| 163 |
+
for k, v in kwargs.items():
|
| 164 |
+
if v is None:
|
| 165 |
+
continue
|
| 166 |
+
if isinstance(v, torch.Tensor):
|
| 167 |
+
v = v.item()
|
| 168 |
+
assert isinstance(v, (float, int))
|
| 169 |
+
self.meters[k].update(v)
|
| 170 |
+
|
| 171 |
+
def __getattr__(self, attr):
|
| 172 |
+
if attr in self.meters:
|
| 173 |
+
return self.meters[attr]
|
| 174 |
+
if attr in self.__dict__:
|
| 175 |
+
return self.__dict__[attr]
|
| 176 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 177 |
+
type(self).__name__, attr))
|
| 178 |
+
|
| 179 |
+
def __str__(self):
|
| 180 |
+
loss_str = []
|
| 181 |
+
for name, meter in self.meters.items():
|
| 182 |
+
loss_str.append(
|
| 183 |
+
"{}: {}".format(name, str(meter))
|
| 184 |
+
)
|
| 185 |
+
return self.delimiter.join(loss_str)
|
| 186 |
+
|
| 187 |
+
def synchronize_between_processes(self):
|
| 188 |
+
for meter in self.meters.values():
|
| 189 |
+
meter.synchronize_between_processes()
|
| 190 |
+
|
| 191 |
+
def add_meter(self, name, meter):
|
| 192 |
+
self.meters[name] = meter
|
| 193 |
+
|
| 194 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 195 |
+
i = 0
|
| 196 |
+
if not header:
|
| 197 |
+
header = ''
|
| 198 |
+
start_time = time.time()
|
| 199 |
+
end = time.time()
|
| 200 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 201 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 202 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 203 |
+
log_msg = [
|
| 204 |
+
header,
|
| 205 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 206 |
+
'eta: {eta}',
|
| 207 |
+
'{meters}',
|
| 208 |
+
'time: {time}',
|
| 209 |
+
'data: {data}'
|
| 210 |
+
]
|
| 211 |
+
if torch.cuda.is_available():
|
| 212 |
+
log_msg.append('max mem: {memory:.0f}')
|
| 213 |
+
log_msg = self.delimiter.join(log_msg)
|
| 214 |
+
MB = 1024.0 * 1024.0
|
| 215 |
+
for obj in iterable:
|
| 216 |
+
data_time.update(time.time() - end)
|
| 217 |
+
yield obj
|
| 218 |
+
iter_time.update(time.time() - end)
|
| 219 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 220 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 221 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 222 |
+
if torch.cuda.is_available():
|
| 223 |
+
print(log_msg.format(
|
| 224 |
+
i, len(iterable), eta=eta_string,
|
| 225 |
+
meters=str(self),
|
| 226 |
+
time=str(iter_time), data=str(data_time),
|
| 227 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 228 |
+
else:
|
| 229 |
+
print(log_msg.format(
|
| 230 |
+
i, len(iterable), eta=eta_string,
|
| 231 |
+
meters=str(self),
|
| 232 |
+
time=str(iter_time), data=str(data_time)))
|
| 233 |
+
i += 1
|
| 234 |
+
end = time.time()
|
| 235 |
+
total_time = time.time() - start_time
|
| 236 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 237 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
| 238 |
+
header, total_time_str, total_time / len(iterable)))
|
| 239 |
+
self.update(total_time=total_time)
|
| 240 |
+
|
| 241 |
+
def sync_fid_loss_fns(fid_loss_fn, device="cuda"):
|
| 242 |
+
"""
|
| 243 |
+
Synchronizes FID loss function metrics across all processes.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
fid_loss_fn (dict): Local FID loss function metrics on each process.
|
| 247 |
+
device (str): Device to move the merged FID metrics to.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
final_fid_loss_fn (dict): Merged FID loss function metrics on all processes.
|
| 251 |
+
"""
|
| 252 |
+
if not is_dist_avail_and_initialized():
|
| 253 |
+
return fid_loss_fn
|
| 254 |
+
|
| 255 |
+
serialized_fid_loss_fn = pickle.dumps(fid_loss_fn)
|
| 256 |
+
gathered_fid_loss_fn = [None] * dist.get_world_size()
|
| 257 |
+
|
| 258 |
+
dist.barrier()
|
| 259 |
+
|
| 260 |
+
dist.all_gather_object(gathered_fid_loss_fn, serialized_fid_loss_fn)
|
| 261 |
+
|
| 262 |
+
final_fid_loss_fn = {
|
| 263 |
+
1: FrechetInceptionDistance(feature_dim=2048).to(device),
|
| 264 |
+
2: FrechetInceptionDistance(feature_dim=2048).to(device),
|
| 265 |
+
4: FrechetInceptionDistance(feature_dim=2048).to(device),
|
| 266 |
+
8: FrechetInceptionDistance(feature_dim=2048).to(device),
|
| 267 |
+
16: FrechetInceptionDistance(feature_dim=2048).to(device),
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
for serialized_fid_loss_fn in gathered_fid_loss_fn:
|
| 271 |
+
curr_fid_loss_fn = pickle.loads(serialized_fid_loss_fn)
|
| 272 |
+
for sec in [1, 2, 4, 8, 16]:
|
| 273 |
+
sec_fid_loss_fn = curr_fid_loss_fn[sec]
|
| 274 |
+
final_fid_loss_fn[sec].merge_state([sec_fid_loss_fn])
|
| 275 |
+
|
| 276 |
+
return final_fid_loss_fn
|
| 277 |
+
|
eval_audio.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# eval_audio.py
|
| 2 |
+
from typing import Optional
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchaudio
|
| 10 |
+
import librosa
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
|
| 13 |
+
_EPS = 1e-12
|
| 14 |
+
|
| 15 |
+
def build_mel_transform(
|
| 16 |
+
sample_rate,
|
| 17 |
+
n_fft=1024,
|
| 18 |
+
win_length=1024,
|
| 19 |
+
hop_length=256,
|
| 20 |
+
n_mels=80,
|
| 21 |
+
power=1.0,
|
| 22 |
+
f_min=0.0,
|
| 23 |
+
f_max=None,
|
| 24 |
+
mel_scale="htk",
|
| 25 |
+
norm=None,
|
| 26 |
+
device=None,
|
| 27 |
+
):
|
| 28 |
+
mel_tf = torchaudio.transforms.MelSpectrogram(
|
| 29 |
+
sample_rate=sample_rate,
|
| 30 |
+
n_fft=n_fft,
|
| 31 |
+
win_length=win_length,
|
| 32 |
+
hop_length=hop_length,
|
| 33 |
+
f_min=f_min,
|
| 34 |
+
f_max=f_max,
|
| 35 |
+
n_mels=n_mels,
|
| 36 |
+
power=power,
|
| 37 |
+
center=True,
|
| 38 |
+
norm=norm,
|
| 39 |
+
mel_scale=mel_scale,
|
| 40 |
+
)
|
| 41 |
+
if device is not None:
|
| 42 |
+
mel_tf = mel_tf.to(device)
|
| 43 |
+
return mel_tf
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _ensure_stereo_torch(x):
|
| 47 |
+
if x.dim() == 1:
|
| 48 |
+
x = x.unsqueeze(0)
|
| 49 |
+
if x.size(0) == 1:
|
| 50 |
+
x = x.repeat(2, 1)
|
| 51 |
+
elif x.size(0) > 2:
|
| 52 |
+
x = x[:2]
|
| 53 |
+
return x
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def mel_cosine_stereo(
|
| 58 |
+
ref, hat, sample_rate,
|
| 59 |
+
n_fft=1024,
|
| 60 |
+
win_length=1024,
|
| 61 |
+
hop_length=256,
|
| 62 |
+
n_mels=80,
|
| 63 |
+
power=1.0,
|
| 64 |
+
mel_tf=None,
|
| 65 |
+
):
|
| 66 |
+
ref = _ensure_stereo_torch(ref)
|
| 67 |
+
hat = _ensure_stereo_torch(hat)
|
| 68 |
+
|
| 69 |
+
device = ref.device
|
| 70 |
+
if mel_tf is None:
|
| 71 |
+
mel_tf = build_mel_transform(
|
| 72 |
+
sample_rate=sample_rate,
|
| 73 |
+
n_fft=n_fft, win_length=win_length, hop_length=hop_length,
|
| 74 |
+
n_mels=n_mels, power=power, device=device
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
mel_tf = mel_tf.to(device)
|
| 78 |
+
|
| 79 |
+
Mr = mel_tf(ref)
|
| 80 |
+
Mh = mel_tf(hat)
|
| 81 |
+
|
| 82 |
+
Ar = Mr.reshape(Mr.size(0), -1)
|
| 83 |
+
Ah = Mh.reshape(Mh.size(0), -1)
|
| 84 |
+
|
| 85 |
+
sim = F.cosine_similarity(Ar, Ah, dim=-1)
|
| 86 |
+
return float(sim.mean().item())
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
def drms_avg_db_stereo(ref, hat, win_length=1024, hop_length=256):
|
| 91 |
+
ref = _ensure_stereo_torch(ref)
|
| 92 |
+
hat = _ensure_stereo_torch(hat)
|
| 93 |
+
|
| 94 |
+
def _rms_db(x):
|
| 95 |
+
C, T = x.size(0), x.size(1)
|
| 96 |
+
if T < win_length:
|
| 97 |
+
x = F.pad(x, (0, win_length - T))
|
| 98 |
+
frames = x.unfold(dimension=-1, size=win_length, step=hop_length)
|
| 99 |
+
rms = torch.sqrt(frames.pow(2).mean(dim=-1) + _EPS)
|
| 100 |
+
db = 20.0 * torch.log10(rms + _EPS)
|
| 101 |
+
return db
|
| 102 |
+
|
| 103 |
+
dbr = _rms_db(ref)
|
| 104 |
+
dbh = _rms_db(hat)
|
| 105 |
+
|
| 106 |
+
Fmin = min(dbr.size(-1), dbh.size(-1))
|
| 107 |
+
dbr = dbr[:, :Fmin]
|
| 108 |
+
dbh = dbh[:, :Fmin]
|
| 109 |
+
|
| 110 |
+
d_db = dbh - dbr
|
| 111 |
+
return float(d_db.mean(dim=-1).mean().item())
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def load_stereo_wav_np(path):
|
| 115 |
+
y, sr = librosa.load(path, sr=None, mono=False)
|
| 116 |
+
if y.ndim == 1:
|
| 117 |
+
y = np.stack([y, y], axis=0)
|
| 118 |
+
elif y.shape[0] != 2:
|
| 119 |
+
y = y[:2]
|
| 120 |
+
return y, sr
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def compute_spectrogram_np(audio_stereo,
|
| 124 |
+
n_fft=512,
|
| 125 |
+
hop_length=160,
|
| 126 |
+
win_length=400,
|
| 127 |
+
pool=4):
|
| 128 |
+
def _stft_abs(sig):
|
| 129 |
+
st = np.abs(librosa.stft(sig, n_fft=n_fft, hop_length=hop_length, win_length=win_length))
|
| 130 |
+
h, w = st.shape
|
| 131 |
+
hq, wq = h // pool, w // pool
|
| 132 |
+
if hq == 0 or wq == 0:
|
| 133 |
+
raise ValueError(f"audio too short for pooling (stft shape {st.shape})")
|
| 134 |
+
st = st[:hq * pool, :wq * pool]
|
| 135 |
+
st = st.reshape(hq, pool, wq, pool).mean(axis=(1, 3))
|
| 136 |
+
return st
|
| 137 |
+
|
| 138 |
+
L = np.log1p(_stft_abs(audio_stereo[0]))
|
| 139 |
+
if audio_stereo.shape[0] >= 2:
|
| 140 |
+
R = np.log1p(_stft_abs(audio_stereo[1]))
|
| 141 |
+
else:
|
| 142 |
+
R = L.copy()
|
| 143 |
+
spec = np.stack([L, R], axis=-1)
|
| 144 |
+
return spec
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def render_ref_hat_panel(title, spec_ref, spec_hat, out_path, cmap="magma"):
|
| 148 |
+
L_all = [spec_ref[:, :, 0], spec_hat[:, :, 0]]
|
| 149 |
+
R_all = [spec_ref[:, :, 1], spec_hat[:, :, 1]]
|
| 150 |
+
|
| 151 |
+
if any(a.size == 0 for a in L_all + R_all):
|
| 152 |
+
print(f"[SKIP]")
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
vmin_L = min(a.min() for a in L_all)
|
| 156 |
+
vmax_L = max(a.max() for a in L_all)
|
| 157 |
+
vmin_R = min(a.min() for a in R_all)
|
| 158 |
+
vmax_R = max(a.max() for a in R_all)
|
| 159 |
+
|
| 160 |
+
fig, axes = plt.subplots(2, 2, figsize=(8, 6), constrained_layout=True)
|
| 161 |
+
Lr, Rr = spec_ref[:, :, 0], spec_ref[:, :, 1]
|
| 162 |
+
Lh, Rh = spec_hat[:, :, 0], spec_hat[:, :, 1]
|
| 163 |
+
|
| 164 |
+
axes[0, 0].imshow(Lr, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_L, vmax=vmax_L)
|
| 165 |
+
axes[0, 1].imshow(Lh, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_L, vmax=vmax_L)
|
| 166 |
+
axes[1, 0].imshow(Rr, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_R, vmax=vmax_R)
|
| 167 |
+
axes[1, 1].imshow(Rh, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_R, vmax=vmax_R)
|
| 168 |
+
|
| 169 |
+
axes[0, 0].set_title("ref")
|
| 170 |
+
axes[0, 1].set_title("hat")
|
| 171 |
+
axes[0, 0].set_ylabel("Left")
|
| 172 |
+
axes[1, 0].set_ylabel("Right")
|
| 173 |
+
|
| 174 |
+
for ax in axes.ravel():
|
| 175 |
+
ax.set_xticks([])
|
| 176 |
+
ax.set_yticks([])
|
| 177 |
+
|
| 178 |
+
fig.suptitle(title)
|
| 179 |
+
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
|
| 180 |
+
plt.savefig(out_path, dpi=180)
|
| 181 |
+
plt.close(fig)
|
| 182 |
+
return True
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def save_ref_hat_spectrogram_panel(
|
| 186 |
+
ref, hat, out_path,
|
| 187 |
+
n_fft=512,
|
| 188 |
+
hop_length=160,
|
| 189 |
+
win_length=400,
|
| 190 |
+
pool=4,
|
| 191 |
+
title="ref vs hat (binaural spectrogram)",
|
| 192 |
+
cmap="magma",
|
| 193 |
+
):
|
| 194 |
+
def _to_np_stereo(x):
|
| 195 |
+
if isinstance(x, torch.Tensor):
|
| 196 |
+
x = x.detach().to(torch.float32).cpu().numpy()
|
| 197 |
+
if x.ndim == 1:
|
| 198 |
+
x = np.stack([x, x], axis=0)
|
| 199 |
+
elif x.shape[0] == 1:
|
| 200 |
+
x = np.repeat(x, 2, axis=0)
|
| 201 |
+
elif x.shape[0] > 2:
|
| 202 |
+
x = x[:2]
|
| 203 |
+
return x
|
| 204 |
+
|
| 205 |
+
ref_np = _to_np_stereo(ref)
|
| 206 |
+
hat_np = _to_np_stereo(hat)
|
| 207 |
+
|
| 208 |
+
spec_ref = compute_spectrogram_np(ref_np, n_fft=n_fft, hop_length=hop_length, win_length=win_length, pool=pool)
|
| 209 |
+
spec_hat = compute_spectrogram_np(hat_np, n_fft=n_fft, hop_length=hop_length, win_length=win_length, pool=pool)
|
| 210 |
+
return render_ref_hat_panel(title, spec_ref, spec_hat, out_path, cmap=cmap)
|
eval_metrics.py
ADDED
|
@@ -0,0 +1,1033 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import argparse
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist_torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import numpy as np
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import lpips
|
| 16 |
+
from dreamsim import dreamsim
|
| 17 |
+
from torchvision import transforms
|
| 18 |
+
from torcheval.metrics import FrechetInceptionDistance
|
| 19 |
+
import soundfile as sf
|
| 20 |
+
import resampy
|
| 21 |
+
import distributed as dist
|
| 22 |
+
import librosa
|
| 23 |
+
from skimage.metrics import structural_similarity as sk_ssim
|
| 24 |
+
from mel_scale import MelScale
|
| 25 |
+
|
| 26 |
+
# -----------------------------
|
| 27 |
+
# Safe, lazy import for FAD (avoid argparse conflicts from dependencies)
|
| 28 |
+
# -----------------------------
|
| 29 |
+
def safe_import_fad():
|
| 30 |
+
"""
|
| 31 |
+
Import frechet_audio_distance.FrechetAudioDistance without letting downstream
|
| 32 |
+
libraries parse our CLI args during import time.
|
| 33 |
+
"""
|
| 34 |
+
import importlib, sys
|
| 35 |
+
argv_backup = sys.argv[:]
|
| 36 |
+
try:
|
| 37 |
+
sys.argv = [argv_backup[0]] # hide our CLI flags from misbehaving imports
|
| 38 |
+
fad_mod = importlib.import_module("frechet_audio_distance")
|
| 39 |
+
return getattr(fad_mod, "FrechetAudioDistance")
|
| 40 |
+
finally:
|
| 41 |
+
sys.argv = argv_backup
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# -----------------------------
|
| 45 |
+
# Distributed init
|
| 46 |
+
# -----------------------------
|
| 47 |
+
def setup_distributed():
|
| 48 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ and "LOCAL_RANK" in os.environ:
|
| 49 |
+
rank = int(os.environ["RANK"])
|
| 50 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 51 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 52 |
+
else:
|
| 53 |
+
return 0, 1, 0
|
| 54 |
+
|
| 55 |
+
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
|
| 56 |
+
os.environ.setdefault("MASTER_PORT", "29500")
|
| 57 |
+
|
| 58 |
+
assert torch.cuda.is_available(), "CUDA Unavailable"
|
| 59 |
+
assert torch.cuda.device_count() > local_rank, "local_rank out of the number of GPUs"
|
| 60 |
+
torch.cuda.set_device(local_rank)
|
| 61 |
+
|
| 62 |
+
dist_torch.init_process_group(
|
| 63 |
+
backend="nccl",
|
| 64 |
+
init_method="env://",
|
| 65 |
+
rank=rank,
|
| 66 |
+
world_size=world_size,
|
| 67 |
+
)
|
| 68 |
+
dist_torch.barrier()
|
| 69 |
+
|
| 70 |
+
if rank == 0:
|
| 71 |
+
print(f"[init] world_size={world_size} | rank->gpu OK")
|
| 72 |
+
|
| 73 |
+
return rank, world_size, local_rank
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# -----------------------------
|
| 77 |
+
# Vision metrics factory
|
| 78 |
+
# -----------------------------
|
| 79 |
+
def get_loss_fn(loss_fn_type, secs, device):
|
| 80 |
+
if loss_fn_type == 'lpips':
|
| 81 |
+
general_lpips_loss_fn = lpips.LPIPS(net='alex').to(device).eval()
|
| 82 |
+
|
| 83 |
+
def loss_fn(img0_paths, img1_paths):
|
| 84 |
+
img0_list, img1_list = [], []
|
| 85 |
+
for p0, p1 in zip(img0_paths, img1_paths):
|
| 86 |
+
img0 = lpips.im2tensor(lpips.load_image(p0)).to(device) # [-1,1]
|
| 87 |
+
img1 = lpips.im2tensor(lpips.load_image(p1)).to(device)
|
| 88 |
+
img0_list.append(img0)
|
| 89 |
+
img1_list.append(img1)
|
| 90 |
+
all_img0 = torch.cat(img0_list, dim=0)
|
| 91 |
+
all_img1 = torch.cat(img1_list, dim=0)
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
dist_val = general_lpips_loss_fn.forward(all_img0, all_img1)
|
| 94 |
+
return dist_val.mean()
|
| 95 |
+
|
| 96 |
+
elif loss_fn_type == 'dreamsim':
|
| 97 |
+
dreamsim_loss_fn, preprocess = dreamsim(pretrained=True, device=device)
|
| 98 |
+
dreamsim_loss_fn.eval()
|
| 99 |
+
|
| 100 |
+
def loss_fn(img0_paths, img1_paths):
|
| 101 |
+
img0_list, img1_list = [], []
|
| 102 |
+
for p0, p1 in zip(img0_paths, img1_paths):
|
| 103 |
+
img0 = preprocess(Image.open(p0)).to(device)
|
| 104 |
+
img1 = preprocess(Image.open(p1)).to(device)
|
| 105 |
+
img0_list.append(img0)
|
| 106 |
+
img1_list.append(img1)
|
| 107 |
+
all_img0 = torch.cat(img0_list, dim=0)
|
| 108 |
+
all_img1 = torch.cat(img1_list, dim=0)
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
dist_val = dreamsim_loss_fn(all_img0, all_img1)
|
| 111 |
+
return dist_val.mean()
|
| 112 |
+
|
| 113 |
+
elif loss_fn_type == 'fid':
|
| 114 |
+
fid_metrics = {}
|
| 115 |
+
for sec in secs:
|
| 116 |
+
fid_metrics[sec] = FrechetInceptionDistance(feature_dim=2048).to(device)
|
| 117 |
+
return fid_metrics
|
| 118 |
+
|
| 119 |
+
else:
|
| 120 |
+
raise NotImplementedError
|
| 121 |
+
|
| 122 |
+
return loss_fn
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ===== Helpers for LSD/SSIM (reproducing AudioMetrics behavior) =====
|
| 126 |
+
_EPS = 1e-12
|
| 127 |
+
|
| 128 |
+
def _ensure_stereo_np(y: np.ndarray):
|
| 129 |
+
if y.ndim == 1:
|
| 130 |
+
y = np.stack([y, y], axis=0)
|
| 131 |
+
elif y.ndim == 2:
|
| 132 |
+
if y.shape[0] == 1:
|
| 133 |
+
y = np.concatenate([y, y], axis=0)
|
| 134 |
+
elif y.shape[0] > 2:
|
| 135 |
+
y = y[:2, :]
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError("Unsupported audio array shape")
|
| 138 |
+
return y
|
| 139 |
+
|
| 140 |
+
def _wav_to_spectrogram(wav: np.ndarray, rate: int):
|
| 141 |
+
if rate == 44100:
|
| 142 |
+
hop_length = 441
|
| 143 |
+
n_fft = 2048
|
| 144 |
+
elif rate == 16000:
|
| 145 |
+
hop_length = 160
|
| 146 |
+
n_fft = 743
|
| 147 |
+
else:
|
| 148 |
+
raise ValueError("Bad Samplerate (expected 16000 or 44100)")
|
| 149 |
+
|
| 150 |
+
f = np.abs(librosa.stft(wav, hop_length=hop_length, n_fft=n_fft)) # [F, T]
|
| 151 |
+
f = np.transpose(f, (1, 0)) # [T, F]
|
| 152 |
+
f_torch = torch.tensor(f[None, None, ...], dtype=torch.float32) # [1,1,T,F]
|
| 153 |
+
return f_torch
|
| 154 |
+
|
| 155 |
+
def _lsd_from_specs(est: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
ratio = (target ** 2) / ((est + _EPS) ** 2) + _EPS
|
| 157 |
+
lsd = torch.log10(ratio) ** 2
|
| 158 |
+
lsd = torch.mean(torch.mean(lsd, dim=3) ** 0.5, dim=2)
|
| 159 |
+
return lsd.mean()
|
| 160 |
+
|
| 161 |
+
def _mel_lsd_ssim_single(
|
| 162 |
+
e_wav: np.ndarray,
|
| 163 |
+
g_wav: np.ndarray,
|
| 164 |
+
mel_tf: MelScale,
|
| 165 |
+
n_fft: int = 743,
|
| 166 |
+
hop_length: int = 160,
|
| 167 |
+
) -> tuple[float, float]:
|
| 168 |
+
est_mag = np.abs(librosa.stft(e_wav, n_fft=n_fft, hop_length=hop_length))
|
| 169 |
+
ref_mag = np.abs(librosa.stft(g_wav, n_fft=n_fft, hop_length=hop_length))
|
| 170 |
+
est_mag_t = torch.from_numpy(est_mag).float()
|
| 171 |
+
ref_mag_t = torch.from_numpy(ref_mag).float()
|
| 172 |
+
est_mel = mel_tf(est_mag_t)
|
| 173 |
+
ref_mel = mel_tf(ref_mag_t)
|
| 174 |
+
ex_m = est_mel.transpose(0, 1).unsqueeze(0).unsqueeze(0)
|
| 175 |
+
gt_m = ref_mel.transpose(0, 1).unsqueeze(0).unsqueeze(0)
|
| 176 |
+
mel_lsd = float(_lsd_from_specs(ex_m, gt_m))
|
| 177 |
+
mel_ssim = float(_ssim_from_specs(ex_m, gt_m))
|
| 178 |
+
return mel_lsd, mel_ssim
|
| 179 |
+
|
| 180 |
+
def _to_log_specs(x: torch.Tensor) -> torch.Tensor:
|
| 181 |
+
return torch.log10(x + _EPS)
|
| 182 |
+
|
| 183 |
+
def _pow_p_norm(x: torch.Tensor) -> torch.Tensor:
|
| 184 |
+
return torch.mean(x.pow(2), dim=(2, 3))
|
| 185 |
+
|
| 186 |
+
def _energy_unify(est: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 187 |
+
p_est = _pow_p_norm(est)
|
| 188 |
+
p_tgt = _pow_p_norm(target)
|
| 189 |
+
scale = torch.sqrt((p_tgt + _EPS) / (p_est + _EPS))
|
| 190 |
+
scale = scale[..., None, None]
|
| 191 |
+
est_scaled = est * scale
|
| 192 |
+
return est_scaled, target
|
| 193 |
+
|
| 194 |
+
def _sispec_from_specs(est: torch.Tensor, target: torch.Tensor, log_domain: bool) -> torch.Tensor:
|
| 195 |
+
if log_domain:
|
| 196 |
+
est = _to_log_specs(est)
|
| 197 |
+
target = _to_log_specs(target)
|
| 198 |
+
est_u, tgt_u = _energy_unify(est, target)
|
| 199 |
+
noise = est_u - tgt_u
|
| 200 |
+
snr = ( _pow_p_norm(tgt_u) / (_pow_p_norm(noise) + _EPS) ) + _EPS
|
| 201 |
+
sp_loss = 10.0 * torch.log10(snr)
|
| 202 |
+
return sp_loss.mean()
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# ===== Image PSNR (RGB on [0,1]) =====
|
| 206 |
+
def _psnr_from_tensors(gt: torch.Tensor, pred: torch.Tensor, data_range: float = 1.0, eps: float = 1e-10) -> torch.Tensor:
|
| 207 |
+
mse = torch.mean((gt - pred) ** 2, dim=(1, 2, 3))
|
| 208 |
+
dr = torch.as_tensor(data_range, device=gt.device, dtype=gt.dtype)
|
| 209 |
+
psnr = 10.0 * torch.log10((dr * dr) / (mse + eps))
|
| 210 |
+
return psnr
|
| 211 |
+
|
| 212 |
+
def _ssim_from_specs(est: torch.Tensor, target: torch.Tensor) -> float:
|
| 213 |
+
if est.is_cuda:
|
| 214 |
+
est_np = est.detach().cpu().numpy()
|
| 215 |
+
tgt_np = target.detach().cpu().numpy()
|
| 216 |
+
else:
|
| 217 |
+
est_np = est.numpy()
|
| 218 |
+
tgt_np = target.numpy()
|
| 219 |
+
|
| 220 |
+
N, C, _, _ = est_np.shape
|
| 221 |
+
acc, cnt = 0.0, 0
|
| 222 |
+
for n in range(N):
|
| 223 |
+
for c in range(C):
|
| 224 |
+
ref = tgt_np[n, c, ...]
|
| 225 |
+
out = est_np[n, c, ...]
|
| 226 |
+
rng = float(out.max() - out.min())
|
| 227 |
+
rng = 1.0 if rng == 0.0 else rng
|
| 228 |
+
s = sk_ssim(out, ref, win_size=7, data_range=rng)
|
| 229 |
+
acc += float(s); cnt += 1
|
| 230 |
+
return acc / max(cnt, 1)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# ==========================================================
|
| 234 |
+
# Streaming, DDP-friendly Audio FAD
|
| 235 |
+
# (embeddings identical to official FrechetAudioDistance)
|
| 236 |
+
# ==========================================================
|
| 237 |
+
class _RunningGaussianStats:
|
| 238 |
+
def __init__(self, feat_dim: int, device: torch.device):
|
| 239 |
+
self.D = feat_dim
|
| 240 |
+
self.device = device
|
| 241 |
+
self.reset()
|
| 242 |
+
|
| 243 |
+
def reset(self):
|
| 244 |
+
D = self.D
|
| 245 |
+
self.count = torch.zeros(1, device=self.device, dtype=torch.float64)
|
| 246 |
+
self.sum_feat = torch.zeros(D, device=self.device, dtype=torch.float64)
|
| 247 |
+
self.sum_outer = torch.zeros(D, D, device=self.device, dtype=torch.float64)
|
| 248 |
+
|
| 249 |
+
@torch.no_grad()
|
| 250 |
+
def update(self, feats: torch.Tensor): # [N, D]
|
| 251 |
+
if feats is None or feats.numel() == 0:
|
| 252 |
+
return
|
| 253 |
+
f = feats.to(dtype=torch.float64)
|
| 254 |
+
self.count += torch.tensor([f.shape[0]], device=self.device, dtype=torch.float64)
|
| 255 |
+
self.sum_feat += f.sum(dim=0)
|
| 256 |
+
self.sum_outer += f.t().mm(f)
|
| 257 |
+
|
| 258 |
+
@torch.no_grad()
|
| 259 |
+
def sync(self):
|
| 260 |
+
if dist_torch.is_initialized():
|
| 261 |
+
for t in (self.count, self.sum_feat, self.sum_outer):
|
| 262 |
+
dist_torch.all_reduce(t, op=dist_torch.ReduceOp.SUM)
|
| 263 |
+
|
| 264 |
+
@torch.no_grad()
|
| 265 |
+
def mean_cov(self, eps: float = 1e-6):
|
| 266 |
+
n = int(self.count.item())
|
| 267 |
+
if n == 0:
|
| 268 |
+
return None, None
|
| 269 |
+
mean = self.sum_feat / self.count
|
| 270 |
+
cov = self.sum_outer / self.count - torch.ger(mean, mean)
|
| 271 |
+
cov = cov + torch.eye(self.D, device=self.device, dtype=torch.float64) * eps
|
| 272 |
+
return mean, cov
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@torch.no_grad()
|
| 276 |
+
def _frechet_distance_torch(mean1, cov1, mean2, cov2) -> float:
|
| 277 |
+
diff = mean1 - mean2
|
| 278 |
+
diff2 = diff.dot(diff)
|
| 279 |
+
evals1, evecs1 = torch.linalg.eigh(cov1)
|
| 280 |
+
sqrt1 = evecs1 @ torch.diag(evals1.clamp(min=0).sqrt()) @ evecs1.t()
|
| 281 |
+
prod = sqrt1 @ cov2 @ sqrt1
|
| 282 |
+
evals_prod = torch.linalg.eigvalsh(prod).clamp(min=0).sqrt()
|
| 283 |
+
trace = torch.trace(cov1 + cov2) - 2.0 * evals_prod.sum()
|
| 284 |
+
return float((diff2 + trace).item())
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class StreamingFAD:
|
| 288 |
+
"""
|
| 289 |
+
Mono (downmix) FID-style streaming FAD:
|
| 290 |
+
- update_from_wavs(paths, is_real=True/False)
|
| 291 |
+
- compute() # does DDP all_reduce internally
|
| 292 |
+
"""
|
| 293 |
+
def __init__(self, fad_backend, pad_seconds: float = 0.96, batch_size: int = 16):
|
| 294 |
+
self.fad = fad_backend
|
| 295 |
+
self.device = self.fad.device
|
| 296 |
+
self.bs = batch_size
|
| 297 |
+
self.pad_len = int(round(self.fad.sample_rate * float(pad_seconds)))
|
| 298 |
+
self.feat_dim = self._infer_feat_dim()
|
| 299 |
+
self.real_stats = _RunningGaussianStats(self.feat_dim, self.device)
|
| 300 |
+
self.fake_stats = _RunningGaussianStats(self.feat_dim, self.device)
|
| 301 |
+
|
| 302 |
+
def _infer_feat_dim(self) -> int:
|
| 303 |
+
sr = self.fad.sample_rate
|
| 304 |
+
x = np.zeros((self.pad_len,), dtype=np.float32)
|
| 305 |
+
emb = self.fad.get_embeddings([x], sr=sr)
|
| 306 |
+
return int(emb.shape[-1]) if isinstance(emb, np.ndarray) else int(emb.shape[-1])
|
| 307 |
+
|
| 308 |
+
@torch.no_grad()
|
| 309 |
+
def _load_and_resample(self, path: str):
|
| 310 |
+
try:
|
| 311 |
+
audio, sr = sf.read(path, dtype="float32", always_2d=False)
|
| 312 |
+
except Exception as e:
|
| 313 |
+
print(f"[StreamingFAD] read error: {path}: {e}")
|
| 314 |
+
return None
|
| 315 |
+
if audio is None or (isinstance(audio, np.ndarray) and audio.size == 0):
|
| 316 |
+
return None
|
| 317 |
+
if isinstance(audio, np.ndarray) and audio.ndim == 2:
|
| 318 |
+
audio = audio.mean(axis=1)
|
| 319 |
+
if sr != self.fad.sample_rate:
|
| 320 |
+
try:
|
| 321 |
+
audio = resampy.resample(audio, sr, self.fad.sample_rate)
|
| 322 |
+
except Exception as e:
|
| 323 |
+
print(f"[StreamingFAD] resample error: {path}: {e}")
|
| 324 |
+
return None
|
| 325 |
+
if audio.shape[0] < self.pad_len:
|
| 326 |
+
pad = np.zeros((self.pad_len - audio.shape[0],), dtype=np.float32)
|
| 327 |
+
audio = np.concatenate([audio, pad], axis=0)
|
| 328 |
+
return audio.astype(np.float32, copy=False)
|
| 329 |
+
|
| 330 |
+
@torch.no_grad()
|
| 331 |
+
def update_from_wavs(self, wav_paths, is_real: bool):
|
| 332 |
+
if not wav_paths:
|
| 333 |
+
return
|
| 334 |
+
xs = []
|
| 335 |
+
for p in wav_paths:
|
| 336 |
+
a = self._load_and_resample(p)
|
| 337 |
+
if a is not None:
|
| 338 |
+
xs.append(a)
|
| 339 |
+
if not xs:
|
| 340 |
+
return
|
| 341 |
+
feats_chunks = []
|
| 342 |
+
for i in range(0, len(xs), self.bs):
|
| 343 |
+
chunk = xs[i:i+self.bs]
|
| 344 |
+
emb_np = self.fad.get_embeddings(chunk, sr=self.fad.sample_rate)
|
| 345 |
+
if isinstance(emb_np, np.ndarray):
|
| 346 |
+
if emb_np.size == 0:
|
| 347 |
+
continue
|
| 348 |
+
feats_chunks.append(torch.from_numpy(emb_np).to(self.device))
|
| 349 |
+
else:
|
| 350 |
+
if emb_np.numel() == 0:
|
| 351 |
+
continue
|
| 352 |
+
feats_chunks.append(emb_np.to(self.device))
|
| 353 |
+
if len(feats_chunks) == 0:
|
| 354 |
+
return
|
| 355 |
+
feats = torch.cat(feats_chunks, dim=0)
|
| 356 |
+
(self.real_stats if is_real else self.fake_stats).update(feats)
|
| 357 |
+
|
| 358 |
+
@torch.no_grad()
|
| 359 |
+
def compute(self) -> float:
|
| 360 |
+
self.real_stats.sync()
|
| 361 |
+
self.fake_stats.sync()
|
| 362 |
+
m1, c1 = self.real_stats.mean_cov()
|
| 363 |
+
m2, c2 = self.fake_stats.mean_cov()
|
| 364 |
+
if (m1 is None) or (m2 is None):
|
| 365 |
+
raise RuntimeError("StreamingFAD: empty stats")
|
| 366 |
+
return _frechet_distance_torch(m1, c1, m2, c2)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class StereoStreamingFAD:
|
| 370 |
+
def __init__(self, fad_backend, pad_seconds: float = 0.96, batch_size: int = 16):
|
| 371 |
+
self.fad = fad_backend
|
| 372 |
+
self.device = self.fad.device
|
| 373 |
+
self.bs = batch_size
|
| 374 |
+
self.pad_len = int(round(self.fad.sample_rate * float(pad_seconds)))
|
| 375 |
+
|
| 376 |
+
self.feat_dim = self._infer_feat_dim()
|
| 377 |
+
self.L_real = _RunningGaussianStats(self.feat_dim, self.device)
|
| 378 |
+
self.L_fake = _RunningGaussianStats(self.feat_dim, self.device)
|
| 379 |
+
self.R_real = _RunningGaussianStats(self.feat_dim, self.device)
|
| 380 |
+
self.R_fake = _RunningGaussianStats(self.feat_dim, self.device)
|
| 381 |
+
|
| 382 |
+
def _infer_feat_dim(self) -> int:
|
| 383 |
+
sr = self.fad.sample_rate
|
| 384 |
+
x = np.zeros((self.pad_len,), dtype=np.float32)
|
| 385 |
+
emb = self.fad.get_embeddings([x], sr=sr)
|
| 386 |
+
return int(emb.shape[-1]) if isinstance(emb, np.ndarray) else int(emb.shape[-1])
|
| 387 |
+
|
| 388 |
+
@torch.no_grad()
|
| 389 |
+
def _load_lr_and_resample_pad(self, path: str):
|
| 390 |
+
try:
|
| 391 |
+
audio, sr = sf.read(path, dtype="float32", always_2d=True) # [T, C]
|
| 392 |
+
except Exception as e:
|
| 393 |
+
print(f"[StereoFAD] read error: {path}: {e}")
|
| 394 |
+
return None, None
|
| 395 |
+
if audio is None or audio.size == 0:
|
| 396 |
+
return None, None
|
| 397 |
+
|
| 398 |
+
C = audio.shape[1]
|
| 399 |
+
if C == 1:
|
| 400 |
+
L = audio[:, 0]; R = audio[:, 0]
|
| 401 |
+
else:
|
| 402 |
+
L = audio[:, 0]; R = audio[:, 1] if C >= 2 else audio[:, 0]
|
| 403 |
+
|
| 404 |
+
if sr != self.fad.sample_rate:
|
| 405 |
+
try:
|
| 406 |
+
L = resampy.resample(L, sr, self.fad.sample_rate)
|
| 407 |
+
R = resampy.resample(R, sr, self.fad.sample_rate)
|
| 408 |
+
except Exception as e:
|
| 409 |
+
print(f"[StereoFAD] resample error: {path}: {e}")
|
| 410 |
+
return None, None
|
| 411 |
+
|
| 412 |
+
def _pad_to_len(x: np.ndarray, n: int):
|
| 413 |
+
if x.shape[0] >= n:
|
| 414 |
+
return x.astype(np.float32, copy=False)
|
| 415 |
+
pad = np.zeros((n - x.shape[0],), dtype=np.float32)
|
| 416 |
+
return np.concatenate([x, pad], axis=0)
|
| 417 |
+
|
| 418 |
+
L = _pad_to_len(L, self.pad_len)
|
| 419 |
+
R = _pad_to_len(R, self.pad_len)
|
| 420 |
+
return L, R
|
| 421 |
+
|
| 422 |
+
@torch.no_grad()
|
| 423 |
+
def update_from_wavs(self, wav_paths, is_real: bool):
|
| 424 |
+
if not wav_paths:
|
| 425 |
+
return
|
| 426 |
+
L_list, R_list = [], []
|
| 427 |
+
for p in wav_paths:
|
| 428 |
+
L, R = self._load_lr_and_resample_pad(p)
|
| 429 |
+
if L is not None and R is not None:
|
| 430 |
+
L_list.append(L); R_list.append(R)
|
| 431 |
+
if not L_list:
|
| 432 |
+
return
|
| 433 |
+
|
| 434 |
+
def _embed_and_update(xs, stats_obj: _RunningGaussianStats):
|
| 435 |
+
feats_chunks = []
|
| 436 |
+
for i in range(0, len(xs), self.bs):
|
| 437 |
+
chunk = xs[i:i+self.bs]
|
| 438 |
+
emb_np = self.fad.get_embeddings(chunk, sr=self.fad.sample_rate)
|
| 439 |
+
if isinstance(emb_np, np.ndarray):
|
| 440 |
+
if emb_np.size == 0:
|
| 441 |
+
continue
|
| 442 |
+
feats_chunks.append(torch.from_numpy(emb_np).to(self.device))
|
| 443 |
+
else:
|
| 444 |
+
if emb_np.numel() == 0:
|
| 445 |
+
continue
|
| 446 |
+
feats_chunks.append(emb_np.to(self.device))
|
| 447 |
+
if len(feats_chunks) == 0:
|
| 448 |
+
return
|
| 449 |
+
feats = torch.cat(feats_chunks, dim=0)
|
| 450 |
+
stats_obj.update(feats)
|
| 451 |
+
|
| 452 |
+
if is_real:
|
| 453 |
+
_embed_and_update(L_list, self.L_real)
|
| 454 |
+
_embed_and_update(R_list, self.R_real)
|
| 455 |
+
else:
|
| 456 |
+
_embed_and_update(L_list, self.L_fake)
|
| 457 |
+
_embed_and_update(R_list, self.R_fake)
|
| 458 |
+
|
| 459 |
+
@torch.no_grad()
|
| 460 |
+
def compute(self):
|
| 461 |
+
for t in (self.L_real, self.L_fake, self.R_real, self.R_fake):
|
| 462 |
+
t.sync()
|
| 463 |
+
mL_r, cL_r = self.L_real.mean_cov()
|
| 464 |
+
mL_f, cL_f = self.L_fake.mean_cov()
|
| 465 |
+
mR_r, cR_r = self.R_real.mean_cov()
|
| 466 |
+
mR_f, cR_f = self.R_fake.mean_cov()
|
| 467 |
+
if (mL_r is None) or (mL_f is None) or (mR_r is None) or (mR_f is None):
|
| 468 |
+
raise RuntimeError("StereoStreamingFAD: empty stats")
|
| 469 |
+
|
| 470 |
+
fad_left = _frechet_distance_torch(mL_r, cL_r, mL_f, cL_f)
|
| 471 |
+
fad_right = _frechet_distance_torch(mR_r, cR_r, mR_f, cR_f)
|
| 472 |
+
fad_mean = 0.5 * (fad_left + fad_right)
|
| 473 |
+
return float(fad_left), float(fad_right), float(fad_mean)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
# -----------------------------
|
| 477 |
+
# Stereo-friendly Audio Metrics (LSD/SSIM/MelCos/DRMS)
|
| 478 |
+
# -----------------------------
|
| 479 |
+
def _load_librosa_stereo(path: str, sr: int) -> np.ndarray:
|
| 480 |
+
y, _ = librosa.load(path, sr=sr, mono=False)
|
| 481 |
+
y = _ensure_stereo_np(y) # (2, T)
|
| 482 |
+
return y
|
| 483 |
+
|
| 484 |
+
def _mel_cosine_single_channel(wav: np.ndarray, ref: np.ndarray, sr: int, mel_tf: MelScale) -> float:
|
| 485 |
+
hop_length = 160; n_fft = 743
|
| 486 |
+
est_mag = np.abs(librosa.stft(wav, hop_length=hop_length, n_fft=n_fft)) # [F, T]
|
| 487 |
+
ref_mag = np.abs(librosa.stft(ref, hop_length=hop_length, n_fft=n_fft))
|
| 488 |
+
|
| 489 |
+
est_mag_t = torch.tensor(est_mag, dtype=torch.float32) # [F,T]
|
| 490 |
+
ref_mag_t = torch.tensor(ref_mag, dtype=torch.float32) # [F,T]
|
| 491 |
+
|
| 492 |
+
est_mel = mel_tf(est_mag_t) # [80, T]
|
| 493 |
+
ref_mel = mel_tf(ref_mag_t) # [80, T]
|
| 494 |
+
|
| 495 |
+
sim = F.cosine_similarity(est_mel.flatten(), ref_mel.flatten(), dim=0)
|
| 496 |
+
return float(sim.item())
|
| 497 |
+
|
| 498 |
+
# -----------------------------
|
| 499 |
+
# Evaluate
|
| 500 |
+
# -----------------------------
|
| 501 |
+
def evaluate(args, dataset_name, eval_type, metric_logger, loss_fns,
|
| 502 |
+
gt_dir, exp_dir, secs, device, rank, world_size, modals):
|
| 503 |
+
|
| 504 |
+
lpips_loss_fn, dreamsim_loss_fn, fid_loss_fn = loss_fns
|
| 505 |
+
|
| 506 |
+
if eval_type == 'rollout':
|
| 507 |
+
eval_name = 'rollout'
|
| 508 |
+
image_idxs = secs.copy()
|
| 509 |
+
elif eval_type == 'time':
|
| 510 |
+
eval_name = eval_type
|
| 511 |
+
image_idxs = secs.copy()
|
| 512 |
+
else:
|
| 513 |
+
raise ValueError(f"Unknown eval_type {eval_type}")
|
| 514 |
+
|
| 515 |
+
if 'v' in modals:
|
| 516 |
+
for s in secs:
|
| 517 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_fid_{int(s)}'].update(0.0, n=0)
|
| 518 |
+
|
| 519 |
+
# Episodes split by rank
|
| 520 |
+
all_eps = sorted([e for e in os.listdir(gt_dir) if os.path.isdir(os.path.join(gt_dir, e))])
|
| 521 |
+
eps = all_eps[rank::world_size]
|
| 522 |
+
if len(eps) == 0:
|
| 523 |
+
return
|
| 524 |
+
|
| 525 |
+
to_tensor = transforms.ToTensor()
|
| 526 |
+
|
| 527 |
+
fad_streams = {}
|
| 528 |
+
stereo_mode = False
|
| 529 |
+
if 'a' in modals:
|
| 530 |
+
try:
|
| 531 |
+
FADLib = safe_import_fad()
|
| 532 |
+
except Exception as e:
|
| 533 |
+
if rank == 0:
|
| 534 |
+
print(f"[WARN] Fail to import frechet_audio_distance:{e}")
|
| 535 |
+
FADLib = None
|
| 536 |
+
|
| 537 |
+
if FADLib is not None:
|
| 538 |
+
base_fad = FADLib(
|
| 539 |
+
model_name=args.fad_model,
|
| 540 |
+
sample_rate=args.fad_sr,
|
| 541 |
+
verbose=False
|
| 542 |
+
)
|
| 543 |
+
if args.fad_model == 'vggish' and not args.mono:
|
| 544 |
+
stereo_mode = True
|
| 545 |
+
for sec in secs:
|
| 546 |
+
fad_streams[sec] = StereoStreamingFAD(base_fad, pad_seconds=args.fad_pad_sec, batch_size=16)
|
| 547 |
+
else:
|
| 548 |
+
for sec in secs:
|
| 549 |
+
fad_streams[sec] = StreamingFAD(base_fad, pad_seconds=args.fad_pad_sec, batch_size=16)
|
| 550 |
+
|
| 551 |
+
mel_tf = MelScale(n_mels=80, sample_rate=16000, n_stft=372)
|
| 552 |
+
|
| 553 |
+
for batch_start in tqdm(range(0, len(eps), args.batch_size),
|
| 554 |
+
total=(len(eps) + args.batch_size - 1) // args.batch_size,
|
| 555 |
+
disable=(rank != 0)):
|
| 556 |
+
batch_eps = eps[batch_start:batch_start + args.batch_size]
|
| 557 |
+
|
| 558 |
+
# per-sec containers (vision)
|
| 559 |
+
gt_img_batch, exp_img_batch = {}, {}
|
| 560 |
+
gt_img_paths_batch, exp_img_paths_batch = {}, {}
|
| 561 |
+
denorm_pairs_by_sec = {}
|
| 562 |
+
secs_py = [int(s) for s in secs]
|
| 563 |
+
denorm_pairs_by_sec = {s: [] for s in secs_py}
|
| 564 |
+
for sec in secs:
|
| 565 |
+
gt_img_batch[sec], exp_img_batch[sec] = [], []
|
| 566 |
+
gt_img_paths_batch[sec], exp_img_paths_batch[sec] = [], []
|
| 567 |
+
|
| 568 |
+
# per-sec containers (audio paths)
|
| 569 |
+
gt_wav_paths_batch, exp_wav_paths_batch = {}, {}
|
| 570 |
+
for sec in secs:
|
| 571 |
+
gt_wav_paths_batch[sec], exp_wav_paths_batch[sec] = [], []
|
| 572 |
+
|
| 573 |
+
for ep in batch_eps:
|
| 574 |
+
gt_ep_dir = os.path.join(gt_dir, ep)
|
| 575 |
+
exp_ep_dir = os.path.join(exp_dir, ep)
|
| 576 |
+
|
| 577 |
+
if (not os.path.isdir(gt_ep_dir)) or (not os.path.isdir(exp_ep_dir)):
|
| 578 |
+
continue
|
| 579 |
+
|
| 580 |
+
gt_dist_p = os.path.join(gt_ep_dir, "distance.json")
|
| 581 |
+
exp_dist_p = os.path.join(exp_ep_dir, "distance.json")
|
| 582 |
+
try:
|
| 583 |
+
if os.path.isfile(gt_dist_p) and os.path.isfile(exp_dist_p):
|
| 584 |
+
with open(gt_dist_p, "r") as f: gt_list = json.load(f)
|
| 585 |
+
with open(exp_dist_p, "r") as f: exp_list = json.load(f)
|
| 586 |
+
gt_map = {int(it["sec"]): float(it["denorm_gt"]) for it in gt_list if "sec" in it and "denorm_gt" in it}
|
| 587 |
+
exp_map = {int(it["sec"]): float(it["denorm_pred"]) for it in exp_list if "sec" in it and "denorm_pred" in it}
|
| 588 |
+
for s in secs_py:
|
| 589 |
+
if s in gt_map and s in exp_map:
|
| 590 |
+
denorm_pairs_by_sec[s].append((gt_map[s], exp_map[s]))
|
| 591 |
+
except Exception:
|
| 592 |
+
pass
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
for sec, image_idx in zip(secs, image_idxs):
|
| 596 |
+
# ---- vision
|
| 597 |
+
if 'v' in modals:
|
| 598 |
+
gt_sec_img_path = os.path.join(gt_ep_dir, f'{int(image_idx)}.png')
|
| 599 |
+
exp_sec_img_path = os.path.join(exp_ep_dir, f'{int(image_idx)}.png')
|
| 600 |
+
if os.path.isfile(gt_sec_img_path) and os.path.isfile(exp_sec_img_path):
|
| 601 |
+
try:
|
| 602 |
+
gt_img = to_tensor(Image.open(gt_sec_img_path).convert("RGB")).unsqueeze(0).to(device)
|
| 603 |
+
exp_img = to_tensor(Image.open(exp_sec_img_path).convert("RGB")).unsqueeze(0).to(device)
|
| 604 |
+
if torch.isfinite(gt_img).all() and torch.isfinite(exp_img).all():
|
| 605 |
+
gt_img_batch[sec].append(gt_img)
|
| 606 |
+
exp_img_batch[sec].append(exp_img)
|
| 607 |
+
gt_img_paths_batch[sec].append(gt_sec_img_path)
|
| 608 |
+
exp_img_paths_batch[sec].append(exp_sec_img_path)
|
| 609 |
+
except Exception:
|
| 610 |
+
pass
|
| 611 |
+
|
| 612 |
+
# ---- audio
|
| 613 |
+
if 'a' in modals:
|
| 614 |
+
gt_sec_wav_path = os.path.join(gt_ep_dir, f'{int(image_idx)}.wav')
|
| 615 |
+
exp_sec_wav_path = os.path.join(exp_ep_dir, f'{int(image_idx)}.wav')
|
| 616 |
+
if os.path.isfile(gt_sec_wav_path) and os.path.isfile(exp_sec_wav_path):
|
| 617 |
+
gt_wav_paths_batch[sec].append(gt_sec_wav_path)
|
| 618 |
+
exp_wav_paths_batch[sec].append(exp_sec_wav_path)
|
| 619 |
+
|
| 620 |
+
# ---- vision metric update per batch
|
| 621 |
+
if 'v' in modals:
|
| 622 |
+
for sec in secs:
|
| 623 |
+
if (len(gt_img_batch[sec]) == 0) or (len(exp_img_batch[sec]) == 0):
|
| 624 |
+
continue
|
| 625 |
+
lpips_dists = lpips_loss_fn(gt_img_paths_batch[sec], exp_img_paths_batch[sec])
|
| 626 |
+
dreamsim_dists = dreamsim_loss_fn(gt_img_paths_batch[sec], exp_img_paths_batch[sec])
|
| 627 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_lpips_{sec}'].update(lpips_dists, n=1)
|
| 628 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_dreamsim_{sec}'].update(dreamsim_dists, n=1)
|
| 629 |
+
|
| 630 |
+
sec_gt_batch = torch.cat(gt_img_batch[sec], dim=0)
|
| 631 |
+
sec_exp_batch = torch.cat(exp_img_batch[sec], dim=0)
|
| 632 |
+
if torch.isfinite(sec_gt_batch).all() and torch.isfinite(sec_exp_batch).all():
|
| 633 |
+
fid_loss_fn[sec].update(images=sec_gt_batch, is_real=True)
|
| 634 |
+
fid_loss_fn[sec].update(images=sec_exp_batch, is_real=False)
|
| 635 |
+
psnr_vals = _psnr_from_tensors(sec_gt_batch, sec_exp_batch, data_range=1.0) # (N,)
|
| 636 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_psnr_{sec}'].update(psnr_vals.mean(), n=1)
|
| 637 |
+
|
| 638 |
+
# ---- audio metrics per batch
|
| 639 |
+
if 'a' in modals:
|
| 640 |
+
# FAD (streaming)
|
| 641 |
+
if len(fad_streams) > 0:
|
| 642 |
+
for sec in secs:
|
| 643 |
+
if len(gt_wav_paths_batch[sec]) == 0 and len(exp_wav_paths_batch[sec]) == 0:
|
| 644 |
+
continue
|
| 645 |
+
fad_streams[sec].update_from_wavs(gt_wav_paths_batch[sec], is_real=True)
|
| 646 |
+
fad_streams[sec].update_from_wavs(exp_wav_paths_batch[sec], is_real=False)
|
| 647 |
+
|
| 648 |
+
# LSD / SSIM / MelCos / dRMS-db
|
| 649 |
+
_AUDIO_SR = 16000
|
| 650 |
+
for sec in secs:
|
| 651 |
+
gt_list = gt_wav_paths_batch[sec]
|
| 652 |
+
exp_list = exp_wav_paths_batch[sec]
|
| 653 |
+
if len(gt_list) == 0 or len(exp_list) == 0:
|
| 654 |
+
continue
|
| 655 |
+
pair_cnt = min(len(gt_list), len(exp_list))
|
| 656 |
+
if pair_cnt == 0:
|
| 657 |
+
continue
|
| 658 |
+
|
| 659 |
+
lsd_L, lsd_R, ssim_L, ssim_R = [], [], [], []
|
| 660 |
+
mel_L, mel_R = [], []
|
| 661 |
+
|
| 662 |
+
mel_lsd_L, mel_lsd_R = [], []
|
| 663 |
+
mel_ssim_L, mel_ssim_R = [], []
|
| 664 |
+
|
| 665 |
+
sispec_nl_L, sispec_nl_R = [], []
|
| 666 |
+
sispec_log_L, sispec_log_R = [], []
|
| 667 |
+
mel_sispec_nl_L, mel_sispec_n_R = [], []
|
| 668 |
+
mel_sispec_log_L, mel_sispec_log_R = [], []
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
for i in range(pair_cnt):
|
| 672 |
+
gpath = gt_list[i]
|
| 673 |
+
epath = exp_list[i]
|
| 674 |
+
try:
|
| 675 |
+
g_st = _load_librosa_stereo(gpath, _AUDIO_SR) # (2,T)
|
| 676 |
+
e_st = _load_librosa_stereo(epath, _AUDIO_SR) # (2,T)
|
| 677 |
+
|
| 678 |
+
if args.mono:
|
| 679 |
+
g_mono = g_st.mean(axis=0)
|
| 680 |
+
e_mono = e_st.mean(axis=0)
|
| 681 |
+
|
| 682 |
+
# LSD/SSIM
|
| 683 |
+
gt_sp = _wav_to_spectrogram(g_mono, rate=_AUDIO_SR)
|
| 684 |
+
ex_sp = _wav_to_spectrogram(e_mono, rate=_AUDIO_SR)
|
| 685 |
+
lsd_val = _lsd_from_specs(ex_sp.clone(), gt_sp.clone())
|
| 686 |
+
ssim_val = _ssim_from_specs(ex_sp.clone(), gt_sp.clone())
|
| 687 |
+
|
| 688 |
+
# MelCos
|
| 689 |
+
mel_val = _mel_cosine_single_channel(e_mono, g_mono, _AUDIO_SR, mel_tf)
|
| 690 |
+
|
| 691 |
+
# mel_lsd & mel_ssim
|
| 692 |
+
mel_lsd_val, mel_ssim_val = _mel_lsd_ssim_single(e_mono, g_mono, mel_tf)
|
| 693 |
+
|
| 694 |
+
# sispec
|
| 695 |
+
sispec_nl = _sispec_from_specs(ex_sp.clone(), gt_sp.clone(), log_domain=False)
|
| 696 |
+
sispec_log = _sispec_from_specs(ex_sp.clone(), gt_sp.clone(), log_domain=True)
|
| 697 |
+
# Mel sispec
|
| 698 |
+
mel_sispec_nl = _sispec_from_specs(ex_m.clone(), gt_m.clone(), log_domain=False)
|
| 699 |
+
mel_sispec_log = _sispec_from_specs(ex_m.clone(), gt_m.clone(), log_domain=True)
|
| 700 |
+
|
| 701 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_lsd_{sec}'].update(lsd_val, n=1)
|
| 702 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_ssim_{sec}'].update(
|
| 703 |
+
torch.tensor(ssim_val), n=1
|
| 704 |
+
)
|
| 705 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_melcos_{sec}'].update(
|
| 706 |
+
torch.tensor(mel_val), n=1
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_lsd_{sec}'].update(
|
| 710 |
+
torch.tensor(float(mel_lsd_val)), n=1
|
| 711 |
+
)
|
| 712 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_ssim_{sec}'].update(
|
| 713 |
+
torch.tensor(float(mel_ssim_val)), n=1
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_non_log_sispec_{sec}'].update(
|
| 717 |
+
torch.tensor(float(sispec_nl)), n=1
|
| 718 |
+
)
|
| 719 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_sispec_{sec}'].update(
|
| 720 |
+
torch.tensor(float(sispec_log)), n=1
|
| 721 |
+
)
|
| 722 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_final_non_log_mel_sispec_{sec}'].update(
|
| 723 |
+
torch.tensor(float(mel_sispec_nl)), n=1
|
| 724 |
+
)
|
| 725 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_sispec_{sec}'].update(
|
| 726 |
+
torch.tensor(float(mel_sispec_log)), n=1
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
else:
|
| 731 |
+
for ch, (acc_lsd, acc_ssim, acc_mel,
|
| 732 |
+
acc_mel_lsd, acc_mel_ssim,
|
| 733 |
+
acc_sispec_nl, acc_sispec_log,
|
| 734 |
+
acc_mel_sispec_nl, acc_mel_sispec_log) in enumerate([
|
| 735 |
+
(lsd_L, ssim_L, mel_L, mel_lsd_L, mel_ssim_L, sispec_nl_L, sispec_log_L, mel_sispec_nl_L, mel_sispec_log_L),
|
| 736 |
+
(lsd_R, ssim_R, mel_R, mel_lsd_R, mel_ssim_R, sispec_nl_R, sispec_log_R, mel_sispec_n_R, mel_sispec_log_R),
|
| 737 |
+
]):
|
| 738 |
+
g = g_st[ch]; e = e_st[ch]
|
| 739 |
+
# LSD/SSIM
|
| 740 |
+
gt_sp = _wav_to_spectrogram(g, rate=_AUDIO_SR)
|
| 741 |
+
ex_sp = _wav_to_spectrogram(e, rate=_AUDIO_SR)
|
| 742 |
+
acc_lsd.append(float(_lsd_from_specs(ex_sp.clone(), gt_sp.clone())))
|
| 743 |
+
acc_ssim.append(float(_ssim_from_specs(ex_sp.clone(), gt_sp.clone())))
|
| 744 |
+
# MelCos
|
| 745 |
+
acc_mel.append(_mel_cosine_single_channel(e, g, _AUDIO_SR, mel_tf))
|
| 746 |
+
|
| 747 |
+
# mel_lsd & mel_ssim
|
| 748 |
+
mel_lsd_val, mel_ssim_val = _mel_lsd_ssim_single(e, g, mel_tf)
|
| 749 |
+
acc_mel_lsd.append(mel_lsd_val)
|
| 750 |
+
acc_mel_ssim.append(mel_ssim_val)
|
| 751 |
+
|
| 752 |
+
# sispec
|
| 753 |
+
acc_sispec_nl.append( float(_sispec_from_specs(ex_sp.clone(), gt_sp.clone(), log_domain=False)) )
|
| 754 |
+
acc_sispec_log.append( float(_sispec_from_specs(ex_sp.clone(), gt_sp.clone(), log_domain=True)) )
|
| 755 |
+
# Mel
|
| 756 |
+
est_mag = np.abs(librosa.stft(e, n_fft=743, hop_length=160))
|
| 757 |
+
ref_mag = np.abs(librosa.stft(g, n_fft=743, hop_length=160))
|
| 758 |
+
est_mel = mel_tf(torch.from_numpy(est_mag).float()) # [M,T]
|
| 759 |
+
ref_mel = mel_tf(torch.from_numpy(ref_mag).float()) # [M,T]
|
| 760 |
+
ex_m = est_mel.T.unsqueeze(0).unsqueeze(0) # [1,1,T,M]
|
| 761 |
+
gt_m = ref_mel.T.unsqueeze(0).unsqueeze(0) # [1,1,T,M]
|
| 762 |
+
# sispec(Mel, non_log / log)
|
| 763 |
+
acc_mel_sispec_nl.append( float(_sispec_from_specs(ex_m.clone(), gt_m.clone(), log_domain=False)) )
|
| 764 |
+
acc_mel_sispec_log.append( float(_sispec_from_specs(ex_m.clone(), gt_m.clone(), log_domain=True)) )
|
| 765 |
+
|
| 766 |
+
except Exception:
|
| 767 |
+
pass
|
| 768 |
+
|
| 769 |
+
if not args.mono:
|
| 770 |
+
def _maybe_mean(x):
|
| 771 |
+
return float(np.mean(x)) if len(x) > 0 else None
|
| 772 |
+
|
| 773 |
+
v = _maybe_mean(lsd_L); w = _maybe_mean(lsd_R)
|
| 774 |
+
if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_lsdL_{sec}'].update(torch.tensor(v), n=1)
|
| 775 |
+
if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_lsdR_{sec}'].update(torch.tensor(w), n=1)
|
| 776 |
+
if v is not None and w is not None:
|
| 777 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_lsd_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
| 778 |
+
|
| 779 |
+
v = _maybe_mean(ssim_L); w = _maybe_mean(ssim_R)
|
| 780 |
+
if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_ssimL_{sec}'].update(torch.tensor(v), n=1)
|
| 781 |
+
if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_ssimR_{sec}'].update(torch.tensor(w), n=1)
|
| 782 |
+
if v is not None and w is not None:
|
| 783 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_ssim_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
| 784 |
+
|
| 785 |
+
v = _maybe_mean(mel_L); w = _maybe_mean(mel_R)
|
| 786 |
+
if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_melcosL_{sec}'].update(torch.tensor(v), n=1)
|
| 787 |
+
if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_melcosR_{sec}'].update(torch.tensor(w), n=1)
|
| 788 |
+
if v is not None and w is not None:
|
| 789 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_melcos_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
| 790 |
+
|
| 791 |
+
v = _maybe_mean(mel_lsd_L); w = _maybe_mean(mel_lsd_R)
|
| 792 |
+
if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_lsdL_{sec}'].update(torch.tensor(v), n=1)
|
| 793 |
+
if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_lsdR_{sec}'].update(torch.tensor(w), n=1)
|
| 794 |
+
if v is not None and w is not None:
|
| 795 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_lsd_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
| 796 |
+
|
| 797 |
+
v = _maybe_mean(mel_ssim_L); w = _maybe_mean(mel_ssim_R)
|
| 798 |
+
if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_ssimL_{sec}'].update(torch.tensor(v), n=1)
|
| 799 |
+
if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_ssimR_{sec}'].update(torch.tensor(w), n=1)
|
| 800 |
+
if v is not None and w is not None:
|
| 801 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_ssim_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
| 802 |
+
|
| 803 |
+
v = _maybe_mean(sispec_nl_L); w = _maybe_mean(sispec_nl_R)
|
| 804 |
+
if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_non_log_sispecL_{sec}'].update(torch.tensor(v), n=1)
|
| 805 |
+
if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_non_log_sispecR_{sec}'].update(torch.tensor(w), n=1)
|
| 806 |
+
if v is not None and w is not None:
|
| 807 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_non_log_sispec_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
| 808 |
+
|
| 809 |
+
v = _maybe_mean(sispec_log_L); w = _maybe_mean(sispec_log_R)
|
| 810 |
+
if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_sispecL_{sec}'].update(torch.tensor(v), n=1)
|
| 811 |
+
if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_sispecR_{sec}'].update(torch.tensor(w), n=1)
|
| 812 |
+
if v is not None and w is not None:
|
| 813 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_sispec_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
| 814 |
+
|
| 815 |
+
v = _maybe_mean(mel_sispec_nl_L); w = _maybe_mean(mel_sispec_n_R)
|
| 816 |
+
if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_non_log_mel_sispecL_{sec}'].update(torch.tensor(v), n=1)
|
| 817 |
+
if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_non_log_mel_sispecR_{sec}'].update(torch.tensor(w), n=1)
|
| 818 |
+
if v is not None and w is not None:
|
| 819 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_final_non_log_mel_sispec_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
| 820 |
+
|
| 821 |
+
v = _maybe_mean(mel_sispec_log_L); w = _maybe_mean(mel_sispec_log_R)
|
| 822 |
+
if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_sispecL_{sec}'].update(torch.tensor(v), n=1)
|
| 823 |
+
if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_sispecR_{sec}'].update(torch.tensor(w), n=1)
|
| 824 |
+
if v is not None and w is not None:
|
| 825 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_sispec_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
| 826 |
+
for s in secs_py:
|
| 827 |
+
pairs = denorm_pairs_by_sec[s]
|
| 828 |
+
if not pairs:
|
| 829 |
+
continue
|
| 830 |
+
arr = np.asarray(pairs, dtype=np.float32)
|
| 831 |
+
mask = np.isfinite(arr).all(axis=1)
|
| 832 |
+
if not np.any(mask):
|
| 833 |
+
continue
|
| 834 |
+
se_mean = float(np.mean((arr[mask, 1] - arr[mask, 0]) ** 2))
|
| 835 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_denorm_mse_{s}'].update(
|
| 836 |
+
torch.tensor(se_mean), n=1
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
if 'v' in modals:
|
| 840 |
+
feature_dim = 2048
|
| 841 |
+
sec_list = [int(s) for s in secs]
|
| 842 |
+
tmp_dir = Path(os.path.join(args.exp_dir, ".fid_tmp"))
|
| 843 |
+
if dist_torch.is_initialized():
|
| 844 |
+
if dist_torch.get_rank() == 0:
|
| 845 |
+
tmp_dir.mkdir(parents=True, exist_ok=True)
|
| 846 |
+
dist_torch.barrier()
|
| 847 |
+
else:
|
| 848 |
+
tmp_dir.mkdir(parents=True, exist_ok=True)
|
| 849 |
+
if dist_torch.is_initialized():
|
| 850 |
+
my_rank = dist_torch.get_rank()
|
| 851 |
+
world_size = dist_torch.get_world_size()
|
| 852 |
+
else:
|
| 853 |
+
my_rank = 0
|
| 854 |
+
world_size = 1
|
| 855 |
+
|
| 856 |
+
for s in sec_list:
|
| 857 |
+
fid_m = fid_loss_fn[s]
|
| 858 |
+
state = {
|
| 859 |
+
"real_sum": fid_m.real_sum.detach().to("cpu", torch.float64),
|
| 860 |
+
"real_cov_sum": fid_m.real_cov_sum.detach().to("cpu", torch.float64),
|
| 861 |
+
"fake_sum": fid_m.fake_sum.detach().to("cpu", torch.float64),
|
| 862 |
+
"fake_cov_sum": fid_m.fake_cov_sum.detach().to("cpu", torch.float64),
|
| 863 |
+
"num_real_images": torch.tensor(int(fid_m.num_real_images.item()), dtype=torch.int64),
|
| 864 |
+
"num_fake_images": torch.tensor(int(fid_m.num_fake_images.item()), dtype=torch.int64),
|
| 865 |
+
}
|
| 866 |
+
out_path = tmp_dir / f"fid_sec{s}_rank{my_rank}.pt"
|
| 867 |
+
torch.save(state, out_path)
|
| 868 |
+
if dist_torch.is_initialized():
|
| 869 |
+
dist_torch.barrier()
|
| 870 |
+
if (not dist_torch.is_initialized()) or my_rank == 0:
|
| 871 |
+
for s in sec_list:
|
| 872 |
+
agg = {
|
| 873 |
+
"real_sum": torch.zeros(feature_dim, dtype=torch.float64),
|
| 874 |
+
"real_cov_sum": torch.zeros((feature_dim, feature_dim), dtype=torch.float64),
|
| 875 |
+
"fake_sum": torch.zeros(feature_dim, dtype=torch.float64),
|
| 876 |
+
"fake_cov_sum": torch.zeros((feature_dim, feature_dim), dtype=torch.float64),
|
| 877 |
+
"num_real_images": torch.tensor(0, dtype=torch.int64),
|
| 878 |
+
"num_fake_images": torch.tensor(0, dtype=torch.int64),
|
| 879 |
+
}
|
| 880 |
+
for r in range(world_size):
|
| 881 |
+
p = tmp_dir / f"fid_sec{s}_rank{r}.pt"
|
| 882 |
+
if not p.exists():
|
| 883 |
+
continue
|
| 884 |
+
st = torch.load(p, map_location="cpu")
|
| 885 |
+
agg["real_sum"] += st["real_sum"]
|
| 886 |
+
agg["real_cov_sum"] += st["real_cov_sum"]
|
| 887 |
+
agg["fake_sum"] += st["fake_sum"]
|
| 888 |
+
agg["fake_cov_sum"] += st["fake_cov_sum"]
|
| 889 |
+
agg["num_real_images"] += st["num_real_images"]
|
| 890 |
+
agg["num_fake_images"] += st["num_fake_images"]
|
| 891 |
+
fid_m = fid_loss_fn[s]
|
| 892 |
+
fid_m.real_sum = agg["real_sum"].to(fid_m.device, fid_m.real_sum.dtype)
|
| 893 |
+
fid_m.real_cov_sum = agg["real_cov_sum"].to(fid_m.device, fid_m.real_cov_sum.dtype)
|
| 894 |
+
fid_m.fake_sum = agg["fake_sum"].to(fid_m.device, fid_m.fake_sum.dtype)
|
| 895 |
+
fid_m.fake_cov_sum = agg["fake_cov_sum"].to(fid_m.device, fid_m.fake_cov_sum.dtype)
|
| 896 |
+
fid_m.num_real_images = torch.tensor(
|
| 897 |
+
int(agg["num_real_images"].item()), device=fid_m.device, dtype=fid_m.num_real_images.dtype
|
| 898 |
+
)
|
| 899 |
+
fid_m.num_fake_images = torch.tensor(
|
| 900 |
+
int(agg["num_fake_images"].item()), device=fid_m.device, dtype=fid_m.num_fake_images.dtype
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
try:
|
| 904 |
+
val = float(fid_m.compute().item())
|
| 905 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_fid_{s}'].update(val, n=1)
|
| 906 |
+
except Exception as e:
|
| 907 |
+
print(f"[WARN] FID compute failed at sec={s}: {e}")
|
| 908 |
+
for s in sec_list:
|
| 909 |
+
for r in range(world_size):
|
| 910 |
+
p = tmp_dir / f"fid_sec{s}_rank{r}.pt"
|
| 911 |
+
try:
|
| 912 |
+
if p.exists():
|
| 913 |
+
p.unlink()
|
| 914 |
+
except Exception:
|
| 915 |
+
pass
|
| 916 |
+
try:
|
| 917 |
+
tmp_dir.rmdir()
|
| 918 |
+
except Exception:
|
| 919 |
+
pass
|
| 920 |
+
if dist_torch.is_initialized():
|
| 921 |
+
dist_torch.barrier()
|
| 922 |
+
|
| 923 |
+
if 'a' in modals and len(fad_streams) > 0:
|
| 924 |
+
for sec in secs:
|
| 925 |
+
try:
|
| 926 |
+
if stereo_mode:
|
| 927 |
+
fad_L, fad_R, fad_avg = fad_streams[sec].compute()
|
| 928 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_fadL_{sec}'].update(fad_L, n=1)
|
| 929 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_fadR_{sec}'].update(fad_R, n=1)
|
| 930 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_fad_{sec}'].update(fad_avg, n=1)
|
| 931 |
+
else:
|
| 932 |
+
fad_val = float(fad_streams[sec].compute())
|
| 933 |
+
metric_logger.meters[f'{dataset_name}_{eval_name}_fad_{sec}'].update(fad_val, n=1)
|
| 934 |
+
except Exception as e:
|
| 935 |
+
if rank == 0:
|
| 936 |
+
print(f"[WARN] FAD compute failed at sec={sec}: {e}")
|
| 937 |
+
continue
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
# -----------------------------
|
| 941 |
+
# Save
|
| 942 |
+
# -----------------------------
|
| 943 |
+
def save_metric_to_disk(metric_logger, log_p, rank):
|
| 944 |
+
if dist_torch.is_initialized():
|
| 945 |
+
metric_logger.synchronize_between_processes()
|
| 946 |
+
if rank == 0:
|
| 947 |
+
log_stats = {k: float(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
| 948 |
+
os.makedirs(os.path.dirname(log_p), exist_ok=True)
|
| 949 |
+
with open(log_p, 'w') as json_file:
|
| 950 |
+
json.dump(log_stats, json_file, indent=4)
|
| 951 |
+
print(f"[OK] Metrics saved to: {log_p}")
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
# -----------------------------
|
| 955 |
+
# Main
|
| 956 |
+
# -----------------------------
|
| 957 |
+
def main(args):
|
| 958 |
+
rank, world_size, local_rank = setup_distributed()
|
| 959 |
+
device = f"cuda:{local_rank}" if world_size > 1 else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 960 |
+
torch.backends.cudnn.benchmark = True
|
| 961 |
+
|
| 962 |
+
dataset_name = args.dataset
|
| 963 |
+
secs = np.array([i for i in range(1, 17)], dtype=int)
|
| 964 |
+
|
| 965 |
+
# vision metrics (will only be used if 'v' in modals)
|
| 966 |
+
lpips_loss_fn = get_loss_fn('lpips', secs, device)
|
| 967 |
+
dreamsim_loss_fn = get_loss_fn('dreamsim', secs, device)
|
| 968 |
+
fid_metrics_vision = get_loss_fn('fid', secs, device)
|
| 969 |
+
|
| 970 |
+
try:
|
| 971 |
+
metric_logger = dist.MetricLogger(delimiter=" ")
|
| 972 |
+
if rank == 0:
|
| 973 |
+
print(f"Evaluating {args.eval_name} {dataset_name} | modals = {args.modals}")
|
| 974 |
+
|
| 975 |
+
time_loss_fns = (lpips_loss_fn, dreamsim_loss_fn, fid_metrics_vision)
|
| 976 |
+
|
| 977 |
+
with torch.no_grad():
|
| 978 |
+
evaluate(
|
| 979 |
+
args=args,
|
| 980 |
+
dataset_name=dataset_name,
|
| 981 |
+
eval_type=args.eval_name,
|
| 982 |
+
metric_logger=metric_logger,
|
| 983 |
+
loss_fns=time_loss_fns,
|
| 984 |
+
gt_dir=args.gt_dir,
|
| 985 |
+
exp_dir=args.exp_dir,
|
| 986 |
+
secs=secs,
|
| 987 |
+
device=device,
|
| 988 |
+
rank=rank,
|
| 989 |
+
world_size=world_size,
|
| 990 |
+
modals=args.modals
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
output_fn = os.path.join(args.exp_dir, f'{dataset_name}_{args.eval_name}.json')
|
| 994 |
+
save_metric_to_disk(metric_logger, output_fn, rank)
|
| 995 |
+
|
| 996 |
+
except Exception as e:
|
| 997 |
+
if rank == 0:
|
| 998 |
+
print(e)
|
| 999 |
+
finally:
|
| 1000 |
+
if dist_torch.is_initialized():
|
| 1001 |
+
dist_torch.barrier()
|
| 1002 |
+
dist_torch.destroy_process_group()
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
# -----------------------------
|
| 1006 |
+
# CLI
|
| 1007 |
+
# -----------------------------
|
| 1008 |
+
if __name__ == "__main__":
|
| 1009 |
+
parser = argparse.ArgumentParser(allow_abbrev=False)
|
| 1010 |
+
|
| 1011 |
+
parser.add_argument("--batch_size", type=int, default=64, help="batch size")
|
| 1012 |
+
parser.add_argument("--gt_dir", type=str, required=True, help="gt directory")
|
| 1013 |
+
parser.add_argument("--exp_dir", type=str, required=True, help="experiment directory (also save json here)")
|
| 1014 |
+
parser.add_argument("--eval_name", type=str, default='time', choices=['time', 'rollout'], help="eval type")
|
| 1015 |
+
parser.add_argument("--dataset", type=str, required=True, help="dataset name (for metric keys & json name)")
|
| 1016 |
+
parser.add_argument("--modals", type=str, default="av", choices=["a", "v", "av"],
|
| 1017 |
+
help="a=audio only (wav), v= image only (png), av=both")
|
| 1018 |
+
|
| 1019 |
+
# FAD options
|
| 1020 |
+
parser.add_argument("--fad_model", type=str, default="vggish",
|
| 1021 |
+
choices=["vggish", "pann", "clap", "encodec"],
|
| 1022 |
+
help="embedding model for FAD")
|
| 1023 |
+
parser.add_argument("--fad_sr", type=int, default=16000,
|
| 1024 |
+
help="sampling rate for FAD")
|
| 1025 |
+
|
| 1026 |
+
# Stereo VGGish FAD options
|
| 1027 |
+
parser.add_argument("--mono", action="store_true",
|
| 1028 |
+
help="default as stereo, add --mono to mono")
|
| 1029 |
+
parser.add_argument("--fad_pad_sec", type=float, default=1.0,
|
| 1030 |
+
help="pad the input of VGGish to x seconds")
|
| 1031 |
+
|
| 1032 |
+
args = parser.parse_args()
|
| 1033 |
+
main(args)
|
inference_avwm.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
from distributed import init_distributed
|
| 7 |
+
import torch
|
| 8 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 9 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 10 |
+
|
| 11 |
+
import yaml
|
| 12 |
+
import argparse
|
| 13 |
+
import os
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
from diffusion import create_diffusion
|
| 17 |
+
from diffusers.models import AutoencoderKL
|
| 18 |
+
|
| 19 |
+
import misc
|
| 20 |
+
import distributed as dist
|
| 21 |
+
from models import AVCDiT_models
|
| 22 |
+
from datasets import EvalDataset
|
| 23 |
+
from PIL import Image
|
| 24 |
+
from soundstream import SoundStream
|
| 25 |
+
import torchaudio
|
| 26 |
+
from skimage.measure import block_reduce
|
| 27 |
+
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
import librosa
|
| 30 |
+
import time
|
| 31 |
+
import warnings
|
| 32 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 33 |
+
from collections import defaultdict
|
| 34 |
+
import json
|
| 35 |
+
|
| 36 |
+
def save_image(output_file, img, unnormalize_img):
|
| 37 |
+
img = img.detach().cpu()
|
| 38 |
+
if unnormalize_img:
|
| 39 |
+
img = misc.unnormalize(img)
|
| 40 |
+
|
| 41 |
+
img = img * 255
|
| 42 |
+
img = img.byte()
|
| 43 |
+
image = Image.fromarray(img.permute(1, 2, 0).numpy(), mode='RGB')
|
| 44 |
+
|
| 45 |
+
image.save(output_file)
|
| 46 |
+
|
| 47 |
+
def save_audio(output_file, audio_tensor, sample_rate):
|
| 48 |
+
audio_tensor = audio_tensor.detach().cpu()
|
| 49 |
+
if audio_tensor.ndim == 1:
|
| 50 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
| 51 |
+
torchaudio.save(output_file, audio_tensor.to(torch.float32), sample_rate)
|
| 52 |
+
|
| 53 |
+
def get_dataset_eval(config, dataset_name, eval_type, predefined_index=True):
|
| 54 |
+
data_config = config["eval_datasets"][dataset_name]
|
| 55 |
+
if predefined_index:
|
| 56 |
+
predefined_index = f"data_splits/{dataset_name}/test/{eval_type}.pkl"
|
| 57 |
+
else:
|
| 58 |
+
predefined_index=None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
dataset = EvalDataset(
|
| 62 |
+
data_folder=data_config["data_folder"],
|
| 63 |
+
data_split_folder=data_config["test"],
|
| 64 |
+
dataset_name=dataset_name,
|
| 65 |
+
image_size=config["image_size"],
|
| 66 |
+
min_dist_cat=config["eval_distance"]["eval_min_dist_cat"],
|
| 67 |
+
max_dist_cat=config["eval_distance"]["eval_max_dist_cat"],
|
| 68 |
+
len_traj_pred=config["eval_len_traj_pred"],
|
| 69 |
+
traj_stride=config["traj_stride"],
|
| 70 |
+
context_size=config["eval_context_size"],
|
| 71 |
+
normalize=config["normalize"],
|
| 72 |
+
transform=misc.transform,
|
| 73 |
+
goals_per_obs=4,
|
| 74 |
+
predefined_index=predefined_index,
|
| 75 |
+
traj_names='traj_names.txt'
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return dataset
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@torch.no_grad()
|
| 82 |
+
def model_forward_wrapper_v(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False):
|
| 83 |
+
model, diffusion, vae = all_models
|
| 84 |
+
x = curr_obs.to(device)
|
| 85 |
+
y = curr_delta.to(device)
|
| 86 |
+
|
| 87 |
+
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
|
| 88 |
+
B, T = x.shape[:2]
|
| 89 |
+
|
| 90 |
+
if rel_t is None:
|
| 91 |
+
rel_t = (torch.ones(B)* (1. / 128.)).to(device)
|
| 92 |
+
rel_t *= num_timesteps
|
| 93 |
+
|
| 94 |
+
x = x.flatten(0,1)
|
| 95 |
+
x = vae.encode(x).latent_dist.sample().mul_(0.18215).unflatten(0, (B, T))
|
| 96 |
+
x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3], x.shape[4]).flatten(0, 1)
|
| 97 |
+
z = torch.randn(B*num_goals, 4, latent_size, latent_size, device=device)
|
| 98 |
+
y = y.flatten(0, 1)
|
| 99 |
+
model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t)
|
| 100 |
+
samples = diffusion.p_sample_loop(
|
| 101 |
+
model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device
|
| 102 |
+
)
|
| 103 |
+
samples = vae.decode(samples / 0.18215).sample
|
| 104 |
+
|
| 105 |
+
return torch.clip(samples, -1., 1.)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@torch.no_grad()
|
| 109 |
+
def model_forward_wrapper_a(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False):
|
| 110 |
+
model, diffusion, sstream = all_models
|
| 111 |
+
x = curr_obs.to(device)
|
| 112 |
+
y = curr_delta.to(device)
|
| 113 |
+
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
|
| 114 |
+
B, T = x.shape[:2]
|
| 115 |
+
if rel_t is None:
|
| 116 |
+
rel_t = (torch.ones(B)* (1. / 128.)).to(device)
|
| 117 |
+
rel_t *= num_timesteps
|
| 118 |
+
x = x.flatten(0,1)
|
| 119 |
+
x = sstream.encoder(x).unflatten(0, (B, T))
|
| 120 |
+
x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3]).flatten(0, 1)
|
| 121 |
+
z = torch.randn(B*num_goals, 16, 181, device=device)
|
| 122 |
+
y = y.flatten(0, 1)
|
| 123 |
+
model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t)
|
| 124 |
+
samples = diffusion.p_sample_loop(
|
| 125 |
+
model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device
|
| 126 |
+
)
|
| 127 |
+
# REWARD TOKEN
|
| 128 |
+
patch_tok = samples[..., -1:] # [N, 64, 1]
|
| 129 |
+
diff_pred = patch_tok.mean(dim=1, keepdim=True) # [N, 1]
|
| 130 |
+
samples = samples[..., :-1]
|
| 131 |
+
# AUDIO TOKENS
|
| 132 |
+
quantized, _, _ = sstream.quantizer(samples.permute(0, 2, 1)) # [1, T', D]
|
| 133 |
+
samples = sstream.decoder(quantized.permute(0, 2, 1))
|
| 134 |
+
return samples, diff_pred
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@torch.no_grad()
|
| 138 |
+
def model_forward_wrapper_av(all_models, curr_obs, curr_delta, num_timesteps, latent_size, device, num_cond, num_goals=1, rel_t=None, progress=False):
|
| 139 |
+
model, diffusion, vae, sstream = all_models
|
| 140 |
+
x_v, x_a = curr_obs
|
| 141 |
+
x_v = x_v.to(device)
|
| 142 |
+
x_a = x_a.to(device)
|
| 143 |
+
y = curr_delta.to(device)
|
| 144 |
+
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
|
| 145 |
+
B, T_v = x_v.shape[:2]
|
| 146 |
+
B, T_a = x_a.shape[:2]
|
| 147 |
+
|
| 148 |
+
if rel_t is None:
|
| 149 |
+
rel_t = (torch.ones(B)* (1. / 128.)).to(device)
|
| 150 |
+
rel_t *= num_timesteps
|
| 151 |
+
x_v = x_v.flatten(0,1)
|
| 152 |
+
x_a = x_a.flatten(0,1)
|
| 153 |
+
x_v = vae.encode(x_v).latent_dist.sample().mul_(0.18215).unflatten(0, (B, T_v))
|
| 154 |
+
x_a = sstream.encoder(x_a).unflatten(0, (B, T_a))
|
| 155 |
+
x_v_cond = x_v[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_v.shape[2], x_v.shape[3], x_v.shape[4]).flatten(0, 1)
|
| 156 |
+
x_a_cond = x_a[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a.shape[2], x_a.shape[3]).flatten(0, 1)
|
| 157 |
+
z_v = torch.randn(B*num_goals, 4, latent_size, latent_size, device=device)
|
| 158 |
+
z_a = torch.randn(B*num_goals, 16, 181, device=device) #TODO
|
| 159 |
+
y = y.flatten(0, 1)
|
| 160 |
+
model_kwargs = dict(y=y, x_v_cond=x_v_cond, x_a_cond=x_a_cond, rel_t=rel_t)
|
| 161 |
+
samples_v, samples_a = diffusion.p_sample_loop(
|
| 162 |
+
model.forward, z_v.shape, z_a.shape, z_v, z_a, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, device=device
|
| 163 |
+
)
|
| 164 |
+
patch_tok = samples_a[..., -1:] # [N, 16, 1]
|
| 165 |
+
diff_pred = patch_tok.mean(dim=1, keepdim=True) # [N, 1]
|
| 166 |
+
samples_a = samples_a[..., :-1]
|
| 167 |
+
samples_v = vae.decode(samples_v / 0.18215).sample
|
| 168 |
+
quantized, _, _ = sstream.quantizer(samples_a.permute(0, 2, 1)) # [1, T', D]
|
| 169 |
+
samples_a = sstream.decoder(quantized.permute(0, 2, 1))
|
| 170 |
+
return torch.clip(samples_v, -1., 1.), samples_a, diff_pred
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def generate_rollout(args, output_dir, rollout_frames, idxs, all_models, obs_av, gt_av, diffs_seq, delta, num_cond, device):
|
| 174 |
+
(obs_image, obs_audio, orig_obs_audio)=obs_av
|
| 175 |
+
(gt_image, gt_audio, orig_gt_audio)=gt_av
|
| 176 |
+
|
| 177 |
+
gt_image = gt_image[:,:rollout_frames]
|
| 178 |
+
gt_audio = gt_audio[:,:rollout_frames]
|
| 179 |
+
curr_v = obs_image.to(device)
|
| 180 |
+
curr_a = obs_audio.to(device)
|
| 181 |
+
down_resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000, lowpass_filter_width=64).to(device, dtype=torch.bfloat16)
|
| 182 |
+
episode_records = defaultdict(list)
|
| 183 |
+
value_key = "denorm_gt" if args.gt else "denorm_pred"
|
| 184 |
+
|
| 185 |
+
for i in range(gt_image.shape[1]):
|
| 186 |
+
curr_delta = delta[:, i:i+1].to(device)
|
| 187 |
+
|
| 188 |
+
x_gt_pixels = gt_image[:, i].to(device)
|
| 189 |
+
x_gt_audios_orig = orig_gt_audio[:, i].to(device)
|
| 190 |
+
if args.gt:
|
| 191 |
+
visualize_preds(output_dir, idxs, i+1, x_gt_pixels, x_gt_audios_orig, 16000)
|
| 192 |
+
denorm_gt_vals = denorm_from_tensor(diffs_seq[:, i:i+1, :]) # [B]
|
| 193 |
+
idxs_1d = idxs.detach().view(-1).cpu().numpy()
|
| 194 |
+
for b, sample_idx in enumerate(idxs_1d):
|
| 195 |
+
episode_records[int(sample_idx)].append({"sec": int(i+1), "value": float(denorm_gt_vals[b])})
|
| 196 |
+
else:
|
| 197 |
+
diff_gt = diffs_seq[:, i:i+1, :].unsqueeze(1).to(device)
|
| 198 |
+
x_pred_pixels, x_pred_audios, diff_pred = model_forward_wrapper_av(all_models, (curr_v, curr_a), curr_delta, num_timesteps=1, latent_size=args.latent_size, device=device, num_cond=num_cond, num_goals=1)
|
| 199 |
+
x_pred_audios_orig = down_resampler(x_pred_audios)
|
| 200 |
+
curr_v = torch.cat((curr_v, x_pred_pixels.unsqueeze(1)), dim=1) # append current prediction
|
| 201 |
+
curr_v = curr_v[:, 1:] # remove first observation
|
| 202 |
+
curr_a = torch.cat((curr_a, x_pred_audios.unsqueeze(1)), dim=1) # append current prediction
|
| 203 |
+
curr_a = curr_a[:, 1:] # remove first observation
|
| 204 |
+
denorm_pred_vals = denorm_from_tensor(diff_pred) # [B]
|
| 205 |
+
denorm_gt_vals = denorm_from_tensor(diff_gt) # [B]
|
| 206 |
+
visualize_preds(output_dir, idxs, i+1, x_pred_pixels, x_pred_audios_orig, 16000)
|
| 207 |
+
visualize_compare(output_dir, idxs, i+1,
|
| 208 |
+
x_pred_pixels, x_pred_audios_orig,
|
| 209 |
+
x_gt_pixels, x_gt_audios_orig,
|
| 210 |
+
denorm_pred_vals=denorm_pred_vals,
|
| 211 |
+
denorm_gt_vals=denorm_gt_vals)
|
| 212 |
+
idxs_1d = idxs.detach().view(-1).cpu().numpy()
|
| 213 |
+
for b, sample_idx in enumerate(idxs_1d):
|
| 214 |
+
episode_records[int(sample_idx)].append({"sec": int(i+1), "value": float(denorm_pred_vals[b])})
|
| 215 |
+
|
| 216 |
+
for sample_idx, rows in episode_records.items():
|
| 217 |
+
rows = sorted(rows, key=lambda r: r["sec"])
|
| 218 |
+
sample_folder = os.path.join(output_dir, f"id_{sample_idx}")
|
| 219 |
+
os.makedirs(sample_folder, exist_ok=True)
|
| 220 |
+
out_json = os.path.join(sample_folder, "distance.json")
|
| 221 |
+
compact = [{ "sec": r["sec"], value_key: r["value"] } for r in rows]
|
| 222 |
+
with open(out_json, "w") as f:
|
| 223 |
+
json.dump(compact, f, indent=2)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def generate_time(args, output_dir, idxs, all_models, obs_av, gt_av, diffs_seq, delta, secs, num_cond, device):
|
| 227 |
+
(obs_image, obs_audio, _)=obs_av
|
| 228 |
+
(gt_image, _, orig_gt_audio)=gt_av
|
| 229 |
+
down_resampler = torchaudio.transforms.Resample(orig_freq=48000, new_freq=16000, lowpass_filter_width=64).to(device, dtype=torch.bfloat16)
|
| 230 |
+
episode_records = defaultdict(list) # {sample_idx: [{"sec": int, "value": float}, ...]}
|
| 231 |
+
value_key = "denorm_gt" if args.gt else "denorm_pred"
|
| 232 |
+
|
| 233 |
+
for sec in secs:
|
| 234 |
+
curr_delta = delta[:, :sec].sum(dim=1, keepdim=True)
|
| 235 |
+
x_gt_pixels = gt_image[:, sec-1].to(device)
|
| 236 |
+
x_gt_audios_orig = orig_gt_audio[:, sec-1].to(device)
|
| 237 |
+
if args.gt:
|
| 238 |
+
denorm_gt_vals = denorm_from_tensor(diffs_seq[:, :sec, :].sum(dim=1, keepdim=True)) # [B]
|
| 239 |
+
visualize_preds(output_dir, idxs, sec, x_gt_pixels, x_gt_audios_orig, 16000)
|
| 240 |
+
idxs_1d = idxs.detach().view(-1).cpu().numpy()
|
| 241 |
+
for b, sample_idx in enumerate(idxs_1d):
|
| 242 |
+
episode_records[int(sample_idx)].append({"sec": int(sec), "value": float(denorm_gt_vals[b])})
|
| 243 |
+
else:
|
| 244 |
+
diff_gt = diffs_seq[:, :sec, :].sum(dim=1, keepdim=True).to(device)
|
| 245 |
+
|
| 246 |
+
print(obs_image.shape, obs_audio.shape, curr_delta.shape, obs_image.dtype, obs_audio.dtype, curr_delta.dtype)
|
| 247 |
+
x_pred_pixels, x_pred_audios, diff_pred = model_forward_wrapper_av(all_models, (obs_image, obs_audio) , curr_delta, sec, args.latent_size, num_cond=num_cond, num_goals=1, device=device)
|
| 248 |
+
x_pred_audios_orig = down_resampler(x_pred_audios)
|
| 249 |
+
denorm_pred_vals = denorm_from_tensor(diff_pred) # [B]
|
| 250 |
+
denorm_gt_vals = denorm_from_tensor(diff_gt) # [B]
|
| 251 |
+
|
| 252 |
+
visualize_preds(output_dir, idxs, sec, x_pred_pixels, x_pred_audios_orig, 16000)
|
| 253 |
+
visualize_compare(output_dir, idxs, sec,
|
| 254 |
+
x_pred_pixels, x_pred_audios_orig,
|
| 255 |
+
x_gt_pixels, x_gt_audios_orig,
|
| 256 |
+
denorm_pred_vals=denorm_pred_vals,
|
| 257 |
+
denorm_gt_vals=denorm_gt_vals)
|
| 258 |
+
idxs_1d = idxs.detach().view(-1).cpu().numpy()
|
| 259 |
+
for b, sample_idx in enumerate(idxs_1d):
|
| 260 |
+
episode_records[int(sample_idx)].append({"sec": int(sec), "value": float(denorm_pred_vals[b])})
|
| 261 |
+
for sample_idx, rows in episode_records.items():
|
| 262 |
+
rows = sorted(rows, key=lambda r: r["sec"])
|
| 263 |
+
sample_folder = os.path.join(output_dir, f"id_{sample_idx}")
|
| 264 |
+
os.makedirs(sample_folder, exist_ok=True)
|
| 265 |
+
out_json = os.path.join(sample_folder, "distance.json")
|
| 266 |
+
compact = [{ "sec": r["sec"], value_key: r["value"] } for r in rows]
|
| 267 |
+
with open(out_json, "w") as f:
|
| 268 |
+
json.dump(compact, f, indent=2)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def visualize_preds(output_dir, idxs, sec, x_pred_pixels, x_pred_audios, sample_rate):
|
| 272 |
+
idxs_1d = idxs.detach().view(-1)
|
| 273 |
+
for batch_idx, sample_idx in enumerate(idxs_1d):
|
| 274 |
+
sample_idx = int(sample_idx.item())
|
| 275 |
+
sample_folder = os.path.join(output_dir, f'id_{sample_idx}')
|
| 276 |
+
os.makedirs(sample_folder, exist_ok=True)
|
| 277 |
+
image_file = os.path.join(sample_folder, f'{sec}.png')
|
| 278 |
+
save_image(image_file, x_pred_pixels[batch_idx], True)
|
| 279 |
+
audio_file = os.path.join(sample_folder, f'{sec}.wav')
|
| 280 |
+
save_audio(audio_file, x_pred_audios[batch_idx], sample_rate)
|
| 281 |
+
|
| 282 |
+
def _compute_binaural_spectrogram_np(audio_2ch: np.ndarray):
|
| 283 |
+
def _stft_abs(signal):
|
| 284 |
+
n_fft = 512
|
| 285 |
+
hop_length = 160
|
| 286 |
+
win_length = 400
|
| 287 |
+
stft = np.abs(librosa.stft(signal, n_fft=n_fft, hop_length=hop_length, win_length=win_length))
|
| 288 |
+
stft = block_reduce(stft, block_size=(4, 4), func=np.mean)
|
| 289 |
+
return stft
|
| 290 |
+
L = np.log1p(_stft_abs(audio_2ch[0]))
|
| 291 |
+
R = np.log1p(_stft_abs(audio_2ch[1]))
|
| 292 |
+
spec = np.stack([L, R], axis=-1) # (F,T,2)
|
| 293 |
+
return spec
|
| 294 |
+
|
| 295 |
+
def denorm_from_tensor(t: torch.Tensor, min_v=-20.0, max_v=20.0, scale=0.15) -> torch.Tensor:
|
| 296 |
+
x = t.detach().float().view(t.shape[0], -1)[:, 0]
|
| 297 |
+
n01 = (x + 1.0) / 2.0
|
| 298 |
+
raw = n01 * (max_v - min_v) + min_v
|
| 299 |
+
return raw * scale
|
| 300 |
+
|
| 301 |
+
def visualize_compare(output_dir, idxs, sec,
|
| 302 |
+
x_pred_pixels, x_pred_audios_orig,
|
| 303 |
+
x_gt_pixels, x_gt_audios_orig,
|
| 304 |
+
denorm_pred_vals,
|
| 305 |
+
denorm_gt_vals):
|
| 306 |
+
idxs_np = idxs.detach().view(-1).cpu().numpy()
|
| 307 |
+
|
| 308 |
+
B = x_pred_pixels.shape[0]
|
| 309 |
+
assert x_gt_pixels.shape[0] == B and x_pred_audios_orig.shape[0] == B and x_gt_audios_orig.shape[0] == B
|
| 310 |
+
|
| 311 |
+
for b in range(B):
|
| 312 |
+
sample_idx = int(idxs_np[b])
|
| 313 |
+
sample_folder = os.path.join(output_dir, f'id_{sample_idx}')
|
| 314 |
+
os.makedirs(sample_folder, exist_ok=True)
|
| 315 |
+
out_path = os.path.join(sample_folder, f'compare_{sec}.png')
|
| 316 |
+
def _tensor_to_display_img(x: torch.Tensor):
|
| 317 |
+
x = x.detach().cpu()
|
| 318 |
+
x = misc.unnormalize(x)
|
| 319 |
+
x = (x * 255.0).round().clamp(0, 255)
|
| 320 |
+
x = x.to(torch.uint8).permute(1, 2, 0)
|
| 321 |
+
return x.numpy()
|
| 322 |
+
|
| 323 |
+
pred_img = _tensor_to_display_img(x_pred_pixels[b])
|
| 324 |
+
gt_img = _tensor_to_display_img(x_gt_pixels[b])
|
| 325 |
+
|
| 326 |
+
pred_aud = x_pred_audios_orig[b].detach().cpu().float().numpy()
|
| 327 |
+
gt_aud = x_gt_audios_orig[b].detach().cpu().float().numpy()
|
| 328 |
+
pred_spec = _compute_binaural_spectrogram_np(pred_aud)
|
| 329 |
+
gt_spec = _compute_binaural_spectrogram_np(gt_aud)
|
| 330 |
+
|
| 331 |
+
vmin_L = min(pred_spec[:, :, 0].min(), gt_spec[:, :, 0].min())
|
| 332 |
+
vmax_L = max(pred_spec[:, :, 0].max(), gt_spec[:, :, 0].max())
|
| 333 |
+
vmin_R = min(pred_spec[:, :, 1].min(), gt_spec[:, :, 1].min())
|
| 334 |
+
vmax_R = max(pred_spec[:, :, 1].max(), gt_spec[:, :, 1].max())
|
| 335 |
+
|
| 336 |
+
dn_pred = float(denorm_pred_vals[b]) if denorm_pred_vals is not None else 0
|
| 337 |
+
dn_gt = float(denorm_gt_vals[b]) if denorm_gt_vals is not None else 0
|
| 338 |
+
|
| 339 |
+
fig, axes = plt.subplots(2, 4, figsize=(14, 6), constrained_layout=True)
|
| 340 |
+
|
| 341 |
+
axes[0, 0].imshow(pred_img); axes[0, 0].set_title('pred image'); axes[0, 0].axis('off')
|
| 342 |
+
axes[0, 1].imshow(gt_img); axes[0, 1].set_title('gt image'); axes[0, 1].axis('off')
|
| 343 |
+
|
| 344 |
+
axes[1, 0].axis('off')
|
| 345 |
+
axes[1, 1].axis('off')
|
| 346 |
+
|
| 347 |
+
im_pred_L = axes[0, 2].imshow(pred_spec[:, :, 0], origin='lower', aspect='auto', vmin=vmin_L, vmax=vmax_L)
|
| 348 |
+
axes[0, 2].set_title('pred spec (Left)'); axes[0, 2].set_xticks([]); axes[0, 2].set_yticks([])
|
| 349 |
+
im_gt_L = axes[0, 3].imshow(gt_spec[:, :, 0], origin='lower', aspect='auto', vmin=vmin_L, vmax=vmax_L)
|
| 350 |
+
axes[0, 3].set_title('gt spec (Left)'); axes[0, 3].set_xticks([]); axes[0, 3].set_yticks([])
|
| 351 |
+
im_pred_R = axes[1, 2].imshow(pred_spec[:, :, 1], origin='lower', aspect='auto', vmin=vmin_R, vmax=vmax_R)
|
| 352 |
+
axes[1, 2].set_title('pred spec (Right)'); axes[1, 2].set_xticks([]); axes[1, 2].set_yticks([])
|
| 353 |
+
im_gt_R = axes[1, 3].imshow(gt_spec[:, :, 1], origin='lower', aspect='auto', vmin=vmin_R, vmax=vmax_R)
|
| 354 |
+
axes[1, 3].set_title('gt spec (Right)'); axes[1, 3].set_xticks([]); axes[1, 3].set_yticks([])
|
| 355 |
+
|
| 356 |
+
fig.suptitle(
|
| 357 |
+
f'id={sample_idx}, sec={sec} | denorm(reward_pred)={dn_pred:.4f}, denorm(reward_gt)={dn_gt:.4f}',
|
| 358 |
+
fontsize=11
|
| 359 |
+
)
|
| 360 |
+
plt.savefig(out_path, dpi=180)
|
| 361 |
+
plt.close(fig)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
@torch.no_grad()
|
| 365 |
+
def main(args):
|
| 366 |
+
_, _, device, _ = init_distributed()
|
| 367 |
+
print(args)
|
| 368 |
+
device = torch.device(device)
|
| 369 |
+
num_tasks = dist.get_world_size()
|
| 370 |
+
global_rank = dist.get_rank()
|
| 371 |
+
exp_eval = args.exp
|
| 372 |
+
|
| 373 |
+
# model & config setup
|
| 374 |
+
if args.gt:
|
| 375 |
+
args.save_output_dir = os.path.join(args.output_dir, 'gt')
|
| 376 |
+
else:
|
| 377 |
+
exp_name = os.path.basename(exp_eval).split('.')[0]
|
| 378 |
+
args.save_output_dir = os.path.join(args.output_dir, exp_name)
|
| 379 |
+
|
| 380 |
+
if args.ckp != '0100000':
|
| 381 |
+
args.save_output_dir = args.save_output_dir + "_%s"%(args.ckp)
|
| 382 |
+
|
| 383 |
+
os.makedirs(args.save_output_dir, exist_ok=True)
|
| 384 |
+
|
| 385 |
+
with open("config/eval_config.yaml", "r") as f:
|
| 386 |
+
default_config = yaml.safe_load(f)
|
| 387 |
+
config = default_config
|
| 388 |
+
|
| 389 |
+
with open(exp_eval, "r") as f:
|
| 390 |
+
user_config = yaml.safe_load(f)
|
| 391 |
+
config.update(user_config)
|
| 392 |
+
|
| 393 |
+
eval_len_traj_pred=config["eval_len_traj_pred"]
|
| 394 |
+
if args.rollout_frames==-1:
|
| 395 |
+
args.rollout_frames=eval_len_traj_pred
|
| 396 |
+
assert args.rollout_frames<=eval_len_traj_pred
|
| 397 |
+
latent_size = config['image_size'] // 8
|
| 398 |
+
args.latent_size = config['image_size'] // 8
|
| 399 |
+
|
| 400 |
+
num_cond = config['context_size']
|
| 401 |
+
print("loading")
|
| 402 |
+
model_lst = (None, None, None, None)
|
| 403 |
+
if not args.gt:
|
| 404 |
+
model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4, mode="av")
|
| 405 |
+
ckp = torch.load(f'{config["results_dir"]}/{config["run_name"]}/checkpoints/{args.ckp}.pth.tar', map_location='cpu', weights_only=False)
|
| 406 |
+
print(model.load_state_dict(ckp["ema"], strict=True))
|
| 407 |
+
model.eval()
|
| 408 |
+
model.to(device)
|
| 409 |
+
model = torch.compile(model)
|
| 410 |
+
diffusion = create_diffusion(str(250), dual=True)
|
| 411 |
+
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device)
|
| 412 |
+
|
| 413 |
+
sstream = SoundStream(C=32, D=16, n_q=8, codebook_size=1024).to(device)
|
| 414 |
+
sstream_path=config["tokenizer_a_path"]
|
| 415 |
+
sstream_checkpoint = torch.load(sstream_path, map_location=device)
|
| 416 |
+
sstream.load_state_dict(sstream_checkpoint["model_state"])
|
| 417 |
+
sstream.eval()
|
| 418 |
+
|
| 419 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], find_unused_parameters=False)
|
| 420 |
+
model_lst = (model, diffusion, vae, sstream)
|
| 421 |
+
|
| 422 |
+
# Loading Datasets
|
| 423 |
+
dataset_names = args.datasets.split(',')
|
| 424 |
+
datasets = {}
|
| 425 |
+
|
| 426 |
+
for dataset_name in dataset_names:
|
| 427 |
+
dataset_val = get_dataset_eval(config, dataset_name, args.eval_type, predefined_index=False)
|
| 428 |
+
|
| 429 |
+
if len(dataset_val) % num_tasks != 0:
|
| 430 |
+
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
| 431 |
+
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
| 432 |
+
'equal num of samples per-process.')
|
| 433 |
+
sampler_val = torch.utils.data.DistributedSampler(
|
| 434 |
+
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
|
| 435 |
+
|
| 436 |
+
curr_data_loader = torch.utils.data.DataLoader(
|
| 437 |
+
dataset_val, sampler=sampler_val,
|
| 438 |
+
batch_size=args.batch_size,
|
| 439 |
+
num_workers=args.num_workers,
|
| 440 |
+
pin_memory=True,
|
| 441 |
+
drop_last=False
|
| 442 |
+
)
|
| 443 |
+
datasets[dataset_name] = curr_data_loader
|
| 444 |
+
|
| 445 |
+
print_freq = 1
|
| 446 |
+
header = 'Evaluation: '
|
| 447 |
+
metric_logger = dist.MetricLogger(delimiter=" ")
|
| 448 |
+
|
| 449 |
+
for dataset_name in dataset_names:
|
| 450 |
+
dataset_save_output_dir = os.path.join(args.save_output_dir, dataset_name)
|
| 451 |
+
os.makedirs(dataset_save_output_dir, exist_ok=True)
|
| 452 |
+
curr_data_loader = datasets[dataset_name]
|
| 453 |
+
|
| 454 |
+
for data_iter_step, (idxs, obs_image, gt_image, obs_audio, gt_audio, diffs_seq, delta, orig_obs_audio, orig_gt_audio) in enumerate(metric_logger.log_every(curr_data_loader, print_freq, header)):
|
| 455 |
+
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
|
| 456 |
+
obs_image = obs_image[:, -num_cond:].to(device)
|
| 457 |
+
gt_image = gt_image.to(device)
|
| 458 |
+
obs_audio = obs_audio[:, -num_cond:].to(device)
|
| 459 |
+
gt_audio = gt_audio.to(device)
|
| 460 |
+
orig_obs_audio = orig_obs_audio[:, -num_cond:].to(device)
|
| 461 |
+
orig_gt_audio = orig_gt_audio.to(device)
|
| 462 |
+
|
| 463 |
+
diffs_seq = diffs_seq.to(device)
|
| 464 |
+
obs_av=(obs_image, obs_audio, orig_obs_audio)
|
| 465 |
+
gt_av=(gt_image, gt_audio, orig_gt_audio)
|
| 466 |
+
if args.eval_type == 'rollout':
|
| 467 |
+
curr_rollout_output_dir = os.path.join(dataset_save_output_dir, f'rollout_{args.rollout_frames}frames')
|
| 468 |
+
os.makedirs(curr_rollout_output_dir, exist_ok=True)
|
| 469 |
+
generate_rollout(args, curr_rollout_output_dir, args.rollout_frames, idxs, model_lst, obs_av, gt_av, diffs_seq, delta, num_cond, device)
|
| 470 |
+
elif args.eval_type == 'time':
|
| 471 |
+
if args.time_secs != '':
|
| 472 |
+
secs = np.array([int(sec) for sec in args.time_secs.split(',')])
|
| 473 |
+
else:
|
| 474 |
+
secs = np.array([int(sec) for sec in range(1,args.rollout_frames+1)])
|
| 475 |
+
curr_time_output_dir = os.path.join(dataset_save_output_dir, 'time')
|
| 476 |
+
os.makedirs(curr_time_output_dir, exist_ok=True)
|
| 477 |
+
generate_time(args, curr_time_output_dir, idxs, model_lst, obs_av, gt_av, diffs_seq, delta, secs, num_cond, device)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
if __name__ == "__main__":
|
| 481 |
+
parser = argparse.ArgumentParser()
|
| 482 |
+
|
| 483 |
+
parser.add_argument("--output_dir", type=str, default=None, help="output directory")
|
| 484 |
+
parser.add_argument("--exp", type=str, default=None, help="experiment name")
|
| 485 |
+
parser.add_argument("--ckp", type=str, default='0100000')
|
| 486 |
+
parser.add_argument("--num_sec_eval", type=int, default=5)
|
| 487 |
+
parser.add_argument("--input_fps", type=int, default=4)
|
| 488 |
+
parser.add_argument("--datasets", type=str, default=None, help="dataset name")
|
| 489 |
+
parser.add_argument("--num_workers", type=int, default=8, help="num workers")
|
| 490 |
+
parser.add_argument("--batch_size", type=int, default=16, help="batch size")
|
| 491 |
+
parser.add_argument("--eval_type", type=str, default=None, help="type of evaluation has to be either 'time' or 'rollout'")
|
| 492 |
+
# Rollout Evaluation Args
|
| 493 |
+
parser.add_argument("--time_secs", type=str, default='', help="") #'1,2,3,4'
|
| 494 |
+
parser.add_argument("--rollout_frames", type=int, default=-1, help="")
|
| 495 |
+
parser.add_argument("--gt", type=int, default=0, help="set to 1 to produce ground truth evaluation set")
|
| 496 |
+
args = parser.parse_args()
|
| 497 |
+
|
| 498 |
+
main(args)
|
mel_scale.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
from typing import Optional
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
class MelScale(torch.nn.Module):
|
| 9 |
+
r"""Turn a normal STFT into a mel frequency STFT, using a conversion
|
| 10 |
+
matrix. This uses triangular filter banks.
|
| 11 |
+
|
| 12 |
+
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
|
| 16 |
+
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
|
| 17 |
+
f_min (float, optional): Minimum frequency. (Default: ``0.``)
|
| 18 |
+
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
|
| 19 |
+
n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
|
| 20 |
+
norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
|
| 21 |
+
(area normalization). (Default: ``None``)
|
| 22 |
+
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
|
| 23 |
+
|
| 24 |
+
See also:
|
| 25 |
+
:py:func:`torchaudio.functional.melscale_fbanks` - The function used to
|
| 26 |
+
generate the filter banks.
|
| 27 |
+
"""
|
| 28 |
+
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
|
| 29 |
+
|
| 30 |
+
def __init__(self,
|
| 31 |
+
n_mels: int = 128,
|
| 32 |
+
sample_rate: int = 16000,
|
| 33 |
+
f_min: float = 0.,
|
| 34 |
+
f_max: Optional[float] = None,
|
| 35 |
+
n_stft: int = 201,
|
| 36 |
+
norm: Optional[str] = None,
|
| 37 |
+
mel_scale: str = "htk") -> None:
|
| 38 |
+
super(MelScale, self).__init__()
|
| 39 |
+
self.n_mels = n_mels
|
| 40 |
+
self.sample_rate = sample_rate
|
| 41 |
+
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
|
| 42 |
+
self.f_min = f_min
|
| 43 |
+
self.norm = norm
|
| 44 |
+
self.mel_scale = mel_scale
|
| 45 |
+
|
| 46 |
+
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
|
| 47 |
+
fb = melscale_fbanks(
|
| 48 |
+
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
|
| 49 |
+
self.mel_scale)
|
| 50 |
+
self.register_buffer('fb', fb)
|
| 51 |
+
|
| 52 |
+
def forward(self, specgram: Tensor) -> Tensor:
|
| 53 |
+
r"""
|
| 54 |
+
Args:
|
| 55 |
+
specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
# (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
|
| 62 |
+
mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
|
| 63 |
+
|
| 64 |
+
return mel_specgram
|
| 65 |
+
|
| 66 |
+
def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
|
| 67 |
+
r"""Convert Hz to Mels.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
freqs (float): Frequencies in Hz
|
| 71 |
+
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
mels (float): Frequency in Mels
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
if mel_scale not in ['slaney', 'htk']:
|
| 78 |
+
raise ValueError('mel_scale should be one of "htk" or "slaney".')
|
| 79 |
+
|
| 80 |
+
if mel_scale == "htk":
|
| 81 |
+
return 2595.0 * math.log10(1.0 + (freq / 700.0))
|
| 82 |
+
|
| 83 |
+
# Fill in the linear part
|
| 84 |
+
f_min = 0.0
|
| 85 |
+
f_sp = 200.0 / 3
|
| 86 |
+
|
| 87 |
+
mels = (freq - f_min) / f_sp
|
| 88 |
+
|
| 89 |
+
# Fill in the log-scale part
|
| 90 |
+
min_log_hz = 1000.0
|
| 91 |
+
min_log_mel = (min_log_hz - f_min) / f_sp
|
| 92 |
+
logstep = math.log(6.4) / 27.0
|
| 93 |
+
|
| 94 |
+
if freq >= min_log_hz:
|
| 95 |
+
mels = min_log_mel + math.log(freq / min_log_hz) / logstep
|
| 96 |
+
|
| 97 |
+
return mels
|
| 98 |
+
|
| 99 |
+
def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
|
| 100 |
+
"""Convert mel bin numbers to frequencies.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
mels (Tensor): Mel frequencies
|
| 104 |
+
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
freqs (Tensor): Mels converted in Hz
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
if mel_scale not in ['slaney', 'htk']:
|
| 111 |
+
raise ValueError('mel_scale should be one of "htk" or "slaney".')
|
| 112 |
+
|
| 113 |
+
if mel_scale == "htk":
|
| 114 |
+
return 700.0 * (10.0**(mels / 2595.0) - 1.0)
|
| 115 |
+
|
| 116 |
+
# Fill in the linear scale
|
| 117 |
+
f_min = 0.0
|
| 118 |
+
f_sp = 200.0 / 3
|
| 119 |
+
freqs = f_min + f_sp * mels
|
| 120 |
+
|
| 121 |
+
# And now the nonlinear scale
|
| 122 |
+
min_log_hz = 1000.0
|
| 123 |
+
min_log_mel = (min_log_hz - f_min) / f_sp
|
| 124 |
+
logstep = math.log(6.4) / 27.0
|
| 125 |
+
|
| 126 |
+
log_t = (mels >= min_log_mel)
|
| 127 |
+
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
|
| 128 |
+
|
| 129 |
+
return freqs
|
| 130 |
+
|
| 131 |
+
def _create_triangular_filterbank(
|
| 132 |
+
all_freqs: Tensor,
|
| 133 |
+
f_pts: Tensor,
|
| 134 |
+
) -> Tensor:
|
| 135 |
+
"""Create a triangular filter bank.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
all_freqs (Tensor): STFT freq points of size (`n_freqs`).
|
| 139 |
+
f_pts (Tensor): Filter mid points of size (`n_filter`).
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
|
| 143 |
+
"""
|
| 144 |
+
# Adopted from Librosa
|
| 145 |
+
# calculate the difference between each filter mid point and each stft freq point in hertz
|
| 146 |
+
f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
|
| 147 |
+
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
|
| 148 |
+
# create overlapping triangles
|
| 149 |
+
zero = torch.zeros(1)
|
| 150 |
+
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
|
| 151 |
+
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
|
| 152 |
+
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
|
| 153 |
+
|
| 154 |
+
return fb
|
| 155 |
+
|
| 156 |
+
def melscale_fbanks(
|
| 157 |
+
n_freqs: int,
|
| 158 |
+
f_min: float,
|
| 159 |
+
f_max: float,
|
| 160 |
+
n_mels: int,
|
| 161 |
+
sample_rate: int,
|
| 162 |
+
norm: Optional[str] = None,
|
| 163 |
+
mel_scale: str = "htk",
|
| 164 |
+
) -> Tensor:
|
| 165 |
+
r"""Create a frequency bin conversion matrix.
|
| 166 |
+
|
| 167 |
+
Note:
|
| 168 |
+
For the sake of the numerical compatibility with librosa, not all the coefficients
|
| 169 |
+
in the resulting filter bank has magnitude of 1.
|
| 170 |
+
|
| 171 |
+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png
|
| 172 |
+
:alt: Visualization of generated filter bank
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
n_freqs (int): Number of frequencies to highlight/apply
|
| 176 |
+
f_min (float): Minimum frequency (Hz)
|
| 177 |
+
f_max (float): Maximum frequency (Hz)
|
| 178 |
+
n_mels (int): Number of mel filterbanks
|
| 179 |
+
sample_rate (int): Sample rate of the audio waveform
|
| 180 |
+
norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
|
| 181 |
+
(area normalization). (Default: ``None``)
|
| 182 |
+
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
|
| 186 |
+
meaning number of frequencies to highlight/apply to x the number of filterbanks.
|
| 187 |
+
Each column is a filterbank so that assuming there is a matrix A of
|
| 188 |
+
size (..., ``n_freqs``), the applied result would be
|
| 189 |
+
``A * melscale_fbanks(A.size(-1), ...)``.
|
| 190 |
+
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
if norm is not None and norm != "slaney":
|
| 194 |
+
raise ValueError("norm must be one of None or 'slaney'")
|
| 195 |
+
|
| 196 |
+
# freq bins
|
| 197 |
+
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
|
| 198 |
+
|
| 199 |
+
# calculate mel freq bins
|
| 200 |
+
m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
|
| 201 |
+
m_max = _hz_to_mel(f_max, mel_scale=mel_scale)
|
| 202 |
+
|
| 203 |
+
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
|
| 204 |
+
f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)
|
| 205 |
+
|
| 206 |
+
# create filterbank
|
| 207 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 208 |
+
|
| 209 |
+
if norm is not None and norm == "slaney":
|
| 210 |
+
# Slaney-style mel is scaled to be approx constant energy per channel
|
| 211 |
+
enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
|
| 212 |
+
fb *= enorm.unsqueeze(0)
|
| 213 |
+
|
| 214 |
+
if (fb.max(dim=0).values == 0.).any():
|
| 215 |
+
warnings.warn(
|
| 216 |
+
"At least one mel filterbank has all zero values. "
|
| 217 |
+
f"The value for `n_mels` ({n_mels}) may be set too high. "
|
| 218 |
+
f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
return fb
|
merge_experts.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import yaml
|
| 3 |
+
import argparse
|
| 4 |
+
from models import AVCDiT_models
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def add_exact_keys(mapping, keys):
|
| 8 |
+
for k in keys:
|
| 9 |
+
mapping[k] = k
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def add_mlp_block_keys(mapping, mlp_name, num_blocks):
|
| 13 |
+
for i in range(num_blocks):
|
| 14 |
+
for fc in ["fc1", "fc2"]:
|
| 15 |
+
for param in ["weight", "bias"]:
|
| 16 |
+
k = f"blocks.{i}.{mlp_name}.{fc}.{param}"
|
| 17 |
+
mapping[k] = k
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_from_two_checkpoints(model, ckpt1_path, ckpt2_path, map1=None, map2=None, device='cuda'):
|
| 21 |
+
ckpt1 = torch.load(ckpt1_path, map_location=device, weights_only=False)
|
| 22 |
+
ckpt2 = torch.load(ckpt2_path, map_location=device, weights_only=False)
|
| 23 |
+
|
| 24 |
+
state1 = {k.replace('_orig_mod.', ''): v for k, v in ckpt1["ema"].items()}
|
| 25 |
+
state2 = {k.replace('_orig_mod.', ''): v for k, v in ckpt2["ema"].items()}
|
| 26 |
+
|
| 27 |
+
model_state = model.state_dict()
|
| 28 |
+
|
| 29 |
+
new_state = {}
|
| 30 |
+
source_info = {} # key: model param name, value: ckpt source name
|
| 31 |
+
|
| 32 |
+
if map1:
|
| 33 |
+
for k_model, k_ckpt in map1.items():
|
| 34 |
+
if (
|
| 35 |
+
k_ckpt in state1
|
| 36 |
+
and k_model in model_state
|
| 37 |
+
and state1[k_ckpt].shape == model_state[k_model].shape
|
| 38 |
+
):
|
| 39 |
+
new_state[k_model] = state1[k_ckpt]
|
| 40 |
+
source_info[k_model] = "ckpt1"
|
| 41 |
+
|
| 42 |
+
if map2:
|
| 43 |
+
for k_model, k_ckpt in map2.items():
|
| 44 |
+
if (
|
| 45 |
+
k_ckpt in state2
|
| 46 |
+
and k_model in model_state
|
| 47 |
+
and state2[k_ckpt].shape == model_state[k_model].shape
|
| 48 |
+
):
|
| 49 |
+
new_state[k_model] = state2[k_ckpt]
|
| 50 |
+
source_info[k_model] = "ckpt2"
|
| 51 |
+
|
| 52 |
+
for k_model, tensor in model_state.items():
|
| 53 |
+
if k_model not in new_state:
|
| 54 |
+
if k_model in state1 and state1[k_model].shape == tensor.shape:
|
| 55 |
+
new_state[k_model] = state1[k_model]
|
| 56 |
+
source_info[k_model] = "fallback_ckpt1"
|
| 57 |
+
|
| 58 |
+
model.load_state_dict(new_state, strict=False)
|
| 59 |
+
print(f"Loaded {len(new_state)} / {len(model_state)} parameters")
|
| 60 |
+
|
| 61 |
+
return new_state
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def main(args):
|
| 65 |
+
with open(args.config, "r") as f:
|
| 66 |
+
config = yaml.safe_load(f)
|
| 67 |
+
|
| 68 |
+
model_name = config.get("model", "AVCDiT-B/2")
|
| 69 |
+
print(f"Using model: {model_name}")
|
| 70 |
+
|
| 71 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 72 |
+
|
| 73 |
+
model = AVCDiT_models[model_name](
|
| 74 |
+
context_size=4,
|
| 75 |
+
input_size=28,
|
| 76 |
+
in_channels=4,
|
| 77 |
+
mode="av"
|
| 78 |
+
).to(device)
|
| 79 |
+
|
| 80 |
+
depth = len(model.blocks)
|
| 81 |
+
|
| 82 |
+
map1 = {}
|
| 83 |
+
add_exact_keys(map1, [
|
| 84 |
+
"pos_embed_v",
|
| 85 |
+
"x_embedder_v.proj.weight",
|
| 86 |
+
"x_embedder_v.proj.bias",
|
| 87 |
+
"final_layer.linear.weight",
|
| 88 |
+
"final_layer.linear.bias",
|
| 89 |
+
"final_layer.adaLN_modulation.1.weight",
|
| 90 |
+
"final_layer.adaLN_modulation.1.bias",
|
| 91 |
+
])
|
| 92 |
+
add_mlp_block_keys(map1, "mlp_v", depth)
|
| 93 |
+
|
| 94 |
+
map2 = {}
|
| 95 |
+
add_exact_keys(map2, [
|
| 96 |
+
"pos_embed_a_cond",
|
| 97 |
+
"pos_embed_a_pred",
|
| 98 |
+
"x_embedder_a.weight",
|
| 99 |
+
"x_embedder_a.bias",
|
| 100 |
+
"final_layer_a.linear.weight",
|
| 101 |
+
"final_layer_a.linear.bias",
|
| 102 |
+
"final_layer_a.adaLN_modulation.1.weight",
|
| 103 |
+
"final_layer_a.adaLN_modulation.1.bias",
|
| 104 |
+
])
|
| 105 |
+
add_mlp_block_keys(map2, "mlp_a", depth)
|
| 106 |
+
|
| 107 |
+
merged_state_dict = load_from_two_checkpoints(
|
| 108 |
+
model,
|
| 109 |
+
ckpt1_path=args.v_expert,
|
| 110 |
+
ckpt2_path=args.a_expert,
|
| 111 |
+
map1=map1,
|
| 112 |
+
map2=map2,
|
| 113 |
+
device=device
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
torch.save({"ema": merged_state_dict}, args.output)
|
| 117 |
+
print(f"Merged model saved to {args.output}")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
parser = argparse.ArgumentParser()
|
| 122 |
+
parser.add_argument("--config", type=str, required=True)
|
| 123 |
+
parser.add_argument("--v_expert", type=str, required=True)
|
| 124 |
+
parser.add_argument("--a_expert", type=str, required=True)
|
| 125 |
+
parser.add_argument("--output", type=str, default="experts_merged.pth")
|
| 126 |
+
args = parser.parse_args()
|
| 127 |
+
|
| 128 |
+
main(args)
|
misc.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
IMAGE_ASPECT_RATIO = (4 / 3) # all images are centered cropped to a 4:3 aspect ratio in training
|
| 13 |
+
|
| 14 |
+
with open("config/data_config.yaml", "r") as f:
|
| 15 |
+
data_config = yaml.safe_load(f)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_action_torch(diffusion_output, action_stats):
|
| 19 |
+
ndeltas = diffusion_output
|
| 20 |
+
ndeltas = ndeltas.reshape(ndeltas.shape[0], -1, 2)
|
| 21 |
+
ndeltas = unnormalize_data(ndeltas, action_stats)
|
| 22 |
+
actions = torch.cumsum(ndeltas, dim=1)
|
| 23 |
+
return actions.to(ndeltas)
|
| 24 |
+
|
| 25 |
+
def log_viz_single(dataset_name, obs_image, goal_image, preds, deltas, loss, min_idx, actions, action_stats, plan_iter=0, output_dir='plot.png'):
|
| 26 |
+
'''
|
| 27 |
+
Visualize a single instance
|
| 28 |
+
actions is gt actions
|
| 29 |
+
'''
|
| 30 |
+
viz_obs_image = unnormalize(obs_image.detach().cpu())[-1] # take last img
|
| 31 |
+
viz_goal_image = unnormalize(goal_image.detach().cpu())
|
| 32 |
+
deltas = deltas.detach().cpu()
|
| 33 |
+
loss = loss.detach().cpu()
|
| 34 |
+
actions = actions.detach().cpu()
|
| 35 |
+
pred_actions = get_action_torch(deltas[:, :, :2], action_stats)
|
| 36 |
+
plot_array = plot_images_and_actions(dataset_name, viz_obs_image, viz_goal_image, pred_actions, actions, min_idx, loss=loss)
|
| 37 |
+
|
| 38 |
+
plt.imshow(plot_array)
|
| 39 |
+
plt.axis('off') # Hide axes for a cleaner image
|
| 40 |
+
|
| 41 |
+
# Save the plot array as a PNG file locally
|
| 42 |
+
plt.savefig(output_dir, format='png', dpi=300, bbox_inches='tight')
|
| 43 |
+
|
| 44 |
+
def plot_images_and_actions(dataset_name, curr_viz_obs_image, curr_viz_goal_image, curr_viz_pred_actions, curr_viz_actions, min_idx, loss):
|
| 45 |
+
curr_viz_obs_image = curr_viz_obs_image.permute(1, 2, 0).cpu().numpy()
|
| 46 |
+
curr_viz_goal_image = curr_viz_goal_image.permute(1, 2, 0).cpu().numpy()
|
| 47 |
+
|
| 48 |
+
# scale back to metric space for plotting
|
| 49 |
+
curr_viz_pred_actions = curr_viz_pred_actions * data_config[dataset_name]['metric_waypoint_spacing']
|
| 50 |
+
curr_viz_actions = curr_viz_actions * data_config[dataset_name]['metric_waypoint_spacing']
|
| 51 |
+
|
| 52 |
+
# Create the figure with three subplots
|
| 53 |
+
fig, axs = plt.subplots(1, 3, figsize=(9, 3))
|
| 54 |
+
|
| 55 |
+
# Plot condition image
|
| 56 |
+
axs[0].imshow(curr_viz_obs_image)
|
| 57 |
+
axs[0].set_title("Condition Image", fontsize=13)
|
| 58 |
+
axs[0].axis("off")
|
| 59 |
+
|
| 60 |
+
# Plot goal image
|
| 61 |
+
axs[1].imshow(curr_viz_goal_image)
|
| 62 |
+
axs[1].set_title("Goal Image", fontsize=13)
|
| 63 |
+
axs[1].axis("off")
|
| 64 |
+
|
| 65 |
+
colors = ['red', 'orange', 'cyan']
|
| 66 |
+
for i in range(1, curr_viz_pred_actions.shape[0]):
|
| 67 |
+
color = colors[(i - 1) % len(colors)]
|
| 68 |
+
label = f"Sample {i} Min Loss" if i == min_idx.item() else f"{i}"
|
| 69 |
+
|
| 70 |
+
if i != min_idx.item():
|
| 71 |
+
axs[2].plot(-curr_viz_pred_actions[i, :, 1], curr_viz_pred_actions[i, :, 0],
|
| 72 |
+
color=color, marker="o", markersize=5, label=label)
|
| 73 |
+
axs[2].text(-curr_viz_pred_actions[i, -1, 1],
|
| 74 |
+
curr_viz_pred_actions[i, -1, 0],
|
| 75 |
+
round(loss[i].item(), 3),
|
| 76 |
+
color='black',
|
| 77 |
+
fontsize=10,
|
| 78 |
+
ha='left', va='bottom') # Adjust position to avoid overlap
|
| 79 |
+
|
| 80 |
+
# Highlight the minimum loss sample
|
| 81 |
+
axs[2].plot(-curr_viz_pred_actions[min_idx.item(), :, 1], curr_viz_pred_actions[min_idx.item(), :, 0],
|
| 82 |
+
color='green', marker="o", markersize=5, label=f"{min_idx.item()}")
|
| 83 |
+
axs[2].text(-curr_viz_pred_actions[min_idx.item(), -1, 1],
|
| 84 |
+
curr_viz_pred_actions[min_idx.item(), -1, 0],
|
| 85 |
+
round(loss[min_idx.item()].item(), 3),
|
| 86 |
+
color='black',
|
| 87 |
+
fontsize=10,
|
| 88 |
+
ha='left', va='bottom') # Adjust position to avoid overlap
|
| 89 |
+
|
| 90 |
+
# Plot ground truth actions
|
| 91 |
+
axs[2].plot(-curr_viz_actions[:, 1], curr_viz_actions[:, 0], color='blue', marker="o", label="GT")
|
| 92 |
+
|
| 93 |
+
# Set titles and labels with larger font size
|
| 94 |
+
axs[2].set_title(" ", fontsize=13)
|
| 95 |
+
axs[2].set_xlabel("X (m)", fontsize=11)
|
| 96 |
+
axs[2].set_ylabel("Y (m)", fontsize=11)
|
| 97 |
+
|
| 98 |
+
# Set equal aspect ratio and adjust axis limits
|
| 99 |
+
axs[2].set_aspect('equal', adjustable='box')
|
| 100 |
+
x_min, x_max = axs[2].get_xlim()
|
| 101 |
+
y_min, y_max = axs[2].get_ylim()
|
| 102 |
+
axis_range = max(x_max - x_min, y_max - y_min) / 2
|
| 103 |
+
x_mid = (x_max + x_min) / 2
|
| 104 |
+
y_mid = (y_max + y_min) / 2
|
| 105 |
+
axs[2].set_xlim(x_mid - axis_range, x_mid + axis_range)
|
| 106 |
+
axs[2].set_ylim(y_mid - axis_range, y_mid + axis_range)
|
| 107 |
+
|
| 108 |
+
axs[2].legend(loc='lower left', fontsize=10, frameon=True, bbox_to_anchor=(0, 0))
|
| 109 |
+
plt.tight_layout()
|
| 110 |
+
|
| 111 |
+
canvas = FigureCanvas(fig)
|
| 112 |
+
canvas.draw()
|
| 113 |
+
plot_array = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
|
| 114 |
+
plot_array = plot_array.reshape(canvas.get_width_height()[::-1] + (3,))
|
| 115 |
+
plt.close(fig)
|
| 116 |
+
return plot_array
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def normalize_data(data, stats):
|
| 120 |
+
# nomalize to [0,1]
|
| 121 |
+
ndata = (data - stats['min']) / (stats['max'] - stats['min'])
|
| 122 |
+
# normalize to [-1, 1]
|
| 123 |
+
ndata = ndata * 2 - 1
|
| 124 |
+
return ndata
|
| 125 |
+
|
| 126 |
+
def unnormalize_data(ndata, stats):
|
| 127 |
+
ndata = (ndata + 1) / 2
|
| 128 |
+
data = ndata * (stats['max'].to(ndata) - stats['min'].to(ndata)) + stats['min'].to(ndata)
|
| 129 |
+
return data
|
| 130 |
+
|
| 131 |
+
def get_data_path(data_folder: str, f: str, time: int, data_type: str = "image"):
|
| 132 |
+
data_ext = {
|
| 133 |
+
"image": ".jpg",
|
| 134 |
+
"audio": ".wav"
|
| 135 |
+
# add more data types here
|
| 136 |
+
}
|
| 137 |
+
return os.path.join(data_folder, f, f"{str(time)}{data_ext[data_type]}")
|
| 138 |
+
|
| 139 |
+
def yaw_rotmat(yaw: float) -> np.ndarray:
|
| 140 |
+
return np.array(
|
| 141 |
+
[
|
| 142 |
+
[np.cos(yaw), -np.sin(yaw), 0.0],
|
| 143 |
+
[np.sin(yaw), np.cos(yaw), 0.0],
|
| 144 |
+
[0.0, 0.0, 1.0],
|
| 145 |
+
],
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def angle_difference(theta1, theta2):
|
| 149 |
+
delta_theta = theta2 - theta1
|
| 150 |
+
delta_theta = delta_theta - 2 * np.pi * np.floor((delta_theta + np.pi) / (2 * np.pi))
|
| 151 |
+
return delta_theta
|
| 152 |
+
|
| 153 |
+
def get_delta_np(actions):
|
| 154 |
+
# append zeros to first action (unbatched)
|
| 155 |
+
ex_actions = np.concatenate((np.zeros((1, actions.shape[1])), actions), axis=0)
|
| 156 |
+
delta = ex_actions[1:] - ex_actions[:-1]
|
| 157 |
+
|
| 158 |
+
return delta
|
| 159 |
+
|
| 160 |
+
def to_local_coords(
|
| 161 |
+
positions: np.ndarray, curr_pos: np.ndarray, curr_yaw: float
|
| 162 |
+
) -> np.ndarray:
|
| 163 |
+
"""
|
| 164 |
+
Convert positions to local coordinates
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
positions (np.ndarray): positions to convert
|
| 168 |
+
curr_pos (np.ndarray): current position
|
| 169 |
+
curr_yaw (float): current yaw
|
| 170 |
+
Returns:
|
| 171 |
+
np.ndarray: positions in local coordinates
|
| 172 |
+
"""
|
| 173 |
+
rotmat = yaw_rotmat(curr_yaw)
|
| 174 |
+
if positions.shape[-1] == 2:
|
| 175 |
+
rotmat = rotmat[:2, :2]
|
| 176 |
+
elif positions.shape[-1] == 3:
|
| 177 |
+
pass
|
| 178 |
+
else:
|
| 179 |
+
raise ValueError
|
| 180 |
+
|
| 181 |
+
return (positions - curr_pos).dot(rotmat)
|
| 182 |
+
|
| 183 |
+
def calculate_delta_yaw(unnorm_actions):
|
| 184 |
+
x = unnorm_actions[..., 0]
|
| 185 |
+
y = unnorm_actions[..., 1]
|
| 186 |
+
|
| 187 |
+
yaw = torch.atan2(y, x).unsqueeze(-1)
|
| 188 |
+
delta_yaw = torch.cat((torch.zeros(yaw.shape[0], 1, yaw.shape[2]).to(yaw.device), yaw), dim=1)
|
| 189 |
+
delta_yaw = delta_yaw[:, 1:, :] - delta_yaw[:, :-1, :]
|
| 190 |
+
|
| 191 |
+
return delta_yaw
|
| 192 |
+
|
| 193 |
+
def save_planning_pred(dataset_save_output_dir, B, idxs, obs_image, goal_image, preds, deltas, loss, gt_actions, plan_iter=0):
|
| 194 |
+
for batch_idx, idx in enumerate(idxs.flatten()):
|
| 195 |
+
sample_idx = int(idx)
|
| 196 |
+
sample_folder = os.path.join(dataset_save_output_dir, f'id_{sample_idx}')
|
| 197 |
+
os.makedirs(sample_folder, exist_ok=True)
|
| 198 |
+
|
| 199 |
+
preds_save = {
|
| 200 |
+
'obs_image': obs_image[batch_idx],
|
| 201 |
+
'goal_image': goal_image[batch_idx],
|
| 202 |
+
'preds': preds[batch_idx],
|
| 203 |
+
'deltas': deltas[batch_idx],
|
| 204 |
+
'loss': loss[batch_idx],
|
| 205 |
+
'gt_actions': gt_actions[batch_idx],
|
| 206 |
+
}
|
| 207 |
+
preds_file = os.path.join(sample_folder, f"preds_{plan_iter}.pth")
|
| 208 |
+
torch.save(preds_save, preds_file)
|
| 209 |
+
|
| 210 |
+
class CenterCropAR:
|
| 211 |
+
def __init__(self, ar: float = IMAGE_ASPECT_RATIO):
|
| 212 |
+
self.ar = ar
|
| 213 |
+
|
| 214 |
+
def __call__(self, img: Image.Image):
|
| 215 |
+
w, h = img.size
|
| 216 |
+
if w > h:
|
| 217 |
+
img = TF.center_crop(img, (h, int(h * self.ar)))
|
| 218 |
+
else:
|
| 219 |
+
img = TF.center_crop(img, (int(w / self.ar), w))
|
| 220 |
+
return img
|
| 221 |
+
|
| 222 |
+
transform = transforms.Compose([
|
| 223 |
+
CenterCropAR(),
|
| 224 |
+
transforms.Resize((224, 224)),
|
| 225 |
+
transforms.ToTensor(),
|
| 226 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 227 |
+
])
|
| 228 |
+
|
| 229 |
+
unnormalize = transforms.Normalize(
|
| 230 |
+
mean=[-0.5 / 0.5, -0.5 / 0.5, -0.5 / 0.5],
|
| 231 |
+
std=[1 / 0.5, 1 / 0.5, 1 / 0.5]
|
| 232 |
+
)
|
models.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
| 9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import numpy as np
|
| 14 |
+
import math
|
| 15 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def modulate(x, shift, scale):
|
| 19 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
#################################################################################
|
| 23 |
+
# Embedding Layers for Timesteps and Class Labels #
|
| 24 |
+
#################################################################################
|
| 25 |
+
|
| 26 |
+
class TimestepEmbedder(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
Embeds scalar timesteps into vector representations.
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.mlp = nn.Sequential(
|
| 33 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 34 |
+
nn.SiLU(),
|
| 35 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 36 |
+
)
|
| 37 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 41 |
+
"""
|
| 42 |
+
Create sinusoidal timestep embeddings.
|
| 43 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 44 |
+
These may be fractional.
|
| 45 |
+
:param dim: the dimension of the output.
|
| 46 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 47 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 48 |
+
"""
|
| 49 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 50 |
+
half = dim // 2
|
| 51 |
+
freqs = torch.exp(
|
| 52 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 53 |
+
).to(device=t.device)
|
| 54 |
+
args = t.float() * freqs[None]
|
| 55 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 56 |
+
if dim % 2:
|
| 57 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 58 |
+
return embedding
|
| 59 |
+
|
| 60 |
+
def forward(self, t):
|
| 61 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 62 |
+
t_emb = self.mlp(t_freq)
|
| 63 |
+
return t_emb
|
| 64 |
+
|
| 65 |
+
class ActionEmbedder(nn.Module):
|
| 66 |
+
"""
|
| 67 |
+
Embeds action xy into vector representations.
|
| 68 |
+
"""
|
| 69 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 70 |
+
super().__init__()
|
| 71 |
+
hsize = hidden_size//3
|
| 72 |
+
self.x_emb = TimestepEmbedder(hsize, frequency_embedding_size)
|
| 73 |
+
self.y_emb = TimestepEmbedder(hsize, frequency_embedding_size)
|
| 74 |
+
self.angle_emb = TimestepEmbedder(hidden_size -2*hsize, frequency_embedding_size)
|
| 75 |
+
|
| 76 |
+
def forward(self, xya):
|
| 77 |
+
return torch.cat([self.x_emb(xya[...,0:1]), self.y_emb(xya[...,1:2]), self.angle_emb(xya[...,2:3])], dim=-1)
|
| 78 |
+
|
| 79 |
+
#################################################################################
|
| 80 |
+
# Core AVCDiT Model #
|
| 81 |
+
#################################################################################
|
| 82 |
+
|
| 83 |
+
class AVCDiTBlock(nn.Module):
|
| 84 |
+
"""
|
| 85 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning and two modalities.
|
| 86 |
+
"""
|
| 87 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, mode="av", **block_kwargs):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.mode = mode
|
| 90 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 91 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 92 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 93 |
+
self.norm_cond = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 94 |
+
self.cttn = nn.MultiheadAttention(hidden_size, num_heads=num_heads, add_bias_kv=True, bias=True, batch_first=True, **block_kwargs)
|
| 95 |
+
self.adaLN_modulation = nn.Sequential(
|
| 96 |
+
nn.SiLU(),
|
| 97 |
+
nn.Linear(hidden_size, 11 * hidden_size, bias=True)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 101 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 102 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 103 |
+
if self.mode == "av" or self.mode == "v":
|
| 104 |
+
self.mlp_v = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 105 |
+
if self.mode == "av" or self.mode == "a":
|
| 106 |
+
self.mlp_a = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 107 |
+
|
| 108 |
+
# def forward(self, x_v, x_a, c, x_v_cond, x_a_cond, mode="av"):
|
| 109 |
+
def forward(self, *args):
|
| 110 |
+
if self.mode == "av":
|
| 111 |
+
x_v, x_a, c, x_v_cond, x_a_cond = args
|
| 112 |
+
shift_msa, scale_msa, gate_msa, shift_ca_xcond, scale_ca_xcond, shift_ca_x, scale_ca_x, gate_ca_x, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(11, dim=1)
|
| 113 |
+
_, v_token_num, _ = x_v.shape
|
| 114 |
+
x = torch.cat([x_v, x_a], dim=1)
|
| 115 |
+
x_cond = torch.cat([x_v_cond, x_a_cond], dim=1)
|
| 116 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 117 |
+
x_cond_norm = modulate(self.norm_cond(x_cond), shift_ca_xcond, scale_ca_xcond)
|
| 118 |
+
x = x + gate_ca_x.unsqueeze(1) * self.cttn(query=modulate(self.norm2(x), shift_ca_x, scale_ca_x), key=x_cond_norm, value=x_cond_norm, need_weights=False)[0]
|
| 119 |
+
x_v = x[:,:v_token_num,:]
|
| 120 |
+
x_a = x[:,v_token_num:,:]
|
| 121 |
+
x_v = x_v + gate_mlp.unsqueeze(1) * self.mlp_v(modulate(self.norm3(x_v), shift_mlp, scale_mlp))
|
| 122 |
+
x_a = x_a + gate_mlp.unsqueeze(1) * self.mlp_a(modulate(self.norm3(x_a), shift_mlp, scale_mlp))
|
| 123 |
+
return x_v, x_a
|
| 124 |
+
elif self.mode == "v":
|
| 125 |
+
x, c, x_cond = args
|
| 126 |
+
shift_msa, scale_msa, gate_msa, shift_ca_xcond, scale_ca_xcond, shift_ca_x, scale_ca_x, gate_ca_x, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(11, dim=1)
|
| 127 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 128 |
+
x_cond_norm = modulate(self.norm_cond(x_cond), shift_ca_xcond, scale_ca_xcond)
|
| 129 |
+
x = x + gate_ca_x.unsqueeze(1) * self.cttn(query=modulate(self.norm2(x), shift_ca_x, scale_ca_x), key=x_cond_norm, value=x_cond_norm, need_weights=False)[0]
|
| 130 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp_v(modulate(self.norm3(x), shift_mlp, scale_mlp))
|
| 131 |
+
return x
|
| 132 |
+
elif self.mode == "a":
|
| 133 |
+
x, c, x_cond = args
|
| 134 |
+
shift_msa, scale_msa, gate_msa, shift_ca_xcond, scale_ca_xcond, shift_ca_x, scale_ca_x, gate_ca_x, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(11, dim=1)
|
| 135 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 136 |
+
x_cond_norm = modulate(self.norm_cond(x_cond), shift_ca_xcond, scale_ca_xcond)
|
| 137 |
+
x = x + gate_ca_x.unsqueeze(1) * self.cttn(query=modulate(self.norm2(x), shift_ca_x, scale_ca_x), key=x_cond_norm, value=x_cond_norm, need_weights=False)[0]
|
| 138 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp_a(modulate(self.norm3(x), shift_mlp, scale_mlp))
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class FinalLayer(nn.Module):
|
| 143 |
+
"""
|
| 144 |
+
The final layer of DiT.
|
| 145 |
+
"""
|
| 146 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 149 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 150 |
+
self.adaLN_modulation = nn.Sequential(
|
| 151 |
+
nn.SiLU(),
|
| 152 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def forward(self, x, c):
|
| 156 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 157 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 158 |
+
x = self.linear(x)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class FinalLayer_audio(nn.Module):
|
| 163 |
+
def __init__(self, hidden_size, out_channels):
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 166 |
+
self.linear = nn.Linear(hidden_size, out_channels, bias=True) # no patch²
|
| 167 |
+
self.adaLN_modulation = nn.Sequential(
|
| 168 |
+
nn.SiLU(),
|
| 169 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def forward(self, x, c):
|
| 173 |
+
# x: (B, N, hidden_size), c: (B, hidden_size)
|
| 174 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) # shape (B, hidden_size)
|
| 175 |
+
x = modulate(self.norm_final(x), shift, scale) # apply AdaLN
|
| 176 |
+
x = self.linear(x) # → (B, N, out_channels)
|
| 177 |
+
return x
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class AVCDiT(nn.Module):
|
| 181 |
+
"""
|
| 182 |
+
Diffusion model with a Transformer backbone.
|
| 183 |
+
"""
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
input_size=32,
|
| 187 |
+
context_size=2,
|
| 188 |
+
patch_size=2,
|
| 189 |
+
in_channels=4,
|
| 190 |
+
hidden_size=1152,
|
| 191 |
+
depth=28,
|
| 192 |
+
num_heads=16,
|
| 193 |
+
mlp_ratio=4.0,
|
| 194 |
+
learn_sigma=True,
|
| 195 |
+
num_patches_a=180,
|
| 196 |
+
mode="av",
|
| 197 |
+
):
|
| 198 |
+
super().__init__()
|
| 199 |
+
self.mode = mode
|
| 200 |
+
assert (self.mode=="av" or self.mode=="v" or self.mode=="a")
|
| 201 |
+
self.context_size = context_size
|
| 202 |
+
self.learn_sigma = learn_sigma
|
| 203 |
+
self.in_channels = in_channels
|
| 204 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| 205 |
+
self.patch_size = patch_size
|
| 206 |
+
self.num_heads = num_heads
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
if self.mode == "av" or self.mode == "v":
|
| 210 |
+
self.x_embedder_v = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 211 |
+
num_patches_v = self.x_embedder_v.num_patches
|
| 212 |
+
self.pos_embed_v = nn.Parameter(torch.zeros(self.context_size + 1, num_patches_v, hidden_size), requires_grad=True) # for context and for predicted frame
|
| 213 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 214 |
+
if self.mode == "av" or self.mode == "a":
|
| 215 |
+
self.x_embedder_a = nn.Conv1d(
|
| 216 |
+
in_channels=16,
|
| 217 |
+
out_channels=hidden_size, # [B]
|
| 218 |
+
kernel_size=1,
|
| 219 |
+
stride=1,
|
| 220 |
+
bias=True
|
| 221 |
+
) #TODO
|
| 222 |
+
self.pos_embed_a_cond = nn.Parameter(torch.zeros(self.context_size, num_patches_a, hidden_size), requires_grad=True)
|
| 223 |
+
self.pos_embed_a_pred = nn.Parameter(torch.zeros(1, num_patches_a+1, hidden_size), requires_grad=True)
|
| 224 |
+
self.final_layer_a = FinalLayer_audio(hidden_size=hidden_size, out_channels=32) # [B]
|
| 225 |
+
|
| 226 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 227 |
+
self.y_embedder = ActionEmbedder(hidden_size)
|
| 228 |
+
|
| 229 |
+
# self.blocks = nn.ModuleList([AVCDiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
|
| 230 |
+
self.blocks = nn.ModuleList([
|
| 231 |
+
AVCDiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, mode=self.mode)
|
| 232 |
+
for _ in range(depth)
|
| 233 |
+
])
|
| 234 |
+
|
| 235 |
+
self.time_embedder = TimestepEmbedder(hidden_size)
|
| 236 |
+
self.initialize_weights()
|
| 237 |
+
|
| 238 |
+
def initialize_weights(self):
|
| 239 |
+
# Initialize transformer layers:
|
| 240 |
+
def _basic_init(module):
|
| 241 |
+
if isinstance(module, nn.Linear):
|
| 242 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 243 |
+
if module.bias is not None:
|
| 244 |
+
nn.init.constant_(module.bias, 0)
|
| 245 |
+
self.apply(_basic_init)
|
| 246 |
+
|
| 247 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 248 |
+
if self.mode == "av" or self.mode == "v":
|
| 249 |
+
nn.init.normal_(self.pos_embed_v, std=0.02)
|
| 250 |
+
if self.mode == "av" or self.mode == "a":
|
| 251 |
+
nn.init.normal_(self.pos_embed_a_pred, std=0.02)
|
| 252 |
+
nn.init.normal_(self.pos_embed_a_cond, std=0.02)
|
| 253 |
+
|
| 254 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 255 |
+
if self.mode == "av" or self.mode == "v":
|
| 256 |
+
w = self.x_embedder_v.proj.weight.data
|
| 257 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 258 |
+
nn.init.constant_(self.x_embedder_v.proj.bias, 0)
|
| 259 |
+
|
| 260 |
+
# Initialize x_embedder_a (Conv1d) like linear
|
| 261 |
+
if self.mode == "av" or self.mode == "a":
|
| 262 |
+
w = self.x_embedder_a.weight.data
|
| 263 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 264 |
+
nn.init.constant_(self.x_embedder_a.bias, 0)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# Initialize action embedding:
|
| 268 |
+
nn.init.normal_(self.y_embedder.x_emb.mlp[0].weight, std=0.02)
|
| 269 |
+
nn.init.normal_(self.y_embedder.x_emb.mlp[2].weight, std=0.02)
|
| 270 |
+
|
| 271 |
+
nn.init.normal_(self.y_embedder.y_emb.mlp[0].weight, std=0.02)
|
| 272 |
+
nn.init.normal_(self.y_embedder.y_emb.mlp[2].weight, std=0.02)
|
| 273 |
+
|
| 274 |
+
nn.init.normal_(self.y_embedder.angle_emb.mlp[0].weight, std=0.02)
|
| 275 |
+
nn.init.normal_(self.y_embedder.angle_emb.mlp[2].weight, std=0.02)
|
| 276 |
+
|
| 277 |
+
# Initialize timestep embedding MLP:
|
| 278 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 279 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 280 |
+
|
| 281 |
+
nn.init.normal_(self.time_embedder.mlp[0].weight, std=0.02)
|
| 282 |
+
nn.init.normal_(self.time_embedder.mlp[2].weight, std=0.02)
|
| 283 |
+
|
| 284 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
| 285 |
+
for block in self.blocks:
|
| 286 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 287 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 288 |
+
|
| 289 |
+
# Zero-out output layers:
|
| 290 |
+
if self.mode == "av" or self.mode == "v":
|
| 291 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 292 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 293 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 294 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 295 |
+
|
| 296 |
+
if self.mode == "av" or self.mode == "a":
|
| 297 |
+
nn.init.constant_(self.final_layer_a.adaLN_modulation[-1].weight, 0)
|
| 298 |
+
nn.init.constant_(self.final_layer_a.adaLN_modulation[-1].bias, 0)
|
| 299 |
+
nn.init.constant_(self.final_layer_a.linear.weight, 0)
|
| 300 |
+
nn.init.constant_(self.final_layer_a.linear.bias, 0)
|
| 301 |
+
|
| 302 |
+
def unpatchify(self, x):
|
| 303 |
+
"""
|
| 304 |
+
x: (N, T, patch_size**2 * C)
|
| 305 |
+
imgs: (N, H, W, C)
|
| 306 |
+
"""
|
| 307 |
+
c = self.out_channels
|
| 308 |
+
p = self.x_embedder_v.patch_size[0]
|
| 309 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 310 |
+
assert h * w == x.shape[1]
|
| 311 |
+
|
| 312 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 313 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 314 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 315 |
+
return imgs
|
| 316 |
+
|
| 317 |
+
# def forward(self, x_v, x_a, t, y, x_v_cond, x_a_cond, rel_t):
|
| 318 |
+
# def forward(self, *args):
|
| 319 |
+
def forward(self, *args, **kwargs):
|
| 320 |
+
"""
|
| 321 |
+
Forward pass of DiT.
|
| 322 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 323 |
+
t: (N,) tensor of diffusion timesteps
|
| 324 |
+
y: (N,) tensor of class labels
|
| 325 |
+
"""
|
| 326 |
+
if self.mode == "av":
|
| 327 |
+
if len(args) >= 7:
|
| 328 |
+
x_v, x_a, t, y, x_v_cond, x_a_cond, rel_t = args[:7]
|
| 329 |
+
else:
|
| 330 |
+
assert len(args) == 3, f"mode='v' expects 2 or 5 positional args, got {len(args)}"
|
| 331 |
+
x_v, x_a, t = args
|
| 332 |
+
y = kwargs["y"]
|
| 333 |
+
x_v_cond = kwargs["x_v_cond"]
|
| 334 |
+
x_a_cond = kwargs["x_a_cond"]
|
| 335 |
+
rel_t = kwargs["rel_t"]
|
| 336 |
+
|
| 337 |
+
x_v = self.x_embedder_v(x_v) + self.pos_embed_v[self.context_size:]
|
| 338 |
+
x_v_cond = self.x_embedder_v(x_v_cond.flatten(0, 1)).unflatten(0, (x_v_cond.shape[0], x_v_cond.shape[1])) + self.pos_embed_v[:self.context_size] # (N, T, D), where T = H * W / patch_size ** 2.flatten(1, 2)
|
| 339 |
+
x_v_cond = x_v_cond.flatten(1, 2)
|
| 340 |
+
|
| 341 |
+
x_a = self.x_embedder_a(x_a) # → (B, embed_dim, L')
|
| 342 |
+
x_a = x_a.transpose(1, 2) # → (B, L', embed_dim)
|
| 343 |
+
x_a = x_a + self.pos_embed_a_pred
|
| 344 |
+
|
| 345 |
+
x_a_cond = self.x_embedder_a(x_a_cond.flatten(0, 1)).transpose(1, 2).unflatten(0, (x_a_cond.shape[0], x_a_cond.shape[1])) + self.pos_embed_a_cond
|
| 346 |
+
x_a_cond = x_a_cond.flatten(1, 2)
|
| 347 |
+
|
| 348 |
+
t = self.t_embedder(t[..., None])
|
| 349 |
+
y = self.y_embedder(y)
|
| 350 |
+
time_emb = self.time_embedder(rel_t[..., None])
|
| 351 |
+
c = t + time_emb + y # if training on unlabeled data, dont add y.
|
| 352 |
+
|
| 353 |
+
for block in self.blocks:
|
| 354 |
+
x_v, x_a = block(x_v, x_a, c, x_v_cond, x_a_cond)
|
| 355 |
+
x_v = self.final_layer(x_v, c)
|
| 356 |
+
x_v = self.unpatchify(x_v)
|
| 357 |
+
x_a = self.final_layer_a(x_a, c)
|
| 358 |
+
x_a = x_a.transpose(1, 2)
|
| 359 |
+
return x_v, x_a
|
| 360 |
+
elif self.mode == "v":
|
| 361 |
+
if len(args) >= 5:
|
| 362 |
+
x, t, y, x_cond, rel_t = args[:5]
|
| 363 |
+
else:
|
| 364 |
+
assert len(args) == 2, f"mode='v' expects 2 or 5 positional args, got {len(args)}"
|
| 365 |
+
x, t = args
|
| 366 |
+
y = kwargs["y"]
|
| 367 |
+
x_cond = kwargs["x_cond"]
|
| 368 |
+
rel_t = kwargs["rel_t"]
|
| 369 |
+
x = self.x_embedder_v(x) + self.pos_embed_v[self.context_size:]
|
| 370 |
+
x_cond = self.x_embedder_v(x_cond.flatten(0, 1)).unflatten(0, (x_cond.shape[0], x_cond.shape[1])) + self.pos_embed_v[:self.context_size] # (N, T, D), where T = H * W / patch_size ** 2.flatten(1, 2)
|
| 371 |
+
x_cond = x_cond.flatten(1, 2)
|
| 372 |
+
t = self.t_embedder(t[..., None])
|
| 373 |
+
y = self.y_embedder(y)
|
| 374 |
+
time_emb = self.time_embedder(rel_t[..., None])
|
| 375 |
+
c = t + time_emb + y # if training on unlabeled data, dont add y.
|
| 376 |
+
for block in self.blocks:
|
| 377 |
+
x = block(x, c, x_cond)
|
| 378 |
+
x = self.final_layer(x, c)
|
| 379 |
+
x = self.unpatchify(x)
|
| 380 |
+
return x
|
| 381 |
+
elif self.mode == "a":
|
| 382 |
+
if len(args) >= 5:
|
| 383 |
+
x, t, y, x_cond, rel_t = args[:5]
|
| 384 |
+
else:
|
| 385 |
+
assert len(args) == 2, f"mode='v' expects 2 or 5 positional args, got {len(args)}"
|
| 386 |
+
x, t = args
|
| 387 |
+
y = kwargs["y"]
|
| 388 |
+
x_cond = kwargs["x_cond"]
|
| 389 |
+
rel_t = kwargs["rel_t"]
|
| 390 |
+
x = self.x_embedder_a(x) # → (B, embed_dim, L')
|
| 391 |
+
x = x.transpose(1, 2) # → (B, L', embed_dim)
|
| 392 |
+
x = x + self.pos_embed_a_pred # [REWARD]
|
| 393 |
+
x_cond = self.x_embedder_a(x_cond.flatten(0, 1)).transpose(1, 2).unflatten(0, (x_cond.shape[0], x_cond.shape[1])) + self.pos_embed_a_cond # [REWARD]
|
| 394 |
+
x_cond = x_cond.flatten(1, 2)
|
| 395 |
+
t = self.t_embedder(t[..., None])
|
| 396 |
+
y = self.y_embedder(y)
|
| 397 |
+
time_emb = self.time_embedder(rel_t[..., None])
|
| 398 |
+
c = t + time_emb + y # if training on unlabeled data, dont add y.
|
| 399 |
+
for block in self.blocks:
|
| 400 |
+
x = block(x, c, x_cond)
|
| 401 |
+
x = self.final_layer_a(x, c)
|
| 402 |
+
x = x.transpose(1, 2)
|
| 403 |
+
return x
|
| 404 |
+
|
| 405 |
+
#################################################################################
|
| 406 |
+
# Sine/Cosine Positional Embedding Functions #
|
| 407 |
+
#################################################################################
|
| 408 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 409 |
+
|
| 410 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 411 |
+
"""
|
| 412 |
+
grid_size: int of the grid height and width
|
| 413 |
+
return:
|
| 414 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 415 |
+
"""
|
| 416 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 417 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 418 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 419 |
+
grid = np.stack(grid, axis=0)
|
| 420 |
+
|
| 421 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 422 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 423 |
+
if cls_token and extra_tokens > 0:
|
| 424 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 425 |
+
return pos_embed
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 429 |
+
assert embed_dim % 2 == 0
|
| 430 |
+
|
| 431 |
+
# use half of dimensions to encode grid_h
|
| 432 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 433 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 434 |
+
|
| 435 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 436 |
+
return emb
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 440 |
+
"""
|
| 441 |
+
embed_dim: output dimension for each position
|
| 442 |
+
pos: a list of positions to be encoded: size (M,)
|
| 443 |
+
out: (M, D)
|
| 444 |
+
"""
|
| 445 |
+
assert embed_dim % 2 == 0
|
| 446 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 447 |
+
omega /= embed_dim / 2.
|
| 448 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 449 |
+
|
| 450 |
+
pos = pos.reshape(-1) # (M,)
|
| 451 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 452 |
+
|
| 453 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 454 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 455 |
+
|
| 456 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 457 |
+
return emb
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
#################################################################################
|
| 461 |
+
# AVCDiT Configs #
|
| 462 |
+
#################################################################################
|
| 463 |
+
|
| 464 |
+
def AVCDiT_XL_2(**kwargs):
|
| 465 |
+
return AVCDiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
| 466 |
+
|
| 467 |
+
def AVCDiT_L_2(**kwargs):
|
| 468 |
+
return AVCDiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
| 469 |
+
|
| 470 |
+
def AVCDiT_B_2(**kwargs):
|
| 471 |
+
return AVCDiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
| 472 |
+
|
| 473 |
+
def AVCDiT_S_2(**kwargs):
|
| 474 |
+
return AVCDiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
AVCDiT_models = {
|
| 478 |
+
'AVCDiT-XL/2': AVCDiT_XL_2,
|
| 479 |
+
'AVCDiT-L/2': AVCDiT_L_2,
|
| 480 |
+
'AVCDiT-B/2': AVCDiT_B_2,
|
| 481 |
+
'AVCDiT-S/2': AVCDiT_S_2
|
| 482 |
+
}
|
soundstream.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn.utils import weight_norm
|
| 5 |
+
|
| 6 |
+
from vector_quantize_pytorch import ResidualVQ
|
| 7 |
+
|
| 8 |
+
class CausalConv1d(nn.Conv1d):
|
| 9 |
+
def __init__(self, *args, **kwargs):
|
| 10 |
+
super().__init__(*args, **kwargs)
|
| 11 |
+
self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1)
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CausalConvTranspose1d(nn.ConvTranspose1d):
|
| 18 |
+
def __init__(self, *args, **kwargs):
|
| 19 |
+
super().__init__(*args, **kwargs)
|
| 20 |
+
self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) + self.output_padding[0] + 1 - self.stride[0]
|
| 21 |
+
|
| 22 |
+
def forward(self, x, output_size=None):
|
| 23 |
+
if self.padding_mode != 'zeros':
|
| 24 |
+
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
|
| 25 |
+
|
| 26 |
+
assert isinstance(self.padding, tuple)
|
| 27 |
+
output_padding = self._output_padding(
|
| 28 |
+
x, output_size, self.stride, self.padding, self.kernel_size, self.dilation)
|
| 29 |
+
return F.conv_transpose1d(
|
| 30 |
+
x, self.weight, self.bias, self.stride, self.padding,
|
| 31 |
+
output_padding, self.groups, self.dilation)[...,:-self.causal_padding]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ResidualUnit(nn.Module):
|
| 35 |
+
def __init__(self, in_channels, out_channels, dilation):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.dilation = dilation
|
| 39 |
+
|
| 40 |
+
self.layers = nn.Sequential(
|
| 41 |
+
CausalConv1d(in_channels=in_channels, out_channels=out_channels,
|
| 42 |
+
kernel_size=7, dilation=dilation),
|
| 43 |
+
nn.ELU(),
|
| 44 |
+
nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
|
| 45 |
+
kernel_size=1)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return x + self.layers(x)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class EncoderBlock(nn.Module):
|
| 53 |
+
def __init__(self, out_channels, stride):
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
self.layers = nn.Sequential(
|
| 57 |
+
ResidualUnit(in_channels=out_channels//2,
|
| 58 |
+
out_channels=out_channels//2, dilation=1),
|
| 59 |
+
nn.ELU(),
|
| 60 |
+
ResidualUnit(in_channels=out_channels//2,
|
| 61 |
+
out_channels=out_channels//2, dilation=3),
|
| 62 |
+
nn.ELU(),
|
| 63 |
+
ResidualUnit(in_channels=out_channels//2,
|
| 64 |
+
out_channels=out_channels//2, dilation=9),
|
| 65 |
+
nn.ELU(),
|
| 66 |
+
CausalConv1d(in_channels=out_channels//2, out_channels=out_channels,
|
| 67 |
+
kernel_size=2*stride, stride=stride)
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
return self.layers(x)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class DecoderBlock(nn.Module):
|
| 75 |
+
def __init__(self, out_channels, stride):
|
| 76 |
+
super().__init__()
|
| 77 |
+
|
| 78 |
+
self.layers = nn.Sequential(
|
| 79 |
+
CausalConvTranspose1d(in_channels=2*out_channels,
|
| 80 |
+
out_channels=out_channels,
|
| 81 |
+
kernel_size=2*stride, stride=stride),
|
| 82 |
+
nn.ELU(),
|
| 83 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
| 84 |
+
dilation=1),
|
| 85 |
+
nn.ELU(),
|
| 86 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
| 87 |
+
dilation=3),
|
| 88 |
+
nn.ELU(),
|
| 89 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
| 90 |
+
dilation=9),
|
| 91 |
+
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
return self.layers(x)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Encoder(nn.Module):
|
| 99 |
+
def __init__(self, C, D):
|
| 100 |
+
super().__init__()
|
| 101 |
+
|
| 102 |
+
self.layers = nn.Sequential(
|
| 103 |
+
CausalConv1d(in_channels=2, out_channels=C, kernel_size=7),
|
| 104 |
+
nn.ELU(),
|
| 105 |
+
EncoderBlock(out_channels=2*C, stride=2),
|
| 106 |
+
nn.ELU(),
|
| 107 |
+
EncoderBlock(out_channels=4*C, stride=4),
|
| 108 |
+
nn.ELU(),
|
| 109 |
+
EncoderBlock(out_channels=8*C, stride=5),
|
| 110 |
+
nn.ELU(),
|
| 111 |
+
# EncoderBlock(out_channels=16*C, stride=8),
|
| 112 |
+
# nn.ELU(),
|
| 113 |
+
# CausalConv1d(in_channels=16*C, out_channels=D, kernel_size=3)
|
| 114 |
+
CausalConv1d(in_channels=8*C, out_channels=D, kernel_size=3)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
return self.layers(x)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class Decoder(nn.Module):
|
| 122 |
+
def __init__(self, C, D):
|
| 123 |
+
super().__init__()
|
| 124 |
+
|
| 125 |
+
self.layers = nn.Sequential(
|
| 126 |
+
CausalConv1d(in_channels=D, out_channels=8*C, kernel_size=7),
|
| 127 |
+
# CausalConv1d(in_channels=D, out_channels=16*C, kernel_size=7),
|
| 128 |
+
# nn.ELU(),
|
| 129 |
+
# DecoderBlock(out_channels=8*C, stride=8),
|
| 130 |
+
nn.ELU(),
|
| 131 |
+
DecoderBlock(out_channels=4*C, stride=5),
|
| 132 |
+
nn.ELU(),
|
| 133 |
+
DecoderBlock(out_channels=2*C, stride=4),
|
| 134 |
+
nn.ELU(),
|
| 135 |
+
DecoderBlock(out_channels=C, stride=2),
|
| 136 |
+
nn.ELU(),
|
| 137 |
+
CausalConv1d(in_channels=C, out_channels=2, kernel_size=7)
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
return self.layers(x)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class SoundStream(nn.Module):
|
| 145 |
+
def __init__(self, C, D, n_q, codebook_size):
|
| 146 |
+
super().__init__()
|
| 147 |
+
|
| 148 |
+
self.encoder = Encoder(C=C, D=D)
|
| 149 |
+
self.quantizer = ResidualVQ(
|
| 150 |
+
num_quantizers=n_q, dim=D, codebook_size=codebook_size,
|
| 151 |
+
kmeans_init=True, kmeans_iters=100, threshold_ema_dead_code=2
|
| 152 |
+
)
|
| 153 |
+
self.decoder = Decoder(C=C, D=D)
|
| 154 |
+
|
| 155 |
+
@staticmethod
|
| 156 |
+
def pad_to_multiple(x, multiple):
|
| 157 |
+
"""
|
| 158 |
+
x: [B, C, T]
|
| 159 |
+
multiple: int, e.g., 320
|
| 160 |
+
return: padded_x, original_length
|
| 161 |
+
"""
|
| 162 |
+
B, C, T = x.shape
|
| 163 |
+
target_len = ((T + multiple - 1) // multiple) * multiple
|
| 164 |
+
pad_len = target_len - T
|
| 165 |
+
padded_x = F.pad(x, (0, pad_len), mode='reflect')
|
| 166 |
+
return padded_x, T
|
| 167 |
+
|
| 168 |
+
@staticmethod
|
| 169 |
+
def crop_to_length(x, original_length):
|
| 170 |
+
return x[..., :original_length]
|
| 171 |
+
|
| 172 |
+
def forward(self, x):
|
| 173 |
+
e = self.encoder(x) # [B, D, T']
|
| 174 |
+
e = e.permute(0, 2, 1) # → [B, T', D]
|
| 175 |
+
quantized, _, _ = self.quantizer(e)
|
| 176 |
+
quantized = quantized.permute(0, 2, 1) # → [B, D, T']
|
| 177 |
+
o = self.decoder(quantized) # → [B, 2, T_padded]
|
| 178 |
+
return o
|
train_avwm_stage1.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from inference_avwm import model_forward_wrapper_v
|
| 2 |
+
import torch
|
| 3 |
+
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
|
| 4 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 5 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 6 |
+
|
| 7 |
+
import matplotlib
|
| 8 |
+
matplotlib.use('Agg')
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
from time import time
|
| 12 |
+
import argparse
|
| 13 |
+
import logging
|
| 14 |
+
import os
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import yaml
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 21 |
+
from torch.utils.data import DataLoader, ConcatDataset
|
| 22 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 23 |
+
from diffusers.models import AutoencoderKL
|
| 24 |
+
|
| 25 |
+
from distributed import init_distributed
|
| 26 |
+
from models import AVCDiT_models
|
| 27 |
+
from diffusion import create_diffusion
|
| 28 |
+
from datasets import TrainingDataset
|
| 29 |
+
from misc import transform
|
| 30 |
+
|
| 31 |
+
#################################################################################
|
| 32 |
+
# Training Helper Functions #
|
| 33 |
+
#################################################################################
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_checkpoint_if_available(model, ema, opt, scaler, config, device, logger, args):
|
| 37 |
+
start_epoch = 0
|
| 38 |
+
train_steps = 0
|
| 39 |
+
latest_path = os.path.join(config['results_dir'], config['run_name'], "checkpoints", "latest.pth.tar")
|
| 40 |
+
if os.path.isfile(latest_path) or config.get('from_checkpoint', 0):
|
| 41 |
+
latest_path = latest_path if os.path.isfile(latest_path) else config.get('from_checkpoint', 0)
|
| 42 |
+
print("Loading model from ", latest_path)
|
| 43 |
+
checkpoint = torch.load(latest_path, map_location=f"cuda:{device}", weights_only=False)
|
| 44 |
+
|
| 45 |
+
ema_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["ema"].items()}
|
| 46 |
+
remapped = {}
|
| 47 |
+
for k, v in ema_ckp.items():
|
| 48 |
+
new_k = k
|
| 49 |
+
# 1) pos_embed -> pos_embed_v
|
| 50 |
+
if k.startswith("pos_embed"):
|
| 51 |
+
new_k = k.replace("pos_embed", "pos_embed_v", 1)
|
| 52 |
+
# 2) x_embedder. -> x_embedder_v.
|
| 53 |
+
if new_k.startswith("x_embedder."):
|
| 54 |
+
new_k = new_k.replace("x_embedder.", "x_embedder_v.", 1)
|
| 55 |
+
# 3) blocks.*.mlp.*: .mlp. -> .mlp_v.
|
| 56 |
+
if new_k.startswith("blocks.") and ".mlp." in new_k:
|
| 57 |
+
new_k = new_k.replace(".mlp.", ".mlp_v.", 1)
|
| 58 |
+
remapped[new_k] = v
|
| 59 |
+
ema_ckp = remapped
|
| 60 |
+
model.load_state_dict(ema_ckp, strict=True)
|
| 61 |
+
print("Model weights loaded.")
|
| 62 |
+
ema.load_state_dict(ema_ckp, strict=True)
|
| 63 |
+
print("EMA weights loaded.")
|
| 64 |
+
|
| 65 |
+
if args.restart_from_checkpoint:
|
| 66 |
+
logger.info("Restarting training: epoch and step counters set to 0.")
|
| 67 |
+
else:
|
| 68 |
+
if "opt" in checkpoint:
|
| 69 |
+
opt_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["opt"].items()}
|
| 70 |
+
opt.load_state_dict(opt_ckp)
|
| 71 |
+
print("Optimizer state loaded.")
|
| 72 |
+
if "scaler" in checkpoint and scaler is not None:
|
| 73 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
| 74 |
+
print("GradScaler state loaded.")
|
| 75 |
+
if "epoch" in checkpoint:
|
| 76 |
+
start_epoch = checkpoint["epoch"] + 1
|
| 77 |
+
if "train_steps" in checkpoint:
|
| 78 |
+
train_steps = checkpoint["train_steps"]
|
| 79 |
+
logger.info(f"Resuming from epoch {start_epoch}, step {train_steps}")
|
| 80 |
+
|
| 81 |
+
return start_epoch, train_steps
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@torch.no_grad()
|
| 85 |
+
def update_ema(ema_model, model, decay=0.9999):
|
| 86 |
+
"""
|
| 87 |
+
Step the EMA model towards the current model.
|
| 88 |
+
"""
|
| 89 |
+
ema_params = OrderedDict(ema_model.named_parameters())
|
| 90 |
+
model_params = OrderedDict(model.named_parameters())
|
| 91 |
+
|
| 92 |
+
for name, param in model_params.items():
|
| 93 |
+
name = name.replace('_orig_mod.', '')
|
| 94 |
+
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def requires_grad(model, flag=True):
|
| 98 |
+
"""
|
| 99 |
+
Set requires_grad flag for all parameters in a model.
|
| 100 |
+
"""
|
| 101 |
+
for p in model.parameters():
|
| 102 |
+
p.requires_grad = flag
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def cleanup():
|
| 106 |
+
"""
|
| 107 |
+
End DDP training.
|
| 108 |
+
"""
|
| 109 |
+
dist.destroy_process_group()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def create_logger(logging_dir):
|
| 113 |
+
"""
|
| 114 |
+
Create a logger that writes to a log file and stdout.
|
| 115 |
+
"""
|
| 116 |
+
if dist.get_rank() == 0: # real logger
|
| 117 |
+
logging.basicConfig(
|
| 118 |
+
level=logging.INFO,
|
| 119 |
+
format='[\033[34m%(asctime)s\033[0m] %(message)s',
|
| 120 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
| 121 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
|
| 122 |
+
)
|
| 123 |
+
logger = logging.getLogger(__name__)
|
| 124 |
+
else: # dummy logger (does nothing)
|
| 125 |
+
logger = logging.getLogger(__name__)
|
| 126 |
+
logger.addHandler(logging.NullHandler())
|
| 127 |
+
return logger
|
| 128 |
+
|
| 129 |
+
#################################################################################
|
| 130 |
+
# Training Loop #
|
| 131 |
+
#################################################################################
|
| 132 |
+
|
| 133 |
+
def main(args):
|
| 134 |
+
"""
|
| 135 |
+
Trains a new AVCDiT model.
|
| 136 |
+
"""
|
| 137 |
+
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
| 138 |
+
|
| 139 |
+
# Setup DDP:
|
| 140 |
+
_, rank, device, _ = init_distributed()
|
| 141 |
+
# rank = dist.get_rank()
|
| 142 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 143 |
+
torch.manual_seed(seed)
|
| 144 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 145 |
+
with open("config/eval_config.yaml", "r") as f:
|
| 146 |
+
default_config = yaml.safe_load(f)
|
| 147 |
+
config = default_config
|
| 148 |
+
|
| 149 |
+
with open(args.config, "r") as f:
|
| 150 |
+
user_config = yaml.safe_load(f)
|
| 151 |
+
config.update(user_config)
|
| 152 |
+
|
| 153 |
+
# Setup an experiment folder:
|
| 154 |
+
os.makedirs(config['results_dir'], exist_ok=True) # Make results folder (holds all experiment subfolders)
|
| 155 |
+
experiment_dir = f"{config['results_dir']}/{config['run_name']}" # Create an experiment folder
|
| 156 |
+
checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
|
| 157 |
+
if rank == 0:
|
| 158 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 159 |
+
logger = create_logger(experiment_dir)
|
| 160 |
+
logger.info(f"Experiment directory created at {experiment_dir}")
|
| 161 |
+
else:
|
| 162 |
+
logger = create_logger(None)
|
| 163 |
+
|
| 164 |
+
# Create model:
|
| 165 |
+
tokenizer = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device)
|
| 166 |
+
latent_size = config['image_size'] // 8
|
| 167 |
+
|
| 168 |
+
assert config['image_size'] % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
|
| 169 |
+
num_cond = config['context_size']
|
| 170 |
+
model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4, mode="v").to(device)
|
| 171 |
+
|
| 172 |
+
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
|
| 173 |
+
requires_grad(ema, False)
|
| 174 |
+
|
| 175 |
+
# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
|
| 176 |
+
lr = float(config.get('lr', 1e-4))
|
| 177 |
+
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
bfloat_enable = bool(hasattr(args, 'bfloat16') and args.bfloat16)
|
| 181 |
+
if bfloat_enable:
|
| 182 |
+
scaler = torch.amp.GradScaler()
|
| 183 |
+
|
| 184 |
+
# load existing checkpoint
|
| 185 |
+
# latest_path = os.path.join(checkpoint_dir, "latest.pth.tar")
|
| 186 |
+
# === Load checkpoint or start from a pretrained one ===
|
| 187 |
+
start_epoch, train_steps = load_checkpoint_if_available(
|
| 188 |
+
model, ema, opt, scaler if bfloat_enable else None, config, device, logger, args
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# ~40% speedup but might leads to worse performance depending on pytorch version
|
| 192 |
+
if args.torch_compile:
|
| 193 |
+
model = torch.compile(model)
|
| 194 |
+
model = DDP(model, device_ids=[device])
|
| 195 |
+
diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
|
| 196 |
+
# ,predict_xstart=True
|
| 197 |
+
logger.info(f"AVCDiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 198 |
+
|
| 199 |
+
train_dataset = []
|
| 200 |
+
test_dataset = []
|
| 201 |
+
|
| 202 |
+
for dataset_name in config["datasets"]:
|
| 203 |
+
data_config = config["datasets"][dataset_name]
|
| 204 |
+
|
| 205 |
+
for data_split_type in ["train", "test"]:
|
| 206 |
+
if data_split_type in data_config:
|
| 207 |
+
goals_per_obs = int(data_config["goals_per_obs"])
|
| 208 |
+
if data_split_type == 'test':
|
| 209 |
+
goals_per_obs = 4 # standardize testing
|
| 210 |
+
|
| 211 |
+
if "distance" in data_config:
|
| 212 |
+
min_dist_cat=data_config["distance"]["min_dist_cat"]
|
| 213 |
+
max_dist_cat=data_config["distance"]["max_dist_cat"]
|
| 214 |
+
else:
|
| 215 |
+
min_dist_cat=config["distance"]["min_dist_cat"]
|
| 216 |
+
max_dist_cat=config["distance"]["max_dist_cat"]
|
| 217 |
+
|
| 218 |
+
if "len_traj_pred" in data_config:
|
| 219 |
+
len_traj_pred=data_config["len_traj_pred"]
|
| 220 |
+
else:
|
| 221 |
+
len_traj_pred=config["len_traj_pred"]
|
| 222 |
+
|
| 223 |
+
dataset = TrainingDataset(
|
| 224 |
+
data_folder=data_config["data_folder"],
|
| 225 |
+
data_split_folder=data_config[data_split_type],
|
| 226 |
+
dataset_name=dataset_name,
|
| 227 |
+
image_size=config["image_size"],
|
| 228 |
+
min_dist_cat=min_dist_cat,
|
| 229 |
+
max_dist_cat=max_dist_cat,
|
| 230 |
+
len_traj_pred=len_traj_pred,
|
| 231 |
+
context_size=config["context_size"],
|
| 232 |
+
normalize=config["normalize"],
|
| 233 |
+
goals_per_obs=goals_per_obs,
|
| 234 |
+
transform=transform,
|
| 235 |
+
predefined_index=None,
|
| 236 |
+
traj_stride=1,
|
| 237 |
+
evaluate=(data_split_type=="test")
|
| 238 |
+
)
|
| 239 |
+
if data_split_type == "train":
|
| 240 |
+
train_dataset.append(dataset)
|
| 241 |
+
else:
|
| 242 |
+
test_dataset.append(dataset)
|
| 243 |
+
print(f"Dataset: {dataset_name} ({data_split_type}), size: {len(dataset)}")
|
| 244 |
+
|
| 245 |
+
# combine all the datasets from different robots
|
| 246 |
+
print(f"Combining {len(train_dataset)} datasets.")
|
| 247 |
+
train_dataset = ConcatDataset(train_dataset)
|
| 248 |
+
test_dataset = ConcatDataset(test_dataset)
|
| 249 |
+
|
| 250 |
+
sampler = DistributedSampler(
|
| 251 |
+
train_dataset,
|
| 252 |
+
num_replicas=dist.get_world_size(),
|
| 253 |
+
rank=rank,
|
| 254 |
+
shuffle=True,
|
| 255 |
+
seed=args.global_seed
|
| 256 |
+
)
|
| 257 |
+
loader = DataLoader(
|
| 258 |
+
train_dataset,
|
| 259 |
+
batch_size=config['batch_size'],
|
| 260 |
+
shuffle=False,
|
| 261 |
+
sampler=sampler,
|
| 262 |
+
num_workers=config['num_workers'],
|
| 263 |
+
pin_memory=True,
|
| 264 |
+
drop_last=True,
|
| 265 |
+
persistent_workers=True
|
| 266 |
+
)
|
| 267 |
+
logger.info(f"Dataset contains {len(train_dataset):,} images")
|
| 268 |
+
|
| 269 |
+
# Prepare models for training:
|
| 270 |
+
model.train() # important! This enables embedding dropout for classifier-free guidance
|
| 271 |
+
ema.eval() # EMA model should always be in eval mode
|
| 272 |
+
|
| 273 |
+
# Variables for monitoring/logging purposes:
|
| 274 |
+
log_steps = 0
|
| 275 |
+
running_loss = 0
|
| 276 |
+
start_time = time()
|
| 277 |
+
|
| 278 |
+
logger.info(f"Training for {args.epochs} epochs...")
|
| 279 |
+
for epoch in range(start_epoch, args.epochs):
|
| 280 |
+
sampler.set_epoch(epoch)
|
| 281 |
+
steps_per_epoch = len(loader)
|
| 282 |
+
if rank == 0:
|
| 283 |
+
logger.info(f"Epoch {epoch} contains {steps_per_epoch} steps.")
|
| 284 |
+
logger.info(f"Beginning epoch {epoch}...")
|
| 285 |
+
|
| 286 |
+
for x, _, y, _, rel_t in loader:
|
| 287 |
+
x = x.to(device, non_blocking=True)
|
| 288 |
+
y = y.to(device, non_blocking=True)
|
| 289 |
+
rel_t = rel_t.to(device, non_blocking=True)
|
| 290 |
+
|
| 291 |
+
with torch.amp.autocast('cuda', enabled=bfloat_enable, dtype=torch.bfloat16):
|
| 292 |
+
with torch.no_grad():
|
| 293 |
+
# Map input images to latent space + normalize latents:
|
| 294 |
+
B, T = x.shape[:2]
|
| 295 |
+
x = x.flatten(0,1)
|
| 296 |
+
x = tokenizer.encode(x).latent_dist.sample().mul_(0.18215)
|
| 297 |
+
x = x.unflatten(0, (B, T))
|
| 298 |
+
|
| 299 |
+
num_goals = T - num_cond
|
| 300 |
+
x_start = x[:, num_cond:].flatten(0, 1)
|
| 301 |
+
x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3], x.shape[4]).flatten(0, 1)
|
| 302 |
+
y = y.flatten(0, 1)
|
| 303 |
+
rel_t = rel_t.flatten(0, 1)
|
| 304 |
+
|
| 305 |
+
t = torch.randint(0, diffusion.num_timesteps, (x_start.shape[0],), device=device)
|
| 306 |
+
model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t)
|
| 307 |
+
loss_dict = diffusion.training_losses(model, x_start, t, model_kwargs)
|
| 308 |
+
loss = loss_dict["loss"].mean()
|
| 309 |
+
|
| 310 |
+
if not bfloat_enable:
|
| 311 |
+
opt.zero_grad()
|
| 312 |
+
loss.backward()
|
| 313 |
+
opt.step()
|
| 314 |
+
else:
|
| 315 |
+
scaler.scale(loss).backward()
|
| 316 |
+
if config.get('grad_clip_val', 0) > 0:
|
| 317 |
+
scaler.unscale_(opt)
|
| 318 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['grad_clip_val'])
|
| 319 |
+
scaler.step(opt)
|
| 320 |
+
scaler.update()
|
| 321 |
+
|
| 322 |
+
update_ema(ema, model.module)
|
| 323 |
+
|
| 324 |
+
# Log loss values:
|
| 325 |
+
running_loss += loss.detach().item()
|
| 326 |
+
log_steps += 1
|
| 327 |
+
train_steps += 1
|
| 328 |
+
if train_steps % args.log_every == 0:
|
| 329 |
+
# Measure training speed:
|
| 330 |
+
torch.cuda.synchronize()
|
| 331 |
+
end_time = time()
|
| 332 |
+
steps_per_sec = log_steps / (end_time - start_time)
|
| 333 |
+
samples_per_sec = dist.get_world_size()*x_cond.shape[0]*steps_per_sec
|
| 334 |
+
# Reduce loss history over all processes:
|
| 335 |
+
avg_loss = torch.tensor(running_loss / log_steps, device=device)
|
| 336 |
+
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
|
| 337 |
+
avg_loss = avg_loss.item() / dist.get_world_size()
|
| 338 |
+
total_steps = len(loader) * args.epochs
|
| 339 |
+
progress_pct = train_steps / total_steps * 100
|
| 340 |
+
|
| 341 |
+
remaining_steps = total_steps - train_steps
|
| 342 |
+
eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0
|
| 343 |
+
eta_hours = eta_seconds / 3600
|
| 344 |
+
|
| 345 |
+
logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Samples/Sec: {samples_per_sec:.2f}")
|
| 346 |
+
logger.info(f"Progress: {progress_pct:.2f}% | ETA: {eta_hours:.1f}h")
|
| 347 |
+
# Reset monitoring variables:
|
| 348 |
+
running_loss = 0
|
| 349 |
+
log_steps = 0
|
| 350 |
+
start_time = time()
|
| 351 |
+
|
| 352 |
+
# Save DiT checkpoint:
|
| 353 |
+
if train_steps % args.ckpt_every == 0 and train_steps > 0:
|
| 354 |
+
if rank == 0:
|
| 355 |
+
checkpoint = {
|
| 356 |
+
"model": model.module.state_dict(),
|
| 357 |
+
"ema": ema.state_dict(),
|
| 358 |
+
"opt": opt.state_dict(),
|
| 359 |
+
"args": args,
|
| 360 |
+
"epoch": epoch,
|
| 361 |
+
"train_steps": train_steps
|
| 362 |
+
}
|
| 363 |
+
if bfloat_enable:
|
| 364 |
+
checkpoint.update({"scaler": scaler.state_dict()})
|
| 365 |
+
checkpoint_path = f"{checkpoint_dir}/latest.pth.tar"
|
| 366 |
+
torch.save(checkpoint, checkpoint_path)
|
| 367 |
+
if train_steps % (10*args.ckpt_every) == 0 and train_steps > 0:
|
| 368 |
+
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pth.tar"
|
| 369 |
+
torch.save(checkpoint, checkpoint_path)
|
| 370 |
+
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
| 371 |
+
|
| 372 |
+
if train_steps % args.eval_every == 0 and train_steps > 0:
|
| 373 |
+
eval_start_time = time()
|
| 374 |
+
# validation / test set evaluation
|
| 375 |
+
save_dir = os.path.join(experiment_dir, str(train_steps))
|
| 376 |
+
sim_score_val = evaluate(ema, tokenizer, diffusion, test_dataset, rank, config["batch_size"], config["num_workers"], latent_size, device, save_dir, args.global_seed, bfloat_enable, num_cond)
|
| 377 |
+
dist.barrier()
|
| 378 |
+
eval_end_time = time()
|
| 379 |
+
eval_time = eval_end_time - eval_start_time
|
| 380 |
+
# logger.info(f"(step={train_steps:07d}) Val Perceptual Loss: {sim_score_val:.4f}, Train Perceptual Loss: {sim_score_train:.4f}, Eval Time: {eval_time:.2f}")
|
| 381 |
+
logger.info(f"(step={train_steps:07d}) Val Perceptual Loss: {sim_score_val:.4f}, Eval Time: {eval_time:.2f}")
|
| 382 |
+
|
| 383 |
+
model.eval()
|
| 384 |
+
logger.info("Done!")
|
| 385 |
+
cleanup()
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
@torch.no_grad
|
| 389 |
+
def evaluate(model, vae, diffusion, test_dataloaders, rank, batch_size, num_workers, latent_size, device, save_dir, seed, bfloat_enable, num_cond):
|
| 390 |
+
sampler = DistributedSampler(
|
| 391 |
+
test_dataloaders,
|
| 392 |
+
num_replicas=dist.get_world_size(),
|
| 393 |
+
rank=rank,
|
| 394 |
+
shuffle=True,
|
| 395 |
+
seed=seed
|
| 396 |
+
)
|
| 397 |
+
loader = DataLoader(
|
| 398 |
+
test_dataloaders,
|
| 399 |
+
batch_size=batch_size,
|
| 400 |
+
shuffle=False,
|
| 401 |
+
sampler=sampler,
|
| 402 |
+
num_workers=num_workers,
|
| 403 |
+
pin_memory=True,
|
| 404 |
+
drop_last=True
|
| 405 |
+
)
|
| 406 |
+
from dreamsim import dreamsim
|
| 407 |
+
eval_model, _ = dreamsim(pretrained=True)
|
| 408 |
+
score = torch.tensor(0.).to(device)
|
| 409 |
+
n_samples = torch.tensor(0).to(device)
|
| 410 |
+
|
| 411 |
+
# Run for 1 step
|
| 412 |
+
for x, _, y, _, rel_t, _ in loader:
|
| 413 |
+
x = x.to(device)
|
| 414 |
+
y = y.to(device)
|
| 415 |
+
rel_t = rel_t.to(device).flatten(0, 1)
|
| 416 |
+
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
|
| 417 |
+
B, T = x.shape[:2]
|
| 418 |
+
num_goals = T - num_cond
|
| 419 |
+
samples = model_forward_wrapper_v((model, diffusion, vae), x, y, num_timesteps=None, latent_size=latent_size, device=device, num_cond=num_cond, num_goals=num_goals, rel_t=rel_t)
|
| 420 |
+
x_start_pixels = x[:, num_cond:].flatten(0, 1)
|
| 421 |
+
x_cond_pixels = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3], x.shape[4]).flatten(0, 1)
|
| 422 |
+
samples = samples * 0.5 + 0.5
|
| 423 |
+
x_start_pixels = x_start_pixels * 0.5 + 0.5
|
| 424 |
+
x_cond_pixels = x_cond_pixels * 0.5 + 0.5
|
| 425 |
+
res = eval_model(x_start_pixels, samples)
|
| 426 |
+
score += res.sum()
|
| 427 |
+
n_samples += len(res)
|
| 428 |
+
break
|
| 429 |
+
|
| 430 |
+
if rank == 0:
|
| 431 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 432 |
+
for i in range(min(samples.shape[0], 10)):
|
| 433 |
+
_, ax = plt.subplots(1,3,dpi=256)
|
| 434 |
+
ax[0].imshow((x_cond_pixels[i, -1].permute(1,2,0).cpu().numpy()*255).astype('uint8'))
|
| 435 |
+
ax[1].imshow((x_start_pixels[i].permute(1,2,0).cpu().numpy()*255).astype('uint8'))
|
| 436 |
+
ax[2].imshow((samples[i].permute(1,2,0).cpu().float().numpy()*255).astype('uint8'))
|
| 437 |
+
plt.savefig(f'{save_dir}/{i}.png')
|
| 438 |
+
plt.close()
|
| 439 |
+
|
| 440 |
+
dist.all_reduce(score)
|
| 441 |
+
dist.all_reduce(n_samples)
|
| 442 |
+
sim_score = score/n_samples
|
| 443 |
+
return sim_score
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def get_args_parser():
|
| 447 |
+
parser = argparse.ArgumentParser()
|
| 448 |
+
parser.add_argument("--config", type=str, required=True)
|
| 449 |
+
parser.add_argument("--epochs", type=int, default=300)
|
| 450 |
+
# parser.add_argument("--global-batch-size", type=int, default=256)
|
| 451 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
| 452 |
+
parser.add_argument("--log-every", type=int, default=100)
|
| 453 |
+
parser.add_argument("--ckpt-every", type=int, default=2000)
|
| 454 |
+
parser.add_argument("--eval-every", type=int, default=5000)
|
| 455 |
+
parser.add_argument("--bfloat16", type=int, default=1)
|
| 456 |
+
parser.add_argument("--torch-compile", type=int, default=1)
|
| 457 |
+
parser.add_argument("--restart-from-checkpoint", type=int, default=0,
|
| 458 |
+
help="If 1, only load model weights and reset epoch/step to zero (cold start)")
|
| 459 |
+
return parser
|
| 460 |
+
|
| 461 |
+
if __name__ == "__main__":
|
| 462 |
+
args = get_args_parser().parse_args()
|
| 463 |
+
main(args)
|
train_avwm_stage2.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# NoMaD, GNM, ViNT: https://github.com/robodhruv/visualnav-transformer
|
| 9 |
+
# --------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
from inference_avwm import model_forward_wrapper_a
|
| 12 |
+
import torch
|
| 13 |
+
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
|
| 14 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 15 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 16 |
+
|
| 17 |
+
import matplotlib
|
| 18 |
+
matplotlib.use('Agg')
|
| 19 |
+
from collections import OrderedDict
|
| 20 |
+
from copy import deepcopy
|
| 21 |
+
from time import time
|
| 22 |
+
import argparse
|
| 23 |
+
import logging
|
| 24 |
+
import os
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
import yaml
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
import torch.distributed as dist
|
| 30 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 31 |
+
from torch.utils.data import DataLoader, ConcatDataset
|
| 32 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 33 |
+
from diffusers.models import AutoencoderKL
|
| 34 |
+
|
| 35 |
+
from distributed import init_distributed
|
| 36 |
+
from models import AVCDiT_models
|
| 37 |
+
from diffusion import create_diffusion
|
| 38 |
+
from datasets import TrainingDataset
|
| 39 |
+
from misc import transform
|
| 40 |
+
from soundstream import SoundStream
|
| 41 |
+
# from audiovae import BinauralSeqTokenCodec
|
| 42 |
+
import torchaudio
|
| 43 |
+
from eval_audio import build_mel_transform, mel_cosine_stereo, drms_avg_db_stereo, save_ref_hat_spectrogram_panel
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_checkpoint_if_available(model, ema, opt, scaler, config, device, logger, args):
|
| 47 |
+
start_epoch = 0
|
| 48 |
+
train_steps = 0
|
| 49 |
+
latest_path = os.path.join(config['results_dir'], config['run_name'], "checkpoints", "latest.pth.tar")
|
| 50 |
+
if os.path.isfile(latest_path) or config.get('from_checkpoint', 0):
|
| 51 |
+
latest_path = latest_path if os.path.isfile(latest_path) else config.get('from_checkpoint', 0)
|
| 52 |
+
print("Loading model from ", latest_path)
|
| 53 |
+
checkpoint = torch.load(latest_path, map_location=f"cuda:{device}", weights_only=False)
|
| 54 |
+
ema_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["ema"].items()}
|
| 55 |
+
remapped = {}
|
| 56 |
+
for k, v in ema_ckp.items():
|
| 57 |
+
new_k = k
|
| 58 |
+
if new_k.startswith("blocks.") and ".mlp_v." in new_k:
|
| 59 |
+
new_k = new_k.replace(".mlp_v.", ".mlp_a.", 1)
|
| 60 |
+
remapped[new_k] = v
|
| 61 |
+
ema_ckp = remapped
|
| 62 |
+
model_state = model.state_dict()
|
| 63 |
+
load_info = model.load_state_dict(ema_ckp, strict=False)
|
| 64 |
+
|
| 65 |
+
print("Model weights loaded.")
|
| 66 |
+
ema.load_state_dict(ema_ckp, strict=False)
|
| 67 |
+
print("EMA weights loaded.")
|
| 68 |
+
if args.restart_from_checkpoint:
|
| 69 |
+
logger.info("Restarting training: epoch and step counters set to 0.")
|
| 70 |
+
else:
|
| 71 |
+
try:
|
| 72 |
+
if "opt" in checkpoint:
|
| 73 |
+
opt_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["opt"].items()}
|
| 74 |
+
opt.load_state_dict(opt_ckp)
|
| 75 |
+
print("Optimizer state loaded.")
|
| 76 |
+
if "scaler" in checkpoint and scaler is not None:
|
| 77 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
| 78 |
+
print("GradScaler state loaded.")
|
| 79 |
+
except ValueError as e:
|
| 80 |
+
print(f"[WARN] Skip loading opt and scaler")
|
| 81 |
+
if "epoch" in checkpoint:
|
| 82 |
+
start_epoch = checkpoint["epoch"] + 1
|
| 83 |
+
if "train_steps" in checkpoint:
|
| 84 |
+
train_steps = checkpoint["train_steps"]
|
| 85 |
+
logger.info(f"Resuming from epoch {start_epoch}, step {train_steps}")
|
| 86 |
+
|
| 87 |
+
return start_epoch, train_steps
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@torch.no_grad()
|
| 91 |
+
def update_ema(ema_model, model, decay=0.9999):
|
| 92 |
+
"""
|
| 93 |
+
Step the EMA model towards the current model.
|
| 94 |
+
"""
|
| 95 |
+
ema_params = OrderedDict(ema_model.named_parameters())
|
| 96 |
+
model_params = OrderedDict(model.named_parameters())
|
| 97 |
+
|
| 98 |
+
for name, param in model_params.items():
|
| 99 |
+
name = name.replace('_orig_mod.', '')
|
| 100 |
+
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def requires_grad(model, flag=True):
|
| 104 |
+
"""
|
| 105 |
+
Set requires_grad flag for all parameters in a model.
|
| 106 |
+
"""
|
| 107 |
+
for p in model.parameters():
|
| 108 |
+
p.requires_grad = flag
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def cleanup():
|
| 112 |
+
"""
|
| 113 |
+
End DDP training.
|
| 114 |
+
"""
|
| 115 |
+
dist.destroy_process_group()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def create_logger(logging_dir):
|
| 119 |
+
"""
|
| 120 |
+
Create a logger that writes to a log file and stdout.
|
| 121 |
+
"""
|
| 122 |
+
if dist.get_rank() == 0: # real logger
|
| 123 |
+
logging.basicConfig(
|
| 124 |
+
level=logging.INFO,
|
| 125 |
+
format='[\033[34m%(asctime)s\033[0m] %(message)s',
|
| 126 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
| 127 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
|
| 128 |
+
)
|
| 129 |
+
logger = logging.getLogger(__name__)
|
| 130 |
+
else: # dummy logger (does nothing)
|
| 131 |
+
logger = logging.getLogger(__name__)
|
| 132 |
+
logger.addHandler(logging.NullHandler())
|
| 133 |
+
return logger
|
| 134 |
+
|
| 135 |
+
#################################################################################
|
| 136 |
+
# Training Loop #
|
| 137 |
+
#################################################################################
|
| 138 |
+
|
| 139 |
+
def main(args):
|
| 140 |
+
"""
|
| 141 |
+
Trains a new AVCDiT model.
|
| 142 |
+
"""
|
| 143 |
+
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
| 144 |
+
|
| 145 |
+
# Setup DDP:
|
| 146 |
+
_, rank, device, _ = init_distributed()
|
| 147 |
+
# rank = dist.get_rank()
|
| 148 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 149 |
+
torch.manual_seed(seed)
|
| 150 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 151 |
+
with open("config/eval_config.yaml", "r") as f:
|
| 152 |
+
default_config = yaml.safe_load(f)
|
| 153 |
+
config = default_config
|
| 154 |
+
|
| 155 |
+
with open(args.config, "r") as f:
|
| 156 |
+
user_config = yaml.safe_load(f)
|
| 157 |
+
config.update(user_config)
|
| 158 |
+
|
| 159 |
+
# Setup an experiment folder:
|
| 160 |
+
os.makedirs(config['results_dir'], exist_ok=True) # Make results folder (holds all experiment subfolders)
|
| 161 |
+
experiment_dir = f"{config['results_dir']}/{config['run_name']}" # Create an experiment folder
|
| 162 |
+
checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
|
| 163 |
+
if rank == 0:
|
| 164 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 165 |
+
logger = create_logger(experiment_dir)
|
| 166 |
+
logger.info(f"Experiment directory created at {experiment_dir}")
|
| 167 |
+
else:
|
| 168 |
+
logger = create_logger(None)
|
| 169 |
+
|
| 170 |
+
# Create model:
|
| 171 |
+
tokenizer = SoundStream(C=32, D=16, n_q=8, codebook_size=1024).to(device)
|
| 172 |
+
tokenizer_path=config["tokenizer_a_path"]
|
| 173 |
+
checkpoint = torch.load(tokenizer_path, map_location=f"cuda:{device}")
|
| 174 |
+
tokenizer.load_state_dict(checkpoint["model_state"])
|
| 175 |
+
tokenizer.eval()
|
| 176 |
+
|
| 177 |
+
latent_size = config['image_size'] // 8
|
| 178 |
+
|
| 179 |
+
assert config['image_size'] % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
|
| 180 |
+
num_cond = config['context_size']
|
| 181 |
+
model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4, mode="a").to(device)
|
| 182 |
+
|
| 183 |
+
ema = deepcopy(model).to(device)
|
| 184 |
+
requires_grad(ema, False)
|
| 185 |
+
|
| 186 |
+
lr = float(config.get('lr', 1e-4))
|
| 187 |
+
for param in model.parameters():
|
| 188 |
+
param.requires_grad = False
|
| 189 |
+
for param in model.x_embedder_a.parameters():
|
| 190 |
+
param.requires_grad = True
|
| 191 |
+
model.pos_embed_a_cond.requires_grad = True
|
| 192 |
+
model.pos_embed_a_pred.requires_grad = True
|
| 193 |
+
for param in model.final_layer_a.parameters():
|
| 194 |
+
param.requires_grad = True
|
| 195 |
+
for i, block in enumerate(model.blocks):
|
| 196 |
+
for name, param in block.named_parameters():
|
| 197 |
+
if name.startswith("mlp."):
|
| 198 |
+
param.requires_grad = True
|
| 199 |
+
|
| 200 |
+
opt = torch.optim.AdamW(
|
| 201 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 202 |
+
lr=lr, weight_decay=0
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
bfloat_enable = bool(hasattr(args, 'bfloat16') and args.bfloat16)
|
| 206 |
+
if bfloat_enable:
|
| 207 |
+
scaler = torch.amp.GradScaler()
|
| 208 |
+
|
| 209 |
+
start_epoch, train_steps = load_checkpoint_if_available(
|
| 210 |
+
model, ema, opt, scaler if bfloat_enable else None, config, device, logger, args
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
print("Trainable Parameters: ")
|
| 214 |
+
for name, param in model.named_parameters():
|
| 215 |
+
if param.requires_grad:
|
| 216 |
+
print(f" - {name}: {tuple(param.shape)}")
|
| 217 |
+
# =======================================================================================#
|
| 218 |
+
|
| 219 |
+
# ~40% speedup but might leads to worse performance depending on pytorch version
|
| 220 |
+
if args.torch_compile:
|
| 221 |
+
model = torch.compile(model)
|
| 222 |
+
model = DDP(model, device_ids=[device])
|
| 223 |
+
diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
|
| 224 |
+
# ,predict_xstart=True
|
| 225 |
+
logger.info(f"AVCDiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 226 |
+
|
| 227 |
+
train_dataset = []
|
| 228 |
+
test_dataset = []
|
| 229 |
+
|
| 230 |
+
for dataset_name in config["datasets"]:
|
| 231 |
+
data_config = config["datasets"][dataset_name]
|
| 232 |
+
|
| 233 |
+
for data_split_type in ["train", "test"]:
|
| 234 |
+
if data_split_type in data_config:
|
| 235 |
+
goals_per_obs = int(data_config["goals_per_obs"])
|
| 236 |
+
if data_split_type == 'test':
|
| 237 |
+
goals_per_obs = 4 # standardize testing
|
| 238 |
+
|
| 239 |
+
if "distance" in data_config:
|
| 240 |
+
min_dist_cat=data_config["distance"]["min_dist_cat"]
|
| 241 |
+
max_dist_cat=data_config["distance"]["max_dist_cat"]
|
| 242 |
+
else:
|
| 243 |
+
min_dist_cat=config["distance"]["min_dist_cat"]
|
| 244 |
+
max_dist_cat=config["distance"]["max_dist_cat"]
|
| 245 |
+
|
| 246 |
+
if "len_traj_pred" in data_config:
|
| 247 |
+
len_traj_pred=data_config["len_traj_pred"]
|
| 248 |
+
else:
|
| 249 |
+
len_traj_pred=config["len_traj_pred"]
|
| 250 |
+
|
| 251 |
+
dataset = TrainingDataset(
|
| 252 |
+
data_folder=data_config["data_folder"],
|
| 253 |
+
data_split_folder=data_config[data_split_type],
|
| 254 |
+
dataset_name=dataset_name,
|
| 255 |
+
image_size=config["image_size"],
|
| 256 |
+
min_dist_cat=min_dist_cat,
|
| 257 |
+
max_dist_cat=max_dist_cat,
|
| 258 |
+
len_traj_pred=len_traj_pred,
|
| 259 |
+
context_size=config["context_size"],
|
| 260 |
+
normalize=config["normalize"],
|
| 261 |
+
goals_per_obs=goals_per_obs,
|
| 262 |
+
transform=transform,
|
| 263 |
+
predefined_index=None,
|
| 264 |
+
traj_stride=1,
|
| 265 |
+
sample_rate=config["sample_rate"],
|
| 266 |
+
input_sr=config["input_sr"],
|
| 267 |
+
evaluate=(data_split_type=="test")
|
| 268 |
+
)
|
| 269 |
+
if data_split_type == "train":
|
| 270 |
+
train_dataset.append(dataset)
|
| 271 |
+
else:
|
| 272 |
+
test_dataset.append(dataset)
|
| 273 |
+
print(f"Dataset: {dataset_name} ({data_split_type}), size: {len(dataset)}")
|
| 274 |
+
|
| 275 |
+
# combine all the datasets from different robots
|
| 276 |
+
print(f"Combining {len(train_dataset)} datasets.")
|
| 277 |
+
train_dataset = ConcatDataset(train_dataset)
|
| 278 |
+
test_dataset = ConcatDataset(test_dataset)
|
| 279 |
+
|
| 280 |
+
sampler = DistributedSampler(
|
| 281 |
+
train_dataset,
|
| 282 |
+
num_replicas=dist.get_world_size(),
|
| 283 |
+
rank=rank,
|
| 284 |
+
shuffle=True,
|
| 285 |
+
seed=args.global_seed
|
| 286 |
+
)
|
| 287 |
+
loader = DataLoader(
|
| 288 |
+
train_dataset,
|
| 289 |
+
batch_size=config['batch_size'],
|
| 290 |
+
shuffle=False,
|
| 291 |
+
sampler=sampler,
|
| 292 |
+
num_workers=config['num_workers'],
|
| 293 |
+
pin_memory=True,
|
| 294 |
+
drop_last=True,
|
| 295 |
+
persistent_workers=True
|
| 296 |
+
)
|
| 297 |
+
logger.info(f"Dataset contains {len(train_dataset):,} images")
|
| 298 |
+
|
| 299 |
+
# Prepare models for training:
|
| 300 |
+
model.train() # important! This enables embedding dropout for classifier-free guidance
|
| 301 |
+
ema.eval() # EMA model should always be in eval mode
|
| 302 |
+
|
| 303 |
+
# Variables for monitoring/logging purposes:
|
| 304 |
+
log_steps = 0
|
| 305 |
+
running_loss = 0
|
| 306 |
+
start_time = time()
|
| 307 |
+
|
| 308 |
+
logger.info(f"Training for {args.epochs} epochs...")
|
| 309 |
+
for epoch in range(start_epoch, args.epochs):
|
| 310 |
+
sampler.set_epoch(epoch)
|
| 311 |
+
steps_per_epoch = len(loader)
|
| 312 |
+
if rank == 0:
|
| 313 |
+
logger.info(f"Epoch {epoch} contains {steps_per_epoch} steps.")
|
| 314 |
+
logger.info(f"Beginning epoch {epoch}...")
|
| 315 |
+
|
| 316 |
+
for _, x, y, diff, rel_t in loader:
|
| 317 |
+
x = x.to(device, non_blocking=True)
|
| 318 |
+
y = y.to(device, non_blocking=True)
|
| 319 |
+
diff = diff.to(device, non_blocking=True) # [REWARD]
|
| 320 |
+
rel_t = rel_t.to(device, non_blocking=True)
|
| 321 |
+
|
| 322 |
+
with torch.amp.autocast('cuda', enabled=bfloat_enable, dtype=torch.bfloat16):
|
| 323 |
+
with torch.no_grad():
|
| 324 |
+
# Map input images to latent space + normalize latents:
|
| 325 |
+
B, T = x.shape[:2]
|
| 326 |
+
x = x.flatten(0,1)
|
| 327 |
+
x = tokenizer.encoder(x)
|
| 328 |
+
x = x.unflatten(0, (B, T))
|
| 329 |
+
|
| 330 |
+
num_goals = T - num_cond
|
| 331 |
+
x_start = x[:, num_cond:].flatten(0, 1)
|
| 332 |
+
x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3]).flatten(0, 1)
|
| 333 |
+
y = y.flatten(0, 1)
|
| 334 |
+
rel_t = rel_t.flatten(0, 1)
|
| 335 |
+
|
| 336 |
+
diff = diff.flatten(0, 1)
|
| 337 |
+
diff_tok = diff.unsqueeze(1).expand(-1, 16, -1)
|
| 338 |
+
x_start = torch.cat([x_start, diff_tok], dim=2)
|
| 339 |
+
|
| 340 |
+
t = torch.randint(0, diffusion.num_timesteps, (x_start.shape[0],), device=device)
|
| 341 |
+
|
| 342 |
+
model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t)
|
| 343 |
+
loss_dict = diffusion.training_losses(model, x_start, t, model_kwargs)
|
| 344 |
+
loss = loss_dict["loss"].mean()
|
| 345 |
+
|
| 346 |
+
if not bfloat_enable:
|
| 347 |
+
opt.zero_grad()
|
| 348 |
+
loss.backward()
|
| 349 |
+
opt.step()
|
| 350 |
+
else:
|
| 351 |
+
scaler.scale(loss).backward()
|
| 352 |
+
if config.get('grad_clip_val', 0) > 0:
|
| 353 |
+
scaler.unscale_(opt)
|
| 354 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['grad_clip_val'])
|
| 355 |
+
scaler.step(opt)
|
| 356 |
+
scaler.update()
|
| 357 |
+
|
| 358 |
+
update_ema(ema, model.module)
|
| 359 |
+
|
| 360 |
+
# Log loss values:
|
| 361 |
+
running_loss += loss.detach().item()
|
| 362 |
+
log_steps += 1
|
| 363 |
+
train_steps += 1
|
| 364 |
+
if train_steps % args.log_every == 0:
|
| 365 |
+
# Measure training speed:
|
| 366 |
+
torch.cuda.synchronize()
|
| 367 |
+
end_time = time()
|
| 368 |
+
steps_per_sec = log_steps / (end_time - start_time)
|
| 369 |
+
samples_per_sec = dist.get_world_size()*x_cond.shape[0]*steps_per_sec
|
| 370 |
+
# Reduce loss history over all processes:
|
| 371 |
+
avg_loss = torch.tensor(running_loss / log_steps, device=device)
|
| 372 |
+
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
|
| 373 |
+
avg_loss = avg_loss.item() / dist.get_world_size()
|
| 374 |
+
total_steps = len(loader) * args.epochs
|
| 375 |
+
progress_pct = train_steps / total_steps * 100
|
| 376 |
+
|
| 377 |
+
remaining_steps = total_steps - train_steps
|
| 378 |
+
eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0
|
| 379 |
+
eta_hours = eta_seconds / 3600
|
| 380 |
+
|
| 381 |
+
logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Samples/Sec: {samples_per_sec:.2f}")
|
| 382 |
+
logger.info(f"Progress: {progress_pct:.2f}% | ETA: {eta_hours:.1f}h")
|
| 383 |
+
running_loss = 0
|
| 384 |
+
log_steps = 0
|
| 385 |
+
start_time = time()
|
| 386 |
+
|
| 387 |
+
# Save DiT checkpoint:
|
| 388 |
+
if train_steps % args.ckpt_every == 0 and train_steps > 0:
|
| 389 |
+
if rank == 0:
|
| 390 |
+
checkpoint = {
|
| 391 |
+
"model": model.module.state_dict(),
|
| 392 |
+
"ema": ema.state_dict(),
|
| 393 |
+
"opt": opt.state_dict(),
|
| 394 |
+
"args": args,
|
| 395 |
+
"epoch": epoch,
|
| 396 |
+
"train_steps": train_steps
|
| 397 |
+
}
|
| 398 |
+
if bfloat_enable:
|
| 399 |
+
checkpoint.update({"scaler": scaler.state_dict()})
|
| 400 |
+
checkpoint_path = f"{checkpoint_dir}/latest.pth.tar"
|
| 401 |
+
torch.save(checkpoint, checkpoint_path)
|
| 402 |
+
if train_steps % (10*args.ckpt_every) == 0 and train_steps > 0:
|
| 403 |
+
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pth.tar"
|
| 404 |
+
torch.save(checkpoint, checkpoint_path)
|
| 405 |
+
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
| 406 |
+
|
| 407 |
+
if train_steps % args.eval_every == 0 and train_steps > 0:
|
| 408 |
+
eval_start_time = time()
|
| 409 |
+
save_dir = os.path.join(experiment_dir, str(train_steps))
|
| 410 |
+
save_dir_train = os.path.join(experiment_dir, f"{train_steps}_train")
|
| 411 |
+
evaluate(ema, tokenizer, diffusion, test_dataset, rank, config["batch_size"], config["num_workers"], latent_size, device, save_dir_train, args.global_seed, bfloat_enable, num_cond, config["sample_rate"], config["input_sr"], logger)
|
| 412 |
+
dist.barrier()
|
| 413 |
+
eval_end_time = time()
|
| 414 |
+
eval_time = eval_end_time - eval_start_time
|
| 415 |
+
|
| 416 |
+
model.eval() # important! This disables randomized embedding dropout
|
| 417 |
+
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
|
| 418 |
+
|
| 419 |
+
logger.info("Done!")
|
| 420 |
+
cleanup()
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def denormalize_dis(ndata: float, min_v=-20.0, max_v=20.0, scale=0.15):
|
| 424 |
+
n01 = (float(ndata) + 1.0) / 2.0
|
| 425 |
+
raw = n01 * (max_v - min_v) + min_v
|
| 426 |
+
return raw * scale
|
| 427 |
+
|
| 428 |
+
@torch.no_grad()
|
| 429 |
+
def evaluate(model, vae, diffusion, test_dataloaders, rank, batch_size, num_workers, latent_size, device, save_dir, seed, bfloat_enable, num_cond, sample_rate, input_sr, logger):
|
| 430 |
+
sampler = DistributedSampler(
|
| 431 |
+
test_dataloaders,
|
| 432 |
+
num_replicas=dist.get_world_size(),
|
| 433 |
+
rank=rank,
|
| 434 |
+
shuffle=True,
|
| 435 |
+
seed=seed
|
| 436 |
+
)
|
| 437 |
+
loader = DataLoader(
|
| 438 |
+
test_dataloaders,
|
| 439 |
+
batch_size=batch_size,
|
| 440 |
+
shuffle=False,
|
| 441 |
+
sampler=sampler,
|
| 442 |
+
num_workers=num_workers,
|
| 443 |
+
pin_memory=True,
|
| 444 |
+
drop_last=True
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
down_resampler = torchaudio.transforms.Resample(orig_freq=input_sr, new_freq=sample_rate, lowpass_filter_width=64).to(device, dtype=torch.bfloat16) # [RESAMPLE]
|
| 448 |
+
mel_tf = build_mel_transform(
|
| 449 |
+
sample_rate=sample_rate,
|
| 450 |
+
n_fft=1024, win_length=1024, hop_length=256,
|
| 451 |
+
n_mels=80, power=1.0,
|
| 452 |
+
device=device,
|
| 453 |
+
)
|
| 454 |
+
# Run for 1 step
|
| 455 |
+
for _, x, y, diff, rel_t, x_orig in loader:
|
| 456 |
+
x = x.to(device)
|
| 457 |
+
y = y.to(device)
|
| 458 |
+
diff = diff.to(device).flatten(0, 1) # [REWARD]
|
| 459 |
+
rel_t = rel_t.to(device).flatten(0, 1)
|
| 460 |
+
x_orig = x_orig.to(device)
|
| 461 |
+
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
|
| 462 |
+
B, T = x.shape[:2]
|
| 463 |
+
num_goals = T - num_cond
|
| 464 |
+
samples, diff_pred = model_forward_wrapper_a((model, diffusion, vae), x, y, num_timesteps=None, latent_size=latent_size, device=device, num_cond=num_cond, num_goals=num_goals, rel_t=rel_t)
|
| 465 |
+
|
| 466 |
+
decoded = down_resampler(samples)
|
| 467 |
+
|
| 468 |
+
x_start_pixels = x_orig[:, num_cond:].flatten(0, 1)
|
| 469 |
+
x_cond_pixels = x_orig[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_orig.shape[2], x_orig.shape[3]).flatten(0, 1)
|
| 470 |
+
break
|
| 471 |
+
|
| 472 |
+
if rank == 0:
|
| 473 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 474 |
+
|
| 475 |
+
num_save = min(samples.shape[0], 10)
|
| 476 |
+
|
| 477 |
+
if diff is not None: # [REWARD]
|
| 478 |
+
mae = torch.mean(torch.abs(diff_pred - diff))
|
| 479 |
+
logger.info(f"Distance Diff MAE = {mae.item():.6f}")
|
| 480 |
+
mel_cosine_ls=[]
|
| 481 |
+
for i in range(num_save):
|
| 482 |
+
mel_cos = mel_cosine_stereo(x_start_pixels[i], decoded[i], sample_rate=sample_rate, mel_tf=mel_tf)
|
| 483 |
+
mel_cosine_ls.append(mel_cos)
|
| 484 |
+
ok = save_ref_hat_spectrogram_panel(
|
| 485 |
+
x_start_pixels[i], decoded[i],
|
| 486 |
+
out_path=f"{save_dir}/{i}_spectrograms.png",
|
| 487 |
+
n_fft=512, hop_length=160, win_length=400, pool=4,
|
| 488 |
+
title="gt vs pred"
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
torchaudio.save(f"{save_dir}/{i}_gen.wav", decoded[i].cpu().to(torch.float32), sample_rate=sample_rate)
|
| 492 |
+
torchaudio.save(f"{save_dir}/{i}_gt.wav", x_start_pixels[i].cpu().to(torch.float32), sample_rate=sample_rate)
|
| 493 |
+
torchaudio.save(f"{save_dir}/{i}_cond.wav", x_cond_pixels[i, -1].cpu().to(torch.float32), sample_rate=sample_rate)
|
| 494 |
+
|
| 495 |
+
logger.info("the first 10 mel cosine: " + ", ".join(f"{v:.6f}" for v in mel_cosine_ls))
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def get_args_parser():
|
| 499 |
+
parser = argparse.ArgumentParser()
|
| 500 |
+
parser.add_argument("--config", type=str, required=True)
|
| 501 |
+
parser.add_argument("--epochs", type=int, default=300)
|
| 502 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
| 503 |
+
parser.add_argument("--log-every", type=int, default=100)
|
| 504 |
+
parser.add_argument("--ckpt-every", type=int, default=2000)
|
| 505 |
+
parser.add_argument("--eval-every", type=int, default=5000)
|
| 506 |
+
parser.add_argument("--bfloat16", type=int, default=1)
|
| 507 |
+
parser.add_argument("--torch-compile", type=int, default=1)
|
| 508 |
+
parser.add_argument("--restart-from-checkpoint", type=int, default=0,
|
| 509 |
+
help="If 1, only load model weights and reset epoch/step to zero (cold start)")
|
| 510 |
+
return parser
|
| 511 |
+
|
| 512 |
+
if __name__ == "__main__":
|
| 513 |
+
args = get_args_parser().parse_args()
|
| 514 |
+
main(args)
|
train_avwm_stage3.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# NoMaD, GNM, ViNT: https://github.com/robodhruv/visualnav-transformer
|
| 9 |
+
# --------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
from inference_avwm import model_forward_wrapper_av
|
| 12 |
+
import torch
|
| 13 |
+
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
|
| 14 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 15 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 16 |
+
|
| 17 |
+
import matplotlib
|
| 18 |
+
matplotlib.use('Agg')
|
| 19 |
+
from collections import OrderedDict
|
| 20 |
+
from copy import deepcopy
|
| 21 |
+
from time import time
|
| 22 |
+
import argparse
|
| 23 |
+
import logging
|
| 24 |
+
import os
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
import yaml
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
import torch.distributed as dist
|
| 30 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 31 |
+
from torch.utils.data import DataLoader, ConcatDataset
|
| 32 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 33 |
+
from diffusers.models import AutoencoderKL
|
| 34 |
+
|
| 35 |
+
from distributed import init_distributed
|
| 36 |
+
from models import AVCDiT_models
|
| 37 |
+
from diffusion import create_diffusion
|
| 38 |
+
from datasets import TrainingDataset
|
| 39 |
+
from misc import transform
|
| 40 |
+
from soundstream import SoundStream
|
| 41 |
+
import torchaudio
|
| 42 |
+
from eval_audio import build_mel_transform, mel_cosine_stereo, drms_avg_db_stereo, save_ref_hat_spectrogram_panel
|
| 43 |
+
|
| 44 |
+
#################################################################################
|
| 45 |
+
# Training Helper Functions #
|
| 46 |
+
#################################################################################
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_checkpoint_if_available(model, ema, opt, scaler, config, device, logger, args):
|
| 50 |
+
start_epoch = 0
|
| 51 |
+
train_steps = 0
|
| 52 |
+
latest_path = os.path.join(config['results_dir'], config['run_name'], "checkpoints", "latest.pth.tar")
|
| 53 |
+
if os.path.isfile(latest_path) or config.get('from_checkpoint', 0):
|
| 54 |
+
latest_path = latest_path if os.path.isfile(latest_path) else config.get('from_checkpoint', 0)
|
| 55 |
+
print("Loading model from ", latest_path)
|
| 56 |
+
checkpoint = torch.load(latest_path, map_location=f"cuda:{device}", weights_only=False)
|
| 57 |
+
|
| 58 |
+
ema_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["ema"].items()}
|
| 59 |
+
model.load_state_dict(ema_ckp, strict=False)
|
| 60 |
+
print("Model weights loaded.")
|
| 61 |
+
ema.load_state_dict(ema_ckp, strict=False)
|
| 62 |
+
print("EMA weights loaded.")
|
| 63 |
+
|
| 64 |
+
if args.restart_from_checkpoint:
|
| 65 |
+
logger.info("Restarting training: epoch and step counters set to 0.")
|
| 66 |
+
else:
|
| 67 |
+
if "opt" in checkpoint:
|
| 68 |
+
opt_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["opt"].items()}
|
| 69 |
+
opt.load_state_dict(opt_ckp)
|
| 70 |
+
print("Optimizer state loaded.")
|
| 71 |
+
if "scaler" in checkpoint and scaler is not None:
|
| 72 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
| 73 |
+
print("GradScaler state loaded.")
|
| 74 |
+
if "epoch" in checkpoint:
|
| 75 |
+
start_epoch = checkpoint["epoch"] + 1
|
| 76 |
+
if "train_steps" in checkpoint:
|
| 77 |
+
train_steps = checkpoint["train_steps"]
|
| 78 |
+
logger.info(f"Resuming from epoch {start_epoch}, step {train_steps}")
|
| 79 |
+
|
| 80 |
+
return start_epoch, train_steps
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@torch.no_grad()
|
| 84 |
+
def update_ema(ema_model, model, decay=0.9999):
|
| 85 |
+
"""
|
| 86 |
+
Step the EMA model towards the current model.
|
| 87 |
+
"""
|
| 88 |
+
ema_params = OrderedDict(ema_model.named_parameters())
|
| 89 |
+
model_params = OrderedDict(model.named_parameters())
|
| 90 |
+
|
| 91 |
+
for name, param in model_params.items():
|
| 92 |
+
name = name.replace('_orig_mod.', '')
|
| 93 |
+
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def requires_grad(model, flag=True):
|
| 97 |
+
"""
|
| 98 |
+
Set requires_grad flag for all parameters in a model.
|
| 99 |
+
"""
|
| 100 |
+
for p in model.parameters():
|
| 101 |
+
p.requires_grad = flag
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def cleanup():
|
| 105 |
+
"""
|
| 106 |
+
End DDP training.
|
| 107 |
+
"""
|
| 108 |
+
dist.destroy_process_group()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def create_logger(logging_dir):
|
| 112 |
+
"""
|
| 113 |
+
Create a logger that writes to a log file and stdout.
|
| 114 |
+
"""
|
| 115 |
+
if dist.get_rank() == 0: # real logger
|
| 116 |
+
logging.basicConfig(
|
| 117 |
+
level=logging.INFO,
|
| 118 |
+
format='[\033[34m%(asctime)s\033[0m] %(message)s',
|
| 119 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
| 120 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
|
| 121 |
+
)
|
| 122 |
+
logger = logging.getLogger(__name__)
|
| 123 |
+
else: # dummy logger (does nothing)
|
| 124 |
+
logger = logging.getLogger(__name__)
|
| 125 |
+
logger.addHandler(logging.NullHandler())
|
| 126 |
+
return logger
|
| 127 |
+
|
| 128 |
+
#################################################################################
|
| 129 |
+
# Training Loop #
|
| 130 |
+
#################################################################################
|
| 131 |
+
|
| 132 |
+
def main(args):
|
| 133 |
+
"""
|
| 134 |
+
Trains a new AVCDiT model.
|
| 135 |
+
"""
|
| 136 |
+
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
|
| 137 |
+
|
| 138 |
+
# Setup DDP:
|
| 139 |
+
_, rank, device, _ = init_distributed()
|
| 140 |
+
# rank = dist.get_rank()
|
| 141 |
+
seed = args.global_seed * dist.get_world_size() + rank
|
| 142 |
+
torch.manual_seed(seed)
|
| 143 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
| 144 |
+
with open("config/eval_config.yaml", "r") as f:
|
| 145 |
+
default_config = yaml.safe_load(f)
|
| 146 |
+
config = default_config
|
| 147 |
+
|
| 148 |
+
with open(args.config, "r") as f:
|
| 149 |
+
user_config = yaml.safe_load(f)
|
| 150 |
+
config.update(user_config)
|
| 151 |
+
|
| 152 |
+
# Setup an experiment folder:
|
| 153 |
+
os.makedirs(config['results_dir'], exist_ok=True) # Make results folder (holds all experiment subfolders)
|
| 154 |
+
experiment_dir = f"{config['results_dir']}/{config['run_name']}" # Create an experiment folder
|
| 155 |
+
checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
|
| 156 |
+
if rank == 0:
|
| 157 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 158 |
+
logger = create_logger(experiment_dir)
|
| 159 |
+
logger.info(f"Experiment directory created at {experiment_dir}")
|
| 160 |
+
else:
|
| 161 |
+
logger = create_logger(None)
|
| 162 |
+
|
| 163 |
+
# Create model:
|
| 164 |
+
tokenizer_v = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device)
|
| 165 |
+
|
| 166 |
+
tokenizer_a = SoundStream(C=32, D=16, n_q=8, codebook_size=1024).to(device)
|
| 167 |
+
tokenizer_a_path=config["tokenizer_a_path"]
|
| 168 |
+
tokenizer_a_checkpoint = torch.load(tokenizer_a_path, map_location=f"cuda:{device}")
|
| 169 |
+
tokenizer_a.load_state_dict(tokenizer_a_checkpoint["model_state"])
|
| 170 |
+
tokenizer_a.eval()
|
| 171 |
+
|
| 172 |
+
latent_size = config['image_size'] // 8
|
| 173 |
+
|
| 174 |
+
assert config['image_size'] % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
|
| 175 |
+
num_cond = config['context_size']
|
| 176 |
+
model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4).to(device)
|
| 177 |
+
|
| 178 |
+
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
|
| 179 |
+
requires_grad(ema, False)
|
| 180 |
+
|
| 181 |
+
# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
|
| 182 |
+
lr = float(config.get('lr', 1e-4))
|
| 183 |
+
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
bfloat_enable = bool(hasattr(args, 'bfloat16') and args.bfloat16)
|
| 187 |
+
if bfloat_enable:
|
| 188 |
+
scaler = torch.amp.GradScaler()
|
| 189 |
+
|
| 190 |
+
start_epoch, train_steps = load_checkpoint_if_available(
|
| 191 |
+
model, ema, opt, scaler if bfloat_enable else None, config, device, logger, args
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# ~40% speedup but might leads to worse performance depending on pytorch version
|
| 195 |
+
if args.torch_compile:
|
| 196 |
+
model = torch.compile(model)
|
| 197 |
+
model = DDP(model, device_ids=[device])
|
| 198 |
+
diffusion = create_diffusion(timestep_respacing="", dual=True) # default: 1000 steps, linear noise schedule
|
| 199 |
+
# ,predict_xstart=True
|
| 200 |
+
logger.info(f"AVCDiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 201 |
+
|
| 202 |
+
train_dataset = []
|
| 203 |
+
test_dataset = []
|
| 204 |
+
|
| 205 |
+
for dataset_name in config["datasets"]:
|
| 206 |
+
data_config = config["datasets"][dataset_name]
|
| 207 |
+
|
| 208 |
+
for data_split_type in ["train", "test"]:
|
| 209 |
+
if data_split_type in data_config:
|
| 210 |
+
goals_per_obs = int(data_config["goals_per_obs"])
|
| 211 |
+
if data_split_type == 'test':
|
| 212 |
+
goals_per_obs = 4 # standardize testing
|
| 213 |
+
|
| 214 |
+
if "distance" in data_config:
|
| 215 |
+
min_dist_cat=data_config["distance"]["min_dist_cat"]
|
| 216 |
+
max_dist_cat=data_config["distance"]["max_dist_cat"]
|
| 217 |
+
else:
|
| 218 |
+
min_dist_cat=config["distance"]["min_dist_cat"]
|
| 219 |
+
max_dist_cat=config["distance"]["max_dist_cat"]
|
| 220 |
+
|
| 221 |
+
if "len_traj_pred" in data_config:
|
| 222 |
+
len_traj_pred=data_config["len_traj_pred"]
|
| 223 |
+
else:
|
| 224 |
+
len_traj_pred=config["len_traj_pred"]
|
| 225 |
+
|
| 226 |
+
dataset = TrainingDataset(
|
| 227 |
+
data_folder=data_config["data_folder"],
|
| 228 |
+
data_split_folder=data_config[data_split_type],
|
| 229 |
+
dataset_name=dataset_name,
|
| 230 |
+
image_size=config["image_size"],
|
| 231 |
+
min_dist_cat=min_dist_cat,
|
| 232 |
+
max_dist_cat=max_dist_cat,
|
| 233 |
+
len_traj_pred=len_traj_pred,
|
| 234 |
+
context_size=config["context_size"],
|
| 235 |
+
normalize=config["normalize"],
|
| 236 |
+
goals_per_obs=goals_per_obs,
|
| 237 |
+
transform=transform,
|
| 238 |
+
predefined_index=None,
|
| 239 |
+
traj_stride=1,
|
| 240 |
+
sample_rate=config["sample_rate"],
|
| 241 |
+
# target_len=7840 #TODO
|
| 242 |
+
input_sr=config["input_sr"],
|
| 243 |
+
evaluate=(data_split_type=="test")
|
| 244 |
+
)
|
| 245 |
+
if data_split_type == "train":
|
| 246 |
+
train_dataset.append(dataset)
|
| 247 |
+
else:
|
| 248 |
+
test_dataset.append(dataset)
|
| 249 |
+
print(f"Dataset: {dataset_name} ({data_split_type}), size: {len(dataset)}")
|
| 250 |
+
|
| 251 |
+
# combine all the datasets from different robots
|
| 252 |
+
print(f"Combining {len(train_dataset)} datasets.")
|
| 253 |
+
train_dataset = ConcatDataset(train_dataset)
|
| 254 |
+
test_dataset = ConcatDataset(test_dataset)
|
| 255 |
+
|
| 256 |
+
sampler = DistributedSampler(
|
| 257 |
+
train_dataset,
|
| 258 |
+
num_replicas=dist.get_world_size(),
|
| 259 |
+
rank=rank,
|
| 260 |
+
shuffle=True,
|
| 261 |
+
seed=args.global_seed
|
| 262 |
+
)
|
| 263 |
+
loader = DataLoader(
|
| 264 |
+
train_dataset,
|
| 265 |
+
batch_size=config['batch_size'],
|
| 266 |
+
shuffle=False,
|
| 267 |
+
sampler=sampler,
|
| 268 |
+
num_workers=config['num_workers'],
|
| 269 |
+
pin_memory=True,
|
| 270 |
+
drop_last=True,
|
| 271 |
+
persistent_workers=True
|
| 272 |
+
)
|
| 273 |
+
logger.info(f"Dataset contains {len(train_dataset):,} images")
|
| 274 |
+
|
| 275 |
+
# Prepare models for training:
|
| 276 |
+
model.train() # important! This enables embedding dropout for classifier-free guidance
|
| 277 |
+
ema.eval() # EMA model should always be in eval mode
|
| 278 |
+
|
| 279 |
+
# Variables for monitoring/logging purposes:
|
| 280 |
+
log_steps = 0
|
| 281 |
+
running_loss = 0
|
| 282 |
+
start_time = time()
|
| 283 |
+
|
| 284 |
+
logger.info(f"Training for {args.epochs} epochs...")
|
| 285 |
+
for epoch in range(start_epoch, args.epochs):
|
| 286 |
+
sampler.set_epoch(epoch)
|
| 287 |
+
steps_per_epoch = len(loader)
|
| 288 |
+
if rank == 0:
|
| 289 |
+
logger.info(f"Epoch {epoch} contains {steps_per_epoch} steps.")
|
| 290 |
+
logger.info(f"Beginning epoch {epoch}...")
|
| 291 |
+
|
| 292 |
+
for x_v, x_a, y, diff, rel_t in loader:
|
| 293 |
+
x_v = x_v.to(device, non_blocking=True)
|
| 294 |
+
x_a = x_a.to(device, non_blocking=True)
|
| 295 |
+
y = y.to(device, non_blocking=True)
|
| 296 |
+
diff = diff.to(device, non_blocking=True)
|
| 297 |
+
rel_t = rel_t.to(device, non_blocking=True)
|
| 298 |
+
|
| 299 |
+
with torch.amp.autocast('cuda', enabled=bfloat_enable, dtype=torch.bfloat16):
|
| 300 |
+
with torch.no_grad():
|
| 301 |
+
# Map input images to latent space + normalize latents:
|
| 302 |
+
B, T = x_v.shape[:2]
|
| 303 |
+
#=== vision observation encoding
|
| 304 |
+
x_v = x_v.flatten(0,1)
|
| 305 |
+
x_v = tokenizer_v.encode(x_v).latent_dist.sample().mul_(0.18215)
|
| 306 |
+
x_v = x_v.unflatten(0, (B, T))
|
| 307 |
+
#=== audio observation encoding
|
| 308 |
+
x_a = x_a.flatten(0,1)
|
| 309 |
+
x_a = tokenizer_a.encoder(x_a)
|
| 310 |
+
x_a = x_a.unflatten(0, (B, T))
|
| 311 |
+
|
| 312 |
+
num_goals = T - num_cond
|
| 313 |
+
#=== split into target and condition
|
| 314 |
+
x_v_start = x_v[:, num_cond:].flatten(0, 1)
|
| 315 |
+
x_v_cond = x_v[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_v.shape[2], x_v.shape[3], x_v.shape[4]).flatten(0, 1)
|
| 316 |
+
x_a_start = x_a[:, num_cond:].flatten(0, 1)
|
| 317 |
+
x_a_cond = x_a[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a.shape[2], x_a.shape[3]).flatten(0, 1)
|
| 318 |
+
#===
|
| 319 |
+
y = y.flatten(0, 1)
|
| 320 |
+
rel_t = rel_t.flatten(0, 1)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
diff = diff.flatten(0, 1) # [N, 1]
|
| 325 |
+
diff_tok = diff.unsqueeze(1).expand(-1, 16, -1) # [N, 64, 1]
|
| 326 |
+
x_a_start = torch.cat([x_a_start, diff_tok], dim=2) # [N, 64, 181]
|
| 327 |
+
|
| 328 |
+
t = torch.randint(0, diffusion.num_timesteps, (x_v_start.shape[0],), device=device)
|
| 329 |
+
model_kwargs = dict(y=y, x_v_cond=x_v_cond, x_a_cond=x_a_cond, rel_t=rel_t)
|
| 330 |
+
loss_dict = diffusion.training_losses(model, x_v_start, x_a_start, t, model_kwargs)
|
| 331 |
+
loss = loss_dict["loss"].mean()
|
| 332 |
+
|
| 333 |
+
if not bfloat_enable:
|
| 334 |
+
opt.zero_grad()
|
| 335 |
+
loss.backward()
|
| 336 |
+
opt.step()
|
| 337 |
+
else:
|
| 338 |
+
scaler.scale(loss).backward()
|
| 339 |
+
if config.get('grad_clip_val', 0) > 0:
|
| 340 |
+
scaler.unscale_(opt)
|
| 341 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['grad_clip_val'])
|
| 342 |
+
scaler.step(opt)
|
| 343 |
+
scaler.update()
|
| 344 |
+
|
| 345 |
+
update_ema(ema, model.module)
|
| 346 |
+
|
| 347 |
+
# Log loss values:
|
| 348 |
+
running_loss += loss.detach().item()
|
| 349 |
+
log_steps += 1
|
| 350 |
+
train_steps += 1
|
| 351 |
+
if train_steps % args.log_every == 0:
|
| 352 |
+
# Measure training speed:
|
| 353 |
+
torch.cuda.synchronize()
|
| 354 |
+
end_time = time()
|
| 355 |
+
steps_per_sec = log_steps / (end_time - start_time)
|
| 356 |
+
samples_per_sec = dist.get_world_size()*x_v_cond.shape[0]*steps_per_sec
|
| 357 |
+
# Reduce loss history over all processes:
|
| 358 |
+
avg_loss = torch.tensor(running_loss / log_steps, device=device)
|
| 359 |
+
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
|
| 360 |
+
avg_loss = avg_loss.item() / dist.get_world_size()
|
| 361 |
+
total_steps = len(loader) * args.epochs
|
| 362 |
+
progress_pct = train_steps / total_steps * 100
|
| 363 |
+
|
| 364 |
+
remaining_steps = total_steps - train_steps
|
| 365 |
+
eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0
|
| 366 |
+
eta_hours = eta_seconds / 3600
|
| 367 |
+
|
| 368 |
+
logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Samples/Sec: {samples_per_sec:.2f}")
|
| 369 |
+
logger.info(f"Progress: {progress_pct:.2f}% | ETA: {eta_hours:.1f}h")
|
| 370 |
+
# Reset monitoring variables:
|
| 371 |
+
running_loss = 0
|
| 372 |
+
log_steps = 0
|
| 373 |
+
start_time = time()
|
| 374 |
+
|
| 375 |
+
# Save DiT checkpoint:
|
| 376 |
+
if train_steps % args.ckpt_every == 0 and train_steps > 0:
|
| 377 |
+
if rank == 0:
|
| 378 |
+
checkpoint = {
|
| 379 |
+
"model": model.module.state_dict(),
|
| 380 |
+
"ema": ema.state_dict(),
|
| 381 |
+
"opt": opt.state_dict(),
|
| 382 |
+
"args": args,
|
| 383 |
+
"epoch": epoch,
|
| 384 |
+
"train_steps": train_steps
|
| 385 |
+
}
|
| 386 |
+
if bfloat_enable:
|
| 387 |
+
checkpoint.update({"scaler": scaler.state_dict()})
|
| 388 |
+
checkpoint_path = f"{checkpoint_dir}/latest.pth.tar"
|
| 389 |
+
torch.save(checkpoint, checkpoint_path)
|
| 390 |
+
if train_steps % (10*args.ckpt_every) == 0 and train_steps > 0:
|
| 391 |
+
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pth.tar"
|
| 392 |
+
torch.save(checkpoint, checkpoint_path)
|
| 393 |
+
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
| 394 |
+
|
| 395 |
+
if train_steps % args.eval_every == 0 and train_steps > 0:
|
| 396 |
+
eval_start_time = time()
|
| 397 |
+
# validation / test set evaluation
|
| 398 |
+
save_dir = os.path.join(experiment_dir, str(train_steps))
|
| 399 |
+
sim_score_val = evaluate(ema, tokenizer_v, tokenizer_a, diffusion, test_dataset, rank, config["batch_size"], config["num_workers"], latent_size, device, save_dir, args.global_seed, bfloat_enable, num_cond, config["sample_rate"], config["input_sr"], logger)
|
| 400 |
+
dist.barrier()
|
| 401 |
+
eval_end_time = time()
|
| 402 |
+
eval_time = eval_end_time - eval_start_time
|
| 403 |
+
# logger.info(f"(step={train_steps:07d}) Val Perceptual Loss: {sim_score_val:.4f}, Train Perceptual Loss: {sim_score_train:.4f}, Eval Time: {eval_time:.2f}")
|
| 404 |
+
logger.info(f"(step={train_steps:07d}) Val Perceptual Loss: {sim_score_val:.4f}, Eval Time: {eval_time:.2f}")
|
| 405 |
+
|
| 406 |
+
model.eval() # important! This disables randomized embedding dropout
|
| 407 |
+
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
|
| 408 |
+
|
| 409 |
+
logger.info("Done!")
|
| 410 |
+
cleanup()
|
| 411 |
+
|
| 412 |
+
def denormalize_dis(ndata: float, min_v=-20.0, max_v=20.0, scale=0.15):
|
| 413 |
+
n01 = (float(ndata) + 1.0) / 2.0
|
| 414 |
+
raw = n01 * (max_v - min_v) + min_v
|
| 415 |
+
return raw * scale
|
| 416 |
+
|
| 417 |
+
@torch.no_grad
|
| 418 |
+
def evaluate(model, vae, sstream, diffusion, test_dataloaders, rank, batch_size, num_workers, latent_size, device, save_dir, seed, bfloat_enable, num_cond, sample_rate, input_sr, logger):
|
| 419 |
+
sampler = DistributedSampler(
|
| 420 |
+
test_dataloaders,
|
| 421 |
+
num_replicas=dist.get_world_size(),
|
| 422 |
+
rank=rank,
|
| 423 |
+
shuffle=True,
|
| 424 |
+
seed=seed
|
| 425 |
+
)
|
| 426 |
+
loader = DataLoader(
|
| 427 |
+
test_dataloaders,
|
| 428 |
+
batch_size=batch_size,
|
| 429 |
+
shuffle=False,
|
| 430 |
+
sampler=sampler,
|
| 431 |
+
num_workers=num_workers,
|
| 432 |
+
pin_memory=True,
|
| 433 |
+
drop_last=True
|
| 434 |
+
)
|
| 435 |
+
from dreamsim import dreamsim
|
| 436 |
+
eval_model, _ = dreamsim(pretrained=True)
|
| 437 |
+
score = torch.tensor(0.).to(device)
|
| 438 |
+
n_samples = torch.tensor(0).to(device)
|
| 439 |
+
|
| 440 |
+
down_resampler = torchaudio.transforms.Resample(orig_freq=input_sr, new_freq=sample_rate, lowpass_filter_width=64).to(device, dtype=torch.bfloat16)
|
| 441 |
+
mel_tf = build_mel_transform(
|
| 442 |
+
sample_rate=sample_rate,
|
| 443 |
+
n_fft=1024, win_length=1024, hop_length=256,
|
| 444 |
+
n_mels=80, power=1.0,
|
| 445 |
+
device=device, # or ref.device
|
| 446 |
+
)
|
| 447 |
+
# Run for 1 step
|
| 448 |
+
for x_v, x_a, y, diff, rel_t, x_a_orig in loader:
|
| 449 |
+
x_v = x_v.to(device)
|
| 450 |
+
x_a = x_a.to(device)
|
| 451 |
+
x_a_orig = x_a_orig.to(device)
|
| 452 |
+
y = y.to(device)
|
| 453 |
+
diff = diff.to(device).flatten(0, 1)
|
| 454 |
+
rel_t = rel_t.to(device).flatten(0, 1)
|
| 455 |
+
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
|
| 456 |
+
B, T = x_v.shape[:2]
|
| 457 |
+
num_goals = T - num_cond
|
| 458 |
+
samples_v, samples_a, diff_pred = model_forward_wrapper_av((model, diffusion, vae, sstream), (x_v, x_a), y, num_timesteps=None, latent_size=latent_size, device=device, num_cond=num_cond, num_goals=num_goals, rel_t=rel_t)
|
| 459 |
+
|
| 460 |
+
samples_a = down_resampler(samples_a) #
|
| 461 |
+
|
| 462 |
+
x_start_pixels = x_v[:, num_cond:].flatten(0, 1)
|
| 463 |
+
x_cond_pixels = x_v[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_v.shape[2], x_v.shape[3], x_v.shape[4]).flatten(0, 1)
|
| 464 |
+
samples_v = samples_v * 0.5 + 0.5
|
| 465 |
+
x_start_pixels = x_start_pixels * 0.5 + 0.5
|
| 466 |
+
x_cond_pixels = x_cond_pixels * 0.5 + 0.5
|
| 467 |
+
res = eval_model(x_start_pixels, samples_v)
|
| 468 |
+
score += res.sum()
|
| 469 |
+
n_samples += len(res)
|
| 470 |
+
|
| 471 |
+
# x_start_audio = x_a[:, num_cond:].flatten(0, 1)
|
| 472 |
+
# x_cond_audio = x_a[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a.shape[2], x_a.shape[3]).flatten(0, 1)
|
| 473 |
+
x_start_audio = x_a_orig[:, num_cond:].flatten(0, 1)
|
| 474 |
+
x_cond_audio = x_a_orig[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x_a_orig.shape[2], x_a_orig.shape[3]).flatten(0, 1)
|
| 475 |
+
break
|
| 476 |
+
|
| 477 |
+
if rank == 0:
|
| 478 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 479 |
+
|
| 480 |
+
if diff is not None:
|
| 481 |
+
mae = torch.mean(torch.abs(diff_pred - diff))
|
| 482 |
+
logger.info(f"Distance Diff MAE = {mae.item():.6f}")
|
| 483 |
+
|
| 484 |
+
mel_cosine_ls=[]
|
| 485 |
+
for i in range(min(samples_v.shape[0], 10)):
|
| 486 |
+
_, ax = plt.subplots(1,3,dpi=256)
|
| 487 |
+
ax[0].imshow((x_cond_pixels[i, -1].permute(1,2,0).cpu().numpy()*255).astype('uint8'))
|
| 488 |
+
ax[1].imshow((x_start_pixels[i].permute(1,2,0).cpu().numpy()*255).astype('uint8'))
|
| 489 |
+
ax[2].imshow((samples_v[i].permute(1,2,0).cpu().float().numpy()*255).astype('uint8'))
|
| 490 |
+
plt.savefig(f'{save_dir}/{i}.png')
|
| 491 |
+
plt.close()
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
mel_cos = mel_cosine_stereo(x_start_audio[i], samples_a[i], sample_rate=sample_rate, mel_tf=mel_tf)
|
| 495 |
+
mel_cosine_ls.append(mel_cos)
|
| 496 |
+
ok = save_ref_hat_spectrogram_panel(
|
| 497 |
+
x_start_audio[i], samples_a[i],
|
| 498 |
+
out_path=f"{save_dir}/{i}_spectrograms.png",
|
| 499 |
+
n_fft=512, hop_length=160, win_length=400, pool=4,
|
| 500 |
+
title="gt vs pred"
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
# sr = int(16000 * 7840 / 2400) #TODO
|
| 504 |
+
torchaudio.save(f"{save_dir}/{i}_gen.wav", samples_a[i].cpu().to(torch.float32), sample_rate=sample_rate)
|
| 505 |
+
torchaudio.save(f"{save_dir}/{i}_gt.wav", x_start_audio[i].cpu().to(torch.float32), sample_rate=sample_rate)
|
| 506 |
+
torchaudio.save(f"{save_dir}/{i}_cond.wav", x_cond_audio[i, -1].cpu().to(torch.float32), sample_rate=sample_rate)
|
| 507 |
+
logger.info("the first 10 mel cosine: " + ", ".join(f"{v:.6f}" for v in mel_cosine_ls))
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
dist.all_reduce(score)
|
| 511 |
+
dist.all_reduce(n_samples)
|
| 512 |
+
sim_score = score/n_samples
|
| 513 |
+
return sim_score
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def get_args_parser():
|
| 517 |
+
parser = argparse.ArgumentParser()
|
| 518 |
+
parser.add_argument("--config", type=str, required=True)
|
| 519 |
+
parser.add_argument("--epochs", type=int, default=300)
|
| 520 |
+
parser.add_argument("--global-seed", type=int, default=0)
|
| 521 |
+
parser.add_argument("--log-every", type=int, default=100)
|
| 522 |
+
parser.add_argument("--ckpt-every", type=int, default=2000)
|
| 523 |
+
parser.add_argument("--eval-every", type=int, default=5000)
|
| 524 |
+
parser.add_argument("--bfloat16", type=int, default=1)
|
| 525 |
+
parser.add_argument("--torch-compile", type=int, default=1)
|
| 526 |
+
parser.add_argument("--restart-from-checkpoint", type=int, default=0,
|
| 527 |
+
help="If 1, only load model weights and reset epoch/step to zero (cold start)")
|
| 528 |
+
return parser
|
| 529 |
+
|
| 530 |
+
if __name__ == "__main__":
|
| 531 |
+
args = get_args_parser().parse_args()
|
| 532 |
+
main(args)
|