| import spaces |
| import torch |
| import gradio as gr |
| from diffusers import StableDiffusionPipeline |
| from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig |
| import os |
|
|
| |
|
|
| |
| MODEL_ID = "runwayml/stable-diffusion-v1-5" |
|
|
| |
| |
| pipe = StableDiffusionPipeline.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float16, |
| safety_checker=None, |
| requires_safety_checker=False |
| ) |
| pipe.to('cuda') |
| pipe.scheduler.set_timesteps(50) |
|
|
| print("Starting AoT Compilation...") |
|
|
| @spaces.GPU(duration=1500) |
| def compile_optimized_unet(): |
| |
| try: |
| quantize_(pipe.unet, Float8DynamicActivationFloat8WeightConfig()) |
| print("✅ Applied FP8 quantization to UNet.") |
| except Exception as e: |
| print(f"⚠️ FP8 Quantization failed (may require specific hardware/libraries): {e}") |
|
|
| |
| |
| bsz = 2 |
| latent_model_input = torch.randn(bsz, 4, 64, 64, device="cuda", dtype=torch.float16) |
| t = torch.randint(0, 1000, (bsz,), device="cuda') |
| encoder_hidden_states = torch.randn(bsz, 77, 768, device="cuda", dtype=torch.float16) |
| |
| with spaces.aoti_capture(pipe.unet) as call: |
| pipe.unet(latent_model_input, t, encoder_hidden_states) |
| |
| # 3. Export the model |
| exported = torch.export.export( |
| pipe.unet, |
| args=call.args, |
| kwargs=call.kwargs, |
| ) |
| |
| # 4. Compile the exported model using AoT |
| return spaces.aoti_compile(exported) |
| |
| # Execute compilation during startup |
| compiled_unet = compile_optimized_unet() |
| # 5. Apply compiled model to the pipeline's UNet component |
| spaces.aoti_apply(compiled_unet, pipe.unet) |
| |
| print("✅ AoT Compilation completed successfully.") |
| |
| # --- 2. Inference Function (Running on GPU) --- |
| |
| @spaces.GPU(duration=60) # Standard duration for image generation |
| def generate_image( |
| prompt: str, |
| negative_prompt: str, |
| steps: int, |
| seed: int |
| ): |
| if not prompt: |
| raise gr.Error("Prompt cannot be empty.") |
| |
| generator = torch.Generator(device="cuda").manual_seed(seed) if seed != -1 else None |
| |
| steps = int(steps) |
| |
| # Run inference using the optimized pipeline |
| result = pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| num_inference_steps=steps, |
| guidance_scale=7.5, |
| generator=generator |
| ).images |
| |
| return result |
| |
| # --- 3. Gradio Interface --- |
| |
| with gr.Blocks(title="Optimized Vision Model (AoT Powered)") as demo: |
| gr.HTML( |
| """ |
| <div style="text-align: center; max-width: 800px; margin: 0 auto;"> |
| <h1><a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">Built with anycoder</a></h1> |
| <h2>High-Performance Creative VLM Simulator (AoT Optimized)</h2> |
| <p>This demo simulates a creative Vision Language Model using AoT-compiled Stable Diffusion for lightning-fast image generation.</p> |
| </div> |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| prompt = gr.Textbox( |
| label="Prompt (Input to VLM)", |
| placeholder="A futuristic city painted by Van Gogh, highly detailed.", |
| lines=3 |
| ) |
| negative_prompt = gr.Textbox( |
| label="Negative Prompt (What to avoid)", |
| placeholder="Blurry, bad quality, low resolution", |
| lines=2 |
| ) |
| |
| with gr.Accordion("Generation Settings", open=True): |
| steps = gr.Slider( |
| minimum=10, |
| maximum=50, |
| step=1, |
| value=30, |
| label="Inference Steps (Higher = Slower/Better)" |
| ) |
| seed = gr.Number( |
| value=-1, |
| label="Seed (-1 for random)" |
| ) |
| |
| generate_btn = gr.Button("Generate Image (AoT Fast!)", variant="primary") |
| |
| with gr.Column(scale=2): |
| output_gallery = gr.Gallery( |
| label="Creative VLM Output", |
| show_label=True, |
| height=512, |
| columns=2, |
| object_fit="contain" |
| ) |
| |
| generate_btn.click( |
| fn=generate_image, |
| inputs=[prompt, negative_prompt, steps, seed], |
| outputs=output_gallery |
| ) |
| |
| gr.Examples( |
| examples=[ |
| ["A majestic wolf standing on a snowy mountain peak, cinematic lighting", "ugly, deformed, low detail", 30], |
| ["Cyberpunk cat sitting in a neon-lit alley, 8k, digital art", "human, blurry, messy background", 40], |
| ["A vintage photograph of a space shuttle launching from a tropical island", "modern, cartoon, painting", 25] |
| ], |
| inputs=[prompt, negative_prompt, steps], |
| outputs=output_gallery, |
| fn=generate_image, |
| cache_examples=False, |
| ) |
| |
| demo.queue() |
| demo.launch() |