prithivMLmods commited on
Commit
33cd763
·
verified ·
1 Parent(s): b413167

update app

Browse files
Files changed (1) hide show
  1. app.py +69 -66
app.py CHANGED
@@ -15,23 +15,15 @@ from PIL import Image, ImageOps
15
  import requests
16
 
17
  from transformers import (
18
- Qwen2VLForConditionalGeneration,
19
  Qwen2_5_VLForConditionalGeneration,
20
- AutoModelForImageTextToText,
21
  AutoProcessor,
22
  TextIteratorStreamer,
23
  )
24
-
25
  from transformers.image_utils import load_image
26
-
27
  from gradio.themes import Soft
28
  from gradio.themes.utils import colors, fonts, sizes
29
-
30
- from docling_core.types.doc import DoclingDocument, DocTagsDocument
31
-
32
- import re
33
- import ast
34
- import html
35
 
36
  # --- Theme and CSS Definition ---
37
 
@@ -107,44 +99,50 @@ css = """
107
  """
108
 
109
  # Constants for text generation
110
- MAX_MAX_NEW_TOKENS = 4096
111
- DEFAULT_MAX_NEW_TOKENS = 2048
112
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
113
 
114
- # Check for CUDA availability
115
- device = "cuda" if torch.cuda.is_available() else "cpu"
116
 
117
- # Load Nanonets-OCR2-3B
118
- MODEL_ID_3B = "nanonets/Nanonets-OCR2-3B"
119
- processor_3b = AutoProcessor.from_pretrained(MODEL_ID_3B, trust_remote_code=True)
120
- model_3b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
121
- MODEL_ID_3B,
122
- torch_dtype=torch.float16,
123
  trust_remote_code=True,
 
124
  ).to(device).eval()
125
 
126
- # Load Nanonets-OCR2-1.5B-exp
127
- MODEL_ID_1_5B = "nanonets/Nanonets-OCR2-1.5B-exp"
128
- processor_1_5b = AutoProcessor.from_pretrained(MODEL_ID_1_5B, trust_remote_code=True)
129
- model_1_5b = AutoModelForImageTextToText.from_pretrained(
130
- MODEL_ID_1_5B,
131
- torch_dtype=torch.float16,
132
- trust_remote_code=True,
133
- attn_implementation="flash_attention_2"
134
- ).to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  @spaces.GPU
137
  def generate_image(model_name: str, text: str, image: Image.Image,
138
- max_new_tokens: int = 1024,
139
- temperature: float = 0.6,
140
- top_p: float = 0.9,
141
- top_k: int = 50,
142
- repetition_penalty: float = 1.2):
143
- """Generation function for image input."""
144
  if model_name == "Nanonets-OCR2-3B":
145
- processor, model = processor_3b, model_3b
146
- elif model_name == "Nanonets-OCR2-1.5B-exp":
147
- processor, model = processor_1_5b, model_1_5b
148
  else:
149
  yield "Invalid model selected.", "Invalid model selected."
150
  return
@@ -152,18 +150,19 @@ def generate_image(model_name: str, text: str, image: Image.Image,
152
  if image is None:
153
  yield "Please upload an image.", "Please upload an image."
154
  return
155
-
156
  images = [image]
157
 
158
  messages = [
159
  {
160
  "role": "user",
161
- "content": [{"type": "image"}] + [{"type": "text", "text": text}]
 
 
162
  }
163
  ]
164
-
165
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
166
- inputs = processor(text=prompt, images=images, return_tensors="pt")
167
 
168
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
169
  generation_kwargs = {
@@ -175,38 +174,49 @@ def generate_image(model_name: str, text: str, image: Image.Image,
175
  "top_k": top_k,
176
  "repetition_penalty": repetition_penalty,
177
  }
 
 
 
 
 
 
178
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
179
  thread.start()
180
 
181
  buffer = ""
182
  for new_text in streamer:
183
- buffer += new_text.replace("<|im_end|>", "")
184
  yield buffer, buffer
185
 
 
 
 
 
186
  # Define examples for image inference
187
  image_examples = [
188
  ["Reconstruct the doc [table] as it is.", "images/0.png"],
189
  ["Describe the image!", "images/8.png"],
190
  ["OCR the image", "images/2.jpg"],
191
- ["Convert this page to docling", "images/1.png"],
192
- ["Convert this page to docling", "images/3.png"],
193
- ["Convert chart to OTSL.", "images/4.png"],
194
- ["Convert code to text", "images/5.jpg"],
195
- ["Convert this table to OTSL.", "images/6.jpg"],
196
- ["Convert formula to late.", "images/7.jpg"],
197
  ]
198
 
 
199
  # Create the Gradio Interface
200
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
201
- gr.Markdown("# **Multimodal OCR3**", elem_id="main-title")
202
  with gr.Row():
203
  with gr.Column(scale=2):
204
- # Image Inference Components
205
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
206
- image_upload = gr.Image(type="pil", label="Upload Image", height=290)
207
- image_submit = gr.Button("Submit", variant="primary")
208
- gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
 
 
 
209
 
 
 
210
  with gr.Accordion("Advanced options", open=False):
211
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
212
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
@@ -216,19 +226,12 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
216
 
217
  with gr.Column(scale=3):
218
  gr.Markdown("## Output", elem_id="output-title")
219
- raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
220
- with gr.Accordion("(Result.md)", open=True):
221
- formatted_output = gr.Markdown(label="(Result.md)")
222
-
223
- model_choice = gr.Radio(
224
- choices=["Nanonets-OCR2-3B", "Nanonets-OCR2-1.5B-exp"],
225
- label="Select Model",
226
- value="Nanonets-OCR2-3B"
227
- )
228
 
229
- image_submit.click(
230
  fn=generate_image,
231
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
232
  outputs=[raw_output, formatted_output]
233
  )
234
 
 
15
  import requests
16
 
17
  from transformers import (
 
18
  Qwen2_5_VLForConditionalGeneration,
19
+ AutoModelForCausalLM,
20
  AutoProcessor,
21
  TextIteratorStreamer,
22
  )
 
23
  from transformers.image_utils import load_image
 
24
  from gradio.themes import Soft
25
  from gradio.themes.utils import colors, fonts, sizes
26
+ from huggingface_hub import snapshot_download
 
 
 
 
 
27
 
28
  # --- Theme and CSS Definition ---
29
 
 
99
  """
100
 
101
  # Constants for text generation
102
+ MAX_MAX_NEW_TOKENS = 5120
103
+ DEFAULT_MAX_NEW_TOKENS = 3072
104
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
105
 
106
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
107
 
108
+ # Load Nanonets-OCR-s
109
+ MODEL_ID_M = "nanonets/Nanonets-OCR2-3B"
110
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
111
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
112
+ MODEL_ID_M,
 
113
  trust_remote_code=True,
114
+ torch_dtype=torch.float16
115
  ).to(device).eval()
116
 
117
+ # Load Dots.OCR
118
+ MODEL_ID_D = "rednote-hilab/dots.ocr"
119
+ model_path_d = "./models/dots-ocr-local"
120
+ snapshot_download(
121
+ repo_id=MODEL_ID_D,
122
+ local_dir=model_path_d,
123
+ local_dir_use_symlinks=False,
124
+ )
125
+ model_d = AutoModelForCausalLM.from_pretrained(
126
+ model_path_d,
127
+ attn_implementation="flash_attention_2" if "cuda" in device.type else "eager",
128
+ torch_dtype=torch.bfloat16,
129
+ device_map="auto",
130
+ trust_remote_code=True
131
+ )
132
+ processor_d = AutoProcessor.from_pretrained(
133
+ model_path_d,
134
+ trust_remote_code=True
135
+ )
136
+
137
 
138
  @spaces.GPU
139
  def generate_image(model_name: str, text: str, image: Image.Image,
140
+ max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
141
+ """Generate responses for image input using the selected model."""
 
 
 
 
142
  if model_name == "Nanonets-OCR2-3B":
143
+ processor, model = processor_m, model_m
144
+ elif model_name == "Dots.OCR":
145
+ processor, model = processor_d, model_d
146
  else:
147
  yield "Invalid model selected.", "Invalid model selected."
148
  return
 
150
  if image is None:
151
  yield "Please upload an image.", "Please upload an image."
152
  return
153
+
154
  images = [image]
155
 
156
  messages = [
157
  {
158
  "role": "user",
159
+ "content": [{"type": "image"}] * len(images) + [
160
+ {"type": "text", "text": text}
161
+ ]
162
  }
163
  ]
 
164
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
165
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(model.device)
166
 
167
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
168
  generation_kwargs = {
 
174
  "top_k": top_k,
175
  "repetition_penalty": repetition_penalty,
176
  }
177
+
178
+ # Dots.OCR uses a different generation parameter name for end-of-sequence
179
+ if "dots.ocr" in model.config.name_or_path.lower():
180
+ generation_kwargs["eos_token_id"] = processor.tokenizer.eos_token_id
181
+
182
+
183
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
184
  thread.start()
185
 
186
  buffer = ""
187
  for new_text in streamer:
188
+ buffer += new_text.replace("<|im_end|>", "").replace("</s>", "")
189
  yield buffer, buffer
190
 
191
+ # The formatted output is the same as the raw output in this version
192
+ yield buffer, buffer
193
+
194
+
195
  # Define examples for image inference
196
  image_examples = [
197
  ["Reconstruct the doc [table] as it is.", "images/0.png"],
198
  ["Describe the image!", "images/8.png"],
199
  ["OCR the image", "images/2.jpg"],
200
+ ["Convert this page to markdown", "images/1.png"],
 
 
 
 
 
201
  ]
202
 
203
+
204
  # Create the Gradio Interface
205
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
206
+ gr.Markdown("# **Multimodal Image OCR**", elem_id="main-title")
207
  with gr.Row():
208
  with gr.Column(scale=2):
209
+ model_choice = gr.Radio(
210
+ choices=["Nanonets-OCR2-3B", "Dots.OCR"],
211
+ label="Select Model",
212
+ value="Nanonets-OCR-s"
213
+ )
214
+ query_input = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
215
+ image_upload = gr.Image(type="pil", label="Upload Image", height=320)
216
+ submit_button = gr.Button("Submit", variant="primary")
217
 
218
+ gr.Examples(examples=image_examples, inputs=[query_input, image_upload])
219
+
220
  with gr.Accordion("Advanced options", open=False):
221
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
222
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
 
226
 
227
  with gr.Column(scale=3):
228
  gr.Markdown("## Output", elem_id="output-title")
229
+ raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=18, show_copy_button=True)
230
+ formatted_output = gr.Markdown(label="Formatted Output (Result.md)")
 
 
 
 
 
 
 
231
 
232
+ submit_button.click(
233
  fn=generate_image,
234
+ inputs=[model_choice, query_input, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
235
  outputs=[raw_output, formatted_output]
236
  )
237