import logging import os import random import time import traceback from io import BytesIO import gradio as gr import requests from PIL import Image from dotenv import load_dotenv logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) load_dotenv() # API Configuration (new style: host + gen_image_path) API_TOKEN = os.environ.get("token") API_HOST = os.environ.get("host") GEN_IMAGE_PATH = os.environ.get("gen_image_path") MODEL_ID = os.environ.get("model_id") # Polling / retry configuration (with sensible defaults) MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT", 3)) POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL", 2.0)) MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME", 300)) # Local-only test mode: skip the real API and return a random-colored image after # a randomised delay. Useful for testing the UI / queueing flow without burning # real model credits. Enable with FAKE_TEST=1 (also accepts "true"/"yes"). FAKE_TEST = os.environ.get("FAKE_TEST", "").strip().lower() in ("1", "true", "yes", "on") # Predefined aspect ratios (wh_ratio) — kept in the same order as the original # (width, height) list so each entry mirrors the previous resolution choice: # 1:1 ←→ 2048×2048 4:3 ←→ 2304×1728 3:4 ←→ 1728×2304 # 16:9 ←→ 2560×1440 9:16 ←→ 1440×2560 3:2 ←→ 2496×1664 # 2:3 ←→ 1664×2496 21:9 ←→ 3104×1312 9:21 ←→ 1312×3104 # 9:7 ←→ 2304×1792 7:9 ←→ 1792×2304 WH_RATIO_OPTIONS = [ "1:1", "4:3", "3:4", "16:9", "9:16", "3:2", "2:3", "21:9", "9:21", "9:7", "7:9", ] logger.info( f"API configuration loaded: HOST={API_HOST}, GEN_IMAGE_PATH={GEN_IMAGE_PATH}, MODEL_ID={MODEL_ID}" ) logger.info( f"Retry configuration: MAX_RETRY_COUNT={MAX_RETRY_COUNT}, POLL_INTERVAL={POLL_INTERVAL}s, MAX_POLL_TIME={MAX_POLL_TIME}s" ) if FAKE_TEST: logger.warning("FAKE_TEST mode is ENABLED — no real API calls will be made.") class APIError(Exception): """Custom exception for API-related errors""" pass # Status codes returned by the API SUCCESS_CODE = 0 def _build_request_url() -> str: if not API_HOST or not GEN_IMAGE_PATH: raise APIError("API host or gen_image_path is not configured. Please set the 'host' and 'gen_image_path' environment variables.") return f"{API_HOST.rstrip('/')}{GEN_IMAGE_PATH}" def _build_result_url(task_id: str) -> str: return f"{_build_request_url()}/results?task_id={task_id}" def _headers() -> dict: if not API_TOKEN: raise APIError("API token is not configured. Please set the 'token' environment variable.") return {"Authorization": f"Bearer {API_TOKEN}"} def create_request( prompt, wh_ratio, negative_prompt="", enable_prompt_refine=True, seed=-1, guidance_scale=5.0, ): """ Submit an image generation request to the API. Args: prompt (str): Text prompt describing the image to generate wh_ratio (str): Aspect ratio for the output image (e.g. "16:9") negative_prompt (str): Optional text describing what to avoid. enable_prompt_refine (bool): Whether to let the backend rewrite/expand the prompt before generation. Sent to the API as 0 / 1. seed (int): Generation seed. -1 means the backend will pick one randomly; any other integer fixes the seed for reproducible runs. guidance_scale (float): Classifier-free guidance strength. Higher values follow the prompt more strictly. Returns: str: Task ID Raises: APIError: If the API request fails """ logger.info( f"Starting create_request with prompt='{prompt[:50]}...', " f"wh_ratio={wh_ratio}, enable_prompt_refine={enable_prompt_refine}, " f"seed={seed}, guidance_scale={guidance_scale}, " f"negative_prompt='{(negative_prompt or '')[:30]}...'" ) if not prompt or not prompt.strip(): logger.error("Empty prompt provided to create_request") raise ValueError("Prompt cannot be empty") if not wh_ratio or not isinstance(wh_ratio, str) or wh_ratio not in WH_RATIO_OPTIONS: logger.error(f"Invalid wh_ratio: {wh_ratio}. Valid options: {WH_RATIO_OPTIONS}") raise ValueError(f"Invalid aspect ratio. Must be one of: {', '.join(WH_RATIO_OPTIONS)}") try: seed_int = int(seed) except (TypeError, ValueError): logger.warning(f"Invalid seed value '{seed}', falling back to -1 (random)") seed_int = -1 try: guidance_scale_f = float(guidance_scale) except (TypeError, ValueError): logger.warning(f"Invalid guidance_scale '{guidance_scale}', falling back to 5.0") guidance_scale_f = 5.0 model_params = { "prompt": prompt, "wh_ratio": wh_ratio, "model_id": MODEL_ID, "n": 1, "negative_prompt": negative_prompt or "", "enable_prompt_refine": 1 if enable_prompt_refine else 0, "seed": seed_int, "guidance_scale": guidance_scale_f, } url = _build_request_url() retry_count = 0 while retry_count < MAX_RETRY_COUNT: try: logger.info( f"Sending API request [attempt {retry_count + 1}/{MAX_RETRY_COUNT}] for prompt: '{prompt[:50]}...'" ) response = requests.post(url, json=model_params, headers=_headers(), timeout=15) logger.info(f"API request response status: {response.status_code}") response.raise_for_status() response_json = response.json() code = response_json.get("code") message = response_json.get("message", "") if code != SUCCESS_CODE: logger.error(f"API returned error code {code}: {message}") raise APIError(f"Failed to submit task (code={code}): {message}") task_id = response_json.get("result", {}).get("task_id") if not task_id: logger.error(f"No task ID in API response: {response_json}") raise APIError(f"No task ID returned from API: {response_json}") logger.info(f"Successfully created task with ID: {task_id}") return task_id except requests.exceptions.Timeout: retry_count += 1 logger.warning(f"Request timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(1) except requests.exceptions.HTTPError as e: status_code = e.response.status_code error_message = f"HTTP error {status_code}" try: error_detail = e.response.json() error_message += f": {error_detail}" logger.error(f"API response error content: {error_detail}") except Exception: logger.error(f"Could not parse API error response as JSON. Raw content: {e.response.content[:500]}") if status_code == 401: logger.error(f"Authentication failed with API token. Status code: {status_code}") raise APIError("Authentication failed. Please check your API token.") elif status_code == 429: retry_count += 1 wait_time = min(2 ** retry_count, 10) logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(wait_time) elif 400 <= status_code < 500: logger.error(f"Client error: {error_message}, Prompt: '{prompt[:50]}...', Status: {status_code}") raise APIError(error_message) else: retry_count += 1 logger.warning(f"Server error: {error_message}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(1) except requests.exceptions.RequestException as e: logger.error(f"Request error: {str(e)}") logger.debug(f"Request error details: {traceback.format_exc()}") raise APIError(f"Failed to connect to API: {str(e)}") except APIError: raise except Exception as e: logger.error(f"Unexpected error in create_request: {str(e)}") logger.error(f"Full traceback: {traceback.format_exc()}") raise APIError(f"Unexpected error: {str(e)}") logger.error(f"Failed to create request after {MAX_RETRY_COUNT} retries for prompt: '{prompt[:50]}...'") raise APIError(f"Failed after {MAX_RETRY_COUNT} retries") def get_results(task_id): """ Check the status of an image generation task. Args: task_id (str): The task ID to check Returns: dict: Task result information (the "result" object from the response), or None on transient failure. Raises: APIError: For unrecoverable errors (e.g. authentication failure). """ logger.debug(f"Checking status for task ID: {task_id}") if not task_id: logger.error("Empty task ID provided to get_results") raise ValueError("Task ID cannot be empty") url = _build_result_url(task_id) try: response = requests.get(url, headers=_headers(), timeout=10) logger.debug(f"Status check response code: {response.status_code}") response.raise_for_status() response_json = response.json() code = response_json.get("code") message = response_json.get("message", "") if code != SUCCESS_CODE: logger.warning(f"API returned non-success code {code} for task {task_id}: {message}") return None return response_json.get("result") except requests.exceptions.Timeout: logger.warning(f"Request timed out when checking task {task_id}") return None except requests.exceptions.HTTPError as e: status_code = e.response.status_code logger.warning(f"HTTP error {status_code} when checking task {task_id}") try: error_content = e.response.json() logger.error(f"Error response content: {error_content}") except Exception: logger.error(f"Could not parse error response as JSON. Raw content: {e.response.content[:500]}") if status_code == 401: logger.error(f"Authentication failed when checking task {task_id}") raise APIError(f"Authentication failed. Please check your API token when checking task {task_id}") elif 400 <= status_code < 500: logger.error(f"Client error {status_code} when checking task {task_id}") return None else: logger.warning(f"Server error {status_code} when checking task {task_id}") return None except requests.exceptions.RequestException as e: logger.warning(f"Network error when checking task {task_id}: {str(e)}") logger.debug(f"Network error details: {traceback.format_exc()}") return None except Exception as e: logger.error(f"Unexpected error when checking task {task_id}: {str(e)}") logger.error(f"Full traceback: {traceback.format_exc()}") return None def download_image(image_url): """ Download an image from a URL and return it as a PIL Image. Converts non-PNG formats (e.g. WebP) to PNG while preserving original metadata. """ logger.info(f"Starting download_image from URL: {image_url}") if not image_url: logger.error("Empty image URL provided to download_image") raise ValueError("Image URL cannot be empty when downloading image") retry_count = 0 while retry_count < MAX_RETRY_COUNT: try: logger.info(f"Downloading image [attempt {retry_count + 1}/{MAX_RETRY_COUNT}] from {image_url}") response = requests.get(image_url, timeout=30) logger.debug( f"Image download response status: {response.status_code}, " f"Content-Type: {response.headers.get('Content-Type')}, " f"Content-Length: {response.headers.get('Content-Length')}" ) response.raise_for_status() image = Image.open(BytesIO(response.content)) logger.info( f"Image opened successfully. Format: {image.format}, " f"Size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}" ) original_metadata = {} for key, value in image.info.items(): if isinstance(key, str) and isinstance(value, str): original_metadata[key] = value logger.debug(f"Original image metadata: {original_metadata}") if image.format != 'PNG': logger.info(f"Converting image from {image.format} to PNG format") png_buffer = BytesIO() if 'A' in image.getbands(): image_to_save = image else: image_to_save = image.convert('RGB') image_to_save.save(png_buffer, format='PNG') png_buffer.seek(0) image = Image.open(png_buffer) for key, value in original_metadata.items(): image.info[key] = value logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}") return image except requests.exceptions.Timeout: retry_count += 1 logger.warning(f"Download timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(1) except requests.exceptions.HTTPError as e: status_code = e.response.status_code logger.error(f"HTTP error {status_code} when downloading image from {image_url}") if 400 <= status_code < 500: raise APIError(f"HTTP error {status_code} when downloading image") else: retry_count += 1 time.sleep(1) except requests.exceptions.RequestException as e: retry_count += 1 logger.warning(f"Network error during image download: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(1) except Exception as e: logger.error(f"Error processing image from {image_url}: {str(e)}") logger.error(f"Full traceback: {traceback.format_exc()}") raise APIError(f"Failed to process image: {str(e)}") logger.error(f"Failed to download image from {image_url} after {MAX_RETRY_COUNT} retries") raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries") APPLE_CSS = """ /* ---- Apple-inspired minimalist UI ---- */ .gradio-container { font-family: -apple-system, BlinkMacSystemFont, "SF Pro Display", "SF Pro Text", "Helvetica Neue", "Segoe UI", Inter, sans-serif !important; background: linear-gradient(180deg, #fbfbfd 0%, #f5f5f7 100%) !important; /* Always use ~3/4 of the viewport, capped at 1600px on huge screens. Using width AND max-width ensures the page is wide from first paint instead of growing only after content loads. */ width: min(1600px, 92vw) !important; max-width: 1600px !important; margin: 0 auto !important; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; color: #1d1d1f !important; } /* Cards / panels */ .panel-card { background: #ffffff !important; border-radius: 18px !important; padding: 18px !important; box-shadow: 0 1px 2px rgba(0,0,0,0.04), 0 8px 28px rgba(0,0,0,0.05) !important; border: 1px solid rgba(0,0,0,0.05) !important; } /* Inputs - rounded with apple-blue focus ring */ textarea, input[type="text"], input[type="number"], .gradio-container .form input, .gradio-container .form textarea { border-radius: 12px !important; border: 1px solid #d2d2d7 !important; background: #ffffff !important; transition: border-color 0.15s ease, box-shadow 0.15s ease !important; font-size: 15px !important; } textarea:focus, input:focus, .gradio-container .form input:focus, .gradio-container .form textarea:focus { border-color: #0071e3 !important; box-shadow: 0 0 0 4px rgba(0, 113, 227, 0.15) !important; outline: none !important; } /* Dropdown */ .gradio-container .wrap.svelte-1ipelgc { border-radius: 12px !important; } /* Block labels */ .gradio-container span[data-testid="block-label"], .gradio-container .block-label, .gradio-container label > span { color: #6e6e73 !important; font-weight: 500 !important; font-size: 13px !important; letter-spacing: 0.01em; } /* Buttons - pill shape, Apple blue. Scoped to the variant classes so it does NOT bleed into internal buttons inside dropdowns, accordions, etc. */ .gradio-container button.primary, .gradio-container button.secondary { border-radius: 980px !important; font-weight: 500 !important; font-size: 15px !important; padding: 12px 22px !important; transition: transform 0.08s ease, box-shadow 0.18s ease, background 0.18s ease, opacity 0.15s ease !important; border: none !important; letter-spacing: 0.01em; } .gradio-container button.primary { background: #0071e3 !important; color: #ffffff !important; box-shadow: 0 1px 2px rgba(0,113,227,0.25), 0 6px 16px rgba(0,113,227,0.22) !important; } .gradio-container button.primary:hover { background: #0077ed !important; transform: translateY(-1px); box-shadow: 0 2px 4px rgba(0,113,227,0.3), 0 10px 22px rgba(0,113,227,0.28) !important; } .gradio-container button.primary:active { transform: translateY(0); } .gradio-container button.secondary { background: rgba(0,0,0,0.05) !important; color: #1d1d1f !important; box-shadow: none !important; } .gradio-container button.secondary:hover { background: rgba(0,0,0,0.09) !important; } /* Make sure the dropdown's selected-value text never gets pill-clipped and is properly aligned inside its rounded box. */ .gradio-container .wrap-inner, .gradio-container .single-select, .gradio-container .secondary-wrap { border-radius: 12px !important; } .gradio-container .single-select input, .gradio-container input[role="listbox"] { border-radius: 12px !important; padding: 10px 14px !important; font-size: 15px !important; } /* Status pill */ #status-bar { padding: 0 !important; margin-top: 6px; } .status-pill { display: inline-flex; align-items: center; gap: 9px; background: #f5f5f7; color: #1d1d1f; padding: 10px 14px; border-radius: 12px; font-size: 13px; font-weight: 500; line-height: 1; border: 1px solid rgba(0,0,0,0.04); } .status-dot { width: 8px; height: 8px; border-radius: 50%; background: #8e8e93; flex-shrink: 0; } .status-info .status-dot { background: #8e8e93; } .status-success { background: rgba(48,209,88,0.10); color: #0a7f2e; border-color: rgba(48,209,88,0.20); } .status-success .status-dot { background: #30d158; } .status-error { background: rgba(255,59,48,0.10); color: #b8261b; border-color: rgba(255,59,48,0.20); } .status-error .status-dot { background: #ff3b30; } .status-running { background: rgba(0,113,227,0.10); color: #0058b8; border-color: rgba(0,113,227,0.20); } .status-running .status-dot { background: #0071e3; animation: pulse 1.4s ease-in-out infinite; } @keyframes pulse { 0%, 100% { opacity: 0.4; transform: scale(0.85); } 50% { opacity: 1.0; transform: scale(1.15); } } /* Image output frame — never crop the image; show full picture with letterbox. */ .image-output { border-radius: 18px !important; background: #f5f5f7 !important; } .image-output, .image-output > div, .image-output [data-testid="image"], .image-output .image-container, .image-output .image-frame, .image-output .preview { min-height: 440px !important; display: flex !important; align-items: center !important; justify-content: center !important; } .image-output img { border-radius: 14px !important; object-fit: contain !important; max-width: 100% !important; max-height: 62vh !important; width: auto !important; height: auto !important; } /* Status pill placed inside the right column, above the image. */ .right-status { display: flex; justify-content: flex-start; margin-bottom: 6px; } /* Accordion */ .gradio-container details { border-radius: 14px !important; border: 1px solid rgba(0,0,0,0.06) !important; background: #ffffff !important; } /* Negative prompt — softer accent so it visually de-emphasises vs. the main prompt */ .negative-prompt textarea { background: #fbfbfd !important; border-color: #e3e3e8 !important; } .negative-prompt textarea:focus { background: #ffffff !important; } /* Advanced options row — keeps the refine switch + seed input visually paired */ .advanced-row { gap: 14px !important; margin-top: 2px; } /* Refine toggle — iOS-style settings card. Goal: title + helper text on the left, a polished pill switch on the right, everything contained inside a single soft card so the helper text no longer floats orphaned above the box. */ .refine-toggle { background: linear-gradient(180deg, #ffffff 0%, #f5f5f7 100%) !important; border-radius: 14px !important; border: 1px solid rgba(0,0,0,0.06) !important; padding: 14px 16px !important; box-shadow: 0 1px 2px rgba(0,0,0,0.03) !important; transition: border-color 0.18s ease, box-shadow 0.18s ease !important; min-height: 88px; display: flex !important; flex-direction: column !important; justify-content: center !important; } .refine-toggle:hover { border-color: rgba(0,0,0,0.10) !important; box-shadow: 0 1px 2px rgba(0,0,0,0.04), 0 4px 14px rgba(0,0,0,0.05) !important; } /* Strip default backgrounds from gradio's inner wrappers so only our card shows. */ .refine-toggle .form, .refine-toggle .wrap, .refine-toggle .form-wrap, .refine-toggle > div { background: transparent !important; border: none !important; padding: 0 !important; margin: 0 !important; box-shadow: none !important; } /* Helper / "info" text becomes a proper subtitle UNDER the toggle row. */ .refine-toggle [data-testid="block-info"], .refine-toggle .info { color: #6e6e73 !important; font-size: 12px !important; line-height: 1.4 !important; margin: 8px 0 0 0 !important; padding: 0 !important; text-align: left !important; order: 2 !important; } /* Force the gradio wrapper to stack: label-row first, info below. */ .refine-toggle .form, .refine-toggle > div:not([data-testid="block-info"]):not(.info) { display: flex !important; flex-direction: column !important; align-items: stretch !important; } /* Label row: title on the LEFT (full width), toggle pinned to the RIGHT. */ .refine-toggle label { display: flex !important; align-items: center !important; justify-content: space-between !important; flex-direction: row-reverse !important; cursor: pointer !important; margin: 0 !important; padding: 0 !important; gap: 14px !important; width: 100% !important; order: 1 !important; } .refine-toggle label > span { color: #1d1d1f !important; font-size: 15px !important; font-weight: 600 !important; letter-spacing: -0.01em; flex: 1 1 auto !important; text-align: left !important; line-height: 1.3 !important; } /* Pill switch — bigger, smoother, more "Apple-like" */ .refine-toggle input[type="checkbox"] { appearance: none; -webkit-appearance: none; width: 46px !important; height: 28px !important; border-radius: 999px !important; background: #e5e5ea !important; position: relative; cursor: pointer; transition: background 0.22s ease, box-shadow 0.22s ease; border: none !important; flex-shrink: 0 !important; margin: 0 !important; box-shadow: inset 0 0 1px rgba(0,0,0,0.06); } .refine-toggle input[type="checkbox"]::after { content: ""; position: absolute; top: 2px; left: 2px; width: 24px; height: 24px; border-radius: 50%; background: #ffffff; box-shadow: 0 2px 5px rgba(0,0,0,0.18), 0 0 1px rgba(0,0,0,0.05); transition: transform 0.24s cubic-bezier(0.4, 0.0, 0.2, 1); } .refine-toggle input[type="checkbox"]:hover { background: #dcdce0 !important; } .refine-toggle input[type="checkbox"]:checked { background: #34c759 !important; } .refine-toggle input[type="checkbox"]:checked:hover { background: #30b352 !important; } .refine-toggle input[type="checkbox"]:checked::after { transform: translateX(18px); } .refine-toggle input[type="checkbox"]:focus-visible { box-shadow: 0 0 0 4px rgba(0,113,227,0.20) !important; } .refine-toggle input[type="checkbox"]:active::after { /* tiny squish on press, very iOS */ width: 28px; } .refine-toggle input[type="checkbox"]:checked:active::after { transform: translateX(14px); } /* Seed number input — match the prompt/dropdown rounding */ .seed-input input[type="number"] { border-radius: 12px !important; padding: 10px 14px !important; font-variant-numeric: tabular-nums; } /* Hide the native spinner buttons on number inputs for a cleaner look */ .seed-input input[type="number"]::-webkit-outer-spin-button, .seed-input input[type="number"]::-webkit-inner-spin-button { -webkit-appearance: none; margin: 0; } .seed-input input[type="number"] { -moz-appearance: textfield; } /* Keep the seed column visually aligned with the refine card next to it */ .seed-input { align-self: stretch !important; } /* Guidance scale slider — Apple-blue track + softer thumb */ .guidance-slider input[type="range"] { accent-color: #0071e3 !important; } .guidance-slider .head { padding-top: 0 !important; } .guidance-slider { margin-top: 4px; } /* Footer tagline */ .tagline { text-align: center; color: #6e6e73; font-size: 12px; margin: 18px 0 14px 0; font-weight: 400; } .tagline a { color: #0071e3; text-decoration: none; font-weight: 500; transition: opacity 0.15s ease; } .tagline a:hover { opacity: 0.7; } /* Footer links row (HuggingFace / GitHub / Twitter) */ .footer-links { display: flex; justify-content: center; align-items: center; gap: 26px; flex-wrap: wrap; font-size: 13px; margin: 24px 0 6px 0; } /* When the vivago tagline follows the links row, tighten the gap. */ .footer-links + .tagline { margin-top: 4px; } /* Hide gradio's default footer for a cleaner look */ footer { display: none !important; } /* Mobile */ @media (max-width: 640px) { .footer-links { gap: 18px; font-size: 12px; } } """ APPLE_THEME = gr.themes.Soft( primary_hue=gr.themes.colors.blue, neutral_hue=gr.themes.colors.slate, radius_size=gr.themes.sizes.radius_lg, text_size=gr.themes.sizes.text_md, font=[ gr.themes.GoogleFont("Inter"), "ui-sans-serif", "-apple-system", "BlinkMacSystemFont", "Segoe UI", "Helvetica Neue", "sans-serif", ], ).set( body_background_fill="*neutral_50", block_background_fill="white", block_border_width="1px", block_label_text_weight="500", block_title_text_weight="600", button_primary_background_fill="#0071e3", button_primary_background_fill_hover="#0077ed", button_primary_text_color="white", button_primary_border_color="#0071e3", button_secondary_background_fill="rgba(0,0,0,0.05)", button_secondary_background_fill_hover="rgba(0,0,0,0.09)", button_secondary_text_color="#1d1d1f", input_background_fill="white", input_border_color="#d2d2d7", input_border_color_focus="#0071e3", input_shadow_focus="0 0 0 4px rgba(0,113,227,0.15)", ) def _status_html(text: str, kind: str = "info") -> str: """Render a styled status pill. kind: info | success | error | running""" return ( f'
' f'{text}' f'
' ) def _queue_text(waiting: int, running: int) -> str: """Inline status text describing the queue from the user's POV. Used inside the right-side status pill while THIS user's request is sitting in the queue waiting for an open slot. """ if waiting <= 0 and running <= 0: return "Queued · waiting for an open slot…" parts = [] if waiting > 0: parts.append(f"{waiting} waiting") if running > 0: parts.append(f"{running} generating") return f"In queue · {' · '.join(parts)}" def _read_queue_stats(demo_obj) -> tuple[int, int, float]: """Best-effort read of gradio's internal queue. Gradio 5.x structure (gradio/queueing.py): demo._queue.event_queue_per_concurrency_id: dict[str, EventQueue] EventQueue.queue: list[Event] ← actually-waiting events demo._queue.active_jobs: list[None | list[Event]] each slot is None (idle) or a list of currently-processing events. demo._queue.process_time_per_fn: dict[BlockFunction, ProcessTime] ProcessTime.avg_time: float Returns (waiting, running, avg_secs). Each lookup is wrapped in a try so the UI degrades gracefully ("idle") if Gradio ever renames a field. """ try: q = getattr(demo_obj, "_queue", None) if q is None: return 0, 0, 0.0 # ---- Waiting: events sitting in EventQueue.queue ---- waiting = 0 events_per_cid = getattr(q, "event_queue_per_concurrency_id", None) or {} for ev_q in events_per_cid.values(): # Newer gradio: ev_q is an EventQueue with a .queue list. # Older/alt: ev_q might already be a list. Handle both. inner = getattr(ev_q, "queue", ev_q) try: waiting += len(inner) except (TypeError, AttributeError): # Last resort: try iterating try: waiting += sum(1 for _ in inner) except Exception: continue # ---- Running: count events held in active_jobs slots ---- # Each slot is None or a list[Event]; sum the list lengths. running = 0 active = getattr(q, "active_jobs", None) or [] for slot in active: if slot is None: continue try: running += len(slot) except (TypeError, AttributeError): running += 1 # very old gradio: single Event per slot # ---- Average per-run time (best effort across versions) ---- avg_secs = 0.0 # Gradio 5.x: dict[BlockFunction, ProcessTime] with .avg_time ptpf = getattr(q, "process_time_per_fn", None) if isinstance(ptpf, dict) and ptpf: try: vals = [] for v in ptpf.values(): avg_t = getattr(v, "avg_time", None) if avg_t: vals.append(float(avg_t)) if vals: avg_secs = sum(vals) / len(vals) except Exception: avg_secs = 0.0 # Older: dict[int, float] if not avg_secs: ptpfi = getattr(q, "process_time_per_fn_index", None) if isinstance(ptpfi, dict) and ptpfi: try: vals = [float(v) for v in ptpfi.values() if v] if vals: avg_secs = sum(vals) / len(vals) except Exception: avg_secs = 0.0 # Oldest: single float on the queue if not avg_secs: apt = getattr(q, "avg_process_time", None) if apt: try: avg_secs = float(apt) except Exception: avg_secs = 0.0 return waiting, running, avg_secs except Exception as exc: logger.debug(f"Queue introspection failed: {exc}") return 0, 0, 0.0 def _fake_generation_iter(prompt: str, wh_ratio_value: str): """FAKE_TEST mode generator. Mimics the real flow's yield protocol without hitting any external API. Useful for exercising the queue UI / status pill locally. Yields (image_or_None, status_html) tuples. The very first 'Sending request to API…' yield is emitted by the caller, so this iterator picks up from 'Request submitted' onwards. """ time.sleep(random.uniform(0.4, 1.0)) fake_id = f"{random.randint(0, 0xFFFFFFFF):08x}" yield None, _status_html(f"Request submitted · Task {fake_id}…", "running") target_secs = random.uniform(8.0, 22.0) start = time.time() while time.time() - start < target_secs: elapsed = int(time.time() - start) yield None, _status_html(f"Generating… {elapsed}s", "running") time.sleep(POLL_INTERVAL) yield None, _status_html("Downloading image…", "running") time.sleep(0.3) rgb = ( random.randint(40, 220), random.randint(40, 220), random.randint(40, 220), ) fake_image = Image.new("RGB", (1024, 1024), color=rgb) logger.info(f"FAKE_TEST: returning random color image rgb={rgb}, took {target_secs:.1f}s") yield fake_image, _status_html("Image generated", "success") def create_ui(): logger.info("Creating Gradio UI") with gr.Blocks( title="HiDream-O1-Image Generator", theme=APPLE_THEME, css=APPLE_CSS, ) as demo: # Per-session state used to drive the right-side status pill while # this user's request is sitting in the queue. `last_pill_state` # caches the most recent pill HTML so the timer can return # `gr.update()` (=no-op, no DOM replacement → no flicker) when the # queue counts haven't actually changed between ticks. queued_state = gr.State(False) last_pill_state = gr.State("") with gr.Row(equal_height=False): with gr.Column(scale=1, elem_classes=["panel-card"]): prompt = gr.Textbox( label="Prompt", placeholder="Describe the image you want to create...", lines=5, show_label=True, ) negative_prompt = gr.Textbox( label="Negative Prompt", placeholder="Things you want to avoid in the image (optional)...", lines=2, show_label=True, elem_classes=["negative-prompt"], ) wh_ratio = gr.Dropdown( choices=WH_RATIO_OPTIONS, value=WH_RATIO_OPTIONS[0], label="Aspect Ratio", info="Width : Height", ) guidance_scale = gr.Slider( minimum=1.0, maximum=20.0, step=0.1, value=5.0, label="Guidance Scale", info="Higher values follow the prompt more strictly", elem_classes=["guidance-slider"], ) with gr.Row(elem_classes=["advanced-row"], equal_height=True): enable_prompt_refine = gr.Checkbox( value=True, label="Prompt Refine", info="Let the model rewrite & enrich your prompt", elem_classes=["refine-toggle"], scale=1, ) seed = gr.Number( value=-1, label="Seed", info="Use -1 for a random seed", precision=0, minimum=-1, elem_classes=["seed-input"], scale=1, ) with gr.Row(): clear_btn = gr.Button("Clear", variant="secondary", scale=1) generate_btn = gr.Button("Generate", variant="primary", scale=3) with gr.Column(scale=1, elem_classes=["panel-card"]): status_msg = gr.HTML( value=_status_html("Ready", "info"), elem_id="status-bar", elem_classes=["right-status"], ) output_image = gr.Image( label="Generated Image", format="png", type="pil", interactive=False, show_download_button=True, elem_classes=["image-output"], ) def generate_with_status( prompt, wh_ratio_value, negative_prompt_value, enable_prompt_refine_value, seed_value, guidance_scale_value, ): logger.info( f"Starting image generation with prompt='{(prompt or '')[:50]}...', " f"wh_ratio={wh_ratio_value}, " f"negative_prompt='{(negative_prompt_value or '')[:30]}...', " f"enable_prompt_refine={enable_prompt_refine_value}, " f"seed={seed_value}, guidance_scale={guidance_scale_value}" ) yield None, _status_html("Sending request to API…", "running") try: if not prompt or not prompt.strip(): logger.error("Empty prompt provided in UI") yield None, _status_html("Prompt cannot be empty", "error") return if wh_ratio_value not in WH_RATIO_OPTIONS: logger.error(f"Invalid aspect ratio selection: {wh_ratio_value}") yield None, _status_html(f"Invalid aspect ratio “{wh_ratio_value}”", "error") return try: seed_int = int(seed_value) if seed_value is not None else -1 except (TypeError, ValueError): seed_int = -1 try: guidance_scale_f = float(guidance_scale_value) if guidance_scale_value is not None else 5.0 except (TypeError, ValueError): guidance_scale_f = 5.0 if FAKE_TEST: logger.info("FAKE_TEST mode active — bypassing real API call") yield from _fake_generation_iter(prompt, wh_ratio_value) return logger.info("Creating API request") task_id = create_request( prompt, wh_ratio_value, negative_prompt=negative_prompt_value or "", enable_prompt_refine=bool(enable_prompt_refine_value), seed=seed_int, guidance_scale=guidance_scale_f, ) yield None, _status_html(f"Request submitted · Task {task_id[:8]}…", "running") start_time = time.time() logger.info(f"Starting to poll for results for task ID: {task_id}") while time.time() - start_time < MAX_POLL_TIME: elapsed_time = time.time() - start_time logger.debug( f"Polling for results - Task ID: {task_id}, Elapsed: {elapsed_time:.2f}s" ) result = get_results(task_id) if not result: time.sleep(POLL_INTERVAL) continue overall_status = result.get("status") sub_results = result.get("sub_task_results", []) or [] if overall_status != 1: elapsed = int(time.time() - start_time) yield None, _status_html(f"Generating… {elapsed}s", "running") time.sleep(POLL_INTERVAL) continue if not sub_results: logger.error(f"Task completed but no sub_task_results returned. Task ID: {task_id}") yield None, _status_html("Task completed but no results returned", "error") return sub = sub_results[0] sub_status = sub.get("task_status") if sub_status == 1: logger.info(f"Task completed successfully - Task ID: {task_id}") image_url = sub.get("url") if not image_url: logger.error(f"No image URL in successful response. Sub result: {sub}") yield None, _status_html("No image URL in response", "error") return yield None, _status_html("Downloading image…", "running") logger.info(f"Downloading image - Task ID: {task_id}, URL: {image_url}") image = download_image(image_url) if image: logger.info(f"Image generation complete - Task ID: {task_id}") yield image, _status_html("Image generated", "success") return else: logger.error(f"Failed to download image - Task ID: {task_id}, URL: {image_url}") yield None, _status_html("Failed to download generated image", "error") return elif sub_status == 3: error_msg = sub.get("task_error") or sub.get("message") or "Unknown error" logger.error( f"Task failed - Task ID: {task_id}, Sub status: {sub_status}, Error: {error_msg}" ) yield None, _status_html(f"Task failed: {error_msg}", "error") return else: elapsed = int(time.time() - start_time) yield None, _status_html(f"Waiting… {elapsed}s", "running") time.sleep(POLL_INTERVAL) logger.error( f"Timeout waiting for task completion - Task ID: {task_id}, Max time: {MAX_POLL_TIME}s" ) yield None, _status_html(f"Timed out after {MAX_POLL_TIME}s", "error") except APIError as e: logger.error(f"API Error during generation: {str(e)}") yield None, _status_html(f"API error: {str(e)}", "error") except ValueError as e: logger.error(f"Value Error during generation: {str(e)}") yield None, _status_html(f"Value error: {str(e)}", "error") except Exception as e: logger.error(f"Unexpected error during image generation: {str(e)}") logger.error(f"Full traceback: {traceback.format_exc()}") yield None, _status_html(f"Unexpected error: {str(e)}", "error") def _enter_queue(): """Click handler #1 — runs IMMEDIATELY (queue=False). Flips queued_state to True and seeds the pill with a snapshot of the current queue. Subsequent live updates come from the timer. """ waiting, running, _ = _read_queue_stats(demo) html = _status_html(_queue_text(waiting, running), "running") return None, html, True, html def _generate_wrapped( prompt_value, wh_ratio_value, negative_prompt_value, enable_prompt_refine_value, seed_value, guidance_scale_value, ): """Click handler #2 — queued. Wraps the existing generator and, on the FIRST yield, also flips queued_state to False so the timer stops touching the pill and lets the generator's own `yield`s drive it (Generating XXs → ...). """ first = True for image, status_html in generate_with_status( prompt_value, wh_ratio_value, negative_prompt_value, enable_prompt_refine_value, seed_value, guidance_scale_value, ): if first: first = False yield image, status_html, False else: yield image, status_html, gr.update() generate_btn.click( fn=_enter_queue, inputs=None, outputs=[output_image, status_msg, queued_state, last_pill_state], queue=False, show_progress="hidden", ).then( fn=_generate_wrapped, inputs=[prompt, wh_ratio, negative_prompt, enable_prompt_refine, seed, guidance_scale], outputs=[output_image, status_msg, queued_state], show_progress="minimal", show_progress_on=[generate_btn], ) def clear_outputs(): logger.info("Clearing UI outputs") return None, _status_html("Ready", "info"), False, "" clear_btn.click( fn=clear_outputs, inputs=None, outputs=[output_image, status_msg, queued_state, last_pill_state], ) # Live queue updates inside the right-side pill — only while THIS # user is queued. Returns gr.update() (no-op) when nothing changed, # which prevents the DOM from being replaced and avoids the pulse # animation resetting (= no flicker). pill_timer = gr.Timer(value=1.5, active=True) def _tick_pill(queued_flag, last_html): if not queued_flag: return gr.update(), last_html waiting, running, _ = _read_queue_stats(demo) new_html = _status_html(_queue_text(waiting, running), "running") if new_html == last_html: return gr.update(), last_html return new_html, new_html pill_timer.tick( fn=_tick_pill, inputs=[queued_state, last_pill_state], outputs=[status_msg, last_pill_state], queue=False, show_progress="hidden", ) gr.HTML( """
For more features and the full experience, visit vivago.ai.
""" ) logger.info("Gradio UI created successfully") return demo if __name__ == "__main__": logger.info("Starting HiDream-O1-Image Generator application") demo = create_ui() logger.info("Launching Gradio interface with queue") demo.queue(max_size=50, default_concurrency_limit=4).launch(show_api=False) logger.info("Application shutdown")