prithivMLmods commited on
Commit
1f756dc
·
verified ·
1 Parent(s): 34ac90e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -271
app.py CHANGED
@@ -1,288 +1,39 @@
1
- import spaces
2
- import gradio as gr
3
  import torch
4
- from PIL import Image
5
- from diffusers import FluxPipeline # Changed from DiffusionPipeline for better compatibility
6
- import random
7
- import uuid
8
- from typing import Tuple
9
- import numpy as np
10
- import time
11
- import zipfile
12
-
13
- # --- Pruna Imports ---
14
  from pruna import SmashConfig, smash
15
 
16
- DESCRIPTION = """## flux realism
17
- """
18
-
19
- def save_image(img):
20
- unique_name = str(uuid.uuid4()) + ".png"
21
- img.save(unique_name)
22
- return unique_name
23
-
24
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
25
- if randomize_seed:
26
- seed = random.randint(0, MAX_SEED)
27
- return seed
28
 
29
- MAX_SEED = np.iinfo(np.int32).max
30
-
31
- # --- Model and Pipeline Setup ---
32
- base_model = "black-forest-labs/FLUX.1-dev"
33
- # Use FluxPipeline directly and move to CUDA before applying optimizations
34
- pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
35
-
36
- lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
37
- trigger_word = "Super Realism"
38
-
39
- pipe.load_lora_weights(lora_repo)
40
- pipe.to("cuda")
41
-
42
- # --- Pruna Optimization ---
43
- print("Applying Pruna optimizations...")
44
  smash_config = SmashConfig()
45
  smash_config["cacher"] = "fora"
46
  smash_config["fora_interval"] = 3 # or 2 for even faster inference
47
  smash_config["compiler"] = "torch_compile"
48
  smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
49
  smash_config["quantizer"] = "torchao"
50
- smash_config["torchao_quant_type"] = "int8dq" # you can also try fp8dq
51
  smash_config["torchao_excluded_modules"] = "norm+embedding"
52
 
53
- # Apply smash to the pipeline
54
  smashed_pipe = smash(pipe, smash_config)
55
- print("Pruna optimizations applied successfully.")
56
-
57
-
58
- style_list = [
59
- {
60
- "name": "3840 x 2160",
61
- "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
62
- "negative_prompt": "",
63
- },
64
- {
65
- "name": "2560 x 1440",
66
- "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
67
- "negative_prompt": "",
68
- },
69
- {
70
- "name": "HD+",
71
- "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
72
- "negative_prompt": "",
73
- },
74
- {
75
- "name": "Style Zero",
76
- "prompt": "{prompt}",
77
- "negative_prompt": "",
78
- },
79
- ]
80
-
81
- styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
82
- DEFAULT_STYLE_NAME = "3840 x 2160"
83
- STYLE_NAMES = list(styles.keys())
84
-
85
- def apply_style(style_name: str, positive: str) -> Tuple[str, str]:
86
- p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
87
- return p.replace("{prompt}", positive), n
88
-
89
- @spaces.GPU
90
- def generate(
91
- prompt: str,
92
- negative_prompt: str = "",
93
- use_negative_prompt: bool = False,
94
- seed: int = 0,
95
- width: int = 1024,
96
- height: int = 1024,
97
- guidance_scale: float = 3,
98
- randomize_seed: bool = False,
99
- style_name: str = DEFAULT_STYLE_NAME,
100
- num_inference_steps: int = 20, # Default value updated for faster inference
101
- num_images: int = 1,
102
- zip_images: bool = False,
103
- progress=gr.Progress(track_tqdm=True),
104
- ):
105
- positive_prompt, style_negative_prompt = apply_style(style_name, prompt)
106
-
107
- if use_negative_prompt:
108
- final_negative_prompt = style_negative_prompt + " " + negative_prompt
109
- else:
110
- final_negative_prompt = style_negative_prompt
111
-
112
- final_negative_prompt = final_negative_prompt.strip()
113
-
114
- if trigger_word:
115
- positive_prompt = f"{trigger_word} {positive_prompt}"
116
-
117
- seed = int(randomize_seed_fn(seed, randomize_seed))
118
- generator = torch.Generator(device="cuda").manual_seed(seed)
119
-
120
- start_time = time.time()
121
-
122
- # --- Use the smashed_pipe for generation ---
123
- images = smashed_pipe(
124
- prompt=positive_prompt,
125
- negative_prompt=final_negative_prompt if final_negative_prompt else None,
126
- width=width,
127
- height=height,
128
- guidance_scale=guidance_scale,
129
- num_inference_steps=num_inference_steps,
130
- num_images_per_prompt=num_images,
131
- generator=generator,
132
- output_type="pil",
133
- ).images
134
-
135
- end_time = time.time()
136
- duration = end_time - start_time
137
-
138
- image_paths = [save_image(img) for img in images]
139
-
140
- zip_path = None
141
- if zip_images:
142
- zip_name = str(uuid.uuid4()) + ".zip"
143
- with zipfile.ZipFile(zip_name, 'w') as zipf:
144
- for i, img_path in enumerate(image_paths):
145
- zipf.write(img_path, arcname=f"Img_{i}.png")
146
- zip_path = zip_name
147
-
148
- return image_paths, seed, f"{duration:.2f}", zip_path
149
-
150
- examples = [
151
- "Super Realism, High-resolution photograph, woman, UHD, photorealistic, shot on a Sony A7III --chaos 20 --ar 1:2 --style raw --stylize 250",
152
- "Woman in a red jacket, snowy, in the style of hyper-realistic portraiture, caninecore, mountainous vistas, timeless beauty, palewave, iconic, distinctive noses --ar 72:101 --stylize 750 --v 6",
153
- "Super Realism, Headshot of handsome young man, wearing dark gray sweater with buttons and big shawl collar, brown hair and short beard, serious look on his face, black background, soft studio lighting, portrait photography --ar 85:128 --v 6.0 --style",
154
- "Super-realism, Purple Dreamy, a medium-angle shot of a young woman with long brown hair, wearing a pair of eye-level glasses, stands in front of a backdrop of purple and white lights. The womans eyes are closed, her lips are slightly parted, as if she is looking up at the sky. Her hair is cascading over her shoulders, framing her face. She is wearing a sleeveless top, adorned with tiny white dots, and a gold chain necklace around her neck. Her left earrings are dangling from her ears, adding a pop of color to the scene."
155
- ]
156
-
157
- css = '''
158
- .gradio-container {
159
- max-width: 590px !important;
160
- margin: 0 auto !important;
161
- }
162
- h1 {
163
- text-align: center;
164
- }
165
- footer {
166
- visibility: hidden;
167
- }
168
- '''
169
-
170
- with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
171
- gr.Markdown(DESCRIPTION)
172
- with gr.Row():
173
- prompt = gr.Text(
174
- label="Prompt",
175
- show_label=False,
176
- max_lines=1,
177
- placeholder="Enter your prompt",
178
- container=False,
179
- )
180
- run_button = gr.Button("Run", scale=0, variant="primary")
181
- result = gr.Gallery(label="Result", columns=1, show_label=False, preview=True)
182
-
183
- with gr.Accordion("Additional Options", open=False):
184
- style_selection = gr.Dropdown(
185
- label="Quality Style",
186
- choices=STYLE_NAMES,
187
- value=DEFAULT_STYLE_NAME,
188
- interactive=True,
189
- )
190
- use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
191
- negative_prompt = gr.Text(
192
- label="Negative prompt",
193
- max_lines=1,
194
- placeholder="Enter a negative prompt",
195
- visible=False,
196
- )
197
- seed = gr.Slider(
198
- label="Seed",
199
- minimum=0,
200
- maximum=MAX_SEED,
201
- step=1,
202
- value=0,
203
- )
204
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
205
- with gr.Row():
206
- width = gr.Slider(
207
- label="Width",
208
- minimum=512,
209
- maximum=2048,
210
- step=64,
211
- value=1024,
212
- )
213
- height = gr.Slider(
214
- label="Height",
215
- minimum=512,
216
- maximum=2048,
217
- step=64,
218
- value=1024,
219
- )
220
- guidance_scale = gr.Slider(
221
- label="Guidance Scale",
222
- minimum=0.1,
223
- maximum=20.0,
224
- step=0.1,
225
- value=3.0,
226
- )
227
- num_inference_steps = gr.Slider(
228
- label="Number of inference steps",
229
- minimum=1,
230
- maximum=40,
231
- step=1,
232
- value=20, # Default value lowered for optimized performance
233
- )
234
- num_images = gr.Slider(
235
- label="Number of images",
236
- minimum=1,
237
- maximum=5,
238
- step=1,
239
- value=1,
240
- )
241
- zip_images = gr.Checkbox(label="Zip generated images", value=False)
242
-
243
- gr.Markdown("### Output Information")
244
- seed_display = gr.Textbox(label="Seed used", interactive=False)
245
- generation_time = gr.Textbox(label="Generation time (seconds)", interactive=False)
246
- zip_file = gr.File(label="Download ZIP")
247
-
248
- gr.Examples(
249
- examples=examples,
250
- inputs=prompt,
251
- outputs=[result, seed_display, generation_time, zip_file],
252
- fn=generate,
253
- cache_examples=False,
254
- )
255
 
256
- use_negative_prompt.change(
257
- fn=lambda x: gr.update(visible=x),
258
- inputs=use_negative_prompt,
259
- outputs=negative_prompt,
260
- api_name=False,
261
- )
262
 
263
- gr.on(
264
- triggers=[
265
- prompt.submit,
266
- run_button.click,
267
- ],
268
- fn=generate,
269
- inputs=[
270
- prompt,
271
- negative_prompt,
272
- use_negative_prompt,
273
- seed,
274
- width,
275
- height,
276
- guidance_scale,
277
- randomize_seed,
278
- style_selection,
279
- num_inference_steps,
280
- num_images,
281
- zip_images,
282
- ],
283
- outputs=[result, seed_display, generation_time, zip_file],
284
- api_name="run",
285
- )
286
 
287
  if __name__ == "__main__":
288
- demo.queue(max_size=120).launch(mcp_server=True, ssr_mode=False, show_error=True)
 
 
 
1
  import torch
2
+ import gradio as gr
3
+ from diffusers import FluxPipeline
 
 
 
 
 
 
 
 
4
  from pruna import SmashConfig, smash
5
 
6
+ # Load pipeline
7
+ pipe = FluxPipeline.from_pretrained(
8
+ "black-forest-labs/FLUX.1-dev",
9
+ torch_dtype=torch.bfloat16
10
+ ).to("cuda")
 
 
 
 
 
 
 
11
 
12
+ # Smash optimization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  smash_config = SmashConfig()
14
  smash_config["cacher"] = "fora"
15
  smash_config["fora_interval"] = 3 # or 2 for even faster inference
16
  smash_config["compiler"] = "torch_compile"
17
  smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
18
  smash_config["quantizer"] = "torchao"
19
+ smash_config["torchao_quant_type"] = "int8dq" # you can also try fp8dq
20
  smash_config["torchao_excluded_modules"] = "norm+embedding"
21
 
 
22
  smashed_pipe = smash(pipe, smash_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Inference function
25
+ def generate_image(prompt: str):
26
+ image = smashed_pipe(prompt).images[0]
27
+ return image
 
 
28
 
29
+ # Gradio UI
30
+ demo = gr.Interface(
31
+ fn=generate_image,
32
+ inputs=gr.Textbox(label="Enter your prompt", placeholder="e.g. a knitted purple prune"),
33
+ outputs=gr.Image(label="Generated Image"),
34
+ title="FLUX.1-dev with Pruna Smash ⚡",
35
+ description="Optimized inference with Fora caching, Torch Compile, and TorchAO quantization."
36
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  if __name__ == "__main__":
39
+ demo.launch(share=True)