SD-VITON-Inference / train_condition.py
Ubaida10's picture
Added model files to the Hugging Face Space
ed9b477
import torch
import torch.nn as nn
from networks import make_grid as mkgrid
from networks import make_grid_3d as mkgrid_3d
import argparse
import os
import time
from cp_dataset import CPDataset, CPDataLoader
from networks import ConditionGenerator, VGGLoss, GANLoss, load_checkpoint, save_checkpoint, define_D
from tqdm import tqdm
from utils import *
def remove_overlap(seg_out, warped_cm):
assert len(warped_cm.shape) == 4
warped_cm = warped_cm - (torch.cat([seg_out[:, 1:3, :, :], seg_out[:, 5:, :, :]], dim=1)).sum(dim=1, keepdim=True) * warped_cm
return warped_cm
def get_opt():
parser = argparse.ArgumentParser()
parser.add_argument("--name", default="test")
parser.add_argument("--gpu_ids", default="")
parser.add_argument('-j', '--workers', type=int, default=4)
parser.add_argument('-b', '--batch-size', type=int, default=8)
parser.add_argument('--fp16', action='store_true', help='use amp')
parser.add_argument("--dataroot", default="./data/")
parser.add_argument("--datamode", default="train")
parser.add_argument("--data_list", default="train_pairs.txt")
parser.add_argument("--fine_width", type=int, default=192)
parser.add_argument("--fine_height", type=int, default=256)
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos')
parser.add_argument('--tocg_checkpoint', type=str, default='', help='tocg checkpoint')
parser.add_argument("--display_count", type=int, default=100)
parser.add_argument("--save_count", type=int, default=1000)
parser.add_argument("--load_step", type=int, default=0)
parser.add_argument("--keep_step", type=int, default=300000)
parser.add_argument("--shuffle", action='store_true', help='shuffle input data')
parser.add_argument("--semantic_nc", type=int, default=13)
parser.add_argument("--output_nc", type=int, default=13)
parser.add_argument('--cond_G_ngf', type=int, default=96)
parser.add_argument('--cond_G_num_layers', type=int, default=5)
parser.add_argument("--warp_feature", choices=['encoder', 'T1'], default="T1")
parser.add_argument("--out_layer", choices=['relu', 'conv'], default="relu")
parser.add_argument('--Ddownx2', action='store_true', help="Downsample D's input to increase the receptive field")
parser.add_argument('--Ddropout', action='store_true', help="Apply dropout to D")
parser.add_argument('--num_D', type=int, default=2, help='Generator ngf')
parser.add_argument("--lasttvonly", action='store_true')
parser.add_argument("--interflowloss", action='store_true', help="Intermediate flow loss")
parser.add_argument('--G_lr', type=float, default=0.0002, help='Generator initial learning rate for adam')
parser.add_argument('--D_lr', type=float, default=0.0002, help='Discriminator initial learning rate for adam')
parser.add_argument('--CElamda', type=float, default=10, help='initial learning rate for adam')
parser.add_argument('--GANlambda', type=float, default=1)
parser.add_argument('--tvlambda_tvob', type=float, default=2)
parser.add_argument('--tvlambda_taco', type=float, default=2)
parser.add_argument('--upsample', type=str, default='bilinear', choices=['nearest', 'bilinear'])
parser.add_argument('--spectral', action='store_true', help="Apply spectral normalization to D")
parser.add_argument('--occlusion', action='store_true', help="Occlusion handling")
parser.add_argument('--resume', action='store_true', help='resume from the latest checkpoint')
opt = parser.parse_args()
return opt
def train(opt, train_loader, tocg, D, optimizer_G, optimizer_D):
tocg.cuda()
tocg.train()
D.cuda()
D.train()
# criterion
criterionL1 = nn.L1Loss()
criterionVGG = VGGLoss()
if opt.fp16:
criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.HalfTensor)
else:
criterionGAN = GANLoss(use_lsgan=True, tensor=torch.cuda.FloatTensor if opt.gpu_ids else torch.Tensor)
for step in tqdm(range(opt.load_step, opt.keep_step)):
iter_start_time = time.time()
inputs = train_loader.next_batch()
# input1
c_paired = inputs['cloth']['paired'].cuda()
cm_paired = inputs['cloth_mask']['paired'].cuda()
cm_paired = torch.FloatTensor((cm_paired.detach().cpu().numpy() > 0.5).astype(float)).cuda()
# input2
parse_agnostic = inputs['parse_agnostic'].cuda()
densepose = inputs['densepose'].cuda()
openpose = inputs['pose'].cuda()
# GT
label_onehot = inputs['parse_onehot'].cuda() # CE
label = inputs['parse'].cuda() # GAN loss
parse_cloth_mask = inputs['pcm'].cuda() # L1
im_c = inputs['parse_cloth'].cuda() # VGG
# visualization
im = inputs['image']
# tucked-out shirts style
lower_clothes_mask = inputs['lower_clothes_mask'].cuda()
clothes_no_loss_mask = inputs['clothes_no_loss_mask'].cuda()
# inputs
input1 = torch.cat([c_paired, cm_paired], 1)
input2 = torch.cat([parse_agnostic, densepose], 1)
# forward
flow_list_taco, fake_segmap, warped_cloth_paired_taco, warped_clothmask_paired_taco, flow_list_tvob, warped_cloth_paired_tvob, warped_clothmask_paired_tvob = tocg(input1, input2)
# warped cloth mask one hot
warped_clothmask_paired_taco_onehot = torch.FloatTensor((warped_clothmask_paired_taco.detach().cpu().numpy() > 0.5).astype(float)).cuda()
# fake segmap cloth channel * warped clothmask
cloth_mask = torch.ones_like(fake_segmap.detach())
cloth_mask[:, 3:4, :, :] = warped_clothmask_paired_taco
fake_segmap = fake_segmap * cloth_mask
if opt.occlusion:
warped_clothmask_paired_taco = remove_overlap(F.softmax(fake_segmap, dim=1), warped_clothmask_paired_taco)
warped_cloth_paired_taco = warped_cloth_paired_taco * warped_clothmask_paired_taco + torch.ones_like(warped_cloth_paired_taco) * (1 - warped_clothmask_paired_taco)
warped_clothmask_paired_tvob = remove_overlap(F.softmax(fake_segmap, dim=1), warped_clothmask_paired_tvob)
warped_cloth_paired_tvob = warped_cloth_paired_tvob * warped_clothmask_paired_tvob + torch.ones_like(warped_cloth_paired_tvob) * (1 - warped_clothmask_paired_tvob)
# loss warping
loss_l1_cloth = criterionL1(warped_clothmask_paired_taco, parse_cloth_mask)
loss_vgg = criterionVGG(warped_cloth_paired_taco, im_c)
# Eq.8 & Eq.9 of SD-VITON
inv_lower_clothes_mask = lower_clothes_mask * clothes_no_loss_mask
inv_lower_clothes_mask = 1. - inv_lower_clothes_mask
loss_l1_cloth += criterionL1(warped_clothmask_paired_tvob * inv_lower_clothes_mask, parse_cloth_mask * inv_lower_clothes_mask)
loss_vgg += criterionVGG(warped_cloth_paired_tvob * inv_lower_clothes_mask, im_c * inv_lower_clothes_mask)
# Eq.12 of SD-VITON
roi_mask = torch.nn.functional.interpolate(parse_cloth_mask, scale_factor=0.5, mode='nearest')
non_roi_mask = 1. - roi_mask
flow_taco = flow_list_taco[-1]
z_gt_non_roi = -1
z_gt_roi = 1
z_src_coordinate = -1
z_dist_loss_non_roi = (torch.abs(z_src_coordinate + flow_taco[:, 0:1, :, :, 2] + z_gt_non_roi) * non_roi_mask).mean()
z_dist_loss_roi = (torch.abs(z_src_coordinate + flow_taco[:, 0:1, :, :, 2] + z_gt_roi) * roi_mask).mean()
loss_tv_tvob = 0
loss_tv_taco = 0
if not opt.lasttvonly:
for flow in flow_list_taco:
y_tv = torch.abs(flow[:, :, 1:, :, :] - flow[:, :, :-1, :, :]).mean()
x_tv = torch.abs(flow[:, :, :, 1:, :] - flow[:, :, :, :-1, :]).mean()
loss_tv_taco = loss_tv_taco + y_tv + x_tv
for flow in flow_list_tvob:
y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean()
x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean()
loss_tv_tvob = loss_tv_tvob + y_tv + x_tv
else:
for flow in flow_list_taco[-1:]:
y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean()
x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean()
loss_tv_taco = loss_tv_taco + y_tv + x_tv
for flow in flow_list_tvob[-1:]:
y_tv = torch.abs(flow[:, 1:, :, :] - flow[:, :-1, :, :]).mean()
x_tv = torch.abs(flow[:, :, 1:, :] - flow[:, :, :-1, :]).mean()
loss_tv_tvob = loss_tv_tvob + y_tv + x_tv
N, _, iH, iW = c_paired.size()
# Intermediate flow loss
if opt.interflowloss:
layers_max_idx = len(flow_list_tvob) - 1
for i in range(len(flow_list_tvob) - 1):
flow = flow_list_tvob[i]
N, fH, fW, _ = flow.size()
grid = mkgrid(N, iH, iW)
grid_3d = mkgrid_3d(N, iH, iW)
flow = F.interpolate(flow.permute(0, 3, 1, 2), size=c_paired.shape[2:], mode=opt.upsample).permute(0, 2, 3, 1)
flow_norm = torch.cat([flow[:, :, :, 0:1] / ((fW - 1.0) / 2.0), flow[:, :, :, 1:2] / ((fH - 1.0) / 2.0)], 3)
warped_c = F.grid_sample(c_paired, flow_norm + grid, padding_mode='border')
warped_cm = F.grid_sample(cm_paired, flow_norm + grid, padding_mode='border')
warped_cm = remove_overlap(F.softmax(fake_segmap, dim=1), warped_cm)
# Eq.8 & Eq.9 of SD-VITON
loss_l1_cloth += criterionL1(warped_cm * inv_lower_clothes_mask, parse_cloth_mask * inv_lower_clothes_mask) / (2 ** (layers_max_idx - i))
loss_vgg += criterionVGG(warped_c * inv_lower_clothes_mask, im_c * inv_lower_clothes_mask) / (2 ** (layers_max_idx - i))
# loss segmentation
# generator
CE_loss = cross_entropy2d(fake_segmap, label_onehot.transpose(0, 1)[0].long())
fake_segmap_softmax = torch.softmax(fake_segmap, 1)
pred_segmap = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax), dim=1))
loss_G_GAN = criterionGAN(pred_segmap, True)
# discriminator
fake_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), fake_segmap_softmax.detach()), dim=1))
real_segmap_pred = D(torch.cat((input1.detach(), input2.detach(), label), dim=1))
loss_D_fake = criterionGAN(fake_segmap_pred, False)
loss_D_real = criterionGAN(real_segmap_pred, True)
# loss sum
loss_G = (10 * loss_l1_cloth + loss_vgg + opt.tvlambda_tvob * loss_tv_tvob + opt.tvlambda_taco * loss_tv_taco) + (CE_loss * opt.CElamda + loss_G_GAN * opt.GANlambda) + z_dist_loss_non_roi + z_dist_loss_roi
loss_D = loss_D_fake + loss_D_real
# step
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
# display
if (step) % 100 == 0:
a_0 = c_paired[0].cuda()
b_0 = im[0].cuda()
c_0 = warped_cloth_paired_tvob[0]
d_0 = warped_cloth_paired_taco[0]
e_0 = lower_clothes_mask
e_0 = torch.cat((e_0[0], e_0[0], e_0[0]), dim=0)
f_0 = densepose[0].cuda()
g_0 = clothes_no_loss_mask
g_0 = torch.cat((g_0[0], g_0[0], g_0[0]), dim=0)
h_0 = lower_clothes_mask * clothes_no_loss_mask
h_0 = torch.cat((h_0[0], h_0[0], h_0[0]), dim=0)
i_0 = inv_lower_clothes_mask
i_0 = torch.cat((i_0[0], i_0[0], i_0[0]), dim=0)
combine = torch.cat((a_0, b_0, c_0, d_0, e_0, f_0, g_0, h_0, i_0), dim=2)
cv_img = (combine.permute(1, 2, 0).detach().cpu().numpy() + 1) / 2
rgb = (cv_img * 255).astype(np.uint8)
bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
cv2.imwrite('sample_fs_3/' + str(step) + '.jpg', bgr)
if (step + 1) % opt.display_count == 0:
t = time.time() - iter_start_time
print("step: %8d, time: %.3f\nloss G: %.4f, L1_cloth loss: %.4f, VGG loss: %.4f, TV_tvob loss: %.4f, TV_taco loss: %.4f, CE: %.4f, G GAN: %.4f\nloss D: %.4f, D real: %.4f, D fake: %.4f, z_non_roi: %.4f, z_roi: %.4f"
% (step + 1, t, loss_G.item(), loss_l1_cloth.item(), loss_vgg.item(), loss_tv_tvob.item(), loss_tv_taco.item(), CE_loss.item(), loss_G_GAN.item(), loss_D.item(), loss_D_real.item(), loss_D_fake.item(), z_dist_loss_non_roi, z_dist_loss_roi), flush=True)
# Save checkpoint
if (step + 1) % opt.save_count == 0:
checkpoint = {
'step': step + 1,
'tocg_state_dict': tocg.state_dict(),
'D_state_dict': D.state_dict(),
'optimizer_G_state_dict': optimizer_G.state_dict(),
'optimizer_D_state_dict': optimizer_D.state_dict(),
}
checkpoint_path = os.path.join(opt.checkpoint_dir, opt.name, 'latest_checkpoint.pth')
torch.save(checkpoint, checkpoint_path)
print(f"Saved checkpoint at step {step + 1} to {checkpoint_path}")
def main():
opt = get_opt()
print(f"Start to train {opt.name}!")
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
os.makedirs('sample_fs_3', exist_ok=True)
os.makedirs(os.path.join(opt.checkpoint_dir, opt.name), exist_ok=True)
# Create train dataset & loader
train_dataset = CPDataset(opt)
train_loader = CPDataLoader(opt, train_dataset)
# Model
input1_nc = 4 # cloth + cloth-mask
input2_nc = opt.semantic_nc + 3 # parse_agnostic + densepose
tocg = ConditionGenerator(opt, input1_nc=4, input2_nc=input2_nc, output_nc=opt.output_nc, ngf=opt.cond_G_ngf, norm_layer=nn.BatchNorm2d, num_layers=opt.cond_G_num_layers)
D = define_D(input_nc=input1_nc + input2_nc + opt.output_nc, Ddownx2=opt.Ddownx2, Ddropout=opt.Ddropout, n_layers_D=(opt.cond_G_num_layers - 2), spectral=opt.spectral, num_D=opt.num_D)
# Move models to GPU
tocg.cuda()
D.cuda()
# Define optimizers
optimizer_G = torch.optim.Adam(tocg.parameters(), lr=opt.G_lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.D_lr, betas=(0.5, 0.999))
# Load checkpoint if resuming
if opt.resume:
checkpoint_path = os.path.join(opt.checkpoint_dir, opt.name, 'latest_checkpoint.pth')
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
tocg.load_state_dict(checkpoint['tocg_state_dict'])
D.load_state_dict(checkpoint['D_state_dict'])
optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
opt.load_step = checkpoint['step']
print(f"Resuming from step {opt.load_step} using checkpoint {checkpoint_path}")
else:
print(f"No checkpoint found at {checkpoint_path}, starting from scratch")
# Train
train(opt, train_loader, tocg, D, optimizer_G, optimizer_D)
# Save final models
save_checkpoint(tocg, os.path.join(opt.checkpoint_dir, opt.name, 'tocg_final.pth'))
save_checkpoint(D, os.path.join(opt.checkpoint_dir, opt.name, 'D_final.pth'))
print(f"Finished training {opt.name}!")
if __name__ == "__main__":
main()