| """ |
| Attribution: https://github.com/AIPI540/AIPI540-Deep-Learning-Applications/ |
| |
| Jon Reifschneider |
| Brinnae Bent |
| |
| """ |
|
|
| import streamlit as st |
| from PIL import Image |
| import numpy as np |
| import os |
| import numpy as np |
| import pandas as pd |
| import pandas as pd |
| import os |
| import json |
| import pandas as pd |
| import torch |
| import numpy as np |
| import pandas as pd |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| from ultralytics import YOLO |
| from PIL import Image, ImageDraw, ImageFont |
| import numpy as np |
| import cv2 |
| import pytesseract |
| from PIL import ImageEnhance |
| import numpy as np |
| import os |
| import json |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments |
| from datasets import load_dataset |
| from transformers import DataCollatorForLanguageModeling |
| from PIL import Image, ImageEnhance |
| from io import StringIO |
|
|
|
|
| def crop_image(model, original_image): |
| """ |
| Crop the region of interest (table) from an image using a YOLO model. |
| |
| Inputs: |
| model (YOLO): The YOLO model used for object detection. |
| original_image (PIL.image): The image to be processed. |
| |
| Returns: |
| PIL.Image: The cropped image containing the detected table. |
| """ |
| image_array = np.array(image) |
| results = model(image_array) |
|
|
| for r in results: |
| boxes = r.boxes |
| |
| for box in boxes: |
| if box.cls == 3: |
| x1, y1, x2, y2 = box.xyxy[0] |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) |
| |
| table_image = original_image.crop((x1, y1, x2, y2)) |
| |
| return table_image |
| return |
|
|
| def process_image(model, image): |
| """ |
| Process the uploaded image with YOLO model and draw bounding boxes with class-specific colors. |
| |
| Inputs: |
| model: The trained YOLO model |
| image: The image file uploaded through Streamlit. |
| |
| Returns: |
| PIL.Image: The processed image with bounding boxes and labels. |
| """ |
| colors = {'title': (255, 0, 0), |
| 'text': (0, 255, 0), |
| 'figure': (0, 0, 255), |
| 'table': (255, 255, 0), |
| 'list': (0, 255, 255)} |
|
|
| image_array = np.array(image) |
| results = model(image_array) |
| |
| for result in results: |
| boxes = result.boxes.cpu().numpy() |
| for box in boxes: |
| r = box.xyxy[0].astype(int) |
| label = result.names[int(box.cls)] |
| color = colors.get(label.lower(), (255, 255, 255)) |
| |
| cv2.rectangle(image_array, r[:2], r[2:], color, 2) |
| |
| label_size, baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) |
| top_left = (r[0], r[1] - label_size[1] - baseline) |
| bottom_right = (r[0] + label_size[0], r[1]) |
| cv2.rectangle(image_array, top_left, bottom_right, color, cv2.FILLED) |
| cv2.putText(image_array, label, (r[0], r[1] - baseline), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) |
| |
| return Image.fromarray(image_array) |
|
|
| def improve_ocr_accuracy(img): |
| """ |
| Preprocess the image to improve OCR accuracy. |
| |
| This function resizes the image, increases contrast, and applies thresholding |
| to enhance the image for better OCR results. |
| |
| Inputs: |
| img (PIL.Image): The input image to be processed. |
| |
| Returns: |
| numpy.ndarray: A binary thresholded image as a numpy array. |
| """ |
| img = img.resize((img.width * 4, img.height * 4)) |
| |
| enhancer = ImageEnhance.Contrast(img) |
| img = enhancer.enhance(2) |
|
|
| _, thresh = cv2.threshold(np.array(img), 127, 255, cv2.THRESH_BINARY_INV) |
| |
| return thresh |
|
|
| def ocr_core(image): |
| """ |
| Perform OCR on the given image and process the extracted text. |
| |
| This function uses pytesseract to extract text from the image and then |
| processes the extracted data to format it with appropriate line breaks |
| and spacing. |
| |
| Inputs: |
| image (numpy.ndarray): The preprocessed image as a numpy array. |
| |
| Returns: |
| str: The extracted and formatted text from the image. |
| """ |
| data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT) |
| df = pd.DataFrame(data) |
| df = df[df['conf'] != -1] |
| df['left_diff'] = df.groupby('block_num')['left'].diff().fillna(0).astype(int) |
| df['prev_width'] = df['width'].shift(1).fillna(0).astype(int) |
| df['spacing'] = (df['left_diff'] - df['prev_width']).fillna(0).astype(int) |
| df['text'] = df.apply(lambda x: '\n' + x['text'] if (x['word_num'] == 1) & (x['block_num'] != 1) else x['text'], axis=1) |
| df['text'] = df.apply(lambda x: ',' + x['text'] if x['spacing'] > 80 else x['text'], axis=1) |
| ocr_text = "" |
| for text in df['text']: |
| ocr_text += text + ' ' |
| return ocr_text |
|
|
| def generate_csv_from_text(tokenizer, model, ocr_text): |
| """ |
| Generate CSV text from OCR extracted text using the gpt model |
| |
| This function takes the OCR extracted text, processes it through a language model, |
| and generates CSV formatted text. |
| |
| Inputs: |
| tokenizer: The tokenizer for the gpt model |
| model: The gpt model used for csv |
| ocr_text (str): The text extracted from OCR |
| |
| Returns: |
| str: The generated CSV formatted text. |
| """ |
| inputs = tokenizer.encode(ocr_text, return_tensors='pt') |
| outputs = model.generate(inputs, max_length=1000, num_return_sequences=1) |
| csv_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| return csv_text |
|
|
| if __name__ == '__main__': |
| |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| model = YOLO(os.getcwd() + '/models/trained_yolov8.pt') |
| gpt_model = GPT2LMHeadModel.from_pretrained(os.getcwd() + '/models/gpt_model') |
| tokenizer = GPT2Tokenizer.from_pretrained(os.getcwd() + '/models/gpt_model') |
| |
| st.header(''' |
| Intelligent Document Processing: Table Extraction |
| ''') |
| |
| header_img = Image.open('assets/header_img.png') |
| st.image(header_img, use_column_width=True) |
| |
| st.subheader("Please upload an image of a scanned document with a table using the sidebar") |
| |
| with st.sidebar: |
| user_image = st.file_uploader("Upload an image of a scanned document", type=["png", "jpg", "jpeg"]) |
|
|
| if user_image is not None: |
| st.divider() |
| image = Image.open(user_image) |
| st.image(image, caption='Uploaded Image', use_column_width=True) |
| |
| st.divider() |
| st.subheader("Document Classes:") |
| processed_image = process_image(model, image) |
| st.image(processed_image, caption='Processed Image', use_column_width=True) |
| |
| try: |
| cropped_table = crop_image(model, image) |
| st.divider() |
| st.subheader("Table Cropped Image:") |
| st.image(cropped_table, caption='Cropped Table', use_column_width=True) |
|
|
| improved_image = improve_ocr_accuracy(cropped_table) |
| st.divider() |
| st.subheader("Improved Table Image:") |
| st.image(improved_image, caption='Improved Table Image', use_column_width=True) |
| |
| ocr_text = ocr_core(improved_image) |
| st.divider() |
| st.subheader("OCR Text:") |
| st.write(ocr_text) |
| |
| csv_output = generate_csv_from_text(tokenizer,gpt_model,ocr_text) |
| st.divider() |
| st.subheader("CSV Output:") |
| st.write(csv_output.encode('utf-8')) |
| except: |
| st.divider() |
| st.subheader("Error:") |
| st.write("Please upload a scanned document with a table") |
|
|
|
|
|
|
|
|