Spaces:
Running on Zero
Running on Zero
File size: 10,551 Bytes
6215e7d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 | import torch
from torch import nn
import typing as tp
from .conditioners import MultiConditioner
from .dit import DiffusionTransformer
from stable_audio_3.inference.distribution_shift import FluxDistributionShift, DistributionShift, LogSNRShift, IdentityDistributionShift
from time import time
class Profiler:
def __init__(self):
self.ticks = [[time(), None]]
def tick(self, msg):
self.ticks.append([time(), msg])
def __repr__(self):
rep = 80 * "=" + "\n"
for i in range(1, len(self.ticks)):
msg = self.ticks[i][1]
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
rep += msg + f": {ellapsed*1000:.2f}ms\n"
rep += 80 * "=" + "\n\n\n"
return rep
class ConditionedDiffusionModelWrapper(nn.Module):
"""
A diffusion model that takes in conditioning
"""
def __init__(
self,
model: nn.Module,
conditioner: MultiConditioner,
io_channels,
sample_rate,
min_input_length: int,
diffusion_objective: tp.Literal["v", "rectified_flow", "rf_denoiser"] = "v",
distribution_shift_options = None,
sampling_distribution_shift_options = None,
mask_padding_attention: bool = False,
use_effective_length_for_schedule: bool = False,
pretransform: tp.Optional[nn.Module] = None,
cross_attn_cond_ids: tp.List[str] = [],
global_cond_ids: tp.List[str] = [],
input_concat_ids: tp.List[str] = [],
local_add_cond_ids: tp.List[str] = [],
modular_local_cond_ids: tp.List[str] = [],
prepend_cond_ids: tp.List[str] = [],
):
super().__init__()
self.model = model
self.conditioner = conditioner
self.io_channels = io_channels
self.sample_rate = sample_rate
self.diffusion_objective = diffusion_objective
self.pretransform = pretransform
self.cross_attn_cond_ids = cross_attn_cond_ids
self.global_cond_ids = global_cond_ids
self.input_concat_ids = input_concat_ids
self.local_add_cond_ids = local_add_cond_ids
self.modular_local_cond_ids = modular_local_cond_ids
self.prepend_cond_ids = prepend_cond_ids
self.min_input_length = min_input_length
self.mask_padding_attention = mask_padding_attention
self.use_effective_length_for_schedule = use_effective_length_for_schedule
self.dist_shift = None
if distribution_shift_options is not None:
self.dist_shift = self._create_dist_shift(distribution_shift_options)
# Sampling dist_shift: separate config for inference-time schedule
if sampling_distribution_shift_options is not None:
self.sampling_dist_shift = self._create_dist_shift(sampling_distribution_shift_options)
else:
# Default: seq_len-invariant LogSNR shift matching legacy log_snr_sampling=True
self.sampling_dist_shift = LogSNRShift(rate=0, anchor_logsnr=-6.2, logsnr_end=2.0)
@staticmethod
def _create_dist_shift(options: dict):
"""Create a distribution shift object from config options."""
dist_shift_type = options.get("type", "full")
dist_shift_kwargs = {k: v for k, v in options.items() if k != "type"}
if dist_shift_type == "none":
return IdentityDistributionShift()
elif dist_shift_type == "flux":
return FluxDistributionShift(**dist_shift_kwargs)
elif dist_shift_type == "full":
return DistributionShift(**dist_shift_kwargs)
elif dist_shift_type == "logsnr":
return LogSNRShift(**dist_shift_kwargs)
else:
raise ValueError(f"Unknown distribution shift type: {dist_shift_type}. Expected 'none', 'flux', 'full', or 'logsnr'.")
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
cross_attention_input = None
cross_attention_masks = None
global_cond = None
input_concat_cond = None
prepend_cond = None
prepend_cond_mask = None
local_add_cond = None
modular_local_cond = None
if len(self.cross_attn_cond_ids) > 0:
# Concatenate all cross-attention inputs over the sequence dimension
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
cross_attention_input = []
cross_attention_masks = []
for key in self.cross_attn_cond_ids:
cross_attn_in, cross_attn_mask = conditioning_tensors[key]
# Add sequence dimension if it's not there
if len(cross_attn_in.shape) == 2:
cross_attn_in = cross_attn_in.unsqueeze(1)
cross_attn_mask = cross_attn_mask.unsqueeze(1)
cross_attention_input.append(cross_attn_in)
cross_attention_masks.append(cross_attn_mask)
cross_attention_input = torch.cat(cross_attention_input, dim=1)
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
if len(self.global_cond_ids) > 0:
# Concatenate all global conditioning inputs over the channel dimension
# Assumes that the global conditioning inputs are of shape (batch, channels)
global_conds = []
for key in self.global_cond_ids:
global_cond_input = conditioning_tensors[key][0]
global_conds.append(global_cond_input)
# Concatenate over the channel dimension
global_cond = torch.cat(global_conds, dim=-1)
if len(global_cond.shape) == 3:
global_cond = global_cond.squeeze(1)
if len(self.input_concat_ids) > 0:
# Concatenate all input concat conditioning inputs over the channel dimension
# Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
if len(self.local_add_cond_ids) > 0:
# Concatenate all local conditioning inputs over the channel dimension
# Assumes that the local conditioning inputs are of shape (batch, channels, seq)
local_add_cond = torch.cat([conditioning_tensors[key][0] for key in self.local_add_cond_ids], dim=1)
if len(self.modular_local_cond_ids) > 0:
# Keep modular local conditioning as a dict of tensors (not concatenated)
# Each tensor is of shape (batch, channels, seq)
modular_local_cond = {}
for key in self.modular_local_cond_ids:
if key in conditioning_tensors:
modular_local_cond[key] = conditioning_tensors[key][0]
# Only set if we have any conditioning
if len(modular_local_cond) == 0:
modular_local_cond = None
if len(self.prepend_cond_ids) > 0:
# Concatenate all prepend conditioning inputs over the sequence dimension
# Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
prepend_conds = []
prepend_cond_masks = []
for key in self.prepend_cond_ids:
prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
prepend_conds.append(prepend_cond_input)
prepend_cond_masks.append(prepend_cond_mask)
prepend_cond = torch.cat(prepend_conds, dim=1)
prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
if negative:
return {
"negative_cross_attn_cond": cross_attention_input,
"negative_cross_attn_mask": cross_attention_masks,
"negative_global_cond": global_cond,
"negative_input_concat_cond": input_concat_cond
}
else:
return {
"cross_attn_cond": cross_attention_input,
"cross_attn_mask": cross_attention_masks,
"global_cond": global_cond,
"input_concat_cond": input_concat_cond,
"local_add_cond": local_add_cond,
"modular_local_cond": modular_local_cond,
"prepend_cond": prepend_cond,
"prepend_cond_mask": prepend_cond_mask
}
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
class DiTWrapper(nn.Module):
def __init__(
self,
diffusion_objective: str,
*args,
**kwargs
):
super().__init__()
self.diffusion_objective = diffusion_objective
self.model = DiffusionTransformer(diffusion_objective=diffusion_objective, *args, **kwargs)
def forward(self,
x,
t,
cross_attn_cond=None,
cross_attn_mask=None,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
input_concat_cond=None,
local_add_cond=None,
negative_input_concat_cond=None,
global_cond=None,
negative_global_cond=None,
prepend_cond=None,
prepend_cond_mask=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = True,
rescale_cfg: bool = False,
scale_phi: float = 0.0,
**kwargs):
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
return self.model(
x,
t,
cross_attn_cond=cross_attn_cond,
cross_attn_cond_mask=cross_attn_mask,
negative_cross_attn_cond=negative_cross_attn_cond,
negative_cross_attn_mask=negative_cross_attn_mask,
input_concat_cond=input_concat_cond,
prepend_cond=prepend_cond,
prepend_cond_mask=prepend_cond_mask,
cfg_scale=cfg_scale,
cfg_dropout_prob=cfg_dropout_prob,
scale_phi=scale_phi,
global_embed=global_cond,
local_add_cond=local_add_cond,
**kwargs)
|