Mask Generation
sam2
Tony Neel commited on
Commit
796780d
·
1 Parent(s): 8ba5658
Files changed (2) hide show
  1. handler.py +81 -38
  2. requirements.txt +10 -5
handler.py CHANGED
@@ -1,47 +1,90 @@
1
- from typing import Dict, List, Any
2
- from transformers import SamModel, SamProcessor
3
  import torch
 
 
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
- self.model = SamModel.from_pretrained(path).to(self.device)
9
- self.processor = SamProcessor.from_pretrained(path)
10
 
11
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
- Handle image segmentation requests
14
  Args:
15
- data: Dictionary containing:
16
- inputs: Raw image bytes
 
 
17
  Returns:
18
- List of dictionaries containing segmentation masks
19
  """
20
- # Get raw image bytes from the request
21
- raw_image = data.pop("inputs", data)
22
-
23
- # Process the image
24
- inputs = self.processor(raw_image, return_tensors="pt").to(self.device)
25
-
26
- # Generate image embeddings
27
- image_embeddings = self.model.get_image_embeddings(inputs["pixel_values"])
28
-
29
- # Generate masks
30
- outputs = self.model.generate(
31
- image_embeddings=image_embeddings,
32
- return_dict=True
33
- )
34
-
35
- # Process outputs
36
- masks = outputs.pred_masks.squeeze().cpu().numpy()
37
- scores = outputs.iou_scores.squeeze().cpu().numpy()
38
-
39
- # Format response
40
- results = []
41
- for mask, score in zip(masks, scores):
42
- results.append({
43
- "mask": mask.tolist(), # Convert numpy array to list for JSON serialization
44
- "score": float(score)
45
- })
46
-
47
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Union
2
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
3
  import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import io
7
+ import base64
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
+ """Initialize the handler with SAM2 model"""
12
+ self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
 
13
 
14
+ def _load_image(self, image_data: Union[str, bytes]) -> Image.Image:
15
+ """Load image from binary or base64 data"""
16
+ try:
17
+ # Handle base64 encoded data
18
+ if isinstance(image_data, str):
19
+ image_data = base64.b64decode(image_data)
20
+
21
+ # Convert bytes to PIL Image
22
+ image = Image.open(io.BytesIO(image_data))
23
+ return image
24
+ except Exception as e:
25
+ raise ValueError(f"Failed to load image: {str(e)}")
26
+
27
+ def __call__(self, data: Union[Dict[str, Any], bytes]) -> Dict[str, Any]:
28
  """
29
+ Handle incoming request data
30
  Args:
31
+ data: Either raw bytes or dictionary containing:
32
+ - image data (raw binary or base64)
33
+ - optional point_coords: List of [x,y] coordinates for clicks
34
+ - optional point_labels: List of 1 (foreground) or 0 (background)
35
  Returns:
36
+ Dictionary containing masks and scores
37
  """
38
+ try:
39
+ # Handle different input formats
40
+ if isinstance(data, dict):
41
+ image_data = data.get("inputs", data)
42
+ # Get optional point prompts
43
+ point_coords = data.get("point_coords", None)
44
+ point_labels = data.get("point_labels", None)
45
+ else:
46
+ image_data = data
47
+ point_coords = None
48
+ point_labels = None
49
+
50
+ # Load and convert image
51
+ image = self._load_image(image_data)
52
+ image_array = np.array(image)
53
+
54
+ # Process with SAM2
55
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
56
+ self.predictor.set_image(image_array)
57
+
58
+ # If point prompts provided, use them
59
+ if point_coords is not None and point_labels is not None:
60
+ point_coords = np.array(point_coords)
61
+ point_labels = np.array(point_labels)
62
+ masks, scores, logits = self.predictor.predict(
63
+ point_coords=point_coords,
64
+ point_labels=point_labels
65
+ )
66
+ else:
67
+ # Default automatic mask generation
68
+ masks, scores, logits = self.predictor.predict()
69
+
70
+ # Convert outputs to JSON-serializable format
71
+ if masks is not None:
72
+ masks = [mask.tolist() for mask in masks]
73
+ scores = scores.tolist() if scores is not None else None
74
+
75
+ return {
76
+ "masks": masks,
77
+ "scores": scores,
78
+ "status": "success"
79
+ }
80
+ else:
81
+ return {
82
+ "error": "No masks generated",
83
+ "status": "error"
84
+ }
85
+
86
+ except Exception as e:
87
+ return {
88
+ "error": str(e),
89
+ "status": "error"
90
+ }
requirements.txt CHANGED
@@ -1,5 +1,10 @@
1
- sam2
2
- transformers
3
- torch
4
- pillow
5
- numpy
 
 
 
 
 
 
1
+ sam2>=0.1.0
2
+ torch>=2.0.0
3
+ numpy>=1.24.0
4
+ Pillow>=10.0.0
5
+ transformers>=4.30.0
6
+ accelerate>=0.20.0
7
+ timm>=0.9.0
8
+ opencv-python>=4.8.0
9
+ scipy>=1.10.0
10
+ scikit-image>=0.21.0