| |
| |
| |
| |
| import gradio as gr |
| import os |
| import torch |
| from transformers import AutoProcessor, MllamaForConditionalGeneration, TextStreamer |
| from PIL import Image |
| import csv |
| import spaces |
| |
| IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1" |
| IS_SPACE = os.environ.get("SPACE_ID", None) is not None |
| IS_GDRVIE = False |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1" |
| print(f"Using device: {device}") |
| print(f"Low memory mode: {LOW_MEMORY}") |
|
|
| |
| HF_TOKEN = os.environ.get('HF_TOKEN') |
|
|
| |
| model_name = "Llama-3.2-11B-Vision-Instruct" |
| if IS_GDRVIE: |
| |
| model_path = "/content/drive/MyDrive/models/" + model_name |
| model = MllamaForConditionalGeneration.from_pretrained( |
| model_path, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
| processor = AutoProcessor.from_pretrained(model_path) |
| else: |
| |
| HF_TOKEN = os.environ.get('HF_TOKEN') |
|
|
| |
| model_name = "ruslanmv/Llama-3.2-11B-Vision-Instruct" |
| model = MllamaForConditionalGeneration.from_pretrained( |
| model_name, |
| use_auth_token=HF_TOKEN, |
| torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, |
| device_map="auto" if device == "cuda" else None, |
| ) |
|
|
| |
| model.to(device) |
| processor = AutoProcessor.from_pretrained(model_name, use_auth_token=HF_TOKEN) |
|
|
|
|
| |
| if hasattr(model, "tie_weights"): |
| model.tie_weights() |
|
|
| example = '''Table 1: |
| header1,header2,header3 |
| value1,value2,value3 |
| |
| Table 2: |
| header1,header2,header3 |
| value1,value2,value3 |
| ''' |
|
|
| prompt_message = """Please extract all tables from the image and generate CSV files. |
| Each table should be separated using the format table_n.csv, where n is the table number. |
| You must use CSV format with commas as the delimiter. Do not use markdown format. Ensure you use the original table headers and content from the image. |
| Only answer with the CSV content. Dont explain the tables. |
| An example of the formatting output is as follows: |
| """ + example |
|
|
|
|
| |
| def stream_response(inputs): |
| streamer = TextStreamer(tokenizer=processor.tokenizer) |
| for token in model.generate(**inputs, max_new_tokens=2000, do_sample=True, streamer=streamer): |
| yield processor.decode(token, skip_special_tokens=True) |
|
|
|
|
| @spaces.GPU |
| |
| def predict(message, image): |
| |
| messages = [ |
| {"role": "user", "content": [ |
| {"type": "image"}, |
| {"type": "text", "text": message} |
| ]} |
| ] |
|
|
| |
| input_text = processor.apply_chat_template(messages, add_generation_prompt=True) |
|
|
| |
| inputs = processor(image, input_text, return_tensors="pt").to(device) |
|
|
| |
| full_response = "" |
| for response in stream_response(inputs): |
| |
| full_response += response |
| return extract_and_save_tables(full_response) |
|
|
| |
| files_list = [] |
|
|
| def clean_full_response(full_response): |
| """Cleans the full response by removing the prompt input before the tables.""" |
| |
| message_to_remove = prompt_message |
| |
| return full_response.replace(message_to_remove, "").strip() |
|
|
| def extract_and_save_tables(full_response): |
| """Extracts CSV tables from the cleaned_response string and saves them as separate files.""" |
| cleaned_response = clean_full_response(full_response) |
| files_list = [] |
| tables = cleaned_response.split("Table ") |
|
|
| for i, table in enumerate(tables[1:], start=1): |
| table_name = f"table_{i}.csv" |
| rows = table.strip().splitlines()[1:] |
| rows = [row.replace('"', '').split(",") for row in rows if row.strip()] |
|
|
| |
| with open(table_name, mode="w", newline='') as file: |
| writer = csv.writer(file) |
| writer.writerows(rows) |
| |
| files_list.append(table_name) |
|
|
| return files_list |
|
|
|
|
| |
| def gradio_app(): |
| def process_image(image): |
| message = prompt_message |
| files = predict(message, image) |
| return "Tables extracted and saved as CSV files.", files |
| |
| image_input = gr.Image(type="pil", label="Upload Image") |
|
|
| |
| output_text = gr.Textbox(label="Extraction Status") |
| file_output = gr.File(label="Download CSV files") |
|
|
| |
| iface = gr.Interface( |
| fn=process_image, |
| inputs=[image_input], |
| outputs=[output_text, file_output], |
| title="Table Extractor and CSV Converter", |
| description="Upload an image to extract tables and download CSV files.", |
| allow_flagging="never" |
| ) |
|
|
| iface.launch(debug=True) |
|
|
|
|
| |
| gradio_app() |