CodeJackR commited on
Commit ·
69e7d30
1
Parent(s): b0c81ad
Update to new way of handling model
Browse files- handler.py +156 -44
handler.py
CHANGED
|
@@ -5,7 +5,7 @@ import base64
|
|
| 5 |
import numpy as np
|
| 6 |
from PIL import Image
|
| 7 |
import torch
|
| 8 |
-
from transformers import SamModel,
|
| 9 |
from typing import Dict, List, Any
|
| 10 |
import torch.nn.functional as F
|
| 11 |
|
|
@@ -20,20 +20,84 @@ class EndpointHandler():
|
|
| 20 |
"""
|
| 21 |
try:
|
| 22 |
# Load the model and processor from the local path
|
| 23 |
-
self.model = SamModel.from_pretrained(path).to(device)
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
except Exception as e:
|
| 26 |
# Fallback to loading from a known SAM model if local loading fails
|
| 27 |
-
print("Failed to load from local path: {}"
|
| 28 |
print("Attempting to load from facebook/sam-vit-base")
|
| 29 |
-
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
|
| 30 |
-
self.processor =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
def __call__(self, data):
|
| 33 |
"""
|
| 34 |
Called on every HTTP request.
|
| 35 |
Handles both base64-encoded images and PIL images.
|
| 36 |
-
Returns a
|
| 37 |
"""
|
| 38 |
# 1. Parse and decode the input image
|
| 39 |
inputs = data.pop("inputs", None)
|
|
@@ -51,69 +115,117 @@ class EndpointHandler():
|
|
| 51 |
else:
|
| 52 |
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
|
| 53 |
|
| 54 |
-
# 2.
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
#
|
|
|
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
|
| 63 |
-
|
| 64 |
|
| 65 |
-
# 4. Process and select the best mask
|
| 66 |
try:
|
|
|
|
|
|
|
|
|
|
| 67 |
# Get predicted masks and scores
|
| 68 |
-
predicted_masks = outputs.pred_masks.cpu()
|
| 69 |
-
iou_scores = outputs.iou_scores.cpu()[
|
| 70 |
|
| 71 |
-
#
|
| 72 |
-
|
| 73 |
-
predicted_masks = predicted_masks.squeeze(1)
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
-
#
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
| 82 |
except Exception as e:
|
| 83 |
-
print("Error processing masks: {}"
|
| 84 |
-
# Fallback: create a simple mask
|
| 85 |
-
height, width = img.size[1], img.size[0]
|
| 86 |
mask_binary = np.zeros((height, width), dtype=np.uint8)
|
| 87 |
center_x, center_y = width // 2, height // 2
|
| 88 |
size = min(width, height) // 8
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
#
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
def main():
|
| 96 |
# This main function shows how a client would call the endpoint locally.
|
| 97 |
input_path = "/Users/rp7/Downloads/test.jpeg"
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
# 1. Prepare the payload with a base64-encoded image string
|
| 101 |
with open(input_path, "rb") as f:
|
| 102 |
img_bytes = f.read()
|
| 103 |
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
| 104 |
-
payload = {"inputs": "data:image/jpeg;base64,{}"
|
| 105 |
|
| 106 |
-
# 2. Instantiate handler and get the
|
| 107 |
handler = EndpointHandler(path=".")
|
| 108 |
-
|
| 109 |
|
| 110 |
-
# 3.
|
| 111 |
-
if
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
| 115 |
else:
|
| 116 |
-
print("Failed to get
|
| 117 |
|
| 118 |
if __name__ == "__main__":
|
| 119 |
main()
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
from PIL import Image
|
| 7 |
import torch
|
| 8 |
+
from transformers import SamModel, SamProcessor
|
| 9 |
from typing import Dict, List, Any
|
| 10 |
import torch.nn.functional as F
|
| 11 |
|
|
|
|
| 20 |
"""
|
| 21 |
try:
|
| 22 |
# Load the model and processor from the local path
|
| 23 |
+
self.model = SamModel.from_pretrained(path).to(device).eval()
|
| 24 |
+
# Load processor with do_resize=False to avoid resizing
|
| 25 |
+
self.processor = SamProcessor.from_pretrained(path)
|
| 26 |
+
# Override the processor's image processor to disable resizing
|
| 27 |
+
self.processor.image_processor.do_resize = False
|
| 28 |
+
self.processor.image_processor.do_rescale = True
|
| 29 |
+
self.processor.image_processor.do_normalize = True
|
| 30 |
+
|
| 31 |
except Exception as e:
|
| 32 |
# Fallback to loading from a known SAM model if local loading fails
|
| 33 |
+
print(f"Failed to load from local path: {e}")
|
| 34 |
print("Attempting to load from facebook/sam-vit-base")
|
| 35 |
+
self.model = SamModel.from_pretrained("facebook/sam-vit-base").to(device).eval()
|
| 36 |
+
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
| 37 |
+
# Override the processor's image processor to disable resizing
|
| 38 |
+
self.processor.image_processor.do_resize = False
|
| 39 |
+
self.processor.image_processor.do_rescale = True
|
| 40 |
+
self.processor.image_processor.do_normalize = True
|
| 41 |
+
|
| 42 |
+
def generate_grid_points(self, width, height, points_per_side=32):
|
| 43 |
+
"""Generate a grid of points across the image for comprehensive segmentation."""
|
| 44 |
+
points = []
|
| 45 |
+
labels = []
|
| 46 |
+
|
| 47 |
+
# Create a grid of points
|
| 48 |
+
x_coords = np.linspace(0, width - 1, points_per_side, dtype=int)
|
| 49 |
+
y_coords = np.linspace(0, height - 1, points_per_side, dtype=int)
|
| 50 |
+
|
| 51 |
+
for y in y_coords:
|
| 52 |
+
for x in x_coords:
|
| 53 |
+
points.append([x, y])
|
| 54 |
+
labels.append(1) # foreground point
|
| 55 |
+
|
| 56 |
+
return [points], [labels]
|
| 57 |
+
|
| 58 |
+
def filter_masks(self, masks, iou_scores, score_threshold=0.88, stability_score_threshold=0.95):
|
| 59 |
+
"""Filter masks based on quality scores and remove duplicates."""
|
| 60 |
+
filtered_masks = []
|
| 61 |
+
filtered_scores = []
|
| 62 |
+
|
| 63 |
+
for i, (mask, score) in enumerate(zip(masks, iou_scores)):
|
| 64 |
+
if score > score_threshold:
|
| 65 |
+
# Calculate stability score (measure of mask quality)
|
| 66 |
+
mask_binary = mask > 0.0
|
| 67 |
+
stability_score = self.calculate_stability_score(mask_binary)
|
| 68 |
+
|
| 69 |
+
if stability_score > stability_score_threshold:
|
| 70 |
+
filtered_masks.append(mask)
|
| 71 |
+
filtered_scores.append(score.item())
|
| 72 |
+
|
| 73 |
+
return filtered_masks, filtered_scores
|
| 74 |
+
|
| 75 |
+
def calculate_stability_score(self, mask):
|
| 76 |
+
"""Calculate stability score for a mask."""
|
| 77 |
+
# Simple stability score based on mask coherence
|
| 78 |
+
mask_float = mask.float()
|
| 79 |
+
# Calculate the ratio of the mask area to its bounding box area
|
| 80 |
+
mask_area = torch.sum(mask_float)
|
| 81 |
+
if mask_area == 0:
|
| 82 |
+
return 0.0
|
| 83 |
+
|
| 84 |
+
# Find bounding box
|
| 85 |
+
coords = torch.nonzero(mask_float)
|
| 86 |
+
if len(coords) == 0:
|
| 87 |
+
return 0.0
|
| 88 |
+
|
| 89 |
+
min_y, min_x = torch.min(coords, dim=0)[0]
|
| 90 |
+
max_y, max_x = torch.max(coords, dim=0)[0]
|
| 91 |
+
bbox_area = (max_y - min_y + 1) * (max_x - min_x + 1)
|
| 92 |
+
|
| 93 |
+
stability = mask_area / bbox_area if bbox_area > 0 else 0.0
|
| 94 |
+
return stability.item()
|
| 95 |
|
| 96 |
def __call__(self, data):
|
| 97 |
"""
|
| 98 |
Called on every HTTP request.
|
| 99 |
Handles both base64-encoded images and PIL images.
|
| 100 |
+
Returns a list of segment masks.
|
| 101 |
"""
|
| 102 |
# 1. Parse and decode the input image
|
| 103 |
inputs = data.pop("inputs", None)
|
|
|
|
| 115 |
else:
|
| 116 |
raise TypeError("Unsupported input type. Expected a PIL Image or a base64 encoded string.")
|
| 117 |
|
| 118 |
+
# 2. Get image dimensions
|
| 119 |
+
width, height = img.size
|
| 120 |
+
|
| 121 |
+
# 3. Generate grid points for comprehensive segmentation
|
| 122 |
+
input_points, input_labels = self.generate_grid_points(width, height, points_per_side=16)
|
| 123 |
|
| 124 |
+
# 4. Process the image and points
|
| 125 |
+
inputs = self.processor(
|
| 126 |
+
img,
|
| 127 |
+
input_points=input_points,
|
| 128 |
+
input_labels=input_labels,
|
| 129 |
+
return_tensors="pt"
|
| 130 |
+
).to(device)
|
| 131 |
|
| 132 |
+
# 5. Generate masks
|
| 133 |
+
all_masks = []
|
| 134 |
+
all_scores = []
|
| 135 |
|
|
|
|
| 136 |
try:
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
outputs = self.model(**inputs)
|
| 139 |
+
|
| 140 |
# Get predicted masks and scores
|
| 141 |
+
predicted_masks = outputs.pred_masks.cpu() # Shape: [batch, num_queries, num_masks_per_query, H, W]
|
| 142 |
+
iou_scores = outputs.iou_scores.cpu() # Shape: [batch, num_queries, num_masks_per_query]
|
| 143 |
|
| 144 |
+
# Process masks from all queries
|
| 145 |
+
batch_size, num_queries, num_masks_per_query = predicted_masks.shape[:3]
|
|
|
|
| 146 |
|
| 147 |
+
for query_idx in range(num_queries):
|
| 148 |
+
query_masks = predicted_masks[0, query_idx] # [num_masks_per_query, H, W]
|
| 149 |
+
query_scores = iou_scores[0, query_idx] # [num_masks_per_query]
|
| 150 |
+
|
| 151 |
+
# Select best mask for this query
|
| 152 |
+
best_mask_idx = torch.argmax(query_scores)
|
| 153 |
+
if query_scores[best_mask_idx] > 0.5: # Only keep high-quality masks
|
| 154 |
+
best_mask = query_masks[best_mask_idx]
|
| 155 |
+
all_masks.append(best_mask)
|
| 156 |
+
all_scores.append(query_scores[best_mask_idx])
|
| 157 |
|
| 158 |
+
# Filter and deduplicate masks
|
| 159 |
+
if all_masks:
|
| 160 |
+
filtered_masks, filtered_scores = self.filter_masks(all_masks, all_scores)
|
| 161 |
+
else:
|
| 162 |
+
filtered_masks, filtered_scores = [], []
|
| 163 |
+
|
| 164 |
except Exception as e:
|
| 165 |
+
print(f"Error processing masks: {e}")
|
| 166 |
+
# Fallback: create a simple center mask
|
|
|
|
| 167 |
mask_binary = np.zeros((height, width), dtype=np.uint8)
|
| 168 |
center_x, center_y = width // 2, height // 2
|
| 169 |
size = min(width, height) // 8
|
| 170 |
+
y_start, y_end = max(0, center_y-size), min(height, center_y+size)
|
| 171 |
+
x_start, x_end = max(0, center_x-size), min(width, center_x+size)
|
| 172 |
+
mask_binary[y_start:y_end, x_start:x_end] = 255
|
| 173 |
+
|
| 174 |
+
output_img = Image.fromarray(mask_binary)
|
| 175 |
+
return [{'score': 0.5, 'label': 'fallback_segment', 'mask': output_img}]
|
| 176 |
|
| 177 |
+
# 6. Convert masks to PIL Images and prepare results
|
| 178 |
+
results = []
|
| 179 |
+
for i, (mask, score) in enumerate(zip(filtered_masks, filtered_scores)):
|
| 180 |
+
# Convert to binary mask
|
| 181 |
+
mask_binary = (mask > 0.0).numpy().astype(np.uint8) * 255
|
| 182 |
+
|
| 183 |
+
# Create PIL Image
|
| 184 |
+
output_img = Image.fromarray(mask_binary)
|
| 185 |
+
|
| 186 |
+
results.append({
|
| 187 |
+
'score': float(score),
|
| 188 |
+
'label': f'segment_{i}',
|
| 189 |
+
'mask': output_img
|
| 190 |
+
})
|
| 191 |
+
|
| 192 |
+
# If no segments found, return a fallback
|
| 193 |
+
if not results:
|
| 194 |
+
mask_binary = np.zeros((height, width), dtype=np.uint8)
|
| 195 |
+
output_img = Image.fromarray(mask_binary)
|
| 196 |
+
results.append({'score': 0.0, 'label': 'no_segments', 'mask': output_img})
|
| 197 |
+
|
| 198 |
+
return results
|
| 199 |
|
| 200 |
def main():
|
| 201 |
# This main function shows how a client would call the endpoint locally.
|
| 202 |
input_path = "/Users/rp7/Downloads/test.jpeg"
|
| 203 |
+
output_dir = "output_masks"
|
| 204 |
+
|
| 205 |
+
# Create output directory
|
| 206 |
+
import os
|
| 207 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 208 |
|
| 209 |
# 1. Prepare the payload with a base64-encoded image string
|
| 210 |
with open(input_path, "rb") as f:
|
| 211 |
img_bytes = f.read()
|
| 212 |
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
| 213 |
+
payload = {"inputs": f"data:image/jpeg;base64,{img_b64}"}
|
| 214 |
|
| 215 |
+
# 2. Instantiate handler and get the result
|
| 216 |
handler = EndpointHandler(path=".")
|
| 217 |
+
results = handler(payload)
|
| 218 |
|
| 219 |
+
# 3. Save all masks
|
| 220 |
+
if results and isinstance(results, list):
|
| 221 |
+
print(f"Found {len(results)} segments")
|
| 222 |
+
for i, result in enumerate(results):
|
| 223 |
+
if 'mask' in result:
|
| 224 |
+
output_path = os.path.join(output_dir, f"segment_{i}_score_{result['score']:.3f}.png")
|
| 225 |
+
result['mask'].save(output_path)
|
| 226 |
+
print(f"Saved {result['label']} (score: {result['score']:.3f}) to {output_path}")
|
| 227 |
else:
|
| 228 |
+
print("Failed to get valid masks from the handler.")
|
| 229 |
|
| 230 |
if __name__ == "__main__":
|
| 231 |
main()
|