akhaliq HF Staff commited on
Commit
dcf4603
·
1 Parent(s): 10bdcf8

refactor: update model loading and generation logic to return FileData for HiDream-O1 integration

Browse files
Files changed (1) hide show
  1. app.py +31 -24
app.py CHANGED
@@ -9,7 +9,7 @@ import gradio as gr
9
  from gradio import Server
10
  from fastapi.responses import HTMLResponse
11
  import torch
12
- from transformers import AutoProcessor, AutoModelForImageTextToText
13
  from PIL import Image
14
  from dotenv import load_dotenv
15
 
@@ -31,11 +31,13 @@ logger = logging.getLogger(__name__)
31
 
32
  load_dotenv()
33
 
 
 
34
  # Load model and processor
35
  logger.info("Loading model and processor...")
36
  model_id = "HiDream-ai/HiDream-O1-Image"
37
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
38
- model = AutoModelForImageTextToText.from_pretrained(
39
  model_id,
40
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
41
  trust_remote_code=True
@@ -49,16 +51,16 @@ else:
49
 
50
  app = Server()
51
 
52
- @app.api("/generate")
53
  @spaces.GPU
54
- def generate_image_api(
55
  prompt: str,
56
  wh_ratio: str = "1:1",
57
  negative_prompt: str = "",
58
  enable_prompt_refine: bool = True,
59
  seed: int = -1,
60
  guidance_scale: float = 5.0
61
- ) -> str:
62
  """
63
  Generate an image using the local transformers model.
64
  """
@@ -74,34 +76,39 @@ def generate_image_api(
74
  inputs = processor(text=prompt, return_tensors="pt").to(model.device)
75
 
76
  with torch.no_grad():
77
- # This is a placeholder for the actual generation call.
78
- # Most transformers image-gen models use .generate() or a custom method.
79
- # Given AutoModelForImageTextToText, it might produce an image tensor.
80
  output = model.generate(
81
  **inputs,
82
- max_new_tokens=1024, # Adjust based on model specifics
83
- # Add other generation params like guidance_scale if supported
 
 
84
  )
85
 
86
  # Process the output to an image
87
- # NOTE: AutoModelForImageTextToText usually generates text tokens.
88
- # If HiDream-O1-Image generates image tokens, you may need a custom
89
- # decoder or a different AutoModel class (e.g. AutoModelForTextToImage).
90
- # This implementation assumes processor.batch_decode can handle the output.
91
- generated_output = processor.batch_decode(output, skip_special_tokens=True)[0]
92
-
93
- # If the output is actually an image (PIL or Tensor), we handle it here.
94
- # For now, let's assume it returns a PIL image or we can convert it.
95
- if isinstance(generated_output, Image.Image):
96
- img = generated_output
97
  else:
98
- # Fallback: create a dummy image if decoding fails to show something
99
- logger.warning("Generated output was not a PIL image, creating placeholder.")
100
- img = Image.new("RGB", (1024, 1024), color=(50, 50, 150))
 
 
 
 
 
 
 
101
 
102
  out_path = f"generated_{int(time.time())}_{random.randint(0, 1000)}.png"
103
  img.save(out_path)
104
- return out_path
105
 
106
  @app.get("/")
107
  async def index():
 
9
  from gradio import Server
10
  from fastapi.responses import HTMLResponse
11
  import torch
12
+ from transformers import AutoProcessor, AutoModel
13
  from PIL import Image
14
  from dotenv import load_dotenv
15
 
 
31
 
32
  load_dotenv()
33
 
34
+ from gradio.data_classes import FileData
35
+
36
  # Load model and processor
37
  logger.info("Loading model and processor...")
38
  model_id = "HiDream-ai/HiDream-O1-Image"
39
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
40
+ model = AutoModel.from_pretrained(
41
  model_id,
42
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
43
  trust_remote_code=True
 
51
 
52
  app = Server()
53
 
54
+ @app.api()
55
  @spaces.GPU
56
+ def generate(
57
  prompt: str,
58
  wh_ratio: str = "1:1",
59
  negative_prompt: str = "",
60
  enable_prompt_refine: bool = True,
61
  seed: int = -1,
62
  guidance_scale: float = 5.0
63
+ ) -> FileData:
64
  """
65
  Generate an image using the local transformers model.
66
  """
 
76
  inputs = processor(text=prompt, return_tensors="pt").to(model.device)
77
 
78
  with torch.no_grad():
79
+ # HiDream-O1 often takes parameters in the prompt or as kwargs
80
+ # We pass them here just in case the custom modeling code supports them
 
81
  output = model.generate(
82
  **inputs,
83
+ max_new_tokens=1024,
84
+ negative_prompt=negative_prompt,
85
+ guidance_scale=guidance_scale,
86
+ wh_ratio=wh_ratio,
87
  )
88
 
89
  # Process the output to an image
90
+ # HiDream models often return a PIL image directly or in a list
91
+ if isinstance(output, Image.Image):
92
+ img = output
93
+ elif isinstance(output, list) and len(output) > 0 and isinstance(output[0], Image.Image):
94
+ img = output[0]
95
+ elif hasattr(output, "images") and output.images:
96
+ img = output.images[0]
 
 
 
97
  else:
98
+ # Fallback to decoder for text-based or token-based models
99
+ logger.info("Output is not a PIL image, attempting to decode...")
100
+ generated_output = processor.batch_decode(output, skip_special_tokens=True)[0]
101
+
102
+ if isinstance(generated_output, Image.Image):
103
+ img = generated_output
104
+ else:
105
+ # Fallback: create a dummy image if decoding fails to show something
106
+ logger.warning("Generated output was not a PIL image, creating placeholder.")
107
+ img = Image.new("RGB", (1024, 1024), color=(50, 50, 150))
108
 
109
  out_path = f"generated_{int(time.time())}_{random.randint(0, 1000)}.png"
110
  img.save(out_path)
111
+ return FileData(path=out_path)
112
 
113
  @app.get("/")
114
  async def index():