JaydeepR commited on
Commit
9b33231
·
verified ·
1 Parent(s): 94019f0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from io import BytesIO
7
+ import uuid
8
+ import gc
9
+
10
+ import sys
11
+ import os
12
+
13
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
14
+ from segmentation_model import load_model,transform_image, run_inference, save_input_image, save_objects_and_metadata, extract_object
15
+ from identification_model import load_yolov8_model, run_object_detection
16
+ # from models.text_extraction_model import extract_text
17
+ # from models.summarization_model import summarize_text
18
+ # from utils.data_mapping import create_summary_table
19
+
20
+
21
+
22
+ model = load_model()
23
+ detection_model = load_yolov8_model()
24
+
25
+ def resize_image(image, size=(800, 800)):
26
+ return image.resize(size, Image.ANTIALIAS)
27
+
28
+ def display_masks(outputs, image, threshold=0.5):
29
+ masks = outputs[0]['masks']
30
+ scores = outputs[0]['scores']
31
+
32
+ fig, ax = plt.subplots()
33
+ ax.imshow(np.array(image))
34
+
35
+ extracted_objects = []
36
+
37
+ for i in range(len(scores)):
38
+ if scores[i] > threshold:
39
+ mask = masks[i].squeeze().cpu().numpy()
40
+ mask = np.where(mask > 0.5, 1, 0).astype(np.uint8)
41
+
42
+ object_img = extract_object(image,mask)
43
+ extracted_objects.append(object_img)
44
+ #Display the mask
45
+ ax.imshow(mask, cmap='jet', alpha=0.5) # Overlay mask on image
46
+
47
+ st.pyplot(fig)
48
+
49
+ return extracted_objects
50
+
51
+
52
+
53
+ st.title("Image Segmentation with Mask R-CNN and Object Detection")
54
+
55
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
56
+
57
+ if uploaded_file is not None:
58
+ # Convert uploaded file to PIL Image
59
+ image = uploaded_file
60
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
61
+ image = Image.open(uploaded_file).convert('RGB')
62
+ # Generate a unique master ID for the image
63
+ master_id = str(uuid.uuid4())
64
+
65
+ # Save the input image
66
+ save_input_image(image, master_id)
67
+ # Transform image
68
+ image_tensor = transform_image(image)
69
+ outputs = run_inference(model, image_tensor)
70
+
71
+ extracted_objects = display_masks(outputs, image)
72
+
73
+ if extracted_objects:
74
+ # Save the extracted objects and their metadata
75
+ metadata = save_objects_and_metadata(extracted_objects, master_id)
76
+
77
+ # Display metadata as a JSON output
78
+ st.write("Metadata for extracted objects:")
79
+ #st.json(metadata)
80
+
81
+ # Display each extracted object
82
+ st.write("Extracted Objects:")
83
+ for i, obj_img in enumerate(extracted_objects):
84
+ st.image(obj_img, caption=f'Object {i+1}', use_column_width=True)
85
+ # Convert the object image to a numpy array for YOLO inference
86
+ obj_img_np = np.array(obj_img)
87
+ # Run object detection on each extracted object
88
+ detection_results = run_object_detection(detection_model, obj_img_np)
89
+ st.write(f"Detection results for Object {i+1}:")
90
+ st.json(detection_results)
91
+ else:
92
+ st.write("No objects were detected")
93
+
94
+
95
+ # del extracted_objects
96
+ # gc.collect()
97
+
98
+ # Display results
99
+ #display_masks(outputs, image)
100
+
101
+
102
+
103
+
104
+
105
+ if uploaded_file is not None:
106
+ image = Image.open(uploaded_file).convert("RGB")
107
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
108
+
109
+ image_tensor = transform_image(image)
110
+ outputs = run_inference(model, image_tensor)
111
+
112
+ display_masks(outputs, image)
113
+
114
+
115
+
116
+ # def upload_image():
117
+ # uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
118
+ # if uploaded_file is not None:
119
+ # image = Image.open(uploaded_file)
120
+ # return image
121
+ # return None
122
+
123
+
124
+ # # def display_segmentation(image):
125
+ # # st.image(image, caption="Original Image", use_column_width=True)
126
+
127
+ # # Transform and run inference
128
+ # # image_tensor = transform_image(image)
129
+ # # outputs = run_inference(image_tensor)
130
+
131
+ # # # Save segmented objects
132
+ # # output_dir = 'segmented_objects/'
133
+ # # save_segmented_objects(image, outputs, output_dir)
134
+
135
+ # # segmented_images = [Image.open(f"{output_dir}object_{i+1}.png") for i in range(len(outputs[0]['scores']))]
136
+ # # for img in segmented_images:
137
+ # # st.image(img, caption="Segmented Object", use_column_width=True)
138
+
139
+
140
+
141
+
142
+ # def main():
143
+ # st.title("Image Processing Pipeline")
144
+
145
+ # # uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png"])
146
+ # # if uploaded_file:
147
+ # # image_path = f"data/input_images/{uploaded_file.name}"
148
+ # # image = Image.open(uploaded_file)
149
+ # # image.save(image_path) # Save the uploaded image for further processing
150
+ # # st.image(image, caption="Uploaded Image")
151
+
152
+ # # if st.button("Segment Image"):
153
+ # # segmented = segment_image(image_path)
154
+ # # st.image(segmented, caption="Segmented Image", use_column_width=True)
155
+
156
+ # # if st.button("Identify and Extract Objects"):
157
+ # # objects_data = identify_and_extract_objects(image_path)
158
+ # # extracted_objects = []
159
+
160
+ # # for obj_data in objects_data:
161
+ # # object_image = Image.open(obj_data['Image Path'])
162
+ # # text = extract_text(object_image)
163
+ # # summary = summarize_text(text)
164
+ # # obj_data['Text'] = text
165
+ # # obj_data['Summary'] = summary
166
+ # # extracted_objects.append(obj_data)
167
+
168
+ # # st.image(object_image, caption=f"Object {obj_data['ID']} - Label {obj_data['Label']}")
169
+
170
+ # # summary_file = create_summary_table(extracted_objects)
171
+ # # st.write(pd.DataFrame(extracted_objects))
172
+ # # st.download_button(label="Download Summary Table", data=open(summary_file).read(), file_name="summary.csv")
173
+
174
+ # if __name__ == "__main__":
175
+ # main()