Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import random | |
| import time | |
| import traceback | |
| from io import BytesIO | |
| import gradio as gr | |
| from gradio import Server | |
| from fastapi.responses import HTMLResponse | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| from PIL import Image | |
| from dotenv import load_dotenv | |
| # Try to import spaces for ZeroGPU support | |
| try: | |
| import spaces | |
| except ImportError: | |
| # Fallback for local development | |
| class spaces: | |
| def GPU(fn): | |
| return fn | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| from gradio.data_classes import FileData | |
| # Load model and processor | |
| logger.info("Loading model and processor...") | |
| model_id = "HiDream-ai/HiDream-O1-Image" | |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModelForVision2Seq.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| trust_remote_code=True | |
| ) | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda") | |
| logger.info("Model moved to CUDA") | |
| else: | |
| logger.info("CUDA not available, running on CPU") | |
| app = Server() | |
| def generate( | |
| prompt: str, | |
| wh_ratio: str = "1:1", | |
| negative_prompt: str = "", | |
| enable_prompt_refine: bool = True, | |
| seed: int = -1, | |
| guidance_scale: float = 5.0 | |
| ) -> FileData: | |
| """ | |
| Generate an image using the local transformers model. | |
| """ | |
| logger.info(f"Generating for prompt: {prompt}") | |
| if seed != -1: | |
| torch.manual_seed(seed) | |
| random.seed(seed) | |
| # Prepare inputs | |
| # Note: The exact usage depends on the specific HiDream model architecture. | |
| # Assuming a standard text-to-image or similar generation interface. | |
| inputs = processor(text=prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| # HiDream-O1 often takes parameters in the prompt or as kwargs | |
| # We pass them here just in case the custom modeling code supports them | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale, | |
| wh_ratio=wh_ratio, | |
| ) | |
| # Process the output to an image | |
| # HiDream models often return a PIL image directly or in a list | |
| if isinstance(output, Image.Image): | |
| img = output | |
| elif isinstance(output, list) and len(output) > 0 and isinstance(output[0], Image.Image): | |
| img = output[0] | |
| elif hasattr(output, "images") and output.images: | |
| img = output.images[0] | |
| else: | |
| # Fallback to decoder for text-based or token-based models | |
| logger.info("Output is not a PIL image, attempting to decode...") | |
| generated_output = processor.batch_decode(output, skip_special_tokens=True)[0] | |
| if isinstance(generated_output, Image.Image): | |
| img = generated_output | |
| else: | |
| # Fallback: create a dummy image if decoding fails to show something | |
| logger.warning("Generated output was not a PIL image, creating placeholder.") | |
| img = Image.new("RGB", (1024, 1024), color=(50, 50, 150)) | |
| out_path = f"generated_{int(time.time())}_{random.randint(0, 1000)}.png" | |
| img.save(out_path) | |
| return FileData(path=out_path) | |
| async def index(): | |
| return HTMLResponse(open("index.html").read()) | |
| if __name__ == "__main__": | |
| app.launch() | |