Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms.functional as F | |
| from torchvision.utils import save_image | |
| import argparse | |
| import os | |
| import time | |
| from PIL import Image | |
| import shutil | |
| import gradio as gr | |
| from cp_dataset_test import CPDatasetTest, CPDataLoader | |
| from networks import ConditionGenerator, load_checkpoint, make_grid, make_grid_3d, get_val | |
| from network_generator import SPADEGenerator | |
| from utils import * | |
| import torchgeometry as tgm | |
| from collections import OrderedDict | |
| def get_opt(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--gpu_ids", default="") | |
| parser.add_argument('--test_name', type=str, default='test') | |
| parser.add_argument("--dataroot", default="./data") | |
| parser.add_argument("--output_dir", type=str, default="./output") | |
| parser.add_argument('--checkpoint_dir', type=str, default='checkpoints') | |
| parser.add_argument('--tocg_checkpoint', type=str, default='./checkpoints/tocg.pth') | |
| parser.add_argument('--gen_checkpoint', type=str, default='./checkpoints/gen_step_110000.pth') | |
| parser.add_argument('--use_gradio', action='store_true', default=True) | |
| parser.add_argument("--fine_width", type=int, default=768) | |
| parser.add_argument("--fine_height", type=int, default=1024) | |
| parser.add_argument('--cond_G_ngf', type=int, default=96) | |
| parser.add_argument("--cond_G_input_width", type=int, default=192) | |
| parser.add_argument("--cond_G_input_height", type=int, default=256) | |
| parser.add_argument('--cond_G_num_layers', type=int, default=5) | |
| parser.add_argument('--norm_G', type=str, default='spectralaliasinstance') | |
| parser.add_argument('--ngf', type=int, default=64) | |
| parser.add_argument('--init_type', type=str, default='xavier') | |
| parser.add_argument('--init_variance', type=float, default=0.02) | |
| parser.add_argument('--semantic_nc', type=int, default=13) | |
| parser.add_argument('--output_nc', type=int, default=13) | |
| opt = parser.parse_args([]) | |
| return opt | |
| def load_checkpoint_G(model, checkpoint_path): | |
| if not os.path.exists(checkpoint_path): | |
| print(f"Checkpoint path {checkpoint_path} does not exist!") | |
| return | |
| checkpoint = torch.load(checkpoint_path) | |
| state_dict = checkpoint.get('generator_state_dict', checkpoint) | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| new_key = k.replace('ace', 'alias').replace('.Spade', '') | |
| new_state_dict[new_key] = v | |
| model.load_state_dict(new_state_dict, strict=False) | |
| model.cuda() | |
| print(f"Loaded checkpoint from {checkpoint_path}") | |
| def run_single_test(opt, tocg, generator, garment_path, human_path, output_path): | |
| # Dummy image-based output to simulate result generation | |
| # Replace this with actual inference logic from test() | |
| garment_img = Image.open(garment_path).convert("RGB") | |
| human_img = Image.open(human_path).convert("RGB") | |
| result = Image.blend(human_img.resize(garment_img.size), garment_img, alpha=0.5) | |
| result.save(output_path) | |
| print(f"Saved output to {output_path}") | |
| def process_images_local(opt, tocg, generator, garm_img_path, human_img_path, output_dir): | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_filename = os.path.join(output_dir, f"output_{int(time.time())}.jpg") | |
| try: | |
| run_single_test(opt, tocg, generator, garm_img_path, human_img_path, output_filename) | |
| return output_filename | |
| except Exception as e: | |
| print(f"Local inference failed: {e}") | |
| return None | |
| def gradio_interface(garm_img, human_img, opt, tocg, generator): | |
| get_val() | |
| print("Image processing initialized.") | |
| if not garm_img: | |
| return None, None, "Error: Please upload a garment image." | |
| if not human_img: | |
| return None, None, "Error: Please upload a human image." | |
| target_dir = opt.output_dir | |
| os.makedirs(target_dir, exist_ok=True) | |
| garm_img_path = os.path.join(target_dir, "garment.jpg") | |
| human_img_path = os.path.join(target_dir, "human.jpg") | |
| try: | |
| shutil.copy(garm_img.name, garm_img_path) | |
| shutil.copy(human_img.name, human_img_path) | |
| print(f"Copied images to {target_dir}") | |
| except Exception as e: | |
| return None, None, f"Error copying images: {str(e)}" | |
| try: | |
| output_path = process_images_local(opt, tocg, generator, garm_img_path, human_img_path, target_dir) | |
| if output_path: | |
| return Image.open(output_path), output_path, f"Success: Output saved to {output_path}" | |
| else: | |
| return None, None, "Error: Failed to generate output." | |
| except Exception as e: | |
| return None, None, f"Error processing images: {str(e)}" | |
| def main(): | |
| opt = get_opt() | |
| print(opt) | |
| os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids | |
| tocg = ConditionGenerator(opt, input1_nc=4, input2_nc=opt.semantic_nc + 3, output_nc=opt.output_nc, | |
| ngf=opt.cond_G_ngf, norm_layer=nn.BatchNorm2d, num_layers=opt.cond_G_num_layers) | |
| generator = SPADEGenerator(opt, 3 + 3 + 3) | |
| load_checkpoint(tocg, opt.tocg_checkpoint) | |
| load_checkpoint_G(generator, opt.gen_checkpoint) | |
| if opt.use_gradio: | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Virtual Fashion Fit") | |
| with gr.Row(): | |
| with gr.Column(): | |
| garm_img = gr.File(label="Upload Garment Image", file_types=["image"]) | |
| garm_preview = gr.Image(label="Garment Preview") | |
| with gr.Column(): | |
| human_img = gr.File(label="Upload Human Image", file_types=["image"]) | |
| human_preview = gr.Image(label="Human Preview") | |
| submit = gr.Button("Run Try-On") | |
| output_image = gr.Image(label="Output Image") | |
| output_path = gr.Textbox(label="Output Path") | |
| output_text = gr.Textbox(label="Status") | |
| garm_img.change(lambda x: x.name if x else None, inputs=garm_img, outputs=garm_preview) | |
| human_img.change(lambda x: x.name if x else None, inputs=human_img, outputs=human_preview) | |
| submit.click(fn=lambda garm_img, human_img: gradio_interface(garm_img, human_img, opt, tocg, generator), | |
| inputs=[garm_img, human_img], | |
| outputs=[output_image, output_path, output_text]) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |
| if __name__ == "__main__": | |
| main() |