techfreakworm commited on
Commit
aa2a834
·
unverified ·
1 Parent(s): 0cf8ffc

feat(progress): gradio progress(track_tqdm=True) on all three handlers

Browse files

DiffSynth uses tqdm for the diffusion step loop and RealESRGAN uses it for
the tile pass. Adding gr.Progress(track_tqdm=True) as a keyword parameter
on the three on_*_generate handlers lets Gradio auto-capture both without
threading a progress object through backend/modes.

Files changed (2) hide show
  1. app.py +34 -3
  2. tests/test_modes.py +20 -6
app.py CHANGED
@@ -106,7 +106,19 @@ def _esrgan_path() -> str:
106
  return hf_hub_download("lllyasviel/Annotators", "RealESRGAN_x4plus.pth")
107
 
108
 
109
- def on_t2i_generate(prompt, negative_prompt, model, steps, cfg, width, height, seed, lora_path, lora_strength):
 
 
 
 
 
 
 
 
 
 
 
 
110
  try:
111
  lora_p = _coerce_lora(lora_path)
112
  except lora_mod.LoRAValidationError as e:
@@ -128,7 +140,17 @@ def on_t2i_generate(prompt, negative_prompt, model, steps, cfg, width, height, s
128
  return image, meta
129
 
130
 
131
- def on_controlnet_generate(prompt, input_image, preprocessor, controlnet_scale, steps, seed, lora_path, lora_strength):
 
 
 
 
 
 
 
 
 
 
132
  try:
133
  lora_p = _coerce_lora(lora_path)
134
  except lora_mod.LoRAValidationError as e:
@@ -148,7 +170,16 @@ def on_controlnet_generate(prompt, input_image, preprocessor, controlnet_scale,
148
  return image, meta
149
 
150
 
151
- def on_upscale_generate(prompt, input_image, refine_steps, refine_denoise, seed, lora_path, lora_strength):
 
 
 
 
 
 
 
 
 
152
  try:
153
  lora_p = _coerce_lora(lora_path)
154
  except lora_mod.LoRAValidationError as e:
 
106
  return hf_hub_download("lllyasviel/Annotators", "RealESRGAN_x4plus.pth")
107
 
108
 
109
+ def on_t2i_generate(
110
+ prompt,
111
+ negative_prompt,
112
+ model,
113
+ steps,
114
+ cfg,
115
+ width,
116
+ height,
117
+ seed,
118
+ lora_path,
119
+ lora_strength,
120
+ progress=gr.Progress(track_tqdm=True), # noqa: B008
121
+ ):
122
  try:
123
  lora_p = _coerce_lora(lora_path)
124
  except lora_mod.LoRAValidationError as e:
 
140
  return image, meta
141
 
142
 
143
+ def on_controlnet_generate(
144
+ prompt,
145
+ input_image,
146
+ preprocessor,
147
+ controlnet_scale,
148
+ steps,
149
+ seed,
150
+ lora_path,
151
+ lora_strength,
152
+ progress=gr.Progress(track_tqdm=True), # noqa: B008
153
+ ):
154
  try:
155
  lora_p = _coerce_lora(lora_path)
156
  except lora_mod.LoRAValidationError as e:
 
170
  return image, meta
171
 
172
 
173
+ def on_upscale_generate(
174
+ prompt,
175
+ input_image,
176
+ refine_steps,
177
+ refine_denoise,
178
+ seed,
179
+ lora_path,
180
+ lora_strength,
181
+ progress=gr.Progress(track_tqdm=True), # noqa: B008
182
+ ):
183
  try:
184
  lora_p = _coerce_lora(lora_path)
185
  except lora_mod.LoRAValidationError as e:
tests/test_modes.py CHANGED
@@ -79,9 +79,16 @@ def test_t2i_swaps_transformer_via_pool_index(fake_pipe):
79
  modes.call_t2i(
80
  fake_pipe,
81
  params=dict(
82
- prompt="x", negative_prompt="", model="Base",
83
- steps=25, cfg=4.0, width=1024, height=1024, seed=0,
84
- lora_path=None, lora_strength=0.0,
 
 
 
 
 
 
 
85
  ),
86
  )
87
  assert fake_pipe.dit is base_dit
@@ -89,9 +96,16 @@ def test_t2i_swaps_transformer_via_pool_index(fake_pipe):
89
  modes.call_t2i(
90
  fake_pipe,
91
  params=dict(
92
- prompt="x", negative_prompt="", model="Turbo",
93
- steps=8, cfg=1.0, width=1024, height=1024, seed=0,
94
- lora_path=None, lora_strength=0.0,
 
 
 
 
 
 
 
95
  ),
96
  )
97
  assert fake_pipe.dit is turbo_dit
 
79
  modes.call_t2i(
80
  fake_pipe,
81
  params=dict(
82
+ prompt="x",
83
+ negative_prompt="",
84
+ model="Base",
85
+ steps=25,
86
+ cfg=4.0,
87
+ width=1024,
88
+ height=1024,
89
+ seed=0,
90
+ lora_path=None,
91
+ lora_strength=0.0,
92
  ),
93
  )
94
  assert fake_pipe.dit is base_dit
 
96
  modes.call_t2i(
97
  fake_pipe,
98
  params=dict(
99
+ prompt="x",
100
+ negative_prompt="",
101
+ model="Turbo",
102
+ steps=8,
103
+ cfg=1.0,
104
+ width=1024,
105
+ height=1024,
106
+ seed=0,
107
+ lora_path=None,
108
+ lora_strength=0.0,
109
  ),
110
  )
111
  assert fake_pipe.dit is turbo_dit