Rawal Khirodkar commited on
Commit
dbdd74a
·
1 Parent(s): 8aae515

Switch pose detector from mmdet/RTMDet to DETR (transformers, Apache 2.0)

Browse files
app.py CHANGED
@@ -35,8 +35,7 @@ from sapiens.pose.datasets import UDPHeatmap, parse_pose_metainfo
35
  from sapiens.pose.evaluators import nms
36
  from sapiens.pose.models import init_model
37
 
38
- from detector_utils import adapt_mmdet_pipeline
39
- from mmdet.apis import inference_detector, init_detector
40
 
41
  from pose_render_utils import visualize_keypoints
42
 
@@ -72,9 +71,7 @@ POSE_MODELS = {
72
  }
73
  DEFAULT_SIZE = "1B"
74
 
75
- DETECTOR_REPO = "facebook/sapiens-pose-bbox-detector"
76
- DETECTOR_CKPT_FILENAME = "rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth"
77
- DETECTOR_CONFIG = os.path.join(ASSETS_DIR, "rtmdet_m_640-8xb32_coco-person_no_nms.py")
78
 
79
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
80
  BBOX_THR = 0.3
@@ -85,7 +82,7 @@ NMS_THR = 0.3
85
  # Model cache (load once, reuse across requests)
86
 
87
  _pose_model_cache: dict = {}
88
- _detector_cache = None
89
  _metainfo_cache = None
90
 
91
 
@@ -98,13 +95,12 @@ def _get_metainfo():
98
 
99
 
100
  def _get_detector():
101
- global _detector_cache
102
- if _detector_cache is None:
103
- ckpt = hf_hub_download(repo_id=DETECTOR_REPO, filename=DETECTOR_CKPT_FILENAME)
104
- det = init_detector(DETECTOR_CONFIG, ckpt, device=DEVICE)
105
- det.cfg = adapt_mmdet_pipeline(det.cfg)
106
- _detector_cache = det
107
- return _detector_cache
108
 
109
 
110
  def _get_pose_model(size: str):
@@ -133,15 +129,23 @@ print("[startup] ready.")
133
  # -----------------------------------------------------------------------------
134
  # Inference
135
 
136
- def _detect_persons(image_bgr: np.ndarray) -> np.ndarray:
137
- detector = _get_detector()
138
- det = inference_detector(detector, image_bgr)
139
- inst = det.pred_instances.cpu().numpy()
140
- bboxes = np.concatenate((inst.bboxes, inst.scores[:, None]), axis=1)
141
- bboxes = bboxes[(inst.labels == 0) & (inst.scores > BBOX_THR)]
142
- bboxes = bboxes[nms(bboxes, NMS_THR), :4] # x1,y1,x2,y2
 
 
 
 
 
 
 
 
143
  if len(bboxes) == 0:
144
- h, w = image_bgr.shape[:2]
145
  bboxes = np.array([[0, 0, w - 1, h - 1]], dtype=np.float32)
146
  return bboxes
147
 
@@ -181,7 +185,7 @@ def predict(image: Image.Image, size: str, kpt_thr: float):
181
  image_rgb = np.array(image.convert("RGB"))
182
  image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
183
 
184
- bboxes = _detect_persons(image_bgr)
185
  model = _get_pose_model(size)
186
  keypoints, scores = _estimate_pose(image_bgr, bboxes, model)
187
 
 
35
  from sapiens.pose.evaluators import nms
36
  from sapiens.pose.models import init_model
37
 
38
+ from transformers import DetrForObjectDetection, DetrImageProcessor
 
39
 
40
  from pose_render_utils import visualize_keypoints
41
 
 
71
  }
72
  DEFAULT_SIZE = "1B"
73
 
74
+ DETECTOR_MODEL_ID = "facebook/detr-resnet-50" # COCO person = label 1
 
 
75
 
76
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
77
  BBOX_THR = 0.3
 
82
  # Model cache (load once, reuse across requests)
83
 
84
  _pose_model_cache: dict = {}
85
+ _detector_cache: dict = {}
86
  _metainfo_cache = None
87
 
88
 
 
95
 
96
 
97
  def _get_detector():
98
+ if "model" not in _detector_cache:
99
+ proc = DetrImageProcessor.from_pretrained(DETECTOR_MODEL_ID)
100
+ model = DetrForObjectDetection.from_pretrained(DETECTOR_MODEL_ID).eval().to(DEVICE)
101
+ _detector_cache["proc"] = proc
102
+ _detector_cache["model"] = model
103
+ return _detector_cache["proc"], _detector_cache["model"]
 
104
 
105
 
106
  def _get_pose_model(size: str):
 
129
  # -----------------------------------------------------------------------------
130
  # Inference
131
 
132
+ def _detect_persons(image_rgb: np.ndarray) -> np.ndarray:
133
+ proc, model = _get_detector()
134
+ pil_img = Image.fromarray(image_rgb)
135
+ inputs = proc(images=pil_img, return_tensors="pt").to(DEVICE)
136
+ with torch.no_grad():
137
+ outputs = model(**inputs)
138
+ target_sizes = torch.tensor([image_rgb.shape[:2]], device=DEVICE) # (h, w)
139
+ results = proc.post_process_object_detection(
140
+ outputs, target_sizes=target_sizes, threshold=BBOX_THR
141
+ )[0]
142
+ person_mask = results["labels"] == 1 # COCO person
143
+ boxes = results["boxes"][person_mask].cpu().numpy() # (N, 4) x1,y1,x2,y2
144
+ scores = results["scores"][person_mask].cpu().numpy().reshape(-1, 1)
145
+ bboxes = np.concatenate([boxes, scores], axis=1) # (N, 5)
146
+ bboxes = bboxes[nms(bboxes, NMS_THR), :4]
147
  if len(bboxes) == 0:
148
+ h, w = image_rgb.shape[:2]
149
  bboxes = np.array([[0, 0, w - 1, h - 1]], dtype=np.float32)
150
  return bboxes
151
 
 
185
  image_rgb = np.array(image.convert("RGB"))
186
  image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
187
 
188
+ bboxes = _detect_persons(image_rgb)
189
  model = _get_pose_model(size)
190
  keypoints, scores = _estimate_pose(image_bgr, bboxes, model)
191
 
assets/rtmdet_m_640-8xb32_coco-person_no_nms.py DELETED
@@ -1,20 +0,0 @@
1
- _base_ = 'mmdet::rtmdet/rtmdet_m_8xb32-300e_coco.py'
2
-
3
- checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-m_8xb256-rsb-a1-600e_in1k-ecb3bbd9.pth' # noqa
4
-
5
- model = dict(
6
- backbone=dict(
7
- init_cfg=dict(
8
- type='Pretrained', prefix='backbone.', checkpoint=checkpoint)),
9
- bbox_head=dict(num_classes=1),
10
- test_cfg=dict(
11
- nms_pre=1000,
12
- min_bbox_size=0,
13
- score_thr=0.05,
14
- nms=None,
15
- max_per_img=100))
16
-
17
- train_dataloader = dict(dataset=dict(metainfo=dict(classes=('person', ))))
18
-
19
- val_dataloader = dict(dataset=dict(metainfo=dict(classes=('person', ))))
20
- test_dataloader = val_dataloader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
detector_utils.py DELETED
@@ -1,196 +0,0 @@
1
- from typing import List, Optional, Sequence, Union
2
-
3
- import torch
4
- import cv2
5
- import numpy as np
6
- from mmcv.ops import RoIPool
7
- from mmengine.dataset import Compose, pseudo_collate
8
- from mmengine.device import get_device
9
- from mmengine.registry import init_default_scope
10
- from mmdet.apis import inference_detector, init_detector
11
- from mmdet.structures import DetDataSample, SampleList
12
- from mmdet.utils import get_test_pipeline_cfg
13
-
14
-
15
- ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
16
-
17
- def nms(dets: np.ndarray, thr: float):
18
- """Greedily select boxes with high confidence and overlap <= thr.
19
- Args:
20
- dets (np.ndarray): [[x1, y1, x2, y2, score]].
21
- thr (float): Retain overlap < thr.
22
- Returns:
23
- list: Indexes to keep.
24
- """
25
- if len(dets) == 0:
26
- return []
27
-
28
- x1 = dets[:, 0]
29
- y1 = dets[:, 1]
30
- x2 = dets[:, 2]
31
- y2 = dets[:, 3]
32
- scores = dets[:, 4]
33
-
34
- areas = (x2 - x1 + 1) * (y2 - y1 + 1)
35
- order = scores.argsort()[::-1]
36
-
37
- keep = []
38
- while len(order) > 0:
39
- i = order[0]
40
- keep.append(i)
41
- xx1 = np.maximum(x1[i], x1[order[1:]])
42
- yy1 = np.maximum(y1[i], y1[order[1:]])
43
- xx2 = np.minimum(x2[i], x2[order[1:]])
44
- yy2 = np.minimum(y2[i], y2[order[1:]])
45
-
46
- w = np.maximum(0.0, xx2 - xx1 + 1)
47
- h = np.maximum(0.0, yy2 - yy1 + 1)
48
- inter = w * h
49
- ovr = inter / (areas[i] + areas[order[1:]] - inter)
50
-
51
- inds = np.where(ovr <= thr)[0]
52
- order = order[inds + 1]
53
-
54
- return keep
55
-
56
- def adapt_mmdet_pipeline(cfg):
57
- """Converts pipeline types in MMDetection's test dataloader to use the
58
- 'mmdet' namespace.
59
-
60
- Args:
61
- cfg (ConfigDict): Configuration dictionary for MMDetection.
62
-
63
- Returns:
64
- ConfigDict: Configuration dictionary with updated pipeline types.
65
- """
66
- # use lazy import to avoid hard dependence on mmdet
67
- from mmdet.datasets import transforms
68
-
69
- if 'test_dataloader' not in cfg:
70
- return cfg
71
-
72
- pipeline = cfg.test_dataloader.dataset.pipeline
73
- for trans in pipeline:
74
- if trans['type'] in dir(transforms):
75
- trans['type'] = 'mmdet.' + trans['type']
76
-
77
- return cfg
78
-
79
-
80
- def inference_detector(
81
- model: torch.nn.Module,
82
- imgs: ImagesType,
83
- test_pipeline: Optional[Compose] = None,
84
- text_prompt: Optional[str] = None,
85
- custom_entities: bool = False,
86
- ) -> Union[DetDataSample, SampleList]:
87
- """Inference image(s) with the detector.
88
-
89
- Args:
90
- model (nn.Module): The loaded detector.
91
- imgs (str, ndarray, Sequence[str/ndarray]):
92
- Either image files or loaded images.
93
- test_pipeline (:obj:`Compose`): Test pipeline.
94
-
95
- Returns:
96
- :obj:`DetDataSample` or list[:obj:`DetDataSample`]:
97
- If imgs is a list or tuple, the same length list type results
98
- will be returned, otherwise return the detection results directly.
99
- """
100
- if isinstance(imgs, torch.Tensor):
101
- if imgs.is_cuda:
102
- imgs = imgs.cpu()
103
-
104
- # Remove batch dimension and transpose
105
- imgs = imgs.squeeze(0).permute(1, 2, 0).numpy()
106
-
107
- # Ensure the data type is appropriate (uint8 for most image processing functions)
108
- imgs = (imgs * 255).astype(np.uint8)
109
-
110
- if isinstance(imgs, (list, tuple)) or (isinstance(imgs, np.ndarray) and len(imgs.shape) == 4):
111
- is_batch = True
112
- else:
113
- imgs = [imgs]
114
- is_batch = False
115
-
116
- cfg = model.cfg
117
-
118
- if test_pipeline is None:
119
- cfg = cfg.copy()
120
- test_pipeline = get_test_pipeline_cfg(cfg)
121
- if isinstance(imgs[0], np.ndarray):
122
- # Calling this method across libraries will result
123
- # in module unregistered error if not prefixed with mmdet.
124
- test_pipeline[0].type = "mmdet.LoadImageFromNDArray"
125
-
126
- test_pipeline = Compose(test_pipeline)
127
-
128
- if model.data_preprocessor.device.type == "cpu":
129
- for m in model.modules():
130
- assert not isinstance(
131
- m, RoIPool
132
- ), "CPU inference with RoIPool is not supported currently."
133
-
134
- result_list = []
135
- for i, img in enumerate(imgs):
136
- # prepare data
137
- if isinstance(img, np.ndarray):
138
- # TODO: remove img_id.
139
- data_ = dict(img=img, img_id=0)
140
- else:
141
- # TODO: remove img_id.
142
- data_ = dict(img_path=img, img_id=0)
143
-
144
- if text_prompt:
145
- data_["text"] = text_prompt
146
- data_["custom_entities"] = custom_entities
147
-
148
- # build the data pipeline
149
- data_ = test_pipeline(data_)
150
-
151
- data_["inputs"] = [data_["inputs"]]
152
- data_["data_samples"] = [data_["data_samples"]]
153
-
154
- # forward the model
155
- with torch.no_grad(), torch.autocast(device_type=get_device(), dtype=torch.bfloat16):
156
- results = model.test_step(data_)[0]
157
-
158
- result_list.append(results)
159
-
160
- if not is_batch:
161
- return result_list[0]
162
- else:
163
- return result_list
164
-
165
-
166
- def process_one_image_bbox(pred_instance, det_cat_id, bbox_thr, nms_thr):
167
- bboxes = np.concatenate(
168
- (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1
169
- )
170
- bboxes = bboxes[
171
- np.logical_and(
172
- pred_instance.labels == det_cat_id,
173
- pred_instance.scores > bbox_thr,
174
- )
175
- ]
176
- bboxes = bboxes[nms(bboxes, nms_thr), :4]
177
- return bboxes
178
-
179
-
180
- def process_images_detector(imgs, detector):
181
- """Visualize predicted keypoints (and heatmaps) of one image."""
182
- # predict bbox
183
- det_results = inference_detector(detector, imgs)
184
- pred_instances = list(
185
- map(lambda det_result: det_result.pred_instances.numpy(), det_results)
186
- )
187
- bboxes_batch = list(
188
- map(
189
- lambda pred_instance: process_one_image_bbox(
190
- pred_instance, 0, 0.3, 0.3 ## argparse.Namespace(det_cat_id=0, bbox_thr=0.3, nms_thr=0.3),
191
- ),
192
- pred_instances,
193
- )
194
- )
195
-
196
- return bboxes_batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  gradio==4.42.0
2
  spaces
3
 
4
- # Pinned to versions verified working together (sapiens2 + mmdet stack).
5
  torch==2.7.1
6
  torchvision==0.22.1
7
 
@@ -21,7 +20,6 @@ termcolor
21
  accelerate
22
  rich
23
 
24
- # RTMDet person detector stack (mmcv 2.2.0 builds cleanly on Python 3.12 + torch 2.7).
25
- mmengine==0.10.7
26
- mmcv==2.2.0
27
- mmdet==3.3.0
 
1
  gradio==4.42.0
2
  spaces
3
 
 
4
  torch==2.7.1
5
  torchvision==0.22.1
6
 
 
20
  accelerate
21
  rich
22
 
23
+ # Person bbox detector DETR via HuggingFace transformers (Apache 2.0, GPU-friendly).
24
+ transformers
25
+ timm