techfreakworm commited on
Commit
14b904e
·
unverified ·
1 Parent(s): 3b83775

feat(ui): per-tab gradio builders with labeled_label + custom model selector

Browse files
Files changed (2) hide show
  1. tests/test_ui.py +64 -1
  2. ui.py +156 -4
tests/test_ui.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import pytest
2
 
3
  import ui
@@ -41,5 +42,67 @@ def test_model_selector_html_defaults_to_turbo():
41
 
42
 
43
  def test_model_selector_html_escapes_current_value():
44
- out = ui.model_selector_html(current='<script>alert(1)</script>')
45
  assert "<script>" not in out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import pytest
3
 
4
  import ui
 
42
 
43
 
44
  def test_model_selector_html_escapes_current_value():
45
+ out = ui.model_selector_html(current="<script>alert(1)</script>")
46
  assert "<script>" not in out
47
+
48
+
49
+ @pytest.fixture(autouse=True)
50
+ def _blocks_ctx():
51
+ """Each builder must be called inside a gr.Blocks() context."""
52
+ with gr.Blocks():
53
+ yield
54
+
55
+
56
+ def test_build_t2i_tab_returns_components():
57
+ components = ui.build_t2i_tab()
58
+ expected = {
59
+ "prompt",
60
+ "negative_prompt",
61
+ "model_state",
62
+ "steps",
63
+ "cfg",
64
+ "width",
65
+ "height",
66
+ "seed",
67
+ "lora_path",
68
+ "lora_strength",
69
+ "generate_btn",
70
+ "output_image",
71
+ "output_meta",
72
+ }
73
+ assert expected.issubset(components.keys())
74
+
75
+
76
+ def test_build_controlnet_tab_returns_components():
77
+ components = ui.build_controlnet_tab()
78
+ expected = {
79
+ "prompt",
80
+ "input_image",
81
+ "preprocessor",
82
+ "controlnet_scale",
83
+ "steps",
84
+ "seed",
85
+ "lora_path",
86
+ "lora_strength",
87
+ "generate_btn",
88
+ "output_image",
89
+ "output_meta",
90
+ }
91
+ assert expected.issubset(components.keys())
92
+
93
+
94
+ def test_build_upscale_tab_returns_components():
95
+ components = ui.build_upscale_tab()
96
+ expected = {
97
+ "prompt",
98
+ "input_image",
99
+ "refine_steps",
100
+ "refine_denoise",
101
+ "seed",
102
+ "lora_path",
103
+ "lora_strength",
104
+ "generate_btn",
105
+ "output_image",
106
+ "output_meta",
107
+ }
108
+ assert expected.issubset(components.keys())
ui.py CHANGED
@@ -1,8 +1,14 @@
1
  """Gradio UI builders + small HTML helpers for the (i) tooltip pattern and the custom model selector."""
 
2
  from __future__ import annotations
3
 
4
  from html import escape
5
 
 
 
 
 
 
6
  GITHUB_MODEL_ZOO_URL = "https://github.com/Tongyi-MAI/Z-Image#model-zoo"
7
 
8
 
@@ -16,7 +22,7 @@ def labeled_label(text: str, info_text: str) -> str:
16
  return (
17
  f'<label class="zis-row-label">{escape(text)}'
18
  f'<span class="zis-info" data-info="{escape(info_text)}">i</span>'
19
- f'</label>'
20
  )
21
 
22
 
@@ -36,10 +42,10 @@ def model_selector_html(current: str = "Turbo") -> str:
36
  cls = "zis-model on" if name == current else "zis-model"
37
  cards.append(
38
  f'<button type="button" class="{cls}" data-value="{name}" '
39
- f'onclick="zis.setModel(\'{name}\')">'
40
  f'<span class="dot"></span>'
41
  f'<span class="name">{name}</span>'
42
- f'</button>'
43
  )
44
  for name in ("Edit", "Omni Base"):
45
  cards.append(
@@ -49,7 +55,153 @@ def model_selector_html(current: str = "Turbo") -> str:
49
  f'<span class="dot"></span>'
50
  f'<span class="name">{name}<span class="ext">↗</span></span>'
51
  f'<span class="soon-tag">soon</span>'
52
- f'</a>'
53
  )
54
  _ = current_safe # current is matched in cls above; this line keeps escape() exercised
55
  return f'<div class="zis-models">{"".join(cards)}</div>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Gradio UI builders + small HTML helpers for the (i) tooltip pattern and the custom model selector."""
2
+
3
  from __future__ import annotations
4
 
5
  from html import escape
6
 
7
+ import gradio as gr
8
+
9
+ import preprocessors
10
+ from tooltips import TOOLTIPS
11
+
12
  GITHUB_MODEL_ZOO_URL = "https://github.com/Tongyi-MAI/Z-Image#model-zoo"
13
 
14
 
 
22
  return (
23
  f'<label class="zis-row-label">{escape(text)}'
24
  f'<span class="zis-info" data-info="{escape(info_text)}">i</span>'
25
+ f"</label>"
26
  )
27
 
28
 
 
42
  cls = "zis-model on" if name == current else "zis-model"
43
  cards.append(
44
  f'<button type="button" class="{cls}" data-value="{name}" '
45
+ f"onclick=\"zis.setModel('{name}')\">"
46
  f'<span class="dot"></span>'
47
  f'<span class="name">{name}</span>'
48
+ f"</button>"
49
  )
50
  for name in ("Edit", "Omni Base"):
51
  cards.append(
 
55
  f'<span class="dot"></span>'
56
  f'<span class="name">{name}<span class="ext">↗</span></span>'
57
  f'<span class="soon-tag">soon</span>'
58
+ f"</a>"
59
  )
60
  _ = current_safe # current is matched in cls above; this line keeps escape() exercised
61
  return f'<div class="zis-models">{"".join(cards)}</div>'
62
+
63
+
64
+ def build_t2i_tab() -> dict[str, gr.components.Component]:
65
+ with gr.Row():
66
+ with gr.Column(scale=4):
67
+ gr.HTML(labeled_label("Prompt", TOOLTIPS["prompt"]))
68
+ prompt = gr.Textbox(lines=4, show_label=False, placeholder="A latina model peeking through pine branches…")
69
+ gr.HTML(labeled_label("Negative prompt (Base only)", TOOLTIPS["negative_prompt"]))
70
+ negative_prompt = gr.Textbox(lines=2, show_label=False, placeholder="blurry, lowres, distorted")
71
+ gr.HTML(labeled_label("Model", TOOLTIPS["model"]))
72
+ model_state = gr.Textbox(value="Turbo", visible=False, elem_id="zis-model-state")
73
+ gr.HTML(model_selector_html(current="Turbo"))
74
+ with gr.Row():
75
+ with gr.Column():
76
+ gr.HTML(labeled_label("LoRA (optional)", TOOLTIPS["lora"]))
77
+ lora_path = gr.File(file_types=[".safetensors"], type="filepath", show_label=False)
78
+ with gr.Column():
79
+ gr.HTML(labeled_label("LoRA strength", TOOLTIPS["lora_strength"]))
80
+ lora_strength = gr.Slider(0.0, 1.5, value=0.8, step=0.05, show_label=False)
81
+ with gr.Row():
82
+ with gr.Column():
83
+ gr.HTML(labeled_label("Steps", TOOLTIPS["steps"]))
84
+ steps = gr.Slider(1, 50, value=8, step=1, show_label=False)
85
+ with gr.Column():
86
+ gr.HTML(labeled_label("CFG (Base only)", TOOLTIPS["cfg"]))
87
+ cfg = gr.Slider(0.5, 12.0, value=1.0, step=0.1, show_label=False)
88
+ with gr.Row():
89
+ with gr.Column():
90
+ gr.HTML(labeled_label("Width", TOOLTIPS["width"]))
91
+ width = gr.Slider(384, 1536, value=1024, step=64, show_label=False)
92
+ with gr.Column():
93
+ gr.HTML(labeled_label("Height", TOOLTIPS["height"]))
94
+ height = gr.Slider(384, 1536, value=1024, step=64, show_label=False)
95
+ with gr.Column():
96
+ gr.HTML(labeled_label("Seed (0 = random)", TOOLTIPS["seed"]))
97
+ seed = gr.Number(value=0, precision=0, show_label=False)
98
+ generate_btn = gr.Button("Generate", variant="primary")
99
+ with gr.Column(scale=5):
100
+ gr.HTML(labeled_label("Output", TOOLTIPS["output"]))
101
+ output_image = gr.Image(type="pil", height=512, show_download_button=True, show_label=False)
102
+ output_meta = gr.JSON(label="Meta", value={})
103
+ return dict(
104
+ prompt=prompt,
105
+ negative_prompt=negative_prompt,
106
+ model_state=model_state,
107
+ steps=steps,
108
+ cfg=cfg,
109
+ width=width,
110
+ height=height,
111
+ seed=seed,
112
+ lora_path=lora_path,
113
+ lora_strength=lora_strength,
114
+ generate_btn=generate_btn,
115
+ output_image=output_image,
116
+ output_meta=output_meta,
117
+ )
118
+
119
+
120
+ def build_controlnet_tab() -> dict[str, gr.components.Component]:
121
+ with gr.Row():
122
+ with gr.Column(scale=4):
123
+ gr.HTML(labeled_label("Prompt", TOOLTIPS["prompt"]))
124
+ prompt = gr.Textbox(lines=3, show_label=False)
125
+ gr.HTML(labeled_label("Control image", TOOLTIPS["controlnet_image"]))
126
+ input_image = gr.Image(type="pil", height=240, show_label=False)
127
+ with gr.Row():
128
+ with gr.Column():
129
+ gr.HTML(labeled_label("Preprocessor", TOOLTIPS["controlnet_preprocessor"]))
130
+ preprocessor = gr.Dropdown(list(preprocessors.MODES), value="Canny", show_label=False)
131
+ with gr.Column():
132
+ gr.HTML(labeled_label("ControlNet scale", TOOLTIPS["controlnet_scale"]))
133
+ controlnet_scale = gr.Slider(0.0, 2.0, value=1.0, step=0.05, show_label=False)
134
+ with gr.Row():
135
+ with gr.Column():
136
+ gr.HTML(labeled_label("LoRA (optional)", TOOLTIPS["lora"]))
137
+ lora_path = gr.File(file_types=[".safetensors"], type="filepath", show_label=False)
138
+ with gr.Column():
139
+ gr.HTML(labeled_label("LoRA strength", TOOLTIPS["lora_strength"]))
140
+ lora_strength = gr.Slider(0.0, 1.5, value=0.8, step=0.05, show_label=False)
141
+ with gr.Row():
142
+ with gr.Column():
143
+ gr.HTML(labeled_label("Steps", TOOLTIPS["steps"]))
144
+ steps = gr.Slider(1, 30, value=9, step=1, show_label=False)
145
+ with gr.Column():
146
+ gr.HTML(labeled_label("Seed (0 = random)", TOOLTIPS["seed"]))
147
+ seed = gr.Number(value=0, precision=0, show_label=False)
148
+ generate_btn = gr.Button("Generate", variant="primary")
149
+ with gr.Column(scale=5):
150
+ gr.HTML(labeled_label("Output", TOOLTIPS["output"]))
151
+ output_image = gr.Image(type="pil", height=512, show_download_button=True, show_label=False)
152
+ output_meta = gr.JSON(label="Meta", value={})
153
+ return dict(
154
+ prompt=prompt,
155
+ input_image=input_image,
156
+ preprocessor=preprocessor,
157
+ controlnet_scale=controlnet_scale,
158
+ steps=steps,
159
+ seed=seed,
160
+ lora_path=lora_path,
161
+ lora_strength=lora_strength,
162
+ generate_btn=generate_btn,
163
+ output_image=output_image,
164
+ output_meta=output_meta,
165
+ )
166
+
167
+
168
+ def build_upscale_tab() -> dict[str, gr.components.Component]:
169
+ with gr.Row():
170
+ with gr.Column(scale=4):
171
+ gr.HTML(labeled_label("Refinement prompt", TOOLTIPS["prompt"]))
172
+ prompt = gr.Textbox(value="masterpiece, 8k", lines=2, show_label=False)
173
+ gr.HTML(labeled_label("Input image", TOOLTIPS["upscale_image"]))
174
+ input_image = gr.Image(type="pil", height=240, show_label=False)
175
+ with gr.Row():
176
+ with gr.Column():
177
+ gr.HTML(labeled_label("Refine steps", TOOLTIPS["refine_steps"]))
178
+ refine_steps = gr.Slider(1, 20, value=5, step=1, show_label=False)
179
+ with gr.Column():
180
+ gr.HTML(labeled_label("Refine denoise", TOOLTIPS["refine_denoise"]))
181
+ refine_denoise = gr.Slider(0.0, 1.0, value=0.33, step=0.01, show_label=False)
182
+ with gr.Row():
183
+ with gr.Column():
184
+ gr.HTML(labeled_label("LoRA (optional)", TOOLTIPS["lora"]))
185
+ lora_path = gr.File(file_types=[".safetensors"], type="filepath", show_label=False)
186
+ with gr.Column():
187
+ gr.HTML(labeled_label("LoRA strength", TOOLTIPS["lora_strength"]))
188
+ lora_strength = gr.Slider(0.0, 1.5, value=0.8, step=0.05, show_label=False)
189
+ gr.HTML(labeled_label("Seed (0 = random)", TOOLTIPS["seed"]))
190
+ seed = gr.Number(value=0, precision=0, show_label=False)
191
+ generate_btn = gr.Button("Generate", variant="primary")
192
+ with gr.Column(scale=5):
193
+ gr.HTML(labeled_label("Output (2x upscaled)", TOOLTIPS["output"]))
194
+ output_image = gr.Image(type="pil", height=512, show_download_button=True, show_label=False)
195
+ output_meta = gr.JSON(label="Meta", value={})
196
+ return dict(
197
+ prompt=prompt,
198
+ input_image=input_image,
199
+ refine_steps=refine_steps,
200
+ refine_denoise=refine_denoise,
201
+ seed=seed,
202
+ lora_path=lora_path,
203
+ lora_strength=lora_strength,
204
+ generate_btn=generate_btn,
205
+ output_image=output_image,
206
+ output_meta=output_meta,
207
+ )