import logging import os import random import time import traceback from io import BytesIO import gradio as gr import requests from PIL import Image from dotenv import load_dotenv logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) load_dotenv() # API Configuration (new style: host + gen_image_path) API_TOKEN = os.environ.get("token") API_HOST = os.environ.get("host") GEN_IMAGE_PATH = os.environ.get("gen_image_path") MODEL_ID = os.environ.get("model_id") # Polling / retry configuration (with sensible defaults) MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT", 3)) POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL", 2.0)) MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME", 300)) # Local-only test mode: skip the real API and return a random-colored image after # a randomised delay. Useful for testing the UI / queueing flow without burning # real model credits. Enable with FAKE_TEST=1 (also accepts "true"/"yes"). FAKE_TEST = os.environ.get("FAKE_TEST", "").strip().lower() in ("1", "true", "yes", "on") # Predefined aspect ratios (wh_ratio) — kept in the same order as the original # (width, height) list so each entry mirrors the previous resolution choice: # 1:1 ←→ 2048×2048 4:3 ←→ 2304×1728 3:4 ←→ 1728×2304 # 16:9 ←→ 2560×1440 9:16 ←→ 1440×2560 3:2 ←→ 2496×1664 # 2:3 ←→ 1664×2496 21:9 ←→ 3104×1312 9:21 ←→ 1312×3104 # 9:7 ←→ 2304×1792 7:9 ←→ 1792×2304 WH_RATIO_OPTIONS = [ "1:1", "4:3", "3:4", "16:9", "9:16", "3:2", "2:3", "21:9", "9:21", "9:7", "7:9", ] logger.info( f"API configuration loaded: HOST={API_HOST}, GEN_IMAGE_PATH={GEN_IMAGE_PATH}, MODEL_ID={MODEL_ID}" ) logger.info( f"Retry configuration: MAX_RETRY_COUNT={MAX_RETRY_COUNT}, POLL_INTERVAL={POLL_INTERVAL}s, MAX_POLL_TIME={MAX_POLL_TIME}s" ) if FAKE_TEST: logger.warning("FAKE_TEST mode is ENABLED — no real API calls will be made.") class APIError(Exception): """Custom exception for API-related errors""" pass # Status codes returned by the API SUCCESS_CODE = 0 def _build_request_url() -> str: if not API_HOST or not GEN_IMAGE_PATH: raise APIError("API host or gen_image_path is not configured. Please set the 'host' and 'gen_image_path' environment variables.") return f"{API_HOST.rstrip('/')}{GEN_IMAGE_PATH}" def _build_result_url(task_id: str) -> str: return f"{_build_request_url()}/results?task_id={task_id}" def _headers() -> dict: if not API_TOKEN: raise APIError("API token is not configured. Please set the 'token' environment variable.") return {"Authorization": f"Bearer {API_TOKEN}"} def create_request( prompt, wh_ratio, negative_prompt="", enable_prompt_refine=True, seed=-1, guidance_scale=5.0, ): """ Submit an image generation request to the API. Args: prompt (str): Text prompt describing the image to generate wh_ratio (str): Aspect ratio for the output image (e.g. "16:9") negative_prompt (str): Optional text describing what to avoid. enable_prompt_refine (bool): Whether to let the backend rewrite/expand the prompt before generation. Sent to the API as 0 / 1. seed (int): Generation seed. -1 means the backend will pick one randomly; any other integer fixes the seed for reproducible runs. guidance_scale (float): Classifier-free guidance strength. Higher values follow the prompt more strictly. Returns: str: Task ID Raises: APIError: If the API request fails """ logger.info( f"Starting create_request with prompt='{prompt[:50]}...', " f"wh_ratio={wh_ratio}, enable_prompt_refine={enable_prompt_refine}, " f"seed={seed}, guidance_scale={guidance_scale}, " f"negative_prompt='{(negative_prompt or '')[:30]}...'" ) if not prompt or not prompt.strip(): logger.error("Empty prompt provided to create_request") raise ValueError("Prompt cannot be empty") if not wh_ratio or not isinstance(wh_ratio, str) or wh_ratio not in WH_RATIO_OPTIONS: logger.error(f"Invalid wh_ratio: {wh_ratio}. Valid options: {WH_RATIO_OPTIONS}") raise ValueError(f"Invalid aspect ratio. Must be one of: {', '.join(WH_RATIO_OPTIONS)}") try: seed_int = int(seed) except (TypeError, ValueError): logger.warning(f"Invalid seed value '{seed}', falling back to -1 (random)") seed_int = -1 try: guidance_scale_f = float(guidance_scale) except (TypeError, ValueError): logger.warning(f"Invalid guidance_scale '{guidance_scale}', falling back to 5.0") guidance_scale_f = 5.0 model_params = { "prompt": prompt, "wh_ratio": wh_ratio, "model_id": MODEL_ID, "n": 1, "negative_prompt": negative_prompt or "", "enable_prompt_refine": 1 if enable_prompt_refine else 0, "seed": seed_int, "guidance_scale": guidance_scale_f, } url = _build_request_url() retry_count = 0 while retry_count < MAX_RETRY_COUNT: try: logger.info( f"Sending API request [attempt {retry_count + 1}/{MAX_RETRY_COUNT}] for prompt: '{prompt[:50]}...'" ) response = requests.post(url, json=model_params, headers=_headers(), timeout=15) logger.info(f"API request response status: {response.status_code}") response.raise_for_status() response_json = response.json() code = response_json.get("code") message = response_json.get("message", "") if code != SUCCESS_CODE: logger.error(f"API returned error code {code}: {message}") raise APIError(f"Failed to submit task (code={code}): {message}") task_id = response_json.get("result", {}).get("task_id") if not task_id: logger.error(f"No task ID in API response: {response_json}") raise APIError(f"No task ID returned from API: {response_json}") logger.info(f"Successfully created task with ID: {task_id}") return task_id except requests.exceptions.Timeout: retry_count += 1 logger.warning(f"Request timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(1) except requests.exceptions.HTTPError as e: status_code = e.response.status_code error_message = f"HTTP error {status_code}" try: error_detail = e.response.json() error_message += f": {error_detail}" logger.error(f"API response error content: {error_detail}") except Exception: logger.error(f"Could not parse API error response as JSON. Raw content: {e.response.content[:500]}") if status_code == 401: logger.error(f"Authentication failed with API token. Status code: {status_code}") raise APIError("Authentication failed. Please check your API token.") elif status_code == 429: retry_count += 1 wait_time = min(2 ** retry_count, 10) logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(wait_time) elif 400 <= status_code < 500: logger.error(f"Client error: {error_message}, Prompt: '{prompt[:50]}...', Status: {status_code}") raise APIError(error_message) else: retry_count += 1 logger.warning(f"Server error: {error_message}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(1) except requests.exceptions.RequestException as e: logger.error(f"Request error: {str(e)}") logger.debug(f"Request error details: {traceback.format_exc()}") raise APIError(f"Failed to connect to API: {str(e)}") except APIError: raise except Exception as e: logger.error(f"Unexpected error in create_request: {str(e)}") logger.error(f"Full traceback: {traceback.format_exc()}") raise APIError(f"Unexpected error: {str(e)}") logger.error(f"Failed to create request after {MAX_RETRY_COUNT} retries for prompt: '{prompt[:50]}...'") raise APIError(f"Failed after {MAX_RETRY_COUNT} retries") def get_results(task_id): """ Check the status of an image generation task. Args: task_id (str): The task ID to check Returns: dict: Task result information (the "result" object from the response), or None on transient failure. Raises: APIError: For unrecoverable errors (e.g. authentication failure). """ logger.debug(f"Checking status for task ID: {task_id}") if not task_id: logger.error("Empty task ID provided to get_results") raise ValueError("Task ID cannot be empty") url = _build_result_url(task_id) try: response = requests.get(url, headers=_headers(), timeout=10) logger.debug(f"Status check response code: {response.status_code}") response.raise_for_status() response_json = response.json() code = response_json.get("code") message = response_json.get("message", "") if code != SUCCESS_CODE: logger.warning(f"API returned non-success code {code} for task {task_id}: {message}") return None return response_json.get("result") except requests.exceptions.Timeout: logger.warning(f"Request timed out when checking task {task_id}") return None except requests.exceptions.HTTPError as e: status_code = e.response.status_code logger.warning(f"HTTP error {status_code} when checking task {task_id}") try: error_content = e.response.json() logger.error(f"Error response content: {error_content}") except Exception: logger.error(f"Could not parse error response as JSON. Raw content: {e.response.content[:500]}") if status_code == 401: logger.error(f"Authentication failed when checking task {task_id}") raise APIError(f"Authentication failed. Please check your API token when checking task {task_id}") elif 400 <= status_code < 500: logger.error(f"Client error {status_code} when checking task {task_id}") return None else: logger.warning(f"Server error {status_code} when checking task {task_id}") return None except requests.exceptions.RequestException as e: logger.warning(f"Network error when checking task {task_id}: {str(e)}") logger.debug(f"Network error details: {traceback.format_exc()}") return None except Exception as e: logger.error(f"Unexpected error when checking task {task_id}: {str(e)}") logger.error(f"Full traceback: {traceback.format_exc()}") return None def download_image(image_url): """ Download an image from a URL and return it as a PIL Image. Converts non-PNG formats (e.g. WebP) to PNG while preserving original metadata. """ logger.info(f"Starting download_image from URL: {image_url}") if not image_url: logger.error("Empty image URL provided to download_image") raise ValueError("Image URL cannot be empty when downloading image") retry_count = 0 while retry_count < MAX_RETRY_COUNT: try: logger.info(f"Downloading image [attempt {retry_count + 1}/{MAX_RETRY_COUNT}] from {image_url}") response = requests.get(image_url, timeout=30) logger.debug( f"Image download response status: {response.status_code}, " f"Content-Type: {response.headers.get('Content-Type')}, " f"Content-Length: {response.headers.get('Content-Length')}" ) response.raise_for_status() image = Image.open(BytesIO(response.content)) logger.info( f"Image opened successfully. Format: {image.format}, " f"Size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}" ) original_metadata = {} for key, value in image.info.items(): if isinstance(key, str) and isinstance(value, str): original_metadata[key] = value logger.debug(f"Original image metadata: {original_metadata}") if image.format != 'PNG': logger.info(f"Converting image from {image.format} to PNG format") png_buffer = BytesIO() if 'A' in image.getbands(): image_to_save = image else: image_to_save = image.convert('RGB') image_to_save.save(png_buffer, format='PNG') png_buffer.seek(0) image = Image.open(png_buffer) for key, value in original_metadata.items(): image.info[key] = value logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}") return image except requests.exceptions.Timeout: retry_count += 1 logger.warning(f"Download timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(1) except requests.exceptions.HTTPError as e: status_code = e.response.status_code logger.error(f"HTTP error {status_code} when downloading image from {image_url}") if 400 <= status_code < 500: raise APIError(f"HTTP error {status_code} when downloading image") else: retry_count += 1 time.sleep(1) except requests.exceptions.RequestException as e: retry_count += 1 logger.warning(f"Network error during image download: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...") time.sleep(1) except Exception as e: logger.error(f"Error processing image from {image_url}: {str(e)}") logger.error(f"Full traceback: {traceback.format_exc()}") raise APIError(f"Failed to process image: {str(e)}") logger.error(f"Failed to download image from {image_url} after {MAX_RETRY_COUNT} retries") raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries") APPLE_CSS = """ /* ---- Apple-inspired minimalist UI ---- */ .gradio-container { font-family: -apple-system, BlinkMacSystemFont, "SF Pro Display", "SF Pro Text", "Helvetica Neue", "Segoe UI", Inter, sans-serif !important; background: linear-gradient(180deg, #fbfbfd 0%, #f5f5f7 100%) !important; /* Always use ~3/4 of the viewport, capped at 1600px on huge screens. Using width AND max-width ensures the page is wide from first paint instead of growing only after content loads. */ width: min(1600px, 92vw) !important; max-width: 1600px !important; margin: 0 auto !important; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; color: #1d1d1f !important; } /* Cards / panels */ .panel-card { background: #ffffff !important; border-radius: 18px !important; padding: 18px !important; box-shadow: 0 1px 2px rgba(0,0,0,0.04), 0 8px 28px rgba(0,0,0,0.05) !important; border: 1px solid rgba(0,0,0,0.05) !important; } /* Inputs - rounded with apple-blue focus ring */ textarea, input[type="text"], input[type="number"], .gradio-container .form input, .gradio-container .form textarea { border-radius: 12px !important; border: 1px solid #d2d2d7 !important; background: #ffffff !important; transition: border-color 0.15s ease, box-shadow 0.15s ease !important; font-size: 15px !important; } textarea:focus, input:focus, .gradio-container .form input:focus, .gradio-container .form textarea:focus { border-color: #0071e3 !important; box-shadow: 0 0 0 4px rgba(0, 113, 227, 0.15) !important; outline: none !important; } /* Dropdown */ .gradio-container .wrap.svelte-1ipelgc { border-radius: 12px !important; } /* Block labels */ .gradio-container span[data-testid="block-label"], .gradio-container .block-label, .gradio-container label > span { color: #6e6e73 !important; font-weight: 500 !important; font-size: 13px !important; letter-spacing: 0.01em; } /* Buttons - pill shape, Apple blue. Scoped to the variant classes so it does NOT bleed into internal buttons inside dropdowns, accordions, etc. */ .gradio-container button.primary, .gradio-container button.secondary { border-radius: 980px !important; font-weight: 500 !important; font-size: 15px !important; padding: 12px 22px !important; transition: transform 0.08s ease, box-shadow 0.18s ease, background 0.18s ease, opacity 0.15s ease !important; border: none !important; letter-spacing: 0.01em; } .gradio-container button.primary { background: #0071e3 !important; color: #ffffff !important; box-shadow: 0 1px 2px rgba(0,113,227,0.25), 0 6px 16px rgba(0,113,227,0.22) !important; } .gradio-container button.primary:hover { background: #0077ed !important; transform: translateY(-1px); box-shadow: 0 2px 4px rgba(0,113,227,0.3), 0 10px 22px rgba(0,113,227,0.28) !important; } .gradio-container button.primary:active { transform: translateY(0); } .gradio-container button.secondary { background: rgba(0,0,0,0.05) !important; color: #1d1d1f !important; box-shadow: none !important; } .gradio-container button.secondary:hover { background: rgba(0,0,0,0.09) !important; } /* Make sure the dropdown's selected-value text never gets pill-clipped and is properly aligned inside its rounded box. */ .gradio-container .wrap-inner, .gradio-container .single-select, .gradio-container .secondary-wrap { border-radius: 12px !important; } .gradio-container .single-select input, .gradio-container input[role="listbox"] { border-radius: 12px !important; padding: 10px 14px !important; font-size: 15px !important; } /* Status pill */ #status-bar { padding: 0 !important; margin-top: 6px; } .status-pill { display: inline-flex; align-items: center; gap: 9px; background: #f5f5f7; color: #1d1d1f; padding: 10px 14px; border-radius: 12px; font-size: 13px; font-weight: 500; line-height: 1; border: 1px solid rgba(0,0,0,0.04); } .status-dot { width: 8px; height: 8px; border-radius: 50%; background: #8e8e93; flex-shrink: 0; } .status-info .status-dot { background: #8e8e93; } .status-success { background: rgba(48,209,88,0.10); color: #0a7f2e; border-color: rgba(48,209,88,0.20); } .status-success .status-dot { background: #30d158; } .status-error { background: rgba(255,59,48,0.10); color: #b8261b; border-color: rgba(255,59,48,0.20); } .status-error .status-dot { background: #ff3b30; } .status-running { background: rgba(0,113,227,0.10); color: #0058b8; border-color: rgba(0,113,227,0.20); } .status-running .status-dot { background: #0071e3; animation: pulse 1.4s ease-in-out infinite; } @keyframes pulse { 0%, 100% { opacity: 0.4; transform: scale(0.85); } 50% { opacity: 1.0; transform: scale(1.15); } } /* Image output frame — never crop the image; show full picture with letterbox. */ .image-output { border-radius: 18px !important; background: #f5f5f7 !important; } .image-output, .image-output > div, .image-output [data-testid="image"], .image-output .image-container, .image-output .image-frame, .image-output .preview { min-height: 440px !important; display: flex !important; align-items: center !important; justify-content: center !important; } .image-output img { border-radius: 14px !important; object-fit: contain !important; max-width: 100% !important; max-height: 62vh !important; width: auto !important; height: auto !important; } /* Status pill placed inside the right column, above the image. */ .right-status { display: flex; justify-content: flex-start; margin-bottom: 6px; } /* Accordion */ .gradio-container details { border-radius: 14px !important; border: 1px solid rgba(0,0,0,0.06) !important; background: #ffffff !important; } /* Negative prompt — softer accent so it visually de-emphasises vs. the main prompt */ .negative-prompt textarea { background: #fbfbfd !important; border-color: #e3e3e8 !important; } .negative-prompt textarea:focus { background: #ffffff !important; } /* Advanced options row — keeps the refine switch + seed input visually paired */ .advanced-row { gap: 14px !important; margin-top: 2px; } /* Refine toggle — iOS-style settings card. Goal: title + helper text on the left, a polished pill switch on the right, everything contained inside a single soft card so the helper text no longer floats orphaned above the box. */ .refine-toggle { background: linear-gradient(180deg, #ffffff 0%, #f5f5f7 100%) !important; border-radius: 14px !important; border: 1px solid rgba(0,0,0,0.06) !important; padding: 14px 16px !important; box-shadow: 0 1px 2px rgba(0,0,0,0.03) !important; transition: border-color 0.18s ease, box-shadow 0.18s ease !important; min-height: 88px; display: flex !important; flex-direction: column !important; justify-content: center !important; } .refine-toggle:hover { border-color: rgba(0,0,0,0.10) !important; box-shadow: 0 1px 2px rgba(0,0,0,0.04), 0 4px 14px rgba(0,0,0,0.05) !important; } /* Strip default backgrounds from gradio's inner wrappers so only our card shows. */ .refine-toggle .form, .refine-toggle .wrap, .refine-toggle .form-wrap, .refine-toggle > div { background: transparent !important; border: none !important; padding: 0 !important; margin: 0 !important; box-shadow: none !important; } /* Helper / "info" text becomes a proper subtitle UNDER the toggle row. */ .refine-toggle [data-testid="block-info"], .refine-toggle .info { color: #6e6e73 !important; font-size: 12px !important; line-height: 1.4 !important; margin: 8px 0 0 0 !important; padding: 0 !important; text-align: left !important; order: 2 !important; } /* Force the gradio wrapper to stack: label-row first, info below. */ .refine-toggle .form, .refine-toggle > div:not([data-testid="block-info"]):not(.info) { display: flex !important; flex-direction: column !important; align-items: stretch !important; } /* Label row: title on the LEFT (full width), toggle pinned to the RIGHT. */ .refine-toggle label { display: flex !important; align-items: center !important; justify-content: space-between !important; flex-direction: row-reverse !important; cursor: pointer !important; margin: 0 !important; padding: 0 !important; gap: 14px !important; width: 100% !important; order: 1 !important; } .refine-toggle label > span { color: #1d1d1f !important; font-size: 15px !important; font-weight: 600 !important; letter-spacing: -0.01em; flex: 1 1 auto !important; text-align: left !important; line-height: 1.3 !important; } /* Pill switch — bigger, smoother, more "Apple-like" */ .refine-toggle input[type="checkbox"] { appearance: none; -webkit-appearance: none; width: 46px !important; height: 28px !important; border-radius: 999px !important; background: #e5e5ea !important; position: relative; cursor: pointer; transition: background 0.22s ease, box-shadow 0.22s ease; border: none !important; flex-shrink: 0 !important; margin: 0 !important; box-shadow: inset 0 0 1px rgba(0,0,0,0.06); } .refine-toggle input[type="checkbox"]::after { content: ""; position: absolute; top: 2px; left: 2px; width: 24px; height: 24px; border-radius: 50%; background: #ffffff; box-shadow: 0 2px 5px rgba(0,0,0,0.18), 0 0 1px rgba(0,0,0,0.05); transition: transform 0.24s cubic-bezier(0.4, 0.0, 0.2, 1); } .refine-toggle input[type="checkbox"]:hover { background: #dcdce0 !important; } .refine-toggle input[type="checkbox"]:checked { background: #34c759 !important; } .refine-toggle input[type="checkbox"]:checked:hover { background: #30b352 !important; } .refine-toggle input[type="checkbox"]:checked::after { transform: translateX(18px); } .refine-toggle input[type="checkbox"]:focus-visible { box-shadow: 0 0 0 4px rgba(0,113,227,0.20) !important; } .refine-toggle input[type="checkbox"]:active::after { /* tiny squish on press, very iOS */ width: 28px; } .refine-toggle input[type="checkbox"]:checked:active::after { transform: translateX(14px); } /* Seed number input — match the prompt/dropdown rounding */ .seed-input input[type="number"] { border-radius: 12px !important; padding: 10px 14px !important; font-variant-numeric: tabular-nums; } /* Hide the native spinner buttons on number inputs for a cleaner look */ .seed-input input[type="number"]::-webkit-outer-spin-button, .seed-input input[type="number"]::-webkit-inner-spin-button { -webkit-appearance: none; margin: 0; } .seed-input input[type="number"] { -moz-appearance: textfield; } /* Keep the seed column visually aligned with the refine card next to it */ .seed-input { align-self: stretch !important; } /* Guidance scale slider — Apple-blue track + softer thumb */ .guidance-slider input[type="range"] { accent-color: #0071e3 !important; } .guidance-slider .head { padding-top: 0 !important; } .guidance-slider { margin-top: 4px; } /* Footer tagline */ .tagline { text-align: center; color: #6e6e73; font-size: 12px; margin: 18px 0 14px 0; font-weight: 400; } .tagline a { color: #0071e3; text-decoration: none; font-weight: 500; transition: opacity 0.15s ease; } .tagline a:hover { opacity: 0.7; } /* Footer links row (HuggingFace / GitHub / Twitter) */ .footer-links { display: flex; justify-content: center; align-items: center; gap: 26px; flex-wrap: wrap; font-size: 13px; margin: 24px 0 6px 0; } /* When the vivago tagline follows the links row, tighten the gap. */ .footer-links + .tagline { margin-top: 4px; } /* Hide gradio's default footer for a cleaner look */ footer { display: none !important; } /* Mobile */ @media (max-width: 640px) { .footer-links { gap: 18px; font-size: 12px; } } """ APPLE_THEME = gr.themes.Soft( primary_hue=gr.themes.colors.blue, neutral_hue=gr.themes.colors.slate, radius_size=gr.themes.sizes.radius_lg, text_size=gr.themes.sizes.text_md, font=[ gr.themes.GoogleFont("Inter"), "ui-sans-serif", "-apple-system", "BlinkMacSystemFont", "Segoe UI", "Helvetica Neue", "sans-serif", ], ).set( body_background_fill="*neutral_50", block_background_fill="white", block_border_width="1px", block_label_text_weight="500", block_title_text_weight="600", button_primary_background_fill="#0071e3", button_primary_background_fill_hover="#0077ed", button_primary_text_color="white", button_primary_border_color="#0071e3", button_secondary_background_fill="rgba(0,0,0,0.05)", button_secondary_background_fill_hover="rgba(0,0,0,0.09)", button_secondary_text_color="#1d1d1f", input_background_fill="white", input_border_color="#d2d2d7", input_border_color_focus="#0071e3", input_shadow_focus="0 0 0 4px rgba(0,113,227,0.15)", ) def _status_html(text: str, kind: str = "info") -> str: """Render a styled status pill. kind: info | success | error | running""" return ( f'