akhaliq's picture
akhaliq HF Staff
refactor: update model loading to use AutoModelForVision2Seq instead of AutoModel
c528919
import logging
import os
import random
import time
import traceback
from io import BytesIO
import gradio as gr
from gradio import Server
from fastapi.responses import HTMLResponse
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
from dotenv import load_dotenv
# Try to import spaces for ZeroGPU support
try:
import spaces
except ImportError:
# Fallback for local development
class spaces:
@staticmethod
def GPU(fn):
return fn
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
load_dotenv()
from gradio.data_classes import FileData
# Load model and processor
logger.info("Loading model and processor...")
model_id = "HiDream-ai/HiDream-O1-Image"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
model_id,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True
)
if torch.cuda.is_available():
model = model.to("cuda")
logger.info("Model moved to CUDA")
else:
logger.info("CUDA not available, running on CPU")
app = Server()
@app.api()
@spaces.GPU
def generate(
prompt: str,
wh_ratio: str = "1:1",
negative_prompt: str = "",
enable_prompt_refine: bool = True,
seed: int = -1,
guidance_scale: float = 5.0
) -> FileData:
"""
Generate an image using the local transformers model.
"""
logger.info(f"Generating for prompt: {prompt}")
if seed != -1:
torch.manual_seed(seed)
random.seed(seed)
# Prepare inputs
# Note: The exact usage depends on the specific HiDream model architecture.
# Assuming a standard text-to-image or similar generation interface.
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
# HiDream-O1 often takes parameters in the prompt or as kwargs
# We pass them here just in case the custom modeling code supports them
output = model.generate(
**inputs,
max_new_tokens=1024,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
wh_ratio=wh_ratio,
)
# Process the output to an image
# HiDream models often return a PIL image directly or in a list
if isinstance(output, Image.Image):
img = output
elif isinstance(output, list) and len(output) > 0 and isinstance(output[0], Image.Image):
img = output[0]
elif hasattr(output, "images") and output.images:
img = output.images[0]
else:
# Fallback to decoder for text-based or token-based models
logger.info("Output is not a PIL image, attempting to decode...")
generated_output = processor.batch_decode(output, skip_special_tokens=True)[0]
if isinstance(generated_output, Image.Image):
img = generated_output
else:
# Fallback: create a dummy image if decoding fails to show something
logger.warning("Generated output was not a PIL image, creating placeholder.")
img = Image.new("RGB", (1024, 1024), color=(50, 50, 150))
out_path = f"generated_{int(time.time())}_{random.randint(0, 1000)}.png"
img.save(out_path)
return FileData(path=out_path)
@app.get("/")
async def index():
return HTMLResponse(open("index.html").read())
if __name__ == "__main__":
app.launch()