Tony Neel commited on
Commit ·
bf7dfcc
1
Parent(s): 796780d
add test endpoints and make handler work
Browse files- README.md +12 -0
- __pycache__/handler.cpython-310.pyc +0 -0
- handler.py +58 -59
- test_flask.py +44 -0
- test_local.py +48 -0
README.md
CHANGED
|
@@ -8,8 +8,20 @@ Repository for SAM 2: Segment Anything in Images and Videos, a foundation model
|
|
| 8 |
|
| 9 |
The official code is publicly release in this [repo](https://github.com/facebookresearch/segment-anything-2/).
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
## Usage
|
| 12 |
|
|
|
|
|
|
|
| 13 |
For image prediction:
|
| 14 |
|
| 15 |
```python
|
|
|
|
| 8 |
|
| 9 |
The official code is publicly release in this [repo](https://github.com/facebookresearch/segment-anything-2/).
|
| 10 |
|
| 11 |
+
# SAM2 Small Inference Endpoint
|
| 12 |
+
|
| 13 |
+
This repository contains the code for running SAM2 (Segment Anything Model 2) small model as a Hugging Face inference endpoint.
|
| 14 |
+
|
| 15 |
+
## Model Details
|
| 16 |
+
|
| 17 |
+
- Model: SAM2 Hiera Small
|
| 18 |
+
- Source: facebook/sam2-hiera-small
|
| 19 |
+
- Type: Segmentation model
|
| 20 |
+
|
| 21 |
## Usage
|
| 22 |
|
| 23 |
+
Send a POST request with an image to get segmentation masks:
|
| 24 |
+
|
| 25 |
For image prediction:
|
| 26 |
|
| 27 |
```python
|
__pycache__/handler.cpython-310.pyc
ADDED
|
Binary file (2.2 kB). View file
|
|
|
handler.py
CHANGED
|
@@ -5,11 +5,34 @@ import numpy as np
|
|
| 5 |
from PIL import Image
|
| 6 |
import io
|
| 7 |
import base64
|
|
|
|
| 8 |
|
| 9 |
-
class EndpointHandler:
|
| 10 |
-
def __init__(self
|
| 11 |
-
"""Initialize the handler with
|
|
|
|
| 12 |
self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def _load_image(self, image_data: Union[str, bytes]) -> Image.Image:
|
| 15 |
"""Load image from binary or base64 data"""
|
|
@@ -24,67 +47,43 @@ class EndpointHandler:
|
|
| 24 |
except Exception as e:
|
| 25 |
raise ValueError(f"Failed to load image: {str(e)}")
|
| 26 |
|
| 27 |
-
def __call__(self,
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
Dictionary containing masks and scores
|
| 37 |
-
"""
|
| 38 |
-
try:
|
| 39 |
-
# Handle different input formats
|
| 40 |
-
if isinstance(data, dict):
|
| 41 |
-
image_data = data.get("inputs", data)
|
| 42 |
-
# Get optional point prompts
|
| 43 |
-
point_coords = data.get("point_coords", None)
|
| 44 |
-
point_labels = data.get("point_labels", None)
|
| 45 |
-
else:
|
| 46 |
-
image_data = data
|
| 47 |
-
point_coords = None
|
| 48 |
-
point_labels = None
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
point_coords = np.array(point_coords)
|
| 61 |
-
point_labels = np.array(point_labels)
|
| 62 |
-
masks, scores, logits = self.predictor.predict(
|
| 63 |
point_coords=point_coords,
|
| 64 |
point_labels=point_labels
|
| 65 |
)
|
| 66 |
-
else:
|
| 67 |
-
# Default automatic mask generation
|
| 68 |
-
masks, scores, logits = self.predictor.predict()
|
| 69 |
-
|
| 70 |
-
# Convert outputs to JSON-serializable format
|
| 71 |
-
if masks is not None:
|
| 72 |
-
masks = [mask.tolist() for mask in masks]
|
| 73 |
-
scores = scores.tolist() if scores is not None else None
|
| 74 |
-
|
| 75 |
-
return {
|
| 76 |
-
"masks": masks,
|
| 77 |
-
"scores": scores,
|
| 78 |
-
"status": "success"
|
| 79 |
-
}
|
| 80 |
else:
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
-
|
|
|
|
| 87 |
return {
|
| 88 |
-
"
|
| 89 |
-
"
|
| 90 |
-
|
|
|
|
|
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
import io
|
| 7 |
import base64
|
| 8 |
+
from huggingface_hub import InferenceEndpoint
|
| 9 |
|
| 10 |
+
class EndpointHandler(InferenceEndpoint):
|
| 11 |
+
def __init__(self):
|
| 12 |
+
"""Initialize the handler with mock predictor for local testing"""
|
| 13 |
+
# Comment out real model for local testing
|
| 14 |
self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
|
| 15 |
+
|
| 16 |
+
# Mock predictor for local testing
|
| 17 |
+
# class MockPredictor:
|
| 18 |
+
# def set_image(self, image):
|
| 19 |
+
# print(f"Mock: set_image called with shape {image.shape}")
|
| 20 |
+
|
| 21 |
+
# def predict(self, point_coords=None, point_labels=None):
|
| 22 |
+
# print("Mock: predict called")
|
| 23 |
+
# if point_coords is not None:
|
| 24 |
+
# print(f"Mock: with point coords {point_coords}")
|
| 25 |
+
# print(f"Mock: with point labels {point_labels}")
|
| 26 |
+
# # Return mock mask focused around the point
|
| 27 |
+
# mock_masks = [np.zeros((100, 100), dtype=bool) for _ in range(1)]
|
| 28 |
+
# mock_scores = np.array([0.95]) # Higher confidence for point prompt
|
| 29 |
+
# else:
|
| 30 |
+
# # Return multiple mock masks for automatic mode
|
| 31 |
+
# mock_masks = [np.zeros((100, 100), dtype=bool) for _ in range(3)]
|
| 32 |
+
# mock_scores = np.array([0.9, 0.8, 0.7])
|
| 33 |
+
# return mock_masks, mock_scores, None
|
| 34 |
+
|
| 35 |
+
self.predictor = MockPredictor()
|
| 36 |
|
| 37 |
def _load_image(self, image_data: Union[str, bytes]) -> Image.Image:
|
| 38 |
"""Load image from binary or base64 data"""
|
|
|
|
| 47 |
except Exception as e:
|
| 48 |
raise ValueError(f"Failed to load image: {str(e)}")
|
| 49 |
|
| 50 |
+
def __call__(self, image_bytes):
|
| 51 |
+
# Get point prompts if provided in request
|
| 52 |
+
if isinstance(image_bytes, dict):
|
| 53 |
+
point_coords = image_bytes.get('point_coords')
|
| 54 |
+
point_labels = image_bytes.get('point_labels')
|
| 55 |
+
image_bytes = image_bytes['image']
|
| 56 |
+
else:
|
| 57 |
+
point_coords = None
|
| 58 |
+
point_labels = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
# Convert bytes to image
|
| 61 |
+
image = Image.open(io.BytesIO(image_bytes))
|
| 62 |
+
if image.mode != 'RGB':
|
| 63 |
+
image = image.convert('RGB')
|
| 64 |
+
image_array = np.array(image)
|
| 65 |
|
| 66 |
+
# Run inference (will use mock predictor locally)
|
| 67 |
+
with torch.inference_mode():
|
| 68 |
+
if torch.cuda.is_available():
|
| 69 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 70 |
+
self.predictor.set_image(image_array)
|
| 71 |
+
masks, scores, _ = self.predictor.predict(
|
|
|
|
|
|
|
|
|
|
| 72 |
point_coords=point_coords,
|
| 73 |
point_labels=point_labels
|
| 74 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
else:
|
| 76 |
+
self.predictor.set_image(image_array)
|
| 77 |
+
masks, scores, _ = self.predictor.predict(
|
| 78 |
+
point_coords=point_coords,
|
| 79 |
+
point_labels=point_labels
|
| 80 |
+
)
|
| 81 |
|
| 82 |
+
# Format output
|
| 83 |
+
if masks is not None:
|
| 84 |
return {
|
| 85 |
+
"masks": [mask.tolist() for mask in masks],
|
| 86 |
+
"scores": scores.tolist() if scores is not None else None,
|
| 87 |
+
"status": "success"
|
| 88 |
+
}
|
| 89 |
+
return {"error": "No masks generated", "status": "error"}
|
test_flask.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask, request, jsonify
|
| 2 |
+
from handler import EndpointHandler
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
app = Flask(__name__)
|
| 6 |
+
|
| 7 |
+
# Initialize the handler
|
| 8 |
+
handler = EndpointHandler()
|
| 9 |
+
|
| 10 |
+
@app.route('/predict', methods=['POST'])
|
| 11 |
+
def predict():
|
| 12 |
+
if 'file' not in request.files:
|
| 13 |
+
return jsonify({'error': 'No file provided'}), 400
|
| 14 |
+
|
| 15 |
+
file = request.files['file']
|
| 16 |
+
if file.filename == '':
|
| 17 |
+
return jsonify({'error': 'No file selected'}), 400
|
| 18 |
+
|
| 19 |
+
# Read the file bytes
|
| 20 |
+
image_bytes = file.read()
|
| 21 |
+
|
| 22 |
+
# Get point prompts if provided
|
| 23 |
+
point_coords = request.form.get('point_coords')
|
| 24 |
+
point_labels = request.form.get('point_labels')
|
| 25 |
+
|
| 26 |
+
# Process with handler
|
| 27 |
+
try:
|
| 28 |
+
if point_coords and point_labels:
|
| 29 |
+
# Convert string inputs to lists
|
| 30 |
+
point_coords = eval(point_coords) # e.g. "[[500, 375]]"
|
| 31 |
+
point_labels = eval(point_labels) # e.g. "[1]"
|
| 32 |
+
result = handler({
|
| 33 |
+
'image': image_bytes,
|
| 34 |
+
'point_coords': point_coords,
|
| 35 |
+
'point_labels': point_labels
|
| 36 |
+
})
|
| 37 |
+
else:
|
| 38 |
+
result = handler(image_bytes)
|
| 39 |
+
return jsonify(result)
|
| 40 |
+
except Exception as e:
|
| 41 |
+
return jsonify({'error': str(e)}), 500
|
| 42 |
+
|
| 43 |
+
if __name__ == '__main__':
|
| 44 |
+
app.run(debug=True, port=5000)
|
test_local.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
def test_endpoint(image_path, point_coords=None, point_labels=None):
|
| 5 |
+
# URL for local Flask server
|
| 6 |
+
url = "http://localhost:5000/predict"
|
| 7 |
+
|
| 8 |
+
# Open image file
|
| 9 |
+
with open(image_path, 'rb') as f:
|
| 10 |
+
files = {'file': f}
|
| 11 |
+
data = {}
|
| 12 |
+
|
| 13 |
+
# Add point prompts if provided
|
| 14 |
+
if point_coords is not None and point_labels is not None:
|
| 15 |
+
data['point_coords'] = str(point_coords)
|
| 16 |
+
data['point_labels'] = str(point_labels)
|
| 17 |
+
|
| 18 |
+
# Make request
|
| 19 |
+
response = requests.post(url, files=files, data=data)
|
| 20 |
+
|
| 21 |
+
print(f"Status Code: {response.status_code}")
|
| 22 |
+
if response.status_code == 200:
|
| 23 |
+
result = response.json()
|
| 24 |
+
print("\nSuccess!")
|
| 25 |
+
print(f"Number of masks: {len(result['masks']) if 'masks' in result else 0}")
|
| 26 |
+
print(f"Scores: {result['scores'] if 'scores' in result else None}")
|
| 27 |
+
else:
|
| 28 |
+
print(f"Error: {response.text}")
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
# Test with your image
|
| 32 |
+
image_path = Path("images/20250121_gauge_0001.jpg")
|
| 33 |
+
if not image_path.exists():
|
| 34 |
+
print(f"Error: Image not found at {image_path}")
|
| 35 |
+
exit(1)
|
| 36 |
+
|
| 37 |
+
# Test without points
|
| 38 |
+
print("\nTesting without points...")
|
| 39 |
+
print(f"Testing with image: {image_path}")
|
| 40 |
+
test_endpoint(image_path)
|
| 41 |
+
|
| 42 |
+
# Test with points
|
| 43 |
+
print("\nTesting with points...")
|
| 44 |
+
test_endpoint(
|
| 45 |
+
image_path,
|
| 46 |
+
point_coords=[[500, 375]], # Example coordinates
|
| 47 |
+
point_labels=[1] # 1 for foreground
|
| 48 |
+
)
|