import gradio as gr import torch import torchvision.transforms as T from PIL import Image from diffusers import AsymmetricAutoencoderKL import spaces MODEL_ID = "AiArtLab/sdxs-1b" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 # ------------------------- # Load VAE # ------------------------- def load_vae(model_id=MODEL_ID): for attempt in (None, "vae"): try: if attempt is None: vae = AsymmetricAutoencoderKL.from_pretrained( model_id, torch_dtype=DTYPE ) else: vae = AsymmetricAutoencoderKL.from_pretrained( model_id, subfolder=attempt, torch_dtype=DTYPE ) vae = vae.to(DEVICE) vae.eval() print("VAE loaded on", DEVICE) return vae except Exception as e: last_err = e raise RuntimeError(f"Failed to load VAE: {last_err}") _vae = None def get_vae(): global _vae if _vae is None: _vae = load_vae() return _vae # ------------------------- # Encode / Decode # ------------------------- @spaces.GPU(duration=60) def encode_decode(img: Image.Image): if img is None: raise gr.Error("Please upload an image") vae = get_vae() img = img.convert("RGB") tfm = T.Compose([ T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3), ]) t = tfm(img).unsqueeze(0).to(DEVICE, dtype=DTYPE) print("Input tensor:", t.shape, t.dtype, t.device) with torch.no_grad(): enc = vae.encode(t) lat = enc.latent_dist.sample() print("Latents:", lat.shape) dec = vae.decode(lat).sample x = (dec.clamp(-1, 1) + 1) * 127.5 x = x.round().to(torch.uint8) x = x.squeeze(0).permute(1, 2, 0).cpu().numpy() out = Image.fromarray(x) print("Output size:", out.size) return out # ------------------------- # UI # ------------------------- with gr.Blocks(title="VAE True-to-Source Upscaler") as demo: gr.Markdown("# VAE-based 2x Upscaler\nSimple, blind, true-to-source upscaling without AI hallucinations.") with gr.Row(): with gr.Column(): input_img = gr.Image(type="pil", label="Upload Image") upscale_btn = gr.Button("Upscale x2", variant="primary") with gr.Column(): output_img = gr.Image(type="pil", label="Result (x2)") # Привязка кнопки к функции upscale_btn.click( fn=encode_decode, inputs=input_img, outputs=output_img ) # ------------------------- # Launch # ------------------------- if __name__ == "__main__": demo.launch()