JaydeepR commited on
Commit
5eb4eea
·
verified ·
1 Parent(s): 4dbcdff

Create segmentation_model.py

Browse files
Files changed (1) hide show
  1. segmentation_model.py +99 -0
segmentation_model.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ from torchvision.models.detection import maskrcnn_resnet50_fpn
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import uuid
8
+ import os
9
+ import cv2
10
+ import json
11
+
12
+
13
+ input_images_dir = 'data/input_images/'
14
+ segmented_objects_dir = 'data/segmented_objects/'
15
+ os.makedirs(input_images_dir, exist_ok=True)
16
+ os.makedirs(segmented_objects_dir, exist_ok=True)
17
+
18
+ #Loading the model
19
+
20
+ def load_model():
21
+ model = maskrcnn_resnet50_fpn(pretrained=True)
22
+ # Using a different backbone
23
+ #model = maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False, backbone_name='resnext50_32x4d')
24
+ model.eval()
25
+ """
26
+ We have set this to evaluation mode,
27
+ because we have loaded a pretrained model
28
+ so we must deactivate dropout layers and other
29
+ training-specific behaviors.
30
+ """
31
+ return model
32
+
33
+ model = load_model() #model initialization
34
+
35
+
36
+ def transform_image(image):
37
+ transform = T.Compose([
38
+ T.Resize((256, 256)), # Resize to match model input
39
+ T.ToTensor(), # Convert to torch tensor
40
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize
41
+ ])
42
+ return transform(image).unsqueeze(0) # Add batch dimension to get [1,C,H,W] #C is channels, RGB has 3, greyscale has 1
43
+
44
+
45
+ # # Test image transformation
46
+ # image_path = "D:\multiobject.jpeg" # Replace with the path to your image
47
+ # image_tensor = transform_image(image_path)
48
+
49
+ def run_inference(model,image_tensor):
50
+ with torch.no_grad():
51
+ outputs = model(image_tensor)
52
+ return outputs
53
+
54
+ def extract_object(image, mask):
55
+ img_np = np.array(image)
56
+
57
+ # Resize mask to match image dimensions
58
+ mask_resized = cv2.resize(mask, (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST)
59
+
60
+ # Create an empty image with the same dimensions as the original image
61
+ object_img = np.zeros_like(img_np)
62
+
63
+ # Apply the mask to the image
64
+ for c in range(3): # Assuming image has 3 channels (RGB)
65
+ object_img[:, :, c] = img_np[:, :, c] * mask_resized
66
+
67
+ return Image.fromarray(object_img)
68
+
69
+ # def extract_object(image, mask):
70
+ # object_img = Image.fromarray((np.array(image) * mask[:, :, None]).astype(np.uint8))
71
+ # return object_img
72
+
73
+ # Save the input image
74
+ def save_input_image(image, master_id):
75
+ input_image_path = os.path.join(input_images_dir, f'{master_id}.png')
76
+ image.save(input_image_path)
77
+ return input_image_path
78
+
79
+ # Save the extracted objects and their metadata
80
+ def save_objects_and_metadata(extracted_objects, master_id):
81
+ object_metadata = []
82
+
83
+ for i, obj_img in enumerate(extracted_objects):
84
+ object_id = str(uuid.uuid4())
85
+ object_image_path = os.path.join(segmented_objects_dir, f'{object_id}.png')
86
+ obj_img.save(object_image_path)
87
+
88
+ metadata = {
89
+ 'object_id': object_id,
90
+ 'master_id': master_id,
91
+ 'object_image_path': object_image_path
92
+ }
93
+ object_metadata.append(metadata)
94
+
95
+ metadata_file = os.path.join(segmented_objects_dir, f'{master_id}_metadata.json')
96
+ with open(metadata_file, 'w') as f:
97
+ json.dump(object_metadata, f, indent=4)
98
+
99
+ return object_metadata