| import os |
| import gc |
| import re |
| import gradio as gr |
| import numpy as np |
| import torch |
| import json |
| import spaces |
| import config |
| import utils |
| import logging |
| import time |
| from datetime import datetime |
| from typing import List, Dict, Tuple, Optional |
| from PIL import Image, PngImagePlugin |
| from diffusers.models import AutoencoderKL |
| from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline |
| from transformers import pipeline as translation_pipeline |
|
|
| from config import ( |
| MODEL, |
| MIN_IMAGE_SIZE, |
| MAX_IMAGE_SIZE, |
| USE_TORCH_COMPILE, |
| ENABLE_CPU_OFFLOAD, |
| OUTPUT_DIR, |
| DEFAULT_NEGATIVE_PROMPT, |
| DEFAULT_ASPECT_RATIO, |
| examples, |
| sampler_list, |
| aspect_ratios, |
| style_list, |
| ) |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1" |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1" |
|
|
| |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| torch.backends.cuda.matmul.allow_tf32 = True |
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| logger.info(f"Using device: {device}") |
|
|
| |
| translator = translation_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") |
|
|
| class GenerationError(Exception): |
| """Custom exception for generation errors""" |
| pass |
|
|
| def translate_if_korean(prompt: str) -> str: |
| """ํ๋กฌํํธ์ ํ๊ธ์ด ํฌํจ๋์ด ์์ผ๋ฉด ์์ด๋ก ๋ฒ์ญ""" |
| if re.search(r'[ใฑ-ใ
ใ
-ใ
ฃ๊ฐ-ํฃ]', prompt): |
| logger.info("Korean detected in prompt. Translating to English...") |
| try: |
| translation = translator(prompt)[0]['translation_text'] |
| logger.info(f"Translation result: {translation}") |
| return translation |
| except Exception as e: |
| logger.error(f"Translation error: {e}") |
| |
| return prompt |
| return prompt |
|
|
| def validate_prompt(prompt: str) -> str: |
| """Validate and clean up the input prompt.""" |
| if not isinstance(prompt, str): |
| raise GenerationError("Prompt must be a string") |
| try: |
| |
| prompt = prompt.encode('utf-8').decode('utf-8') |
| |
| prompt = prompt.replace("!,", "! ,") |
| except UnicodeError: |
| raise GenerationError("Invalid characters in prompt") |
| |
| |
| if not prompt or prompt.isspace(): |
| raise GenerationError("Prompt cannot be empty") |
| |
| |
| prompt = translate_if_korean(prompt) |
| return prompt.strip() |
|
|
| def validate_dimensions(width: int, height: int) -> None: |
| """Validate image dimensions.""" |
| if not MIN_IMAGE_SIZE <= width <= MAX_IMAGE_SIZE: |
| raise GenerationError(f"Width must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}") |
| |
| if not MIN_IMAGE_SIZE <= height <= MAX_IMAGE_SIZE: |
| raise GenerationError(f"Height must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}") |
|
|
| @spaces.GPU |
| def generate( |
| prompt: str, |
| negative_prompt: str = DEFAULT_NEGATIVE_PROMPT, |
| seed: int = 0, |
| custom_width: int = 1024, |
| custom_height: int = 1024, |
| guidance_scale: float = 6.0, |
| num_inference_steps: int = 25, |
| sampler: str = "Euler a", |
| aspect_ratio_selector: str = DEFAULT_ASPECT_RATIO, |
| style_selector: str = "(None)", |
| use_upscaler: bool = False, |
| upscaler_strength: float = 0.55, |
| upscale_by: float = 1.5, |
| add_quality_tags: bool = True, |
| progress: gr.Progress = gr.Progress(track_tqdm=True), |
| ) -> Tuple[List[str], Dict]: |
| """Generate images based on the given parameters.""" |
| start_time = time.time() |
| upscaler_pipe = None |
| backup_scheduler = None |
| |
| try: |
| |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| |
| prompt = validate_prompt(prompt) |
| if negative_prompt: |
| negative_prompt = negative_prompt.encode('utf-8').decode('utf-8') |
| |
| validate_dimensions(custom_width, custom_height) |
| |
| |
| generator = utils.seed_everything(seed) |
| width, height = utils.aspect_ratio_handler( |
| aspect_ratio_selector, |
| custom_width, |
| custom_height, |
| ) |
|
|
| |
| if add_quality_tags: |
| prompt = f"masterpiece, high score, great score, absurdres, {prompt}" |
|
|
| prompt, negative_prompt = utils.preprocess_prompt( |
| styles, style_selector, prompt, negative_prompt |
| ) |
|
|
| width, height = utils.preprocess_image_dimensions(width, height) |
|
|
| |
| backup_scheduler = pipe.scheduler |
| pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler) |
|
|
| if use_upscaler: |
| upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components) |
| |
| |
| metadata = { |
| "prompt": prompt, |
| "negative_prompt": negative_prompt, |
| "resolution": f"{width} x {height}", |
| "guidance_scale": guidance_scale, |
| "num_inference_steps": num_inference_steps, |
| "style_preset": style_selector, |
| "seed": seed, |
| "sampler": sampler, |
| "Model": "Animagine XL 4.0", |
| "Model hash": "e3c47aedb0", |
| } |
|
|
| if use_upscaler: |
| new_width = int(width * upscale_by) |
| new_height = int(height * upscale_by) |
| metadata["use_upscaler"] = { |
| "upscale_method": "nearest-exact", |
| "upscaler_strength": upscaler_strength, |
| "upscale_by": upscale_by, |
| "new_resolution": f"{new_width} x {new_height}", |
| } |
| else: |
| metadata["use_upscaler"] = None |
| |
| logger.info(f"Starting generation with parameters: {json.dumps(metadata, indent=4)}") |
|
|
| |
| if use_upscaler: |
| latents = pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| width=width, |
| height=height, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_inference_steps, |
| generator=generator, |
| output_type="latent", |
| ).images |
| upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by) |
| images = upscaler_pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| image=upscaled_latents, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_inference_steps, |
| strength=upscaler_strength, |
| generator=generator, |
| output_type="pil", |
| ).images |
| else: |
| images = pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| width=width, |
| height=height, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_inference_steps, |
| generator=generator, |
| output_type="pil", |
| ).images |
|
|
| |
| if images: |
| total = len(images) |
| image_paths = [] |
| for idx, image in enumerate(images, 1): |
| progress(idx/total, desc="Saving images...") |
| path = utils.save_image(image, metadata, OUTPUT_DIR, IS_COLAB) |
| image_paths.append(path) |
| logger.info(f"Image {idx}/{total} saved as {path}") |
|
|
| generation_time = time.time() - start_time |
| logger.info(f"Generation completed successfully in {generation_time:.2f} seconds") |
| metadata["generation_time"] = f"{generation_time:.2f}s" |
| |
| return image_paths, metadata |
|
|
| except GenerationError as e: |
| logger.warning(f"Generation validation error: {str(e)}") |
| raise gr.Error(str(e)) |
| except Exception as e: |
| logger.exception("Unexpected error during generation") |
| raise gr.Error(f"Generation failed: {str(e)}") |
| finally: |
| |
| torch.cuda.empty_cache() |
| gc.collect() |
| |
| if upscaler_pipe is not None: |
| del upscaler_pipe |
| |
| if backup_scheduler is not None and pipe is not None: |
| pipe.scheduler = backup_scheduler |
| |
| utils.free_memory() |
|
|
| |
| if torch.cuda.is_available(): |
| try: |
| logger.info("Loading VAE and pipeline...") |
| vae = AutoencoderKL.from_pretrained( |
| "madebyollin/sdxl-vae-fp16-fix", |
| torch_dtype=torch.float16, |
| ) |
| pipe = utils.load_pipeline(MODEL, device, vae=vae) |
| logger.info("Pipeline loaded successfully on GPU!") |
| except Exception as e: |
| logger.error(f"Error loading VAE, falling back to default: {e}") |
| pipe = utils.load_pipeline(MODEL, device) |
| else: |
| logger.warning("CUDA not available, running on CPU") |
| pipe = None |
|
|
| |
| styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} |
|
|
| |
| custom_css = """ |
| /* ๋ฐฐ๊ฒฝ ๋ฐ ๊ธ์ ์์ ๋ณ๊ฒฝ */ |
| body { |
| background-color: #f7f9fc; |
| color: #333; |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
| } |
| |
| /* ํค๋ ์คํ์ผ */ |
| .header { |
| text-align: center; |
| padding: 20px; |
| } |
| .header .title { |
| font-size: 3em; |
| font-weight: bold; |
| color: #2c3e50; |
| } |
| .header .subtitle { |
| font-size: 1.2em; |
| color: #7f8c8d; |
| } |
| a { |
| text-decoration: none; |
| color: #3498db; |
| } |
| |
| /* Discord ๋ฒํผ ์คํ์ผ */ |
| .discord-btn { |
| display: flex; |
| align-items: center; |
| justify-content: center; |
| padding: 10px 20px; |
| background: #7289da; |
| color: white; |
| border-radius: 8px; |
| font-weight: bold; |
| margin-top: 20px; |
| } |
| .discord-btn:hover { |
| background: #5b6eae; |
| } |
| .discord-icon { |
| width: 24px; |
| height: 24px; |
| margin-right: 8px; |
| } |
| |
| /* Gradio ๊ฐค๋ฌ๋ฆฌ ์คํ์ผ ๊ฐ์ */ |
| .gradio-gallery { |
| border: none; |
| box-shadow: none; |
| } |
| """ |
|
|
| with gr.Blocks(css=custom_css, theme="default") as demo: |
| |
| gr.HTML( |
| """ |
| <div class="header"> |
| <div class="title">Multilingual Animagine</div> |
| </div> |
| """ |
| ) |
| |
| |
| with gr.Row(): |
| with gr.Column(scale=3): |
| with gr.Tabs(): |
| with gr.TabItem("Generate"): |
| with gr.Group(): |
| prompt = gr.Textbox( |
| label="Prompt", |
| lines=4, |
| placeholder="Describe what you want to generate...", |
| info="Enter your image generation prompt here. ํ๊ธ ์
๋ ฅ ์ ์๋์ผ๋ก ์์ด๋ก ๋ฒ์ญ๋ฉ๋๋ค.", |
| ) |
| negative_prompt = gr.Textbox( |
| label="Negative Prompt", |
| lines=4, |
| placeholder="Describe what you want to avoid", |
| value=DEFAULT_NEGATIVE_PROMPT, |
| info="Specify elements you don't want in the image.", |
| ) |
| add_quality_tags = gr.Checkbox( |
| label="Quality Tags", |
| value=True, |
| info="Automatically add quality-enhancing tags to your prompt.", |
| ) |
| with gr.Accordion(label="More Settings", open=False): |
| with gr.Column(): |
| aspect_ratio_selector = gr.Radio( |
| label="Aspect Ratio", |
| choices=aspect_ratios, |
| value=DEFAULT_ASPECT_RATIO, |
| container=True, |
| info="Choose the dimensions of your image.", |
| ) |
| with gr.Row(visible=False) as custom_resolution: |
| custom_width = gr.Slider( |
| label="Width", |
| minimum=MIN_IMAGE_SIZE, |
| maximum=MAX_IMAGE_SIZE, |
| step=8, |
| value=1024, |
| info=f"Image width (between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})", |
| ) |
| custom_height = gr.Slider( |
| label="Height", |
| minimum=MIN_IMAGE_SIZE, |
| maximum=MAX_IMAGE_SIZE, |
| step=8, |
| value=1024, |
| info=f"Image height (between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE})", |
| ) |
| with gr.Accordion(label="Advanced Parameters", open=False): |
| with gr.Row(): |
| style_selector = gr.Dropdown( |
| label="Style Preset", |
| choices=list(styles.keys()), |
| value="(None)", |
| info="Apply a predefined style to your generation.", |
| ) |
| sampler = gr.Dropdown( |
| label="Sampler", |
| choices=sampler_list, |
| value="Euler a", |
| info="Different samplers can produce varying results.", |
| ) |
| with gr.Row(): |
| seed = gr.Slider( |
| label="Seed", |
| minimum=0, |
| maximum=utils.MAX_SEED, |
| step=1, |
| value=0, |
| info="Set a specific seed for reproducible results.", |
| ) |
| randomize_seed = gr.Checkbox( |
| label="Randomize seed", |
| value=True, |
| info="Generate a new random seed for each image.", |
| ) |
| with gr.Row(): |
| guidance_scale = gr.Slider( |
| label="Guidance scale", |
| minimum=1, |
| maximum=12, |
| step=0.1, |
| value=6.0, |
| info="Higher values make the image more closely match your prompt.", |
| ) |
| num_inference_steps = gr.Slider( |
| label="Inference steps", |
| minimum=1, |
| maximum=50, |
| step=1, |
| value=25, |
| info="More steps generally yield higher quality images.", |
| ) |
| with gr.Row(): |
| use_upscaler = gr.Checkbox( |
| label="Use Upscaler", |
| value=False, |
| info="Enable high-resolution upscaling.", |
| ) |
| upscaler_strength = gr.Slider( |
| label="Upscaler Strength", |
| minimum=0, |
| maximum=1, |
| step=0.05, |
| value=0.55, |
| visible=False, |
| info="Control how strongly the upscaler affects the image.", |
| ) |
| upscale_by = gr.Slider( |
| label="Upscale By", |
| minimum=1, |
| maximum=1.5, |
| step=0.1, |
| value=1.5, |
| visible=False, |
| info="Multiplier for the final image resolution.", |
| ) |
| with gr.TabItem("Examples"): |
| gr.Markdown( |
| """ |
| ### Example Prompts |
| - **Scenic Landscape:** A breathtaking view of a mountain landscape during sunrise. |
| - **Cyberpunk City:** A futuristic cyberpunk city with neon lights and towering skyscrapers. |
| - **Fantasy Character:** A majestic wizard with a long beard and glowing magical staff. |
| """ |
| ) |
| gr.Examples( |
| examples=examples, |
| inputs=prompt, |
| outputs=[], |
| cache_examples=CACHE_EXAMPLES, |
| ) |
| run_button = gr.Button("Generate", variant="primary", elem_id="generate-button") |
| |
| with gr.Column(scale=4): |
| result = gr.Gallery( |
| label="Generated Images", |
| columns=2, |
| height="600px", |
| show_label=True, |
| elem_classes="gradio-gallery", |
| ) |
| with gr.Accordion(label="Generation Parameters", open=False): |
| gr_metadata = gr.JSON( |
| label="Image Metadata", |
| show_label=True, |
| ) |
| |
| |
| with gr.Row(): |
| gr.HTML( |
| """ |
| <div style="width:100%; display:flex; justify-content:center;"> |
| <a href="https://discord.gg/openfreeai" target="_blank" class="discord-btn"> |
| <svg class="discord-icon" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 127.14 96.36"> |
| <path fill="currentColor" d="M107.7,8.07A105.15,105.15,0,0,0,81.47,0a72.06,72.06,0,0,0-3.36,6.83A97.68,97.68,0,0,0,49,6.83,72.37,72.37,0,0,0,45.64,0,105.89,105.89,0,0,0,19.39,8.09C2.79,32.65-1.71,56.6.54,80.21h0A105.73,105.73,0,0,0,32.71,96.36,77.7,77.7,0,0,0,39.6,85.25a68.42,68.42,0,0,1-10.85-5.18c.91-.66,1.8-1.34,2.66-2a75.57,75.57,0,0,0,64.32,0c.87.71,1.76,1.39,2.66,2a68.68,68.68,0,0,1-10.87,5.19,77,77,0,0,0,6.89,11.1A105.25,105.25,0,0,0,126.6,80.22h0C129.24,52.84,122.09,29.11,107.7,8.07ZM42.45,65.69C36.18,65.69,31,60,31,53s5-12.74,11.43-12.74S54,46,53.89,53,48.84,65.69,42.45,65.69Zm42.24,0C78.41,65.69,73.25,60,73.25,53s5-12.74,11.44-12.74S96.23,46,96.12,53,91.08,65.69,84.69,65.69Z"/> |
| </svg> |
| <span class="discord-text">Join our Discord Server</span> |
| </a> |
| </div> |
| """ |
| ) |
| |
| |
| use_upscaler.change( |
| fn=lambda x: [gr.update(visible=x), gr.update(visible=x)], |
| inputs=use_upscaler, |
| outputs=[upscaler_strength, upscale_by], |
| queue=False, |
| api_name=False, |
| ) |
| aspect_ratio_selector.change( |
| fn=lambda x: gr.update(visible=x == "Custom"), |
| inputs=aspect_ratio_selector, |
| outputs=custom_resolution, |
| queue=False, |
| api_name=False, |
| ) |
|
|
| |
| gr.on( |
| triggers=[ |
| prompt.submit, |
| negative_prompt.submit, |
| run_button.click, |
| ], |
| fn=utils.randomize_seed_fn, |
| inputs=[seed, randomize_seed], |
| outputs=seed, |
| queue=False, |
| api_name=False, |
| ).then( |
| fn=lambda: gr.update(interactive=False, value="Generating..."), |
| outputs=run_button, |
| ).then( |
| fn=generate, |
| inputs=[ |
| prompt, |
| negative_prompt, |
| seed, |
| custom_width, |
| custom_height, |
| guidance_scale, |
| num_inference_steps, |
| sampler, |
| aspect_ratio_selector, |
| style_selector, |
| use_upscaler, |
| upscaler_strength, |
| upscale_by, |
| add_quality_tags, |
| ], |
| outputs=[result, gr_metadata], |
| ).then( |
| fn=lambda: gr.update(interactive=True, value="Generate"), |
| outputs=run_button, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB) |
|
|