cai-qi commited on
Commit
38ae4b8
·
verified ·
1 Parent(s): 1b38fa0

Upload 5 files

Browse files
Files changed (5) hide show
  1. .gitattributes +35 -35
  2. .gitignore +2 -0
  3. README.md +15 -14
  4. app.py +871 -0
  5. requirements.txt +6 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .gradio
2
+ __pycache__
README.md CHANGED
@@ -1,14 +1,15 @@
1
- ---
2
- title: HiDream O1 Image
3
- emoji: 🐢
4
- colorFrom: purple
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 6.14.0
8
- python_version: '3.13'
9
- app_file: app.py
10
- pinned: false
11
- license: mit
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
+ ---
2
+ title: HiDream O1 Image
3
+ emoji: 🚀
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.23.3
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ models:
12
+ - HiDream-ai/HiDream-O1-Image
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ import traceback
5
+ from io import BytesIO
6
+
7
+ import gradio as gr
8
+ import requests
9
+ from PIL import Image
10
+ from dotenv import load_dotenv
11
+
12
+ logging.basicConfig(
13
+ level=logging.INFO,
14
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
15
+ )
16
+ logger = logging.getLogger(__name__)
17
+
18
+ load_dotenv()
19
+
20
+ # API Configuration (new style: host + gen_image_path)
21
+ API_TOKEN = os.environ.get("token")
22
+ API_HOST = os.environ.get("host")
23
+ GEN_IMAGE_PATH = os.environ.get("gen_image_path")
24
+ MODEL_ID = os.environ.get("model_id")
25
+
26
+ # Polling / retry configuration (with sensible defaults)
27
+ MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT", 3))
28
+ POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL", 2.0))
29
+ MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME", 300))
30
+
31
+ # Predefined aspect ratios (wh_ratio) — kept in the same order as the original
32
+ # (width, height) list so each entry mirrors the previous resolution choice:
33
+ # 1:1 ←→ 2048×2048 4:3 ←→ 2304×1728 3:4 ←→ 1728×2304
34
+ # 16:9 ←→ 2560×1440 9:16 ←→ 1440×2560 3:2 ←→ 2496×1664
35
+ # 2:3 ←→ 1664×2496 21:9 ←→ 3104×1312 9:21 ←→ 1312×3104
36
+ # 9:7 ←→ 2304×1792 7:9 ←→ 1792×2304
37
+ WH_RATIO_OPTIONS = [
38
+ "1:1",
39
+ "4:3",
40
+ "3:4",
41
+ "16:9",
42
+ "9:16",
43
+ "3:2",
44
+ "2:3",
45
+ "21:9",
46
+ "9:21",
47
+ "9:7",
48
+ "7:9",
49
+ ]
50
+
51
+ logger.info(
52
+ f"API configuration loaded: HOST={API_HOST}, GEN_IMAGE_PATH={GEN_IMAGE_PATH}, MODEL_ID={MODEL_ID}"
53
+ )
54
+ logger.info(
55
+ f"Retry configuration: MAX_RETRY_COUNT={MAX_RETRY_COUNT}, POLL_INTERVAL={POLL_INTERVAL}s, MAX_POLL_TIME={MAX_POLL_TIME}s"
56
+ )
57
+
58
+
59
+ class APIError(Exception):
60
+ """Custom exception for API-related errors"""
61
+ pass
62
+
63
+
64
+ # Status codes returned by the API
65
+ SUCCESS_CODE = 0
66
+
67
+
68
+ def _build_request_url() -> str:
69
+ if not API_HOST or not GEN_IMAGE_PATH:
70
+ raise APIError("API host or gen_image_path is not configured. Please set the 'host' and 'gen_image_path' environment variables.")
71
+ return f"{API_HOST.rstrip('/')}{GEN_IMAGE_PATH}"
72
+
73
+
74
+ def _build_result_url(task_id: str) -> str:
75
+ return f"{_build_request_url()}/results?task_id={task_id}"
76
+
77
+
78
+ def _headers() -> dict:
79
+ if not API_TOKEN:
80
+ raise APIError("API token is not configured. Please set the 'token' environment variable.")
81
+ return {"Authorization": f"Bearer {API_TOKEN}"}
82
+
83
+
84
+ def create_request(prompt, wh_ratio):
85
+ """
86
+ Submit an image generation request to the API.
87
+
88
+ Args:
89
+ prompt (str): Text prompt describing the image to generate
90
+ wh_ratio (str): Aspect ratio for the output image (e.g. "16:9")
91
+
92
+ Returns:
93
+ str: Task ID
94
+
95
+ Raises:
96
+ APIError: If the API request fails
97
+ """
98
+ logger.info(
99
+ f"Starting create_request with prompt='{prompt[:50]}...', wh_ratio={wh_ratio}"
100
+ )
101
+
102
+ if not prompt or not prompt.strip():
103
+ logger.error("Empty prompt provided to create_request")
104
+ raise ValueError("Prompt cannot be empty")
105
+
106
+ if not wh_ratio or not isinstance(wh_ratio, str) or wh_ratio not in WH_RATIO_OPTIONS:
107
+ logger.error(f"Invalid wh_ratio: {wh_ratio}. Valid options: {WH_RATIO_OPTIONS}")
108
+ raise ValueError(f"Invalid aspect ratio. Must be one of: {', '.join(WH_RATIO_OPTIONS)}")
109
+
110
+ model_params = {
111
+ "prompt": prompt,
112
+ "wh_ratio": wh_ratio,
113
+ "model_id": MODEL_ID,
114
+ "n": 1,
115
+ }
116
+
117
+ url = _build_request_url()
118
+
119
+ retry_count = 0
120
+ while retry_count < MAX_RETRY_COUNT:
121
+ try:
122
+ logger.info(
123
+ f"Sending API request [attempt {retry_count + 1}/{MAX_RETRY_COUNT}] for prompt: '{prompt[:50]}...'"
124
+ )
125
+ response = requests.post(url, json=model_params, headers=_headers(), timeout=15)
126
+ logger.info(f"API request response status: {response.status_code}")
127
+ response.raise_for_status()
128
+
129
+ response_json = response.json()
130
+ code = response_json.get("code")
131
+ message = response_json.get("message", "")
132
+ if code != SUCCESS_CODE:
133
+ logger.error(f"API returned error code {code}: {message}")
134
+ raise APIError(f"Failed to submit task (code={code}): {message}")
135
+
136
+ task_id = response_json.get("result", {}).get("task_id")
137
+ if not task_id:
138
+ logger.error(f"No task ID in API response: {response_json}")
139
+ raise APIError(f"No task ID returned from API: {response_json}")
140
+
141
+ logger.info(f"Successfully created task with ID: {task_id}")
142
+ return task_id
143
+
144
+ except requests.exceptions.Timeout:
145
+ retry_count += 1
146
+ logger.warning(f"Request timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
147
+ time.sleep(1)
148
+
149
+ except requests.exceptions.HTTPError as e:
150
+ status_code = e.response.status_code
151
+ error_message = f"HTTP error {status_code}"
152
+ try:
153
+ error_detail = e.response.json()
154
+ error_message += f": {error_detail}"
155
+ logger.error(f"API response error content: {error_detail}")
156
+ except Exception:
157
+ logger.error(f"Could not parse API error response as JSON. Raw content: {e.response.content[:500]}")
158
+
159
+ if status_code == 401:
160
+ logger.error(f"Authentication failed with API token. Status code: {status_code}")
161
+ raise APIError("Authentication failed. Please check your API token.")
162
+ elif status_code == 429:
163
+ retry_count += 1
164
+ wait_time = min(2 ** retry_count, 10)
165
+ logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry ({retry_count}/{MAX_RETRY_COUNT})...")
166
+ time.sleep(wait_time)
167
+ elif 400 <= status_code < 500:
168
+ logger.error(f"Client error: {error_message}, Prompt: '{prompt[:50]}...', Status: {status_code}")
169
+ raise APIError(error_message)
170
+ else:
171
+ retry_count += 1
172
+ logger.warning(f"Server error: {error_message}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
173
+ time.sleep(1)
174
+
175
+ except requests.exceptions.RequestException as e:
176
+ logger.error(f"Request error: {str(e)}")
177
+ logger.debug(f"Request error details: {traceback.format_exc()}")
178
+ raise APIError(f"Failed to connect to API: {str(e)}")
179
+
180
+ except APIError:
181
+ raise
182
+
183
+ except Exception as e:
184
+ logger.error(f"Unexpected error in create_request: {str(e)}")
185
+ logger.error(f"Full traceback: {traceback.format_exc()}")
186
+ raise APIError(f"Unexpected error: {str(e)}")
187
+
188
+ logger.error(f"Failed to create request after {MAX_RETRY_COUNT} retries for prompt: '{prompt[:50]}...'")
189
+ raise APIError(f"Failed after {MAX_RETRY_COUNT} retries")
190
+
191
+
192
+ def get_results(task_id):
193
+ """
194
+ Check the status of an image generation task.
195
+
196
+ Args:
197
+ task_id (str): The task ID to check
198
+
199
+ Returns:
200
+ dict: Task result information (the "result" object from the response), or None on transient failure.
201
+
202
+ Raises:
203
+ APIError: For unrecoverable errors (e.g. authentication failure).
204
+ """
205
+ logger.debug(f"Checking status for task ID: {task_id}")
206
+
207
+ if not task_id:
208
+ logger.error("Empty task ID provided to get_results")
209
+ raise ValueError("Task ID cannot be empty")
210
+
211
+ url = _build_result_url(task_id)
212
+
213
+ try:
214
+ response = requests.get(url, headers=_headers(), timeout=10)
215
+ logger.debug(f"Status check response code: {response.status_code}")
216
+ response.raise_for_status()
217
+ response_json = response.json()
218
+
219
+ code = response_json.get("code")
220
+ message = response_json.get("message", "")
221
+ if code != SUCCESS_CODE:
222
+ logger.warning(f"API returned non-success code {code} for task {task_id}: {message}")
223
+ return None
224
+
225
+ return response_json.get("result")
226
+
227
+ except requests.exceptions.Timeout:
228
+ logger.warning(f"Request timed out when checking task {task_id}")
229
+ return None
230
+
231
+ except requests.exceptions.HTTPError as e:
232
+ status_code = e.response.status_code
233
+ logger.warning(f"HTTP error {status_code} when checking task {task_id}")
234
+ try:
235
+ error_content = e.response.json()
236
+ logger.error(f"Error response content: {error_content}")
237
+ except Exception:
238
+ logger.error(f"Could not parse error response as JSON. Raw content: {e.response.content[:500]}")
239
+
240
+ if status_code == 401:
241
+ logger.error(f"Authentication failed when checking task {task_id}")
242
+ raise APIError(f"Authentication failed. Please check your API token when checking task {task_id}")
243
+ elif 400 <= status_code < 500:
244
+ logger.error(f"Client error {status_code} when checking task {task_id}")
245
+ return None
246
+ else:
247
+ logger.warning(f"Server error {status_code} when checking task {task_id}")
248
+ return None
249
+
250
+ except requests.exceptions.RequestException as e:
251
+ logger.warning(f"Network error when checking task {task_id}: {str(e)}")
252
+ logger.debug(f"Network error details: {traceback.format_exc()}")
253
+ return None
254
+
255
+ except Exception as e:
256
+ logger.error(f"Unexpected error when checking task {task_id}: {str(e)}")
257
+ logger.error(f"Full traceback: {traceback.format_exc()}")
258
+ return None
259
+
260
+
261
+ def download_image(image_url):
262
+ """
263
+ Download an image from a URL and return it as a PIL Image.
264
+ Converts non-PNG formats (e.g. WebP) to PNG while preserving original metadata.
265
+ """
266
+ logger.info(f"Starting download_image from URL: {image_url}")
267
+
268
+ if not image_url:
269
+ logger.error("Empty image URL provided to download_image")
270
+ raise ValueError("Image URL cannot be empty when downloading image")
271
+
272
+ retry_count = 0
273
+ while retry_count < MAX_RETRY_COUNT:
274
+ try:
275
+ logger.info(f"Downloading image [attempt {retry_count + 1}/{MAX_RETRY_COUNT}] from {image_url}")
276
+ response = requests.get(image_url, timeout=30)
277
+ logger.debug(
278
+ f"Image download response status: {response.status_code}, "
279
+ f"Content-Type: {response.headers.get('Content-Type')}, "
280
+ f"Content-Length: {response.headers.get('Content-Length')}"
281
+ )
282
+ response.raise_for_status()
283
+
284
+ image = Image.open(BytesIO(response.content))
285
+ logger.info(
286
+ f"Image opened successfully. Format: {image.format}, "
287
+ f"Size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}"
288
+ )
289
+
290
+ original_metadata = {}
291
+ for key, value in image.info.items():
292
+ if isinstance(key, str) and isinstance(value, str):
293
+ original_metadata[key] = value
294
+ logger.debug(f"Original image metadata: {original_metadata}")
295
+
296
+ if image.format != 'PNG':
297
+ logger.info(f"Converting image from {image.format} to PNG format")
298
+ png_buffer = BytesIO()
299
+ if 'A' in image.getbands():
300
+ image_to_save = image
301
+ else:
302
+ image_to_save = image.convert('RGB')
303
+ image_to_save.save(png_buffer, format='PNG')
304
+ png_buffer.seek(0)
305
+ image = Image.open(png_buffer)
306
+ for key, value in original_metadata.items():
307
+ image.info[key] = value
308
+
309
+ logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}")
310
+ return image
311
+
312
+ except requests.exceptions.Timeout:
313
+ retry_count += 1
314
+ logger.warning(f"Download timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
315
+ time.sleep(1)
316
+
317
+ except requests.exceptions.HTTPError as e:
318
+ status_code = e.response.status_code
319
+ logger.error(f"HTTP error {status_code} when downloading image from {image_url}")
320
+ if 400 <= status_code < 500:
321
+ raise APIError(f"HTTP error {status_code} when downloading image")
322
+ else:
323
+ retry_count += 1
324
+ time.sleep(1)
325
+
326
+ except requests.exceptions.RequestException as e:
327
+ retry_count += 1
328
+ logger.warning(f"Network error during image download: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
329
+ time.sleep(1)
330
+
331
+ except Exception as e:
332
+ logger.error(f"Error processing image from {image_url}: {str(e)}")
333
+ logger.error(f"Full traceback: {traceback.format_exc()}")
334
+ raise APIError(f"Failed to process image: {str(e)}")
335
+
336
+ logger.error(f"Failed to download image from {image_url} after {MAX_RETRY_COUNT} retries")
337
+ raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries")
338
+
339
+
340
+ APPLE_CSS = """
341
+ /* ---- Apple-inspired minimalist UI ---- */
342
+ .gradio-container {
343
+ font-family: -apple-system, BlinkMacSystemFont, "SF Pro Display", "SF Pro Text",
344
+ "Helvetica Neue", "Segoe UI", Inter, sans-serif !important;
345
+ background: linear-gradient(180deg, #fbfbfd 0%, #f5f5f7 100%) !important;
346
+ /* Always use ~3/4 of the viewport, capped at 1600px on huge screens.
347
+ Using width AND max-width ensures the page is wide from first paint
348
+ instead of growing only after content loads. */
349
+ width: min(1600px, 92vw) !important;
350
+ max-width: 1600px !important;
351
+ margin: 0 auto !important;
352
+ -webkit-font-smoothing: antialiased;
353
+ -moz-osx-font-smoothing: grayscale;
354
+ color: #1d1d1f !important;
355
+ }
356
+
357
+ /* Header — two-column: title left, intro + links right. Compact so the whole
358
+ form + image fit on one screen without scrolling. */
359
+ #app-header {
360
+ display: flex;
361
+ align-items: center;
362
+ justify-content: space-between;
363
+ gap: 40px;
364
+ padding: 28px 8px 18px 8px;
365
+ }
366
+ .header-left { flex-shrink: 0; }
367
+ .header-left h1 {
368
+ font-size: 46px;
369
+ font-weight: 700;
370
+ letter-spacing: -0.025em;
371
+ margin: 0;
372
+ /* line-height needs headroom for descenders (g/p/y) — at 1.0 the gradient
373
+ text mask clips them. 1.2 gives clean rendering without adding extra
374
+ visual whitespace because the H1 has no following sibling. */
375
+ line-height: 1.2;
376
+ padding-bottom: 2px;
377
+ /* Apple-Intelligence-style gradient: system blue → indigo → purple → pink.
378
+ Background is animated very slowly for a subtle "alive" feel. */
379
+ background: linear-gradient(
380
+ 120deg,
381
+ #0071e3 0%,
382
+ #5e5ce6 35%,
383
+ #af52de 70%,
384
+ #ff375f 100%
385
+ );
386
+ background-size: 200% 200%;
387
+ -webkit-background-clip: text;
388
+ -webkit-text-fill-color: transparent;
389
+ background-clip: text;
390
+ animation: title-shimmer 12s ease-in-out infinite;
391
+ }
392
+ @keyframes title-shimmer {
393
+ 0%, 100% { background-position: 0% 50%; }
394
+ 50% { background-position: 100% 50%; }
395
+ }
396
+ .header-right {
397
+ flex: 1;
398
+ text-align: right;
399
+ max-width: 640px;
400
+ }
401
+ .header-right .subtitle {
402
+ color: #6e6e73;
403
+ font-size: 15px;
404
+ font-weight: 400;
405
+ line-height: 1.5;
406
+ margin: 0 0 10px 0;
407
+ }
408
+ .header-right .links {
409
+ display: flex;
410
+ justify-content: flex-end;
411
+ gap: 22px;
412
+ flex-wrap: wrap;
413
+ font-size: 13px;
414
+ }
415
+ .header-right .links a {
416
+ color: #0071e3;
417
+ text-decoration: none;
418
+ font-weight: 500;
419
+ transition: opacity 0.15s ease;
420
+ }
421
+ .header-right .links a:hover { opacity: 0.7; }
422
+
423
+ @media (max-width: 820px) {
424
+ #app-header {
425
+ flex-direction: column;
426
+ align-items: flex-start;
427
+ gap: 14px;
428
+ padding: 24px 8px 14px 8px;
429
+ }
430
+ .header-right { text-align: left; max-width: none; }
431
+ .header-right .links { justify-content: flex-start; }
432
+ }
433
+
434
+ /* Cards / panels */
435
+ .panel-card {
436
+ background: #ffffff !important;
437
+ border-radius: 18px !important;
438
+ padding: 18px !important;
439
+ box-shadow: 0 1px 2px rgba(0,0,0,0.04), 0 8px 28px rgba(0,0,0,0.05) !important;
440
+ border: 1px solid rgba(0,0,0,0.05) !important;
441
+ }
442
+
443
+ /* Inputs - rounded with apple-blue focus ring */
444
+ textarea, input[type="text"], input[type="number"],
445
+ .gradio-container .form input,
446
+ .gradio-container .form textarea {
447
+ border-radius: 12px !important;
448
+ border: 1px solid #d2d2d7 !important;
449
+ background: #ffffff !important;
450
+ transition: border-color 0.15s ease, box-shadow 0.15s ease !important;
451
+ font-size: 15px !important;
452
+ }
453
+ textarea:focus, input:focus,
454
+ .gradio-container .form input:focus,
455
+ .gradio-container .form textarea:focus {
456
+ border-color: #0071e3 !important;
457
+ box-shadow: 0 0 0 4px rgba(0, 113, 227, 0.15) !important;
458
+ outline: none !important;
459
+ }
460
+
461
+ /* Dropdown */
462
+ .gradio-container .wrap.svelte-1ipelgc { border-radius: 12px !important; }
463
+
464
+ /* Block labels */
465
+ .gradio-container span[data-testid="block-label"],
466
+ .gradio-container .block-label,
467
+ .gradio-container label > span {
468
+ color: #6e6e73 !important;
469
+ font-weight: 500 !important;
470
+ font-size: 13px !important;
471
+ letter-spacing: 0.01em;
472
+ }
473
+
474
+ /* Buttons - pill shape, Apple blue. Scoped to the variant classes so it does
475
+ NOT bleed into internal buttons inside dropdowns, accordions, etc. */
476
+ .gradio-container button.primary,
477
+ .gradio-container button.secondary {
478
+ border-radius: 980px !important;
479
+ font-weight: 500 !important;
480
+ font-size: 15px !important;
481
+ padding: 12px 22px !important;
482
+ transition: transform 0.08s ease, box-shadow 0.18s ease, background 0.18s ease, opacity 0.15s ease !important;
483
+ border: none !important;
484
+ letter-spacing: 0.01em;
485
+ }
486
+ .gradio-container button.primary {
487
+ background: #0071e3 !important;
488
+ color: #ffffff !important;
489
+ box-shadow: 0 1px 2px rgba(0,113,227,0.25), 0 6px 16px rgba(0,113,227,0.22) !important;
490
+ }
491
+ .gradio-container button.primary:hover {
492
+ background: #0077ed !important;
493
+ transform: translateY(-1px);
494
+ box-shadow: 0 2px 4px rgba(0,113,227,0.3), 0 10px 22px rgba(0,113,227,0.28) !important;
495
+ }
496
+ .gradio-container button.primary:active { transform: translateY(0); }
497
+ .gradio-container button.secondary {
498
+ background: rgba(0,0,0,0.05) !important;
499
+ color: #1d1d1f !important;
500
+ box-shadow: none !important;
501
+ }
502
+ .gradio-container button.secondary:hover {
503
+ background: rgba(0,0,0,0.09) !important;
504
+ }
505
+
506
+ /* Make sure the dropdown's selected-value text never gets pill-clipped and
507
+ is properly aligned inside its rounded box. */
508
+ .gradio-container .wrap-inner,
509
+ .gradio-container .single-select,
510
+ .gradio-container .secondary-wrap {
511
+ border-radius: 12px !important;
512
+ }
513
+ .gradio-container .single-select input,
514
+ .gradio-container input[role="listbox"] {
515
+ border-radius: 12px !important;
516
+ padding: 10px 14px !important;
517
+ font-size: 15px !important;
518
+ }
519
+
520
+ /* Status pill */
521
+ #status-bar {
522
+ padding: 0 !important;
523
+ margin-top: 6px;
524
+ }
525
+ .status-pill {
526
+ display: inline-flex;
527
+ align-items: center;
528
+ gap: 9px;
529
+ background: #f5f5f7;
530
+ color: #1d1d1f;
531
+ padding: 10px 14px;
532
+ border-radius: 12px;
533
+ font-size: 13px;
534
+ font-weight: 500;
535
+ line-height: 1;
536
+ border: 1px solid rgba(0,0,0,0.04);
537
+ }
538
+ .status-dot {
539
+ width: 8px;
540
+ height: 8px;
541
+ border-radius: 50%;
542
+ background: #8e8e93;
543
+ flex-shrink: 0;
544
+ }
545
+ .status-info .status-dot { background: #8e8e93; }
546
+ .status-success { background: rgba(48,209,88,0.10); color: #0a7f2e; border-color: rgba(48,209,88,0.20); }
547
+ .status-success .status-dot { background: #30d158; }
548
+ .status-error { background: rgba(255,59,48,0.10); color: #b8261b; border-color: rgba(255,59,48,0.20); }
549
+ .status-error .status-dot { background: #ff3b30; }
550
+ .status-running { background: rgba(0,113,227,0.10); color: #0058b8; border-color: rgba(0,113,227,0.20); }
551
+ .status-running .status-dot {
552
+ background: #0071e3;
553
+ animation: pulse 1.4s ease-in-out infinite;
554
+ }
555
+ @keyframes pulse {
556
+ 0%, 100% { opacity: 0.4; transform: scale(0.85); }
557
+ 50% { opacity: 1.0; transform: scale(1.15); }
558
+ }
559
+
560
+ /* Image output frame — never crop the image; show full picture with letterbox. */
561
+ .image-output {
562
+ border-radius: 18px !important;
563
+ background: #f5f5f7 !important;
564
+ }
565
+ .image-output,
566
+ .image-output > div,
567
+ .image-output [data-testid="image"],
568
+ .image-output .image-container,
569
+ .image-output .image-frame,
570
+ .image-output .preview {
571
+ min-height: 440px !important;
572
+ display: flex !important;
573
+ align-items: center !important;
574
+ justify-content: center !important;
575
+ }
576
+ .image-output img {
577
+ border-radius: 14px !important;
578
+ object-fit: contain !important;
579
+ max-width: 100% !important;
580
+ max-height: 62vh !important;
581
+ width: auto !important;
582
+ height: auto !important;
583
+ }
584
+
585
+ /* Status pill placed inside the right column, above the image. */
586
+ .right-status {
587
+ display: flex;
588
+ justify-content: flex-start;
589
+ margin-bottom: 6px;
590
+ }
591
+
592
+ /* Accordion */
593
+ .gradio-container details {
594
+ border-radius: 14px !important;
595
+ border: 1px solid rgba(0,0,0,0.06) !important;
596
+ background: #ffffff !important;
597
+ }
598
+
599
+ /* Footer tagline */
600
+ .tagline {
601
+ text-align: center;
602
+ color: #6e6e73;
603
+ font-size: 12px;
604
+ margin: 18px 0 14px 0;
605
+ font-weight: 400;
606
+ }
607
+ .tagline a {
608
+ color: #0071e3;
609
+ text-decoration: none;
610
+ font-weight: 500;
611
+ }
612
+ .tagline a:hover { opacity: 0.7; }
613
+
614
+ /* Hide gradio's default footer for a cleaner look */
615
+ footer { display: none !important; }
616
+
617
+ /* Mobile */
618
+ @media (max-width: 640px) {
619
+ #app-header { padding: 32px 16px 16px 16px; }
620
+ #app-header h1 { font-size: 32px; }
621
+ #app-header p.subtitle { font-size: 15px; }
622
+ }
623
+ """
624
+
625
+
626
+ APPLE_THEME = gr.themes.Soft(
627
+ primary_hue=gr.themes.colors.blue,
628
+ neutral_hue=gr.themes.colors.slate,
629
+ radius_size=gr.themes.sizes.radius_lg,
630
+ text_size=gr.themes.sizes.text_md,
631
+ font=[
632
+ gr.themes.GoogleFont("Inter"),
633
+ "ui-sans-serif",
634
+ "-apple-system",
635
+ "BlinkMacSystemFont",
636
+ "Segoe UI",
637
+ "Helvetica Neue",
638
+ "sans-serif",
639
+ ],
640
+ ).set(
641
+ body_background_fill="*neutral_50",
642
+ block_background_fill="white",
643
+ block_border_width="1px",
644
+ block_label_text_weight="500",
645
+ block_title_text_weight="600",
646
+ button_primary_background_fill="#0071e3",
647
+ button_primary_background_fill_hover="#0077ed",
648
+ button_primary_text_color="white",
649
+ button_primary_border_color="#0071e3",
650
+ button_secondary_background_fill="rgba(0,0,0,0.05)",
651
+ button_secondary_background_fill_hover="rgba(0,0,0,0.09)",
652
+ button_secondary_text_color="#1d1d1f",
653
+ input_background_fill="white",
654
+ input_border_color="#d2d2d7",
655
+ input_border_color_focus="#0071e3",
656
+ input_shadow_focus="0 0 0 4px rgba(0,113,227,0.15)",
657
+ )
658
+
659
+
660
+ def _status_html(text: str, kind: str = "info") -> str:
661
+ """Render a styled status pill. kind: info | success | error | running"""
662
+ return (
663
+ f'<div class="status-pill status-{kind}">'
664
+ f'<span class="status-dot"></span>{text}'
665
+ f'</div>'
666
+ )
667
+
668
+
669
+ def create_ui():
670
+ logger.info("Creating Gradio UI")
671
+ with gr.Blocks(
672
+ title="HiDream-O1-Image Generator",
673
+ theme=APPLE_THEME,
674
+ css=APPLE_CSS,
675
+ ) as demo:
676
+ gr.HTML(
677
+ """
678
+ <div id="app-header">
679
+ <div class="header-left">
680
+ <h1>HiDream-O1-Image</h1>
681
+ </div>
682
+ <div class="header-right">
683
+ <p class="subtitle">A natively unified pixel-space image generative model.</p>
684
+ <div class="links">
685
+ <a href="https://huggingface.co/HiDream-ai/HiDream-O1-Image" target="_blank">HuggingFace</a>
686
+ <a href="https://github.com/HiDream-ai/HiDream-O1-Image" target="_blank">GitHub</a>
687
+ <a href="https://x.com/vivago_ai" target="_blank">Twitter</a>
688
+ </div>
689
+ </div>
690
+ </div>
691
+ """
692
+ )
693
+
694
+ with gr.Row(equal_height=False):
695
+ with gr.Column(scale=1, elem_classes=["panel-card"]):
696
+ prompt = gr.Textbox(
697
+ label="Prompt",
698
+ placeholder="Describe the image you want to create...",
699
+ lines=5,
700
+ show_label=True,
701
+ )
702
+
703
+ wh_ratio = gr.Dropdown(
704
+ choices=WH_RATIO_OPTIONS,
705
+ value=WH_RATIO_OPTIONS[0],
706
+ label="Aspect Ratio",
707
+ info="Width : Height",
708
+ )
709
+
710
+ with gr.Row():
711
+ clear_btn = gr.Button("Clear", variant="secondary", scale=1)
712
+ generate_btn = gr.Button("Generate", variant="primary", scale=3)
713
+
714
+ with gr.Column(scale=1, elem_classes=["panel-card"]):
715
+ status_msg = gr.HTML(
716
+ value=_status_html("Ready", "info"),
717
+ elem_id="status-bar",
718
+ elem_classes=["right-status"],
719
+ )
720
+ output_image = gr.Image(
721
+ label="Generated Image",
722
+ format="png",
723
+ type="pil",
724
+ interactive=False,
725
+ show_download_button=True,
726
+ elem_classes=["image-output"],
727
+ )
728
+
729
+ def generate_with_status(prompt, wh_ratio_value):
730
+ logger.info(
731
+ f"Starting image generation with prompt='{(prompt or '')[:50]}...', wh_ratio={wh_ratio_value}"
732
+ )
733
+
734
+ yield None, _status_html("Sending request to API…", "running")
735
+
736
+ try:
737
+ if not prompt or not prompt.strip():
738
+ logger.error("Empty prompt provided in UI")
739
+ yield None, _status_html("Prompt cannot be empty", "error")
740
+ return
741
+
742
+ if wh_ratio_value not in WH_RATIO_OPTIONS:
743
+ logger.error(f"Invalid aspect ratio selection: {wh_ratio_value}")
744
+ yield None, _status_html(f"Invalid aspect ratio “{wh_ratio_value}”", "error")
745
+ return
746
+
747
+ logger.info("Creating API request")
748
+ task_id = create_request(prompt, wh_ratio_value)
749
+ yield None, _status_html(f"Request submitted · Task {task_id[:8]}…", "running")
750
+
751
+ start_time = time.time()
752
+ logger.info(f"Starting to poll for results for task ID: {task_id}")
753
+
754
+ while time.time() - start_time < MAX_POLL_TIME:
755
+ elapsed_time = time.time() - start_time
756
+ logger.debug(
757
+ f"Polling for results - Task ID: {task_id}, Elapsed: {elapsed_time:.2f}s"
758
+ )
759
+
760
+ result = get_results(task_id)
761
+ if not result:
762
+ time.sleep(POLL_INTERVAL)
763
+ continue
764
+
765
+ overall_status = result.get("status")
766
+ sub_results = result.get("sub_task_results", []) or []
767
+
768
+ if overall_status != 1:
769
+ elapsed = int(time.time() - start_time)
770
+ yield None, _status_html(f"Generating… {elapsed}s", "running")
771
+ time.sleep(POLL_INTERVAL)
772
+ continue
773
+
774
+ if not sub_results:
775
+ logger.error(f"Task completed but no sub_task_results returned. Task ID: {task_id}")
776
+ yield None, _status_html("Task completed but no results returned", "error")
777
+ return
778
+
779
+ sub = sub_results[0]
780
+ sub_status = sub.get("task_status")
781
+
782
+ if sub_status == 1:
783
+ logger.info(f"Task completed successfully - Task ID: {task_id}")
784
+
785
+ image_url = sub.get("url")
786
+ if not image_url:
787
+ logger.error(f"No image URL in successful response. Sub result: {sub}")
788
+ yield None, _status_html("No image URL in response", "error")
789
+ return
790
+
791
+ yield None, _status_html("Downloading image…", "running")
792
+
793
+ logger.info(f"Downloading image - Task ID: {task_id}, URL: {image_url}")
794
+ image = download_image(image_url)
795
+
796
+ if image:
797
+ logger.info(f"Image generation complete - Task ID: {task_id}")
798
+ yield image, _status_html("Image generated", "success")
799
+ return
800
+ else:
801
+ logger.error(f"Failed to download image - Task ID: {task_id}, URL: {image_url}")
802
+ yield None, _status_html("Failed to download generated image", "error")
803
+ return
804
+
805
+ elif sub_status == 3:
806
+ error_msg = sub.get("task_error") or sub.get("message") or "Unknown error"
807
+ logger.error(
808
+ f"Task failed - Task ID: {task_id}, Sub status: {sub_status}, Error: {error_msg}"
809
+ )
810
+ yield None, _status_html(f"Task failed: {error_msg}", "error")
811
+ return
812
+
813
+ else:
814
+ elapsed = int(time.time() - start_time)
815
+ yield None, _status_html(f"Waiting… {elapsed}s", "running")
816
+ time.sleep(POLL_INTERVAL)
817
+
818
+ logger.error(
819
+ f"Timeout waiting for task completion - Task ID: {task_id}, Max time: {MAX_POLL_TIME}s"
820
+ )
821
+ yield None, _status_html(f"Timed out after {MAX_POLL_TIME}s", "error")
822
+
823
+ except APIError as e:
824
+ logger.error(f"API Error during generation: {str(e)}")
825
+ yield None, _status_html(f"API error: {str(e)}", "error")
826
+
827
+ except ValueError as e:
828
+ logger.error(f"Value Error during generation: {str(e)}")
829
+ yield None, _status_html(f"Value error: {str(e)}", "error")
830
+
831
+ except Exception as e:
832
+ logger.error(f"Unexpected error during image generation: {str(e)}")
833
+ logger.error(f"Full traceback: {traceback.format_exc()}")
834
+ yield None, _status_html(f"Unexpected error: {str(e)}", "error")
835
+
836
+ generate_btn.click(
837
+ fn=generate_with_status,
838
+ inputs=[prompt, wh_ratio],
839
+ outputs=[output_image, status_msg],
840
+ show_progress="hidden",
841
+ )
842
+
843
+ def clear_outputs():
844
+ logger.info("Clearing UI outputs")
845
+ return None, _status_html("Ready", "info")
846
+
847
+ clear_btn.click(
848
+ fn=clear_outputs,
849
+ inputs=None,
850
+ outputs=[output_image, status_msg],
851
+ )
852
+
853
+ gr.HTML(
854
+ """
855
+ <div class="tagline">
856
+ For more features and the full experience, visit
857
+ <a href="https://vivago.ai/" target="_blank">vivago.ai</a>.
858
+ </div>
859
+ """
860
+ )
861
+
862
+ logger.info("Gradio UI created successfully")
863
+ return demo
864
+
865
+
866
+ if __name__ == "__main__":
867
+ logger.info("Starting HiDream-O1-Image Generator application")
868
+ demo = create_ui()
869
+ logger.info("Launching Gradio interface with queue")
870
+ demo.queue(max_size=50, default_concurrency_limit=4).launch(show_api=False)
871
+ logger.info("Application shutdown")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ requests>=2.28.0
3
+ Pillow>=9.0.0
4
+ python-dotenv>=1.0.0
5
+ numpy>=1.22.0
6
+ tqdm>=4.65.0