import logging import os import random import time import traceback from io import BytesIO import gradio as gr from gradio import Server from fastapi.responses import HTMLResponse import torch from transformers import AutoProcessor, AutoModelForVision2Seq from PIL import Image from dotenv import load_dotenv # Try to import spaces for ZeroGPU support try: import spaces except ImportError: # Fallback for local development class spaces: @staticmethod def GPU(fn): return fn logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) load_dotenv() from gradio.data_classes import FileData # Load model and processor logger.info("Loading model and processor...") model_id = "HiDream-ai/HiDream-O1-Image" processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForVision2Seq.from_pretrained( model_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True ) if torch.cuda.is_available(): model = model.to("cuda") logger.info("Model moved to CUDA") else: logger.info("CUDA not available, running on CPU") app = Server() @app.api() @spaces.GPU def generate( prompt: str, wh_ratio: str = "1:1", negative_prompt: str = "", enable_prompt_refine: bool = True, seed: int = -1, guidance_scale: float = 5.0 ) -> FileData: """ Generate an image using the local transformers model. """ logger.info(f"Generating for prompt: {prompt}") if seed != -1: torch.manual_seed(seed) random.seed(seed) # Prepare inputs # Note: The exact usage depends on the specific HiDream model architecture. # Assuming a standard text-to-image or similar generation interface. inputs = processor(text=prompt, return_tensors="pt").to(model.device) with torch.no_grad(): # HiDream-O1 often takes parameters in the prompt or as kwargs # We pass them here just in case the custom modeling code supports them output = model.generate( **inputs, max_new_tokens=1024, negative_prompt=negative_prompt, guidance_scale=guidance_scale, wh_ratio=wh_ratio, ) # Process the output to an image # HiDream models often return a PIL image directly or in a list if isinstance(output, Image.Image): img = output elif isinstance(output, list) and len(output) > 0 and isinstance(output[0], Image.Image): img = output[0] elif hasattr(output, "images") and output.images: img = output.images[0] else: # Fallback to decoder for text-based or token-based models logger.info("Output is not a PIL image, attempting to decode...") generated_output = processor.batch_decode(output, skip_special_tokens=True)[0] if isinstance(generated_output, Image.Image): img = generated_output else: # Fallback: create a dummy image if decoding fails to show something logger.warning("Generated output was not a PIL image, creating placeholder.") img = Image.new("RGB", (1024, 1024), color=(50, 50, 150)) out_path = f"generated_{int(time.time())}_{random.randint(0, 1000)}.png" img.save(out_path) return FileData(path=out_path) @app.get("/") async def index(): return HTMLResponse(open("index.html").read()) if __name__ == "__main__": app.launch()