prithivMLmods commited on
Commit
7798d98
·
verified ·
1 Parent(s): eb9cd4d

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -39
app.py DELETED
@@ -1,39 +0,0 @@
1
- import torch
2
- import gradio as gr
3
- from diffusers import FluxPipeline
4
- from pruna import SmashConfig, smash
5
-
6
- # Load pipeline
7
- pipe = FluxPipeline.from_pretrained(
8
- "black-forest-labs/FLUX.1-dev",
9
- torch_dtype=torch.bfloat16
10
- ).to("cuda")
11
-
12
- # Smash optimization
13
- smash_config = SmashConfig()
14
- smash_config["cacher"] = "fora"
15
- smash_config["fora_interval"] = 3 # or 2 for even faster inference
16
- smash_config["compiler"] = "torch_compile"
17
- smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
18
- smash_config["quantizer"] = "torchao"
19
- smash_config["torchao_quant_type"] = "int8dq" # you can also try fp8dq
20
- smash_config["torchao_excluded_modules"] = "norm+embedding"
21
-
22
- smashed_pipe = smash(pipe, smash_config)
23
-
24
- # Inference function
25
- def generate_image(prompt: str):
26
- image = smashed_pipe(prompt).images[0]
27
- return image
28
-
29
- # Gradio UI
30
- demo = gr.Interface(
31
- fn=generate_image,
32
- inputs=gr.Textbox(label="Enter your prompt", placeholder="e.g. a knitted purple prune"),
33
- outputs=gr.Image(label="Generated Image"),
34
- title="FLUX.1-dev with Pruna Smash ⚡",
35
- description="Optimized inference with Fora caching, Torch Compile, and TorchAO quantization."
36
- )
37
-
38
- if __name__ == "__main__":
39
- demo.launch(share=True)