| import torch |
|
|
| def normalise2(tensor): |
| '''[0,1] -> [-1,1]''' |
| return (tensor*2 - 1.).clamp(-1,1) |
|
|
| def tfg_data(dataloader, face_hide_percentage, use_ref, use_audio): |
| def inf_gen(generator): |
| while True: |
| yield from generator |
| data = inf_gen(dataloader) |
| for batch in data: |
| img_batch, model_kwargs = tfg_process_batch(batch, face_hide_percentage, use_ref, use_audio) |
| yield img_batch, model_kwargs |
| |
|
|
| def tfg_process_batch(batch, face_hide_percentage, use_ref=False, use_audio=False, sampling_use_gt_for_ref=False, noise = None): |
| model_kwargs = {} |
| B, F,C, H, W = batch["image"].shape |
| img_batch = normalise2(batch["image"].reshape(B*F, C, H, W).contiguous()) |
| model_kwargs = tfg_add_cond_inputs(img_batch, model_kwargs, face_hide_percentage, noise) |
| if use_ref: |
| model_kwargs = tfg_add_reference(batch, model_kwargs, sampling_use_gt_for_ref) |
| if use_audio: |
| model_kwargs = tfg_add_audio(batch,model_kwargs) |
| return img_batch, model_kwargs |
|
|
| def tfg_add_reference(batch, model_kwargs, sampling_use_gt_for_ref=False): |
| |
| |
| if sampling_use_gt_for_ref: |
| B, F,C, H, W = batch["image"].shape |
| img_batch = normalise2(batch["image"].reshape(B*F, C, H, W).contiguous()) |
| model_kwargs["ref_img"] = img_batch |
| else: |
| _, _, C, H , W = batch["ref_img"].shape |
| ref_img = normalise2(batch["ref_img"].reshape(-1, C, H, W).contiguous()) |
| model_kwargs["ref_img"] = ref_img |
| return model_kwargs |
|
|
| def tfg_add_audio(batch, model_kwargs): |
| |
| B, F, _, h, w = batch["indiv_mels"].shape |
| indiv_mels = batch["indiv_mels"] |
| indiv_mels = indiv_mels.squeeze(dim=2).reshape(B*F, h , w) |
| model_kwargs["indiv_mels"] = indiv_mels |
| |
| if "mel" in batch: |
| mel = batch["mel"] |
| model_kwargs["mel"]=mel |
| return model_kwargs |
|
|
| def tfg_add_cond_inputs(img_batch, model_kwargs, face_hide_percentage, noise=None): |
| B, C, H, W = img_batch.shape |
| mask = torch.zeros(B,1,H,W) |
| mask_start_idx = int (H*(1-face_hide_percentage)) |
| mask[:,:,mask_start_idx:,:]=1. |
| if noise is None: |
| noise = torch.randn_like(img_batch) |
| assert noise.shape == img_batch.shape, "Noise shape != Image shape" |
| cond_img = img_batch *(1. - mask)+mask*noise |
|
|
| model_kwargs["cond_img"] = cond_img |
| model_kwargs["mask"] = mask |
| return model_kwargs |
|
|
|
|
| def get_n_params(model): |
| pp=0 |
| for p in list(model.parameters()): |
| nn=1 |
| for s in list(p.size()): |
| nn=nn*s |
| pp+=nn |
| return pp |