| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .layers import ResidualBlock, AttnBlock |
| from .utils import get_named_beta_schedule |
|
|
| def sinusoidal_embedding(n, d): |
| """ |
| n: iteration steps, |
| d: time embedding dimension |
| """ |
| |
| embedding = torch.tensor([[i / 10000 ** (2 * j / d) for j in range(d)] for i in range(n)]) |
| sin_mask = torch.arange(0, n, 2) |
|
|
| embedding[sin_mask] = torch.sin(embedding[sin_mask]) |
| embedding[1 - sin_mask] = torch.cos(embedding[sin_mask]) |
|
|
| return embedding |
|
|
| def _make_te(dim_in, dim_out): |
| return nn.Sequential( |
| nn.Linear(dim_in, dim_out), |
| nn.SiLU(), |
| nn.Linear(dim_out, dim_out) |
| ) |
|
|
| class UNet_with_time(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| input_frame = config.input_frame |
| output_frame = config.output_frame |
| n_steps = config.n_steps |
| time_emb_dim = config.time_emb_dim |
| cond_nc = config.cond_nc |
| chs_mult = config.chs_mult |
| n_res_blocks = config.n_res_blocks |
| base_chs = config.base_chs |
| |
| use_attn_list = config.use_attn_list |
|
|
| layer_depth = len(chs_mult) |
| assert len(use_attn_list) == layer_depth, "length of use_attn_list should be the same as chs_mult" |
| assert input_frame >= output_frame, "input_frame should be larger than or equal to output_frame" |
|
|
| self.filter_list = [base_chs * m for m in chs_mult] |
|
|
| |
| self.time_embed = nn.Embedding(n_steps, time_emb_dim) |
| self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim) |
| self.time_embed.requires_grad_(False) |
| self.time_embed_fc = _make_te(time_emb_dim, time_emb_dim) |
| |
|
|
| |
| self.input_layer = nn.PixelUnshuffle(downscale_factor=2) |
|
|
| |
| self.down_blocks = nn.ModuleList() |
| in_c = input_frame * 4 |
| for i in range(layer_depth): |
| out_c = self.filter_list[i] |
|
|
| for _ in range(n_res_blocks): |
| self.down_blocks.append( |
| ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False) |
| ) |
|
|
| if use_attn_list[i]: |
| self.down_blocks.append(AttnBlock(in_c, 4)) |
| |
| self.down_blocks.append( |
| ResidualBlock(in_c, out_c, cond_nc, time_emb_dim, down_flag=True, up_flag=False) |
| ) |
| in_c = out_c |
| |
|
|
| |
| self.mid_block1 = ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False) |
| self.mid_attn = AttnBlock(in_c, 4) |
| self.mid_block2 = ResidualBlock(in_c, in_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False) |
| |
|
|
| |
| self.up_blocks = nn.ModuleList() |
| self.filter_list = [input_frame * 4] + self.filter_list[:-1] |
| for i in reversed(range(layer_depth)): |
| out_c = self.filter_list[i] |
| |
| self.up_blocks.append( |
| ResidualBlock(in_c*2, out_c, cond_nc, time_emb_dim, down_flag=False, up_flag=True) |
| ) |
| if use_attn_list[i]: |
| self.up_blocks.append(AttnBlock(out_c)) |
|
|
| for _ in range(n_res_blocks): |
| self.up_blocks.append( |
| ResidualBlock(out_c*2, out_c, cond_nc, time_emb_dim, down_flag=False, up_flag=False) |
| ) |
|
|
| in_c = out_c |
| |
| |
| self.out_up = nn.PixelShuffle(upscale_factor=2) |
| self.out_conv = nn.Conv2d(input_frame, output_frame, 3, padding=1) |
|
|
| def forward(self, x, t, cond): |
| """ |
| x: (b, in_c, h, w), noisy input (concatenated with some data) |
| t: (b,), time step |
| cond: (b, cond_nc, h, w), conditional input |
| """ |
| |
| t_emb = self.time_embed(t) |
| t_emb = self.time_embed_fc(t_emb) |
|
|
| |
| x = self.input_layer(x) |
|
|
| |
| skip_x = [] |
| for ii, down_layer in enumerate(self.down_blocks): |
| if isinstance(down_layer, ResidualBlock): |
| x = down_layer(x, cond, t_emb) |
| skip_x.append(x) |
| elif isinstance(down_layer, AttnBlock): |
| x = down_layer(x) |
| else: |
| raise ValueError("Wrong layer type in down_blocks") |
|
|
| |
| x = self.mid_block1(x, cond, t_emb) |
| x = self.mid_attn(x) |
| x = self.mid_block2(x, cond, t_emb) |
|
|
| |
| for up_layer in self.up_blocks: |
| if isinstance(up_layer, ResidualBlock): |
| skip_feat = skip_x.pop() |
| x = torch.cat([x, skip_feat], dim=1) |
| x = up_layer(x, cond, t_emb) |
| elif isinstance(up_layer, AttnBlock): |
| x = up_layer(x) |
| else: |
| raise ValueError("Wrong layer type in up_blocks") |
|
|
| |
| x = self.out_up(x) |
| x = self.out_conv(x) |
|
|
| return x |
|
|
| class DDPM(nn.Module): |
| def __init__(self, backbone, output_shape, n_steps=1000, min_beta=1e-4, max_beta=0.02, device='cuda'): |
| """ |
| output_shape: dim(C, H, W) |
| """ |
| super().__init__() |
| self.device = device |
| self.backbone_model = backbone |
| self.output_shape = output_shape |
|
|
| self.n_steps = n_steps |
| |
| |
| betas = get_named_beta_schedule("linear", n_steps, min_beta, max_beta) |
| alphas = 1.0 - betas |
| alpha_bars = torch.cumprod(alphas, dim=0) |
|
|
| self.register_buffer('betas', betas) |
| self.register_buffer('alphas', alphas) |
| self.register_buffer('alpha_bars', alpha_bars) |
|
|
| def forward(self, x, t, cond): |
| """ |
| x: (b, in_c, h, w), noisy input (concatenated with some data) |
| cond: (b, cond_nc, h, w), conditional input |
| t: (b,), time step |
| """ |
| return self.backbone_model(x, t, cond) |
|
|
| @torch.no_grad() |
| def add_noise(self, x0, t, eta=None): |
| """ |
| x0: (b, c, h, w), original data |
| t: (b,), time step (0 <= t < n_steps) |
| """ |
| b, c, h, w = x0.shape |
| if eta is None: |
| eta = torch.randn(b, c, h, w, device=x0.device) |
|
|
| alpha_bar = self.alpha_bars[t] |
| noisy_x = alpha_bar.sqrt().reshape(b, 1, 1, 1) * x0 + (1 - alpha_bar).sqrt().reshape(b, 1, 1, 1) * eta |
|
|
| return noisy_x |
|
|
| def denoise(self, xt, t, cond): |
| """ |
| xt: (b, in_c, h, w), noisy input (concatenated with some data) |
| cond: (b, cond_nc, h, w), conditional input |
| t: (b,), time step (0 <= t < n_steps) |
| """ |
| pred_noise = self(xt, t, cond) |
| return pred_noise |
|
|
| @torch.no_grad() |
| def _build_progress_iter(self, iterable, total, mode: str): |
| """ |
| Internal helper to create a progress iterator based on verbose mode. |
| """ |
| mode = (mode or "none").lower() |
| if mode == "tqdm": |
| try: |
| from tqdm import tqdm |
|
|
| return tqdm(iterable, total=total, desc="DDPM sampling", leave=False), mode |
| except Exception: |
| return iterable, "none" |
| return iterable, mode |
|
|
| @torch.no_grad() |
| def sample_ddpm(self, cond, input_cond=None, verbose: str = "none", store_intermediate: bool = False): |
| """ |
| input_frame: (b, c, h, w) number of input frames (conditional input frames) for the diffusion model |
| cond: (b, cond_nc, h, w), conditional input |
| verbose: "none", "text", or "tqdm" for progress display |
| """ |
| |
| self.backbone_model.eval() |
| |
| B, C, H, W = cond.shape |
| |
| device = cond.device |
|
|
| x = torch.randn(B, *self.output_shape, device=device) |
|
|
| progress_iter_raw = reversed(range(self.n_steps)) |
| progress_iter, mode = self._build_progress_iter(progress_iter_raw, self.n_steps, verbose) |
| use_text = mode == "text" |
|
|
| text_interval = max(1, self.n_steps // 10) |
| |
| frames = [] |
| for idx, t in enumerate(progress_iter): |
| time_tensor = (torch.ones(B, device=device) * t).long() |
| if input_cond is not None: |
| input_ = torch.cat((x, input_cond), dim=1) |
| else: |
| input_ = x |
|
|
| eta_theta = self.denoise(input_, time_tensor, cond) |
|
|
| alpha_t = self.alphas[t] |
| alpha_t_bar = self.alpha_bars[t] |
|
|
| a = 1 / alpha_t.sqrt() |
| b = ((1 - alpha_t) / (1 - alpha_t_bar).sqrt()) * eta_theta |
|
|
| x = a * (x - b) |
| if t > 0: |
| z = torch.randn(B, *self.output_shape, device=device) |
| beta_t = self.betas[t] |
| sigma_t = beta_t.sqrt() |
| x = x + sigma_t * z |
|
|
| |
| if (idx % 50 == 0) or (t == 0): |
| out = x.clone() |
| out = ((out + 1) / 2).clamp(0, 1) |
| out = out.cpu().numpy() |
| frames.append(out) |
| |
| if use_text and (idx + 1) % text_interval == 0: |
| print(f"DDPM sampling {idx + 1}/{self.n_steps}", flush=True) |
|
|
| if mode == "tqdm" and hasattr(progress_iter, "close"): |
| progress_iter.close() |
| |
| if store_intermediate: |
| return x, frames |
| else: |
| return x |
|
|
| @torch.no_grad() |
| def sample_ddim(self, cond, input_cond=None, ddim_steps: int = 100, eta: float = 0.2, verbose: str = "none", store_intermediate: bool = False): |
| """ |
| Deterministic/stochastic DDIM sampling. |
| |
| cond: (b, cond_nc, h, w) |
| input_cond: optional conditional input concatenated with the predicted frames |
| ddim_steps: number of steps to sample (<= n_steps) |
| eta: 0 for deterministic DDIM, >0 adds noise controlled by eta |
| verbose: "none", "text", or "tqdm" for progress display |
| """ |
| self.backbone_model.eval() |
|
|
| B, C, H, W = cond.shape |
| device = cond.device |
| ddim_steps = max(1, min(ddim_steps, self.n_steps)) |
|
|
| |
| ddim_timesteps = torch.linspace(0, self.n_steps - 1, steps=ddim_steps, device=device).long() |
| ddim_timesteps = torch.unique(ddim_timesteps, sorted=True) |
| ddim_t_reverse = list(reversed(ddim_timesteps.tolist())) |
|
|
| x = torch.randn(B, *self.output_shape, device=device) |
|
|
| progress_iter_raw = enumerate(ddim_t_reverse) |
| progress_iter, mode = self._build_progress_iter(progress_iter_raw, len(ddim_t_reverse), verbose) |
| use_text = mode == "text" |
| text_interval = max(1, len(ddim_t_reverse) // 10) |
|
|
| frames = [] |
| for idx, (iter_idx, t) in enumerate(progress_iter): |
| time_tensor = torch.full((B,), t, device=device, dtype=torch.long) |
| if input_cond is not None: |
| input_ = torch.cat((x, input_cond), dim=1) |
| else: |
| input_ = x |
|
|
| eps = self.denoise(input_, time_tensor, cond) |
|
|
| alpha_bar_t = self.alpha_bars[t] |
| sqrt_alpha_bar_t = alpha_bar_t.sqrt() |
| sqrt_one_minus_alpha_bar_t = (1 - alpha_bar_t).sqrt() |
|
|
| x0_pred = (x - sqrt_one_minus_alpha_bar_t * eps) / sqrt_alpha_bar_t |
|
|
| if iter_idx + 1 < len(ddim_t_reverse): |
| t_prev = ddim_t_reverse[iter_idx + 1] |
| alpha_bar_prev = self.alpha_bars[t_prev] |
| else: |
| alpha_bar_prev = torch.ones_like(alpha_bar_t, device=device) |
|
|
| sigma_t = 0.0 |
| if eta > 0 and alpha_bar_prev < 1: |
| sigma_t = eta * torch.sqrt( |
| (1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_prev) |
| ) |
|
|
| sigma_t = torch.as_tensor(sigma_t, device=device, dtype=x.dtype) |
| noise = torch.randn_like(x) if (eta > 0 and alpha_bar_prev < 1) else torch.zeros_like(x) |
|
|
| c_t = torch.sqrt(torch.clamp(1 - alpha_bar_prev - sigma_t ** 2, min=0.0)) |
| x = ( |
| alpha_bar_prev.sqrt() * x0_pred |
| + c_t * eps |
| + sigma_t * noise |
| ) |
|
|
| |
| if (idx % 25 == 0) or (t == 0): |
| out = x.clone() |
| out = ((out + 1) / 2).clamp(0, 1) |
| out = out.cpu().numpy() |
| frames.append(out) |
|
|
| if use_text and (idx + 1) % text_interval == 0: |
| print(f"DDIM sampling {idx + 1}/{len(ddim_t_reverse)}", flush=True) |
|
|
| if mode == "tqdm" and hasattr(progress_iter, "close"): |
| progress_iter.close() |
|
|
| if store_intermediate: |
| return x, frames |
| else: |
| return x |
|
|
| |
| sample = sample_ddpm |
|
|