Rawal Khirodkar commited on
Commit
c139808
·
1 Parent(s): 3ae907a

Initial sapiens2-pose Space (HF download at startup, all 4 sizes)

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ default.profraw
4
+ .DS_Store
5
+ *.log
README.md CHANGED
@@ -1,12 +1,22 @@
1
  ---
2
  title: Sapiens2 Pose
3
- emoji: 🐢
4
- colorFrom: indigo
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.13.0
8
  app_file: app.py
9
  pinned: false
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
1
  ---
2
  title: Sapiens2 Pose
3
+ emoji: 🧬
4
+ colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.42.0
8
  app_file: app.py
9
  pinned: false
10
+ license: other
11
+ license_name: sapiens2-license
12
+ license_link: https://github.com/facebookresearch/sapiens2/blob/main/LICENSE.md
13
  ---
14
 
15
+ # Sapiens2: Pose Estimation
16
+ ### ICLR 2026
17
+
18
+ Top-down 308-keypoint human pose estimation. Detects people with RTMDet, then runs Sapiens2 on each crop.
19
+
20
+ - **Code:** [github.com/facebookresearch/sapiens2](https://github.com/facebookresearch/sapiens2)
21
+ - **Models:** [Sapiens2 collection](https://huggingface.co/facebook/sapiens2)
22
+ - **Paper:** https://openreview.net/pdf?id=IVAlYCqdvW
app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sapiens2 pose-estimation Gradio Space.
2
+
3
+ Top-down 308-keypoint pose: RTMDet finds people, Sapiens2 estimates keypoints
4
+ on each crop, and we draw skeleton + keypoints with the GOLIATH palette.
5
+
6
+ All checkpoints are pulled from HuggingFace at startup so this Space repo
7
+ stays small. The eager pre-load below warms the cache for the detector and
8
+ all 4 pose sizes during boot, so user requests are instant.
9
+ """
10
+
11
+ # Block mmpretrain: mmdet's reid modules try `import mmpretrain` inside
12
+ # try/except ImportError, but mmpretrain's BLIP language_model.py raises
13
+ # TypeError (transformers API drift) — escapes the except and kills the process.
14
+ import sys
15
+ sys.modules["mmpretrain"] = None
16
+
17
+ import json
18
+ import os
19
+ import tempfile
20
+ from typing import List, Tuple
21
+
22
+ import cv2
23
+ import gradio as gr
24
+ import numpy as np
25
+ import spaces
26
+ import torch
27
+ from huggingface_hub import hf_hub_download
28
+ from PIL import Image
29
+
30
+ from sapiens.pose.datasets import UDPHeatmap, parse_pose_metainfo
31
+ from sapiens.pose.evaluators import nms
32
+ from sapiens.pose.models import init_model
33
+
34
+ from detector_utils import adapt_mmdet_pipeline
35
+ from mmdet.apis import inference_detector, init_detector
36
+
37
+ from pose_render_utils import visualize_keypoints
38
+
39
+
40
+ # -----------------------------------------------------------------------------
41
+ # Config
42
+
43
+ ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
44
+ CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs")
45
+
46
+ # Sapiens2 pose checkpoints — fetched from HF model repos at startup.
47
+ POSE_MODELS = {
48
+ "0.4B": {
49
+ "repo": "facebook/sapiens2-pose-0.4b",
50
+ "filename": "sapiens2_0.4b_pose.safetensors",
51
+ "config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_keypoints308_shutterstock_goliath_3po-1024x768.py"),
52
+ },
53
+ "0.8B": {
54
+ "repo": "facebook/sapiens2-pose-0.8b",
55
+ "filename": "sapiens2_0.8b_pose.safetensors",
56
+ "config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_keypoints308_shutterstock_goliath_3po-1024x768.py"),
57
+ },
58
+ "1B": {
59
+ "repo": "facebook/sapiens2-pose-1b",
60
+ "filename": "sapiens2_1b_pose.safetensors",
61
+ "config": os.path.join(CONFIGS_DIR, "sapiens2_1b_keypoints308_shutterstock_goliath_3po-1024x768.py"),
62
+ },
63
+ "5B": {
64
+ "repo": "facebook/sapiens2-pose-5b",
65
+ "filename": "sapiens2_5b_pose.safetensors",
66
+ "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_keypoints308_shutterstock_goliath_3po-1024x768.py"),
67
+ },
68
+ }
69
+ DEFAULT_SIZE = "1B"
70
+
71
+ DETECTOR_REPO = "facebook/sapiens-pose-bbox-detector"
72
+ DETECTOR_CKPT_FILENAME = "rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth"
73
+ DETECTOR_CONFIG = os.path.join(ASSETS_DIR, "rtmdet_m_640-8xb32_coco-person_no_nms.py")
74
+
75
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
76
+ BBOX_THR = 0.3
77
+ NMS_THR = 0.3
78
+
79
+
80
+ # -----------------------------------------------------------------------------
81
+ # Model cache (load once, reuse across requests)
82
+
83
+ _pose_model_cache: dict = {}
84
+ _detector_cache = None
85
+ _metainfo_cache = None
86
+
87
+
88
+ def _get_metainfo():
89
+ global _metainfo_cache
90
+ if _metainfo_cache is None:
91
+ meta_path = os.path.join(CONFIGS_DIR, "_base_", "keypoints308.py")
92
+ _metainfo_cache = parse_pose_metainfo(dict(from_file=meta_path))
93
+ return _metainfo_cache
94
+
95
+
96
+ def _get_detector():
97
+ global _detector_cache
98
+ if _detector_cache is None:
99
+ ckpt = hf_hub_download(repo_id=DETECTOR_REPO, filename=DETECTOR_CKPT_FILENAME)
100
+ det = init_detector(DETECTOR_CONFIG, ckpt, device=DEVICE)
101
+ det.cfg = adapt_mmdet_pipeline(det.cfg)
102
+ _detector_cache = det
103
+ return _detector_cache
104
+
105
+
106
+ def _get_pose_model(size: str):
107
+ if size not in _pose_model_cache:
108
+ spec = POSE_MODELS[size]
109
+ ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"])
110
+ model = init_model(spec["config"], ckpt, device=DEVICE)
111
+ codec_cfg = dict(model.cfg.codec)
112
+ assert codec_cfg.pop("type") == "UDPHeatmap"
113
+ model.codec = UDPHeatmap(**codec_cfg)
114
+ model.pose_metainfo = _get_metainfo()
115
+ _pose_model_cache[size] = model
116
+ return _pose_model_cache[size]
117
+
118
+
119
+ # -----------------------------------------------------------------------------
120
+ # Eager pre-load: download + warm-load detector + all pose sizes at startup so
121
+ # the first user (and every user thereafter) gets an instant response.
122
+ print("[startup] pre-loading detector + all pose sizes ...")
123
+ _get_detector()
124
+ for _size in POSE_MODELS:
125
+ _get_pose_model(_size)
126
+ print("[startup] ready.")
127
+
128
+
129
+ # -----------------------------------------------------------------------------
130
+ # Inference
131
+
132
+ def _detect_persons(image_bgr: np.ndarray) -> np.ndarray:
133
+ detector = _get_detector()
134
+ det = inference_detector(detector, image_bgr)
135
+ inst = det.pred_instances.cpu().numpy()
136
+ bboxes = np.concatenate((inst.bboxes, inst.scores[:, None]), axis=1)
137
+ bboxes = bboxes[(inst.labels == 0) & (inst.scores > BBOX_THR)]
138
+ bboxes = bboxes[nms(bboxes, NMS_THR), :4] # x1,y1,x2,y2
139
+ if len(bboxes) == 0:
140
+ h, w = image_bgr.shape[:2]
141
+ bboxes = np.array([[0, 0, w - 1, h - 1]], dtype=np.float32)
142
+ return bboxes
143
+
144
+
145
+ def _estimate_pose(image_bgr: np.ndarray, bboxes: np.ndarray, model) -> Tuple[List[np.ndarray], List[np.ndarray]]:
146
+ inputs_list, samples_list = [], []
147
+ for bbox in bboxes:
148
+ data_info = dict(img=image_bgr, bbox=bbox[None], bbox_score=np.ones(1, dtype=np.float32))
149
+ data = model.pipeline(data_info)
150
+ data = model.data_preprocessor(data)
151
+ inputs_list.append(data["inputs"])
152
+ samples_list.append(data["data_samples"])
153
+
154
+ inputs = torch.cat(inputs_list, dim=0)
155
+ with torch.no_grad():
156
+ pred = model(inputs) # (B, K, h, w) heatmaps
157
+
158
+ pred = pred.cpu().numpy()
159
+ keypoints, scores = [], []
160
+ for i, sample in enumerate(samples_list):
161
+ kpts_i, scr_i = model.codec.decode(pred[i]) # (1, K, 2), (1, K)
162
+ meta = sample["meta"]
163
+ kpts_i = kpts_i / meta["input_size"] * meta["bbox_scale"] + meta["bbox_center"] - 0.5 * meta["bbox_scale"]
164
+ keypoints.append(kpts_i[0])
165
+ scores.append(scr_i[0])
166
+ return keypoints, scores
167
+
168
+
169
+ # -----------------------------------------------------------------------------
170
+ # Gradio handler
171
+
172
+ @spaces.GPU(duration=120)
173
+ def predict(image: Image.Image, size: str, kpt_thr: float):
174
+ if image is None:
175
+ return None, None
176
+
177
+ image_rgb = np.array(image.convert("RGB"))
178
+ image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
179
+
180
+ bboxes = _detect_persons(image_bgr)
181
+ model = _get_pose_model(size)
182
+ keypoints, scores = _estimate_pose(image_bgr, bboxes, model)
183
+
184
+ meta = model.pose_metainfo
185
+ vis_rgb = visualize_keypoints(
186
+ image=image_rgb,
187
+ keypoints=keypoints,
188
+ keypoints_visible=[np.ones(len(s), dtype=bool) for s in scores],
189
+ keypoint_scores=scores,
190
+ radius=3,
191
+ thickness=1,
192
+ kpt_thr=kpt_thr,
193
+ skeleton=meta["skeleton_links"],
194
+ kpt_color=meta["keypoint_colors"],
195
+ link_color=meta["skeleton_link_colors"],
196
+ )
197
+
198
+ instances = [
199
+ {
200
+ "bbox": [float(v) for v in np.asarray(bbox).reshape(-1)[:4]],
201
+ "keypoints": np.asarray(kpts, dtype=float).tolist(),
202
+ "keypoint_scores": np.asarray(s, dtype=float).reshape(-1).tolist(),
203
+ }
204
+ for bbox, kpts, s in zip(bboxes, keypoints, scores)
205
+ ]
206
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w") as f:
207
+ json.dump({"instances": instances}, f)
208
+ json_path = f.name
209
+
210
+ return Image.fromarray(vis_rgb), json_path
211
+
212
+
213
+ # -----------------------------------------------------------------------------
214
+ # UI
215
+
216
+ EXAMPLES = sorted(
217
+ os.path.join(ASSETS_DIR, "images", n)
218
+ for n in os.listdir(os.path.join(ASSETS_DIR, "images"))
219
+ if n.lower().endswith((".jpg", ".jpeg", ".png"))
220
+ )
221
+
222
+ with gr.Blocks(title="Sapiens2 Pose", theme=gr.themes.Default()) as demo:
223
+ gr.Markdown(
224
+ "# Sapiens2: Pose Estimation\n"
225
+ "### ICLR 2026\n"
226
+ "Top-down 308-keypoint human pose. RTMDet finds people; Sapiens2 estimates keypoints.\n\n"
227
+ "[Code](https://github.com/facebookresearch/sapiens2) · "
228
+ "[Models](https://huggingface.co/facebook/sapiens2) · "
229
+ "[Paper](https://openreview.net/pdf?id=IVAlYCqdvW)"
230
+ )
231
+ with gr.Row():
232
+ with gr.Column():
233
+ inp = gr.Image(label="Input", type="pil")
234
+ with gr.Row():
235
+ size = gr.Radio(
236
+ choices=list(POSE_MODELS.keys()),
237
+ value=DEFAULT_SIZE,
238
+ label="Model size",
239
+ )
240
+ thr = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Keypoint threshold")
241
+ run = gr.Button("Run", variant="primary")
242
+ gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14)
243
+ with gr.Column():
244
+ out_img = gr.Image(label="Pose-308 result", type="pil")
245
+ out_json = gr.File(label="Keypoints (.json)")
246
+
247
+ run.click(predict, inputs=[inp, size, thr], outputs=[out_img, out_json])
248
+
249
+
250
+ if __name__ == "__main__":
251
+ if torch.cuda.is_available():
252
+ torch.backends.cuda.matmul.allow_tf32 = True
253
+ torch.backends.cudnn.allow_tf32 = True
254
+ demo.launch(share=False)
assets/configs/_base_/keypoints308.py ADDED
The diff for this file is too large to render. See raw diff
 
assets/configs/sapiens2_0.4b_keypoints308_shutterstock_goliath_3po-1024x768.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ # num_iters = 2e4
16
+ num_iters = 1e4
17
+
18
+ # ------------------------------------------------------------------------------
19
+ vis_every_iters = 100
20
+ log_every_iters = 10
21
+ save_every_iters = 1000
22
+ val_every_iters = 1000
23
+
24
+ # # # # debug
25
+ # vis_every_iters = 1
26
+ # log_every_iters = 1
27
+ # val_every_iters = 2
28
+ # save_every_iters = 1000
29
+
30
+ load_from = None
31
+ resume = False
32
+
33
+ # ------------------------------------------------------------------
34
+ model_name = "sapiens2_0.4b"
35
+ embed_dim = 1024
36
+ num_layers = 24
37
+ num_heads = 16
38
+
39
+ layer_decay_rate = 0.8
40
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.4b_pretrain.safetensors"
41
+
42
+ ##-----------------------------------------------------------------
43
+ image_size = (1024, 768) ## height x width
44
+ patch_size = 16
45
+
46
+ sigma = 6 ## sigma is 2 for 256
47
+ scale = 4
48
+ num_keypoints = 308
49
+
50
+ # ------------------------------------------------------------------
51
+ use_fsdp = True
52
+ # use_fsdp = False
53
+
54
+ use_compile = True
55
+ # use_compile = False
56
+
57
+ ## DDP config
58
+ if use_fsdp is False:
59
+ accelerator_cfg = dict(
60
+ type="DDP",
61
+ log_with="tensorboard",
62
+ # find_unused_parameters=True,
63
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
64
+ max_interval=num_iters,
65
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
66
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
67
+ )
68
+
69
+ else:
70
+ accelerator_cfg = dict(
71
+ type="FSDP",
72
+ log_with="tensorboard",
73
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
74
+ max_interval=num_iters,
75
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
76
+ step_scheduler_with_optimizer=False,
77
+ fsdp_cfg=dict(
78
+ fsdp_version=2, # DTensor-based engine
79
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
80
+ mixed_precision=dict(
81
+ param_dtype="bf16",
82
+ reduce_dtype="bf16",
83
+ ),
84
+ cpu_ram_efficient_loading=False,
85
+ ),
86
+ )
87
+
88
+ if use_compile:
89
+ accelerator_cfg["compile_cfg"] = dict(
90
+ backend="inductor",
91
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
92
+ fullgraph=False,
93
+ dynamic=False,
94
+ )
95
+
96
+ # ------------------------------------------------------------------
97
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
98
+ logger = dict(
99
+ type="Logger",
100
+ log_interval=log_every_iters,
101
+ )
102
+ checkpoint = dict(
103
+ type="Checkpointer",
104
+ save_interval=save_every_iters,
105
+ )
106
+
107
+ visualizer = dict(
108
+ type="PoseVisualizer",
109
+ vis_interval=vis_every_iters,
110
+ vis_max_samples=4,
111
+ vis_image_width=384,
112
+ vis_image_height=512,
113
+ num_keypoints=num_keypoints,
114
+ )
115
+
116
+
117
+ ##-----------------------------------------------------------------
118
+ codec = dict(
119
+ type="UDPHeatmap",
120
+ input_size=(image_size[1], image_size[0]), ## width x height
121
+ heatmap_size=(int(image_size[1] / scale), int(image_size[0] / scale)),
122
+ sigma=sigma,
123
+ ) ## sigma is 2 for 256
124
+
125
+ train_pipeline = [
126
+ dict(type="PoseGetBBoxCenterScale"),
127
+ dict(type="PoseRandomFlip", direction="horizontal"), ## default prob is 0.5
128
+ dict(type="PoseRandomHalfBody"),
129
+ dict(type="PoseRandomBBoxTransform"),
130
+ dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True),
131
+ dict(type="RandomPhotoMetricDistortion", prob=0.8),
132
+ dict(
133
+ type="PoseAlbumentation",
134
+ transforms=[
135
+ dict(type="Blur", p=0.1),
136
+ dict(type="MedianBlur", p=0.1),
137
+ dict(
138
+ type="CoarseDropout",
139
+ max_holes=1,
140
+ max_height=0.4,
141
+ max_width=0.4,
142
+ min_holes=1,
143
+ min_height=0.2,
144
+ min_width=0.2,
145
+ p=1.0,
146
+ ),
147
+ ],
148
+ ),
149
+ dict(type="PoseGenerateTarget", encoder=codec),
150
+ dict(type="PosePackInputs"),
151
+ ]
152
+
153
+ val_pipeline = [
154
+ dict(type="PoseGetBBoxCenterScale"),
155
+ dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True),
156
+ dict(type="PosePackInputs"),
157
+ ]
158
+
159
+ test_pipeline = [
160
+ dict(type="PoseGetBBoxCenterScale"),
161
+ dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True),
162
+ dict(type="PosePackInputs"),
163
+ ]
164
+
165
+ ##------------------------------------------------------------------------
166
+ dataset_shutterstock_train = dict(
167
+ type="Keypoints308ShutterstockDataset",
168
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_102866/itw_shutterstock_body_keypoint_344_train:2025070300.json",
169
+ )
170
+
171
+ dataset_goliath_train = dict(
172
+ type="Keypoints308GoliathDataset",
173
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_keypoint_344_train:2024093001.json",
174
+ subsample_factor=8,
175
+ )
176
+
177
+ dataset_3po_train = dict(
178
+ type="Keypoints308_3PODataset",
179
+ ann_file=f"{_DATA_ROOT}/indices/3po/train.json",
180
+ subsample_factor=2,
181
+ )
182
+
183
+ # train_datasets = [dataset_shutterstock_train]
184
+ # train_datasets = [dataset_goliath_train]
185
+ # train_datasets = [dataset_3po_train]
186
+ train_datasets = (
187
+ [dataset_goliath_train] + 2 * [dataset_shutterstock_train] + [dataset_3po_train]
188
+ )
189
+
190
+ train_dataloader = dict(
191
+ batch_size=1,
192
+ num_workers=4,
193
+ persistent_workers=True,
194
+ shuffle=True,
195
+ dataset=dict(
196
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
197
+ ),
198
+ )
199
+
200
+ # ------------------------------------------------------------------------------
201
+ dataset_shutterstock_val = dict(
202
+ type="Keypoints308ShutterstockEvalDataset",
203
+ data_root=f"{_DATA_ROOT}/pose/data/shutterstock/test/images",
204
+ ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json",
205
+ test_mode=True,
206
+ pipeline=val_pipeline,
207
+ )
208
+
209
+ val_dataloader = dict(
210
+ batch_size=4,
211
+ num_workers=4,
212
+ persistent_workers=True,
213
+ multiprocessing_context="spawn", ## avoids fork error with airstore
214
+ shuffle=False,
215
+ dataset=dataset_shutterstock_val,
216
+ collate_fn=dict(type="eval_collate"),
217
+ )
218
+
219
+ val_cfg = dict(
220
+ val_interval=val_every_iters,
221
+ flip_test=True, ## left right flip
222
+ evaluator=dict(
223
+ type="Keypoints308Evaluator",
224
+ decoder=codec,
225
+ ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json",
226
+ ),
227
+ )
228
+
229
+ # dataset_goliath_val = dict(
230
+ # type="Keypoints308GoliathEvalDataset",
231
+ # data_root=f"{_DATA_ROOT}/pose/data/goliath/test_10000/images",
232
+ # ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json",
233
+ # test_mode=True,
234
+ # # num_samples=10, ## debug
235
+ # pipeline=val_pipeline,
236
+ # )
237
+
238
+ # val_dataloader = dict(
239
+ # batch_size=4,
240
+ # num_workers=4,
241
+ # persistent_workers=True,
242
+ # multiprocessing_context="spawn", ## avoids fork error with airstore
243
+ # # num_workers=0, # debug
244
+ # # persistent_workers=False, # debug
245
+ # shuffle=False,
246
+ # dataset=dataset_goliath_val,
247
+ # collate_fn=dict(type="eval_collate"),
248
+ # )
249
+
250
+ # val_cfg = dict(
251
+ # val_interval=val_every_iters,
252
+ # flip_test=True, ## left right flip
253
+ # evaluator=dict(
254
+ # type="Keypoints308Evaluator",
255
+ # decoder=codec,
256
+ # ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json",
257
+ # ),
258
+ # )
259
+
260
+ data_preprocessor = dict(
261
+ type="ImagePreprocessor",
262
+ mean=[123.675, 116.28, 103.53],
263
+ std=[58.395, 57.12, 57.375],
264
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
265
+ )
266
+
267
+ ##-----------------------------------------------------------------
268
+ model = dict(
269
+ type="PoseTopdownEstimator",
270
+ backbone=dict(
271
+ type="Sapiens2",
272
+ arch=model_name,
273
+ img_size=image_size,
274
+ patch_size=patch_size,
275
+ final_norm=True,
276
+ use_tokenizer=False,
277
+ with_cls_token=True,
278
+ out_type="featmap",
279
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
280
+ ),
281
+ decode_head=dict(
282
+ type="PoseHeatmapHead",
283
+ in_channels=embed_dim,
284
+ out_channels=num_keypoints,
285
+ deconv_out_channels=(1024, 768), ## this will 2x at each step. so total is 4x
286
+ deconv_kernel_sizes=(4, 4),
287
+ conv_out_channels=(512, 512, 256),
288
+ conv_kernel_sizes=(1, 1, 1),
289
+ loss_decode=dict(
290
+ type="KeypointMSELoss", use_target_weight=True, loss_weight=10.0
291
+ ),
292
+ # loss_decode=dict(type='KeypointOHKMMSELoss', use_target_weight=True, topk=128), ## loss only for top 128 keypoints. for finetuning later.
293
+ ),
294
+ )
295
+
296
+
297
+ ##-----------------------------------------------------------------
298
+ optimizer = dict(
299
+ type="AdamW",
300
+ lr=5e-4,
301
+ betas=(0.9, 0.999),
302
+ weight_decay=0.1,
303
+ paramwise_cfg=dict(
304
+ num_layers=num_layers,
305
+ layer_decay_rate=layer_decay_rate,
306
+ ),
307
+ fused=True,
308
+ )
309
+
310
+ scheduler = dict(
311
+ type="SequentialLR",
312
+ milestones=[warmup_iters],
313
+ schedulers=[
314
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
315
+ dict(
316
+ type="PolynomialLR",
317
+ total_iters=num_iters - warmup_iters,
318
+ power=1.0,
319
+ ),
320
+ ],
321
+ )
322
+
323
+ clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0)
324
+
325
+ runner_type = "PoseRunner"
assets/configs/sapiens2_0.8b_keypoints308_shutterstock_goliath_3po-1024x768.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ # num_iters = 2e4
16
+ num_iters = 1e4
17
+
18
+ # ------------------------------------------------------------------------------
19
+ vis_every_iters = 100
20
+ log_every_iters = 10
21
+ save_every_iters = 1000
22
+ val_every_iters = 1000
23
+
24
+ # # # # debug
25
+ # vis_every_iters = 1
26
+ # log_every_iters = 1
27
+ # val_every_iters = 2
28
+ # save_every_iters = 1000
29
+
30
+ load_from = None
31
+ resume = False
32
+
33
+ # ------------------------------------------------------------------
34
+ model_name = "sapiens2_0.8b"
35
+ embed_dim = 1280
36
+ num_layers = 32
37
+ num_heads = 16
38
+
39
+ layer_decay_rate = 0.85
40
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.8b_pretrain.safetensors"
41
+
42
+ ##-----------------------------------------------------------------
43
+ image_size = (1024, 768) ## height x width
44
+ patch_size = 16
45
+
46
+ sigma = 6 ## sigma is 2 for 256
47
+ scale = 4
48
+ num_keypoints = 308
49
+
50
+ # ------------------------------------------------------------------
51
+ use_fsdp = True
52
+ # use_fsdp = False
53
+
54
+ use_compile = True
55
+ # use_compile = False
56
+
57
+ ## DDP config
58
+ if use_fsdp is False:
59
+ accelerator_cfg = dict(
60
+ type="DDP",
61
+ log_with="tensorboard",
62
+ # find_unused_parameters=True,
63
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
64
+ max_interval=num_iters,
65
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
66
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
67
+ )
68
+
69
+ else:
70
+ accelerator_cfg = dict(
71
+ type="FSDP",
72
+ log_with="tensorboard",
73
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
74
+ max_interval=num_iters,
75
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
76
+ step_scheduler_with_optimizer=False,
77
+ fsdp_cfg=dict(
78
+ fsdp_version=2, # DTensor-based engine
79
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
80
+ mixed_precision=dict(
81
+ param_dtype="bf16",
82
+ reduce_dtype="bf16",
83
+ ),
84
+ cpu_ram_efficient_loading=False,
85
+ ),
86
+ )
87
+
88
+ if use_compile:
89
+ accelerator_cfg["compile_cfg"] = dict(
90
+ backend="inductor",
91
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
92
+ fullgraph=False,
93
+ dynamic=False,
94
+ )
95
+
96
+ # ------------------------------------------------------------------
97
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
98
+ logger = dict(
99
+ type="Logger",
100
+ log_interval=log_every_iters,
101
+ )
102
+ checkpoint = dict(
103
+ type="Checkpointer",
104
+ save_interval=save_every_iters,
105
+ )
106
+
107
+ visualizer = dict(
108
+ type="PoseVisualizer",
109
+ vis_interval=vis_every_iters,
110
+ vis_max_samples=4,
111
+ vis_image_width=384,
112
+ vis_image_height=512,
113
+ num_keypoints=num_keypoints,
114
+ )
115
+
116
+
117
+ ##-----------------------------------------------------------------
118
+ codec = dict(
119
+ type="UDPHeatmap",
120
+ input_size=(image_size[1], image_size[0]), ## width x height
121
+ heatmap_size=(int(image_size[1] / scale), int(image_size[0] / scale)),
122
+ sigma=sigma,
123
+ ) ## sigma is 2 for 256
124
+
125
+ train_pipeline = [
126
+ dict(type="PoseGetBBoxCenterScale"),
127
+ dict(type="PoseRandomFlip", direction="horizontal"), ## default prob is 0.5
128
+ dict(type="PoseRandomHalfBody"),
129
+ dict(type="PoseRandomBBoxTransform"),
130
+ dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True),
131
+ dict(type="RandomPhotoMetricDistortion", prob=0.8),
132
+ dict(
133
+ type="PoseAlbumentation",
134
+ transforms=[
135
+ dict(type="Blur", p=0.1),
136
+ dict(type="MedianBlur", p=0.1),
137
+ dict(
138
+ type="CoarseDropout",
139
+ max_holes=1,
140
+ max_height=0.4,
141
+ max_width=0.4,
142
+ min_holes=1,
143
+ min_height=0.2,
144
+ min_width=0.2,
145
+ p=1.0,
146
+ ),
147
+ ],
148
+ ),
149
+ dict(type="PoseGenerateTarget", encoder=codec),
150
+ dict(type="PosePackInputs"),
151
+ ]
152
+
153
+ val_pipeline = [
154
+ dict(type="PoseGetBBoxCenterScale"),
155
+ dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True),
156
+ dict(type="PosePackInputs"),
157
+ ]
158
+
159
+ test_pipeline = [
160
+ dict(type="PoseGetBBoxCenterScale"),
161
+ dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True),
162
+ dict(type="PosePackInputs"),
163
+ ]
164
+
165
+ ##------------------------------------------------------------------------
166
+ dataset_shutterstock_train = dict(
167
+ type="Keypoints308ShutterstockDataset",
168
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_102866/itw_shutterstock_body_keypoint_344_train:2025070300.json",
169
+ )
170
+
171
+ dataset_goliath_train = dict(
172
+ type="Keypoints308GoliathDataset",
173
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_keypoint_344_train:2024093001.json",
174
+ subsample_factor=8,
175
+ )
176
+
177
+ dataset_3po_train = dict(
178
+ type="Keypoints308_3PODataset",
179
+ ann_file=f"{_DATA_ROOT}/indices/3po/train.json",
180
+ subsample_factor=2,
181
+ )
182
+
183
+ # train_datasets = [dataset_shutterstock_train]
184
+ # train_datasets = [dataset_goliath_train]
185
+ # train_datasets = [dataset_3po_train]
186
+ train_datasets = (
187
+ [dataset_goliath_train] + 2 * [dataset_shutterstock_train] + [dataset_3po_train]
188
+ )
189
+
190
+ train_dataloader = dict(
191
+ batch_size=1,
192
+ num_workers=4,
193
+ persistent_workers=True,
194
+ shuffle=True,
195
+ dataset=dict(
196
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
197
+ ),
198
+ )
199
+
200
+ # ------------------------------------------------------------------------------
201
+ dataset_shutterstock_val = dict(
202
+ type="Keypoints308ShutterstockEvalDataset",
203
+ data_root=f"{_DATA_ROOT}/pose/data/shutterstock/test/images",
204
+ ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json",
205
+ test_mode=True,
206
+ pipeline=val_pipeline,
207
+ )
208
+
209
+ val_dataloader = dict(
210
+ batch_size=4,
211
+ num_workers=4,
212
+ persistent_workers=True,
213
+ multiprocessing_context="spawn", ## avoids fork error with airstore
214
+ shuffle=False,
215
+ dataset=dataset_shutterstock_val,
216
+ collate_fn=dict(type="eval_collate"),
217
+ )
218
+
219
+ val_cfg = dict(
220
+ val_interval=val_every_iters,
221
+ flip_test=True, ## left right flip
222
+ evaluator=dict(
223
+ type="Keypoints308Evaluator",
224
+ decoder=codec,
225
+ ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json",
226
+ ),
227
+ )
228
+
229
+ # dataset_goliath_val = dict(
230
+ # type="Keypoints308GoliathEvalDataset",
231
+ # data_root=f"{_DATA_ROOT}/pose/data/goliath/test_10000/images",
232
+ # ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json",
233
+ # test_mode=True,
234
+ # # num_samples=10, ## debug
235
+ # pipeline=val_pipeline,
236
+ # )
237
+
238
+ # val_dataloader = dict(
239
+ # batch_size=4,
240
+ # num_workers=4,
241
+ # persistent_workers=True,
242
+ # multiprocessing_context="spawn", ## avoids fork error with airstore
243
+ # # num_workers=0, # debug
244
+ # # persistent_workers=False, # debug
245
+ # shuffle=False,
246
+ # dataset=dataset_goliath_val,
247
+ # collate_fn=dict(type="eval_collate"),
248
+ # )
249
+
250
+ # val_cfg = dict(
251
+ # val_interval=val_every_iters,
252
+ # flip_test=True, ## left right flip
253
+ # evaluator=dict(
254
+ # type="Keypoints308Evaluator",
255
+ # decoder=codec,
256
+ # ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json",
257
+ # ),
258
+ # )
259
+
260
+ data_preprocessor = dict(
261
+ type="ImagePreprocessor",
262
+ mean=[123.675, 116.28, 103.53],
263
+ std=[58.395, 57.12, 57.375],
264
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
265
+ )
266
+
267
+ ##-----------------------------------------------------------------
268
+ model = dict(
269
+ type="PoseTopdownEstimator",
270
+ backbone=dict(
271
+ type="Sapiens2",
272
+ arch=model_name,
273
+ img_size=image_size,
274
+ patch_size=patch_size,
275
+ final_norm=True,
276
+ use_tokenizer=False,
277
+ with_cls_token=True,
278
+ out_type="featmap",
279
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
280
+ ),
281
+ decode_head=dict(
282
+ type="PoseHeatmapHead",
283
+ in_channels=embed_dim,
284
+ out_channels=num_keypoints,
285
+ deconv_out_channels=(1024, 768), ## this will 2x at each step. so total is 4x
286
+ deconv_kernel_sizes=(4, 4),
287
+ conv_out_channels=(512, 512, 256),
288
+ conv_kernel_sizes=(1, 1, 1),
289
+ loss_decode=dict(
290
+ type="KeypointMSELoss", use_target_weight=True, loss_weight=10.0
291
+ ),
292
+ # loss_decode=dict(type='KeypointOHKMMSELoss', use_target_weight=True, topk=128), ## loss only for top 128 keypoints. for finetuning later.
293
+ ),
294
+ )
295
+
296
+
297
+ ##-----------------------------------------------------------------
298
+ optimizer = dict(
299
+ type="AdamW",
300
+ lr=5e-4,
301
+ betas=(0.9, 0.999),
302
+ weight_decay=0.1,
303
+ paramwise_cfg=dict(
304
+ num_layers=num_layers,
305
+ layer_decay_rate=layer_decay_rate,
306
+ ),
307
+ fused=True,
308
+ )
309
+
310
+ scheduler = dict(
311
+ type="SequentialLR",
312
+ milestones=[warmup_iters],
313
+ schedulers=[
314
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
315
+ dict(
316
+ type="PolynomialLR",
317
+ total_iters=num_iters - warmup_iters,
318
+ power=1.0,
319
+ ),
320
+ ],
321
+ )
322
+
323
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
324
+
325
+ runner_type = "PoseRunner"
assets/configs/sapiens2_1b_keypoints308_shutterstock_goliath_3po-1024x768.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ # num_iters = 4e4
16
+ num_iters = 2e4
17
+
18
+ # ------------------------------------------------------------------------------
19
+ vis_every_iters = 100
20
+ log_every_iters = 10
21
+ save_every_iters = 1000
22
+ val_every_iters = 1000
23
+
24
+ # # # # debug
25
+ # vis_every_iters = 1
26
+ # log_every_iters = 1
27
+ # val_every_iters = 2
28
+ # save_every_iters = 1000
29
+
30
+ load_from = None
31
+ resume = False
32
+
33
+ # ------------------------------------------------------------------
34
+ model_name = "sapiens2_1b"
35
+ embed_dim = 1536
36
+ num_layers = 40
37
+ num_heads = 24
38
+
39
+ layer_decay_rate = 0.9
40
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_1b_pretrain.safetensors"
41
+
42
+ ##-----------------------------------------------------------------
43
+ image_size = (1024, 768) ## height x width
44
+ patch_size = 16
45
+
46
+ sigma = 6 ## sigma is 2 for 256
47
+ scale = 4
48
+ num_keypoints = 308
49
+
50
+ # ------------------------------------------------------------------
51
+ use_fsdp = True
52
+ # use_fsdp = False
53
+
54
+ use_compile = True
55
+ # use_compile = False
56
+
57
+ ## DDP config
58
+ if use_fsdp is False:
59
+ accelerator_cfg = dict(
60
+ type="DDP",
61
+ log_with="tensorboard",
62
+ # find_unused_parameters=True,
63
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
64
+ max_interval=num_iters,
65
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
66
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
67
+ )
68
+
69
+ else:
70
+ accelerator_cfg = dict(
71
+ type="FSDP",
72
+ log_with="tensorboard",
73
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
74
+ max_interval=num_iters,
75
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
76
+ step_scheduler_with_optimizer=False,
77
+ fsdp_cfg=dict(
78
+ fsdp_version=2, # DTensor-based engine
79
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
80
+ mixed_precision=dict(
81
+ param_dtype="bf16",
82
+ reduce_dtype="bf16",
83
+ ),
84
+ cpu_ram_efficient_loading=False,
85
+ ),
86
+ )
87
+
88
+ if use_compile:
89
+ accelerator_cfg["compile_cfg"] = dict(
90
+ backend="inductor",
91
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
92
+ fullgraph=False,
93
+ dynamic=False,
94
+ )
95
+
96
+ # ------------------------------------------------------------------
97
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
98
+ logger = dict(
99
+ type="Logger",
100
+ log_interval=log_every_iters,
101
+ )
102
+ checkpoint = dict(
103
+ type="Checkpointer",
104
+ save_interval=save_every_iters,
105
+ )
106
+
107
+ visualizer = dict(
108
+ type="PoseVisualizer",
109
+ vis_interval=vis_every_iters,
110
+ vis_max_samples=4,
111
+ vis_image_width=384,
112
+ vis_image_height=512,
113
+ num_keypoints=num_keypoints,
114
+ )
115
+
116
+
117
+ ##-----------------------------------------------------------------
118
+ codec = dict(
119
+ type="UDPHeatmap",
120
+ input_size=(image_size[1], image_size[0]), ## width x height
121
+ heatmap_size=(int(image_size[1] / scale), int(image_size[0] / scale)),
122
+ sigma=sigma,
123
+ ) ## sigma is 2 for 256
124
+
125
+ train_pipeline = [
126
+ dict(type="PoseGetBBoxCenterScale"),
127
+ dict(type="PoseRandomFlip", direction="horizontal"), ## default prob is 0.5
128
+ dict(type="PoseRandomHalfBody"),
129
+ dict(type="PoseRandomBBoxTransform"),
130
+ dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True),
131
+ dict(type="RandomPhotoMetricDistortion", prob=0.8),
132
+ dict(
133
+ type="PoseAlbumentation",
134
+ transforms=[
135
+ dict(type="Blur", p=0.1),
136
+ dict(type="MedianBlur", p=0.1),
137
+ dict(
138
+ type="CoarseDropout",
139
+ max_holes=1,
140
+ max_height=0.4,
141
+ max_width=0.4,
142
+ min_holes=1,
143
+ min_height=0.2,
144
+ min_width=0.2,
145
+ p=1.0,
146
+ ),
147
+ ],
148
+ ),
149
+ dict(type="PoseGenerateTarget", encoder=codec),
150
+ dict(type="PosePackInputs"),
151
+ ]
152
+
153
+ val_pipeline = [
154
+ dict(type="PoseGetBBoxCenterScale"),
155
+ dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True),
156
+ dict(type="PosePackInputs"),
157
+ ]
158
+
159
+ test_pipeline = [
160
+ dict(type="PoseGetBBoxCenterScale"),
161
+ dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True),
162
+ dict(type="PosePackInputs"),
163
+ ]
164
+
165
+ ##------------------------------------------------------------------------
166
+ dataset_shutterstock_train = dict(
167
+ type="Keypoints308ShutterstockDataset",
168
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_102866/itw_shutterstock_body_keypoint_344_train:2025070300.json",
169
+ )
170
+
171
+ dataset_goliath_train = dict(
172
+ type="Keypoints308GoliathDataset",
173
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_keypoint_344_train:2024093001.json",
174
+ subsample_factor=8,
175
+ )
176
+
177
+ dataset_3po_train = dict(
178
+ type="Keypoints308_3PODataset",
179
+ ann_file=f"{_DATA_ROOT}/indices/3po/train.json",
180
+ subsample_factor=2,
181
+ )
182
+
183
+ # train_datasets = [dataset_shutterstock_train]
184
+ # train_datasets = [dataset_goliath_train]
185
+ # train_datasets = [dataset_3po_train]
186
+ train_datasets = (
187
+ [dataset_goliath_train] + 2 * [dataset_shutterstock_train] + [dataset_3po_train]
188
+ )
189
+
190
+ train_dataloader = dict(
191
+ batch_size=1,
192
+ num_workers=4,
193
+ persistent_workers=True,
194
+ shuffle=True,
195
+ dataset=dict(
196
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
197
+ ),
198
+ )
199
+
200
+ # ------------------------------------------------------------------------------
201
+ dataset_shutterstock_val = dict(
202
+ type="Keypoints308ShutterstockEvalDataset",
203
+ data_root=f"{_DATA_ROOT}/pose/data/shutterstock/test/images",
204
+ ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json",
205
+ test_mode=True,
206
+ # num_samples=10, ## debug
207
+ pipeline=val_pipeline,
208
+ )
209
+
210
+ val_dataloader = dict(
211
+ batch_size=4,
212
+ num_workers=4,
213
+ persistent_workers=True,
214
+ multiprocessing_context="spawn", ## avoids fork error with airstore
215
+ # num_workers=0, # debug
216
+ # persistent_workers=False, # debug
217
+ shuffle=False,
218
+ dataset=dataset_shutterstock_val,
219
+ collate_fn=dict(type="eval_collate"),
220
+ )
221
+
222
+ val_cfg = dict(
223
+ val_interval=val_every_iters,
224
+ flip_test=True, ## left right flip
225
+ evaluator=dict(
226
+ type="Keypoints308Evaluator",
227
+ decoder=codec,
228
+ ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json",
229
+ ),
230
+ )
231
+
232
+ # dataset_goliath_val = dict(
233
+ # type="Keypoints308GoliathEvalDataset",
234
+ # data_root=f"{_DATA_ROOT}/pose/data/goliath/test_10000/images",
235
+ # ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json",
236
+ # test_mode=True,
237
+ # # num_samples=10, ## debug
238
+ # pipeline=val_pipeline,
239
+ # )
240
+
241
+ # val_dataloader = dict(
242
+ # batch_size=4,
243
+ # num_workers=4,
244
+ # persistent_workers=True,
245
+ # multiprocessing_context="spawn", ## avoids fork error with airstore
246
+ # # num_workers=0, # debug
247
+ # # persistent_workers=False, # debug
248
+ # shuffle=False,
249
+ # dataset=dataset_goliath_val,
250
+ # collate_fn=dict(type="eval_collate"),
251
+ # )
252
+
253
+ # val_cfg = dict(
254
+ # val_interval=val_every_iters,
255
+ # flip_test=True, ## left right flip
256
+ # evaluator=dict(
257
+ # type="Keypoints308Evaluator",
258
+ # decoder=codec,
259
+ # ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json",
260
+ # ),
261
+ # )
262
+
263
+ data_preprocessor = dict(
264
+ type="ImagePreprocessor",
265
+ mean=[123.675, 116.28, 103.53],
266
+ std=[58.395, 57.12, 57.375],
267
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
268
+ )
269
+
270
+ ##-----------------------------------------------------------------
271
+ model = dict(
272
+ type="PoseTopdownEstimator",
273
+ backbone=dict(
274
+ type="Sapiens2",
275
+ arch=model_name,
276
+ img_size=image_size,
277
+ patch_size=patch_size,
278
+ final_norm=True,
279
+ use_tokenizer=False,
280
+ with_cls_token=True,
281
+ out_type="featmap",
282
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
283
+ ),
284
+ decode_head=dict(
285
+ type="PoseHeatmapHead",
286
+ in_channels=embed_dim,
287
+ out_channels=num_keypoints,
288
+ deconv_out_channels=(1536, 1024), ## this will 2x at each step. so total is 4x
289
+ deconv_kernel_sizes=(4, 4),
290
+ conv_out_channels=(768, 512, 256),
291
+ conv_kernel_sizes=(1, 1, 1),
292
+ loss_decode=dict(
293
+ type="KeypointMSELoss", use_target_weight=True, loss_weight=10.0
294
+ ),
295
+ # loss_decode=dict(type='KeypointOHKMMSELoss', use_target_weight=True, topk=128), ## loss only for top 128 keypoints. for finetuning later.
296
+ ),
297
+ )
298
+
299
+
300
+ ##-----------------------------------------------------------------
301
+ optimizer = dict(
302
+ type="AdamW",
303
+ lr=5e-4,
304
+ betas=(0.9, 0.999),
305
+ weight_decay=0.1,
306
+ paramwise_cfg=dict(
307
+ num_layers=num_layers,
308
+ layer_decay_rate=layer_decay_rate,
309
+ ),
310
+ fused=True,
311
+ )
312
+
313
+ scheduler = dict(
314
+ type="SequentialLR",
315
+ milestones=[warmup_iters],
316
+ schedulers=[
317
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
318
+ dict(
319
+ type="PolynomialLR",
320
+ total_iters=num_iters - warmup_iters,
321
+ power=1.0,
322
+ ),
323
+ ],
324
+ )
325
+
326
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
327
+
328
+ runner_type = "PoseRunner"
assets/configs/sapiens2_5b_keypoints308_shutterstock_goliath_3po-1024x768.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ _CHECKPOINT_ROOT = os.path.expanduser(
10
+ os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host")
11
+ )
12
+ _DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data"))
13
+
14
+ warmup_iters = 500
15
+ # num_iters = 4e4
16
+ num_iters = 2e4 ## light finetune
17
+
18
+ # ------------------------------------------------------------------------------
19
+ vis_every_iters = 100
20
+ log_every_iters = 10
21
+ save_every_iters = 1000
22
+ val_every_iters = 1000
23
+
24
+ # # # # debug
25
+ # vis_every_iters = 1
26
+ # log_every_iters = 1
27
+ # val_every_iters = 2
28
+ # save_every_iters = 1000
29
+
30
+ load_from = None
31
+ resume = False
32
+
33
+ # ------------------------------------------------------------------
34
+ model_name = "sapiens2_5b"
35
+ embed_dim = 2432
36
+ num_layers = 56
37
+ num_heads = 32
38
+ layer_decay_rate = 0.94
39
+
40
+ pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_5b_pretrain.safetensors"
41
+
42
+ ##-----------------------------------------------------------------
43
+ image_size = (1024, 768) ## height x width
44
+ patch_size = 16
45
+
46
+ sigma = 6 ## sigma is 2 for 256
47
+ scale = 4
48
+ num_keypoints = 308
49
+
50
+ # ------------------------------------------------------------------
51
+ use_fsdp = True
52
+ # use_fsdp = False
53
+
54
+ use_compile = True
55
+ # use_compile = False
56
+
57
+ ## DDP config
58
+ if use_fsdp is False:
59
+ accelerator_cfg = dict(
60
+ type="DDP",
61
+ log_with="tensorboard",
62
+ # find_unused_parameters=True,
63
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
64
+ max_interval=num_iters,
65
+ # mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
66
+ step_scheduler_with_optimizer=False, ## schedule independent of n_gpus
67
+ )
68
+
69
+ else:
70
+ accelerator_cfg = dict(
71
+ type="FSDP",
72
+ log_with="tensorboard",
73
+ gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off.
74
+ max_interval=num_iters,
75
+ mixed_precision="bf16", # Options: ‘no’,‘fp16’,‘bf16’ or ‘fp8’.
76
+ step_scheduler_with_optimizer=False,
77
+ fsdp_cfg=dict(
78
+ fsdp_version=2, # DTensor-based engine
79
+ state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT
80
+ mixed_precision=dict(
81
+ param_dtype="bf16",
82
+ reduce_dtype="bf16",
83
+ ),
84
+ cpu_ram_efficient_loading=False,
85
+ ),
86
+ )
87
+
88
+ if use_compile:
89
+ accelerator_cfg["compile_cfg"] = dict(
90
+ backend="inductor",
91
+ mode="default", # Options: "default", "reduce-overhead", "max-autotune"
92
+ fullgraph=False,
93
+ dynamic=False,
94
+ )
95
+
96
+ # ------------------------------------------------------------------
97
+ randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
98
+ logger = dict(
99
+ type="Logger",
100
+ log_interval=log_every_iters,
101
+ )
102
+ checkpoint = dict(
103
+ type="Checkpointer",
104
+ save_interval=save_every_iters,
105
+ )
106
+
107
+ visualizer = dict(
108
+ type="PoseVisualizer",
109
+ vis_interval=vis_every_iters,
110
+ vis_max_samples=4,
111
+ vis_image_width=384,
112
+ vis_image_height=512,
113
+ num_keypoints=num_keypoints,
114
+ )
115
+
116
+
117
+ ##-----------------------------------------------------------------
118
+ codec = dict(
119
+ type="UDPHeatmap",
120
+ input_size=(image_size[1], image_size[0]), ## width x height
121
+ heatmap_size=(int(image_size[1] / scale), int(image_size[0] / scale)),
122
+ sigma=sigma,
123
+ ) ## sigma is 2 for 256
124
+
125
+ train_pipeline = [
126
+ dict(type="PoseGetBBoxCenterScale"),
127
+ dict(type="PoseRandomFlip", direction="horizontal"), ## default prob is 0.5
128
+ dict(type="PoseRandomHalfBody"),
129
+ dict(type="PoseRandomBBoxTransform"),
130
+ dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True),
131
+ dict(type="RandomPhotoMetricDistortion", prob=0.8),
132
+ dict(
133
+ type="PoseAlbumentation",
134
+ transforms=[
135
+ dict(type="Blur", p=0.1),
136
+ dict(type="MedianBlur", p=0.1),
137
+ dict(
138
+ type="CoarseDropout",
139
+ max_holes=1,
140
+ max_height=0.4,
141
+ max_width=0.4,
142
+ min_holes=1,
143
+ min_height=0.2,
144
+ min_width=0.2,
145
+ p=1.0,
146
+ ),
147
+ ],
148
+ ),
149
+ dict(type="PoseGenerateTarget", encoder=codec),
150
+ dict(type="PosePackInputs"),
151
+ ]
152
+
153
+ val_pipeline = [
154
+ dict(type="PoseGetBBoxCenterScale"),
155
+ dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True),
156
+ dict(type="PosePackInputs"),
157
+ ]
158
+
159
+ test_pipeline = [
160
+ dict(type="PoseGetBBoxCenterScale"),
161
+ dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True),
162
+ dict(type="PosePackInputs"),
163
+ ]
164
+
165
+ ##------------------------------------------------------------------------
166
+ dataset_shutterstock_train = dict(
167
+ type="Keypoints308ShutterstockDataset",
168
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_102866/itw_shutterstock_body_keypoint_344_train:2025070300.json",
169
+ )
170
+
171
+ dataset_goliath_train = dict(
172
+ type="Keypoints308GoliathDataset",
173
+ ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_keypoint_344_train:2024093001.json",
174
+ subsample_factor=8,
175
+ )
176
+
177
+ dataset_3po_train = dict(
178
+ type="Keypoints308_3PODataset",
179
+ ann_file=f"{_DATA_ROOT}/indices/3po/train.json",
180
+ subsample_factor=2,
181
+ )
182
+
183
+ # train_datasets = [dataset_shutterstock_train]
184
+ # train_datasets = [dataset_goliath_train]
185
+ # train_datasets = [dataset_3po_train]
186
+ train_datasets = (
187
+ [dataset_goliath_train] + 2 * [dataset_shutterstock_train] + [dataset_3po_train]
188
+ )
189
+
190
+ train_dataloader = dict(
191
+ batch_size=1,
192
+ num_workers=4,
193
+ persistent_workers=True,
194
+ shuffle=True,
195
+ dataset=dict(
196
+ type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline
197
+ ),
198
+ )
199
+
200
+ # ------------------------------------------------------------------------------
201
+ dataset_shutterstock_val = dict(
202
+ type="Keypoints308ShutterstockEvalDataset",
203
+ data_root=f"{_DATA_ROOT}/pose/data/shutterstock/test/images",
204
+ ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json",
205
+ test_mode=True,
206
+ pipeline=val_pipeline,
207
+ )
208
+
209
+ val_dataloader = dict(
210
+ batch_size=4,
211
+ num_workers=4,
212
+ persistent_workers=True,
213
+ multiprocessing_context="spawn", ## avoids fork error with airstore
214
+ shuffle=False,
215
+ dataset=dataset_shutterstock_val,
216
+ collate_fn=dict(type="eval_collate"),
217
+ )
218
+
219
+ val_cfg = dict(
220
+ val_interval=val_every_iters,
221
+ flip_test=True, ## left right flip
222
+ evaluator=dict(
223
+ type="Keypoints308Evaluator",
224
+ decoder=codec,
225
+ ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json",
226
+ ),
227
+ )
228
+
229
+ # dataset_goliath_val = dict(
230
+ # type="Keypoints308GoliathEvalDataset",
231
+ # data_root=f"{_DATA_ROOT}/pose/data/goliath/test_10000/images",
232
+ # ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json",
233
+ # test_mode=True,
234
+ # # num_samples=10, ## debug
235
+ # pipeline=val_pipeline,
236
+ # )
237
+
238
+ # val_dataloader = dict(
239
+ # batch_size=4,
240
+ # num_workers=4,
241
+ # persistent_workers=True,
242
+ # multiprocessing_context="spawn", ## avoids fork error with airstore
243
+ # # num_workers=0, # debug
244
+ # # persistent_workers=False, # debug
245
+ # shuffle=False,
246
+ # dataset=dataset_goliath_val,
247
+ # collate_fn=dict(type="eval_collate"),
248
+ # )
249
+
250
+ # val_cfg = dict(
251
+ # val_interval=val_every_iters,
252
+ # flip_test=True, ## left right flip
253
+ # evaluator=dict(
254
+ # type="Keypoints308Evaluator",
255
+ # decoder=codec,
256
+ # ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json",
257
+ # ),
258
+ # )
259
+
260
+ data_preprocessor = dict(
261
+ type="ImagePreprocessor",
262
+ mean=[123.675, 116.28, 103.53],
263
+ std=[58.395, 57.12, 57.375],
264
+ bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models
265
+ )
266
+
267
+ ##-----------------------------------------------------------------
268
+ model = dict(
269
+ type="PoseTopdownEstimator",
270
+ backbone=dict(
271
+ type="Sapiens2",
272
+ arch=model_name,
273
+ img_size=image_size,
274
+ patch_size=patch_size,
275
+ final_norm=True,
276
+ use_tokenizer=False,
277
+ with_cls_token=True,
278
+ out_type="featmap",
279
+ init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint),
280
+ ),
281
+ decode_head=dict(
282
+ type="PoseHeatmapHead",
283
+ in_channels=embed_dim,
284
+ out_channels=num_keypoints,
285
+ deconv_out_channels=(1024, 768), ## this will 2x at each step. so total is 4x
286
+ deconv_kernel_sizes=(4, 4),
287
+ conv_out_channels=(512, 512, 256),
288
+ conv_kernel_sizes=(1, 1, 1),
289
+ loss_decode=dict(
290
+ type="KeypointMSELoss", use_target_weight=True, loss_weight=10.0
291
+ ),
292
+ # loss_decode=dict(type='KeypointOHKMMSELoss', use_target_weight=True, topk=128), ## loss only for top 128 keypoints. for finetuning later.
293
+ ),
294
+ )
295
+
296
+
297
+ ##-----------------------------------------------------------------
298
+ optimizer = dict(
299
+ type="AdamW",
300
+ # lr=5e-4,
301
+ lr=1e-4,
302
+ betas=(0.9, 0.999),
303
+ weight_decay=0.1,
304
+ paramwise_cfg=dict(
305
+ num_layers=num_layers,
306
+ layer_decay_rate=layer_decay_rate,
307
+ ),
308
+ fused=True,
309
+ )
310
+
311
+ scheduler = dict(
312
+ type="SequentialLR",
313
+ milestones=[warmup_iters],
314
+ schedulers=[
315
+ dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters),
316
+ dict(
317
+ type="PolynomialLR",
318
+ total_iters=num_iters - warmup_iters,
319
+ power=1.0,
320
+ ),
321
+ ],
322
+ )
323
+
324
+ clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0)
325
+
326
+ runner_type = "PoseRunner"
assets/images/68204.png ADDED

Git LFS Details

  • SHA256: 9b0268cb801ed164864a4b5f6d131e0ac5cc2fbd149a6467d5d0c97da47122c2
  • Pointer size: 132 Bytes
  • Size of remote file: 4.29 MB
assets/images/68210.png ADDED

Git LFS Details

  • SHA256: dbe5f80498af4ebd1ff09ae4184f37c20ba981e53bd554c3cc78d39ae0ee7fd7
  • Pointer size: 132 Bytes
  • Size of remote file: 3.93 MB
assets/images/68658.png ADDED

Git LFS Details

  • SHA256: 61a68b619bd17235e683324f2826ce0693322e45ab8c86f1c057851ecb333ac7
  • Pointer size: 132 Bytes
  • Size of remote file: 5.1 MB
assets/images/68666.png ADDED

Git LFS Details

  • SHA256: ea3047e6c2ccb485fdb3966aa2325e803cbf49c27c0bff00287b44bc16f18914
  • Pointer size: 132 Bytes
  • Size of remote file: 4.56 MB
assets/images/68691.png ADDED

Git LFS Details

  • SHA256: fae39e4055c1b297af7068cdddfeeba8d685363281b839d8c5afac1980204b57
  • Pointer size: 132 Bytes
  • Size of remote file: 3.74 MB
assets/images/68956.png ADDED

Git LFS Details

  • SHA256: eee1f27082b10999d0fa848121ecb06cda3386b1a864b9aa0f59ae78261f8908
  • Pointer size: 132 Bytes
  • Size of remote file: 4.15 MB
assets/images/pexels-amresh444-17315601.png ADDED

Git LFS Details

  • SHA256: 4e17ee1b229147e4b52e8348a6ef426bc9e9a2f90738e776e15b26b325abb9b3
  • Pointer size: 132 Bytes
  • Size of remote file: 3.5 MB
assets/images/pexels-gabby-k-6311686.png ADDED

Git LFS Details

  • SHA256: 3f10eded3fb05ab04b963f7b9fd2e183d8d4e81b20569b1c6b0653549639421f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.65 MB
assets/images/pexels-julia-m-cameron-4145040.png ADDED

Git LFS Details

  • SHA256: 459cf0280667b028ffbca16aa11188780d7a0205c0defec02916ff3cbaeecb72
  • Pointer size: 132 Bytes
  • Size of remote file: 2.92 MB
assets/images/pexels-marcus-aurelius-6787357.png ADDED

Git LFS Details

  • SHA256: 7d35452f76492125eaf7d5783aa9fd6b0d5990ebe0579fe9dfd58a9d634f4955
  • Pointer size: 132 Bytes
  • Size of remote file: 3.3 MB
assets/images/pexels-mo-saeed-3616599-5409085.png ADDED

Git LFS Details

  • SHA256: 7c1ca7afd6c2a654e94ef59d5fb56fca4f3cde5fb5216f6b218c34a7b8c143dc
  • Pointer size: 132 Bytes
  • Size of remote file: 3.13 MB
assets/images/pexels-riedelmax-27355495.png ADDED

Git LFS Details

  • SHA256: 4141d2f5f718f162ea1f6710c06b28b5cb51fd69598fde35948f8f3491228164
  • Pointer size: 132 Bytes
  • Size of remote file: 3.73 MB
assets/images/pexels-sergeymakashin-5368660.png ADDED

Git LFS Details

  • SHA256: af8f5a8f26dd102d87d94c1be36ec903791fe8e6d951c68ebb9ebcfc6d7397bb
  • Pointer size: 132 Bytes
  • Size of remote file: 4.08 MB
assets/images/pexels-vinicius-wiesehofer-289347-4219918.png ADDED

Git LFS Details

  • SHA256: a6eef5eee15b81fe65ea95627e9a46040b9889466689b3c1ca6ed273e02fe84f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.63 MB
assets/rtmdet_m_640-8xb32_coco-person_no_nms.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = 'mmdet::rtmdet/rtmdet_m_8xb32-300e_coco.py'
2
+
3
+ checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-m_8xb256-rsb-a1-600e_in1k-ecb3bbd9.pth' # noqa
4
+
5
+ model = dict(
6
+ backbone=dict(
7
+ init_cfg=dict(
8
+ type='Pretrained', prefix='backbone.', checkpoint=checkpoint)),
9
+ bbox_head=dict(num_classes=1),
10
+ test_cfg=dict(
11
+ nms_pre=1000,
12
+ min_bbox_size=0,
13
+ score_thr=0.05,
14
+ nms=None,
15
+ max_per_img=100))
16
+
17
+ train_dataloader = dict(dataset=dict(metainfo=dict(classes=('person', ))))
18
+
19
+ val_dataloader = dict(dataset=dict(metainfo=dict(classes=('person', ))))
20
+ test_dataloader = val_dataloader
classes_and_palettes.py ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ COCO_KPTS_COLORS = [
2
+ [51, 153, 255], # 0: nose
3
+ [51, 153, 255], # 1: left_eye
4
+ [51, 153, 255], # 2: right_eye
5
+ [51, 153, 255], # 3: left_ear
6
+ [51, 153, 255], # 4: right_ear
7
+ [0, 255, 0], # 5: left_shoulder
8
+ [255, 128, 0], # 6: right_shoulder
9
+ [0, 255, 0], # 7: left_elbow
10
+ [255, 128, 0], # 8: right_elbow
11
+ [0, 255, 0], # 9: left_wrist
12
+ [255, 128, 0], # 10: right_wrist
13
+ [0, 255, 0], # 11: left_hip
14
+ [255, 128, 0], # 12: right_hip
15
+ [0, 255, 0], # 13: left_knee
16
+ [255, 128, 0], # 14: right_knee
17
+ [0, 255, 0], # 15: left_ankle
18
+ [255, 128, 0], # 16: right_ankle
19
+ ]
20
+
21
+ COCO_WHOLEBODY_KPTS_COLORS = [
22
+ [51, 153, 255], # 0: nose
23
+ [51, 153, 255], # 1: left_eye
24
+ [51, 153, 255], # 2: right_eye
25
+ [51, 153, 255], # 3: left_ear
26
+ [51, 153, 255], # 4: right_ear
27
+ [0, 255, 0], # 5: left_shoulder
28
+ [255, 128, 0], # 6: right_shoulder
29
+ [0, 255, 0], # 7: left_elbow
30
+ [255, 128, 0], # 8: right_elbow
31
+ [0, 255, 0], # 9: left_wrist
32
+ [255, 128, 0], # 10: right_wrist
33
+ [0, 255, 0], # 11: left_hip
34
+ [255, 128, 0], # 12: right_hip
35
+ [0, 255, 0], # 13: left_knee
36
+ [255, 128, 0], # 14: right_knee
37
+ [0, 255, 0], # 15: left_ankle
38
+ [255, 128, 0], # 16: right_ankle
39
+ [255, 128, 0], # 17: left_big_toe
40
+ [255, 128, 0], # 18: left_small_toe
41
+ [255, 128, 0], # 19: left_heel
42
+ [255, 128, 0], # 20: right_big_toe
43
+ [255, 128, 0], # 21: right_small_toe
44
+ [255, 128, 0], # 22: right_heel
45
+ [255, 255, 255], # 23: face-0
46
+ [255, 255, 255], # 24: face-1
47
+ [255, 255, 255], # 25: face-2
48
+ [255, 255, 255], # 26: face-3
49
+ [255, 255, 255], # 27: face-4
50
+ [255, 255, 255], # 28: face-5
51
+ [255, 255, 255], # 29: face-6
52
+ [255, 255, 255], # 30: face-7
53
+ [255, 255, 255], # 31: face-8
54
+ [255, 255, 255], # 32: face-9
55
+ [255, 255, 255], # 33: face-10
56
+ [255, 255, 255], # 34: face-11
57
+ [255, 255, 255], # 35: face-12
58
+ [255, 255, 255], # 36: face-13
59
+ [255, 255, 255], # 37: face-14
60
+ [255, 255, 255], # 38: face-15
61
+ [255, 255, 255], # 39: face-16
62
+ [255, 255, 255], # 40: face-17
63
+ [255, 255, 255], # 41: face-18
64
+ [255, 255, 255], # 42: face-19
65
+ [255, 255, 255], # 43: face-20
66
+ [255, 255, 255], # 44: face-21
67
+ [255, 255, 255], # 45: face-22
68
+ [255, 255, 255], # 46: face-23
69
+ [255, 255, 255], # 47: face-24
70
+ [255, 255, 255], # 48: face-25
71
+ [255, 255, 255], # 49: face-26
72
+ [255, 255, 255], # 50: face-27
73
+ [255, 255, 255], # 51: face-28
74
+ [255, 255, 255], # 52: face-29
75
+ [255, 255, 255], # 53: face-30
76
+ [255, 255, 255], # 54: face-31
77
+ [255, 255, 255], # 55: face-32
78
+ [255, 255, 255], # 56: face-33
79
+ [255, 255, 255], # 57: face-34
80
+ [255, 255, 255], # 58: face-35
81
+ [255, 255, 255], # 59: face-36
82
+ [255, 255, 255], # 60: face-37
83
+ [255, 255, 255], # 61: face-38
84
+ [255, 255, 255], # 62: face-39
85
+ [255, 255, 255], # 63: face-40
86
+ [255, 255, 255], # 64: face-41
87
+ [255, 255, 255], # 65: face-42
88
+ [255, 255, 255], # 66: face-43
89
+ [255, 255, 255], # 67: face-44
90
+ [255, 255, 255], # 68: face-45
91
+ [255, 255, 255], # 69: face-46
92
+ [255, 255, 255], # 70: face-47
93
+ [255, 255, 255], # 71: face-48
94
+ [255, 255, 255], # 72: face-49
95
+ [255, 255, 255], # 73: face-50
96
+ [255, 255, 255], # 74: face-51
97
+ [255, 255, 255], # 75: face-52
98
+ [255, 255, 255], # 76: face-53
99
+ [255, 255, 255], # 77: face-54
100
+ [255, 255, 255], # 78: face-55
101
+ [255, 255, 255], # 79: face-56
102
+ [255, 255, 255], # 80: face-57
103
+ [255, 255, 255], # 81: face-58
104
+ [255, 255, 255], # 82: face-59
105
+ [255, 255, 255], # 83: face-60
106
+ [255, 255, 255], # 84: face-61
107
+ [255, 255, 255], # 85: face-62
108
+ [255, 255, 255], # 86: face-63
109
+ [255, 255, 255], # 87: face-64
110
+ [255, 255, 255], # 88: face-65
111
+ [255, 255, 255], # 89: face-66
112
+ [255, 255, 255], # 90: face-67
113
+ [255, 255, 255], # 91: left_hand_root
114
+ [255, 128, 0], # 92: left_thumb1
115
+ [255, 128, 0], # 93: left_thumb2
116
+ [255, 128, 0], # 94: left_thumb3
117
+ [255, 128, 0], # 95: left_thumb4
118
+ [255, 153, 255], # 96: left_forefinger1
119
+ [255, 153, 255], # 97: left_forefinger2
120
+ [255, 153, 255], # 98: left_forefinger3
121
+ [255, 153, 255], # 99: left_forefinger4
122
+ [102, 178, 255], # 100: left_middle_finger1
123
+ [102, 178, 255], # 101: left_middle_finger2
124
+ [102, 178, 255], # 102: left_middle_finger3
125
+ [102, 178, 255], # 103: left_middle_finger4
126
+ [255, 51, 51], # 104: left_ring_finger1
127
+ [255, 51, 51], # 105: left_ring_finger2
128
+ [255, 51, 51], # 106: left_ring_finger3
129
+ [255, 51, 51], # 107: left_ring_finger4
130
+ [0, 255, 0], # 108: left_pinky_finger1
131
+ [0, 255, 0], # 109: left_pinky_finger2
132
+ [0, 255, 0], # 110: left_pinky_finger3
133
+ [0, 255, 0], # 111: left_pinky_finger4
134
+ [255, 255, 255], # 112: right_hand_root
135
+ [255, 128, 0], # 113: right_thumb1
136
+ [255, 128, 0], # 114: right_thumb2
137
+ [255, 128, 0], # 115: right_thumb3
138
+ [255, 128, 0], # 116: right_thumb4
139
+ [255, 153, 255], # 117: right_forefinger1
140
+ [255, 153, 255], # 118: right_forefinger2
141
+ [255, 153, 255], # 119: right_forefinger3
142
+ [255, 153, 255], # 120: right_forefinger4
143
+ [102, 178, 255], # 121: right_middle_finger1
144
+ [102, 178, 255], # 122: right_middle_finger2
145
+ [102, 178, 255], # 123: right_middle_finger3
146
+ [102, 178, 255], # 124: right_middle_finger4
147
+ [255, 51, 51], # 125: right_ring_finger1
148
+ [255, 51, 51], # 126: right_ring_finger2
149
+ [255, 51, 51], # 127: right_ring_finger3
150
+ [255, 51, 51], # 128: right_ring_finger4
151
+ [0, 255, 0], # 129: right_pinky_finger1
152
+ [0, 255, 0], # 130: right_pinky_finger2
153
+ [0, 255, 0], # 131: right_pinky_finger3
154
+ [0, 255, 0], # 132: right_pinky_finger4
155
+ ]
156
+
157
+
158
+ GOLIATH_KPTS_COLORS = [
159
+ [51, 153, 255], # 0: nose
160
+ [51, 153, 255], # 1: left_eye
161
+ [51, 153, 255], # 2: right_eye
162
+ [51, 153, 255], # 3: left_ear
163
+ [51, 153, 255], # 4: right_ear
164
+ [51, 153, 255], # 5: left_shoulder
165
+ [51, 153, 255], # 6: right_shoulder
166
+ [51, 153, 255], # 7: left_elbow
167
+ [51, 153, 255], # 8: right_elbow
168
+ [51, 153, 255], # 9: left_hip
169
+ [51, 153, 255], # 10: right_hip
170
+ [51, 153, 255], # 11: left_knee
171
+ [51, 153, 255], # 12: right_knee
172
+ [51, 153, 255], # 13: left_ankle
173
+ [51, 153, 255], # 14: right_ankle
174
+ [51, 153, 255], # 15: left_big_toe
175
+ [51, 153, 255], # 16: left_small_toe
176
+ [51, 153, 255], # 17: left_heel
177
+ [51, 153, 255], # 18: right_big_toe
178
+ [51, 153, 255], # 19: right_small_toe
179
+ [51, 153, 255], # 20: right_heel
180
+ [51, 153, 255], # 21: right_thumb4
181
+ [51, 153, 255], # 22: right_thumb3
182
+ [51, 153, 255], # 23: right_thumb2
183
+ [51, 153, 255], # 24: right_thumb_third_joint
184
+ [51, 153, 255], # 25: right_forefinger4
185
+ [51, 153, 255], # 26: right_forefinger3
186
+ [51, 153, 255], # 27: right_forefinger2
187
+ [51, 153, 255], # 28: right_forefinger_third_joint
188
+ [51, 153, 255], # 29: right_middle_finger4
189
+ [51, 153, 255], # 30: right_middle_finger3
190
+ [51, 153, 255], # 31: right_middle_finger2
191
+ [51, 153, 255], # 32: right_middle_finger_third_joint
192
+ [51, 153, 255], # 33: right_ring_finger4
193
+ [51, 153, 255], # 34: right_ring_finger3
194
+ [51, 153, 255], # 35: right_ring_finger2
195
+ [51, 153, 255], # 36: right_ring_finger_third_joint
196
+ [51, 153, 255], # 37: right_pinky_finger4
197
+ [51, 153, 255], # 38: right_pinky_finger3
198
+ [51, 153, 255], # 39: right_pinky_finger2
199
+ [51, 153, 255], # 40: right_pinky_finger_third_joint
200
+ [51, 153, 255], # 41: right_wrist
201
+ [51, 153, 255], # 42: left_thumb4
202
+ [51, 153, 255], # 43: left_thumb3
203
+ [51, 153, 255], # 44: left_thumb2
204
+ [51, 153, 255], # 45: left_thumb_third_joint
205
+ [51, 153, 255], # 46: left_forefinger4
206
+ [51, 153, 255], # 47: left_forefinger3
207
+ [51, 153, 255], # 48: left_forefinger2
208
+ [51, 153, 255], # 49: left_forefinger_third_joint
209
+ [51, 153, 255], # 50: left_middle_finger4
210
+ [51, 153, 255], # 51: left_middle_finger3
211
+ [51, 153, 255], # 52: left_middle_finger2
212
+ [51, 153, 255], # 53: left_middle_finger_third_joint
213
+ [51, 153, 255], # 54: left_ring_finger4
214
+ [51, 153, 255], # 55: left_ring_finger3
215
+ [51, 153, 255], # 56: left_ring_finger2
216
+ [51, 153, 255], # 57: left_ring_finger_third_joint
217
+ [51, 153, 255], # 58: left_pinky_finger4
218
+ [51, 153, 255], # 59: left_pinky_finger3
219
+ [51, 153, 255], # 60: left_pinky_finger2
220
+ [51, 153, 255], # 61: left_pinky_finger_third_joint
221
+ [51, 153, 255], # 62: left_wrist
222
+ [51, 153, 255], # 63: left_olecranon
223
+ [51, 153, 255], # 64: right_olecranon
224
+ [51, 153, 255], # 65: left_cubital_fossa
225
+ [51, 153, 255], # 66: right_cubital_fossa
226
+ [51, 153, 255], # 67: left_acromion
227
+ [51, 153, 255], # 68: right_acromion
228
+ [51, 153, 255], # 69: neck
229
+ [255, 255, 255], # 70: center_of_glabella
230
+ [255, 255, 255], # 71: center_of_nose_root
231
+ [255, 255, 255], # 72: tip_of_nose_bridge
232
+ [255, 255, 255], # 73: midpoint_1_of_nose_bridge
233
+ [255, 255, 255], # 74: midpoint_2_of_nose_bridge
234
+ [255, 255, 255], # 75: midpoint_3_of_nose_bridge
235
+ [255, 255, 255], # 76: center_of_labiomental_groove
236
+ [255, 255, 255], # 77: tip_of_chin
237
+ [255, 255, 255], # 78: upper_startpoint_of_r_eyebrow
238
+ [255, 255, 255], # 79: lower_startpoint_of_r_eyebrow
239
+ [255, 255, 255], # 80: end_of_r_eyebrow
240
+ [255, 255, 255], # 81: upper_midpoint_1_of_r_eyebrow
241
+ [255, 255, 255], # 82: lower_midpoint_1_of_r_eyebrow
242
+ [255, 255, 255], # 83: upper_midpoint_2_of_r_eyebrow
243
+ [255, 255, 255], # 84: upper_midpoint_3_of_r_eyebrow
244
+ [255, 255, 255], # 85: lower_midpoint_2_of_r_eyebrow
245
+ [255, 255, 255], # 86: lower_midpoint_3_of_r_eyebrow
246
+ [255, 255, 255], # 87: upper_startpoint_of_l_eyebrow
247
+ [255, 255, 255], # 88: lower_startpoint_of_l_eyebrow
248
+ [255, 255, 255], # 89: end_of_l_eyebrow
249
+ [255, 255, 255], # 90: upper_midpoint_1_of_l_eyebrow
250
+ [255, 255, 255], # 91: lower_midpoint_1_of_l_eyebrow
251
+ [255, 255, 255], # 92: upper_midpoint_2_of_l_eyebrow
252
+ [255, 255, 255], # 93: upper_midpoint_3_of_l_eyebrow
253
+ [255, 255, 255], # 94: lower_midpoint_2_of_l_eyebrow
254
+ [255, 255, 255], # 95: lower_midpoint_3_of_l_eyebrow
255
+ [192, 64, 128], # 96: l_inner_end_of_upper_lash_line
256
+ [192, 64, 128], # 97: l_outer_end_of_upper_lash_line
257
+ [192, 64, 128], # 98: l_centerpoint_of_upper_lash_line
258
+ [192, 64, 128], # 99: l_midpoint_2_of_upper_lash_line
259
+ [192, 64, 128], # 100: l_midpoint_1_of_upper_lash_line
260
+ [192, 64, 128], # 101: l_midpoint_6_of_upper_lash_line
261
+ [192, 64, 128], # 102: l_midpoint_5_of_upper_lash_line
262
+ [192, 64, 128], # 103: l_midpoint_4_of_upper_lash_line
263
+ [192, 64, 128], # 104: l_midpoint_3_of_upper_lash_line
264
+ [192, 64, 128], # 105: l_outer_end_of_upper_eyelid_line
265
+ [192, 64, 128], # 106: l_midpoint_6_of_upper_eyelid_line
266
+ [192, 64, 128], # 107: l_midpoint_2_of_upper_eyelid_line
267
+ [192, 64, 128], # 108: l_midpoint_5_of_upper_eyelid_line
268
+ [192, 64, 128], # 109: l_centerpoint_of_upper_eyelid_line
269
+ [192, 64, 128], # 110: l_midpoint_4_of_upper_eyelid_line
270
+ [192, 64, 128], # 111: l_midpoint_1_of_upper_eyelid_line
271
+ [192, 64, 128], # 112: l_midpoint_3_of_upper_eyelid_line
272
+ [192, 64, 128], # 113: l_midpoint_6_of_upper_crease_line
273
+ [192, 64, 128], # 114: l_midpoint_2_of_upper_crease_line
274
+ [192, 64, 128], # 115: l_midpoint_5_of_upper_crease_line
275
+ [192, 64, 128], # 116: l_centerpoint_of_upper_crease_line
276
+ [192, 64, 128], # 117: l_midpoint_4_of_upper_crease_line
277
+ [192, 64, 128], # 118: l_midpoint_1_of_upper_crease_line
278
+ [192, 64, 128], # 119: l_midpoint_3_of_upper_crease_line
279
+ [64, 32, 192], # 120: r_inner_end_of_upper_lash_line
280
+ [64, 32, 192], # 121: r_outer_end_of_upper_lash_line
281
+ [64, 32, 192], # 122: r_centerpoint_of_upper_lash_line
282
+ [64, 32, 192], # 123: r_midpoint_1_of_upper_lash_line
283
+ [64, 32, 192], # 124: r_midpoint_2_of_upper_lash_line
284
+ [64, 32, 192], # 125: r_midpoint_3_of_upper_lash_line
285
+ [64, 32, 192], # 126: r_midpoint_4_of_upper_lash_line
286
+ [64, 32, 192], # 127: r_midpoint_5_of_upper_lash_line
287
+ [64, 32, 192], # 128: r_midpoint_6_of_upper_lash_line
288
+ [64, 32, 192], # 129: r_outer_end_of_upper_eyelid_line
289
+ [64, 32, 192], # 130: r_midpoint_3_of_upper_eyelid_line
290
+ [64, 32, 192], # 131: r_midpoint_1_of_upper_eyelid_line
291
+ [64, 32, 192], # 132: r_midpoint_4_of_upper_eyelid_line
292
+ [64, 32, 192], # 133: r_centerpoint_of_upper_eyelid_line
293
+ [64, 32, 192], # 134: r_midpoint_5_of_upper_eyelid_line
294
+ [64, 32, 192], # 135: r_midpoint_2_of_upper_eyelid_line
295
+ [64, 32, 192], # 136: r_midpoint_6_of_upper_eyelid_line
296
+ [64, 32, 192], # 137: r_midpoint_3_of_upper_crease_line
297
+ [64, 32, 192], # 138: r_midpoint_1_of_upper_crease_line
298
+ [64, 32, 192], # 139: r_midpoint_4_of_upper_crease_line
299
+ [64, 32, 192], # 140: r_centerpoint_of_upper_crease_line
300
+ [64, 32, 192], # 141: r_midpoint_5_of_upper_crease_line
301
+ [64, 32, 192], # 142: r_midpoint_2_of_upper_crease_line
302
+ [64, 32, 192], # 143: r_midpoint_6_of_upper_crease_line
303
+ [64, 192, 128], # 144: l_inner_end_of_lower_lash_line
304
+ [64, 192, 128], # 145: l_outer_end_of_lower_lash_line
305
+ [64, 192, 128], # 146: l_centerpoint_of_lower_lash_line
306
+ [64, 192, 128], # 147: l_midpoint_2_of_lower_lash_line
307
+ [64, 192, 128], # 148: l_midpoint_1_of_lower_lash_line
308
+ [64, 192, 128], # 149: l_midpoint_6_of_lower_lash_line
309
+ [64, 192, 128], # 150: l_midpoint_5_of_lower_lash_line
310
+ [64, 192, 128], # 151: l_midpoint_4_of_lower_lash_line
311
+ [64, 192, 128], # 152: l_midpoint_3_of_lower_lash_line
312
+ [64, 192, 128], # 153: l_outer_end_of_lower_eyelid_line
313
+ [64, 192, 128], # 154: l_midpoint_6_of_lower_eyelid_line
314
+ [64, 192, 128], # 155: l_midpoint_2_of_lower_eyelid_line
315
+ [64, 192, 128], # 156: l_midpoint_5_of_lower_eyelid_line
316
+ [64, 192, 128], # 157: l_centerpoint_of_lower_eyelid_line
317
+ [64, 192, 128], # 158: l_midpoint_4_of_lower_eyelid_line
318
+ [64, 192, 128], # 159: l_midpoint_1_of_lower_eyelid_line
319
+ [64, 192, 128], # 160: l_midpoint_3_of_lower_eyelid_line
320
+ [64, 192, 32], # 161: r_inner_end_of_lower_lash_line
321
+ [64, 192, 32], # 162: r_outer_end_of_lower_lash_line
322
+ [64, 192, 32], # 163: r_centerpoint_of_lower_lash_line
323
+ [64, 192, 32], # 164: r_midpoint_1_of_lower_lash_line
324
+ [64, 192, 32], # 165: r_midpoint_2_of_lower_lash_line
325
+ [64, 192, 32], # 166: r_midpoint_3_of_lower_lash_line
326
+ [64, 192, 32], # 167: r_midpoint_4_of_lower_lash_line
327
+ [64, 192, 32], # 168: r_midpoint_5_of_lower_lash_line
328
+ [64, 192, 32], # 169: r_midpoint_6_of_lower_lash_line
329
+ [64, 192, 32], # 170: r_outer_end_of_lower_eyelid_line
330
+ [64, 192, 32], # 171: r_midpoint_3_of_lower_eyelid_line
331
+ [64, 192, 32], # 172: r_midpoint_1_of_lower_eyelid_line
332
+ [64, 192, 32], # 173: r_midpoint_4_of_lower_eyelid_line
333
+ [64, 192, 32], # 174: r_centerpoint_of_lower_eyelid_line
334
+ [64, 192, 32], # 175: r_midpoint_5_of_lower_eyelid_line
335
+ [64, 192, 32], # 176: r_midpoint_2_of_lower_eyelid_line
336
+ [64, 192, 32], # 177: r_midpoint_6_of_lower_eyelid_line
337
+ [0, 192, 0], # 178: tip_of_nose
338
+ [0, 192, 0], # 179: bottom_center_of_nose
339
+ [0, 192, 0], # 180: r_outer_corner_of_nose
340
+ [0, 192, 0], # 181: l_outer_corner_of_nose
341
+ [0, 192, 0], # 182: inner_corner_of_r_nostril
342
+ [0, 192, 0], # 183: outer_corner_of_r_nostril
343
+ [0, 192, 0], # 184: upper_corner_of_r_nostril
344
+ [0, 192, 0], # 185: inner_corner_of_l_nostril
345
+ [0, 192, 0], # 186: outer_corner_of_l_nostril
346
+ [0, 192, 0], # 187: upper_corner_of_l_nostril
347
+ [192, 0, 0], # 188: r_outer_corner_of_mouth
348
+ [192, 0, 0], # 189: l_outer_corner_of_mouth
349
+ [192, 0, 0], # 190: center_of_cupid_bow
350
+ [192, 0, 0], # 191: center_of_lower_outer_lip
351
+ [192, 0, 0], # 192: midpoint_1_of_upper_outer_lip
352
+ [192, 0, 0], # 193: midpoint_2_of_upper_outer_lip
353
+ [192, 0, 0], # 194: midpoint_1_of_lower_outer_lip
354
+ [192, 0, 0], # 195: midpoint_2_of_lower_outer_lip
355
+ [192, 0, 0], # 196: midpoint_3_of_upper_outer_lip
356
+ [192, 0, 0], # 197: midpoint_4_of_upper_outer_lip
357
+ [192, 0, 0], # 198: midpoint_5_of_upper_outer_lip
358
+ [192, 0, 0], # 199: midpoint_6_of_upper_outer_lip
359
+ [192, 0, 0], # 200: midpoint_3_of_lower_outer_lip
360
+ [192, 0, 0], # 201: midpoint_4_of_lower_outer_lip
361
+ [192, 0, 0], # 202: midpoint_5_of_lower_outer_lip
362
+ [192, 0, 0], # 203: midpoint_6_of_lower_outer_lip
363
+ [0, 192, 192], # 204: r_inner_corner_of_mouth
364
+ [0, 192, 192], # 205: l_inner_corner_of_mouth
365
+ [0, 192, 192], # 206: center_of_upper_inner_lip
366
+ [0, 192, 192], # 207: center_of_lower_inner_lip
367
+ [0, 192, 192], # 208: midpoint_1_of_upper_inner_lip
368
+ [0, 192, 192], # 209: midpoint_2_of_upper_inner_lip
369
+ [0, 192, 192], # 210: midpoint_1_of_lower_inner_lip
370
+ [0, 192, 192], # 211: midpoint_2_of_lower_inner_lip
371
+ [0, 192, 192], # 212: midpoint_3_of_upper_inner_lip
372
+ [0, 192, 192], # 213: midpoint_4_of_upper_inner_lip
373
+ [0, 192, 192], # 214: midpoint_5_of_upper_inner_lip
374
+ [0, 192, 192], # 215: midpoint_6_of_upper_inner_lip
375
+ [0, 192, 192], # 216: midpoint_3_of_lower_inner_lip
376
+ [0, 192, 192], # 217: midpoint_4_of_lower_inner_lip
377
+ [0, 192, 192], # 218: midpoint_5_of_lower_inner_lip
378
+ [0, 192, 192], # 219: midpoint_6_of_lower_inner_lip. teeths removed
379
+ [200, 200, 0], # 256: l_top_end_of_inferior_crus
380
+ [200, 200, 0], # 257: l_top_end_of_superior_crus
381
+ [200, 200, 0], # 258: l_start_of_antihelix
382
+ [200, 200, 0], # 259: l_end_of_antihelix
383
+ [200, 200, 0], # 260: l_midpoint_1_of_antihelix
384
+ [200, 200, 0], # 261: l_midpoint_1_of_inferior_crus
385
+ [200, 200, 0], # 262: l_midpoint_2_of_antihelix
386
+ [200, 200, 0], # 263: l_midpoint_3_of_antihelix
387
+ [200, 200, 0], # 264: l_point_1_of_inner_helix
388
+ [200, 200, 0], # 265: l_point_2_of_inner_helix
389
+ [200, 200, 0], # 266: l_point_3_of_inner_helix
390
+ [200, 200, 0], # 267: l_point_4_of_inner_helix
391
+ [200, 200, 0], # 268: l_point_5_of_inner_helix
392
+ [200, 200, 0], # 269: l_point_6_of_inner_helix
393
+ [200, 200, 0], # 270: l_point_7_of_inner_helix
394
+ [200, 200, 0], # 271: l_highest_point_of_antitragus
395
+ [200, 200, 0], # 272: l_bottom_point_of_tragus
396
+ [200, 200, 0], # 273: l_protruding_point_of_tragus
397
+ [200, 200, 0], # 274: l_top_point_of_tragus
398
+ [200, 200, 0], # 275: l_start_point_of_crus_of_helix
399
+ [200, 200, 0], # 276: l_deepest_point_of_concha
400
+ [200, 200, 0], # 277: l_tip_of_ear_lobe
401
+ [200, 200, 0], # 278: l_midpoint_between_22_15
402
+ [200, 200, 0], # 279: l_bottom_connecting_point_of_ear_lobe
403
+ [200, 200, 0], # 280: l_top_connecting_point_of_helix
404
+ [200, 200, 0], # 281: l_point_8_of_inner_helix
405
+ [0, 200, 200], # 282: r_top_end_of_inferior_crus
406
+ [0, 200, 200], # 283: r_top_end_of_superior_crus
407
+ [0, 200, 200], # 284: r_start_of_antihelix
408
+ [0, 200, 200], # 285: r_end_of_antihelix
409
+ [0, 200, 200], # 286: r_midpoint_1_of_antihelix
410
+ [0, 200, 200], # 287: r_midpoint_1_of_inferior_crus
411
+ [0, 200, 200], # 288: r_midpoint_2_of_antihelix
412
+ [0, 200, 200], # 289: r_midpoint_3_of_antihelix
413
+ [0, 200, 200], # 290: r_point_1_of_inner_helix
414
+ [0, 200, 200], # 291: r_point_8_of_inner_helix
415
+ [0, 200, 200], # 292: r_point_3_of_inner_helix
416
+ [0, 200, 200], # 293: r_point_4_of_inner_helix
417
+ [0, 200, 200], # 294: r_point_5_of_inner_helix
418
+ [0, 200, 200], # 295: r_point_6_of_inner_helix
419
+ [0, 200, 200], # 296: r_point_7_of_inner_helix
420
+ [0, 200, 200], # 297: r_highest_point_of_antitragus
421
+ [0, 200, 200], # 298: r_bottom_point_of_tragus
422
+ [0, 200, 200], # 299: r_protruding_point_of_tragus
423
+ [0, 200, 200], # 300: r_top_point_of_tragus
424
+ [0, 200, 200], # 301: r_start_point_of_crus_of_helix
425
+ [0, 200, 200], # 302: r_deepest_point_of_concha
426
+ [0, 200, 200], # 303: r_tip_of_ear_lobe
427
+ [0, 200, 200], # 304: r_midpoint_between_22_15
428
+ [0, 200, 200], # 305: r_bottom_connecting_point_of_ear_lobe
429
+ [0, 200, 200], # 306: r_top_connecting_point_of_helix
430
+ [0, 200, 200], # 307: r_point_2_of_inner_helix
431
+ [128, 192, 64], # 308: l_center_of_iris
432
+ [128, 192, 64], # 309: l_border_of_iris_3
433
+ [128, 192, 64], # 310: l_border_of_iris_midpoint_1
434
+ [128, 192, 64], # 311: l_border_of_iris_12
435
+ [128, 192, 64], # 312: l_border_of_iris_midpoint_4
436
+ [128, 192, 64], # 313: l_border_of_iris_9
437
+ [128, 192, 64], # 314: l_border_of_iris_midpoint_3
438
+ [128, 192, 64], # 315: l_border_of_iris_6
439
+ [128, 192, 64], # 316: l_border_of_iris_midpoint_2
440
+ [192, 32, 64], # 317: r_center_of_iris
441
+ [192, 32, 64], # 318: r_border_of_iris_3
442
+ [192, 32, 64], # 319: r_border_of_iris_midpoint_1
443
+ [192, 32, 64], # 320: r_border_of_iris_12
444
+ [192, 32, 64], # 321: r_border_of_iris_midpoint_4
445
+ [192, 32, 64], # 322: r_border_of_iris_9
446
+ [192, 32, 64], # 323: r_border_of_iris_midpoint_3
447
+ [192, 32, 64], # 324: r_border_of_iris_6
448
+ [192, 32, 64], # 325: r_border_of_iris_midpoint_2
449
+ [192, 128, 64], # 326: l_center_of_pupil
450
+ [192, 128, 64], # 327: l_border_of_pupil_3
451
+ [192, 128, 64], # 328: l_border_of_pupil_midpoint_1
452
+ [192, 128, 64], # 329: l_border_of_pupil_12
453
+ [192, 128, 64], # 330: l_border_of_pupil_midpoint_4
454
+ [192, 128, 64], # 331: l_border_of_pupil_9
455
+ [192, 128, 64], # 332: l_border_of_pupil_midpoint_3
456
+ [192, 128, 64], # 333: l_border_of_pupil_6
457
+ [192, 128, 64], # 334: l_border_of_pupil_midpoint_2
458
+ [32, 192, 192], # 335: r_center_of_pupil
459
+ [32, 192, 192], # 336: r_border_of_pupil_3
460
+ [32, 192, 192], # 337: r_border_of_pupil_midpoint_1
461
+ [32, 192, 192], # 338: r_border_of_pupil_12
462
+ [32, 192, 192], # 339: r_border_of_pupil_midpoint_4
463
+ [32, 192, 192], # 340: r_border_of_pupil_9
464
+ [32, 192, 192], # 341: r_border_of_pupil_midpoint_3
465
+ [32, 192, 192], # 342: r_border_of_pupil_6
466
+ [32, 192, 192], # 343: r_border_of_pupil_midpoint_2
467
+ ]
468
+
469
+ GOLIATH_KEYPOINTS = [
470
+ "nose",
471
+ "left_eye",
472
+ "right_eye",
473
+ "left_ear",
474
+ "right_ear",
475
+ "left_shoulder",
476
+ "right_shoulder",
477
+ "left_elbow",
478
+ "right_elbow",
479
+ "left_hip",
480
+ "right_hip",
481
+ "left_knee",
482
+ "right_knee",
483
+ "left_ankle",
484
+ "right_ankle",
485
+ "left_big_toe",
486
+ "left_small_toe",
487
+ "left_heel",
488
+ "right_big_toe",
489
+ "right_small_toe",
490
+ "right_heel",
491
+ "right_thumb4",
492
+ "right_thumb3",
493
+ "right_thumb2",
494
+ "right_thumb_third_joint",
495
+ "right_forefinger4",
496
+ "right_forefinger3",
497
+ "right_forefinger2",
498
+ "right_forefinger_third_joint",
499
+ "right_middle_finger4",
500
+ "right_middle_finger3",
501
+ "right_middle_finger2",
502
+ "right_middle_finger_third_joint",
503
+ "right_ring_finger4",
504
+ "right_ring_finger3",
505
+ "right_ring_finger2",
506
+ "right_ring_finger_third_joint",
507
+ "right_pinky_finger4",
508
+ "right_pinky_finger3",
509
+ "right_pinky_finger2",
510
+ "right_pinky_finger_third_joint",
511
+ "right_wrist",
512
+ "left_thumb4",
513
+ "left_thumb3",
514
+ "left_thumb2",
515
+ "left_thumb_third_joint",
516
+ "left_forefinger4",
517
+ "left_forefinger3",
518
+ "left_forefinger2",
519
+ "left_forefinger_third_joint",
520
+ "left_middle_finger4",
521
+ "left_middle_finger3",
522
+ "left_middle_finger2",
523
+ "left_middle_finger_third_joint",
524
+ "left_ring_finger4",
525
+ "left_ring_finger3",
526
+ "left_ring_finger2",
527
+ "left_ring_finger_third_joint",
528
+ "left_pinky_finger4",
529
+ "left_pinky_finger3",
530
+ "left_pinky_finger2",
531
+ "left_pinky_finger_third_joint",
532
+ "left_wrist",
533
+ "left_olecranon",
534
+ "right_olecranon",
535
+ "left_cubital_fossa",
536
+ "right_cubital_fossa",
537
+ "left_acromion",
538
+ "right_acromion",
539
+ "neck",
540
+ "center_of_glabella",
541
+ "center_of_nose_root",
542
+ "tip_of_nose_bridge",
543
+ "midpoint_1_of_nose_bridge",
544
+ "midpoint_2_of_nose_bridge",
545
+ "midpoint_3_of_nose_bridge",
546
+ "center_of_labiomental_groove",
547
+ "tip_of_chin",
548
+ "upper_startpoint_of_r_eyebrow",
549
+ "lower_startpoint_of_r_eyebrow",
550
+ "end_of_r_eyebrow",
551
+ "upper_midpoint_1_of_r_eyebrow",
552
+ "lower_midpoint_1_of_r_eyebrow",
553
+ "upper_midpoint_2_of_r_eyebrow",
554
+ "upper_midpoint_3_of_r_eyebrow",
555
+ "lower_midpoint_2_of_r_eyebrow",
556
+ "lower_midpoint_3_of_r_eyebrow",
557
+ "upper_startpoint_of_l_eyebrow",
558
+ "lower_startpoint_of_l_eyebrow",
559
+ "end_of_l_eyebrow",
560
+ "upper_midpoint_1_of_l_eyebrow",
561
+ "lower_midpoint_1_of_l_eyebrow",
562
+ "upper_midpoint_2_of_l_eyebrow",
563
+ "upper_midpoint_3_of_l_eyebrow",
564
+ "lower_midpoint_2_of_l_eyebrow",
565
+ "lower_midpoint_3_of_l_eyebrow",
566
+ "l_inner_end_of_upper_lash_line",
567
+ "l_outer_end_of_upper_lash_line",
568
+ "l_centerpoint_of_upper_lash_line",
569
+ "l_midpoint_2_of_upper_lash_line",
570
+ "l_midpoint_1_of_upper_lash_line",
571
+ "l_midpoint_6_of_upper_lash_line",
572
+ "l_midpoint_5_of_upper_lash_line",
573
+ "l_midpoint_4_of_upper_lash_line",
574
+ "l_midpoint_3_of_upper_lash_line",
575
+ "l_outer_end_of_upper_eyelid_line",
576
+ "l_midpoint_6_of_upper_eyelid_line",
577
+ "l_midpoint_2_of_upper_eyelid_line",
578
+ "l_midpoint_5_of_upper_eyelid_line",
579
+ "l_centerpoint_of_upper_eyelid_line",
580
+ "l_midpoint_4_of_upper_eyelid_line",
581
+ "l_midpoint_1_of_upper_eyelid_line",
582
+ "l_midpoint_3_of_upper_eyelid_line",
583
+ "l_midpoint_6_of_upper_crease_line",
584
+ "l_midpoint_2_of_upper_crease_line",
585
+ "l_midpoint_5_of_upper_crease_line",
586
+ "l_centerpoint_of_upper_crease_line",
587
+ "l_midpoint_4_of_upper_crease_line",
588
+ "l_midpoint_1_of_upper_crease_line",
589
+ "l_midpoint_3_of_upper_crease_line",
590
+ "r_inner_end_of_upper_lash_line",
591
+ "r_outer_end_of_upper_lash_line",
592
+ "r_centerpoint_of_upper_lash_line",
593
+ "r_midpoint_1_of_upper_lash_line",
594
+ "r_midpoint_2_of_upper_lash_line",
595
+ "r_midpoint_3_of_upper_lash_line",
596
+ "r_midpoint_4_of_upper_lash_line",
597
+ "r_midpoint_5_of_upper_lash_line",
598
+ "r_midpoint_6_of_upper_lash_line",
599
+ "r_outer_end_of_upper_eyelid_line",
600
+ "r_midpoint_3_of_upper_eyelid_line",
601
+ "r_midpoint_1_of_upper_eyelid_line",
602
+ "r_midpoint_4_of_upper_eyelid_line",
603
+ "r_centerpoint_of_upper_eyelid_line",
604
+ "r_midpoint_5_of_upper_eyelid_line",
605
+ "r_midpoint_2_of_upper_eyelid_line",
606
+ "r_midpoint_6_of_upper_eyelid_line",
607
+ "r_midpoint_3_of_upper_crease_line",
608
+ "r_midpoint_1_of_upper_crease_line",
609
+ "r_midpoint_4_of_upper_crease_line",
610
+ "r_centerpoint_of_upper_crease_line",
611
+ "r_midpoint_5_of_upper_crease_line",
612
+ "r_midpoint_2_of_upper_crease_line",
613
+ "r_midpoint_6_of_upper_crease_line",
614
+ "l_inner_end_of_lower_lash_line",
615
+ "l_outer_end_of_lower_lash_line",
616
+ "l_centerpoint_of_lower_lash_line",
617
+ "l_midpoint_2_of_lower_lash_line",
618
+ "l_midpoint_1_of_lower_lash_line",
619
+ "l_midpoint_6_of_lower_lash_line",
620
+ "l_midpoint_5_of_lower_lash_line",
621
+ "l_midpoint_4_of_lower_lash_line",
622
+ "l_midpoint_3_of_lower_lash_line",
623
+ "l_outer_end_of_lower_eyelid_line",
624
+ "l_midpoint_6_of_lower_eyelid_line",
625
+ "l_midpoint_2_of_lower_eyelid_line",
626
+ "l_midpoint_5_of_lower_eyelid_line",
627
+ "l_centerpoint_of_lower_eyelid_line",
628
+ "l_midpoint_4_of_lower_eyelid_line",
629
+ "l_midpoint_1_of_lower_eyelid_line",
630
+ "l_midpoint_3_of_lower_eyelid_line",
631
+ "r_inner_end_of_lower_lash_line",
632
+ "r_outer_end_of_lower_lash_line",
633
+ "r_centerpoint_of_lower_lash_line",
634
+ "r_midpoint_1_of_lower_lash_line",
635
+ "r_midpoint_2_of_lower_lash_line",
636
+ "r_midpoint_3_of_lower_lash_line",
637
+ "r_midpoint_4_of_lower_lash_line",
638
+ "r_midpoint_5_of_lower_lash_line",
639
+ "r_midpoint_6_of_lower_lash_line",
640
+ "r_outer_end_of_lower_eyelid_line",
641
+ "r_midpoint_3_of_lower_eyelid_line",
642
+ "r_midpoint_1_of_lower_eyelid_line",
643
+ "r_midpoint_4_of_lower_eyelid_line",
644
+ "r_centerpoint_of_lower_eyelid_line",
645
+ "r_midpoint_5_of_lower_eyelid_line",
646
+ "r_midpoint_2_of_lower_eyelid_line",
647
+ "r_midpoint_6_of_lower_eyelid_line",
648
+ "tip_of_nose",
649
+ "bottom_center_of_nose",
650
+ "r_outer_corner_of_nose",
651
+ "l_outer_corner_of_nose",
652
+ "inner_corner_of_r_nostril",
653
+ "outer_corner_of_r_nostril",
654
+ "upper_corner_of_r_nostril",
655
+ "inner_corner_of_l_nostril",
656
+ "outer_corner_of_l_nostril",
657
+ "upper_corner_of_l_nostril",
658
+ "r_outer_corner_of_mouth",
659
+ "l_outer_corner_of_mouth",
660
+ "center_of_cupid_bow",
661
+ "center_of_lower_outer_lip",
662
+ "midpoint_1_of_upper_outer_lip",
663
+ "midpoint_2_of_upper_outer_lip",
664
+ "midpoint_1_of_lower_outer_lip",
665
+ "midpoint_2_of_lower_outer_lip",
666
+ "midpoint_3_of_upper_outer_lip",
667
+ "midpoint_4_of_upper_outer_lip",
668
+ "midpoint_5_of_upper_outer_lip",
669
+ "midpoint_6_of_upper_outer_lip",
670
+ "midpoint_3_of_lower_outer_lip",
671
+ "midpoint_4_of_lower_outer_lip",
672
+ "midpoint_5_of_lower_outer_lip",
673
+ "midpoint_6_of_lower_outer_lip",
674
+ "r_inner_corner_of_mouth",
675
+ "l_inner_corner_of_mouth",
676
+ "center_of_upper_inner_lip",
677
+ "center_of_lower_inner_lip",
678
+ "midpoint_1_of_upper_inner_lip",
679
+ "midpoint_2_of_upper_inner_lip",
680
+ "midpoint_1_of_lower_inner_lip",
681
+ "midpoint_2_of_lower_inner_lip",
682
+ "midpoint_3_of_upper_inner_lip",
683
+ "midpoint_4_of_upper_inner_lip",
684
+ "midpoint_5_of_upper_inner_lip",
685
+ "midpoint_6_of_upper_inner_lip",
686
+ "midpoint_3_of_lower_inner_lip",
687
+ "midpoint_4_of_lower_inner_lip",
688
+ "midpoint_5_of_lower_inner_lip",
689
+ "midpoint_6_of_lower_inner_lip",
690
+ "l_top_end_of_inferior_crus",
691
+ "l_top_end_of_superior_crus",
692
+ "l_start_of_antihelix",
693
+ "l_end_of_antihelix",
694
+ "l_midpoint_1_of_antihelix",
695
+ "l_midpoint_1_of_inferior_crus",
696
+ "l_midpoint_2_of_antihelix",
697
+ "l_midpoint_3_of_antihelix",
698
+ "l_point_1_of_inner_helix",
699
+ "l_point_2_of_inner_helix",
700
+ "l_point_3_of_inner_helix",
701
+ "l_point_4_of_inner_helix",
702
+ "l_point_5_of_inner_helix",
703
+ "l_point_6_of_inner_helix",
704
+ "l_point_7_of_inner_helix",
705
+ "l_highest_point_of_antitragus",
706
+ "l_bottom_point_of_tragus",
707
+ "l_protruding_point_of_tragus",
708
+ "l_top_point_of_tragus",
709
+ "l_start_point_of_crus_of_helix",
710
+ "l_deepest_point_of_concha",
711
+ "l_tip_of_ear_lobe",
712
+ "l_midpoint_between_22_15",
713
+ "l_bottom_connecting_point_of_ear_lobe",
714
+ "l_top_connecting_point_of_helix",
715
+ "l_point_8_of_inner_helix",
716
+ "r_top_end_of_inferior_crus",
717
+ "r_top_end_of_superior_crus",
718
+ "r_start_of_antihelix",
719
+ "r_end_of_antihelix",
720
+ "r_midpoint_1_of_antihelix",
721
+ "r_midpoint_1_of_inferior_crus",
722
+ "r_midpoint_2_of_antihelix",
723
+ "r_midpoint_3_of_antihelix",
724
+ "r_point_1_of_inner_helix",
725
+ "r_point_8_of_inner_helix",
726
+ "r_point_3_of_inner_helix",
727
+ "r_point_4_of_inner_helix",
728
+ "r_point_5_of_inner_helix",
729
+ "r_point_6_of_inner_helix",
730
+ "r_point_7_of_inner_helix",
731
+ "r_highest_point_of_antitragus",
732
+ "r_bottom_point_of_tragus",
733
+ "r_protruding_point_of_tragus",
734
+ "r_top_point_of_tragus",
735
+ "r_start_point_of_crus_of_helix",
736
+ "r_deepest_point_of_concha",
737
+ "r_tip_of_ear_lobe",
738
+ "r_midpoint_between_22_15",
739
+ "r_bottom_connecting_point_of_ear_lobe",
740
+ "r_top_connecting_point_of_helix",
741
+ "r_point_2_of_inner_helix",
742
+ "l_center_of_iris",
743
+ "l_border_of_iris_3",
744
+ "l_border_of_iris_midpoint_1",
745
+ "l_border_of_iris_12",
746
+ "l_border_of_iris_midpoint_4",
747
+ "l_border_of_iris_9",
748
+ "l_border_of_iris_midpoint_3",
749
+ "l_border_of_iris_6",
750
+ "l_border_of_iris_midpoint_2",
751
+ "r_center_of_iris",
752
+ "r_border_of_iris_3",
753
+ "r_border_of_iris_midpoint_1",
754
+ "r_border_of_iris_12",
755
+ "r_border_of_iris_midpoint_4",
756
+ "r_border_of_iris_9",
757
+ "r_border_of_iris_midpoint_3",
758
+ "r_border_of_iris_6",
759
+ "r_border_of_iris_midpoint_2",
760
+ "l_center_of_pupil",
761
+ "l_border_of_pupil_3",
762
+ "l_border_of_pupil_midpoint_1",
763
+ "l_border_of_pupil_12",
764
+ "l_border_of_pupil_midpoint_4",
765
+ "l_border_of_pupil_9",
766
+ "l_border_of_pupil_midpoint_3",
767
+ "l_border_of_pupil_6",
768
+ "l_border_of_pupil_midpoint_2",
769
+ "r_center_of_pupil",
770
+ "r_border_of_pupil_3",
771
+ "r_border_of_pupil_midpoint_1",
772
+ "r_border_of_pupil_12",
773
+ "r_border_of_pupil_midpoint_4",
774
+ "r_border_of_pupil_9",
775
+ "r_border_of_pupil_midpoint_3",
776
+ "r_border_of_pupil_6",
777
+ "r_border_of_pupil_midpoint_2"
778
+ ]
779
+
780
+ GOLIATH_SKELETON_INFO = {
781
+ 0:
782
+ dict(link=('left_ankle', 'left_knee'), id=0, color=[0, 255, 0]),
783
+ 1:
784
+ dict(link=('left_knee', 'left_hip'), id=1, color=[0, 255, 0]),
785
+ 2:
786
+ dict(link=('right_ankle', 'right_knee'), id=2, color=[255, 128, 0]),
787
+ 3:
788
+ dict(link=('right_knee', 'right_hip'), id=3, color=[255, 128, 0]),
789
+ 4:
790
+ dict(link=('left_hip', 'right_hip'), id=4, color=[51, 153, 255]),
791
+ 5:
792
+ dict(link=('left_shoulder', 'left_hip'), id=5, color=[51, 153, 255]),
793
+ 6:
794
+ dict(link=('right_shoulder', 'right_hip'), id=6, color=[51, 153, 255]),
795
+ 7:
796
+ dict(
797
+ link=('left_shoulder', 'right_shoulder'),
798
+ id=7,
799
+ color=[51, 153, 255]),
800
+ 8:
801
+ dict(link=('left_shoulder', 'left_elbow'), id=8, color=[0, 255, 0]),
802
+ 9:
803
+ dict(
804
+ link=('right_shoulder', 'right_elbow'), id=9, color=[255, 128, 0]),
805
+ 10:
806
+ dict(link=('left_elbow', 'left_wrist'), id=10, color=[0, 255, 0]),
807
+ 11:
808
+ dict(link=('right_elbow', 'right_wrist'), id=11, color=[255, 128, 0]),
809
+ 12:
810
+ dict(link=('left_eye', 'right_eye'), id=12, color=[51, 153, 255]),
811
+ 13:
812
+ dict(link=('nose', 'left_eye'), id=13, color=[51, 153, 255]),
813
+ 14:
814
+ dict(link=('nose', 'right_eye'), id=14, color=[51, 153, 255]),
815
+ 15:
816
+ dict(link=('left_eye', 'left_ear'), id=15, color=[51, 153, 255]),
817
+ 16:
818
+ dict(link=('right_eye', 'right_ear'), id=16, color=[51, 153, 255]),
819
+ 17:
820
+ dict(link=('left_ear', 'left_shoulder'), id=17, color=[51, 153, 255]),
821
+ 18:
822
+ dict(
823
+ link=('right_ear', 'right_shoulder'), id=18, color=[51, 153, 255]),
824
+ 19:
825
+ dict(link=('left_ankle', 'left_big_toe'), id=19, color=[0, 255, 0]),
826
+ 20:
827
+ dict(link=('left_ankle', 'left_small_toe'), id=20, color=[0, 255, 0]),
828
+ 21:
829
+ dict(link=('left_ankle', 'left_heel'), id=21, color=[0, 255, 0]),
830
+ 22:
831
+ dict(
832
+ link=('right_ankle', 'right_big_toe'), id=22, color=[255, 128, 0]),
833
+ 23:
834
+ dict(
835
+ link=('right_ankle', 'right_small_toe'),
836
+ id=23,
837
+ color=[255, 128, 0]),
838
+ 24:
839
+ dict(link=('right_ankle', 'right_heel'), id=24, color=[255, 128, 0]),
840
+ 25:
841
+ dict(
842
+ link=('left_wrist', 'left_thumb_third_joint'), id=25, color=[255, 128,
843
+ 0]),
844
+ 26:
845
+ dict(link=('left_thumb_third_joint', 'left_thumb2'), id=26, color=[255, 128, 0]),
846
+ 27:
847
+ dict(link=('left_thumb2', 'left_thumb3'), id=27, color=[255, 128, 0]),
848
+ 28:
849
+ dict(link=('left_thumb3', 'left_thumb4'), id=28, color=[255, 128, 0]),
850
+ 29:
851
+ dict(
852
+ link=('left_wrist', 'left_forefinger_third_joint'),
853
+ id=29,
854
+ color=[255, 153, 255]),
855
+ 30:
856
+ dict(
857
+ link=('left_forefinger_third_joint', 'left_forefinger2'),
858
+ id=30,
859
+ color=[255, 153, 255]),
860
+ 31:
861
+ dict(
862
+ link=('left_forefinger2', 'left_forefinger3'),
863
+ id=31,
864
+ color=[255, 153, 255]),
865
+ 32:
866
+ dict(
867
+ link=('left_forefinger3', 'left_forefinger4'),
868
+ id=32,
869
+ color=[255, 153, 255]),
870
+ 33:
871
+ dict(
872
+ link=('left_wrist', 'left_middle_finger_third_joint'),
873
+ id=33,
874
+ color=[102, 178, 255]),
875
+ 34:
876
+ dict(
877
+ link=('left_middle_finger_third_joint', 'left_middle_finger2'),
878
+ id=34,
879
+ color=[102, 178, 255]),
880
+ 35:
881
+ dict(
882
+ link=('left_middle_finger2', 'left_middle_finger3'),
883
+ id=35,
884
+ color=[102, 178, 255]),
885
+ 36:
886
+ dict(
887
+ link=('left_middle_finger3', 'left_middle_finger4'),
888
+ id=36,
889
+ color=[102, 178, 255]),
890
+ 37:
891
+ dict(
892
+ link=('left_wrist', 'left_ring_finger_third_joint'),
893
+ id=37,
894
+ color=[255, 51, 51]),
895
+ 38:
896
+ dict(
897
+ link=('left_ring_finger_third_joint', 'left_ring_finger2'),
898
+ id=38,
899
+ color=[255, 51, 51]),
900
+ 39:
901
+ dict(
902
+ link=('left_ring_finger2', 'left_ring_finger3'),
903
+ id=39,
904
+ color=[255, 51, 51]),
905
+ 40:
906
+ dict(
907
+ link=('left_ring_finger3', 'left_ring_finger4'),
908
+ id=40,
909
+ color=[255, 51, 51]),
910
+ 41:
911
+ dict(
912
+ link=('left_wrist', 'left_pinky_finger_third_joint'),
913
+ id=41,
914
+ color=[0, 255, 0]),
915
+ 42:
916
+ dict(
917
+ link=('left_pinky_finger_third_joint', 'left_pinky_finger2'),
918
+ id=42,
919
+ color=[0, 255, 0]),
920
+ 43:
921
+ dict(
922
+ link=('left_pinky_finger2', 'left_pinky_finger3'),
923
+ id=43,
924
+ color=[0, 255, 0]),
925
+ 44:
926
+ dict(
927
+ link=('left_pinky_finger3', 'left_pinky_finger4'),
928
+ id=44,
929
+ color=[0, 255, 0]),
930
+ 45:
931
+ dict(
932
+ link=('right_wrist', 'right_thumb_third_joint'),
933
+ id=45,
934
+ color=[255, 128, 0]),
935
+ 46:
936
+ dict(
937
+ link=('right_thumb_third_joint', 'right_thumb2'), id=46, color=[255, 128, 0]),
938
+ 47:
939
+ dict(
940
+ link=('right_thumb2', 'right_thumb3'), id=47, color=[255, 128, 0]),
941
+ 48:
942
+ dict(
943
+ link=('right_thumb3', 'right_thumb4'), id=48, color=[255, 128, 0]),
944
+ 49:
945
+ dict(
946
+ link=('right_wrist', 'right_forefinger_third_joint'),
947
+ id=49,
948
+ color=[255, 153, 255]),
949
+ 50:
950
+ dict(
951
+ link=('right_forefinger_third_joint', 'right_forefinger2'),
952
+ id=50,
953
+ color=[255, 153, 255]),
954
+ 51:
955
+ dict(
956
+ link=('right_forefinger2', 'right_forefinger3'),
957
+ id=51,
958
+ color=[255, 153, 255]),
959
+ 52:
960
+ dict(
961
+ link=('right_forefinger3', 'right_forefinger4'),
962
+ id=52,
963
+ color=[255, 153, 255]),
964
+ 53:
965
+ dict(
966
+ link=('right_wrist', 'right_middle_finger_third_joint'),
967
+ id=53,
968
+ color=[102, 178, 255]),
969
+ 54:
970
+ dict(
971
+ link=('right_middle_finger_third_joint', 'right_middle_finger2'),
972
+ id=54,
973
+ color=[102, 178, 255]),
974
+ 55:
975
+ dict(
976
+ link=('right_middle_finger2', 'right_middle_finger3'),
977
+ id=55,
978
+ color=[102, 178, 255]),
979
+ 56:
980
+ dict(
981
+ link=('right_middle_finger3', 'right_middle_finger4'),
982
+ id=56,
983
+ color=[102, 178, 255]),
984
+ 57:
985
+ dict(
986
+ link=('right_wrist', 'right_ring_finger_third_joint'),
987
+ id=57,
988
+ color=[255, 51, 51]),
989
+ 58:
990
+ dict(
991
+ link=('right_ring_finger_third_joint', 'right_ring_finger2'),
992
+ id=58,
993
+ color=[255, 51, 51]),
994
+ 59:
995
+ dict(
996
+ link=('right_ring_finger2', 'right_ring_finger3'),
997
+ id=59,
998
+ color=[255, 51, 51]),
999
+ 60:
1000
+ dict(
1001
+ link=('right_ring_finger3', 'right_ring_finger4'),
1002
+ id=60,
1003
+ color=[255, 51, 51]),
1004
+ 61:
1005
+ dict(
1006
+ link=('right_wrist', 'right_pinky_finger_third_joint'),
1007
+ id=61,
1008
+ color=[0, 255, 0]),
1009
+ 62:
1010
+ dict(
1011
+ link=('right_pinky_finger_third_joint', 'right_pinky_finger2'),
1012
+ id=62,
1013
+ color=[0, 255, 0]),
1014
+ 63:
1015
+ dict(
1016
+ link=('right_pinky_finger2', 'right_pinky_finger3'),
1017
+ id=63,
1018
+ color=[0, 255, 0]),
1019
+ 64:
1020
+ dict(
1021
+ link=('right_pinky_finger3', 'right_pinky_finger4'),
1022
+ id=64,
1023
+ color=[0, 255, 0])
1024
+ }
detector_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Sequence, Union
2
+
3
+ import torch
4
+ import cv2
5
+ import numpy as np
6
+ from mmcv.ops import RoIPool
7
+ from mmengine.dataset import Compose, pseudo_collate
8
+ from mmengine.device import get_device
9
+ from mmengine.registry import init_default_scope
10
+ from mmdet.apis import inference_detector, init_detector
11
+ from mmdet.structures import DetDataSample, SampleList
12
+ from mmdet.utils import get_test_pipeline_cfg
13
+
14
+
15
+ ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
16
+
17
+ def nms(dets: np.ndarray, thr: float):
18
+ """Greedily select boxes with high confidence and overlap <= thr.
19
+ Args:
20
+ dets (np.ndarray): [[x1, y1, x2, y2, score]].
21
+ thr (float): Retain overlap < thr.
22
+ Returns:
23
+ list: Indexes to keep.
24
+ """
25
+ if len(dets) == 0:
26
+ return []
27
+
28
+ x1 = dets[:, 0]
29
+ y1 = dets[:, 1]
30
+ x2 = dets[:, 2]
31
+ y2 = dets[:, 3]
32
+ scores = dets[:, 4]
33
+
34
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
35
+ order = scores.argsort()[::-1]
36
+
37
+ keep = []
38
+ while len(order) > 0:
39
+ i = order[0]
40
+ keep.append(i)
41
+ xx1 = np.maximum(x1[i], x1[order[1:]])
42
+ yy1 = np.maximum(y1[i], y1[order[1:]])
43
+ xx2 = np.minimum(x2[i], x2[order[1:]])
44
+ yy2 = np.minimum(y2[i], y2[order[1:]])
45
+
46
+ w = np.maximum(0.0, xx2 - xx1 + 1)
47
+ h = np.maximum(0.0, yy2 - yy1 + 1)
48
+ inter = w * h
49
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
50
+
51
+ inds = np.where(ovr <= thr)[0]
52
+ order = order[inds + 1]
53
+
54
+ return keep
55
+
56
+ def adapt_mmdet_pipeline(cfg):
57
+ """Converts pipeline types in MMDetection's test dataloader to use the
58
+ 'mmdet' namespace.
59
+
60
+ Args:
61
+ cfg (ConfigDict): Configuration dictionary for MMDetection.
62
+
63
+ Returns:
64
+ ConfigDict: Configuration dictionary with updated pipeline types.
65
+ """
66
+ # use lazy import to avoid hard dependence on mmdet
67
+ from mmdet.datasets import transforms
68
+
69
+ if 'test_dataloader' not in cfg:
70
+ return cfg
71
+
72
+ pipeline = cfg.test_dataloader.dataset.pipeline
73
+ for trans in pipeline:
74
+ if trans['type'] in dir(transforms):
75
+ trans['type'] = 'mmdet.' + trans['type']
76
+
77
+ return cfg
78
+
79
+
80
+ def inference_detector(
81
+ model: torch.nn.Module,
82
+ imgs: ImagesType,
83
+ test_pipeline: Optional[Compose] = None,
84
+ text_prompt: Optional[str] = None,
85
+ custom_entities: bool = False,
86
+ ) -> Union[DetDataSample, SampleList]:
87
+ """Inference image(s) with the detector.
88
+
89
+ Args:
90
+ model (nn.Module): The loaded detector.
91
+ imgs (str, ndarray, Sequence[str/ndarray]):
92
+ Either image files or loaded images.
93
+ test_pipeline (:obj:`Compose`): Test pipeline.
94
+
95
+ Returns:
96
+ :obj:`DetDataSample` or list[:obj:`DetDataSample`]:
97
+ If imgs is a list or tuple, the same length list type results
98
+ will be returned, otherwise return the detection results directly.
99
+ """
100
+ if isinstance(imgs, torch.Tensor):
101
+ if imgs.is_cuda:
102
+ imgs = imgs.cpu()
103
+
104
+ # Remove batch dimension and transpose
105
+ imgs = imgs.squeeze(0).permute(1, 2, 0).numpy()
106
+
107
+ # Ensure the data type is appropriate (uint8 for most image processing functions)
108
+ imgs = (imgs * 255).astype(np.uint8)
109
+
110
+ if isinstance(imgs, (list, tuple)) or (isinstance(imgs, np.ndarray) and len(imgs.shape) == 4):
111
+ is_batch = True
112
+ else:
113
+ imgs = [imgs]
114
+ is_batch = False
115
+
116
+ cfg = model.cfg
117
+
118
+ if test_pipeline is None:
119
+ cfg = cfg.copy()
120
+ test_pipeline = get_test_pipeline_cfg(cfg)
121
+ if isinstance(imgs[0], np.ndarray):
122
+ # Calling this method across libraries will result
123
+ # in module unregistered error if not prefixed with mmdet.
124
+ test_pipeline[0].type = "mmdet.LoadImageFromNDArray"
125
+
126
+ test_pipeline = Compose(test_pipeline)
127
+
128
+ if model.data_preprocessor.device.type == "cpu":
129
+ for m in model.modules():
130
+ assert not isinstance(
131
+ m, RoIPool
132
+ ), "CPU inference with RoIPool is not supported currently."
133
+
134
+ result_list = []
135
+ for i, img in enumerate(imgs):
136
+ # prepare data
137
+ if isinstance(img, np.ndarray):
138
+ # TODO: remove img_id.
139
+ data_ = dict(img=img, img_id=0)
140
+ else:
141
+ # TODO: remove img_id.
142
+ data_ = dict(img_path=img, img_id=0)
143
+
144
+ if text_prompt:
145
+ data_["text"] = text_prompt
146
+ data_["custom_entities"] = custom_entities
147
+
148
+ # build the data pipeline
149
+ data_ = test_pipeline(data_)
150
+
151
+ data_["inputs"] = [data_["inputs"]]
152
+ data_["data_samples"] = [data_["data_samples"]]
153
+
154
+ # forward the model
155
+ with torch.no_grad(), torch.autocast(device_type=get_device(), dtype=torch.bfloat16):
156
+ results = model.test_step(data_)[0]
157
+
158
+ result_list.append(results)
159
+
160
+ if not is_batch:
161
+ return result_list[0]
162
+ else:
163
+ return result_list
164
+
165
+
166
+ def process_one_image_bbox(pred_instance, det_cat_id, bbox_thr, nms_thr):
167
+ bboxes = np.concatenate(
168
+ (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1
169
+ )
170
+ bboxes = bboxes[
171
+ np.logical_and(
172
+ pred_instance.labels == det_cat_id,
173
+ pred_instance.scores > bbox_thr,
174
+ )
175
+ ]
176
+ bboxes = bboxes[nms(bboxes, nms_thr), :4]
177
+ return bboxes
178
+
179
+
180
+ def process_images_detector(imgs, detector):
181
+ """Visualize predicted keypoints (and heatmaps) of one image."""
182
+ # predict bbox
183
+ det_results = inference_detector(detector, imgs)
184
+ pred_instances = list(
185
+ map(lambda det_result: det_result.pred_instances.numpy(), det_results)
186
+ )
187
+ bboxes_batch = list(
188
+ map(
189
+ lambda pred_instance: process_one_image_bbox(
190
+ pred_instance, 0, 0.3, 0.3 ## argparse.Namespace(det_cat_id=0, bbox_thr=0.3, nms_thr=0.3),
191
+ ),
192
+ pred_instances,
193
+ )
194
+ )
195
+
196
+ return bboxes_batch
pose_render_utils.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import cv2
8
+ import numpy as np
9
+
10
+
11
+ def visualize_keypoints(
12
+ image: np.ndarray, # RGB uint8 H,W,3
13
+ keypoints, # list[(J,2)]
14
+ keypoints_visible, # list[(J,), {0/1}]
15
+ keypoint_scores, # list[(J,)]
16
+ *,
17
+ radius: int = 4,
18
+ thickness: int = -1,
19
+ color=(255, 0, 0),
20
+ kpt_thr: float = 0.3,
21
+ skeleton: list | None = None, # [(i,j)]
22
+ kpt_color: list | tuple | np.ndarray | None = None,
23
+ link_color: list | tuple | np.ndarray | None = None,
24
+ show_kpt_idx: bool = False,
25
+ ) -> np.ndarray:
26
+ img = image.copy()
27
+ H, W = img.shape[:2]
28
+
29
+ # defaults
30
+ if skeleton is None:
31
+ skeleton = [] # points only
32
+ if kpt_color is None:
33
+ kpt_color = color
34
+ if link_color is None:
35
+ link_color = (0, 255, 0)
36
+
37
+ # robust color normalization: supports tuple, list-of-tuples, np.ndarray (N,3) or (3,)
38
+ def _as_color_list(c, n):
39
+ # torch -> numpy
40
+ if hasattr(c, "detach"):
41
+ c = c.detach().cpu().numpy()
42
+ # numpy -> array
43
+ if isinstance(c, np.ndarray):
44
+ if c.ndim == 2 and c.shape[1] == 3: # (N,3) palette
45
+ return [tuple(int(v) for v in row) for row in c.tolist()]
46
+ if c.size == 3: # single (3,)
47
+ return [tuple(int(v) for v in c.tolist())] * max(1, n)
48
+ # python containers
49
+ if isinstance(c, (list, tuple)):
50
+ if n and len(c) == n and isinstance(c[0], (list, tuple, np.ndarray)):
51
+ out = []
52
+ for cc in c:
53
+ cc = np.asarray(cc).reshape(-1)
54
+ assert cc.size == 3, "Each color must be length-3"
55
+ out.append(tuple(int(v) for v in cc.tolist()))
56
+ return out
57
+ # single triplet
58
+ c_arr = np.asarray(c).reshape(-1)
59
+ if c_arr.size == 3:
60
+ return [tuple(int(v) for v in c_arr.tolist())] * max(1, n)
61
+ # fallback: red
62
+ return [(255, 0, 0)] * max(1, n)
63
+
64
+ J = keypoints[0].shape[0] if keypoints else 0
65
+ kpt_colors = _as_color_list(kpt_color, J)
66
+ link_colors = _as_color_list(link_color, len(skeleton))
67
+
68
+ def in_bounds(x, y):
69
+ return 0 <= x < W and 0 <= y < H
70
+
71
+ for kpts, vis, score in zip(keypoints, keypoints_visible, keypoint_scores):
72
+ kpts = np.asarray(kpts, float)
73
+ vis = np.asarray(vis).reshape(-1).astype(bool)
74
+ score = np.asarray(score).reshape(-1)
75
+
76
+ # links (draw in RGB; NO channel flip)
77
+ for lk, (i, j) in enumerate(skeleton):
78
+ if i >= len(kpts) or j >= len(kpts):
79
+ continue
80
+ if not (vis[i] and vis[j]):
81
+ continue
82
+ if score[i] < kpt_thr or score[j] < kpt_thr:
83
+ continue
84
+
85
+ x1, y1 = map(int, np.round(kpts[i]))
86
+ x2, y2 = map(int, np.round(kpts[j]))
87
+ if not (in_bounds(x1, y1) and in_bounds(x2, y2)):
88
+ continue
89
+
90
+ cv2.line(
91
+ img,
92
+ (x1, y1),
93
+ (x2, y2),
94
+ link_colors[lk % len(link_colors)],
95
+ thickness=max(1, thickness),
96
+ lineType=cv2.LINE_AA,
97
+ )
98
+
99
+ # points
100
+ for j_idx, (xy, v, s) in enumerate(zip(kpts, vis, score)):
101
+ if not v or s < kpt_thr:
102
+ continue
103
+ x, y = map(int, np.round(xy))
104
+ if not in_bounds(x, y):
105
+ continue
106
+
107
+ c = kpt_colors[min(j_idx, len(kpt_colors) - 1)]
108
+ cv2.circle(img, (x, y), radius, c, thickness=-1, lineType=cv2.LINE_AA)
109
+ if show_kpt_idx:
110
+ cv2.putText(
111
+ img,
112
+ str(j_idx),
113
+ (x + radius, y - radius),
114
+ cv2.FONT_HERSHEY_SIMPLEX,
115
+ 0.4,
116
+ c,
117
+ 1,
118
+ cv2.LINE_AA,
119
+ )
120
+ return img
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.42.0
2
+ spaces
3
+ numpy
4
+ torch
5
+ torchvision
6
+ opencv-python
7
+ pillow
8
+ matplotlib
9
+ safetensors
10
+ huggingface_hub
11
+
12
+ # mmdet stack — needed for the RTMDet person detector
13
+ mmengine
14
+ mmcv==2.1.0
15
+ mmdet==3.2.0
16
+
17
+ # Sapiens2 itself (provides PoseTopdownEstimator + init_model + 308-keypoint configs)
18
+ sapiens @ git+https://github.com/facebookresearch/sapiens2.git