techfreakworm commited on
Commit
5d4511a
·
unverified ·
1 Parent(s): 14b904e

feat(app): gradio blocks entrypoint with bootstrap + event wiring + js shim

Browse files
Files changed (1) hide show
  1. app.py +203 -0
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """z-image-studio — Gradio entrypoint.
2
+
3
+ On HF Spaces, ``_bootstrap`` runs once on import to mirror the read-only preload
4
+ cache into a writable tree.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import os
9
+ import random
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import gradio as gr
14
+
15
+ import backend
16
+ import lora as lora_mod # avoid shadowing the gr.File `lora_path` name
17
+ import models
18
+ import theme
19
+ import ui
20
+
21
+
22
+ # ----- HF Spaces bootstrap ---------------------------------------------------
23
+
24
+ def _bootstrap() -> None:
25
+ """Mirror the preload_from_hub cache once, then point HF env at the mirror."""
26
+ if not models.on_spaces():
27
+ return
28
+ src = Path(os.environ.get("HF_HOME", str(Path.home() / ".cache" / "huggingface")))
29
+ dst = Path.home() / "hf-cache-rw"
30
+ models.mirror_preload_hf_cache(src, dst)
31
+ os.environ["HF_HOME"] = str(dst)
32
+ os.environ["HF_HUB_CACHE"] = str(dst / "hub")
33
+
34
+
35
+ _bootstrap()
36
+
37
+
38
+ # ----- Eager backend boot ----------------------------------------------------
39
+
40
+ _BACKEND: backend.ZImageStudioBackend | None = None
41
+
42
+
43
+ def get_backend() -> backend.ZImageStudioBackend:
44
+ global _BACKEND
45
+ if _BACKEND is None:
46
+ _BACKEND = backend.ZImageStudioBackend()
47
+ return _BACKEND
48
+
49
+
50
+ # ----- Generation event handlers --------------------------------------------
51
+
52
+ def _maybe_random_seed(seed: int) -> int:
53
+ return seed if seed and seed > 0 else random.randint(1, 2_147_483_647)
54
+
55
+
56
+ def _coerce_lora(lora_path: str | None) -> Path | None:
57
+ if not lora_path:
58
+ return None
59
+ p = Path(lora_path)
60
+ lora_mod.sniff(p) # validate cheaply; raises LoRAValidationError if bad
61
+ return p
62
+
63
+
64
+ def _esrgan_path() -> str:
65
+ """Locate the preloaded RealESRGAN_x4plus.pth."""
66
+ from huggingface_hub import hf_hub_download
67
+ return hf_hub_download("xinntao/Real-ESRGAN", "RealESRGAN_x4plus.pth")
68
+
69
+
70
+ def on_t2i_generate(prompt, negative_prompt, model, steps, cfg,
71
+ width, height, seed, lora_path, lora_strength):
72
+ try:
73
+ lora_p = _coerce_lora(lora_path)
74
+ except lora_mod.LoRAValidationError as e:
75
+ raise gr.Error(str(e)) from e
76
+
77
+ params = dict(
78
+ prompt=prompt, negative_prompt=negative_prompt or "",
79
+ model=model, steps=int(steps), cfg=float(cfg),
80
+ width=int(width), height=int(height),
81
+ seed=_maybe_random_seed(int(seed)),
82
+ lora_path=lora_p, lora_strength=float(lora_strength),
83
+ )
84
+ image, meta = get_backend().generate(mode="t2i", params=params)
85
+ return image, meta
86
+
87
+
88
+ def on_controlnet_generate(prompt, input_image, preprocessor, controlnet_scale,
89
+ steps, seed, lora_path, lora_strength):
90
+ try:
91
+ lora_p = _coerce_lora(lora_path)
92
+ except lora_mod.LoRAValidationError as e:
93
+ raise gr.Error(str(e)) from e
94
+
95
+ params = dict(
96
+ prompt=prompt, input_image=input_image,
97
+ preprocessor=preprocessor, controlnet_scale=float(controlnet_scale),
98
+ steps=int(steps), seed=_maybe_random_seed(int(seed)),
99
+ lora_path=lora_p, lora_strength=float(lora_strength),
100
+ )
101
+ image, meta = get_backend().generate(mode="controlnet", params=params)
102
+ return image, meta
103
+
104
+
105
+ def on_upscale_generate(prompt, input_image, refine_steps, refine_denoise,
106
+ seed, lora_path, lora_strength):
107
+ try:
108
+ lora_p = _coerce_lora(lora_path)
109
+ except lora_mod.LoRAValidationError as e:
110
+ raise gr.Error(str(e)) from e
111
+
112
+ params = dict(
113
+ prompt=prompt or "masterpiece, 8k",
114
+ input_image=input_image,
115
+ refine_steps=int(refine_steps),
116
+ refine_denoise=float(refine_denoise),
117
+ seed=_maybe_random_seed(int(seed)),
118
+ lora_path=lora_p, lora_strength=float(lora_strength),
119
+ esrgan_model_path=_esrgan_path(),
120
+ )
121
+ image, meta = get_backend().generate(mode="upscale", params=params)
122
+ return image, meta
123
+
124
+
125
+ # ----- Blocks ----------------------------------------------------------------
126
+
127
+ HEADER_HTML = """
128
+ <div style="display:flex;justify-content:space-between;align-items:baseline;padding:8px 0 4px 0;">
129
+ <div style="font-family:'Geist',sans-serif;font-size:16px;font-weight:600;letter-spacing:-0.02em;">
130
+ z<span style="color:#FFB02E;">·</span>image studio
131
+ </div>
132
+ <div class="zis-status">ready</div>
133
+ </div>
134
+ """.strip()
135
+
136
+
137
+ _HEAD_JS = """
138
+ <script>
139
+ window.zis = {
140
+ setModel: function(name) {
141
+ document.querySelectorAll('.zis-model').forEach(el => {
142
+ el.classList.toggle('on', el.dataset.value === name);
143
+ });
144
+ const hidden = document.querySelector('#zis-model-state textarea, #zis-model-state input');
145
+ if (hidden) {
146
+ hidden.value = name;
147
+ hidden.dispatchEvent(new Event('input', { bubbles: true }));
148
+ }
149
+ }
150
+ };
151
+ // Tap-to-pin tooltips on mobile
152
+ document.addEventListener('touchstart', function(e) {
153
+ const tip = e.target.closest('.zis-info');
154
+ document.querySelectorAll('.zis-info.shown').forEach(el => {
155
+ if (el !== tip) el.classList.remove('shown');
156
+ });
157
+ if (tip) tip.classList.toggle('shown');
158
+ }, { passive: true });
159
+ </script>
160
+ """.strip()
161
+
162
+
163
+ def build_app() -> gr.Blocks:
164
+ with gr.Blocks(theme=theme.build_theme(), css=theme.CSS, head=_HEAD_JS, title="z-image-studio") as demo:
165
+ gr.HTML(HEADER_HTML)
166
+
167
+ with gr.Tabs():
168
+ with gr.Tab("Text → Image"):
169
+ t = ui.build_t2i_tab()
170
+ t["generate_btn"].click(
171
+ fn=on_t2i_generate,
172
+ inputs=[t["prompt"], t["negative_prompt"], t["model_state"],
173
+ t["steps"], t["cfg"], t["width"], t["height"], t["seed"],
174
+ t["lora_path"], t["lora_strength"]],
175
+ outputs=[t["output_image"], t["output_meta"]],
176
+ )
177
+
178
+ with gr.Tab("ControlNet"):
179
+ c = ui.build_controlnet_tab()
180
+ c["generate_btn"].click(
181
+ fn=on_controlnet_generate,
182
+ inputs=[c["prompt"], c["input_image"],
183
+ c["preprocessor"], c["controlnet_scale"],
184
+ c["steps"], c["seed"], c["lora_path"], c["lora_strength"]],
185
+ outputs=[c["output_image"], c["output_meta"]],
186
+ )
187
+
188
+ with gr.Tab("Upscale"):
189
+ u = ui.build_upscale_tab()
190
+ u["generate_btn"].click(
191
+ fn=on_upscale_generate,
192
+ inputs=[u["prompt"], u["input_image"],
193
+ u["refine_steps"], u["refine_denoise"],
194
+ u["seed"], u["lora_path"], u["lora_strength"]],
195
+ outputs=[u["output_image"], u["output_meta"]],
196
+ )
197
+ return demo
198
+
199
+
200
+ if __name__ == "__main__":
201
+ demo = build_app()
202
+ demo.queue(default_concurrency_limit=1)
203
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))