Ubaida10's picture
Update app.py
553cd96 verified
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.utils import save_image
import argparse
import os
import time
from PIL import Image
import shutil
import gradio as gr
from cp_dataset_test import CPDatasetTest, CPDataLoader
from networks import ConditionGenerator, load_checkpoint, make_grid, make_grid_3d, get_val
from network_generator import SPADEGenerator
from utils import *
import torchgeometry as tgm
from collections import OrderedDict
def get_opt():
parser = argparse.ArgumentParser()
parser.add_argument("--gpu_ids", default="")
parser.add_argument('--test_name', type=str, default='test')
parser.add_argument("--dataroot", default="./data")
parser.add_argument("--output_dir", type=str, default="./output")
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints')
parser.add_argument('--tocg_checkpoint', type=str, default='./checkpoints/tocg.pth')
parser.add_argument('--gen_checkpoint', type=str, default='./checkpoints/gen_step_110000.pth')
parser.add_argument('--use_gradio', action='store_true', default=True)
parser.add_argument("--fine_width", type=int, default=768)
parser.add_argument("--fine_height", type=int, default=1024)
parser.add_argument('--cond_G_ngf', type=int, default=96)
parser.add_argument("--cond_G_input_width", type=int, default=192)
parser.add_argument("--cond_G_input_height", type=int, default=256)
parser.add_argument('--cond_G_num_layers', type=int, default=5)
parser.add_argument('--norm_G', type=str, default='spectralaliasinstance')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--init_type', type=str, default='xavier')
parser.add_argument('--init_variance', type=float, default=0.02)
parser.add_argument('--semantic_nc', type=int, default=13)
parser.add_argument('--output_nc', type=int, default=13)
opt = parser.parse_args([])
return opt
def load_checkpoint_G(model, checkpoint_path):
if not os.path.exists(checkpoint_path):
print(f"Checkpoint path {checkpoint_path} does not exist!")
return
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint.get('generator_state_dict', checkpoint)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_key = k.replace('ace', 'alias').replace('.Spade', '')
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict, strict=False)
model.cuda()
print(f"Loaded checkpoint from {checkpoint_path}")
def run_single_test(opt, tocg, generator, garment_path, human_path, output_path):
# Dummy image-based output to simulate result generation
# Replace this with actual inference logic from test()
garment_img = Image.open(garment_path).convert("RGB")
human_img = Image.open(human_path).convert("RGB")
result = Image.blend(human_img.resize(garment_img.size), garment_img, alpha=0.5)
result.save(output_path)
print(f"Saved output to {output_path}")
def process_images_local(opt, tocg, generator, garm_img_path, human_img_path, output_dir):
os.makedirs(output_dir, exist_ok=True)
output_filename = os.path.join(output_dir, f"output_{int(time.time())}.jpg")
try:
run_single_test(opt, tocg, generator, garm_img_path, human_img_path, output_filename)
return output_filename
except Exception as e:
print(f"Local inference failed: {e}")
return None
def gradio_interface(garm_img, human_img, opt, tocg, generator):
get_val()
print("Image processing initialized.")
if not garm_img:
return None, None, "Error: Please upload a garment image."
if not human_img:
return None, None, "Error: Please upload a human image."
target_dir = opt.output_dir
os.makedirs(target_dir, exist_ok=True)
garm_img_path = os.path.join(target_dir, "garment.jpg")
human_img_path = os.path.join(target_dir, "human.jpg")
try:
shutil.copy(garm_img.name, garm_img_path)
shutil.copy(human_img.name, human_img_path)
print(f"Copied images to {target_dir}")
except Exception as e:
return None, None, f"Error copying images: {str(e)}"
try:
output_path = process_images_local(opt, tocg, generator, garm_img_path, human_img_path, target_dir)
if output_path:
return Image.open(output_path), output_path, f"Success: Output saved to {output_path}"
else:
return None, None, "Error: Failed to generate output."
except Exception as e:
return None, None, f"Error processing images: {str(e)}"
def main():
opt = get_opt()
print(opt)
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
tocg = ConditionGenerator(opt, input1_nc=4, input2_nc=opt.semantic_nc + 3, output_nc=opt.output_nc,
ngf=opt.cond_G_ngf, norm_layer=nn.BatchNorm2d, num_layers=opt.cond_G_num_layers)
generator = SPADEGenerator(opt, 3 + 3 + 3)
load_checkpoint(tocg, opt.tocg_checkpoint)
load_checkpoint_G(generator, opt.gen_checkpoint)
if opt.use_gradio:
with gr.Blocks() as demo:
gr.Markdown("## Virtual Fashion Fit")
with gr.Row():
with gr.Column():
garm_img = gr.File(label="Upload Garment Image", file_types=["image"])
garm_preview = gr.Image(label="Garment Preview")
with gr.Column():
human_img = gr.File(label="Upload Human Image", file_types=["image"])
human_preview = gr.Image(label="Human Preview")
submit = gr.Button("Run Try-On")
output_image = gr.Image(label="Output Image")
output_path = gr.Textbox(label="Output Path")
output_text = gr.Textbox(label="Status")
garm_img.change(lambda x: x.name if x else None, inputs=garm_img, outputs=garm_preview)
human_img.change(lambda x: x.name if x else None, inputs=human_img, outputs=human_preview)
submit.click(fn=lambda garm_img, human_img: gradio_interface(garm_img, human_img, opt, tocg, generator),
inputs=[garm_img, human_img],
outputs=[output_image, output_path, output_text])
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
main()