Spaces:
Running
Running
| import logging | |
| import os | |
| import random | |
| import time | |
| import traceback | |
| from io import BytesIO | |
| import gradio as gr | |
| import requests | |
| from PIL import Image | |
| from dotenv import load_dotenv | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| # API Configuration (new style: host + gen_image_path) | |
| API_TOKEN = os.environ.get("token") | |
| API_HOST = os.environ.get("host") | |
| GEN_IMAGE_PATH = os.environ.get("gen_image_path") | |
| MODEL_ID = os.environ.get("model_id") | |
| # Polling / retry configuration (with sensible defaults) | |
| MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT", 3)) | |
| POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL", 2.0)) | |
| MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME", 300)) | |
| # Local-only test mode: skip the real API and return a random-colored image after | |
| # a randomised delay. Useful for testing the UI / queueing flow without burning | |
| # real model credits. Enable with FAKE_TEST=1 (also accepts "true"/"yes"). | |
| FAKE_TEST = os.environ.get("FAKE_TEST", "").strip().lower() in ("1", "true", "yes", "on") | |
| # Predefined aspect ratios (wh_ratio) — kept in the same order as the original | |
| # (width, height) list so each entry mirrors the previous resolution choice: | |
| # 1:1 ←→ 2048×2048 4:3 ←→ 2304×1728 3:4 ←→ 1728×2304 | |
| # 16:9 ←→ 2560×1440 9:16 ←→ 1440×2560 3:2 ←→ 2496×1664 | |
| # 2:3 ←→ 1664×2496 21:9 ←→ 3104×1312 9:21 ←→ 1312×3104 | |
| # 9:7 ←→ 2304×1792 7:9 ←→ 1792×2304 | |
| WH_RATIO_OPTIONS = [ | |
| "1:1", | |
| "4:3", | |
| "3:4", | |
| "16:9", | |
| "9:16", | |
| "3:2", | |
| "2:3", | |
| "21:9", | |
| "9:21", | |
| "9:7", | |
| "7:9", | |
| ] | |
| logger.info( | |
| f"API configuration loaded: HOST={API_HOST}, GEN_IMAGE_PATH={GEN_IMAGE_PATH}, MODEL_ID={MODEL_ID}" | |
| ) | |
| logger.info( | |
| f"Retry configuration: MAX_RETRY_COUNT={MAX_RETRY_COUNT}, POLL_INTERVAL={POLL_INTERVAL}s, MAX_POLL_TIME={MAX_POLL_TIME}s" | |
| ) | |
| if FAKE_TEST: | |
| logger.warning("FAKE_TEST mode is ENABLED — no real API calls will be made.") | |
| class APIError(Exception): | |
| """Custom exception for API-related errors""" | |
| pass | |
| # Status codes returned by the API | |
| SUCCESS_CODE = 0 | |
| def _build_request_url() -> str: | |
| if not API_HOST or not GEN_IMAGE_PATH: | |
| raise APIError("API host or gen_image_path is not configured. Please set the 'host' and 'gen_image_path' environment variables.") | |
| return f"{API_HOST.rstrip('/')}{GEN_IMAGE_PATH}" | |
| def _build_result_url(task_id: str) -> str: | |
| return f"{_build_request_url()}/results?task_id={task_id}" | |
| def _headers() -> dict: | |
| if not API_TOKEN: | |
| raise APIError("API token is not configured. Please set the 'token' environment variable.") | |
| return {"Authorization": f"Bearer {API_TOKEN}"} | |
| def create_request( | |
| prompt, | |
| wh_ratio, | |
| negative_prompt="", | |
| enable_prompt_refine=True, | |
| seed=-1, | |
| guidance_scale=5.0, | |
| ): | |
| """ | |
| Submit an image generation request to the API. | |
| Args: | |
| prompt (str): Text prompt describing the image to generate | |
| wh_ratio (str): Aspect ratio for the output image (e.g. "16:9") | |
| negative_prompt (str): Optional text describing what to avoid. | |
| enable_prompt_refine (bool): Whether to let the backend rewrite/expand | |
| the prompt before generation. Sent to the API as 0 / 1. | |
| seed (int): Generation seed. -1 means the backend will pick one | |
| randomly; any other integer fixes the seed for reproducible runs. | |
| guidance_scale (float): Classifier-free guidance strength. Higher | |
| values follow the prompt more strictly. | |
| Returns: | |
| str: Task ID | |
| Raises: | |
| APIError: If the API request fails | |
| """ | |
| logger.info( | |
| f"Starting create_request with prompt='{prompt[:50]}...', " | |
| f"wh_ratio={wh_ratio}, enable_prompt_refine={enable_prompt_refine}, " | |
| f"seed={seed}, guidance_scale={guidance_scale}, " | |
| f"negative_prompt='{(negative_prompt or '')[:30]}...'" | |
| ) | |
| if not prompt or not prompt.strip(): | |
| logger.error("Empty prompt provided to create_request") | |
| raise ValueError("Prompt cannot be empty") | |
| if not wh_ratio or not isinstance(wh_ratio, str) or wh_ratio not in WH_RATIO_OPTIONS: | |
| logger.error(f"Invalid wh_ratio: {wh_ratio}. Valid options: {WH_RATIO_OPTIONS}") | |
| raise ValueError(f"Invalid aspect ratio. Must be one of: {', '.join(WH_RATIO_OPTIONS)}") | |
| try: | |
| seed_int = int(seed) | |
| except (TypeError, ValueError): | |
| logger.warning(f"Invalid seed value '{seed}', falling back to -1 (random)") | |
| seed_int = -1 | |
| try: | |
| guidance_scale_f = float(guidance_scale) | |
| except (TypeError, ValueError): | |
| logger.warning(f"Invalid guidance_scale '{guidance_scale}', falling back to 5.0") | |
| guidance_scale_f = 5.0 | |
| model_params = { | |
| "prompt": prompt, | |
| "wh_ratio": wh_ratio, | |
| "model_id": MODEL_ID, | |
| "n": 1, | |
| "negative_prompt": negative_prompt or "", | |
| "enable_prompt_refine": 1 if enable_prompt_refine else 0, | |
| "seed": seed_int, | |
| "guidance_scale": guidance_scale_f, | |
| } | |
| url = _build_request_url() | |
| retry_count = 0 | |
| while retry_count < MAX_RETRY_COUNT: | |
| try: | |
| logger.info( | |
| f"Sending API request [attempt {retry_count + 1}/{MAX_RETRY_COUNT}] for prompt: '{prompt[:50]}...'" | |
| ) | |
| response = requests.post(url, json=model_params, headers=_headers(), timeout=15) | |
| logger.info(f"API request response status: {response.status_code}") | |
| response.raise_for_status() | |
| response_json = response.json() | |
| code = response_json.get("code") | |
| message = response_json.get("message", "") | |
| if code != SUCCESS_CODE: | |
| logger.error(f"API returned error code {code}: {message}") | |
| raise APIError(f"Failed to submit task (code={code}): {message}") | |
| task_id = response_json.get("result", {}).get("task_id") | |
| if not task_id: | |
| logger.error(f"No task ID in API response: {response_json}") | |
| raise APIError(f"No task ID returned from API: {response_json}") | |
| logger.info(f"Successfully created task with ID: {task_id}") | |
| return task_id | |
| except requests.exceptions.Timeout: | |
| retry_count += 1 | |
| logger.warning(f"Request timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") | |
| time.sleep(1) | |
| except requests.exceptions.HTTPError as e: | |
| status_code = e.response.status_code | |
| error_message = f"HTTP error {status_code}" | |
| try: | |
| error_detail = e.response.json() | |
| error_message += f": {error_detail}" | |
| logger.error(f"API response error content: {error_detail}") | |
| except Exception: | |
| logger.error(f"Could not parse API error response as JSON. Raw content: {e.response.content[:500]}") | |
| if status_code == 401: | |
| logger.error(f"Authentication failed with API token. Status code: {status_code}") | |
| raise APIError("Authentication failed. Please check your API token.") | |
| elif status_code == 429: | |
| retry_count += 1 | |
| wait_time = min(2 ** retry_count, 10) | |
| logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry ({retry_count}/{MAX_RETRY_COUNT})...") | |
| time.sleep(wait_time) | |
| elif 400 <= status_code < 500: | |
| logger.error(f"Client error: {error_message}, Prompt: '{prompt[:50]}...', Status: {status_code}") | |
| raise APIError(error_message) | |
| else: | |
| retry_count += 1 | |
| logger.warning(f"Server error: {error_message}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") | |
| time.sleep(1) | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Request error: {str(e)}") | |
| logger.debug(f"Request error details: {traceback.format_exc()}") | |
| raise APIError(f"Failed to connect to API: {str(e)}") | |
| except APIError: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Unexpected error in create_request: {str(e)}") | |
| logger.error(f"Full traceback: {traceback.format_exc()}") | |
| raise APIError(f"Unexpected error: {str(e)}") | |
| logger.error(f"Failed to create request after {MAX_RETRY_COUNT} retries for prompt: '{prompt[:50]}...'") | |
| raise APIError(f"Failed after {MAX_RETRY_COUNT} retries") | |
| def get_results(task_id): | |
| """ | |
| Check the status of an image generation task. | |
| Args: | |
| task_id (str): The task ID to check | |
| Returns: | |
| dict: Task result information (the "result" object from the response), or None on transient failure. | |
| Raises: | |
| APIError: For unrecoverable errors (e.g. authentication failure). | |
| """ | |
| logger.debug(f"Checking status for task ID: {task_id}") | |
| if not task_id: | |
| logger.error("Empty task ID provided to get_results") | |
| raise ValueError("Task ID cannot be empty") | |
| url = _build_result_url(task_id) | |
| try: | |
| response = requests.get(url, headers=_headers(), timeout=10) | |
| logger.debug(f"Status check response code: {response.status_code}") | |
| response.raise_for_status() | |
| response_json = response.json() | |
| code = response_json.get("code") | |
| message = response_json.get("message", "") | |
| if code != SUCCESS_CODE: | |
| logger.warning(f"API returned non-success code {code} for task {task_id}: {message}") | |
| return None | |
| return response_json.get("result") | |
| except requests.exceptions.Timeout: | |
| logger.warning(f"Request timed out when checking task {task_id}") | |
| return None | |
| except requests.exceptions.HTTPError as e: | |
| status_code = e.response.status_code | |
| logger.warning(f"HTTP error {status_code} when checking task {task_id}") | |
| try: | |
| error_content = e.response.json() | |
| logger.error(f"Error response content: {error_content}") | |
| except Exception: | |
| logger.error(f"Could not parse error response as JSON. Raw content: {e.response.content[:500]}") | |
| if status_code == 401: | |
| logger.error(f"Authentication failed when checking task {task_id}") | |
| raise APIError(f"Authentication failed. Please check your API token when checking task {task_id}") | |
| elif 400 <= status_code < 500: | |
| logger.error(f"Client error {status_code} when checking task {task_id}") | |
| return None | |
| else: | |
| logger.warning(f"Server error {status_code} when checking task {task_id}") | |
| return None | |
| except requests.exceptions.RequestException as e: | |
| logger.warning(f"Network error when checking task {task_id}: {str(e)}") | |
| logger.debug(f"Network error details: {traceback.format_exc()}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Unexpected error when checking task {task_id}: {str(e)}") | |
| logger.error(f"Full traceback: {traceback.format_exc()}") | |
| return None | |
| def download_image(image_url): | |
| """ | |
| Download an image from a URL and return it as a PIL Image. | |
| Converts non-PNG formats (e.g. WebP) to PNG while preserving original metadata. | |
| """ | |
| logger.info(f"Starting download_image from URL: {image_url}") | |
| if not image_url: | |
| logger.error("Empty image URL provided to download_image") | |
| raise ValueError("Image URL cannot be empty when downloading image") | |
| retry_count = 0 | |
| while retry_count < MAX_RETRY_COUNT: | |
| try: | |
| logger.info(f"Downloading image [attempt {retry_count + 1}/{MAX_RETRY_COUNT}] from {image_url}") | |
| response = requests.get(image_url, timeout=30) | |
| logger.debug( | |
| f"Image download response status: {response.status_code}, " | |
| f"Content-Type: {response.headers.get('Content-Type')}, " | |
| f"Content-Length: {response.headers.get('Content-Length')}" | |
| ) | |
| response.raise_for_status() | |
| image = Image.open(BytesIO(response.content)) | |
| logger.info( | |
| f"Image opened successfully. Format: {image.format}, " | |
| f"Size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}" | |
| ) | |
| original_metadata = {} | |
| for key, value in image.info.items(): | |
| if isinstance(key, str) and isinstance(value, str): | |
| original_metadata[key] = value | |
| logger.debug(f"Original image metadata: {original_metadata}") | |
| if image.format != 'PNG': | |
| logger.info(f"Converting image from {image.format} to PNG format") | |
| png_buffer = BytesIO() | |
| if 'A' in image.getbands(): | |
| image_to_save = image | |
| else: | |
| image_to_save = image.convert('RGB') | |
| image_to_save.save(png_buffer, format='PNG') | |
| png_buffer.seek(0) | |
| image = Image.open(png_buffer) | |
| for key, value in original_metadata.items(): | |
| image.info[key] = value | |
| logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}") | |
| return image | |
| except requests.exceptions.Timeout: | |
| retry_count += 1 | |
| logger.warning(f"Download timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") | |
| time.sleep(1) | |
| except requests.exceptions.HTTPError as e: | |
| status_code = e.response.status_code | |
| logger.error(f"HTTP error {status_code} when downloading image from {image_url}") | |
| if 400 <= status_code < 500: | |
| raise APIError(f"HTTP error {status_code} when downloading image") | |
| else: | |
| retry_count += 1 | |
| time.sleep(1) | |
| except requests.exceptions.RequestException as e: | |
| retry_count += 1 | |
| logger.warning(f"Network error during image download: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") | |
| time.sleep(1) | |
| except Exception as e: | |
| logger.error(f"Error processing image from {image_url}: {str(e)}") | |
| logger.error(f"Full traceback: {traceback.format_exc()}") | |
| raise APIError(f"Failed to process image: {str(e)}") | |
| logger.error(f"Failed to download image from {image_url} after {MAX_RETRY_COUNT} retries") | |
| raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries") | |
| APPLE_CSS = """ | |
| /* ---- Apple-inspired minimalist UI ---- */ | |
| .gradio-container { | |
| font-family: -apple-system, BlinkMacSystemFont, "SF Pro Display", "SF Pro Text", | |
| "Helvetica Neue", "Segoe UI", Inter, sans-serif !important; | |
| background: linear-gradient(180deg, #fbfbfd 0%, #f5f5f7 100%) !important; | |
| /* Always use ~3/4 of the viewport, capped at 1600px on huge screens. | |
| Using width AND max-width ensures the page is wide from first paint | |
| instead of growing only after content loads. */ | |
| width: min(1600px, 92vw) !important; | |
| max-width: 1600px !important; | |
| margin: 0 auto !important; | |
| -webkit-font-smoothing: antialiased; | |
| -moz-osx-font-smoothing: grayscale; | |
| color: #1d1d1f !important; | |
| } | |
| /* Cards / panels */ | |
| .panel-card { | |
| background: #ffffff !important; | |
| border-radius: 18px !important; | |
| padding: 18px !important; | |
| box-shadow: 0 1px 2px rgba(0,0,0,0.04), 0 8px 28px rgba(0,0,0,0.05) !important; | |
| border: 1px solid rgba(0,0,0,0.05) !important; | |
| } | |
| /* Inputs - rounded with apple-blue focus ring */ | |
| textarea, input[type="text"], input[type="number"], | |
| .gradio-container .form input, | |
| .gradio-container .form textarea { | |
| border-radius: 12px !important; | |
| border: 1px solid #d2d2d7 !important; | |
| background: #ffffff !important; | |
| transition: border-color 0.15s ease, box-shadow 0.15s ease !important; | |
| font-size: 15px !important; | |
| } | |
| textarea:focus, input:focus, | |
| .gradio-container .form input:focus, | |
| .gradio-container .form textarea:focus { | |
| border-color: #0071e3 !important; | |
| box-shadow: 0 0 0 4px rgba(0, 113, 227, 0.15) !important; | |
| outline: none !important; | |
| } | |
| /* Dropdown */ | |
| .gradio-container .wrap.svelte-1ipelgc { border-radius: 12px !important; } | |
| /* Block labels */ | |
| .gradio-container span[data-testid="block-label"], | |
| .gradio-container .block-label, | |
| .gradio-container label > span { | |
| color: #6e6e73 !important; | |
| font-weight: 500 !important; | |
| font-size: 13px !important; | |
| letter-spacing: 0.01em; | |
| } | |
| /* Buttons - pill shape, Apple blue. Scoped to the variant classes so it does | |
| NOT bleed into internal buttons inside dropdowns, accordions, etc. */ | |
| .gradio-container button.primary, | |
| .gradio-container button.secondary { | |
| border-radius: 980px !important; | |
| font-weight: 500 !important; | |
| font-size: 15px !important; | |
| padding: 12px 22px !important; | |
| transition: transform 0.08s ease, box-shadow 0.18s ease, background 0.18s ease, opacity 0.15s ease !important; | |
| border: none !important; | |
| letter-spacing: 0.01em; | |
| } | |
| .gradio-container button.primary { | |
| background: #0071e3 !important; | |
| color: #ffffff !important; | |
| box-shadow: 0 1px 2px rgba(0,113,227,0.25), 0 6px 16px rgba(0,113,227,0.22) !important; | |
| } | |
| .gradio-container button.primary:hover { | |
| background: #0077ed !important; | |
| transform: translateY(-1px); | |
| box-shadow: 0 2px 4px rgba(0,113,227,0.3), 0 10px 22px rgba(0,113,227,0.28) !important; | |
| } | |
| .gradio-container button.primary:active { transform: translateY(0); } | |
| .gradio-container button.secondary { | |
| background: rgba(0,0,0,0.05) !important; | |
| color: #1d1d1f !important; | |
| box-shadow: none !important; | |
| } | |
| .gradio-container button.secondary:hover { | |
| background: rgba(0,0,0,0.09) !important; | |
| } | |
| /* Make sure the dropdown's selected-value text never gets pill-clipped and | |
| is properly aligned inside its rounded box. */ | |
| .gradio-container .wrap-inner, | |
| .gradio-container .single-select, | |
| .gradio-container .secondary-wrap { | |
| border-radius: 12px !important; | |
| } | |
| .gradio-container .single-select input, | |
| .gradio-container input[role="listbox"] { | |
| border-radius: 12px !important; | |
| padding: 10px 14px !important; | |
| font-size: 15px !important; | |
| } | |
| /* Status pill */ | |
| #status-bar { | |
| padding: 0 !important; | |
| margin-top: 6px; | |
| } | |
| .status-pill { | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 9px; | |
| background: #f5f5f7; | |
| color: #1d1d1f; | |
| padding: 10px 14px; | |
| border-radius: 12px; | |
| font-size: 13px; | |
| font-weight: 500; | |
| line-height: 1; | |
| border: 1px solid rgba(0,0,0,0.04); | |
| } | |
| .status-dot { | |
| width: 8px; | |
| height: 8px; | |
| border-radius: 50%; | |
| background: #8e8e93; | |
| flex-shrink: 0; | |
| } | |
| .status-info .status-dot { background: #8e8e93; } | |
| .status-success { background: rgba(48,209,88,0.10); color: #0a7f2e; border-color: rgba(48,209,88,0.20); } | |
| .status-success .status-dot { background: #30d158; } | |
| .status-error { background: rgba(255,59,48,0.10); color: #b8261b; border-color: rgba(255,59,48,0.20); } | |
| .status-error .status-dot { background: #ff3b30; } | |
| .status-running { background: rgba(0,113,227,0.10); color: #0058b8; border-color: rgba(0,113,227,0.20); } | |
| .status-running .status-dot { | |
| background: #0071e3; | |
| animation: pulse 1.4s ease-in-out infinite; | |
| } | |
| @keyframes pulse { | |
| 0%, 100% { opacity: 0.4; transform: scale(0.85); } | |
| 50% { opacity: 1.0; transform: scale(1.15); } | |
| } | |
| /* Image output frame — never crop the image; show full picture with letterbox. */ | |
| .image-output { | |
| border-radius: 18px !important; | |
| background: #f5f5f7 !important; | |
| } | |
| .image-output, | |
| .image-output > div, | |
| .image-output [data-testid="image"], | |
| .image-output .image-container, | |
| .image-output .image-frame, | |
| .image-output .preview { | |
| min-height: 440px !important; | |
| display: flex !important; | |
| align-items: center !important; | |
| justify-content: center !important; | |
| } | |
| .image-output img { | |
| border-radius: 14px !important; | |
| object-fit: contain !important; | |
| max-width: 100% !important; | |
| max-height: 62vh !important; | |
| width: auto !important; | |
| height: auto !important; | |
| } | |
| /* Status pill placed inside the right column, above the image. */ | |
| .right-status { | |
| display: flex; | |
| justify-content: flex-start; | |
| margin-bottom: 6px; | |
| } | |
| /* Accordion */ | |
| .gradio-container details { | |
| border-radius: 14px !important; | |
| border: 1px solid rgba(0,0,0,0.06) !important; | |
| background: #ffffff !important; | |
| } | |
| /* Negative prompt — softer accent so it visually de-emphasises vs. the main prompt */ | |
| .negative-prompt textarea { | |
| background: #fbfbfd !important; | |
| border-color: #e3e3e8 !important; | |
| } | |
| .negative-prompt textarea:focus { | |
| background: #ffffff !important; | |
| } | |
| /* Advanced options row — keeps the refine switch + seed input visually paired */ | |
| .advanced-row { | |
| gap: 14px !important; | |
| margin-top: 2px; | |
| } | |
| /* Refine toggle — iOS-style settings card. | |
| Goal: title + helper text on the left, a polished pill switch on the right, | |
| everything contained inside a single soft card so the helper text no longer | |
| floats orphaned above the box. */ | |
| .refine-toggle { | |
| background: linear-gradient(180deg, #ffffff 0%, #f5f5f7 100%) !important; | |
| border-radius: 14px !important; | |
| border: 1px solid rgba(0,0,0,0.06) !important; | |
| padding: 14px 16px !important; | |
| box-shadow: 0 1px 2px rgba(0,0,0,0.03) !important; | |
| transition: border-color 0.18s ease, box-shadow 0.18s ease !important; | |
| min-height: 88px; | |
| display: flex !important; | |
| flex-direction: column !important; | |
| justify-content: center !important; | |
| } | |
| .refine-toggle:hover { | |
| border-color: rgba(0,0,0,0.10) !important; | |
| box-shadow: 0 1px 2px rgba(0,0,0,0.04), 0 4px 14px rgba(0,0,0,0.05) !important; | |
| } | |
| /* Strip default backgrounds from gradio's inner wrappers so only our card shows. */ | |
| .refine-toggle .form, | |
| .refine-toggle .wrap, | |
| .refine-toggle .form-wrap, | |
| .refine-toggle > div { | |
| background: transparent !important; | |
| border: none !important; | |
| padding: 0 !important; | |
| margin: 0 !important; | |
| box-shadow: none !important; | |
| } | |
| /* Helper / "info" text becomes a proper subtitle UNDER the toggle row. */ | |
| .refine-toggle [data-testid="block-info"], | |
| .refine-toggle .info { | |
| color: #6e6e73 !important; | |
| font-size: 12px !important; | |
| line-height: 1.4 !important; | |
| margin: 8px 0 0 0 !important; | |
| padding: 0 !important; | |
| text-align: left !important; | |
| order: 2 !important; | |
| } | |
| /* Force the gradio wrapper to stack: label-row first, info below. */ | |
| .refine-toggle .form, | |
| .refine-toggle > div:not([data-testid="block-info"]):not(.info) { | |
| display: flex !important; | |
| flex-direction: column !important; | |
| align-items: stretch !important; | |
| } | |
| /* Label row: title on the LEFT (full width), toggle pinned to the RIGHT. */ | |
| .refine-toggle label { | |
| display: flex !important; | |
| align-items: center !important; | |
| justify-content: space-between !important; | |
| flex-direction: row-reverse !important; | |
| cursor: pointer !important; | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| gap: 14px !important; | |
| width: 100% !important; | |
| order: 1 !important; | |
| } | |
| .refine-toggle label > span { | |
| color: #1d1d1f !important; | |
| font-size: 15px !important; | |
| font-weight: 600 !important; | |
| letter-spacing: -0.01em; | |
| flex: 1 1 auto !important; | |
| text-align: left !important; | |
| line-height: 1.3 !important; | |
| } | |
| /* Pill switch — bigger, smoother, more "Apple-like" */ | |
| .refine-toggle input[type="checkbox"] { | |
| appearance: none; | |
| -webkit-appearance: none; | |
| width: 46px !important; | |
| height: 28px !important; | |
| border-radius: 999px !important; | |
| background: #e5e5ea !important; | |
| position: relative; | |
| cursor: pointer; | |
| transition: background 0.22s ease, box-shadow 0.22s ease; | |
| border: none !important; | |
| flex-shrink: 0 !important; | |
| margin: 0 !important; | |
| box-shadow: inset 0 0 1px rgba(0,0,0,0.06); | |
| } | |
| .refine-toggle input[type="checkbox"]::after { | |
| content: ""; | |
| position: absolute; | |
| top: 2px; | |
| left: 2px; | |
| width: 24px; | |
| height: 24px; | |
| border-radius: 50%; | |
| background: #ffffff; | |
| box-shadow: 0 2px 5px rgba(0,0,0,0.18), 0 0 1px rgba(0,0,0,0.05); | |
| transition: transform 0.24s cubic-bezier(0.4, 0.0, 0.2, 1); | |
| } | |
| .refine-toggle input[type="checkbox"]:hover { | |
| background: #dcdce0 !important; | |
| } | |
| .refine-toggle input[type="checkbox"]:checked { | |
| background: #34c759 !important; | |
| } | |
| .refine-toggle input[type="checkbox"]:checked:hover { | |
| background: #30b352 !important; | |
| } | |
| .refine-toggle input[type="checkbox"]:checked::after { | |
| transform: translateX(18px); | |
| } | |
| .refine-toggle input[type="checkbox"]:focus-visible { | |
| box-shadow: 0 0 0 4px rgba(0,113,227,0.20) !important; | |
| } | |
| .refine-toggle input[type="checkbox"]:active::after { | |
| /* tiny squish on press, very iOS */ | |
| width: 28px; | |
| } | |
| .refine-toggle input[type="checkbox"]:checked:active::after { | |
| transform: translateX(14px); | |
| } | |
| /* Seed number input — match the prompt/dropdown rounding */ | |
| .seed-input input[type="number"] { | |
| border-radius: 12px !important; | |
| padding: 10px 14px !important; | |
| font-variant-numeric: tabular-nums; | |
| } | |
| /* Hide the native spinner buttons on number inputs for a cleaner look */ | |
| .seed-input input[type="number"]::-webkit-outer-spin-button, | |
| .seed-input input[type="number"]::-webkit-inner-spin-button { | |
| -webkit-appearance: none; | |
| margin: 0; | |
| } | |
| .seed-input input[type="number"] { | |
| -moz-appearance: textfield; | |
| } | |
| /* Keep the seed column visually aligned with the refine card next to it */ | |
| .seed-input { | |
| align-self: stretch !important; | |
| } | |
| /* Guidance scale slider — Apple-blue track + softer thumb */ | |
| .guidance-slider input[type="range"] { | |
| accent-color: #0071e3 !important; | |
| } | |
| .guidance-slider .head { padding-top: 0 !important; } | |
| .guidance-slider { | |
| margin-top: 4px; | |
| } | |
| /* Footer tagline */ | |
| .tagline { | |
| text-align: center; | |
| color: #6e6e73; | |
| font-size: 12px; | |
| margin: 18px 0 14px 0; | |
| font-weight: 400; | |
| } | |
| .tagline a { | |
| color: #0071e3; | |
| text-decoration: none; | |
| font-weight: 500; | |
| transition: opacity 0.15s ease; | |
| } | |
| .tagline a:hover { opacity: 0.7; } | |
| /* Footer links row (HuggingFace / GitHub / Twitter) */ | |
| .footer-links { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| gap: 26px; | |
| flex-wrap: wrap; | |
| font-size: 13px; | |
| margin: 24px 0 6px 0; | |
| } | |
| /* When the vivago tagline follows the links row, tighten the gap. */ | |
| .footer-links + .tagline { | |
| margin-top: 4px; | |
| } | |
| /* Hide gradio's default footer for a cleaner look */ | |
| footer { display: none !important; } | |
| /* Mobile */ | |
| @media (max-width: 640px) { | |
| .footer-links { gap: 18px; font-size: 12px; } | |
| } | |
| """ | |
| APPLE_THEME = gr.themes.Soft( | |
| primary_hue=gr.themes.colors.blue, | |
| neutral_hue=gr.themes.colors.slate, | |
| radius_size=gr.themes.sizes.radius_lg, | |
| text_size=gr.themes.sizes.text_md, | |
| font=[ | |
| gr.themes.GoogleFont("Inter"), | |
| "ui-sans-serif", | |
| "-apple-system", | |
| "BlinkMacSystemFont", | |
| "Segoe UI", | |
| "Helvetica Neue", | |
| "sans-serif", | |
| ], | |
| ).set( | |
| body_background_fill="*neutral_50", | |
| block_background_fill="white", | |
| block_border_width="1px", | |
| block_label_text_weight="500", | |
| block_title_text_weight="600", | |
| button_primary_background_fill="#0071e3", | |
| button_primary_background_fill_hover="#0077ed", | |
| button_primary_text_color="white", | |
| button_primary_border_color="#0071e3", | |
| button_secondary_background_fill="rgba(0,0,0,0.05)", | |
| button_secondary_background_fill_hover="rgba(0,0,0,0.09)", | |
| button_secondary_text_color="#1d1d1f", | |
| input_background_fill="white", | |
| input_border_color="#d2d2d7", | |
| input_border_color_focus="#0071e3", | |
| input_shadow_focus="0 0 0 4px rgba(0,113,227,0.15)", | |
| ) | |
| def _status_html(text: str, kind: str = "info") -> str: | |
| """Render a styled status pill. kind: info | success | error | running""" | |
| return ( | |
| f'<div class="status-pill status-{kind}">' | |
| f'<span class="status-dot"></span>{text}' | |
| f'</div>' | |
| ) | |
| def _queue_text(waiting: int, running: int) -> str: | |
| """Inline status text describing the queue from the user's POV. | |
| Used inside the right-side status pill while THIS user's request is | |
| sitting in the queue waiting for an open slot. | |
| """ | |
| if waiting <= 0 and running <= 0: | |
| return "Queued · waiting for an open slot…" | |
| parts = [] | |
| if waiting > 0: | |
| parts.append(f"{waiting} waiting") | |
| if running > 0: | |
| parts.append(f"{running} generating") | |
| return f"In queue · {' · '.join(parts)}" | |
| def _read_queue_stats(demo_obj) -> tuple[int, int, float]: | |
| """Best-effort read of gradio's internal queue. | |
| Gradio 5.x structure (gradio/queueing.py): | |
| demo._queue.event_queue_per_concurrency_id: dict[str, EventQueue] | |
| EventQueue.queue: list[Event] ← actually-waiting events | |
| demo._queue.active_jobs: list[None | list[Event]] | |
| each slot is None (idle) or a list of currently-processing events. | |
| demo._queue.process_time_per_fn: dict[BlockFunction, ProcessTime] | |
| ProcessTime.avg_time: float | |
| Returns (waiting, running, avg_secs). Each lookup is wrapped in a try | |
| so the UI degrades gracefully ("idle") if Gradio ever renames a field. | |
| """ | |
| try: | |
| q = getattr(demo_obj, "_queue", None) | |
| if q is None: | |
| return 0, 0, 0.0 | |
| # ---- Waiting: events sitting in EventQueue.queue ---- | |
| waiting = 0 | |
| events_per_cid = getattr(q, "event_queue_per_concurrency_id", None) or {} | |
| for ev_q in events_per_cid.values(): | |
| # Newer gradio: ev_q is an EventQueue with a .queue list. | |
| # Older/alt: ev_q might already be a list. Handle both. | |
| inner = getattr(ev_q, "queue", ev_q) | |
| try: | |
| waiting += len(inner) | |
| except (TypeError, AttributeError): | |
| # Last resort: try iterating | |
| try: | |
| waiting += sum(1 for _ in inner) | |
| except Exception: | |
| continue | |
| # ---- Running: count events held in active_jobs slots ---- | |
| # Each slot is None or a list[Event]; sum the list lengths. | |
| running = 0 | |
| active = getattr(q, "active_jobs", None) or [] | |
| for slot in active: | |
| if slot is None: | |
| continue | |
| try: | |
| running += len(slot) | |
| except (TypeError, AttributeError): | |
| running += 1 # very old gradio: single Event per slot | |
| # ---- Average per-run time (best effort across versions) ---- | |
| avg_secs = 0.0 | |
| # Gradio 5.x: dict[BlockFunction, ProcessTime] with .avg_time | |
| ptpf = getattr(q, "process_time_per_fn", None) | |
| if isinstance(ptpf, dict) and ptpf: | |
| try: | |
| vals = [] | |
| for v in ptpf.values(): | |
| avg_t = getattr(v, "avg_time", None) | |
| if avg_t: | |
| vals.append(float(avg_t)) | |
| if vals: | |
| avg_secs = sum(vals) / len(vals) | |
| except Exception: | |
| avg_secs = 0.0 | |
| # Older: dict[int, float] | |
| if not avg_secs: | |
| ptpfi = getattr(q, "process_time_per_fn_index", None) | |
| if isinstance(ptpfi, dict) and ptpfi: | |
| try: | |
| vals = [float(v) for v in ptpfi.values() if v] | |
| if vals: | |
| avg_secs = sum(vals) / len(vals) | |
| except Exception: | |
| avg_secs = 0.0 | |
| # Oldest: single float on the queue | |
| if not avg_secs: | |
| apt = getattr(q, "avg_process_time", None) | |
| if apt: | |
| try: | |
| avg_secs = float(apt) | |
| except Exception: | |
| avg_secs = 0.0 | |
| return waiting, running, avg_secs | |
| except Exception as exc: | |
| logger.debug(f"Queue introspection failed: {exc}") | |
| return 0, 0, 0.0 | |
| def _fake_generation_iter(prompt: str, wh_ratio_value: str): | |
| """FAKE_TEST mode generator. | |
| Mimics the real flow's yield protocol without hitting any external API. | |
| Useful for exercising the queue UI / status pill locally. | |
| Yields (image_or_None, status_html) tuples. The very first 'Sending | |
| request to API…' yield is emitted by the caller, so this iterator picks | |
| up from 'Request submitted' onwards. | |
| """ | |
| time.sleep(random.uniform(0.4, 1.0)) | |
| fake_id = f"{random.randint(0, 0xFFFFFFFF):08x}" | |
| yield None, _status_html(f"Request submitted · Task {fake_id}…", "running") | |
| target_secs = random.uniform(8.0, 22.0) | |
| start = time.time() | |
| while time.time() - start < target_secs: | |
| elapsed = int(time.time() - start) | |
| yield None, _status_html(f"Generating… {elapsed}s", "running") | |
| time.sleep(POLL_INTERVAL) | |
| yield None, _status_html("Downloading image…", "running") | |
| time.sleep(0.3) | |
| rgb = ( | |
| random.randint(40, 220), | |
| random.randint(40, 220), | |
| random.randint(40, 220), | |
| ) | |
| fake_image = Image.new("RGB", (1024, 1024), color=rgb) | |
| logger.info(f"FAKE_TEST: returning random color image rgb={rgb}, took {target_secs:.1f}s") | |
| yield fake_image, _status_html("Image generated", "success") | |
| def create_ui(): | |
| logger.info("Creating Gradio UI") | |
| with gr.Blocks( | |
| title="HiDream-O1-Image Generator", | |
| theme=APPLE_THEME, | |
| css=APPLE_CSS, | |
| ) as demo: | |
| # Per-session state used to drive the right-side status pill while | |
| # this user's request is sitting in the queue. `last_pill_state` | |
| # caches the most recent pill HTML so the timer can return | |
| # `gr.update()` (=no-op, no DOM replacement → no flicker) when the | |
| # queue counts haven't actually changed between ticks. | |
| queued_state = gr.State(False) | |
| last_pill_state = gr.State("") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, elem_classes=["panel-card"]): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the image you want to create...", | |
| lines=5, | |
| show_label=True, | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| placeholder="Things you want to avoid in the image (optional)...", | |
| lines=2, | |
| show_label=True, | |
| elem_classes=["negative-prompt"], | |
| ) | |
| wh_ratio = gr.Dropdown( | |
| choices=WH_RATIO_OPTIONS, | |
| value=WH_RATIO_OPTIONS[0], | |
| label="Aspect Ratio", | |
| info="Width : Height", | |
| ) | |
| guidance_scale = gr.Slider( | |
| minimum=1.0, | |
| maximum=20.0, | |
| step=0.1, | |
| value=5.0, | |
| label="Guidance Scale", | |
| info="Higher values follow the prompt more strictly", | |
| elem_classes=["guidance-slider"], | |
| ) | |
| with gr.Row(elem_classes=["advanced-row"], equal_height=True): | |
| enable_prompt_refine = gr.Checkbox( | |
| value=True, | |
| label="Prompt Refine", | |
| info="Let the model rewrite & enrich your prompt", | |
| elem_classes=["refine-toggle"], | |
| scale=1, | |
| ) | |
| seed = gr.Number( | |
| value=-1, | |
| label="Seed", | |
| info="Use -1 for a random seed", | |
| precision=0, | |
| minimum=-1, | |
| elem_classes=["seed-input"], | |
| scale=1, | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear", variant="secondary", scale=1) | |
| generate_btn = gr.Button("Generate", variant="primary", scale=3) | |
| with gr.Column(scale=1, elem_classes=["panel-card"]): | |
| status_msg = gr.HTML( | |
| value=_status_html("Ready", "info"), | |
| elem_id="status-bar", | |
| elem_classes=["right-status"], | |
| ) | |
| output_image = gr.Image( | |
| label="Generated Image", | |
| format="png", | |
| type="pil", | |
| interactive=False, | |
| show_download_button=True, | |
| elem_classes=["image-output"], | |
| ) | |
| def generate_with_status( | |
| prompt, | |
| wh_ratio_value, | |
| negative_prompt_value, | |
| enable_prompt_refine_value, | |
| seed_value, | |
| guidance_scale_value, | |
| ): | |
| logger.info( | |
| f"Starting image generation with prompt='{(prompt or '')[:50]}...', " | |
| f"wh_ratio={wh_ratio_value}, " | |
| f"negative_prompt='{(negative_prompt_value or '')[:30]}...', " | |
| f"enable_prompt_refine={enable_prompt_refine_value}, " | |
| f"seed={seed_value}, guidance_scale={guidance_scale_value}" | |
| ) | |
| yield None, _status_html("Sending request to API…", "running") | |
| try: | |
| if not prompt or not prompt.strip(): | |
| logger.error("Empty prompt provided in UI") | |
| yield None, _status_html("Prompt cannot be empty", "error") | |
| return | |
| if wh_ratio_value not in WH_RATIO_OPTIONS: | |
| logger.error(f"Invalid aspect ratio selection: {wh_ratio_value}") | |
| yield None, _status_html(f"Invalid aspect ratio “{wh_ratio_value}”", "error") | |
| return | |
| try: | |
| seed_int = int(seed_value) if seed_value is not None else -1 | |
| except (TypeError, ValueError): | |
| seed_int = -1 | |
| try: | |
| guidance_scale_f = float(guidance_scale_value) if guidance_scale_value is not None else 5.0 | |
| except (TypeError, ValueError): | |
| guidance_scale_f = 5.0 | |
| if FAKE_TEST: | |
| logger.info("FAKE_TEST mode active — bypassing real API call") | |
| yield from _fake_generation_iter(prompt, wh_ratio_value) | |
| return | |
| logger.info("Creating API request") | |
| task_id = create_request( | |
| prompt, | |
| wh_ratio_value, | |
| negative_prompt=negative_prompt_value or "", | |
| enable_prompt_refine=bool(enable_prompt_refine_value), | |
| seed=seed_int, | |
| guidance_scale=guidance_scale_f, | |
| ) | |
| yield None, _status_html(f"Request submitted · Task {task_id[:8]}…", "running") | |
| start_time = time.time() | |
| logger.info(f"Starting to poll for results for task ID: {task_id}") | |
| while time.time() - start_time < MAX_POLL_TIME: | |
| elapsed_time = time.time() - start_time | |
| logger.debug( | |
| f"Polling for results - Task ID: {task_id}, Elapsed: {elapsed_time:.2f}s" | |
| ) | |
| result = get_results(task_id) | |
| if not result: | |
| time.sleep(POLL_INTERVAL) | |
| continue | |
| overall_status = result.get("status") | |
| sub_results = result.get("sub_task_results", []) or [] | |
| if overall_status != 1: | |
| elapsed = int(time.time() - start_time) | |
| yield None, _status_html(f"Generating… {elapsed}s", "running") | |
| time.sleep(POLL_INTERVAL) | |
| continue | |
| if not sub_results: | |
| logger.error(f"Task completed but no sub_task_results returned. Task ID: {task_id}") | |
| yield None, _status_html("Task completed but no results returned", "error") | |
| return | |
| sub = sub_results[0] | |
| sub_status = sub.get("task_status") | |
| if sub_status == 1: | |
| logger.info(f"Task completed successfully - Task ID: {task_id}") | |
| image_url = sub.get("url") | |
| if not image_url: | |
| logger.error(f"No image URL in successful response. Sub result: {sub}") | |
| yield None, _status_html("No image URL in response", "error") | |
| return | |
| yield None, _status_html("Downloading image…", "running") | |
| logger.info(f"Downloading image - Task ID: {task_id}, URL: {image_url}") | |
| image = download_image(image_url) | |
| if image: | |
| logger.info(f"Image generation complete - Task ID: {task_id}") | |
| yield image, _status_html("Image generated", "success") | |
| return | |
| else: | |
| logger.error(f"Failed to download image - Task ID: {task_id}, URL: {image_url}") | |
| yield None, _status_html("Failed to download generated image", "error") | |
| return | |
| elif sub_status == 3: | |
| error_msg = sub.get("task_error") or sub.get("message") or "Unknown error" | |
| logger.error( | |
| f"Task failed - Task ID: {task_id}, Sub status: {sub_status}, Error: {error_msg}" | |
| ) | |
| yield None, _status_html(f"Task failed: {error_msg}", "error") | |
| return | |
| else: | |
| elapsed = int(time.time() - start_time) | |
| yield None, _status_html(f"Waiting… {elapsed}s", "running") | |
| time.sleep(POLL_INTERVAL) | |
| logger.error( | |
| f"Timeout waiting for task completion - Task ID: {task_id}, Max time: {MAX_POLL_TIME}s" | |
| ) | |
| yield None, _status_html(f"Timed out after {MAX_POLL_TIME}s", "error") | |
| except APIError as e: | |
| logger.error(f"API Error during generation: {str(e)}") | |
| yield None, _status_html(f"API error: {str(e)}", "error") | |
| except ValueError as e: | |
| logger.error(f"Value Error during generation: {str(e)}") | |
| yield None, _status_html(f"Value error: {str(e)}", "error") | |
| except Exception as e: | |
| logger.error(f"Unexpected error during image generation: {str(e)}") | |
| logger.error(f"Full traceback: {traceback.format_exc()}") | |
| yield None, _status_html(f"Unexpected error: {str(e)}", "error") | |
| def _enter_queue(): | |
| """Click handler #1 — runs IMMEDIATELY (queue=False). | |
| Flips queued_state to True and seeds the pill with a snapshot of | |
| the current queue. Subsequent live updates come from the timer. | |
| """ | |
| waiting, running, _ = _read_queue_stats(demo) | |
| html = _status_html(_queue_text(waiting, running), "running") | |
| return None, html, True, html | |
| def _generate_wrapped( | |
| prompt_value, | |
| wh_ratio_value, | |
| negative_prompt_value, | |
| enable_prompt_refine_value, | |
| seed_value, | |
| guidance_scale_value, | |
| ): | |
| """Click handler #2 — queued. | |
| Wraps the existing generator and, on the FIRST yield, also flips | |
| queued_state to False so the timer stops touching the pill and | |
| lets the generator's own `yield`s drive it (Generating XXs → ...). | |
| """ | |
| first = True | |
| for image, status_html in generate_with_status( | |
| prompt_value, | |
| wh_ratio_value, | |
| negative_prompt_value, | |
| enable_prompt_refine_value, | |
| seed_value, | |
| guidance_scale_value, | |
| ): | |
| if first: | |
| first = False | |
| yield image, status_html, False | |
| else: | |
| yield image, status_html, gr.update() | |
| generate_btn.click( | |
| fn=_enter_queue, | |
| inputs=None, | |
| outputs=[output_image, status_msg, queued_state, last_pill_state], | |
| queue=False, | |
| show_progress="hidden", | |
| ).then( | |
| fn=_generate_wrapped, | |
| inputs=[prompt, wh_ratio, negative_prompt, enable_prompt_refine, seed, guidance_scale], | |
| outputs=[output_image, status_msg, queued_state], | |
| show_progress="minimal", | |
| show_progress_on=[generate_btn], | |
| ) | |
| def clear_outputs(): | |
| logger.info("Clearing UI outputs") | |
| return None, _status_html("Ready", "info"), False, "" | |
| clear_btn.click( | |
| fn=clear_outputs, | |
| inputs=None, | |
| outputs=[output_image, status_msg, queued_state, last_pill_state], | |
| ) | |
| # Live queue updates inside the right-side pill — only while THIS | |
| # user is queued. Returns gr.update() (no-op) when nothing changed, | |
| # which prevents the DOM from being replaced and avoids the pulse | |
| # animation resetting (= no flicker). | |
| pill_timer = gr.Timer(value=1.5, active=True) | |
| def _tick_pill(queued_flag, last_html): | |
| if not queued_flag: | |
| return gr.update(), last_html | |
| waiting, running, _ = _read_queue_stats(demo) | |
| new_html = _status_html(_queue_text(waiting, running), "running") | |
| if new_html == last_html: | |
| return gr.update(), last_html | |
| return new_html, new_html | |
| pill_timer.tick( | |
| fn=_tick_pill, | |
| inputs=[queued_state, last_pill_state], | |
| outputs=[status_msg, last_pill_state], | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| gr.HTML( | |
| """ | |
| <div class="tagline footer-links"> | |
| <a href="https://huggingface.co/HiDream-ai/HiDream-O1-Image" target="_blank">HuggingFace</a> | |
| <a href="https://github.com/HiDream-ai/HiDream-O1-Image" target="_blank">GitHub</a> | |
| <a href="https://x.com/vivago_ai" target="_blank">Twitter</a> | |
| </div> | |
| <div class="tagline"> | |
| For more features and the full experience, visit | |
| <a href="https://vivago.ai/" target="_blank">vivago.ai</a>. | |
| </div> | |
| """ | |
| ) | |
| logger.info("Gradio UI created successfully") | |
| return demo | |
| if __name__ == "__main__": | |
| logger.info("Starting HiDream-O1-Image Generator application") | |
| demo = create_ui() | |
| logger.info("Launching Gradio interface with queue") | |
| demo.queue(max_size=50, default_concurrency_limit=4).launch(show_api=False) | |
| logger.info("Application shutdown") | |