Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from diffusers import AsymmetricAutoencoderKL | |
| import spaces | |
| MODEL_ID = "AiArtLab/sdxs-1b" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| # ------------------------- | |
| # Load VAE | |
| # ------------------------- | |
| def load_vae(model_id=MODEL_ID): | |
| for attempt in (None, "vae"): | |
| try: | |
| if attempt is None: | |
| vae = AsymmetricAutoencoderKL.from_pretrained( | |
| model_id, | |
| torch_dtype=DTYPE | |
| ) | |
| else: | |
| vae = AsymmetricAutoencoderKL.from_pretrained( | |
| model_id, | |
| subfolder=attempt, | |
| torch_dtype=DTYPE | |
| ) | |
| vae = vae.to(DEVICE) | |
| vae.eval() | |
| print("VAE loaded on", DEVICE) | |
| return vae | |
| except Exception as e: | |
| last_err = e | |
| raise RuntimeError(f"Failed to load VAE: {last_err}") | |
| _vae = None | |
| def get_vae(): | |
| global _vae | |
| if _vae is None: | |
| _vae = load_vae() | |
| return _vae | |
| # ------------------------- | |
| # Encode / Decode | |
| # ------------------------- | |
| def encode_decode(img: Image.Image): | |
| if img is None: | |
| raise gr.Error("Please upload an image") | |
| vae = get_vae() | |
| img = img.convert("RGB") | |
| tfm = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize([0.5]*3, [0.5]*3), | |
| ]) | |
| t = tfm(img).unsqueeze(0).to(DEVICE, dtype=DTYPE) | |
| print("Input tensor:", t.shape, t.dtype, t.device) | |
| with torch.no_grad(): | |
| enc = vae.encode(t) | |
| lat = enc.latent_dist.sample() | |
| print("Latents:", lat.shape) | |
| dec = vae.decode(lat).sample | |
| x = (dec.clamp(-1, 1) + 1) * 127.5 | |
| x = x.round().to(torch.uint8) | |
| x = x.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
| out = Image.fromarray(x) | |
| print("Output size:", out.size) | |
| return out | |
| # ------------------------- | |
| # UI | |
| # ------------------------- | |
| with gr.Blocks(title="VAE True-to-Source Upscaler") as demo: | |
| gr.Markdown("# VAE-based 2x Upscaler\nSimple, blind, true-to-source upscaling without AI hallucinations.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img = gr.Image(type="pil", label="Upload Image") | |
| upscale_btn = gr.Button("Upscale x2", variant="primary") | |
| with gr.Column(): | |
| output_img = gr.Image(type="pil", label="Result (x2)") | |
| # Привязка кнопки к функции | |
| upscale_btn.click( | |
| fn=encode_decode, | |
| inputs=input_img, | |
| outputs=output_img | |
| ) | |
| # ------------------------- | |
| # Launch | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| demo.launch() |