| import os |
| import sys |
| import spaces |
| import gradio as gr |
| import numpy as np |
| import torch |
| import random |
| import time |
| from PIL import Image |
| from huggingface_hub import hf_hub_download |
| import subprocess |
| subprocess.run( |
| "pip install flash-attn==2.7.3 --no-build-isolation", |
| shell=True |
| ) |
|
|
| from star.models.config import load_config_from_json, STARMultiModalConfig |
| from star.models.model import STARMultiModal |
|
|
|
|
| TEXTS = { |
| "zh": { |
| "title": "🌟 STAR 多模态演示", |
| "description": "基于STAR模型的多模态AI演示系统,支持文本生成图像、图像编辑和图像理解功能。", |
| "please_load_model": "请先加载模型!", |
| "please_upload_image": "请上传图像!", |
| "generation_failed": "生成失败!", |
| "generation_success_diffusion": "生成成功!", |
| "generation_success_vq": "生成成功!", |
| "edit_failed": "编辑失败!", |
| "edit_success_diffusion": "编辑成功!", |
| "edit_success_vq": "编辑成功!", |
| "understanding_failed": "理解失败!", |
| "generation_error": "生成过程中出错: ", |
| "edit_error": "编辑过程中出错: ", |
| "understanding_error": "理解过程中出错: ", |
| "tab_text_to_image": "🖼️ 文本生成图像", |
| "tab_image_edit": "🖌️ 图像编辑", |
| "tab_image_understanding": "📝 图像理解", |
| "text_prompt": "文本提示", |
| "text_prompt_placeholder": "A whimsical scene featuring a small elf with pointed ears and a green hat, sipping orange juice through a long straw from a disproportionately large orange. Next to the elf, a curious squirrel perches on its hind legs, while an owl with wide, observant eyes watches intently from a branch overhead. The orange's vibrant color contrasts with the muted browns and greens of the surrounding forest foliage.", |
| "advanced_params": "高级参数", |
| "cfg_scale": "CFG Scale", |
| "cfg_scale_info": "控制生成图像与文本的匹配程度", |
| "top_k": "Top-K", |
| "top_k_info": "采样时考虑的token数量", |
| "top_p": "Top-P", |
| "top_p_info": "核采样参数", |
| "generate_image": "🎨 生成图像", |
| "generated_image": "生成的图像", |
| "generation_status": "生成状态", |
| "input_image": "输入图像", |
| "edit_instruction": "编辑指令", |
| "edit_instruction_placeholder": "Remove the tiger in the water.", |
| "edit_image": "✏️ 编辑图像", |
| "edited_image": "编辑后的图像", |
| "edit_status": "编辑状态", |
| "question": "问题", |
| "question_placeholder": "Please describe the content of this image", |
| "max_generation_length": "最大生成长度", |
| "understand_image": "🔍 理解图像", |
| "understanding_result": "理解结果", |
| "usage_instructions": "使用说明", |
| "usage_step1": "1. **文本生成图像**: 输入文本描述,调整参数后点击生成", |
| "usage_step2": "2. **图像编辑**: 上传图像并输入编辑指令", |
| "usage_step3": "3. **图像理解**: 上传图像并提出问题", |
| "language": "语言 / Language" |
| }, |
| "en": { |
| "title": "🌟 STAR Multi-Modal Demo", |
| "description": "A multi-modal AI demonstration system based on STAR model, supporting text-to-image generation, image editing, and image understanding.", |
| "please_load_model": "Please load the model first!", |
| "please_upload_image": "Please upload an image!", |
| "generation_failed": "Generation failed!", |
| "generation_success_diffusion": "Generation successful! ", |
| "generation_success_vq": "Generation successful! Using VQ decoder", |
| "edit_failed": "Editing failed!", |
| "edit_success_diffusion": "Editing successful! ", |
| "edit_success_vq": "Editing successful! Using VQ decoder", |
| "understanding_failed": "Understanding failed!", |
| "generation_error": "Error during generation: ", |
| "edit_error": "Error during editing: ", |
| "understanding_error": "Error during understanding: ", |
| "tab_text_to_image": "🖼️ Text to Image", |
| "tab_image_edit": "🖌️ Image Editing", |
| "tab_image_understanding": "📝 Image Understanding", |
| "text_prompt": "Text Prompt", |
| "text_prompt_placeholder": "A whimsical scene featuring a small elf with pointed ears and a green hat, sipping orange juice through a long straw from a disproportionately large orange. Next to the elf, a curious squirrel perches on its hind legs, while an owl with wide, observant eyes watches intently from a branch overhead. The orange's vibrant color contrasts with the muted browns and greens of the surrounding forest foliage.", |
| "advanced_params": "Advanced Parameters", |
| "cfg_scale": "CFG Scale", |
| "cfg_scale_info": "Controls how closely the generated image matches the text", |
| "top_k": "Top-K", |
| "top_k_info": "Number of tokens to consider during sampling", |
| "top_p": "Top-P", |
| "top_p_info": "Nucleus sampling parameter", |
| "generate_image": "🎨 Generate Image", |
| "generated_image": "Generated Image", |
| "generation_status": "Generation Status", |
| "input_image": "Input Image", |
| "edit_instruction": "Edit Instruction", |
| "edit_instruction_placeholder": "Remove the tiger in the water.", |
| "edit_image": "✏️ Edit Image", |
| "edited_image": "Edited Image", |
| "edit_status": "Edit Status", |
| "question": "Question", |
| "question_placeholder": "Please describe the content of this image", |
| "max_generation_length": "Max Generation Length", |
| "understand_image": "🔍 Understand Image", |
| "understanding_result": "Understanding Result", |
| "usage_instructions": "Usage Instructions", |
| "usage_step1": "1. **Text to Image**: Enter text description, adjust parameters and click generate", |
| "usage_step2": "2. **Image Editing**: Upload an image and enter editing instructions", |
| "usage_step3": "3. **Image Understanding**: Upload an image and ask questions", |
| "language": "语言 / Language" |
| } |
| } |
|
|
| class MockArgs: |
| def __init__(self): |
| self.data_type = "generation" |
| self.diffusion_as_decoder = True |
| self.ori_inp_dit = "seq" |
| self.grad_ckpt = False |
| self.diffusion_resolution = 1024 |
| self.max_diff_seq_length = 256 |
| self.max_seq_length = 8192 |
| self.max_text_tokens = 512 |
| self.max_pixels = 28 * 28 * 576 |
| self.min_pixels = 28 * 28 * 16 |
| self.vq_image_size = 384 |
| self.vq_tokens = 576 |
|
|
|
|
| def set_seed(seed=100): |
| if seed > 0: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| return seed |
|
|
|
|
| def print_with_time(msg): |
| print(f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}: {msg}") |
|
|
|
|
| class STARInferencer: |
|
|
| def __init__(self, model_config_path, checkpoint_path, vq_checkpoint, device="cpu"): |
| self.device = device |
| self.model_config_path = model_config_path |
| self.checkpoint_path = checkpoint_path |
| self.vq_checkpint_path = vq_checkpoint |
| self.model = None |
| self._load_model() |
|
|
| def _create_mock_args(self): |
|
|
| return MockArgs() |
|
|
| def _load_model(self): |
| try: |
| print_with_time("Loading model configuration...") |
| config_data = load_config_from_json(self.model_config_path) |
| model_config = STARMultiModalConfig(**config_data) |
|
|
| model_config.language_model.model_path = "Qwen/Qwen2.5-VL-7B-Instruct" |
| model_config.pixel_encoder.model_path = self.vq_checkpint_path |
| model_config.pixel_decoder.model_path = "Alpha-VLLM/Lumina-Image-2.0" |
|
|
| args = self._create_mock_args() |
|
|
| print_with_time("Initializing model...") |
| self.model = STARMultiModal(model_config, args) |
|
|
| if os.path.exists(self.checkpoint_path): |
| print_with_time(f"Loading checkpoint from {self.checkpoint_path}") |
| with torch.no_grad(): |
| checkpoint = torch.load(self.checkpoint_path, map_location='cpu', weights_only=False) |
| if 'state_dict' in checkpoint: |
| state_dict = checkpoint['state_dict'] |
| else: |
| state_dict = checkpoint |
|
|
| if not isinstance(state_dict, dict): |
| raise ValueError("Invalid checkpoint format") |
|
|
| print_with_time(f"Checkpoint contains {len(state_dict)} parameters") |
| self.model.load_state_dict(state_dict, strict=False) |
|
|
| print_with_time(f"Moving model to device: {self.device}") |
| self.model.to(self.device) |
|
|
| print_with_time("Setting model to eval mode...") |
| self.model.eval() |
|
|
| if torch.cuda.is_available(): |
| print_with_time(f"GPU memory after model loading: {torch.cuda.memory_allocated()/1024**3:.2f}GB") |
|
|
| print_with_time("Model loaded successfully!") |
|
|
| except Exception as e: |
| print_with_time(f"Error loading model: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| raise e |
|
|
| @spaces.GPU(duration=210) |
| def generate_image(self, prompt, num_images=1, cfg=20.0, topk=2000, topp=1.0, seed=0): |
|
|
| if self.model.device.type == 'cpu': |
| print_with_time("Moving model to GPU...") |
| self.model.to('cuda') |
| self.model.to(torch.bfloat16) |
| print_with_time("Model moved to GPU") |
| |
| set_seed(seed) |
|
|
| print_with_time(f"Generating image for prompt: {prompt}") |
|
|
| cfg = max(1.0, min(20.0, float(cfg))) |
| topk = max(100, min(2000, int(topk))) |
| topp = max(0.1, min(1.0, float(topp))) |
|
|
| print_with_time(f"Using validated params: cfg={cfg}, topk={topk}, topp={topp}") |
|
|
| if not (torch.isfinite(torch.tensor(cfg)) and torch.isfinite(torch.tensor(topk)) and torch.isfinite(torch.tensor(topp))): |
| print_with_time("Warning: Non-finite parameters detected") |
| return None |
|
|
| try: |
| with torch.no_grad(): |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| print_with_time(f"GPU memory before generation: {torch.cuda.memory_allocated()/1024**3:.2f}GB") |
|
|
| if not isinstance(prompt, str) or len(prompt.strip()) == 0: |
| print_with_time("Warning: Invalid prompt") |
| return None |
|
|
| if not (0 < cfg <= 20 and 0 < topk <= 5000 and 0 < topp <= 1): |
| print_with_time(f"Warning: Invalid parameters - cfg={cfg}, topk={topk}, topp={topp}") |
| return None |
|
|
| print_with_time("Calling model.generate_images...") |
|
|
| safe_max_tokens = 576 |
|
|
| output = self.model.generate_images( |
| prompt, |
| max_new_tokens=safe_max_tokens, |
| num_return_sequences=num_images, |
| cfg_weight=cfg, |
| topk_sample=topk, |
| topp_sample=topp, |
| reasoning=False, |
| return_dict=True |
| ) |
| print_with_time("Model generation completed") |
|
|
| if output is None: |
| print_with_time("Warning: Model returned None output") |
| return None |
|
|
| print_with_time("Processing output images...") |
| result = self._process_output_images(output, num_images) |
| print_with_time("Image processing completed") |
| return result |
| except Exception as e: |
| print_with_time(f"Error during image generation: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| raise e |
|
|
| @spaces.GPU(duration=210) |
| def edit_image(self, image, instruction, num_images=1, cfg=20.0, topk=2000, topp=1.0, seed=0): |
|
|
| if self.model.device.type == 'cpu': |
| print_with_time("Moving model to GPU...") |
| self.model.to('cuda') |
| self.model.to(torch.bfloat16) |
| print_with_time("Model moved to GPU") |
| |
| set_seed(seed) |
|
|
| if isinstance(image, np.ndarray): |
| image = Image.fromarray(image) |
|
|
| print_with_time(f"Editing image with instruction: {instruction}") |
|
|
| with torch.no_grad(): |
| output = self.model.generate_images_edit( |
| [image], |
| instruction, |
| max_new_tokens=576, |
| num_return_sequences=num_images, |
| cfg_weight=cfg, |
| topk_sample=topk, |
| topp_sample=topp, |
| return_dict=True |
| ) |
|
|
| if output is None: |
| return None |
|
|
| return self._process_output_images(output, num_images) |
|
|
| @spaces.GPU(duration=180) |
| def understand_image(self, image, question, max_new_tokens=256): |
|
|
| if self.model.device.type == 'cpu': |
| print_with_time("Moving model to GPU...") |
| self.model.to('cuda') |
| self.model.to(torch.bfloat16) |
| print_with_time("Model moved to GPU") |
| |
| if isinstance(image, np.ndarray): |
| image = Image.fromarray(image) |
|
|
| print_with_time(f"Understanding image with question: {question}") |
|
|
| with torch.no_grad(): |
| answer = self.model.inference_understand( |
| image=image, |
| question=question, |
| max_new_tokens=max_new_tokens |
| ) |
|
|
| return answer |
|
|
| def _process_output_images(self, output, num_images): |
| image_size = 384 |
|
|
| try: |
| if isinstance(output, dict): |
| output_images = output.get("output_images") |
| diff_images = output.get("diff_images") |
|
|
| results = {} |
|
|
| if output_images is not None: |
| if isinstance(output_images, torch.Tensor): |
| output_images = output_images.detach().cpu().numpy() |
|
|
| if output_images.size == 0: |
| print_with_time("Warning: Empty output_images array") |
| results["vq_images"] = None |
| else: |
| output_images = np.nan_to_num(output_images, nan=0.0, posinf=1.0, neginf=-1.0) |
| dec_vq = np.clip((output_images + 1) / 2 * 255, 0, 255) |
|
|
| if len(dec_vq.shape) == 3: |
| dec_vq = dec_vq.reshape(num_images, image_size, image_size, 3) |
|
|
| visual_img_vq = np.zeros((num_images, image_size, image_size, 3), dtype=np.uint8) |
| visual_img_vq[:, :, :] = dec_vq |
| imgs_vq = [Image.fromarray(visual_img_vq[j].astype(np.uint8)) for j in range(visual_img_vq.shape[0])] |
| results["vq_images"] = imgs_vq |
|
|
| if diff_images is not None: |
| results["diff_images"] = diff_images |
| else: |
| results["diff_images"] = None |
|
|
| return results |
| else: |
| if isinstance(output, torch.Tensor): |
| output = output.detach().cpu().numpy() |
|
|
| output = np.nan_to_num(output, nan=0.0, posinf=1.0, neginf=-1.0) |
| dec = np.clip((output + 1) / 2 * 255, 0, 255) |
|
|
| if len(dec.shape) == 3: |
| dec = dec.reshape(num_images, image_size, image_size, 3) |
|
|
| visual_img = np.zeros((num_images, image_size, image_size, 3), dtype=np.uint8) |
| visual_img[:, :, :] = dec |
| imgs = [Image.fromarray(visual_img[j].astype(np.uint8)) for j in range(visual_img.shape[0])] |
| return {"vq_images": imgs, "diff_images": None} |
|
|
| except Exception as e: |
| print_with_time(f"Error in _process_output_images: {str(e)}") |
| return {"vq_images": None, "diff_images": None} |
|
|
|
|
| inferencer = None |
|
|
|
|
|
|
| def save_language_setting(language): |
| try: |
| with open('.language_setting', 'w') as f: |
| f.write(language) |
| except: |
| pass |
|
|
| def update_interface_language(language): |
| global current_language |
| current_language = language |
|
|
| save_language_setting(language) |
|
|
| return [ |
| language, |
| f"# {get_text('title')}", |
| get_text("description"), |
| get_text("text_prompt_placeholder"), |
| get_text("edit_instruction_placeholder"), |
| get_text("question_placeholder"), |
| f""" |
| --- |
| ### {get_text("usage_instructions")} |
| {get_text("usage_step1")} |
| {get_text("usage_step2")} |
| {get_text("usage_step3")} |
| """, |
| f"✅ Language switched to {language.upper()} successfully! / 语言已成功切换为{language.upper()}!" |
| ] |
|
|
| current_language = "en" |
|
|
| def get_text(key): |
| return TEXTS[current_language].get(key, key) |
|
|
|
|
| def auto_detect_device(): |
| if torch.cuda.is_available(): |
| device = f"cuda:{torch.cuda.current_device()}" |
| print_with_time(f"Detected CUDA device: {device}") |
| print_with_time(f"GPU name: {torch.cuda.get_device_name()}") |
| print_with_time(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB") |
| else: |
| device = "cpu" |
| print_with_time("No CUDA device detected, using CPU") |
| return device |
|
|
|
|
| def initialize_model_on_startup(): |
| global inferencer |
|
|
| default_checkpoint = hf_hub_download( |
| repo_id="MM-MVR/STAR-7B", |
| filename="STAR-7B.pt" |
| ) |
|
|
| default_config = "star/configs/STAR_Qwen2.5-VL-7B.json" |
|
|
| vq_checkpoint = hf_hub_download( |
| repo_id="MM-MVR/STAR-VQ", |
| filename="VQ-Model.pt" |
| ) |
| |
|
|
| if not os.path.exists(default_config): |
| print_with_time(f"⚠️ Model config file not found: {default_config}") |
| return False, f"Model config file not found: {default_config}" |
|
|
| if not os.path.exists(default_checkpoint): |
| print_with_time(f"⚠️ Model checkpoint file not found: {default_checkpoint}") |
| return False, f"Model checkpoint file not found: {default_checkpoint}" |
|
|
| try: |
| device = 'cpu' |
| print_with_time("Starting to load STAR model...") |
|
|
| inferencer = STARInferencer(default_config, default_checkpoint, vq_checkpoint, device) |
|
|
| print_with_time("✅ STAR model loaded successfully!") |
| return True, "✅ STAR model loaded successfully!" |
|
|
| except Exception as e: |
| error_msg = f"❌ Model loading failed: {str(e)}" |
| print_with_time(error_msg) |
| return False, error_msg |
|
|
|
|
|
|
|
|
| def text_to_image(prompt, cfg_scale=1.0, topk=1000, topp=0.8): |
| if inferencer is None: |
| return None, get_text("please_load_model") |
|
|
| cfg_scale = max(1.0, min(20.0, cfg_scale)) |
| topk = max(100, min(2000, int(topk))) |
| topp = max(0.1, min(1.0, topp)) |
| seed = 100 |
|
|
| try: |
| print_with_time(f"Starting generation with params: cfg={cfg_scale}, topk={topk}, topp={topp}, seed={seed}") |
| result = inferencer.generate_image(prompt, cfg=cfg_scale, topk=topk, topp=topp, seed=seed) |
|
|
| if result is None: |
| return None, get_text("generation_failed") |
|
|
| if result.get("diff_images") and len(result["diff_images"]) > 0: |
| return result["diff_images"][0], get_text("generation_success_diffusion") |
| elif result.get("vq_images") and len(result["vq_images"]) > 0: |
| return result["vq_images"][0], get_text("generation_success_vq") |
| else: |
| return None, get_text("generation_failed") |
|
|
| except Exception as e: |
| return None, get_text("generation_error") + str(e) |
|
|
|
|
| def image_editing(image, instruction, cfg_scale=1.0, topk=1000, topp=0.8): |
| if inferencer is None: |
| return None, get_text("please_load_model") |
|
|
| if image is None: |
| return None, get_text("please_upload_image") |
|
|
|
|
| cfg_scale = max(1.0, min(20.0, cfg_scale)) |
| topk = max(100, min(2000, int(topk))) |
| topp = max(0.1, min(1.0, topp)) |
| seed = 100 |
|
|
| try: |
| print_with_time(f"Starting image editing with params: cfg={cfg_scale}, topk={topk}, topp={topp}, seed={seed}") |
| result = inferencer.edit_image(image, instruction, cfg=cfg_scale, topk=topk, topp=topp, seed=seed) |
|
|
| if result is None: |
| return None, get_text("edit_failed") |
|
|
| if result.get("diff_images") and len(result["diff_images"]) > 0: |
| return result["diff_images"][0], get_text("edit_success_diffusion") |
| elif result.get("vq_images") and len(result["vq_images"]) > 0: |
| return result["vq_images"][0], get_text("edit_success_vq") |
| else: |
| return None, get_text("edit_failed") |
|
|
| except Exception as e: |
| return None, get_text("edit_error") + str(e) |
|
|
|
|
| def image_understanding(image, question, max_new_tokens=256): |
| if inferencer is None: |
| return get_text("please_load_model") |
|
|
| if image is None: |
| return get_text("please_upload_image") |
|
|
| try: |
| answer = inferencer.understand_image(image, question, max_new_tokens) |
| return answer if answer else get_text("understanding_failed") |
|
|
| except Exception as e: |
| return get_text("understanding_error") + str(e) |
|
|
|
|
| def change_language(language): |
| global current_language |
| current_language = language |
|
|
| return ( |
| get_text("title"), |
| get_text("description"), |
| get_text("tab_text_to_image"), |
| get_text("text_prompt"), |
| get_text("text_prompt_placeholder"), |
| get_text("advanced_params"), |
| get_text("cfg_scale"), |
| get_text("cfg_scale_info"), |
| get_text("top_k"), |
| get_text("top_k_info"), |
| get_text("top_p"), |
| get_text("top_p_info"), |
| get_text("random_seed"), |
| get_text("random_seed_info"), |
| get_text("generate_image"), |
| get_text("generated_image"), |
| get_text("generation_status"), |
| get_text("tab_image_edit"), |
| get_text("input_image"), |
| get_text("edit_instruction"), |
| get_text("edit_instruction_placeholder"), |
| get_text("edit_image"), |
| get_text("edited_image"), |
| get_text("edit_status"), |
| get_text("tab_image_understanding"), |
| get_text("question"), |
| get_text("question_placeholder"), |
| get_text("max_generation_length"), |
| get_text("understand_image"), |
| get_text("understanding_result"), |
| get_text("usage_instructions"), |
| get_text("usage_step1"), |
| get_text("usage_step2"), |
| get_text("usage_step3") |
| ) |
|
|
|
|
| def load_example_image(image_path): |
| try: |
| if os.path.exists(image_path): |
| return Image.open(image_path) |
| except Exception as e: |
| print(f"Error loading example image: {e}") |
| return None |
|
|
|
|
|
|
| def create_interface(): |
| |
| print_with_time("Initializing STAR demo system...") |
| model_loaded, status_message = initialize_model_on_startup() |
|
|
| with gr.Blocks(title="🌟 STAR Multi-Modal Demo", theme=gr.themes.Soft()) as demo: |
| |
| language_state = gr.State(value=current_language) |
| title_md = gr.Markdown(f"# {get_text('title')}") |
| desc_md = gr.Markdown(get_text("description")) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| language_dropdown = gr.Dropdown( |
| choices=[("English", "en"), ("中文", "zh")], |
| value=current_language, |
| label="Language / 语言", |
| interactive=True |
| ) |
|
|
| with gr.Tabs(): |
| with gr.Tab(get_text("tab_text_to_image")) as txt_tab: |
| with gr.Row(): |
| with gr.Column(): |
| txt_prompt = gr.Textbox( |
| label=get_text("text_prompt"), |
| value=get_text("text_prompt_placeholder"), |
| lines=3 |
| ) |
|
|
| with gr.Accordion(get_text("advanced_params"), open=False): |
| txt_cfg_scale = gr.Slider( |
| minimum=1.0, maximum=20.0, value=1.1, step=0.1, |
| label=get_text("cfg_scale"), info=get_text("cfg_scale_info") |
| ) |
| txt_topk = gr.Slider( |
| minimum=100, maximum=2000, value=1000, step=50, |
| label=get_text("top_k"), info=get_text("top_k_info") |
| ) |
| txt_topp = gr.Slider( |
| minimum=0.1, maximum=1.0, value=0.8, step=0.05, |
| label=get_text("top_p"), info=get_text("top_p_info") |
| ) |
|
|
| txt_generate_btn = gr.Button(get_text("generate_image"), variant="primary") |
|
|
| with gr.Column(): |
| txt_output_image = gr.Image(label=get_text("generated_image")) |
| txt_status = gr.Textbox(label=get_text("generation_status"), interactive=False) |
|
|
| |
| with gr.Tab(get_text("tab_image_edit")) as edit_tab: |
| with gr.Row(): |
| with gr.Column(): |
| edit_input_image = gr.Image( |
| label=get_text("input_image"), |
| value=load_example_image('assets/editing.png') |
| ) |
| edit_instruction = gr.Textbox( |
| label=get_text("edit_instruction"), |
| value=get_text("edit_instruction_placeholder"), |
| lines=2 |
| ) |
|
|
| with gr.Accordion(get_text("advanced_params"), open=False): |
| edit_cfg_scale = gr.Slider( |
| minimum=1.0, maximum=20.0, value=1.1, step=0.1, |
| label=get_text("cfg_scale") |
| ) |
| edit_topk = gr.Slider( |
| minimum=100, maximum=2000, value=1000, step=50, |
| label=get_text("top_k") |
| ) |
| edit_topp = gr.Slider( |
| minimum=0.1, maximum=1.0, value=0.8, step=0.05, |
| label=get_text("top_p") |
| ) |
|
|
| edit_btn = gr.Button(get_text("edit_image"), variant="primary") |
|
|
| with gr.Column(): |
| edit_output_image = gr.Image(label=get_text("edited_image")) |
| edit_status = gr.Textbox(label=get_text("edit_status"), interactive=False) |
|
|
| |
| with gr.Tab(get_text("tab_image_understanding")) as understand_tab: |
| with gr.Row(): |
| with gr.Column(): |
| understand_input_image = gr.Image( |
| label=get_text("input_image"), |
| value=load_example_image('assets/understand.png') |
| ) |
| understand_question = gr.Textbox( |
| label=get_text("question"), |
| value=get_text("question_placeholder"), |
| lines=2 |
| ) |
|
|
| with gr.Accordion(get_text("advanced_params"), open=False): |
| understand_max_tokens = gr.Slider( |
| minimum=64, maximum=1024, value=256, step=64, |
| label=get_text("max_generation_length") |
| ) |
|
|
| understand_btn = gr.Button(get_text("understand_image"), variant="primary") |
|
|
| with gr.Column(): |
| understand_output = gr.Textbox( |
| label=get_text("understanding_result"), |
| lines=15, |
| interactive=False |
| ) |
|
|
| usage_md = gr.Markdown( |
| f""" |
| --- |
| ### {get_text("usage_instructions")} |
| {get_text("usage_step1")} |
| {get_text("usage_step2")} |
| {get_text("usage_step3")} |
| """ |
| ) |
|
|
| txt_generate_btn.click( |
| fn=text_to_image, |
| inputs=[txt_prompt, txt_cfg_scale, txt_topk, txt_topp], |
| outputs=[txt_output_image, txt_status] |
| ) |
|
|
| edit_btn.click( |
| fn=image_editing, |
| inputs=[edit_input_image, edit_instruction, edit_cfg_scale, edit_topk, edit_topp], |
| outputs=[edit_output_image, edit_status] |
| ) |
|
|
| understand_btn.click( |
| fn=image_understanding, |
| inputs=[understand_input_image, understand_question, understand_max_tokens], |
| outputs=understand_output |
| ) |
|
|
|
|
| language_dropdown.change( |
| fn=update_interface_language, |
| inputs=[language_dropdown], |
| outputs=[language_state, title_md, desc_md, txt_prompt, edit_instruction, understand_question, usage_md, txt_status] |
| ) |
|
|
| return demo |
|
|
| demo = create_interface() |
|
|
| demo.launch(share=True, show_error=True) |
|
|
|
|