| import os |
| import os.path as osp |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import pdb |
| import cv2 |
|
|
|
|
| def labels2image(all_indices, label_type='int_label', scale_schedule=None): |
| summed_codes, recons_imgs = self.vae.decode_from_indices(all_indices, scale_schedule, label_type) |
| recons_img = recons_imgs[0] |
| recons_img = (recons_img + 1) / 2 |
| recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1] |
| return recons_img |
|
|
| def features2image(raw_features): |
| recons_imgs = self.vae.decode(raw_features.squeeze(-3)) |
| recons_img = recons_imgs[0] |
| recons_img = (recons_img + 1) / 2 |
| recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1] |
| return recons_img |
|
|
| class BitwiseSelfCorrection(object): |
| def __init__(self, vae, args): |
| self.noise_apply_layers = args.noise_apply_layers |
| self.noise_apply_requant = args.noise_apply_requant |
| self.noise_apply_strength = args.noise_apply_strength |
| self.apply_spatial_patchify = args.apply_spatial_patchify |
| self.vae = vae |
| self.debug_bsc = args.debug_bsc |
|
|
| def flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device): |
| with torch.amp.autocast('cuda', enabled = False): |
| B = raw_features.shape[0] |
| if raw_features.dim() == 4: |
| codes_out = raw_features.unsqueeze(2) |
| else: |
| codes_out = raw_features |
| cum_var_input = 0 |
| gt_all_bit_indices = [] |
| pred_all_bit_indices = [] |
| x_BLC_wo_prefix = [] |
| for si, (pt, ph, pw) in enumerate(vae_scale_schedule): |
| residual = codes_out - cum_var_input |
| if si != len(vae_scale_schedule)-1: |
| residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous() |
| quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) |
| gt_all_bit_indices.append(bit_indices) |
| if si < self.noise_apply_layers: |
| noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01 |
| mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength |
| pred_bit_indices = bit_indices.clone() |
| pred_bit_indices[mask] = 1 - pred_bit_indices[mask] |
| pred_all_bit_indices.append(pred_bit_indices) |
| if self.noise_apply_requant: |
| quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label') |
| else: |
| pred_all_bit_indices.append(bit_indices) |
| cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous() |
| if si < len(vae_scale_schedule)-1: |
| this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=self.vae.quantizer.z_interplote_up).contiguous() |
| if self.apply_spatial_patchify: |
| |
| this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2) |
| x_BLC_wo_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) |
|
|
| if self.apply_spatial_patchify: |
| gt_ms_idx_Bl = [] |
| for item in gt_all_bit_indices: |
| |
| item = item.squeeze(1).permute(0,3,1,2) |
| |
| item = torch.nn.functional.pixel_unshuffle(item, 2) |
| |
| item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim) |
| gt_ms_idx_Bl.append(item) |
| else: |
| gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices] |
| x_BLC_wo_prefix = torch.cat(x_BLC_wo_prefix, 1) |
|
|
| if self.debug_bsc: |
| self.visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices) |
| |
| return x_BLC_wo_prefix, gt_ms_idx_Bl |
| |
| def my_flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device): |
| my_noise_apply_layers = -1 |
| with torch.amp.autocast('cuda', enabled = False): |
| B = raw_features.shape[0] |
| if raw_features.dim() == 4: |
| codes_out = raw_features.unsqueeze(2) |
| else: |
| codes_out = raw_features |
| cum_var_input = 0 |
| gt_all_bit_indices = [] |
| pred_all_bit_indices = [] |
| x_BLC_w_prefix = [] |
| for si, (pt, ph, pw) in enumerate(vae_scale_schedule): |
| residual = codes_out - cum_var_input |
| if si != len(vae_scale_schedule)-1: |
| residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous() |
| quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) |
| gt_all_bit_indices.append(bit_indices) |
| if si < my_noise_apply_layers: |
| noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01 |
| mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength |
| pred_bit_indices = bit_indices.clone() |
| pred_bit_indices[mask] = 1 - pred_bit_indices[mask] |
| pred_all_bit_indices.append(pred_bit_indices) |
| if self.noise_apply_requant: |
| quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label') |
| else: |
| pred_all_bit_indices.append(bit_indices) |
| cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous() |
| |
| |
| |
| |
| |
| |
|
|
| if si <= len(vae_scale_schedule)-1: |
| this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_up).contiguous() |
| if self.apply_spatial_patchify: |
| |
| this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2) |
| |
| x_BLC_w_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) |
|
|
| if self.apply_spatial_patchify: |
| gt_ms_idx_Bl = [] |
| for item in gt_all_bit_indices: |
| |
| item = item.squeeze(1).permute(0,3,1,2) |
| |
| item = torch.nn.functional.pixel_unshuffle(item, 2) |
| |
| item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim) |
| gt_ms_idx_Bl.append(item) |
| else: |
| gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices] |
| x_BLC_w_prefix = torch.cat(x_BLC_w_prefix, 1) |
|
|
| if self.debug_bsc: |
| self.visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices) |
| return x_BLC_w_prefix, gt_ms_idx_Bl |
| |
| def long_flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device): |
| my_noise_apply_layers = -1 |
| with torch.amp.autocast('cuda', enabled = False): |
| B = raw_features.shape[0] |
| if raw_features.dim() == 4: |
| codes_out = raw_features.unsqueeze(2) |
| else: |
| codes_out = raw_features |
| cum_var_input = 0 |
| gt_all_bit_indices = [] |
| pred_all_bit_indices = [] |
| x_BLC_w_prefix = [] |
| for si, (pt, ph, pw) in enumerate(vae_scale_schedule): |
| residual = codes_out - cum_var_input |
| if si != len(vae_scale_schedule)-1: |
| residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous() |
| quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) |
| gt_all_bit_indices.append(bit_indices) |
| if si < my_noise_apply_layers: |
| noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01 |
| mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength |
| pred_bit_indices = bit_indices.clone() |
| pred_bit_indices[mask] = 1 - pred_bit_indices[mask] |
| pred_all_bit_indices.append(pred_bit_indices) |
| if self.noise_apply_requant: |
| quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label') |
| else: |
| pred_all_bit_indices.append(bit_indices) |
| cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous() |
| |
| if si < len(vae_scale_schedule)-1: |
| this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=self.vae.quantizer.z_interplote_up).contiguous() |
| if self.apply_spatial_patchify: |
| |
| this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2) |
| x_BLC_w_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) |
|
|
| if si == len(vae_scale_schedule)-1: |
| this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_up).contiguous() |
| if self.apply_spatial_patchify: |
| |
| this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2) |
| x_BLC_w_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) |
|
|
| if self.apply_spatial_patchify: |
| gt_ms_idx_Bl = [] |
| for item in gt_all_bit_indices: |
| |
| item = item.squeeze(1).permute(0,3,1,2) |
| |
| item = torch.nn.functional.pixel_unshuffle(item, 2) |
| |
| item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim) |
| gt_ms_idx_Bl.append(item) |
| else: |
| gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices] |
| x_BLC_w_prefix = torch.cat(x_BLC_w_prefix, 1) |
|
|
| if self.debug_bsc: |
| self.visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices) |
| return x_BLC_w_prefix, gt_ms_idx_Bl |
|
|
| def flow_flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device): |
| my_noise_apply_layers = -1 |
| with torch.amp.autocast('cuda', enabled = False): |
| B = raw_features.shape[0] |
| if raw_features.dim() == 4: |
| codes_out = raw_features.unsqueeze(2) |
| else: |
| codes_out = raw_features |
| cum_var_input = 0 |
| gt_all_bit_indices = [] |
| pred_all_bit_indices = [] |
| x_BLC_w_prefix = [] |
| for si, (pt, ph, pw) in enumerate(vae_scale_schedule): |
| residual = codes_out - cum_var_input |
| if si != len(vae_scale_schedule)-1: |
| residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous() |
| quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) |
| gt_all_bit_indices.append(bit_indices) |
| if si < my_noise_apply_layers: |
| noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01 |
| mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength |
| pred_bit_indices = bit_indices.clone() |
| pred_bit_indices[mask] = 1 - pred_bit_indices[mask] |
| pred_all_bit_indices.append(pred_bit_indices) |
| if self.noise_apply_requant: |
| quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label') |
| else: |
| pred_all_bit_indices.append(bit_indices) |
| cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous() |
| |
| |
| |
| |
| |
| |
|
|
| if si <= len(vae_scale_schedule)-1: |
| this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_up).contiguous() |
| if self.apply_spatial_patchify: |
| |
| this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2) |
| |
| x_BLC_w_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) |
|
|
| if self.apply_spatial_patchify: |
| gt_ms_idx_Bl = [] |
| for item in gt_all_bit_indices: |
| |
| item = item.squeeze(1).permute(0,3,1,2) |
| |
| item = torch.nn.functional.pixel_unshuffle(item, 2) |
| |
| item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim) |
| gt_ms_idx_Bl.append(item) |
| else: |
| gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices] |
| x_BLC_w_prefix = torch.cat(x_BLC_w_prefix, 1) |
|
|
| if self.debug_bsc: |
| self.flow_visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices,raw_features) |
| raw_features_quantized, _, _, _ = self.vae.quantizer.lfq(codes_out) |
| raw_features_seq = raw_features.reshape(*raw_features.shape[:2], -1).permute(0,2,1) |
| return x_BLC_w_prefix, gt_ms_idx_Bl,raw_features_seq |
| |
| def visualize(self, vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices): |
| gt_img = (inp_B3HW.squeeze(-3) + 1) / 2 * 255 |
| gt_img = gt_img[0].permute(1,2,0).cpu().numpy().astype(np.uint8)[:,:,::-1] |
| recons_img_2 = self.labels2image(gt_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule) |
| recons_img_3 = self.labels2image(pred_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule) |
| cat_image = np.concatenate([gt_img, recons_img_2, recons_img_3], axis=1) |
| save_path = osp.abspath('gt-gt_indices-pred_indices_new.jpg') |
| cv2.imwrite(save_path, cat_image) |
| print(f'Save to {save_path}') |
| print(cat_image.shape) |
| import pdb; pdb.set_trace() |
|
|
| def flow_visualize(self, vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices,raw_features): |
| gt_img = (inp_B3HW.squeeze(-3) + 1) / 2 * 255 |
| gt_img = gt_img[0].permute(1,2,0).cpu().numpy().astype(np.uint8)[:,:,::-1] |
| recons_img_2 = self.labels2image(gt_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule) |
| recons_img_3 = self.labels2image(pred_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule) |
| recons_img_4 = self.features2image(raw_features) |
| cat_image = np.concatenate([gt_img, recons_img_2, recons_img_3,recons_img_4], axis=1) |
| save_path = osp.abspath('gt-gt_indices-pred_indices-raw_features_new.jpg') |
| cv2.imwrite(save_path, cat_image) |
| print(f'Save to {save_path}') |
| print(cat_image.shape) |
| import pdb; pdb.set_trace() |
| |
| def labels2image(self,all_indices, label_type='int_label', scale_schedule=None): |
| summed_codes, recons_imgs = self.vae.decode_from_indices(all_indices, scale_schedule, label_type) |
| recons_img = recons_imgs[0] |
| recons_img = (recons_img + 1) / 2 |
| recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1] |
| return recons_img |
|
|
| def features2image(self,raw_features): |
| recons_imgs = self.vae.decode(raw_features.squeeze(-3)) |
| recons_img = recons_imgs[0] |
| recons_img = (recons_img + 1) / 2 |
| recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1] |
| return recons_img |
| |