""" app.py — OOTDiffusion Hugging Face Space Place this file in the ROOT of your Space repo alongside the OOTDiffusion source folders: ootd/, run/, preprocess/, checkpoints/ README.md front-matter required: --- title: OOTDiffusion Virtual Try-On emoji: 👗 colorFrom: purple colorTo: pink sdk: gradio sdk_version: 4.16.0 app_file: app.py pinned: false license: cc-by-nc-sa-4.0 --- """ import sys import os # ── Path setup ──────────────────────────────────────────────────────────────── ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) RUN_DIR = os.path.join(ROOT_DIR, "run") sys.path.insert(0, ROOT_DIR) sys.path.insert(0, RUN_DIR) import torch import numpy as np import gradio as gr from PIL import Image # ── Device ──────────────────────────────────────────────────────────────────── DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"[OOTDiffusion] Device: {DEVICE}") # ── Lazy-load models (loaded once on first request) ─────────────────────────── _pipe_hd = None # VITON-HD — half-body _pipe_dc = None # Dress Code — full-body def load_pipeline(model_type: str): """Import and cache the requested OOTDiffusion pipeline.""" global _pipe_hd, _pipe_dc if model_type == "hd": if _pipe_hd is None: from ootd.inference_ootd_hd import OOTDiffusionHD print("[OOTDiffusion] Loading HD pipeline …") _pipe_hd = OOTDiffusionHD(ROOT_DIR) return _pipe_hd else: # dc if _pipe_dc is None: from ootd.inference_ootd_dc import OOTDiffusionDC print("[OOTDiffusion] Loading DC pipeline …") _pipe_dc = OOTDiffusionDC(ROOT_DIR) return _pipe_dc # ── Category mapping ────────────────────────────────────────────────────────── CATEGORY_MAP = { "Upper-body": 0, "Lower-body": 1, "Dress": 2, } # ── Main inference function ─────────────────────────────────────────────────── def run_tryon( model_image, cloth_image, model_type, category_label, n_samples, n_steps, guidance_scale, seed, ): if model_image is None: raise gr.Error("Please upload a model (person) image.") if cloth_image is None: raise gr.Error("Please upload a garment image.") # Convert to PIL just in case Gradio passes numpy arrays if isinstance(model_image, np.ndarray): model_image = Image.fromarray(model_image) if isinstance(cloth_image, np.ndarray): cloth_image = Image.fromarray(cloth_image) model_image = model_image.convert("RGB") cloth_image = cloth_image.convert("RGB") category_idx = CATEGORY_MAP[category_label] try: pipe = load_pipeline(model_type) except Exception as e: raise gr.Error( f"Failed to load model: {e}\n" "Make sure checkpoints/ and ootd/ folders are present." ) try: if model_type == "hd": result = pipe( model_type="hd", category=category_idx, image_garm=cloth_image, image_vton=model_image, mask=None, image_ori=model_image, num_samples=int(n_samples), num_steps=int(n_steps), guidance_scale=guidance_scale, seed=int(seed), ) else: result = pipe( model_type="dc", category=category_idx, image_garm=cloth_image, image_vton=model_image, mask=None, image_ori=model_image, num_samples=int(n_samples), num_steps=int(n_steps), guidance_scale=guidance_scale, seed=int(seed), ) except Exception as e: raise gr.Error(f"Inference failed: {e}") # result is expected to be a list of PIL Images if isinstance(result, (list, tuple)): return result return [result] # ── Gradio UI ───────────────────────────────────────────────────────────────── with gr.Blocks(title="OOTDiffusion Virtual Try-On", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 👗 OOTDiffusion — Virtual Try-On **[AAAI 2025]** Upload a *model photo* and a *garment image*, choose settings, and click **Run Try-On**. > ⚠️ Non-commercial use only (CC-BY-NC-SA-4.0) """ ) with gr.Row(): # ── Left column: inputs ─────────────────────────────────────────────── with gr.Column(scale=1): model_img = gr.Image( label="Model Image (person)", type="pil", height=400, ) cloth_img = gr.Image( label="Garment Image (clothing)", type="pil", height=400, ) # ── Middle column: settings ─────────────────────────────────────────── with gr.Column(scale=1): model_type = gr.Radio( choices=["hd", "dc"], value="hd", label="Model Type", info="hd = half-body (VITON-HD) | dc = full-body (Dress Code)", ) category = gr.Dropdown( choices=list(CATEGORY_MAP.keys()), value="Upper-body", label="Garment Category", info="Only used when Model Type is 'dc'", ) n_samples = gr.Slider( minimum=1, maximum=4, step=1, value=1, label="Number of Samples", ) n_steps = gr.Slider( minimum=10, maximum=40, step=5, value=20, label="Denoising Steps", info="More steps = better quality but slower", ) guidance_scale = gr.Slider( minimum=1.0, maximum=5.0, step=0.5, value=2.0, label="Guidance Scale", ) seed = gr.Number( value=42, label="Seed (-1 = random)", precision=0, ) run_btn = gr.Button("🚀 Run Try-On", variant="primary") # ── Right column: outputs ───────────────────────────────────────────── with gr.Column(scale=1): output_gallery = gr.Gallery( label="Try-On Results", columns=2, height=500, object_fit="contain", ) gr.Markdown( """ ### Tips - **HD model**: best for upper-body garments on half-body photos - **DC model**: supports upper-body / lower-body / dress on full-body photos - Increasing **steps** to 30–40 noticeably improves quality - Set **seed = -1** for random results each run """ ) # ── Wire up the button ──────────────────────────────────────────────────── run_btn.click( fn=run_tryon, inputs=[ model_img, cloth_img, model_type, category, n_samples, n_steps, guidance_scale, seed, ], outputs=output_gallery, ) # ── Launch ──────────────────────────────────────────────────────────────────── if __name__ == "__main__": demo.launch()