Vector857 commited on
Commit
0ac09d7
·
verified ·
1 Parent(s): aecde72

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +1291 -0
  2. gitignore +2 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,1291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import time
5
+ import traceback
6
+ from io import BytesIO
7
+
8
+ import gradio as gr
9
+ import requests
10
+ from PIL import Image
11
+ from dotenv import load_dotenv
12
+
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+ load_dotenv()
20
+
21
+ # API Configuration (new style: host + gen_image_path)
22
+ API_TOKEN = os.environ.get("token")
23
+ API_HOST = os.environ.get("host")
24
+ GEN_IMAGE_PATH = os.environ.get("gen_image_path")
25
+ MODEL_ID = os.environ.get("model_id")
26
+
27
+ # Polling / retry configuration (with sensible defaults)
28
+ MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT", 3))
29
+ POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL", 2.0))
30
+ MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME", 300))
31
+
32
+ # Local-only test mode: skip the real API and return a random-colored image after
33
+ # a randomised delay. Useful for testing the UI / queueing flow without burning
34
+ # real model credits. Enable with FAKE_TEST=1 (also accepts "true"/"yes").
35
+ FAKE_TEST = os.environ.get("FAKE_TEST", "").strip().lower() in ("1", "true", "yes", "on")
36
+
37
+ # Predefined aspect ratios (wh_ratio) — kept in the same order as the original
38
+ # (width, height) list so each entry mirrors the previous resolution choice:
39
+ # 1:1 ←→ 2048×2048 4:3 ←→ 2304×1728 3:4 ←→ 1728×2304
40
+ # 16:9 ←→ 2560×1440 9:16 ←→ 1440×2560 3:2 ←→ 2496×1664
41
+ # 2:3 ←→ 1664×2496 21:9 ←→ 3104×1312 9:21 ←→ 1312×3104
42
+ # 9:7 ←→ 2304×1792 7:9 ←→ 1792×2304
43
+ WH_RATIO_OPTIONS = [
44
+ "1:1",
45
+ "4:3",
46
+ "3:4",
47
+ "16:9",
48
+ "9:16",
49
+ "3:2",
50
+ "2:3",
51
+ "21:9",
52
+ "9:21",
53
+ "9:7",
54
+ "7:9",
55
+ ]
56
+
57
+ logger.info(
58
+ f"API configuration loaded: HOST={API_HOST}, GEN_IMAGE_PATH={GEN_IMAGE_PATH}, MODEL_ID={MODEL_ID}"
59
+ )
60
+ logger.info(
61
+ f"Retry configuration: MAX_RETRY_COUNT={MAX_RETRY_COUNT}, POLL_INTERVAL={POLL_INTERVAL}s, MAX_POLL_TIME={MAX_POLL_TIME}s"
62
+ )
63
+ if FAKE_TEST:
64
+ logger.warning("FAKE_TEST mode is ENABLED — no real API calls will be made.")
65
+
66
+
67
+ class APIError(Exception):
68
+ """Custom exception for API-related errors"""
69
+ pass
70
+
71
+
72
+ # Status codes returned by the API
73
+ SUCCESS_CODE = 0
74
+
75
+
76
+ def _build_request_url() -> str:
77
+ if not API_HOST or not GEN_IMAGE_PATH:
78
+ raise APIError("API host or gen_image_path is not configured. Please set the 'host' and 'gen_image_path' environment variables.")
79
+ return f"{API_HOST.rstrip('/')}{GEN_IMAGE_PATH}"
80
+
81
+
82
+ def _build_result_url(task_id: str) -> str:
83
+ return f"{_build_request_url()}/results?task_id={task_id}"
84
+
85
+
86
+ def _headers() -> dict:
87
+ if not API_TOKEN:
88
+ raise APIError("API token is not configured. Please set the 'token' environment variable.")
89
+ return {"Authorization": f"Bearer {API_TOKEN}"}
90
+
91
+
92
+ def create_request(
93
+ prompt,
94
+ wh_ratio,
95
+ negative_prompt="",
96
+ enable_prompt_refine=True,
97
+ seed=-1,
98
+ guidance_scale=5.0,
99
+ ):
100
+ """
101
+ Submit an image generation request to the API.
102
+
103
+ Args:
104
+ prompt (str): Text prompt describing the image to generate
105
+ wh_ratio (str): Aspect ratio for the output image (e.g. "16:9")
106
+ negative_prompt (str): Optional text describing what to avoid.
107
+ enable_prompt_refine (bool): Whether to let the backend rewrite/expand
108
+ the prompt before generation. Sent to the API as 0 / 1.
109
+ seed (int): Generation seed. -1 means the backend will pick one
110
+ randomly; any other integer fixes the seed for reproducible runs.
111
+ guidance_scale (float): Classifier-free guidance strength. Higher
112
+ values follow the prompt more strictly.
113
+
114
+ Returns:
115
+ str: Task ID
116
+
117
+ Raises:
118
+ APIError: If the API request fails
119
+ """
120
+ logger.info(
121
+ f"Starting create_request with prompt='{prompt[:50]}...', "
122
+ f"wh_ratio={wh_ratio}, enable_prompt_refine={enable_prompt_refine}, "
123
+ f"seed={seed}, guidance_scale={guidance_scale}, "
124
+ f"negative_prompt='{(negative_prompt or '')[:30]}...'"
125
+ )
126
+
127
+ if not prompt or not prompt.strip():
128
+ logger.error("Empty prompt provided to create_request")
129
+ raise ValueError("Prompt cannot be empty")
130
+
131
+ if not wh_ratio or not isinstance(wh_ratio, str) or wh_ratio not in WH_RATIO_OPTIONS:
132
+ logger.error(f"Invalid wh_ratio: {wh_ratio}. Valid options: {WH_RATIO_OPTIONS}")
133
+ raise ValueError(f"Invalid aspect ratio. Must be one of: {', '.join(WH_RATIO_OPTIONS)}")
134
+
135
+ try:
136
+ seed_int = int(seed)
137
+ except (TypeError, ValueError):
138
+ logger.warning(f"Invalid seed value '{seed}', falling back to -1 (random)")
139
+ seed_int = -1
140
+
141
+ try:
142
+ guidance_scale_f = float(guidance_scale)
143
+ except (TypeError, ValueError):
144
+ logger.warning(f"Invalid guidance_scale '{guidance_scale}', falling back to 5.0")
145
+ guidance_scale_f = 5.0
146
+
147
+ model_params = {
148
+ "prompt": prompt,
149
+ "wh_ratio": wh_ratio,
150
+ "model_id": MODEL_ID,
151
+ "n": 1,
152
+ "negative_prompt": negative_prompt or "",
153
+ "enable_prompt_refine": 1 if enable_prompt_refine else 0,
154
+ "seed": seed_int,
155
+ "guidance_scale": guidance_scale_f,
156
+ }
157
+
158
+ url = _build_request_url()
159
+
160
+ retry_count = 0
161
+ while retry_count < MAX_RETRY_COUNT:
162
+ try:
163
+ logger.info(
164
+ f"Sending API request [attempt {retry_count + 1}/{MAX_RETRY_COUNT}] for prompt: '{prompt[:50]}...'"
165
+ )
166
+ response = requests.post(url, json=model_params, headers=_headers(), timeout=15)
167
+ logger.info(f"API request response status: {response.status_code}")
168
+ response.raise_for_status()
169
+
170
+ response_json = response.json()
171
+ code = response_json.get("code")
172
+ message = response_json.get("message", "")
173
+ if code != SUCCESS_CODE:
174
+ logger.error(f"API returned error code {code}: {message}")
175
+ raise APIError(f"Failed to submit task (code={code}): {message}")
176
+
177
+ task_id = response_json.get("result", {}).get("task_id")
178
+ if not task_id:
179
+ logger.error(f"No task ID in API response: {response_json}")
180
+ raise APIError(f"No task ID returned from API: {response_json}")
181
+
182
+ logger.info(f"Successfully created task with ID: {task_id}")
183
+ return task_id
184
+
185
+ except requests.exceptions.Timeout:
186
+ retry_count += 1
187
+ logger.warning(f"Request timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
188
+ time.sleep(1)
189
+
190
+ except requests.exceptions.HTTPError as e:
191
+ status_code = e.response.status_code
192
+ error_message = f"HTTP error {status_code}"
193
+ try:
194
+ error_detail = e.response.json()
195
+ error_message += f": {error_detail}"
196
+ logger.error(f"API response error content: {error_detail}")
197
+ except Exception:
198
+ logger.error(f"Could not parse API error response as JSON. Raw content: {e.response.content[:500]}")
199
+
200
+ if status_code == 401:
201
+ logger.error(f"Authentication failed with API token. Status code: {status_code}")
202
+ raise APIError("Authentication failed. Please check your API token.")
203
+ elif status_code == 429:
204
+ retry_count += 1
205
+ wait_time = min(2 ** retry_count, 10)
206
+ logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry ({retry_count}/{MAX_RETRY_COUNT})...")
207
+ time.sleep(wait_time)
208
+ elif 400 <= status_code < 500:
209
+ logger.error(f"Client error: {error_message}, Prompt: '{prompt[:50]}...', Status: {status_code}")
210
+ raise APIError(error_message)
211
+ else:
212
+ retry_count += 1
213
+ logger.warning(f"Server error: {error_message}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
214
+ time.sleep(1)
215
+
216
+ except requests.exceptions.RequestException as e:
217
+ logger.error(f"Request error: {str(e)}")
218
+ logger.debug(f"Request error details: {traceback.format_exc()}")
219
+ raise APIError(f"Failed to connect to API: {str(e)}")
220
+
221
+ except APIError:
222
+ raise
223
+
224
+ except Exception as e:
225
+ logger.error(f"Unexpected error in create_request: {str(e)}")
226
+ logger.error(f"Full traceback: {traceback.format_exc()}")
227
+ raise APIError(f"Unexpected error: {str(e)}")
228
+
229
+ logger.error(f"Failed to create request after {MAX_RETRY_COUNT} retries for prompt: '{prompt[:50]}...'")
230
+ raise APIError(f"Failed after {MAX_RETRY_COUNT} retries")
231
+
232
+
233
+ def get_results(task_id):
234
+ """
235
+ Check the status of an image generation task.
236
+
237
+ Args:
238
+ task_id (str): The task ID to check
239
+
240
+ Returns:
241
+ dict: Task result information (the "result" object from the response), or None on transient failure.
242
+
243
+ Raises:
244
+ APIError: For unrecoverable errors (e.g. authentication failure).
245
+ """
246
+ logger.debug(f"Checking status for task ID: {task_id}")
247
+
248
+ if not task_id:
249
+ logger.error("Empty task ID provided to get_results")
250
+ raise ValueError("Task ID cannot be empty")
251
+
252
+ url = _build_result_url(task_id)
253
+
254
+ try:
255
+ response = requests.get(url, headers=_headers(), timeout=10)
256
+ logger.debug(f"Status check response code: {response.status_code}")
257
+ response.raise_for_status()
258
+ response_json = response.json()
259
+
260
+ code = response_json.get("code")
261
+ message = response_json.get("message", "")
262
+ if code != SUCCESS_CODE:
263
+ logger.warning(f"API returned non-success code {code} for task {task_id}: {message}")
264
+ return None
265
+
266
+ return response_json.get("result")
267
+
268
+ except requests.exceptions.Timeout:
269
+ logger.warning(f"Request timed out when checking task {task_id}")
270
+ return None
271
+
272
+ except requests.exceptions.HTTPError as e:
273
+ status_code = e.response.status_code
274
+ logger.warning(f"HTTP error {status_code} when checking task {task_id}")
275
+ try:
276
+ error_content = e.response.json()
277
+ logger.error(f"Error response content: {error_content}")
278
+ except Exception:
279
+ logger.error(f"Could not parse error response as JSON. Raw content: {e.response.content[:500]}")
280
+
281
+ if status_code == 401:
282
+ logger.error(f"Authentication failed when checking task {task_id}")
283
+ raise APIError(f"Authentication failed. Please check your API token when checking task {task_id}")
284
+ elif 400 <= status_code < 500:
285
+ logger.error(f"Client error {status_code} when checking task {task_id}")
286
+ return None
287
+ else:
288
+ logger.warning(f"Server error {status_code} when checking task {task_id}")
289
+ return None
290
+
291
+ except requests.exceptions.RequestException as e:
292
+ logger.warning(f"Network error when checking task {task_id}: {str(e)}")
293
+ logger.debug(f"Network error details: {traceback.format_exc()}")
294
+ return None
295
+
296
+ except Exception as e:
297
+ logger.error(f"Unexpected error when checking task {task_id}: {str(e)}")
298
+ logger.error(f"Full traceback: {traceback.format_exc()}")
299
+ return None
300
+
301
+
302
+ def download_image(image_url):
303
+ """
304
+ Download an image from a URL and return it as a PIL Image.
305
+ Converts non-PNG formats (e.g. WebP) to PNG while preserving original metadata.
306
+ """
307
+ logger.info(f"Starting download_image from URL: {image_url}")
308
+
309
+ if not image_url:
310
+ logger.error("Empty image URL provided to download_image")
311
+ raise ValueError("Image URL cannot be empty when downloading image")
312
+
313
+ retry_count = 0
314
+ while retry_count < MAX_RETRY_COUNT:
315
+ try:
316
+ logger.info(f"Downloading image [attempt {retry_count + 1}/{MAX_RETRY_COUNT}] from {image_url}")
317
+ response = requests.get(image_url, timeout=30)
318
+ logger.debug(
319
+ f"Image download response status: {response.status_code}, "
320
+ f"Content-Type: {response.headers.get('Content-Type')}, "
321
+ f"Content-Length: {response.headers.get('Content-Length')}"
322
+ )
323
+ response.raise_for_status()
324
+
325
+ image = Image.open(BytesIO(response.content))
326
+ logger.info(
327
+ f"Image opened successfully. Format: {image.format}, "
328
+ f"Size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}"
329
+ )
330
+
331
+ original_metadata = {}
332
+ for key, value in image.info.items():
333
+ if isinstance(key, str) and isinstance(value, str):
334
+ original_metadata[key] = value
335
+ logger.debug(f"Original image metadata: {original_metadata}")
336
+
337
+ if image.format != 'PNG':
338
+ logger.info(f"Converting image from {image.format} to PNG format")
339
+ png_buffer = BytesIO()
340
+ if 'A' in image.getbands():
341
+ image_to_save = image
342
+ else:
343
+ image_to_save = image.convert('RGB')
344
+ image_to_save.save(png_buffer, format='PNG')
345
+ png_buffer.seek(0)
346
+ image = Image.open(png_buffer)
347
+ for key, value in original_metadata.items():
348
+ image.info[key] = value
349
+
350
+ logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}")
351
+ return image
352
+
353
+ except requests.exceptions.Timeout:
354
+ retry_count += 1
355
+ logger.warning(f"Download timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
356
+ time.sleep(1)
357
+
358
+ except requests.exceptions.HTTPError as e:
359
+ status_code = e.response.status_code
360
+ logger.error(f"HTTP error {status_code} when downloading image from {image_url}")
361
+ if 400 <= status_code < 500:
362
+ raise APIError(f"HTTP error {status_code} when downloading image")
363
+ else:
364
+ retry_count += 1
365
+ time.sleep(1)
366
+
367
+ except requests.exceptions.RequestException as e:
368
+ retry_count += 1
369
+ logger.warning(f"Network error during image download: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
370
+ time.sleep(1)
371
+
372
+ except Exception as e:
373
+ logger.error(f"Error processing image from {image_url}: {str(e)}")
374
+ logger.error(f"Full traceback: {traceback.format_exc()}")
375
+ raise APIError(f"Failed to process image: {str(e)}")
376
+
377
+ logger.error(f"Failed to download image from {image_url} after {MAX_RETRY_COUNT} retries")
378
+ raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries")
379
+
380
+
381
+ APPLE_CSS = """
382
+ /* ---- Apple-inspired minimalist UI ---- */
383
+ .gradio-container {
384
+ font-family: -apple-system, BlinkMacSystemFont, "SF Pro Display", "SF Pro Text",
385
+ "Helvetica Neue", "Segoe UI", Inter, sans-serif !important;
386
+ background: linear-gradient(180deg, #fbfbfd 0%, #f5f5f7 100%) !important;
387
+ /* Always use ~3/4 of the viewport, capped at 1600px on huge screens.
388
+ Using width AND max-width ensures the page is wide from first paint
389
+ instead of growing only after content loads. */
390
+ width: min(1600px, 92vw) !important;
391
+ max-width: 1600px !important;
392
+ margin: 0 auto !important;
393
+ -webkit-font-smoothing: antialiased;
394
+ -moz-osx-font-smoothing: grayscale;
395
+ color: #1d1d1f !important;
396
+ }
397
+
398
+ /* Cards / panels */
399
+ .panel-card {
400
+ background: #ffffff !important;
401
+ border-radius: 18px !important;
402
+ padding: 18px !important;
403
+ box-shadow: 0 1px 2px rgba(0,0,0,0.04), 0 8px 28px rgba(0,0,0,0.05) !important;
404
+ border: 1px solid rgba(0,0,0,0.05) !important;
405
+ }
406
+
407
+ /* Inputs - rounded with apple-blue focus ring */
408
+ textarea, input[type="text"], input[type="number"],
409
+ .gradio-container .form input,
410
+ .gradio-container .form textarea {
411
+ border-radius: 12px !important;
412
+ border: 1px solid #d2d2d7 !important;
413
+ background: #ffffff !important;
414
+ transition: border-color 0.15s ease, box-shadow 0.15s ease !important;
415
+ font-size: 15px !important;
416
+ }
417
+ textarea:focus, input:focus,
418
+ .gradio-container .form input:focus,
419
+ .gradio-container .form textarea:focus {
420
+ border-color: #0071e3 !important;
421
+ box-shadow: 0 0 0 4px rgba(0, 113, 227, 0.15) !important;
422
+ outline: none !important;
423
+ }
424
+
425
+ /* Dropdown */
426
+ .gradio-container .wrap.svelte-1ipelgc { border-radius: 12px !important; }
427
+
428
+ /* Block labels */
429
+ .gradio-container span[data-testid="block-label"],
430
+ .gradio-container .block-label,
431
+ .gradio-container label > span {
432
+ color: #6e6e73 !important;
433
+ font-weight: 500 !important;
434
+ font-size: 13px !important;
435
+ letter-spacing: 0.01em;
436
+ }
437
+
438
+ /* Buttons - pill shape, Apple blue. Scoped to the variant classes so it does
439
+ NOT bleed into internal buttons inside dropdowns, accordions, etc. */
440
+ .gradio-container button.primary,
441
+ .gradio-container button.secondary {
442
+ border-radius: 980px !important;
443
+ font-weight: 500 !important;
444
+ font-size: 15px !important;
445
+ padding: 12px 22px !important;
446
+ transition: transform 0.08s ease, box-shadow 0.18s ease, background 0.18s ease, opacity 0.15s ease !important;
447
+ border: none !important;
448
+ letter-spacing: 0.01em;
449
+ }
450
+ .gradio-container button.primary {
451
+ background: #0071e3 !important;
452
+ color: #ffffff !important;
453
+ box-shadow: 0 1px 2px rgba(0,113,227,0.25), 0 6px 16px rgba(0,113,227,0.22) !important;
454
+ }
455
+ .gradio-container button.primary:hover {
456
+ background: #0077ed !important;
457
+ transform: translateY(-1px);
458
+ box-shadow: 0 2px 4px rgba(0,113,227,0.3), 0 10px 22px rgba(0,113,227,0.28) !important;
459
+ }
460
+ .gradio-container button.primary:active { transform: translateY(0); }
461
+ .gradio-container button.secondary {
462
+ background: rgba(0,0,0,0.05) !important;
463
+ color: #1d1d1f !important;
464
+ box-shadow: none !important;
465
+ }
466
+ .gradio-container button.secondary:hover {
467
+ background: rgba(0,0,0,0.09) !important;
468
+ }
469
+
470
+ /* Make sure the dropdown's selected-value text never gets pill-clipped and
471
+ is properly aligned inside its rounded box. */
472
+ .gradio-container .wrap-inner,
473
+ .gradio-container .single-select,
474
+ .gradio-container .secondary-wrap {
475
+ border-radius: 12px !important;
476
+ }
477
+ .gradio-container .single-select input,
478
+ .gradio-container input[role="listbox"] {
479
+ border-radius: 12px !important;
480
+ padding: 10px 14px !important;
481
+ font-size: 15px !important;
482
+ }
483
+
484
+ /* Status pill */
485
+ #status-bar {
486
+ padding: 0 !important;
487
+ margin-top: 6px;
488
+ }
489
+ .status-pill {
490
+ display: inline-flex;
491
+ align-items: center;
492
+ gap: 9px;
493
+ background: #f5f5f7;
494
+ color: #1d1d1f;
495
+ padding: 10px 14px;
496
+ border-radius: 12px;
497
+ font-size: 13px;
498
+ font-weight: 500;
499
+ line-height: 1;
500
+ border: 1px solid rgba(0,0,0,0.04);
501
+ }
502
+ .status-dot {
503
+ width: 8px;
504
+ height: 8px;
505
+ border-radius: 50%;
506
+ background: #8e8e93;
507
+ flex-shrink: 0;
508
+ }
509
+ .status-info .status-dot { background: #8e8e93; }
510
+ .status-success { background: rgba(48,209,88,0.10); color: #0a7f2e; border-color: rgba(48,209,88,0.20); }
511
+ .status-success .status-dot { background: #30d158; }
512
+ .status-error { background: rgba(255,59,48,0.10); color: #b8261b; border-color: rgba(255,59,48,0.20); }
513
+ .status-error .status-dot { background: #ff3b30; }
514
+ .status-running { background: rgba(0,113,227,0.10); color: #0058b8; border-color: rgba(0,113,227,0.20); }
515
+ .status-running .status-dot {
516
+ background: #0071e3;
517
+ animation: pulse 1.4s ease-in-out infinite;
518
+ }
519
+ @keyframes pulse {
520
+ 0%, 100% { opacity: 0.4; transform: scale(0.85); }
521
+ 50% { opacity: 1.0; transform: scale(1.15); }
522
+ }
523
+
524
+ /* Image output frame — never crop the image; show full picture with letterbox. */
525
+ .image-output {
526
+ border-radius: 18px !important;
527
+ background: #f5f5f7 !important;
528
+ }
529
+ .image-output,
530
+ .image-output > div,
531
+ .image-output [data-testid="image"],
532
+ .image-output .image-container,
533
+ .image-output .image-frame,
534
+ .image-output .preview {
535
+ min-height: 440px !important;
536
+ display: flex !important;
537
+ align-items: center !important;
538
+ justify-content: center !important;
539
+ }
540
+ .image-output img {
541
+ border-radius: 14px !important;
542
+ object-fit: contain !important;
543
+ max-width: 100% !important;
544
+ max-height: 62vh !important;
545
+ width: auto !important;
546
+ height: auto !important;
547
+ }
548
+
549
+ /* Status pill placed inside the right column, above the image. */
550
+ .right-status {
551
+ display: flex;
552
+ justify-content: flex-start;
553
+ margin-bottom: 6px;
554
+ }
555
+
556
+ /* Accordion */
557
+ .gradio-container details {
558
+ border-radius: 14px !important;
559
+ border: 1px solid rgba(0,0,0,0.06) !important;
560
+ background: #ffffff !important;
561
+ }
562
+
563
+ /* Negative prompt — softer accent so it visually de-emphasises vs. the main prompt */
564
+ .negative-prompt textarea {
565
+ background: #fbfbfd !important;
566
+ border-color: #e3e3e8 !important;
567
+ }
568
+ .negative-prompt textarea:focus {
569
+ background: #ffffff !important;
570
+ }
571
+
572
+ /* Advanced options row — keeps the refine switch + seed input visually paired */
573
+ .advanced-row {
574
+ gap: 14px !important;
575
+ margin-top: 2px;
576
+ }
577
+
578
+ /* Refine toggle — iOS-style settings card.
579
+ Goal: title + helper text on the left, a polished pill switch on the right,
580
+ everything contained inside a single soft card so the helper text no longer
581
+ floats orphaned above the box. */
582
+ .refine-toggle {
583
+ background: linear-gradient(180deg, #ffffff 0%, #f5f5f7 100%) !important;
584
+ border-radius: 14px !important;
585
+ border: 1px solid rgba(0,0,0,0.06) !important;
586
+ padding: 14px 16px !important;
587
+ box-shadow: 0 1px 2px rgba(0,0,0,0.03) !important;
588
+ transition: border-color 0.18s ease, box-shadow 0.18s ease !important;
589
+ min-height: 88px;
590
+ display: flex !important;
591
+ flex-direction: column !important;
592
+ justify-content: center !important;
593
+ }
594
+ .refine-toggle:hover {
595
+ border-color: rgba(0,0,0,0.10) !important;
596
+ box-shadow: 0 1px 2px rgba(0,0,0,0.04), 0 4px 14px rgba(0,0,0,0.05) !important;
597
+ }
598
+
599
+ /* Strip default backgrounds from gradio's inner wrappers so only our card shows. */
600
+ .refine-toggle .form,
601
+ .refine-toggle .wrap,
602
+ .refine-toggle .form-wrap,
603
+ .refine-toggle > div {
604
+ background: transparent !important;
605
+ border: none !important;
606
+ padding: 0 !important;
607
+ margin: 0 !important;
608
+ box-shadow: none !important;
609
+ }
610
+
611
+ /* Helper / "info" text becomes a proper subtitle UNDER the toggle row. */
612
+ .refine-toggle [data-testid="block-info"],
613
+ .refine-toggle .info {
614
+ color: #6e6e73 !important;
615
+ font-size: 12px !important;
616
+ line-height: 1.4 !important;
617
+ margin: 8px 0 0 0 !important;
618
+ padding: 0 !important;
619
+ text-align: left !important;
620
+ order: 2 !important;
621
+ }
622
+
623
+ /* Force the gradio wrapper to stack: label-row first, info below. */
624
+ .refine-toggle .form,
625
+ .refine-toggle > div:not([data-testid="block-info"]):not(.info) {
626
+ display: flex !important;
627
+ flex-direction: column !important;
628
+ align-items: stretch !important;
629
+ }
630
+
631
+ /* Label row: title on the LEFT (full width), toggle pinned to the RIGHT. */
632
+ .refine-toggle label {
633
+ display: flex !important;
634
+ align-items: center !important;
635
+ justify-content: space-between !important;
636
+ flex-direction: row-reverse !important;
637
+ cursor: pointer !important;
638
+ margin: 0 !important;
639
+ padding: 0 !important;
640
+ gap: 14px !important;
641
+ width: 100% !important;
642
+ order: 1 !important;
643
+ }
644
+ .refine-toggle label > span {
645
+ color: #1d1d1f !important;
646
+ font-size: 15px !important;
647
+ font-weight: 600 !important;
648
+ letter-spacing: -0.01em;
649
+ flex: 1 1 auto !important;
650
+ text-align: left !important;
651
+ line-height: 1.3 !important;
652
+ }
653
+
654
+ /* Pill switch — bigger, smoother, more "Apple-like" */
655
+ .refine-toggle input[type="checkbox"] {
656
+ appearance: none;
657
+ -webkit-appearance: none;
658
+ width: 46px !important;
659
+ height: 28px !important;
660
+ border-radius: 999px !important;
661
+ background: #e5e5ea !important;
662
+ position: relative;
663
+ cursor: pointer;
664
+ transition: background 0.22s ease, box-shadow 0.22s ease;
665
+ border: none !important;
666
+ flex-shrink: 0 !important;
667
+ margin: 0 !important;
668
+ box-shadow: inset 0 0 1px rgba(0,0,0,0.06);
669
+ }
670
+ .refine-toggle input[type="checkbox"]::after {
671
+ content: "";
672
+ position: absolute;
673
+ top: 2px;
674
+ left: 2px;
675
+ width: 24px;
676
+ height: 24px;
677
+ border-radius: 50%;
678
+ background: #ffffff;
679
+ box-shadow: 0 2px 5px rgba(0,0,0,0.18), 0 0 1px rgba(0,0,0,0.05);
680
+ transition: transform 0.24s cubic-bezier(0.4, 0.0, 0.2, 1);
681
+ }
682
+ .refine-toggle input[type="checkbox"]:hover {
683
+ background: #dcdce0 !important;
684
+ }
685
+ .refine-toggle input[type="checkbox"]:checked {
686
+ background: #34c759 !important;
687
+ }
688
+ .refine-toggle input[type="checkbox"]:checked:hover {
689
+ background: #30b352 !important;
690
+ }
691
+ .refine-toggle input[type="checkbox"]:checked::after {
692
+ transform: translateX(18px);
693
+ }
694
+ .refine-toggle input[type="checkbox"]:focus-visible {
695
+ box-shadow: 0 0 0 4px rgba(0,113,227,0.20) !important;
696
+ }
697
+ .refine-toggle input[type="checkbox"]:active::after {
698
+ /* tiny squish on press, very iOS */
699
+ width: 28px;
700
+ }
701
+ .refine-toggle input[type="checkbox"]:checked:active::after {
702
+ transform: translateX(14px);
703
+ }
704
+
705
+ /* Seed number input — match the prompt/dropdown rounding */
706
+ .seed-input input[type="number"] {
707
+ border-radius: 12px !important;
708
+ padding: 10px 14px !important;
709
+ font-variant-numeric: tabular-nums;
710
+ }
711
+ /* Hide the native spinner buttons on number inputs for a cleaner look */
712
+ .seed-input input[type="number"]::-webkit-outer-spin-button,
713
+ .seed-input input[type="number"]::-webkit-inner-spin-button {
714
+ -webkit-appearance: none;
715
+ margin: 0;
716
+ }
717
+ .seed-input input[type="number"] {
718
+ -moz-appearance: textfield;
719
+ }
720
+ /* Keep the seed column visually aligned with the refine card next to it */
721
+ .seed-input {
722
+ align-self: stretch !important;
723
+ }
724
+
725
+ /* Guidance scale slider — Apple-blue track + softer thumb */
726
+ .guidance-slider input[type="range"] {
727
+ accent-color: #0071e3 !important;
728
+ }
729
+ .guidance-slider .head { padding-top: 0 !important; }
730
+ .guidance-slider {
731
+ margin-top: 4px;
732
+ }
733
+
734
+ /* Footer tagline */
735
+ .tagline {
736
+ text-align: center;
737
+ color: #6e6e73;
738
+ font-size: 12px;
739
+ margin: 18px 0 14px 0;
740
+ font-weight: 400;
741
+ }
742
+ .tagline a {
743
+ color: #0071e3;
744
+ text-decoration: none;
745
+ font-weight: 500;
746
+ transition: opacity 0.15s ease;
747
+ }
748
+ .tagline a:hover { opacity: 0.7; }
749
+
750
+ /* Footer links row (HuggingFace / GitHub / Twitter) */
751
+ .footer-links {
752
+ display: flex;
753
+ justify-content: center;
754
+ align-items: center;
755
+ gap: 26px;
756
+ flex-wrap: wrap;
757
+ font-size: 13px;
758
+ margin: 24px 0 6px 0;
759
+ }
760
+ /* When the vivago tagline follows the links row, tighten the gap. */
761
+ .footer-links + .tagline {
762
+ margin-top: 4px;
763
+ }
764
+
765
+ /* Hide gradio's default footer for a cleaner look */
766
+ footer { display: none !important; }
767
+
768
+ /* Mobile */
769
+ @media (max-width: 640px) {
770
+ .footer-links { gap: 18px; font-size: 12px; }
771
+ }
772
+ """
773
+
774
+
775
+ APPLE_THEME = gr.themes.Soft(
776
+ primary_hue=gr.themes.colors.blue,
777
+ neutral_hue=gr.themes.colors.slate,
778
+ radius_size=gr.themes.sizes.radius_lg,
779
+ text_size=gr.themes.sizes.text_md,
780
+ font=[
781
+ gr.themes.GoogleFont("Inter"),
782
+ "ui-sans-serif",
783
+ "-apple-system",
784
+ "BlinkMacSystemFont",
785
+ "Segoe UI",
786
+ "Helvetica Neue",
787
+ "sans-serif",
788
+ ],
789
+ ).set(
790
+ body_background_fill="*neutral_50",
791
+ block_background_fill="white",
792
+ block_border_width="1px",
793
+ block_label_text_weight="500",
794
+ block_title_text_weight="600",
795
+ button_primary_background_fill="#0071e3",
796
+ button_primary_background_fill_hover="#0077ed",
797
+ button_primary_text_color="white",
798
+ button_primary_border_color="#0071e3",
799
+ button_secondary_background_fill="rgba(0,0,0,0.05)",
800
+ button_secondary_background_fill_hover="rgba(0,0,0,0.09)",
801
+ button_secondary_text_color="#1d1d1f",
802
+ input_background_fill="white",
803
+ input_border_color="#d2d2d7",
804
+ input_border_color_focus="#0071e3",
805
+ input_shadow_focus="0 0 0 4px rgba(0,113,227,0.15)",
806
+ )
807
+
808
+
809
+ def _status_html(text: str, kind: str = "info") -> str:
810
+ """Render a styled status pill. kind: info | success | error | running"""
811
+ return (
812
+ f'<div class="status-pill status-{kind}">'
813
+ f'<span class="status-dot"></span>{text}'
814
+ f'</div>'
815
+ )
816
+
817
+
818
+ def _queue_text(waiting: int, running: int) -> str:
819
+ """Inline status text describing the queue from the user's POV.
820
+
821
+ Used inside the right-side status pill while THIS user's request is
822
+ sitting in the queue waiting for an open slot.
823
+ """
824
+ if waiting <= 0 and running <= 0:
825
+ return "Queued · waiting for an open slot…"
826
+ parts = []
827
+ if waiting > 0:
828
+ parts.append(f"{waiting} waiting")
829
+ if running > 0:
830
+ parts.append(f"{running} generating")
831
+ return f"In queue · {' · '.join(parts)}"
832
+
833
+
834
+ def _read_queue_stats(demo_obj) -> tuple[int, int, float]:
835
+ """Best-effort read of gradio's internal queue.
836
+
837
+ Gradio 5.x structure (gradio/queueing.py):
838
+ demo._queue.event_queue_per_concurrency_id: dict[str, EventQueue]
839
+ EventQueue.queue: list[Event] ← actually-waiting events
840
+ demo._queue.active_jobs: list[None | list[Event]]
841
+ each slot is None (idle) or a list of currently-processing events.
842
+ demo._queue.process_time_per_fn: dict[BlockFunction, ProcessTime]
843
+ ProcessTime.avg_time: float
844
+
845
+ Returns (waiting, running, avg_secs). Each lookup is wrapped in a try
846
+ so the UI degrades gracefully ("idle") if Gradio ever renames a field.
847
+ """
848
+ try:
849
+ q = getattr(demo_obj, "_queue", None)
850
+ if q is None:
851
+ return 0, 0, 0.0
852
+
853
+ # ---- Waiting: events sitting in EventQueue.queue ----
854
+ waiting = 0
855
+ events_per_cid = getattr(q, "event_queue_per_concurrency_id", None) or {}
856
+ for ev_q in events_per_cid.values():
857
+ # Newer gradio: ev_q is an EventQueue with a .queue list.
858
+ # Older/alt: ev_q might already be a list. Handle both.
859
+ inner = getattr(ev_q, "queue", ev_q)
860
+ try:
861
+ waiting += len(inner)
862
+ except (TypeError, AttributeError):
863
+ # Last resort: try iterating
864
+ try:
865
+ waiting += sum(1 for _ in inner)
866
+ except Exception:
867
+ continue
868
+
869
+ # ---- Running: count events held in active_jobs slots ----
870
+ # Each slot is None or a list[Event]; sum the list lengths.
871
+ running = 0
872
+ active = getattr(q, "active_jobs", None) or []
873
+ for slot in active:
874
+ if slot is None:
875
+ continue
876
+ try:
877
+ running += len(slot)
878
+ except (TypeError, AttributeError):
879
+ running += 1 # very old gradio: single Event per slot
880
+
881
+ # ---- Average per-run time (best effort across versions) ----
882
+ avg_secs = 0.0
883
+ # Gradio 5.x: dict[BlockFunction, ProcessTime] with .avg_time
884
+ ptpf = getattr(q, "process_time_per_fn", None)
885
+ if isinstance(ptpf, dict) and ptpf:
886
+ try:
887
+ vals = []
888
+ for v in ptpf.values():
889
+ avg_t = getattr(v, "avg_time", None)
890
+ if avg_t:
891
+ vals.append(float(avg_t))
892
+ if vals:
893
+ avg_secs = sum(vals) / len(vals)
894
+ except Exception:
895
+ avg_secs = 0.0
896
+ # Older: dict[int, float]
897
+ if not avg_secs:
898
+ ptpfi = getattr(q, "process_time_per_fn_index", None)
899
+ if isinstance(ptpfi, dict) and ptpfi:
900
+ try:
901
+ vals = [float(v) for v in ptpfi.values() if v]
902
+ if vals:
903
+ avg_secs = sum(vals) / len(vals)
904
+ except Exception:
905
+ avg_secs = 0.0
906
+ # Oldest: single float on the queue
907
+ if not avg_secs:
908
+ apt = getattr(q, "avg_process_time", None)
909
+ if apt:
910
+ try:
911
+ avg_secs = float(apt)
912
+ except Exception:
913
+ avg_secs = 0.0
914
+
915
+ return waiting, running, avg_secs
916
+ except Exception as exc:
917
+ logger.debug(f"Queue introspection failed: {exc}")
918
+ return 0, 0, 0.0
919
+
920
+
921
+ def _fake_generation_iter(prompt: str, wh_ratio_value: str):
922
+ """FAKE_TEST mode generator.
923
+
924
+ Mimics the real flow's yield protocol without hitting any external API.
925
+ Useful for exercising the queue UI / status pill locally.
926
+
927
+ Yields (image_or_None, status_html) tuples. The very first 'Sending
928
+ request to API…' yield is emitted by the caller, so this iterator picks
929
+ up from 'Request submitted' onwards.
930
+ """
931
+ time.sleep(random.uniform(0.4, 1.0))
932
+
933
+ fake_id = f"{random.randint(0, 0xFFFFFFFF):08x}"
934
+ yield None, _status_html(f"Request submitted · Task {fake_id}…", "running")
935
+
936
+ target_secs = random.uniform(8.0, 22.0)
937
+ start = time.time()
938
+ while time.time() - start < target_secs:
939
+ elapsed = int(time.time() - start)
940
+ yield None, _status_html(f"Generating… {elapsed}s", "running")
941
+ time.sleep(POLL_INTERVAL)
942
+
943
+ yield None, _status_html("Downloading image…", "running")
944
+ time.sleep(0.3)
945
+
946
+ rgb = (
947
+ random.randint(40, 220),
948
+ random.randint(40, 220),
949
+ random.randint(40, 220),
950
+ )
951
+ fake_image = Image.new("RGB", (1024, 1024), color=rgb)
952
+ logger.info(f"FAKE_TEST: returning random color image rgb={rgb}, took {target_secs:.1f}s")
953
+ yield fake_image, _status_html("Image generated", "success")
954
+
955
+
956
+ def create_ui():
957
+ logger.info("Creating Gradio UI")
958
+ with gr.Blocks(
959
+ title="HiDream-O1-Image Generator",
960
+ theme=APPLE_THEME,
961
+ css=APPLE_CSS,
962
+ ) as demo:
963
+ # Per-session state used to drive the right-side status pill while
964
+ # this user's request is sitting in the queue. `last_pill_state`
965
+ # caches the most recent pill HTML so the timer can return
966
+ # `gr.update()` (=no-op, no DOM replacement → no flicker) when the
967
+ # queue counts haven't actually changed between ticks.
968
+ queued_state = gr.State(False)
969
+ last_pill_state = gr.State("")
970
+
971
+ with gr.Row(equal_height=False):
972
+ with gr.Column(scale=1, elem_classes=["panel-card"]):
973
+ prompt = gr.Textbox(
974
+ label="Prompt",
975
+ placeholder="Describe the image you want to create...",
976
+ lines=5,
977
+ show_label=True,
978
+ )
979
+
980
+ negative_prompt = gr.Textbox(
981
+ label="Negative Prompt",
982
+ placeholder="Things you want to avoid in the image (optional)...",
983
+ lines=2,
984
+ show_label=True,
985
+ elem_classes=["negative-prompt"],
986
+ )
987
+
988
+ wh_ratio = gr.Dropdown(
989
+ choices=WH_RATIO_OPTIONS,
990
+ value=WH_RATIO_OPTIONS[0],
991
+ label="Aspect Ratio",
992
+ info="Width : Height",
993
+ )
994
+
995
+ guidance_scale = gr.Slider(
996
+ minimum=1.0,
997
+ maximum=20.0,
998
+ step=0.1,
999
+ value=5.0,
1000
+ label="Guidance Scale",
1001
+ info="Higher values follow the prompt more strictly",
1002
+ elem_classes=["guidance-slider"],
1003
+ )
1004
+
1005
+ with gr.Row(elem_classes=["advanced-row"], equal_height=True):
1006
+ enable_prompt_refine = gr.Checkbox(
1007
+ value=True,
1008
+ label="Prompt Refine",
1009
+ info="Let the model rewrite & enrich your prompt",
1010
+ elem_classes=["refine-toggle"],
1011
+ scale=1,
1012
+ )
1013
+ seed = gr.Number(
1014
+ value=-1,
1015
+ label="Seed",
1016
+ info="Use -1 for a random seed",
1017
+ precision=0,
1018
+ minimum=-1,
1019
+ elem_classes=["seed-input"],
1020
+ scale=1,
1021
+ )
1022
+
1023
+ with gr.Row():
1024
+ clear_btn = gr.Button("Clear", variant="secondary", scale=1)
1025
+ generate_btn = gr.Button("Generate", variant="primary", scale=3)
1026
+
1027
+ with gr.Column(scale=1, elem_classes=["panel-card"]):
1028
+ status_msg = gr.HTML(
1029
+ value=_status_html("Ready", "info"),
1030
+ elem_id="status-bar",
1031
+ elem_classes=["right-status"],
1032
+ )
1033
+ output_image = gr.Image(
1034
+ label="Generated Image",
1035
+ format="png",
1036
+ type="pil",
1037
+ interactive=False,
1038
+ show_download_button=True,
1039
+ elem_classes=["image-output"],
1040
+ )
1041
+
1042
+ def generate_with_status(
1043
+ prompt,
1044
+ wh_ratio_value,
1045
+ negative_prompt_value,
1046
+ enable_prompt_refine_value,
1047
+ seed_value,
1048
+ guidance_scale_value,
1049
+ ):
1050
+ logger.info(
1051
+ f"Starting image generation with prompt='{(prompt or '')[:50]}...', "
1052
+ f"wh_ratio={wh_ratio_value}, "
1053
+ f"negative_prompt='{(negative_prompt_value or '')[:30]}...', "
1054
+ f"enable_prompt_refine={enable_prompt_refine_value}, "
1055
+ f"seed={seed_value}, guidance_scale={guidance_scale_value}"
1056
+ )
1057
+
1058
+ yield None, _status_html("Sending request to API…", "running")
1059
+
1060
+ try:
1061
+ if not prompt or not prompt.strip():
1062
+ logger.error("Empty prompt provided in UI")
1063
+ yield None, _status_html("Prompt cannot be empty", "error")
1064
+ return
1065
+
1066
+ if wh_ratio_value not in WH_RATIO_OPTIONS:
1067
+ logger.error(f"Invalid aspect ratio selection: {wh_ratio_value}")
1068
+ yield None, _status_html(f"Invalid aspect ratio “{wh_ratio_value}”", "error")
1069
+ return
1070
+
1071
+ try:
1072
+ seed_int = int(seed_value) if seed_value is not None else -1
1073
+ except (TypeError, ValueError):
1074
+ seed_int = -1
1075
+
1076
+ try:
1077
+ guidance_scale_f = float(guidance_scale_value) if guidance_scale_value is not None else 5.0
1078
+ except (TypeError, ValueError):
1079
+ guidance_scale_f = 5.0
1080
+
1081
+ if FAKE_TEST:
1082
+ logger.info("FAKE_TEST mode active — bypassing real API call")
1083
+ yield from _fake_generation_iter(prompt, wh_ratio_value)
1084
+ return
1085
+
1086
+ logger.info("Creating API request")
1087
+ task_id = create_request(
1088
+ prompt,
1089
+ wh_ratio_value,
1090
+ negative_prompt=negative_prompt_value or "",
1091
+ enable_prompt_refine=bool(enable_prompt_refine_value),
1092
+ seed=seed_int,
1093
+ guidance_scale=guidance_scale_f,
1094
+ )
1095
+ yield None, _status_html(f"Request submitted · Task {task_id[:8]}…", "running")
1096
+
1097
+ start_time = time.time()
1098
+ logger.info(f"Starting to poll for results for task ID: {task_id}")
1099
+
1100
+ while time.time() - start_time < MAX_POLL_TIME:
1101
+ elapsed_time = time.time() - start_time
1102
+ logger.debug(
1103
+ f"Polling for results - Task ID: {task_id}, Elapsed: {elapsed_time:.2f}s"
1104
+ )
1105
+
1106
+ result = get_results(task_id)
1107
+ if not result:
1108
+ time.sleep(POLL_INTERVAL)
1109
+ continue
1110
+
1111
+ overall_status = result.get("status")
1112
+ sub_results = result.get("sub_task_results", []) or []
1113
+
1114
+ if overall_status != 1:
1115
+ elapsed = int(time.time() - start_time)
1116
+ yield None, _status_html(f"Generating… {elapsed}s", "running")
1117
+ time.sleep(POLL_INTERVAL)
1118
+ continue
1119
+
1120
+ if not sub_results:
1121
+ logger.error(f"Task completed but no sub_task_results returned. Task ID: {task_id}")
1122
+ yield None, _status_html("Task completed but no results returned", "error")
1123
+ return
1124
+
1125
+ sub = sub_results[0]
1126
+ sub_status = sub.get("task_status")
1127
+
1128
+ if sub_status == 1:
1129
+ logger.info(f"Task completed successfully - Task ID: {task_id}")
1130
+
1131
+ image_url = sub.get("url")
1132
+ if not image_url:
1133
+ logger.error(f"No image URL in successful response. Sub result: {sub}")
1134
+ yield None, _status_html("No image URL in response", "error")
1135
+ return
1136
+
1137
+ yield None, _status_html("Downloading image…", "running")
1138
+
1139
+ logger.info(f"Downloading image - Task ID: {task_id}, URL: {image_url}")
1140
+ image = download_image(image_url)
1141
+
1142
+ if image:
1143
+ logger.info(f"Image generation complete - Task ID: {task_id}")
1144
+ yield image, _status_html("Image generated", "success")
1145
+ return
1146
+ else:
1147
+ logger.error(f"Failed to download image - Task ID: {task_id}, URL: {image_url}")
1148
+ yield None, _status_html("Failed to download generated image", "error")
1149
+ return
1150
+
1151
+ elif sub_status == 3:
1152
+ error_msg = sub.get("task_error") or sub.get("message") or "Unknown error"
1153
+ logger.error(
1154
+ f"Task failed - Task ID: {task_id}, Sub status: {sub_status}, Error: {error_msg}"
1155
+ )
1156
+ yield None, _status_html(f"Task failed: {error_msg}", "error")
1157
+ return
1158
+
1159
+ else:
1160
+ elapsed = int(time.time() - start_time)
1161
+ yield None, _status_html(f"Waiting… {elapsed}s", "running")
1162
+ time.sleep(POLL_INTERVAL)
1163
+
1164
+ logger.error(
1165
+ f"Timeout waiting for task completion - Task ID: {task_id}, Max time: {MAX_POLL_TIME}s"
1166
+ )
1167
+ yield None, _status_html(f"Timed out after {MAX_POLL_TIME}s", "error")
1168
+
1169
+ except APIError as e:
1170
+ logger.error(f"API Error during generation: {str(e)}")
1171
+ yield None, _status_html(f"API error: {str(e)}", "error")
1172
+
1173
+ except ValueError as e:
1174
+ logger.error(f"Value Error during generation: {str(e)}")
1175
+ yield None, _status_html(f"Value error: {str(e)}", "error")
1176
+
1177
+ except Exception as e:
1178
+ logger.error(f"Unexpected error during image generation: {str(e)}")
1179
+ logger.error(f"Full traceback: {traceback.format_exc()}")
1180
+ yield None, _status_html(f"Unexpected error: {str(e)}", "error")
1181
+
1182
+ def _enter_queue():
1183
+ """Click handler #1 — runs IMMEDIATELY (queue=False).
1184
+
1185
+ Flips queued_state to True and seeds the pill with a snapshot of
1186
+ the current queue. Subsequent live updates come from the timer.
1187
+ """
1188
+ waiting, running, _ = _read_queue_stats(demo)
1189
+ html = _status_html(_queue_text(waiting, running), "running")
1190
+ return None, html, True, html
1191
+
1192
+ def _generate_wrapped(
1193
+ prompt_value,
1194
+ wh_ratio_value,
1195
+ negative_prompt_value,
1196
+ enable_prompt_refine_value,
1197
+ seed_value,
1198
+ guidance_scale_value,
1199
+ ):
1200
+ """Click handler #2 — queued.
1201
+
1202
+ Wraps the existing generator and, on the FIRST yield, also flips
1203
+ queued_state to False so the timer stops touching the pill and
1204
+ lets the generator's own `yield`s drive it (Generating XXs → ...).
1205
+ """
1206
+ first = True
1207
+ for image, status_html in generate_with_status(
1208
+ prompt_value,
1209
+ wh_ratio_value,
1210
+ negative_prompt_value,
1211
+ enable_prompt_refine_value,
1212
+ seed_value,
1213
+ guidance_scale_value,
1214
+ ):
1215
+ if first:
1216
+ first = False
1217
+ yield image, status_html, False
1218
+ else:
1219
+ yield image, status_html, gr.update()
1220
+
1221
+ generate_btn.click(
1222
+ fn=_enter_queue,
1223
+ inputs=None,
1224
+ outputs=[output_image, status_msg, queued_state, last_pill_state],
1225
+ queue=False,
1226
+ show_progress="hidden",
1227
+ ).then(
1228
+ fn=_generate_wrapped,
1229
+ inputs=[prompt, wh_ratio, negative_prompt, enable_prompt_refine, seed, guidance_scale],
1230
+ outputs=[output_image, status_msg, queued_state],
1231
+ show_progress="minimal",
1232
+ show_progress_on=[generate_btn],
1233
+ )
1234
+
1235
+ def clear_outputs():
1236
+ logger.info("Clearing UI outputs")
1237
+ return None, _status_html("Ready", "info"), False, ""
1238
+
1239
+ clear_btn.click(
1240
+ fn=clear_outputs,
1241
+ inputs=None,
1242
+ outputs=[output_image, status_msg, queued_state, last_pill_state],
1243
+ )
1244
+
1245
+ # Live queue updates inside the right-side pill — only while THIS
1246
+ # user is queued. Returns gr.update() (no-op) when nothing changed,
1247
+ # which prevents the DOM from being replaced and avoids the pulse
1248
+ # animation resetting (= no flicker).
1249
+ pill_timer = gr.Timer(value=1.5, active=True)
1250
+
1251
+ def _tick_pill(queued_flag, last_html):
1252
+ if not queued_flag:
1253
+ return gr.update(), last_html
1254
+ waiting, running, _ = _read_queue_stats(demo)
1255
+ new_html = _status_html(_queue_text(waiting, running), "running")
1256
+ if new_html == last_html:
1257
+ return gr.update(), last_html
1258
+ return new_html, new_html
1259
+
1260
+ pill_timer.tick(
1261
+ fn=_tick_pill,
1262
+ inputs=[queued_state, last_pill_state],
1263
+ outputs=[status_msg, last_pill_state],
1264
+ queue=False,
1265
+ show_progress="hidden",
1266
+ )
1267
+
1268
+ gr.HTML(
1269
+ """
1270
+ <div class="tagline footer-links">
1271
+ <a href="https://huggingface.co/HiDream-ai/HiDream-O1-Image" target="_blank">HuggingFace</a>
1272
+ <a href="https://github.com/HiDream-ai/HiDream-O1-Image" target="_blank">GitHub</a>
1273
+ <a href="https://x.com/vivago_ai" target="_blank">Twitter</a>
1274
+ </div>
1275
+ <div class="tagline">
1276
+ For more features and the full experience, visit
1277
+ <a href="https://vivago.ai/" target="_blank">vivago.ai</a>.
1278
+ </div>
1279
+ """
1280
+ )
1281
+
1282
+ logger.info("Gradio UI created successfully")
1283
+ return demo
1284
+
1285
+
1286
+ if __name__ == "__main__":
1287
+ logger.info("Starting HiDream-O1-Image Generator application")
1288
+ demo = create_ui()
1289
+ logger.info("Launching Gradio interface with queue")
1290
+ demo.queue(max_size=50, default_concurrency_limit=4).launch(show_api=False)
1291
+ logger.info("Application shutdown")
gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .gradio
2
+ __pycache__
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ requests>=2.28.0
3
+ Pillow>=9.0.0
4
+ python-dotenv>=1.0.0
5
+ numpy>=1.22.0
6
+ tqdm>=4.65.0