Dreamshape-01 / app.py
Vector857's picture
Upload 3 files
0ac09d7 verified
import logging
import os
import random
import time
import traceback
from io import BytesIO
import gradio as gr
import requests
from PIL import Image
from dotenv import load_dotenv
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
load_dotenv()
# API Configuration (new style: host + gen_image_path)
API_TOKEN = os.environ.get("token")
API_HOST = os.environ.get("host")
GEN_IMAGE_PATH = os.environ.get("gen_image_path")
MODEL_ID = os.environ.get("model_id")
# Polling / retry configuration (with sensible defaults)
MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT", 3))
POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL", 2.0))
MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME", 300))
# Local-only test mode: skip the real API and return a random-colored image after
# a randomised delay. Useful for testing the UI / queueing flow without burning
# real model credits. Enable with FAKE_TEST=1 (also accepts "true"/"yes").
FAKE_TEST = os.environ.get("FAKE_TEST", "").strip().lower() in ("1", "true", "yes", "on")
# Predefined aspect ratios (wh_ratio) — kept in the same order as the original
# (width, height) list so each entry mirrors the previous resolution choice:
# 1:1 ←→ 2048×2048 4:3 ←→ 2304×1728 3:4 ←→ 1728×2304
# 16:9 ←→ 2560×1440 9:16 ←→ 1440×2560 3:2 ←→ 2496×1664
# 2:3 ←→ 1664×2496 21:9 ←→ 3104×1312 9:21 ←→ 1312×3104
# 9:7 ←→ 2304×1792 7:9 ←→ 1792×2304
WH_RATIO_OPTIONS = [
"1:1",
"4:3",
"3:4",
"16:9",
"9:16",
"3:2",
"2:3",
"21:9",
"9:21",
"9:7",
"7:9",
]
logger.info(
f"API configuration loaded: HOST={API_HOST}, GEN_IMAGE_PATH={GEN_IMAGE_PATH}, MODEL_ID={MODEL_ID}"
)
logger.info(
f"Retry configuration: MAX_RETRY_COUNT={MAX_RETRY_COUNT}, POLL_INTERVAL={POLL_INTERVAL}s, MAX_POLL_TIME={MAX_POLL_TIME}s"
)
if FAKE_TEST:
logger.warning("FAKE_TEST mode is ENABLED — no real API calls will be made.")
class APIError(Exception):
"""Custom exception for API-related errors"""
pass
# Status codes returned by the API
SUCCESS_CODE = 0
def _build_request_url() -> str:
if not API_HOST or not GEN_IMAGE_PATH:
raise APIError("API host or gen_image_path is not configured. Please set the 'host' and 'gen_image_path' environment variables.")
return f"{API_HOST.rstrip('/')}{GEN_IMAGE_PATH}"
def _build_result_url(task_id: str) -> str:
return f"{_build_request_url()}/results?task_id={task_id}"
def _headers() -> dict:
if not API_TOKEN:
raise APIError("API token is not configured. Please set the 'token' environment variable.")
return {"Authorization": f"Bearer {API_TOKEN}"}
def create_request(
prompt,
wh_ratio,
negative_prompt="",
enable_prompt_refine=True,
seed=-1,
guidance_scale=5.0,
):
"""
Submit an image generation request to the API.
Args:
prompt (str): Text prompt describing the image to generate
wh_ratio (str): Aspect ratio for the output image (e.g. "16:9")
negative_prompt (str): Optional text describing what to avoid.
enable_prompt_refine (bool): Whether to let the backend rewrite/expand
the prompt before generation. Sent to the API as 0 / 1.
seed (int): Generation seed. -1 means the backend will pick one
randomly; any other integer fixes the seed for reproducible runs.
guidance_scale (float): Classifier-free guidance strength. Higher
values follow the prompt more strictly.
Returns:
str: Task ID
Raises:
APIError: If the API request fails
"""
logger.info(
f"Starting create_request with prompt='{prompt[:50]}...', "
f"wh_ratio={wh_ratio}, enable_prompt_refine={enable_prompt_refine}, "
f"seed={seed}, guidance_scale={guidance_scale}, "
f"negative_prompt='{(negative_prompt or '')[:30]}...'"
)
if not prompt or not prompt.strip():
logger.error("Empty prompt provided to create_request")
raise ValueError("Prompt cannot be empty")
if not wh_ratio or not isinstance(wh_ratio, str) or wh_ratio not in WH_RATIO_OPTIONS:
logger.error(f"Invalid wh_ratio: {wh_ratio}. Valid options: {WH_RATIO_OPTIONS}")
raise ValueError(f"Invalid aspect ratio. Must be one of: {', '.join(WH_RATIO_OPTIONS)}")
try:
seed_int = int(seed)
except (TypeError, ValueError):
logger.warning(f"Invalid seed value '{seed}', falling back to -1 (random)")
seed_int = -1
try:
guidance_scale_f = float(guidance_scale)
except (TypeError, ValueError):
logger.warning(f"Invalid guidance_scale '{guidance_scale}', falling back to 5.0")
guidance_scale_f = 5.0
model_params = {
"prompt": prompt,
"wh_ratio": wh_ratio,
"model_id": MODEL_ID,
"n": 1,
"negative_prompt": negative_prompt or "",
"enable_prompt_refine": 1 if enable_prompt_refine else 0,
"seed": seed_int,
"guidance_scale": guidance_scale_f,
}
url = _build_request_url()
retry_count = 0
while retry_count < MAX_RETRY_COUNT:
try:
logger.info(
f"Sending API request [attempt {retry_count + 1}/{MAX_RETRY_COUNT}] for prompt: '{prompt[:50]}...'"
)
response = requests.post(url, json=model_params, headers=_headers(), timeout=15)
logger.info(f"API request response status: {response.status_code}")
response.raise_for_status()
response_json = response.json()
code = response_json.get("code")
message = response_json.get("message", "")
if code != SUCCESS_CODE:
logger.error(f"API returned error code {code}: {message}")
raise APIError(f"Failed to submit task (code={code}): {message}")
task_id = response_json.get("result", {}).get("task_id")
if not task_id:
logger.error(f"No task ID in API response: {response_json}")
raise APIError(f"No task ID returned from API: {response_json}")
logger.info(f"Successfully created task with ID: {task_id}")
return task_id
except requests.exceptions.Timeout:
retry_count += 1
logger.warning(f"Request timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
error_message = f"HTTP error {status_code}"
try:
error_detail = e.response.json()
error_message += f": {error_detail}"
logger.error(f"API response error content: {error_detail}")
except Exception:
logger.error(f"Could not parse API error response as JSON. Raw content: {e.response.content[:500]}")
if status_code == 401:
logger.error(f"Authentication failed with API token. Status code: {status_code}")
raise APIError("Authentication failed. Please check your API token.")
elif status_code == 429:
retry_count += 1
wait_time = min(2 ** retry_count, 10)
logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(wait_time)
elif 400 <= status_code < 500:
logger.error(f"Client error: {error_message}, Prompt: '{prompt[:50]}...', Status: {status_code}")
raise APIError(error_message)
else:
retry_count += 1
logger.warning(f"Server error: {error_message}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.RequestException as e:
logger.error(f"Request error: {str(e)}")
logger.debug(f"Request error details: {traceback.format_exc()}")
raise APIError(f"Failed to connect to API: {str(e)}")
except APIError:
raise
except Exception as e:
logger.error(f"Unexpected error in create_request: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise APIError(f"Unexpected error: {str(e)}")
logger.error(f"Failed to create request after {MAX_RETRY_COUNT} retries for prompt: '{prompt[:50]}...'")
raise APIError(f"Failed after {MAX_RETRY_COUNT} retries")
def get_results(task_id):
"""
Check the status of an image generation task.
Args:
task_id (str): The task ID to check
Returns:
dict: Task result information (the "result" object from the response), or None on transient failure.
Raises:
APIError: For unrecoverable errors (e.g. authentication failure).
"""
logger.debug(f"Checking status for task ID: {task_id}")
if not task_id:
logger.error("Empty task ID provided to get_results")
raise ValueError("Task ID cannot be empty")
url = _build_result_url(task_id)
try:
response = requests.get(url, headers=_headers(), timeout=10)
logger.debug(f"Status check response code: {response.status_code}")
response.raise_for_status()
response_json = response.json()
code = response_json.get("code")
message = response_json.get("message", "")
if code != SUCCESS_CODE:
logger.warning(f"API returned non-success code {code} for task {task_id}: {message}")
return None
return response_json.get("result")
except requests.exceptions.Timeout:
logger.warning(f"Request timed out when checking task {task_id}")
return None
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
logger.warning(f"HTTP error {status_code} when checking task {task_id}")
try:
error_content = e.response.json()
logger.error(f"Error response content: {error_content}")
except Exception:
logger.error(f"Could not parse error response as JSON. Raw content: {e.response.content[:500]}")
if status_code == 401:
logger.error(f"Authentication failed when checking task {task_id}")
raise APIError(f"Authentication failed. Please check your API token when checking task {task_id}")
elif 400 <= status_code < 500:
logger.error(f"Client error {status_code} when checking task {task_id}")
return None
else:
logger.warning(f"Server error {status_code} when checking task {task_id}")
return None
except requests.exceptions.RequestException as e:
logger.warning(f"Network error when checking task {task_id}: {str(e)}")
logger.debug(f"Network error details: {traceback.format_exc()}")
return None
except Exception as e:
logger.error(f"Unexpected error when checking task {task_id}: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
return None
def download_image(image_url):
"""
Download an image from a URL and return it as a PIL Image.
Converts non-PNG formats (e.g. WebP) to PNG while preserving original metadata.
"""
logger.info(f"Starting download_image from URL: {image_url}")
if not image_url:
logger.error("Empty image URL provided to download_image")
raise ValueError("Image URL cannot be empty when downloading image")
retry_count = 0
while retry_count < MAX_RETRY_COUNT:
try:
logger.info(f"Downloading image [attempt {retry_count + 1}/{MAX_RETRY_COUNT}] from {image_url}")
response = requests.get(image_url, timeout=30)
logger.debug(
f"Image download response status: {response.status_code}, "
f"Content-Type: {response.headers.get('Content-Type')}, "
f"Content-Length: {response.headers.get('Content-Length')}"
)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
logger.info(
f"Image opened successfully. Format: {image.format}, "
f"Size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}"
)
original_metadata = {}
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
original_metadata[key] = value
logger.debug(f"Original image metadata: {original_metadata}")
if image.format != 'PNG':
logger.info(f"Converting image from {image.format} to PNG format")
png_buffer = BytesIO()
if 'A' in image.getbands():
image_to_save = image
else:
image_to_save = image.convert('RGB')
image_to_save.save(png_buffer, format='PNG')
png_buffer.seek(0)
image = Image.open(png_buffer)
for key, value in original_metadata.items():
image.info[key] = value
logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}")
return image
except requests.exceptions.Timeout:
retry_count += 1
logger.warning(f"Download timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
logger.error(f"HTTP error {status_code} when downloading image from {image_url}")
if 400 <= status_code < 500:
raise APIError(f"HTTP error {status_code} when downloading image")
else:
retry_count += 1
time.sleep(1)
except requests.exceptions.RequestException as e:
retry_count += 1
logger.warning(f"Network error during image download: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except Exception as e:
logger.error(f"Error processing image from {image_url}: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise APIError(f"Failed to process image: {str(e)}")
logger.error(f"Failed to download image from {image_url} after {MAX_RETRY_COUNT} retries")
raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries")
APPLE_CSS = """
/* ---- Apple-inspired minimalist UI ---- */
.gradio-container {
font-family: -apple-system, BlinkMacSystemFont, "SF Pro Display", "SF Pro Text",
"Helvetica Neue", "Segoe UI", Inter, sans-serif !important;
background: linear-gradient(180deg, #fbfbfd 0%, #f5f5f7 100%) !important;
/* Always use ~3/4 of the viewport, capped at 1600px on huge screens.
Using width AND max-width ensures the page is wide from first paint
instead of growing only after content loads. */
width: min(1600px, 92vw) !important;
max-width: 1600px !important;
margin: 0 auto !important;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
color: #1d1d1f !important;
}
/* Cards / panels */
.panel-card {
background: #ffffff !important;
border-radius: 18px !important;
padding: 18px !important;
box-shadow: 0 1px 2px rgba(0,0,0,0.04), 0 8px 28px rgba(0,0,0,0.05) !important;
border: 1px solid rgba(0,0,0,0.05) !important;
}
/* Inputs - rounded with apple-blue focus ring */
textarea, input[type="text"], input[type="number"],
.gradio-container .form input,
.gradio-container .form textarea {
border-radius: 12px !important;
border: 1px solid #d2d2d7 !important;
background: #ffffff !important;
transition: border-color 0.15s ease, box-shadow 0.15s ease !important;
font-size: 15px !important;
}
textarea:focus, input:focus,
.gradio-container .form input:focus,
.gradio-container .form textarea:focus {
border-color: #0071e3 !important;
box-shadow: 0 0 0 4px rgba(0, 113, 227, 0.15) !important;
outline: none !important;
}
/* Dropdown */
.gradio-container .wrap.svelte-1ipelgc { border-radius: 12px !important; }
/* Block labels */
.gradio-container span[data-testid="block-label"],
.gradio-container .block-label,
.gradio-container label > span {
color: #6e6e73 !important;
font-weight: 500 !important;
font-size: 13px !important;
letter-spacing: 0.01em;
}
/* Buttons - pill shape, Apple blue. Scoped to the variant classes so it does
NOT bleed into internal buttons inside dropdowns, accordions, etc. */
.gradio-container button.primary,
.gradio-container button.secondary {
border-radius: 980px !important;
font-weight: 500 !important;
font-size: 15px !important;
padding: 12px 22px !important;
transition: transform 0.08s ease, box-shadow 0.18s ease, background 0.18s ease, opacity 0.15s ease !important;
border: none !important;
letter-spacing: 0.01em;
}
.gradio-container button.primary {
background: #0071e3 !important;
color: #ffffff !important;
box-shadow: 0 1px 2px rgba(0,113,227,0.25), 0 6px 16px rgba(0,113,227,0.22) !important;
}
.gradio-container button.primary:hover {
background: #0077ed !important;
transform: translateY(-1px);
box-shadow: 0 2px 4px rgba(0,113,227,0.3), 0 10px 22px rgba(0,113,227,0.28) !important;
}
.gradio-container button.primary:active { transform: translateY(0); }
.gradio-container button.secondary {
background: rgba(0,0,0,0.05) !important;
color: #1d1d1f !important;
box-shadow: none !important;
}
.gradio-container button.secondary:hover {
background: rgba(0,0,0,0.09) !important;
}
/* Make sure the dropdown's selected-value text never gets pill-clipped and
is properly aligned inside its rounded box. */
.gradio-container .wrap-inner,
.gradio-container .single-select,
.gradio-container .secondary-wrap {
border-radius: 12px !important;
}
.gradio-container .single-select input,
.gradio-container input[role="listbox"] {
border-radius: 12px !important;
padding: 10px 14px !important;
font-size: 15px !important;
}
/* Status pill */
#status-bar {
padding: 0 !important;
margin-top: 6px;
}
.status-pill {
display: inline-flex;
align-items: center;
gap: 9px;
background: #f5f5f7;
color: #1d1d1f;
padding: 10px 14px;
border-radius: 12px;
font-size: 13px;
font-weight: 500;
line-height: 1;
border: 1px solid rgba(0,0,0,0.04);
}
.status-dot {
width: 8px;
height: 8px;
border-radius: 50%;
background: #8e8e93;
flex-shrink: 0;
}
.status-info .status-dot { background: #8e8e93; }
.status-success { background: rgba(48,209,88,0.10); color: #0a7f2e; border-color: rgba(48,209,88,0.20); }
.status-success .status-dot { background: #30d158; }
.status-error { background: rgba(255,59,48,0.10); color: #b8261b; border-color: rgba(255,59,48,0.20); }
.status-error .status-dot { background: #ff3b30; }
.status-running { background: rgba(0,113,227,0.10); color: #0058b8; border-color: rgba(0,113,227,0.20); }
.status-running .status-dot {
background: #0071e3;
animation: pulse 1.4s ease-in-out infinite;
}
@keyframes pulse {
0%, 100% { opacity: 0.4; transform: scale(0.85); }
50% { opacity: 1.0; transform: scale(1.15); }
}
/* Image output frame — never crop the image; show full picture with letterbox. */
.image-output {
border-radius: 18px !important;
background: #f5f5f7 !important;
}
.image-output,
.image-output > div,
.image-output [data-testid="image"],
.image-output .image-container,
.image-output .image-frame,
.image-output .preview {
min-height: 440px !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
}
.image-output img {
border-radius: 14px !important;
object-fit: contain !important;
max-width: 100% !important;
max-height: 62vh !important;
width: auto !important;
height: auto !important;
}
/* Status pill placed inside the right column, above the image. */
.right-status {
display: flex;
justify-content: flex-start;
margin-bottom: 6px;
}
/* Accordion */
.gradio-container details {
border-radius: 14px !important;
border: 1px solid rgba(0,0,0,0.06) !important;
background: #ffffff !important;
}
/* Negative prompt — softer accent so it visually de-emphasises vs. the main prompt */
.negative-prompt textarea {
background: #fbfbfd !important;
border-color: #e3e3e8 !important;
}
.negative-prompt textarea:focus {
background: #ffffff !important;
}
/* Advanced options row — keeps the refine switch + seed input visually paired */
.advanced-row {
gap: 14px !important;
margin-top: 2px;
}
/* Refine toggle — iOS-style settings card.
Goal: title + helper text on the left, a polished pill switch on the right,
everything contained inside a single soft card so the helper text no longer
floats orphaned above the box. */
.refine-toggle {
background: linear-gradient(180deg, #ffffff 0%, #f5f5f7 100%) !important;
border-radius: 14px !important;
border: 1px solid rgba(0,0,0,0.06) !important;
padding: 14px 16px !important;
box-shadow: 0 1px 2px rgba(0,0,0,0.03) !important;
transition: border-color 0.18s ease, box-shadow 0.18s ease !important;
min-height: 88px;
display: flex !important;
flex-direction: column !important;
justify-content: center !important;
}
.refine-toggle:hover {
border-color: rgba(0,0,0,0.10) !important;
box-shadow: 0 1px 2px rgba(0,0,0,0.04), 0 4px 14px rgba(0,0,0,0.05) !important;
}
/* Strip default backgrounds from gradio's inner wrappers so only our card shows. */
.refine-toggle .form,
.refine-toggle .wrap,
.refine-toggle .form-wrap,
.refine-toggle > div {
background: transparent !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
box-shadow: none !important;
}
/* Helper / "info" text becomes a proper subtitle UNDER the toggle row. */
.refine-toggle [data-testid="block-info"],
.refine-toggle .info {
color: #6e6e73 !important;
font-size: 12px !important;
line-height: 1.4 !important;
margin: 8px 0 0 0 !important;
padding: 0 !important;
text-align: left !important;
order: 2 !important;
}
/* Force the gradio wrapper to stack: label-row first, info below. */
.refine-toggle .form,
.refine-toggle > div:not([data-testid="block-info"]):not(.info) {
display: flex !important;
flex-direction: column !important;
align-items: stretch !important;
}
/* Label row: title on the LEFT (full width), toggle pinned to the RIGHT. */
.refine-toggle label {
display: flex !important;
align-items: center !important;
justify-content: space-between !important;
flex-direction: row-reverse !important;
cursor: pointer !important;
margin: 0 !important;
padding: 0 !important;
gap: 14px !important;
width: 100% !important;
order: 1 !important;
}
.refine-toggle label > span {
color: #1d1d1f !important;
font-size: 15px !important;
font-weight: 600 !important;
letter-spacing: -0.01em;
flex: 1 1 auto !important;
text-align: left !important;
line-height: 1.3 !important;
}
/* Pill switch — bigger, smoother, more "Apple-like" */
.refine-toggle input[type="checkbox"] {
appearance: none;
-webkit-appearance: none;
width: 46px !important;
height: 28px !important;
border-radius: 999px !important;
background: #e5e5ea !important;
position: relative;
cursor: pointer;
transition: background 0.22s ease, box-shadow 0.22s ease;
border: none !important;
flex-shrink: 0 !important;
margin: 0 !important;
box-shadow: inset 0 0 1px rgba(0,0,0,0.06);
}
.refine-toggle input[type="checkbox"]::after {
content: "";
position: absolute;
top: 2px;
left: 2px;
width: 24px;
height: 24px;
border-radius: 50%;
background: #ffffff;
box-shadow: 0 2px 5px rgba(0,0,0,0.18), 0 0 1px rgba(0,0,0,0.05);
transition: transform 0.24s cubic-bezier(0.4, 0.0, 0.2, 1);
}
.refine-toggle input[type="checkbox"]:hover {
background: #dcdce0 !important;
}
.refine-toggle input[type="checkbox"]:checked {
background: #34c759 !important;
}
.refine-toggle input[type="checkbox"]:checked:hover {
background: #30b352 !important;
}
.refine-toggle input[type="checkbox"]:checked::after {
transform: translateX(18px);
}
.refine-toggle input[type="checkbox"]:focus-visible {
box-shadow: 0 0 0 4px rgba(0,113,227,0.20) !important;
}
.refine-toggle input[type="checkbox"]:active::after {
/* tiny squish on press, very iOS */
width: 28px;
}
.refine-toggle input[type="checkbox"]:checked:active::after {
transform: translateX(14px);
}
/* Seed number input — match the prompt/dropdown rounding */
.seed-input input[type="number"] {
border-radius: 12px !important;
padding: 10px 14px !important;
font-variant-numeric: tabular-nums;
}
/* Hide the native spinner buttons on number inputs for a cleaner look */
.seed-input input[type="number"]::-webkit-outer-spin-button,
.seed-input input[type="number"]::-webkit-inner-spin-button {
-webkit-appearance: none;
margin: 0;
}
.seed-input input[type="number"] {
-moz-appearance: textfield;
}
/* Keep the seed column visually aligned with the refine card next to it */
.seed-input {
align-self: stretch !important;
}
/* Guidance scale slider — Apple-blue track + softer thumb */
.guidance-slider input[type="range"] {
accent-color: #0071e3 !important;
}
.guidance-slider .head { padding-top: 0 !important; }
.guidance-slider {
margin-top: 4px;
}
/* Footer tagline */
.tagline {
text-align: center;
color: #6e6e73;
font-size: 12px;
margin: 18px 0 14px 0;
font-weight: 400;
}
.tagline a {
color: #0071e3;
text-decoration: none;
font-weight: 500;
transition: opacity 0.15s ease;
}
.tagline a:hover { opacity: 0.7; }
/* Footer links row (HuggingFace / GitHub / Twitter) */
.footer-links {
display: flex;
justify-content: center;
align-items: center;
gap: 26px;
flex-wrap: wrap;
font-size: 13px;
margin: 24px 0 6px 0;
}
/* When the vivago tagline follows the links row, tighten the gap. */
.footer-links + .tagline {
margin-top: 4px;
}
/* Hide gradio's default footer for a cleaner look */
footer { display: none !important; }
/* Mobile */
@media (max-width: 640px) {
.footer-links { gap: 18px; font-size: 12px; }
}
"""
APPLE_THEME = gr.themes.Soft(
primary_hue=gr.themes.colors.blue,
neutral_hue=gr.themes.colors.slate,
radius_size=gr.themes.sizes.radius_lg,
text_size=gr.themes.sizes.text_md,
font=[
gr.themes.GoogleFont("Inter"),
"ui-sans-serif",
"-apple-system",
"BlinkMacSystemFont",
"Segoe UI",
"Helvetica Neue",
"sans-serif",
],
).set(
body_background_fill="*neutral_50",
block_background_fill="white",
block_border_width="1px",
block_label_text_weight="500",
block_title_text_weight="600",
button_primary_background_fill="#0071e3",
button_primary_background_fill_hover="#0077ed",
button_primary_text_color="white",
button_primary_border_color="#0071e3",
button_secondary_background_fill="rgba(0,0,0,0.05)",
button_secondary_background_fill_hover="rgba(0,0,0,0.09)",
button_secondary_text_color="#1d1d1f",
input_background_fill="white",
input_border_color="#d2d2d7",
input_border_color_focus="#0071e3",
input_shadow_focus="0 0 0 4px rgba(0,113,227,0.15)",
)
def _status_html(text: str, kind: str = "info") -> str:
"""Render a styled status pill. kind: info | success | error | running"""
return (
f'<div class="status-pill status-{kind}">'
f'<span class="status-dot"></span>{text}'
f'</div>'
)
def _queue_text(waiting: int, running: int) -> str:
"""Inline status text describing the queue from the user's POV.
Used inside the right-side status pill while THIS user's request is
sitting in the queue waiting for an open slot.
"""
if waiting <= 0 and running <= 0:
return "Queued · waiting for an open slot…"
parts = []
if waiting > 0:
parts.append(f"{waiting} waiting")
if running > 0:
parts.append(f"{running} generating")
return f"In queue · {' · '.join(parts)}"
def _read_queue_stats(demo_obj) -> tuple[int, int, float]:
"""Best-effort read of gradio's internal queue.
Gradio 5.x structure (gradio/queueing.py):
demo._queue.event_queue_per_concurrency_id: dict[str, EventQueue]
EventQueue.queue: list[Event] ← actually-waiting events
demo._queue.active_jobs: list[None | list[Event]]
each slot is None (idle) or a list of currently-processing events.
demo._queue.process_time_per_fn: dict[BlockFunction, ProcessTime]
ProcessTime.avg_time: float
Returns (waiting, running, avg_secs). Each lookup is wrapped in a try
so the UI degrades gracefully ("idle") if Gradio ever renames a field.
"""
try:
q = getattr(demo_obj, "_queue", None)
if q is None:
return 0, 0, 0.0
# ---- Waiting: events sitting in EventQueue.queue ----
waiting = 0
events_per_cid = getattr(q, "event_queue_per_concurrency_id", None) or {}
for ev_q in events_per_cid.values():
# Newer gradio: ev_q is an EventQueue with a .queue list.
# Older/alt: ev_q might already be a list. Handle both.
inner = getattr(ev_q, "queue", ev_q)
try:
waiting += len(inner)
except (TypeError, AttributeError):
# Last resort: try iterating
try:
waiting += sum(1 for _ in inner)
except Exception:
continue
# ---- Running: count events held in active_jobs slots ----
# Each slot is None or a list[Event]; sum the list lengths.
running = 0
active = getattr(q, "active_jobs", None) or []
for slot in active:
if slot is None:
continue
try:
running += len(slot)
except (TypeError, AttributeError):
running += 1 # very old gradio: single Event per slot
# ---- Average per-run time (best effort across versions) ----
avg_secs = 0.0
# Gradio 5.x: dict[BlockFunction, ProcessTime] with .avg_time
ptpf = getattr(q, "process_time_per_fn", None)
if isinstance(ptpf, dict) and ptpf:
try:
vals = []
for v in ptpf.values():
avg_t = getattr(v, "avg_time", None)
if avg_t:
vals.append(float(avg_t))
if vals:
avg_secs = sum(vals) / len(vals)
except Exception:
avg_secs = 0.0
# Older: dict[int, float]
if not avg_secs:
ptpfi = getattr(q, "process_time_per_fn_index", None)
if isinstance(ptpfi, dict) and ptpfi:
try:
vals = [float(v) for v in ptpfi.values() if v]
if vals:
avg_secs = sum(vals) / len(vals)
except Exception:
avg_secs = 0.0
# Oldest: single float on the queue
if not avg_secs:
apt = getattr(q, "avg_process_time", None)
if apt:
try:
avg_secs = float(apt)
except Exception:
avg_secs = 0.0
return waiting, running, avg_secs
except Exception as exc:
logger.debug(f"Queue introspection failed: {exc}")
return 0, 0, 0.0
def _fake_generation_iter(prompt: str, wh_ratio_value: str):
"""FAKE_TEST mode generator.
Mimics the real flow's yield protocol without hitting any external API.
Useful for exercising the queue UI / status pill locally.
Yields (image_or_None, status_html) tuples. The very first 'Sending
request to API…' yield is emitted by the caller, so this iterator picks
up from 'Request submitted' onwards.
"""
time.sleep(random.uniform(0.4, 1.0))
fake_id = f"{random.randint(0, 0xFFFFFFFF):08x}"
yield None, _status_html(f"Request submitted · Task {fake_id}…", "running")
target_secs = random.uniform(8.0, 22.0)
start = time.time()
while time.time() - start < target_secs:
elapsed = int(time.time() - start)
yield None, _status_html(f"Generating… {elapsed}s", "running")
time.sleep(POLL_INTERVAL)
yield None, _status_html("Downloading image…", "running")
time.sleep(0.3)
rgb = (
random.randint(40, 220),
random.randint(40, 220),
random.randint(40, 220),
)
fake_image = Image.new("RGB", (1024, 1024), color=rgb)
logger.info(f"FAKE_TEST: returning random color image rgb={rgb}, took {target_secs:.1f}s")
yield fake_image, _status_html("Image generated", "success")
def create_ui():
logger.info("Creating Gradio UI")
with gr.Blocks(
title="HiDream-O1-Image Generator",
theme=APPLE_THEME,
css=APPLE_CSS,
) as demo:
# Per-session state used to drive the right-side status pill while
# this user's request is sitting in the queue. `last_pill_state`
# caches the most recent pill HTML so the timer can return
# `gr.update()` (=no-op, no DOM replacement → no flicker) when the
# queue counts haven't actually changed between ticks.
queued_state = gr.State(False)
last_pill_state = gr.State("")
with gr.Row(equal_height=False):
with gr.Column(scale=1, elem_classes=["panel-card"]):
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to create...",
lines=5,
show_label=True,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Things you want to avoid in the image (optional)...",
lines=2,
show_label=True,
elem_classes=["negative-prompt"],
)
wh_ratio = gr.Dropdown(
choices=WH_RATIO_OPTIONS,
value=WH_RATIO_OPTIONS[0],
label="Aspect Ratio",
info="Width : Height",
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=20.0,
step=0.1,
value=5.0,
label="Guidance Scale",
info="Higher values follow the prompt more strictly",
elem_classes=["guidance-slider"],
)
with gr.Row(elem_classes=["advanced-row"], equal_height=True):
enable_prompt_refine = gr.Checkbox(
value=True,
label="Prompt Refine",
info="Let the model rewrite & enrich your prompt",
elem_classes=["refine-toggle"],
scale=1,
)
seed = gr.Number(
value=-1,
label="Seed",
info="Use -1 for a random seed",
precision=0,
minimum=-1,
elem_classes=["seed-input"],
scale=1,
)
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary", scale=1)
generate_btn = gr.Button("Generate", variant="primary", scale=3)
with gr.Column(scale=1, elem_classes=["panel-card"]):
status_msg = gr.HTML(
value=_status_html("Ready", "info"),
elem_id="status-bar",
elem_classes=["right-status"],
)
output_image = gr.Image(
label="Generated Image",
format="png",
type="pil",
interactive=False,
show_download_button=True,
elem_classes=["image-output"],
)
def generate_with_status(
prompt,
wh_ratio_value,
negative_prompt_value,
enable_prompt_refine_value,
seed_value,
guidance_scale_value,
):
logger.info(
f"Starting image generation with prompt='{(prompt or '')[:50]}...', "
f"wh_ratio={wh_ratio_value}, "
f"negative_prompt='{(negative_prompt_value or '')[:30]}...', "
f"enable_prompt_refine={enable_prompt_refine_value}, "
f"seed={seed_value}, guidance_scale={guidance_scale_value}"
)
yield None, _status_html("Sending request to API…", "running")
try:
if not prompt or not prompt.strip():
logger.error("Empty prompt provided in UI")
yield None, _status_html("Prompt cannot be empty", "error")
return
if wh_ratio_value not in WH_RATIO_OPTIONS:
logger.error(f"Invalid aspect ratio selection: {wh_ratio_value}")
yield None, _status_html(f"Invalid aspect ratio “{wh_ratio_value}”", "error")
return
try:
seed_int = int(seed_value) if seed_value is not None else -1
except (TypeError, ValueError):
seed_int = -1
try:
guidance_scale_f = float(guidance_scale_value) if guidance_scale_value is not None else 5.0
except (TypeError, ValueError):
guidance_scale_f = 5.0
if FAKE_TEST:
logger.info("FAKE_TEST mode active — bypassing real API call")
yield from _fake_generation_iter(prompt, wh_ratio_value)
return
logger.info("Creating API request")
task_id = create_request(
prompt,
wh_ratio_value,
negative_prompt=negative_prompt_value or "",
enable_prompt_refine=bool(enable_prompt_refine_value),
seed=seed_int,
guidance_scale=guidance_scale_f,
)
yield None, _status_html(f"Request submitted · Task {task_id[:8]}…", "running")
start_time = time.time()
logger.info(f"Starting to poll for results for task ID: {task_id}")
while time.time() - start_time < MAX_POLL_TIME:
elapsed_time = time.time() - start_time
logger.debug(
f"Polling for results - Task ID: {task_id}, Elapsed: {elapsed_time:.2f}s"
)
result = get_results(task_id)
if not result:
time.sleep(POLL_INTERVAL)
continue
overall_status = result.get("status")
sub_results = result.get("sub_task_results", []) or []
if overall_status != 1:
elapsed = int(time.time() - start_time)
yield None, _status_html(f"Generating… {elapsed}s", "running")
time.sleep(POLL_INTERVAL)
continue
if not sub_results:
logger.error(f"Task completed but no sub_task_results returned. Task ID: {task_id}")
yield None, _status_html("Task completed but no results returned", "error")
return
sub = sub_results[0]
sub_status = sub.get("task_status")
if sub_status == 1:
logger.info(f"Task completed successfully - Task ID: {task_id}")
image_url = sub.get("url")
if not image_url:
logger.error(f"No image URL in successful response. Sub result: {sub}")
yield None, _status_html("No image URL in response", "error")
return
yield None, _status_html("Downloading image…", "running")
logger.info(f"Downloading image - Task ID: {task_id}, URL: {image_url}")
image = download_image(image_url)
if image:
logger.info(f"Image generation complete - Task ID: {task_id}")
yield image, _status_html("Image generated", "success")
return
else:
logger.error(f"Failed to download image - Task ID: {task_id}, URL: {image_url}")
yield None, _status_html("Failed to download generated image", "error")
return
elif sub_status == 3:
error_msg = sub.get("task_error") or sub.get("message") or "Unknown error"
logger.error(
f"Task failed - Task ID: {task_id}, Sub status: {sub_status}, Error: {error_msg}"
)
yield None, _status_html(f"Task failed: {error_msg}", "error")
return
else:
elapsed = int(time.time() - start_time)
yield None, _status_html(f"Waiting… {elapsed}s", "running")
time.sleep(POLL_INTERVAL)
logger.error(
f"Timeout waiting for task completion - Task ID: {task_id}, Max time: {MAX_POLL_TIME}s"
)
yield None, _status_html(f"Timed out after {MAX_POLL_TIME}s", "error")
except APIError as e:
logger.error(f"API Error during generation: {str(e)}")
yield None, _status_html(f"API error: {str(e)}", "error")
except ValueError as e:
logger.error(f"Value Error during generation: {str(e)}")
yield None, _status_html(f"Value error: {str(e)}", "error")
except Exception as e:
logger.error(f"Unexpected error during image generation: {str(e)}")
logger.error(f"Full traceback: {traceback.format_exc()}")
yield None, _status_html(f"Unexpected error: {str(e)}", "error")
def _enter_queue():
"""Click handler #1 — runs IMMEDIATELY (queue=False).
Flips queued_state to True and seeds the pill with a snapshot of
the current queue. Subsequent live updates come from the timer.
"""
waiting, running, _ = _read_queue_stats(demo)
html = _status_html(_queue_text(waiting, running), "running")
return None, html, True, html
def _generate_wrapped(
prompt_value,
wh_ratio_value,
negative_prompt_value,
enable_prompt_refine_value,
seed_value,
guidance_scale_value,
):
"""Click handler #2 — queued.
Wraps the existing generator and, on the FIRST yield, also flips
queued_state to False so the timer stops touching the pill and
lets the generator's own `yield`s drive it (Generating XXs → ...).
"""
first = True
for image, status_html in generate_with_status(
prompt_value,
wh_ratio_value,
negative_prompt_value,
enable_prompt_refine_value,
seed_value,
guidance_scale_value,
):
if first:
first = False
yield image, status_html, False
else:
yield image, status_html, gr.update()
generate_btn.click(
fn=_enter_queue,
inputs=None,
outputs=[output_image, status_msg, queued_state, last_pill_state],
queue=False,
show_progress="hidden",
).then(
fn=_generate_wrapped,
inputs=[prompt, wh_ratio, negative_prompt, enable_prompt_refine, seed, guidance_scale],
outputs=[output_image, status_msg, queued_state],
show_progress="minimal",
show_progress_on=[generate_btn],
)
def clear_outputs():
logger.info("Clearing UI outputs")
return None, _status_html("Ready", "info"), False, ""
clear_btn.click(
fn=clear_outputs,
inputs=None,
outputs=[output_image, status_msg, queued_state, last_pill_state],
)
# Live queue updates inside the right-side pill — only while THIS
# user is queued. Returns gr.update() (no-op) when nothing changed,
# which prevents the DOM from being replaced and avoids the pulse
# animation resetting (= no flicker).
pill_timer = gr.Timer(value=1.5, active=True)
def _tick_pill(queued_flag, last_html):
if not queued_flag:
return gr.update(), last_html
waiting, running, _ = _read_queue_stats(demo)
new_html = _status_html(_queue_text(waiting, running), "running")
if new_html == last_html:
return gr.update(), last_html
return new_html, new_html
pill_timer.tick(
fn=_tick_pill,
inputs=[queued_state, last_pill_state],
outputs=[status_msg, last_pill_state],
queue=False,
show_progress="hidden",
)
gr.HTML(
"""
<div class="tagline footer-links">
<a href="https://huggingface.co/HiDream-ai/HiDream-O1-Image" target="_blank">HuggingFace</a>
<a href="https://github.com/HiDream-ai/HiDream-O1-Image" target="_blank">GitHub</a>
<a href="https://x.com/vivago_ai" target="_blank">Twitter</a>
</div>
<div class="tagline">
For more features and the full experience, visit
<a href="https://vivago.ai/" target="_blank">vivago.ai</a>.
</div>
"""
)
logger.info("Gradio UI created successfully")
return demo
if __name__ == "__main__":
logger.info("Starting HiDream-O1-Image Generator application")
demo = create_ui()
logger.info("Launching Gradio interface with queue")
demo.queue(max_size=50, default_concurrency_limit=4).launch(show_api=False)
logger.info("Application shutdown")