JaydeepR commited on
Commit
368a993
·
verified ·
1 Parent(s): b2901d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import uuid
7
+
8
+
9
+ import sys
10
+ import os
11
+
12
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
13
+ from segmentation_model import load_model,transform_image, run_inference, save_input_image, save_objects_and_metadata, extract_object
14
+ from identification_model import load_yolov8_model, run_object_detection
15
+ from text_extraction_model import extract_text
16
+ from summarization_model import generate_description, summarize_text_and_image
17
+ from data_mapping import map_object_data, create_summary_table, save_mapping_to_json
18
+ from visualization import generate_output
19
+
20
+
21
+ #loading the required models
22
+ model = load_model()
23
+ detection_model = load_yolov8_model()
24
+
25
+ # def resize_image(image, size=(800, 800)):
26
+ # return image.resize(size, Image.ANTIALIAS)
27
+
28
+ def display_masks(outputs, image, threshold=0.5):
29
+ masks = outputs[0]['masks']
30
+ scores = outputs[0]['scores']
31
+
32
+ fig, ax = plt.subplots()
33
+ ax.imshow(np.array(image))
34
+
35
+ extracted_objects = []
36
+
37
+ for i in range(len(scores)):
38
+ if scores[i] > threshold:
39
+ mask = masks[i].squeeze().cpu().numpy()
40
+ mask = np.where(mask > 0.5, 1, 0).astype(np.uint8)
41
+
42
+ object_img = extract_object(image,mask)
43
+ extracted_objects.append(object_img)
44
+ #Display the mask
45
+ ax.imshow(mask, cmap='jet', alpha=0.5) # Overlay mask on image
46
+
47
+ st.pyplot(fig)
48
+
49
+ return extracted_objects
50
+
51
+
52
+
53
+ st.title("AI Pipeline: Image Segmentation, Object Detection, and Text Extraction")
54
+ st.sidebar.header("Options")
55
+ st.sidebar.text("Upload an image to start processing.")
56
+
57
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
58
+
59
+ if uploaded_file is not None:
60
+ # Convert uploaded file to PIL Image
61
+ image = uploaded_file
62
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
63
+ image = Image.open(uploaded_file).convert('RGB')
64
+
65
+ # Generate a unique master ID for the image
66
+ master_id = str(uuid.uuid4())
67
+
68
+ st.header("Results") #header for results
69
+
70
+ description = generate_description(image)
71
+ st.subheader("Image Description")
72
+ st.write("Generated Description:", description)
73
+
74
+ #extract text for the entire image
75
+ extracted_text = extract_text(image)
76
+
77
+ st.subheader("Extracted Text")
78
+ if extracted_text:
79
+ st.write(extracted_text)
80
+ else:
81
+ st.write("No text was detected")
82
+
83
+ #summarize the entire image
84
+ summary = summarize_text_and_image(description, extracted_text)
85
+ st.subheader("Image Summary")
86
+ st.write("Generated Summary:", summary)
87
+
88
+ # Save the input image
89
+ save_input_image(image, master_id)
90
+
91
+ # Transform image
92
+ image_tensor = transform_image(image)
93
+ outputs = run_inference(model, image_tensor)
94
+
95
+
96
+ extracted_objects = display_masks(outputs, image)
97
+
98
+ objects_data = []
99
+
100
+ if extracted_objects:
101
+ # Save the extracted objects and their metadata
102
+ metadata = save_objects_and_metadata(extracted_objects, master_id)
103
+
104
+
105
+ # Display each extracted object
106
+ st.write("Extracted Objects:")
107
+ for i, obj_img in enumerate(extracted_objects):
108
+ st.image(obj_img, caption=f'Object {i+1}', use_column_width=True)
109
+
110
+ obj_description = generate_description(obj_img)
111
+ st.write("Generated Description:", description)
112
+
113
+ # Convert the object image to a numpy array for YOLO inference
114
+ #obj_img_np = np.array(obj_img)
115
+
116
+ # Run object detection on each extracted object
117
+ detection_results = run_object_detection(detection_model, obj_img)
118
+ st.write(f"Detection results for Object {i+1}:")
119
+ st.json(detection_results)
120
+
121
+ obj_text = extract_text(obj_img)
122
+ if obj_text:
123
+ st.write(f"Extracted Text for Object {i+1}:")
124
+ st.json(obj_text)
125
+ else:
126
+ st.write("No text was detected")
127
+
128
+ obj_summary = summarize_text_and_image(obj_description, obj_text)
129
+ st.write(f"Object Summary:\n{obj_summary}")
130
+
131
+ object_id = str(uuid.uuid4())
132
+ object_data = map_object_data(object_id, obj_description, obj_text, obj_summary)
133
+ objects_data.append(object_data)
134
+
135
+ data_mapping = create_summary_table(objects_data)
136
+ output_path = os.path.join("data", "output", f"{master_id}_data_mapping.json")
137
+ save_mapping_to_json(data_mapping, output_path)
138
+
139
+ # Generate the final output image with annotations and summary table
140
+ annotated_image_path, summary_table_path = generate_output(image, outputs[0]['masks'], objects_data, master_id)
141
+
142
+ st.subheader("Final Output")
143
+
144
+ # Display the annotated image
145
+ st.image(annotated_image_path, caption='Annotated Image', use_column_width=True)
146
+
147
+ # Provide a download link for the summary table
148
+ st.write("Summary Table:")
149
+ st.write(f"Download the summary table [here](data/output/{master_id}_summary.csv)")
150
+
151
+ # Display the mapped data
152
+ st.write("Mapped Data:")
153
+ st.json(data_mapping)
154
+
155
+ # # Display the JSON data
156
+ # st.write("Mapped Data:")
157
+ # st.json(data_mapping)
158
+
159
+ else:
160
+ st.write("No objects were detected")
161
+
162
+
163
+