research21 commited on
Commit
97655e3
·
verified ·
1 Parent(s): 3d7e152

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +716 -0
README.md ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ extra_gated_fields:
4
+ First Name: text
5
+ Last Name: text
6
+ Date of birth: date_picker
7
+ Country: country
8
+ Affiliation: text
9
+ Job title:
10
+ type: select
11
+ options:
12
+ - Student
13
+ - Research Graduate
14
+ - AI researcher
15
+ - AI developer/engineer
16
+ - Reporter
17
+ - Other
18
+ geo: ip_location
19
+ By clicking Submit below I accept the terms of the license and acknowledge that the information I provide will be collected stored processed and shared in accordance with the Meta Privacy Policy: checkbox
20
+ extra_gated_description: >-
21
+ The information you provide will be collected, stored, processed and shared in
22
+ accordance with the [Meta Privacy
23
+ Policy](https://www.facebook.com/privacy/policy/).
24
+ extra_gated_button_content: Submit
25
+ language:
26
+ - en
27
+ pipeline_tag: mask-generation
28
+ library_name: transformers
29
+ tags:
30
+ - sam3
31
+ ---
32
+
33
+ SAM 3 is a unified foundation model for promptable segmentation in images and videos. It can detect, segment, and track objects using text or visual prompts such as points, boxes, and masks. Compared to its predecessor [SAM 2](https://github.com/facebookresearch/sam2), SAM 3 introduces the ability to exhaustively segment all instances of an open-vocabulary concept specified by a short text phrase or exemplars. Unlike prior work, SAM 3 can handle a vastly larger set of open-vocabulary prompts. It achieves 75-80% of human performance on our new [SA-CO benchmark](https://github.com/facebookresearch/sam3/edit/main_readme/README.md#sa-co-dataset) which contains 270K unique concepts, over 50 times more than existing benchmarks.
34
+
35
+ [Hugging Face 🤗 app](https://huggingface.co/spaces/akhaliq/sam3)
36
+
37
+ ### Basic Usage
38
+
39
+ ```python
40
+ import torch
41
+ #################################### For Image ####################################
42
+ from PIL import Image
43
+ from sam3.model_builder import build_sam3_image_model
44
+ from sam3.model.sam3_image_processor import Sam3Processor
45
+ # Load the model
46
+ model = build_sam3_image_model()
47
+ processor = Sam3Processor(model)
48
+ # Load an image
49
+ image = Image.open("<YOUR_IMAGE_PATH.jpg>")
50
+ inference_state = processor.set_image(image)
51
+ # Prompt the model with text
52
+ output = processor.set_text_prompt(state=inference_state, prompt="<YOUR_TEXT_PROMPT>")
53
+
54
+ # Get the masks, bounding boxes, and scores
55
+ masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
56
+
57
+ #################################### For Video ####################################
58
+
59
+ from sam3.model_builder import build_sam3_video_predictor
60
+
61
+ video_predictor = build_sam3_video_predictor()
62
+ video_path = "<YOUR_VIDEO_PATH>" # a JPEG folder or an MP4 video file
63
+ # Start a session
64
+ response = video_predictor.handle_request(
65
+ request=dict(
66
+ type="start_session",
67
+ resource_path=video_path,
68
+ )
69
+ )
70
+ response = video_predictor.handle_request(
71
+ request=dict(
72
+ type="add_prompt",
73
+ session_id=response["session_id"],
74
+ frame_index=0, # Arbitrary frame index
75
+ text="<YOUR_TEXT_PROMPT>",
76
+ )
77
+ )
78
+ output = response["outputs"]
79
+ ```
80
+
81
+ The official code is publicly released in the [sam3 repo](https://github.com/facebookresearch/sam3).
82
+
83
+
84
+ ## Usage with 🤗 Transformers
85
+
86
+ ### SAM3 - Promptable Concept Segmentation (PCS) for Images
87
+
88
+ SAM3 performs Promptable Concept Segmentation (PCS) on images, taking text and/or image exemplars as prompts and returning segmentation masks for **all matching object instances** in the image.
89
+
90
+ #### Text-Only Prompts
91
+
92
+ ```python
93
+ >>> from transformers import Sam3Processor, Sam3Model
94
+ >>> import torch
95
+ >>> from PIL import Image
96
+ >>> import requests
97
+
98
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
99
+
100
+ >>> model = Sam3Model.from_pretrained("facebook/sam3").to(device)
101
+ >>> processor = Sam3Processor.from_pretrained("facebook/sam3")
102
+
103
+ >>> # Load image
104
+ >>> image_url = "http://images.cocodataset.org/val2017/000000077595.jpg"
105
+ >>> image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
106
+
107
+ >>> # Segment using text prompt
108
+ >>> inputs = processor(images=image, text="ear", return_tensors="pt").to(device)
109
+
110
+ >>> with torch.no_grad():
111
+ ... outputs = model(**inputs)
112
+
113
+ >>> # Post-process results
114
+ >>> results = processor.post_process_instance_segmentation(
115
+ ... outputs,
116
+ ... threshold=0.5,
117
+ ... mask_threshold=0.5,
118
+ ... target_sizes=inputs.get("original_sizes").tolist()
119
+ ... )[0]
120
+
121
+ >>> print(f"Found {len(results['masks'])} objects")
122
+ >>> # Results contain:
123
+ >>> # - masks: Binary masks resized to original image size
124
+ >>> # - boxes: Bounding boxes in absolute pixel coordinates (xyxy format)
125
+ >>> # - scores: Confidence scores
126
+ ```
127
+
128
+ You can display masks using a simple helper like the following:
129
+
130
+ ```python
131
+ import numpy as np
132
+ import matplotlib
133
+
134
+ def overlay_masks(image, masks):
135
+ image = image.convert("RGBA")
136
+ masks = 255 * masks.cpu().numpy().astype(np.uint8)
137
+
138
+ n_masks = masks.shape[0]
139
+ cmap = matplotlib.colormaps.get_cmap("rainbow").resampled(n_masks)
140
+ colors = [
141
+ tuple(int(c * 255) for c in cmap(i)[:3])
142
+ for i in range(n_masks)
143
+ ]
144
+
145
+ for mask, color in zip(masks, colors):
146
+ mask = Image.fromarray(mask)
147
+ overlay = Image.new("RGBA", image.size, color + (0,))
148
+ alpha = mask.point(lambda v: int(v * 0.5))
149
+ overlay.putalpha(alpha)
150
+ image = Image.alpha_composite(image, overlay)
151
+ return image
152
+ ```
153
+
154
+ Then you can save the resulting composite image or display it in a notebook:
155
+
156
+ ```python
157
+ >>> overlay_masks(image, results["masks"])
158
+ ```
159
+
160
+ #### Single Bounding Box Prompt
161
+
162
+ Segment objects using a bounding box:
163
+
164
+ ```python
165
+ >>> # Box in xyxy format: [x1, y1, x2, y2] in pixel coordinates
166
+ >>> # Example: laptop region
167
+ >>> box_xyxy = [100, 150, 500, 450]
168
+ >>> input_boxes = [[box_xyxy]] # [batch, num_boxes, 4]
169
+ >>> input_boxes_labels = [[1]] # 1 = positive box
170
+
171
+ >>> inputs = processor(
172
+ ... images=image,
173
+ ... input_boxes=input_boxes,
174
+ ... input_boxes_labels=input_boxes_labels,
175
+ ... return_tensors="pt"
176
+ ... ).to(device)
177
+
178
+ >>> with torch.no_grad():
179
+ ... outputs = model(**inputs)
180
+
181
+ >>> # Post-process results
182
+ >>> results = processor.post_process_instance_segmentation(
183
+ ... outputs,
184
+ ... threshold=0.5,
185
+ ... mask_threshold=0.5,
186
+ ... target_sizes=inputs.get("original_sizes").tolist()
187
+ ... )[0]
188
+ ```
189
+
190
+ #### Multiple Box Prompts (Positive and Negative)
191
+
192
+ Use multiple boxes with positive and negative labels to refine the concept:
193
+
194
+ ```python
195
+ >>> # Load kitchen image
196
+ >>> kitchen_url = "http://images.cocodataset.org/val2017/000000136466.jpg"
197
+ >>> kitchen_image = Image.open(requests.get(kitchen_url, stream=True).raw).convert("RGB")
198
+
199
+ >>> # Define two positive boxes (e.g., dial and button on oven)
200
+ >>> # Boxes are in xyxy format [x1, y1, x2, y2] in pixel coordinates
201
+ >>> box1_xyxy = [59, 144, 76, 163] # Dial box
202
+ >>> box2_xyxy = [87, 148, 104, 159] # Button box
203
+ >>> input_boxes = [[box1_xyxy, box2_xyxy]]
204
+ >>> input_boxes_labels = [[1, 1]] # Both positive
205
+
206
+ >>> inputs = processor(
207
+ ... images=kitchen_image,
208
+ ... input_boxes=input_boxes,
209
+ ... input_boxes_labels=input_boxes_labels,
210
+ ... return_tensors="pt"
211
+ ... ).to(device)
212
+
213
+ >>> with torch.no_grad():
214
+ ... outputs = model(**inputs)
215
+
216
+ >>> # Post-process results
217
+ >>> results = processor.post_process_instance_segmentation(
218
+ ... outputs,
219
+ ... threshold=0.5,
220
+ ... mask_threshold=0.5,
221
+ ... target_sizes=inputs.get("original_sizes").tolist()
222
+ ... )[0]
223
+ >>> overlay_masks(kitchen_image, results["masks"])
224
+ ```
225
+
226
+ #### Combined Prompts (Text + Negative Box)
227
+
228
+ Use text prompts with negative visual prompts to refine the concept:
229
+
230
+ ```python
231
+ >>> # Segment "handle" but exclude the oven handle using a negative box
232
+ >>> text = "handle"
233
+ >>> # Negative box covering oven handle area (xyxy): [40, 183, 318, 204]
234
+ >>> oven_handle_box = [40, 183, 318, 204]
235
+ >>> input_boxes = [[oven_handle_box]]
236
+
237
+ >>> inputs = processor(
238
+ ... images=kitchen_image,
239
+ ... text=text,
240
+ ... input_boxes=input_boxes,
241
+ ... input_boxes_labels=[[0]], # 0 = negative (exclude this region)
242
+ ... return_tensors="pt"
243
+ ... ).to(device)
244
+
245
+ >>> with torch.no_grad():
246
+ ... outputs = model(**inputs)
247
+
248
+ >>> # Post-process results
249
+ >>> results = processor.post_process_instance_segmentation(
250
+ ... outputs,
251
+ ... threshold=0.5,
252
+ ... mask_threshold=0.5,
253
+ ... target_sizes=inputs.get("original_sizes").tolist()
254
+ ... )[0]
255
+ >>> # This will segment pot handles but exclude the oven handle
256
+ ```
257
+
258
+ #### Batched Inference with Text Prompts
259
+
260
+ Process multiple images with different text prompts by batch:
261
+
262
+ ```python
263
+ >>> cat_url = "http://images.cocodataset.org/val2017/000000077595.jpg"
264
+ >>> kitchen_url = "http://images.cocodataset.org/val2017/000000136466.jpg"
265
+ >>> images = [
266
+ ... Image.open(requests.get(cat_url, stream=True).raw).convert("RGB"),
267
+ ... Image.open(requests.get(kitchen_url, stream=True).raw).convert("RGB")
268
+ ... ]
269
+
270
+ >>> text_prompts = ["ear", "dial"]
271
+
272
+ >>> inputs = processor(images=images, text=text_prompts, return_tensors="pt").to(device)
273
+
274
+ >>> with torch.no_grad():
275
+ ... outputs = model(**inputs)
276
+
277
+ >>> # Post-process results for both images
278
+ >>> results = processor.post_process_instance_segmentation(
279
+ ... outputs,
280
+ ... threshold=0.5,
281
+ ... mask_threshold=0.5,
282
+ ... target_sizes=inputs.get("original_sizes").tolist()
283
+ ... )
284
+
285
+ >>> print(f"Image 1: {len(results[0]['masks'])} objects found")
286
+ >>> print(f"Image 2: {len(results[1]['masks'])} objects found")
287
+ ```
288
+
289
+ #### Batched Mixed Prompts
290
+
291
+ Use different prompt types for different images in the same batch:
292
+
293
+ ```python
294
+ >>> # Image 1: text prompt "laptop"
295
+ >>> # Image 2: visual prompt (dial box)
296
+ >>> box2_xyxy = [59, 144, 76, 163]
297
+
298
+ >>> inputs = processor(
299
+ ... images=images,
300
+ ... text=["laptop", None], # Only first image has text
301
+ ... input_boxes=[None, [box2_xyxy]], # Only second image has box
302
+ ... input_boxes_labels=[None, [1]], # Positive box for second image
303
+ ... return_tensors="pt"
304
+ ... ).to(device)
305
+
306
+ >>> with torch.no_grad():
307
+ ... outputs = model(**inputs)
308
+
309
+ >>> # Post-process results for both images
310
+ >>> results = processor.post_process_instance_segmentation(
311
+ ... outputs,
312
+ ... threshold=0.5,
313
+ ... mask_threshold=0.5,
314
+ ... target_sizes=inputs.get("original_sizes").tolist()
315
+ ... )
316
+ >>> # Both images processed in single forward pass
317
+ ```
318
+
319
+ #### Semantic Segmentation Output
320
+
321
+ SAM3 also provides semantic segmentation alongside instance masks:
322
+
323
+ ```python
324
+ >>> inputs = processor(images=image, text="ear", return_tensors="pt").to(device)
325
+
326
+ >>> with torch.no_grad():
327
+ ... outputs = model(**inputs)
328
+
329
+ >>> # Instance segmentation masks
330
+ >>> instance_masks = torch.sigmoid(outputs.pred_masks) # [batch, num_queries, H, W]
331
+
332
+ >>> # Semantic segmentation (single channel)
333
+ >>> semantic_seg = outputs.semantic_seg # [batch, 1, H, W]
334
+
335
+ >>> print(f"Instance masks: {instance_masks.shape}")
336
+ >>> print(f"Semantic segmentation: {semantic_seg.shape}")
337
+ ```
338
+
339
+ ### SAM3 Video - Promptable Concept Segmentation (PCS) for Videos
340
+
341
+ SAM3 Video performs Promptable Concept Segmentation (PCS) on videos, taking text as prompts and detecting and tracking **all matching object instances** across video frames.
342
+
343
+ #### Pre-loaded Video Inference
344
+
345
+ Process a video with all frames already available using text prompts:
346
+
347
+ ```python
348
+ >>> from transformers import Sam3VideoModel, Sam3VideoProcessor
349
+ >>> from accelerate import Accelerator
350
+ >>> import torch
351
+
352
+ >>> device = Accelerator().device
353
+ >>> model = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
354
+ >>> processor = Sam3VideoProcessor.from_pretrained("facebook/sam3")
355
+
356
+ >>> # Load video frames
357
+ >>> from transformers.video_utils import load_video
358
+ >>> video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"
359
+ >>> video_frames, _ = load_video(video_url)
360
+
361
+ >>> # Initialize video inference session
362
+ >>> inference_session = processor.init_video_session(
363
+ ... video=video_frames,
364
+ ... inference_device=device,
365
+ ... processing_device="cpu",
366
+ ... video_storage_device="cpu",
367
+ ... dtype=torch.bfloat16,
368
+ ... )
369
+
370
+ >>> # Add text prompt to detect and track objects
371
+ >>> text = "person"
372
+ >>> inference_session = processor.add_text_prompt(
373
+ ... inference_session=inference_session,
374
+ ... text=text,
375
+ ... )
376
+
377
+ >>> # Process all frames in the video
378
+ >>> outputs_per_frame = {}
379
+ >>> for model_outputs in model.propagate_in_video_iterator(
380
+ ... inference_session=inference_session, max_frame_num_to_track=50
381
+ ... ):
382
+ ... processed_outputs = processor.postprocess_outputs(inference_session, model_outputs)
383
+ ... outputs_per_frame[model_outputs.frame_idx] = processed_outputs
384
+
385
+ >>> print(f"Processed {len(outputs_per_frame)} frames")
386
+ Processed 51 frames
387
+
388
+ >>> # Access results for a specific frame
389
+ >>> frame_0_outputs = outputs_per_frame[0]
390
+ >>> print(f"Detected {len(frame_0_outputs['object_ids'])} objects")
391
+ >>> print(f"Object IDs: {frame_0_outputs['object_ids'].tolist()}")
392
+ >>> print(f"Scores: {frame_0_outputs['scores'].tolist()}")
393
+ >>> print(f"Boxes shape (XYXY format, absolute coordinates): {frame_0_outputs['boxes'].shape}")
394
+ >>> print(f"Masks shape: {frame_0_outputs['masks'].shape}")
395
+ ```
396
+
397
+ #### Streaming Video Inference
398
+
399
+ For real-time applications, the Transformers implementation of SAM3 Video supports processing video frames as they arrive:
400
+
401
+ ```python
402
+ >>> # Initialize session for streaming
403
+ >>> streaming_inference_session = processor.init_video_session(
404
+ ... inference_device=device,
405
+ ... processing_device="cpu",
406
+ ... video_storage_device="cpu",
407
+ ... dtype=torch.bfloat16,
408
+ ... )
409
+
410
+ >>> # Add text prompt
411
+ >>> text = "person"
412
+ >>> streaming_inference_session = processor.add_text_prompt(
413
+ ... inference_session=streaming_inference_session,
414
+ ... text=text,
415
+ ... )
416
+
417
+ >>> # Process frames one by one (streaming mode)
418
+ >>> streaming_outputs_per_frame = {}
419
+ >>> for frame_idx, frame in enumerate(video_frames[:50]): # Process first 50 frames
420
+ ... # First, process the frame using the processor
421
+ ... inputs = processor(images=frame, device=device, return_tensors="pt")
422
+ ...
423
+ ... # Process frame using streaming inference - pass the processed pixel_values
424
+ ... model_outputs = model(
425
+ ... inference_session=streaming_inference_session,
426
+ ... frame=inputs.pixel_values[0], # Provide processed frame - this enables streaming mode
427
+ ... reverse=False,
428
+ ... )
429
+ ...
430
+ ... # Post-process outputs with original_sizes for proper resolution handling
431
+ ... processed_outputs = processor.postprocess_outputs(
432
+ ... streaming_inference_session,
433
+ ... model_outputs,
434
+ ... original_sizes=inputs.original_sizes, # Required for streaming inference
435
+ ... )
436
+ ... streaming_outputs_per_frame[frame_idx] = processed_outputs
437
+ ...
438
+ ... if (frame_idx + 1) % 10 == 0:
439
+ ... print(f"Processed {frame_idx + 1} frames...")
440
+
441
+ >>> print(f"✓ Streaming inference complete! Processed {len(streaming_outputs_per_frame)} frames")
442
+ ✓ Streaming inference complete! Processed 50 frames
443
+
444
+ >>> # Access results
445
+ >>> frame_0_outputs = streaming_outputs_per_frame[0]
446
+ >>> print(f"Detected {len(frame_0_outputs['object_ids'])} objects in first frame")
447
+ >>> print(f"Boxes are in XYXY format (absolute pixel coordinates): {frame_0_outputs['boxes'].shape}")
448
+ >>> print(f"Masks are at original video resolution: {frame_0_outputs['masks'].shape}")
449
+ ```
450
+
451
+ <div class="warning">
452
+ ⚠️ **Note on Streaming Inference Quality**: Streaming inference disables hotstart heuristics that remove unmatched and duplicate objects, as these require access to future frames to make informed decisions. This may result in more false positive detections and duplicate object tracks compared to pre-loaded video inference. For best results, use pre-loaded video inference when all frames are available.
453
+ </div>
454
+
455
+ ### SAM3 Tracker - Promptable Visual Segmentation (PVS) for Images
456
+
457
+ Sam3Tracker performs Promptable Visual Segmentation (PVS) on images, taking interactive visual prompts (points, boxes, masks) to segment a **specific object instance** per prompt. It is an updated version of SAM2 that maintains the same API while providing improved performance, making it a drop-in replacement for SAM2 workflows.
458
+
459
+ #### Automatic Mask Generation with Pipeline
460
+
461
+ ```python
462
+ >>> from transformers import pipeline
463
+
464
+ >>> generator = pipeline("mask-generation", model="facebook/sam3", device=0)
465
+ >>> image_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
466
+ >>> outputs = generator(image_url, points_per_batch=64)
467
+
468
+ >>> len(outputs["masks"]) # Number of masks generated
469
+ ```
470
+
471
+ #### Basic Image Segmentation
472
+
473
+ ##### Single Point Click
474
+
475
+ ```python
476
+ >>> from transformers import Sam3TrackerProcessor, Sam3TrackerModel
477
+ >>> from accelerate import Accelerator
478
+ >>> import torch
479
+ >>> from PIL import Image
480
+ >>> import requests
481
+
482
+ >>> device = Accelerator().device
483
+
484
+ >>> model = Sam3TrackerModel.from_pretrained("facebook/sam3").to(device)
485
+ >>> processor = Sam3TrackerProcessor.from_pretrained("facebook/sam3")
486
+
487
+ >>> image_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
488
+ >>> raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
489
+
490
+ >>> input_points = [[[[500, 375]]]] # Single point click, 4 dimensions (image_dim, object_dim, point_per_object_dim, coordinates)
491
+ >>> input_labels = [[[1]]] # 1 for positive click, 0 for negative click, 3 dimensions (image_dim, object_dim, point_label)
492
+
493
+ >>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(model.device)
494
+
495
+ >>> with torch.no_grad():
496
+ ... outputs = model(**inputs)
497
+
498
+ >>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
499
+
500
+ >>> # The model outputs multiple mask predictions ranked by quality score
501
+ >>> print(f"Generated {masks.shape[1]} masks with shape {masks.shape}")
502
+ ```
503
+
504
+ ##### Multiple Points for Refinement
505
+
506
+ ```python
507
+ >>> # Add both positive and negative points to refine the mask
508
+ >>> input_points = [[[[500, 375], [1125, 625]]]] # Multiple points for refinement
509
+ >>> input_labels = [[[1, 1]]] # Both positive clicks
510
+
511
+ >>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
512
+
513
+ >>> with torch.no_grad():
514
+ ... outputs = model(**inputs)
515
+
516
+ >>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
517
+ ```
518
+
519
+ ##### Bounding Box Input
520
+
521
+ ```python
522
+ >>> # Define bounding box as [x_min, y_min, x_max, y_max]
523
+ >>> input_boxes = [[[75, 275, 1725, 850]]]
524
+
525
+ >>> inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(device)
526
+
527
+ >>> with torch.no_grad():
528
+ ... outputs = model(**inputs)
529
+
530
+ >>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
531
+ ```
532
+
533
+ ##### Multiple Objects Segmentation
534
+
535
+ ```python
536
+ >>> # Define points for two different objects
537
+ >>> input_points = [[[[500, 375]], [[650, 750]]]] # Points for two objects in same image
538
+ >>> input_labels = [[[1], [1]]] # Positive clicks for both objects
539
+
540
+ >>> inputs = processor(images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(model.device)
541
+
542
+ >>> with torch.no_grad():
543
+ ... outputs = model(**inputs, multimask_output=False)
544
+
545
+ >>> # Each object gets its own mask
546
+ >>> masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])[0]
547
+ >>> print(f"Generated masks for {masks.shape[0]} objects")
548
+ Generated masks for 2 objects
549
+ ```
550
+
551
+ #### Batch Inference
552
+
553
+
554
+ ```python
555
+ >>> # Load multiple images
556
+ >>> image_urls = [
557
+ ... "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg",
558
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
559
+ ... ]
560
+ >>> raw_images = [Image.open(requests.get(url, stream=True).raw).convert("RGB") for url in image_urls]
561
+
562
+ >>> # Single point per image
563
+ >>> input_points = [[[[500, 375]]], [[[770, 200]]]] # One point for each image
564
+ >>> input_labels = [[[1]], [[1]]] # Positive clicks for both images
565
+
566
+ >>> inputs = processor(images=raw_images, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(model.device)
567
+
568
+ >>> with torch.no_grad():
569
+ ... outputs = model(**inputs, multimask_output=False)
570
+
571
+ >>> # Post-process masks for each image
572
+ >>> all_masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"])
573
+ >>> print(f"Processed {len(all_masks)} images, each with {all_masks[0].shape[0]} objects")
574
+ ```
575
+
576
+ ### SAM3 Tracker Video - Promptable Visual Segmentation (PVS) for Videos
577
+
578
+ Sam3TrackerVideo performs Promptable Visual Segmentation (PVS) on videos, taking interactive visual prompts (points, boxes, masks) to track a **specific object instance** per prompt across video frames. It is an updated version of SAM2 Video that maintains the same API while providing improved performance, making it a drop-in replacement for SAM2 Video workflows.
579
+
580
+ #### Basic Video Tracking
581
+
582
+ ```python
583
+ >>> from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor
584
+ >>> from accelerate import Accelerator
585
+ >>> import torch
586
+
587
+ >>> device = Accelerator().device
588
+ >>> model = Sam3TrackerVideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
589
+ >>> processor = Sam3TrackerVideoProcessor.from_pretrained("facebook/sam3")
590
+
591
+ >>> # Load video frames
592
+ >>> from transformers.video_utils import load_video
593
+ >>> video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"
594
+ >>> video_frames, _ = load_video(video_url)
595
+
596
+ >>> # Initialize video inference session
597
+ >>> inference_session = processor.init_video_session(
598
+ ... video=video_frames,
599
+ ... inference_device=device,
600
+ ... dtype=torch.bfloat16,
601
+ ... )
602
+
603
+ >>> # Add click on first frame to select object
604
+ >>> ann_frame_idx = 0
605
+ >>> ann_obj_id = 1
606
+ >>> points = [[[[210, 350]]]]
607
+ >>> labels = [[[1]]]
608
+
609
+ >>> processor.add_inputs_to_inference_session(
610
+ ... inference_session=inference_session,
611
+ ... frame_idx=ann_frame_idx,
612
+ ... obj_ids=ann_obj_id,
613
+ ... input_points=points,
614
+ ... input_labels=labels,
615
+ ... )
616
+
617
+ >>> # Segment the object on the first frame (optional, you can also propagate the masks through the video directly)
618
+ >>> outputs = model(
619
+ ... inference_session=inference_session,
620
+ ... frame_idx=ann_frame_idx,
621
+ ... )
622
+ >>> video_res_masks = processor.post_process_masks(
623
+ ... [outputs.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False
624
+ ... )[0]
625
+ >>> print(f"Segmentation shape: {video_res_masks.shape}")
626
+ Segmentation shape: torch.Size([1, 1, 480, 854])
627
+
628
+ >>> # Propagate through the entire video
629
+ >>> video_segments = {}
630
+ >>> for sam3_tracker_video_output in model.propagate_in_video_iterator(inference_session):
631
+ ... video_res_masks = processor.post_process_masks(
632
+ ... [sam3_tracker_video_output.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False
633
+ ... )[0]
634
+ ... video_segments[sam3_tracker_video_output.frame_idx] = video_res_masks
635
+
636
+ >>> print(f"Tracked object through {len(video_segments)} frames")
637
+ Tracked object through 180 frames
638
+ ```
639
+
640
+ #### Multi-Object Video Tracking
641
+
642
+ Track multiple objects simultaneously across video frames:
643
+
644
+ ```python
645
+ >>> # Reset for new tracking session
646
+ >>> inference_session.reset_inference_session()
647
+
648
+ >>> # Add multiple objects on the first frame
649
+ >>> ann_frame_idx = 0
650
+ >>> obj_ids = [2, 3]
651
+ >>> input_points = [[[[200, 300]], [[400, 150]]]] # Points for two objects (batched)
652
+ >>> input_labels = [[[1], [1]]]
653
+
654
+ >>> processor.add_inputs_to_inference_session(
655
+ ... inference_session=inference_session,
656
+ ... frame_idx=ann_frame_idx,
657
+ ... obj_ids=obj_ids,
658
+ ... input_points=input_points,
659
+ ... input_labels=input_labels,
660
+ ... )
661
+
662
+ >>> # Get masks for both objects on first frame (optional, you can also propagate the masks through the video directly)
663
+ >>> outputs = model(
664
+ ... inference_session=inference_session,
665
+ ... frame_idx=ann_frame_idx,
666
+ ... )
667
+
668
+ >>> # Propagate both objects through video
669
+ >>> video_segments = {}
670
+ >>> for sam3_tracker_video_output in model.propagate_in_video_iterator(inference_session):
671
+ ... video_res_masks = processor.post_process_masks(
672
+ ... [sam3_tracker_video_output.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False
673
+ ... )[0]
674
+ ... video_segments[sam3_tracker_video_output.frame_idx] = {
675
+ ... obj_id: video_res_masks[i]
676
+ ... for i, obj_id in enumerate(inference_session.obj_ids)
677
+ ... }
678
+
679
+ >>> print(f"Tracked {len(inference_session.obj_ids)} objects through {len(video_segments)} frames")
680
+ Tracked 2 objects through 180 frames
681
+ ```
682
+
683
+ #### Streaming Video Inference
684
+
685
+ For real-time applications, Sam3TrackerVideo supports processing video frames as they arrive:
686
+
687
+ ```python
688
+ >>> # Initialize session for streaming
689
+ >>> inference_session = processor.init_video_session(
690
+ ... inference_device=device,
691
+ ... dtype=torch.bfloat16,
692
+ ... )
693
+
694
+ >>> # Process frames one by one
695
+ >>> for frame_idx, frame in enumerate(video_frames[:10]): # Process first 10 frames
696
+ ... inputs = processor(images=frame, device=device, return_tensors="pt")
697
+ ...
698
+ ... if frame_idx == 0:
699
+ ... # Add point input on first frame
700
+ ... processor.add_inputs_to_inference_session(
701
+ ... inference_session=inference_session,
702
+ ... frame_idx=0,
703
+ ... obj_ids=1,
704
+ ... input_points=[[[[210, 350], [250, 220]]]],
705
+ ... input_labels=[[[1, 1]]],
706
+ ... original_size=inputs.original_sizes[0], # need to be provided when using streaming video inference
707
+ ... )
708
+ ...
709
+ ... # Process current frame
710
+ ... sam3_tracker_video_output = model(inference_session=inference_session, frame=inputs.pixel_values[0])
711
+ ...
712
+ ... video_res_masks = processor.post_process_masks(
713
+ ... [sam3_tracker_video_output.pred_masks], original_sizes=inputs.original_sizes, binarize=False
714
+ ... )[0]
715
+ ... print(f"Frame {frame_idx}: mask shape {video_res_masks.shape}")
716
+ ```