import torch import torch.nn as nn import torchvision.transforms.functional as F from torchvision.utils import save_image import argparse import os import time from PIL import Image import shutil import gradio as gr from cp_dataset_test import CPDatasetTest, CPDataLoader from networks import ConditionGenerator, load_checkpoint, make_grid, make_grid_3d, get_val from network_generator import SPADEGenerator from utils import * import torchgeometry as tgm from collections import OrderedDict def get_opt(): parser = argparse.ArgumentParser() parser.add_argument("--gpu_ids", default="") parser.add_argument('--test_name', type=str, default='test') parser.add_argument("--dataroot", default="./data") parser.add_argument("--output_dir", type=str, default="./output") parser.add_argument('--checkpoint_dir', type=str, default='checkpoints') parser.add_argument('--tocg_checkpoint', type=str, default='./checkpoints/tocg.pth') parser.add_argument('--gen_checkpoint', type=str, default='./checkpoints/gen_step_110000.pth') parser.add_argument('--use_gradio', action='store_true', default=True) parser.add_argument("--fine_width", type=int, default=768) parser.add_argument("--fine_height", type=int, default=1024) parser.add_argument('--cond_G_ngf', type=int, default=96) parser.add_argument("--cond_G_input_width", type=int, default=192) parser.add_argument("--cond_G_input_height", type=int, default=256) parser.add_argument('--cond_G_num_layers', type=int, default=5) parser.add_argument('--norm_G', type=str, default='spectralaliasinstance') parser.add_argument('--ngf', type=int, default=64) parser.add_argument('--init_type', type=str, default='xavier') parser.add_argument('--init_variance', type=float, default=0.02) parser.add_argument('--semantic_nc', type=int, default=13) parser.add_argument('--output_nc', type=int, default=13) opt = parser.parse_args([]) return opt def load_checkpoint_G(model, checkpoint_path): if not os.path.exists(checkpoint_path): print(f"Checkpoint path {checkpoint_path} does not exist!") return checkpoint = torch.load(checkpoint_path) state_dict = checkpoint.get('generator_state_dict', checkpoint) new_state_dict = OrderedDict() for k, v in state_dict.items(): new_key = k.replace('ace', 'alias').replace('.Spade', '') new_state_dict[new_key] = v model.load_state_dict(new_state_dict, strict=False) model.cuda() print(f"Loaded checkpoint from {checkpoint_path}") def run_single_test(opt, tocg, generator, garment_path, human_path, output_path): # Dummy image-based output to simulate result generation # Replace this with actual inference logic from test() garment_img = Image.open(garment_path).convert("RGB") human_img = Image.open(human_path).convert("RGB") result = Image.blend(human_img.resize(garment_img.size), garment_img, alpha=0.5) result.save(output_path) print(f"Saved output to {output_path}") def process_images_local(opt, tocg, generator, garm_img_path, human_img_path, output_dir): os.makedirs(output_dir, exist_ok=True) output_filename = os.path.join(output_dir, f"output_{int(time.time())}.jpg") try: run_single_test(opt, tocg, generator, garm_img_path, human_img_path, output_filename) return output_filename except Exception as e: print(f"Local inference failed: {e}") return None def gradio_interface(garm_img, human_img, opt, tocg, generator): get_val() print("Image processing initialized.") if not garm_img: return None, None, "Error: Please upload a garment image." if not human_img: return None, None, "Error: Please upload a human image." target_dir = opt.output_dir os.makedirs(target_dir, exist_ok=True) garm_img_path = os.path.join(target_dir, "garment.jpg") human_img_path = os.path.join(target_dir, "human.jpg") try: shutil.copy(garm_img.name, garm_img_path) shutil.copy(human_img.name, human_img_path) print(f"Copied images to {target_dir}") except Exception as e: return None, None, f"Error copying images: {str(e)}" try: output_path = process_images_local(opt, tocg, generator, garm_img_path, human_img_path, target_dir) if output_path: return Image.open(output_path), output_path, f"Success: Output saved to {output_path}" else: return None, None, "Error: Failed to generate output." except Exception as e: return None, None, f"Error processing images: {str(e)}" def main(): opt = get_opt() print(opt) os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids tocg = ConditionGenerator(opt, input1_nc=4, input2_nc=opt.semantic_nc + 3, output_nc=opt.output_nc, ngf=opt.cond_G_ngf, norm_layer=nn.BatchNorm2d, num_layers=opt.cond_G_num_layers) generator = SPADEGenerator(opt, 3 + 3 + 3) load_checkpoint(tocg, opt.tocg_checkpoint) load_checkpoint_G(generator, opt.gen_checkpoint) if opt.use_gradio: with gr.Blocks() as demo: gr.Markdown("## Virtual Fashion Fit") with gr.Row(): with gr.Column(): garm_img = gr.File(label="Upload Garment Image", file_types=["image"]) garm_preview = gr.Image(label="Garment Preview") with gr.Column(): human_img = gr.File(label="Upload Human Image", file_types=["image"]) human_preview = gr.Image(label="Human Preview") submit = gr.Button("Run Try-On") output_image = gr.Image(label="Output Image") output_path = gr.Textbox(label="Output Path") output_text = gr.Textbox(label="Status") garm_img.change(lambda x: x.name if x else None, inputs=garm_img, outputs=garm_preview) human_img.change(lambda x: x.name if x else None, inputs=human_img, outputs=human_preview) submit.click(fn=lambda garm_img, human_img: gradio_interface(garm_img, human_img, opt, tocg, generator), inputs=[garm_img, human_img], outputs=[output_image, output_path, output_text]) demo.launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": main()