simplevae2x / app.py
recoilme's picture
Update app.py
39326e5 verified
import gradio as gr
import torch
import torchvision.transforms as T
from PIL import Image
from diffusers import AsymmetricAutoencoderKL
import spaces
MODEL_ID = "AiArtLab/sdxs-1b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# -------------------------
# Load VAE
# -------------------------
def load_vae(model_id=MODEL_ID):
for attempt in (None, "vae"):
try:
if attempt is None:
vae = AsymmetricAutoencoderKL.from_pretrained(
model_id,
torch_dtype=DTYPE
)
else:
vae = AsymmetricAutoencoderKL.from_pretrained(
model_id,
subfolder=attempt,
torch_dtype=DTYPE
)
vae = vae.to(DEVICE)
vae.eval()
print("VAE loaded on", DEVICE)
return vae
except Exception as e:
last_err = e
raise RuntimeError(f"Failed to load VAE: {last_err}")
_vae = None
def get_vae():
global _vae
if _vae is None:
_vae = load_vae()
return _vae
# -------------------------
# Encode / Decode
# -------------------------
@spaces.GPU(duration=60)
def encode_decode(img: Image.Image):
if img is None:
raise gr.Error("Please upload an image")
vae = get_vae()
img = img.convert("RGB")
tfm = T.Compose([
T.ToTensor(),
T.Normalize([0.5]*3, [0.5]*3),
])
t = tfm(img).unsqueeze(0).to(DEVICE, dtype=DTYPE)
print("Input tensor:", t.shape, t.dtype, t.device)
with torch.no_grad():
enc = vae.encode(t)
lat = enc.latent_dist.sample()
print("Latents:", lat.shape)
dec = vae.decode(lat).sample
x = (dec.clamp(-1, 1) + 1) * 127.5
x = x.round().to(torch.uint8)
x = x.squeeze(0).permute(1, 2, 0).cpu().numpy()
out = Image.fromarray(x)
print("Output size:", out.size)
return out
# -------------------------
# UI
# -------------------------
with gr.Blocks(title="VAE True-to-Source Upscaler") as demo:
gr.Markdown("# VAE-based 2x Upscaler\nSimple, blind, true-to-source upscaling without AI hallucinations.")
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="Upload Image")
upscale_btn = gr.Button("Upscale x2", variant="primary")
with gr.Column():
output_img = gr.Image(type="pil", label="Result (x2)")
# Привязка кнопки к функции
upscale_btn.click(
fn=encode_decode,
inputs=input_img,
outputs=output_img
)
# -------------------------
# Launch
# -------------------------
if __name__ == "__main__":
demo.launch()