import spaces import gradio as gr import os import sys import time import tempfile import shutil import torch _root = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, _root) sys.path.insert(0, os.path.join(_root, "common")) from PIL import Image REPO_LAYERDIFF = "layerdifforg/seethroughv0.0.2_layerdiff3d" REPO_DEPTH = "24yearsold/seethroughv0.0.1_marigold" def _log(msg): print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) # --------------- Preload models to CPU at startup --------------- _log("Preloading LayerDiff pipeline to CPU...") from modules.layerdiffuse.diffusers_kdiffusion_sdxl import KDiffusionStableDiffusionXLPipeline from modules.layerdiffuse.layerdiff3d import UNetFrameConditionModel from modules.layerdiffuse.vae import TransparentVAE, TransparentVAEDecoder, TransparentVAEEncoder _trans_vae = TransparentVAE.from_pretrained(REPO_LAYERDIFF, subfolder="trans_vae") _unet_ld = UNetFrameConditionModel.from_pretrained(REPO_LAYERDIFF, subfolder="unet") _layerdiff_pipe = KDiffusionStableDiffusionXLPipeline.from_pretrained( REPO_LAYERDIFF, trans_vae=_trans_vae, unet=_unet_ld, scheduler=None ) _log("LayerDiff pipeline loaded to CPU.") _log("Preloading Marigold pipeline to CPU...") from modules.marigold import MarigoldDepthPipeline _unet_mg = UNetFrameConditionModel.from_pretrained(REPO_DEPTH, subfolder="unet") _marigold_pipe = MarigoldDepthPipeline.from_pretrained(REPO_DEPTH, unet=_unet_mg) _log("Marigold pipeline loaded to CPU.") _models_on_gpu = False from utils.inference_utils import apply_layerdiff, apply_marigold, further_extr from utils.torch_utils import seed_everything import utils.inference_utils as _inf def _move_to_gpu(): global _models_on_gpu if _models_on_gpu: _log("Models already on GPU, skipping transfer.") return t0 = time.time() _log("Moving LayerDiff to CUDA bf16...") _layerdiff_pipe.vae.to(dtype=torch.bfloat16, device="cuda") _layerdiff_pipe.trans_vae.to(dtype=torch.bfloat16, device="cuda") _layerdiff_pipe.unet.to(dtype=torch.bfloat16, device="cuda") _layerdiff_pipe.text_encoder.to(dtype=torch.bfloat16, device="cuda") _layerdiff_pipe.text_encoder_2.to(dtype=torch.bfloat16, device="cuda") _log(f"LayerDiff on GPU ({time.time() - t0:.1f}s)") t0 = time.time() _log("Moving Marigold to CUDA bf16...") _marigold_pipe.to(device="cuda", dtype=torch.bfloat16) _log(f"Marigold on GPU ({time.time() - t0:.1f}s)") # Inject into inference_utils globals so apply_* functions skip their own loading _inf.layerdiff_pipeline = _layerdiff_pipe _inf.marigold_pipeline = _marigold_pipe _models_on_gpu = True _SKIP_TAGS = {"src_img", "src_head", "reconstruction"} def _collect_layer_gallery(saved_dir): """Collect layer PNGs as (image, label) tuples for the gallery.""" gallery = [] for f in sorted(os.listdir(saved_dir)): if not f.endswith(".png"): continue tag = f[:-4] if tag.endswith("_depth") or tag in _SKIP_TAGS: continue img = Image.open(os.path.join(saved_dir, f)) gallery.append((img, tag)) return gallery @spaces.GPU(duration=120) def inference(image: Image.Image, resolution: int = 768, seed: int = 42, tblr_split: bool = False): t_start = time.time() if image is None: raise gr.Error("Please upload an image.") # Snap to nearest multiple of 64 for clean latent dimensions resolution = max(64, min(resolution, 1280)) resolution = round(resolution / 64) * 64 _log(f"Resolution: {resolution}, Seed: {seed}, Image: {image.size}") _move_to_gpu() seed_everything(seed) tmpdir = tempfile.mkdtemp(prefix="seethrough_") try: input_path = os.path.join(tmpdir, "input.png") image.save(input_path) t0 = time.time() _log("Running LayerDiff...") apply_layerdiff( input_path, REPO_LAYERDIFF, save_dir=tmpdir, seed=seed, resolution=resolution, ) _log(f"LayerDiff done ({time.time() - t0:.1f}s)") t0 = time.time() _log("Running Marigold depth...") apply_marigold( input_path, REPO_DEPTH, save_dir=tmpdir, seed=seed, resolution=resolution, ) _log(f"Marigold done ({time.time() - t0:.1f}s)") saved = os.path.join(tmpdir, "input") # Collect gallery before PSD assembly (further_extr may modify files) gallery = _collect_layer_gallery(saved) t0 = time.time() _log("Running PSD assembly...") further_extr(saved, rotate=False, save_to_psd=True, tblr_split=tblr_split) _log(f"PSD assembly done ({time.time() - t0:.1f}s)") psd_path = saved + ".psd" if os.path.exists(psd_path): output_path = os.path.join( tempfile.gettempdir(), "seethrough_output.psd" ) shutil.copy2(psd_path, output_path) _log(f"Total inference time: {time.time() - t_start:.1f}s") return output_path, gallery raise gr.Error("PSD generation failed — no output file produced.") finally: shutil.rmtree(tmpdir, ignore_errors=True) with gr.Blocks(title="See-through: Layer Decomposition") as demo: gr.Markdown( "# See-through: Single-image Layer Decomposition for Anime Characters\n\n" 'GitHub | ' 'Paper (arXiv:2602.03749)\n\n' "Upload an anime character illustration to decompose it into " "fully-inpainted semantic layers with depth ordering, " "exported as a layered PSD file.\n\n" "**Note:** 768 resolution is recommended for ZeroGPU free tier. " "Higher resolutions may timeout or exhaust your daily quota. " 'For best quality, clone the full repo ' "and run `inference_psd.py` locally. " 'We also have a ComfyUI Node Extension ' "and other community extensions — check our repo for more details.\n\n" "**Disclaimer:** This demo uses a newer model checkpoint and " "may not fully reproduce identical results reported in the paper." ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Upload image (non-square images will be padded)") resolution = gr.Slider( minimum=768, maximum=1280, value=768, step=64, label="Resolution", info="768 recommended for ZeroGPU free tier. Higher resolutions may timeout or use up your daily quota quickly.", ) seed = gr.Slider(minimum=0, maximum=9999, value=42, step=1, label="Seed") tblr_split = gr.Checkbox( value=False, label="Split left/right arms & legs", info="Separate left and right limbs into individual layers. Useful if the default output glues them together.", ) run_btn = gr.Button("Run", variant="primary") with gr.Column(scale=2): psd_output = gr.File(label="Download layered PSD") gallery_output = gr.Gallery(label="Separated layers", columns=4, height="auto") run_btn.click( fn=inference, inputs=[input_image, resolution, seed, tblr_split], outputs=[psd_output, gallery_output], ) gr.Examples( examples=[["common/assets/test_image.png", 768, 42, False]], inputs=[input_image, resolution, seed, tblr_split], ) if __name__ == "__main__": demo.launch()