diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f6b1f326ca4ab7cf0c8798856f8fe0020ff82d58 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..08849db0ada0b32b5f44b9cf2de618f870885ab8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__ +*.pyc +default.profraw +.DS_Store +*.log diff --git a/README.md b/README.md index db28bee1c26fe2aba8a4a602f946eb1079deab58..ae48a4ec87b97caf142addf82a65bed69240d1db 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,23 @@ --- title: Sapiens2 Normal -emoji: ๐Ÿข -colorFrom: green -colorTo: gray +emoji: ๐ŸงŠ +colorFrom: purple +colorTo: blue sdk: gradio -sdk_version: 6.13.0 +sdk_version: 4.42.0 app_file: app.py +python_version: "3.12" pinned: false +license: other +license_name: sapiens2-license +license_link: https://github.com/facebookresearch/sapiens2/blob/main/LICENSE.md --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Sapiens2: Surface Normal Estimation +### ICLR 2026 + +Per-pixel surface-normal estimation (3-channel unit vectors in camera frame). + +- **Code:** [github.com/facebookresearch/sapiens2](https://github.com/facebookresearch/sapiens2) +- **Models:** [Sapiens2 collection](https://huggingface.co/facebook/sapiens2) +- **Paper:** https://openreview.net/pdf?id=IVAlYCqdvW diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6c98bbe5f1151c5e28c6198580642af70232ea --- /dev/null +++ b/app.py @@ -0,0 +1,167 @@ +"""Sapiens2 surface-normal Gradio Space. + +Image โ†’ per-pixel surface normals. Visualized by RGB-encoding the unit-length +(x, y, z) normal: r = (x + 1) / 2, g = (y + 1) / 2, b = (z + 1) / 2. +""" + +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +import tempfile + +import cv2 +import gradio as gr +import numpy as np +import spaces +import torch +import torch.nn.functional as F +from PIL import Image + +from huggingface_hub import hf_hub_download +from sapiens.dense.models import NormalEstimator, init_model # NormalEstimator triggers registry +_ = NormalEstimator + + +# ----------------------------------------------------------------------------- +# Config + +ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") +CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs") + +NORMAL_MODELS = { + "0.4B": { + "repo": "facebook/sapiens2-normal-0.4b", + "filename": "sapiens2_0.4b_normal.safetensors", + "config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_normal_metasim_render_people-1024x768.py"), + }, + "0.8B": { + "repo": "facebook/sapiens2-normal-0.8b", + "filename": "sapiens2_0.8b_normal.safetensors", + "config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_normal_metasim_render_people-1024x768.py"), + }, + "1B": { + "repo": "facebook/sapiens2-normal-1b", + "filename": "sapiens2_1b_normal.safetensors", + "config": os.path.join(CONFIGS_DIR, "sapiens2_1b_normal_metasim_render_people-1024x768.py"), + }, + "5B": { + "repo": "facebook/sapiens2-normal-5b", + "filename": "sapiens2_5b_normal.safetensors", + "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_normal_metasim_render_people-1024x768.py"), + }, +} +DEFAULT_SIZE = "1B" + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +# ----------------------------------------------------------------------------- +# Model cache + +_normal_model_cache: dict = {} + + +def _get_normal_model(size: str): + if size not in _normal_model_cache: + spec = NORMAL_MODELS[size] + ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"]) + model = init_model(spec["config"], ckpt, device=DEVICE) + _normal_model_cache[size] = model + return _normal_model_cache[size] + + +print("[startup] pre-loading all normal sizes ...") +for _size in NORMAL_MODELS: + _get_normal_model(_size) +print("[startup] ready.") + + +# ----------------------------------------------------------------------------- +# Inference + +def _estimate_normal(image_bgr: np.ndarray, model) -> np.ndarray: + h0, w0 = image_bgr.shape[:2] + data = model.pipeline(dict(img=image_bgr)) + data = model.data_preprocessor(data) + inputs = data["inputs"] + if inputs.ndim == 3: + inputs = inputs.unsqueeze(0) + + with torch.no_grad(): + normals = model(inputs) # (1, 3, H, W) + + # Unit-length normalization, interpolate to original size, cast to numpy + normals = normals / normals.norm(dim=1, keepdim=True).clamp_min(1e-6) + normals = F.interpolate(normals, size=(h0, w0), mode="bilinear", align_corners=False) + normals = normals[0].cpu().float().numpy() # (3, H, W) in [-1, 1] + return normals.transpose(1, 2, 0) # (H, W, 3) + + +def _normal_to_rgb(normal_hwc: np.ndarray) -> np.ndarray: + rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8) + return rgb[:, :, ::-1] # match training viz channel order + + +# ----------------------------------------------------------------------------- +# Gradio handler + +@spaces.GPU(duration=120) +def predict(image: Image.Image, size: str): + if image is None: + return None, None + + image_rgb = np.array(image.convert("RGB")) + image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) + + model = _get_normal_model(size) + normals = _estimate_normal(image_bgr, model) # (H, W, 3) in [-1, 1] + rgb = _normal_to_rgb(normals) + + with tempfile.NamedTemporaryFile(delete=False, suffix=".npy") as f: + np.save(f.name, normals.astype(np.float32)) + npy_path = f.name + + return Image.fromarray(rgb), npy_path + + +# ----------------------------------------------------------------------------- +# UI + +EXAMPLES = sorted( + os.path.join(ASSETS_DIR, "images", n) + for n in os.listdir(os.path.join(ASSETS_DIR, "images")) + if n.lower().endswith((".jpg", ".jpeg", ".png")) +) + +with gr.Blocks(title="Sapiens2 Normal", theme=gr.themes.Default()) as demo: + gr.Markdown( + "# Sapiens2: Surface Normal Estimation\n" + "### ICLR 2026\n" + "Per-pixel surface-normal estimation. Output is RGB-encoded (x, y, z โ†’ R, G, B).\n\n" + "[Code](https://github.com/facebookresearch/sapiens2) ยท " + "[Models](https://huggingface.co/facebook/sapiens2) ยท " + "[Paper](https://openreview.net/pdf?id=IVAlYCqdvW)" + ) + with gr.Row(): + with gr.Column(): + inp = gr.Image(label="Input", type="pil") + size = gr.Radio( + choices=list(NORMAL_MODELS.keys()), + value=DEFAULT_SIZE, + label="Model size", + ) + run = gr.Button("Run", variant="primary") + gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14) + with gr.Column(): + out_img = gr.Image(label="Surface normal (RGB-encoded)", type="pil") + out_npy = gr.File(label="Raw normals (.npy float32 [-1, 1])") + + run.click(predict, inputs=[inp, size], outputs=[out_img, out_npy]) + + +if __name__ == "__main__": + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + demo.launch(share=False) diff --git a/assets/configs/sapiens2_0.4b_normal_metasim_render_people-1024x768.py b/assets/configs/sapiens2_0.4b_normal_metasim_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..0c1b5ddfb333aea8bcce75923d8359093dc0e35b --- /dev/null +++ b/assets/configs/sapiens2_0.4b_normal_metasim_render_people-1024x768.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 2e4 + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_0.4b" +embed_dim = 1024 +num_layers = 24 +num_heads = 16 + +layer_decay_rate = 0.8 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.4b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="NormalVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=8, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict(type="PhotoMetricDistortion"), + dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2), + dict( + type="NormalRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="NormalRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="NormalRandomFlip", + prob=0.3, + ), + dict(type="NormalResize", height=1024, width=768), + dict( + type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0) + ), + dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)), + dict(type="RandomSolarize", prob=0.3, threshold=128), + dict(type="NormalGenerateTarget"), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "ori_shape", + ), + ), +] + +val_pipeline = [ + dict(type="NormalResize", height=1024, width=768, test_mode=True), + dict( + type="NormalPackInputs", + test_mode=True, + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + ), + ), +] + +test_pipeline = [ + dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "padding_size", + ), + ), +] + +metasim_dataset = dict( + type="NormalMetaSimDataset", + airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data", + json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json", +) + +render_people_dataset = dict( + type="NormalRenderPeopleBodyDataset", ## body only + data_root=f"{_DATA_ROOT}/synthetic", + seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg", +) + +multihuman_render_people_dataset = dict( + type="NormalRenderPeopleMultihumanDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human", + normal_extension=".npz", + seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman +) + +# train_datasets = 2 * [metasim_dataset] + [ +# render_people_dataset, +# multihuman_render_people_dataset, +# ] + +# train_datasets = [render_people_dataset] +# train_datasets = [multihuman_render_people_dataset] +train_datasets = [metasim_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="NormalRenderPeopleBodyDataset", ## body only + # num_samples=100, ## debug: only use N samples for validation + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="NormalEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="NormalEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="NormalHead", + in_channels=embed_dim, + upsample_channels=[768, 512, 256, 128], ## 1K resolution + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + loss_decode=[ + dict( + type="NormalCosineSimilarityLoss", + loss_weight=10.0, + ), + dict(type="L1Loss", loss_weight=1.0), + dict(type="NormalGradL1Loss", loss_weight=10.0), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0) diff --git a/assets/configs/sapiens2_0.8b_normal_metasim_render_people-1024x768.py b/assets/configs/sapiens2_0.8b_normal_metasim_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf64f4635c457a8ded298326ed61c9647e93a48 --- /dev/null +++ b/assets/configs/sapiens2_0.8b_normal_metasim_render_people-1024x768.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 1e4 + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_0.8b" +embed_dim = 1280 +num_layers = 32 +num_heads = 16 + +layer_decay_rate = 0.85 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.8b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="NormalVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=8, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict(type="PhotoMetricDistortion"), + dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2), + dict( + type="NormalRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="NormalRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="NormalRandomFlip", + prob=0.3, + ), + dict(type="NormalResize", height=1024, width=768), + dict( + type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0) + ), + dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)), + dict(type="RandomSolarize", prob=0.3, threshold=128), + dict(type="NormalGenerateTarget"), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "ori_shape", + ), + ), +] + +val_pipeline = [ + dict(type="NormalResize", height=1024, width=768, test_mode=True), + dict( + type="NormalPackInputs", + test_mode=True, + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + ), + ), +] + +test_pipeline = [ + dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "padding_size", + ), + ), +] + +metasim_dataset = dict( + type="NormalMetaSimDataset", + airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data", + json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json", +) + +render_people_dataset = dict( + type="NormalRenderPeopleBodyDataset", ## body only + data_root=f"{_DATA_ROOT}/synthetic", + seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg", +) + +multihuman_render_people_dataset = dict( + type="NormalRenderPeopleMultihumanDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human", + normal_extension=".npz", + seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman +) + +# train_datasets = 2 * [metasim_dataset] + [ +# render_people_dataset, +# multihuman_render_people_dataset, +# ] + +# train_datasets = [render_people_dataset] +# train_datasets = [multihuman_render_people_dataset] +train_datasets = [metasim_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="NormalRenderPeopleBodyDataset", ## body only + # num_samples=100, ## debug: only use N samples for validation + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="NormalEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="NormalEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="NormalHead", + in_channels=embed_dim, + upsample_channels=[768, 512, 256, 128], ## 1K resolution + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + loss_decode=[ + dict( + type="NormalCosineSimilarityLoss", + loss_weight=10.0, + ), + dict(type="L1Loss", loss_weight=1.0), + dict(type="NormalGradL1Loss", loss_weight=10.0), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/assets/configs/sapiens2_1b_normal_metasim_render_people-1024x768.py b/assets/configs/sapiens2_1b_normal_metasim_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..bf76d49b8e0a08fda95e7e89ed7268535d7b8c9c --- /dev/null +++ b/assets/configs/sapiens2_1b_normal_metasim_render_people-1024x768.py @@ -0,0 +1,306 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters. +# num_iters = 1e4 ## light finetune + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_1b" +embed_dim = 1536 +num_layers = 40 +num_heads = 24 +layer_decay_rate = 0.9 + +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_1b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="NormalVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict(type="PhotoMetricDistortion"), + dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2), + dict( + type="NormalRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="NormalRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="NormalRandomFlip", + prob=0.3, + ), + dict(type="NormalResize", height=1024, width=768), + dict( + type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0) + ), + dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)), + dict(type="RandomSolarize", prob=0.3, threshold=128), + dict(type="NormalGenerateTarget"), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "ori_shape", + ), + ), +] + +val_pipeline = [ + dict(type="NormalResize", height=1024, width=768, test_mode=True), + dict( + type="NormalPackInputs", + test_mode=True, + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + ), + ), +] + +test_pipeline = [ + dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "padding_size", + ), + ), +] + +metasim_dataset = dict( + type="NormalMetaSimDataset", + airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data", + json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json", +) + +render_people_dataset = dict( + type="NormalRenderPeopleBodyDataset", ## body only + data_root=f"{_DATA_ROOT}/synthetic", + seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg", +) + +multihuman_render_people_dataset = dict( + type="NormalRenderPeopleMultihumanDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human", + normal_extension=".npz", + seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman +) + +# train_datasets = 2 * [metasim_dataset] + [ +# render_people_dataset, +# multihuman_render_people_dataset, +# ] + +# train_datasets = [render_people_dataset] +# train_datasets = [multihuman_render_people_dataset] +train_datasets = [metasim_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="NormalRenderPeopleBodyDataset", ## body only + # num_samples=100, ## debug: only use N samples for validation + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="NormalEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="NormalEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + # with_cls_token=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="NormalHead", + in_channels=embed_dim, + upsample_channels=[768, 512, 256, 128], ## 1K resolution + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + loss_decode=[ + dict( + type="NormalCosineSimilarityLoss", + loss_weight=10.0, + ), + dict(type="L1Loss", loss_weight=1.0), + dict(type="NormalGradL1Loss", loss_weight=10.0), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/assets/configs/sapiens2_5b_normal_metasim_render_people-1024x768.py b/assets/configs/sapiens2_5b_normal_metasim_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c0b330f08b9787b69defb916fa16e57cb7e10a --- /dev/null +++ b/assets/configs/sapiens2_5b_normal_metasim_render_people-1024x768.py @@ -0,0 +1,312 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters. +# num_iters = 1e4 ## light finetune + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_5b" +embed_dim = 2432 +num_layers = 56 +num_heads = 32 +layer_decay_rate = 0.94 + +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_5b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + # parallelism_cfg=dict( + # dp_shard_size=2, # Fully Sharded Data Parallel degree + # dp_replicate_size=1, # Data Parallel degree + # tp_size=1, # Tensor Parallel degree + # cp_size=4, # Context Parallel degree + # ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="NormalVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict(type="PhotoMetricDistortion"), + dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2), + dict( + type="NormalRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="NormalRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="NormalRandomFlip", + prob=0.3, + ), + dict(type="NormalResize", height=1024, width=768), + dict( + type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0) + ), + dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)), + dict(type="RandomSolarize", prob=0.3, threshold=128), + dict(type="NormalGenerateTarget"), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "ori_shape", + ), + ), +] + +val_pipeline = [ + dict(type="NormalResize", height=1024, width=768, test_mode=True), + dict( + type="NormalPackInputs", + test_mode=True, + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + ), + ), +] + +test_pipeline = [ + dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + ), + ), +] + +metasim_dataset = dict( + type="NormalMetaSimDataset", + airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data", + json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json", +) + +render_people_dataset = dict( + type="NormalRenderPeopleBodyDataset", ## body only + data_root=f"{_DATA_ROOT}/synthetic", + seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg", +) + +multihuman_render_people_dataset = dict( + type="NormalRenderPeopleMultihumanDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human", + normal_extension=".npz", + seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman +) + +# train_datasets = 2 * [metasim_dataset] + [ +# render_people_dataset, +# multihuman_render_people_dataset, +# ] + +# train_datasets = [render_people_dataset] +# train_datasets = [multihuman_render_people_dataset] +train_datasets = [metasim_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="NormalRenderPeopleBodyDataset", ## body only + # num_samples=100, ## debug: only use N samples for validation + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="NormalEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="NormalEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="NormalHead", + in_channels=embed_dim, + upsample_channels=[1536, 768, 512, 256], ## 1K resolution + conv_out_channels=[128, 64, 32], + conv_kernel_sizes=[3, 3, 3], + loss_decode=[ + dict( + type="NormalCosineSimilarityLoss", + loss_weight=10.0, + ), + dict(type="L1Loss", loss_weight=1.0), + dict(type="NormalGradL1Loss", loss_weight=10.0), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + # lr=5e-4, + lr=1e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/assets/images/68204.png b/assets/images/68204.png new file mode 100644 index 0000000000000000000000000000000000000000..6584b288fafd94166c2877b7e43a3f387016a434 --- /dev/null +++ b/assets/images/68204.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b0268cb801ed164864a4b5f6d131e0ac5cc2fbd149a6467d5d0c97da47122c2 +size 4285020 diff --git a/assets/images/68210.png b/assets/images/68210.png new file mode 100644 index 0000000000000000000000000000000000000000..a0c34954cd7483f373026408b083c8195d165489 --- /dev/null +++ b/assets/images/68210.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbe5f80498af4ebd1ff09ae4184f37c20ba981e53bd554c3cc78d39ae0ee7fd7 +size 3933143 diff --git a/assets/images/68658.png b/assets/images/68658.png new file mode 100644 index 0000000000000000000000000000000000000000..24dd9477f8cdb5d92d96db34a8932a0d24da334e --- /dev/null +++ b/assets/images/68658.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61a68b619bd17235e683324f2826ce0693322e45ab8c86f1c057851ecb333ac7 +size 5096267 diff --git a/assets/images/68666.png b/assets/images/68666.png new file mode 100644 index 0000000000000000000000000000000000000000..95e7ae11dc90d22afc15fa3b41cbfc60ac4cda91 --- /dev/null +++ b/assets/images/68666.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea3047e6c2ccb485fdb3966aa2325e803cbf49c27c0bff00287b44bc16f18914 +size 4562681 diff --git a/assets/images/68691.png b/assets/images/68691.png new file mode 100644 index 0000000000000000000000000000000000000000..9c688716c962b891073e1feea115a7838f72fcba --- /dev/null +++ b/assets/images/68691.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fae39e4055c1b297af7068cdddfeeba8d685363281b839d8c5afac1980204b57 +size 3736765 diff --git a/assets/images/68956.png b/assets/images/68956.png new file mode 100644 index 0000000000000000000000000000000000000000..d8a83b85cdb8d999f65677278a28deaa08352a57 --- /dev/null +++ b/assets/images/68956.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eee1f27082b10999d0fa848121ecb06cda3386b1a864b9aa0f59ae78261f8908 +size 4147008 diff --git a/assets/images/pexels-amresh444-17315601.png b/assets/images/pexels-amresh444-17315601.png new file mode 100644 index 0000000000000000000000000000000000000000..8453dc7bb9885c43733d798c46c877779cc8ba15 --- /dev/null +++ b/assets/images/pexels-amresh444-17315601.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e17ee1b229147e4b52e8348a6ef426bc9e9a2f90738e776e15b26b325abb9b3 +size 3503065 diff --git a/assets/images/pexels-gabby-k-6311686.png b/assets/images/pexels-gabby-k-6311686.png new file mode 100644 index 0000000000000000000000000000000000000000..9add365bb9485f5085155dbbbab2232a0b533449 --- /dev/null +++ b/assets/images/pexels-gabby-k-6311686.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f10eded3fb05ab04b963f7b9fd2e183d8d4e81b20569b1c6b0653549639421f +size 3651731 diff --git a/assets/images/pexels-julia-m-cameron-4145040.png b/assets/images/pexels-julia-m-cameron-4145040.png new file mode 100644 index 0000000000000000000000000000000000000000..ff67ab842aabe6ebb6c0d5c8630c2c081c7a40ba --- /dev/null +++ b/assets/images/pexels-julia-m-cameron-4145040.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:459cf0280667b028ffbca16aa11188780d7a0205c0defec02916ff3cbaeecb72 +size 2924608 diff --git a/assets/images/pexels-marcus-aurelius-6787357.png b/assets/images/pexels-marcus-aurelius-6787357.png new file mode 100644 index 0000000000000000000000000000000000000000..c48247aeacbdb3e0e81b2a0dd1376bf26b687817 --- /dev/null +++ b/assets/images/pexels-marcus-aurelius-6787357.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d35452f76492125eaf7d5783aa9fd6b0d5990ebe0579fe9dfd58a9d634f4955 +size 3297473 diff --git a/assets/images/pexels-mo-saeed-3616599-5409085.png b/assets/images/pexels-mo-saeed-3616599-5409085.png new file mode 100644 index 0000000000000000000000000000000000000000..ac7017af6d97524a95a6e43940f95ce10193cc9f --- /dev/null +++ b/assets/images/pexels-mo-saeed-3616599-5409085.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c1ca7afd6c2a654e94ef59d5fb56fca4f3cde5fb5216f6b218c34a7b8c143dc +size 3125143 diff --git a/assets/images/pexels-riedelmax-27355495.png b/assets/images/pexels-riedelmax-27355495.png new file mode 100644 index 0000000000000000000000000000000000000000..20a059e38001957319449e3c1892b3a9bae0ab94 --- /dev/null +++ b/assets/images/pexels-riedelmax-27355495.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4141d2f5f718f162ea1f6710c06b28b5cb51fd69598fde35948f8f3491228164 +size 3732680 diff --git a/assets/images/pexels-sergeymakashin-5368660.png b/assets/images/pexels-sergeymakashin-5368660.png new file mode 100644 index 0000000000000000000000000000000000000000..5b0d1554db9cdcaec20efba4c0628a0ab55867f4 --- /dev/null +++ b/assets/images/pexels-sergeymakashin-5368660.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af8f5a8f26dd102d87d94c1be36ec903791fe8e6d951c68ebb9ebcfc6d7397bb +size 4075879 diff --git a/assets/images/pexels-vinicius-wiesehofer-289347-4219918.png b/assets/images/pexels-vinicius-wiesehofer-289347-4219918.png new file mode 100644 index 0000000000000000000000000000000000000000..95aa28be407ccc9b63706bdfb961f99a67319dad --- /dev/null +++ b/assets/images/pexels-vinicius-wiesehofer-289347-4219918.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6eef5eee15b81fe65ea95627e9a46040b9889466689b3c1ca6ed273e02fe84f +size 3627053 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0ba4b783fbc0f60375983b976f52b3e7b5a71844 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +gradio==4.42.0 +spaces + +torch==2.7.1 +torchvision==0.22.1 + +numpy +opencv-python +pillow +matplotlib +safetensors +huggingface_hub + +# Sapiens2 deps (sapiens2 source is vendored under ./sapiens/, not pip-installed). +tqdm +scipy +iopath +prettytable +termcolor +accelerate +rich diff --git a/sapiens/__init__.py b/sapiens/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0392aebb036336c71255a955b1bf055207ff8ad --- /dev/null +++ b/sapiens/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .version import __version__ +from .engine import * +from .backbones import * +from .dense import * +from .pose import * +from .registry import * + +__all__ = ["__version__"] diff --git a/sapiens/backbones/__init__.py b/sapiens/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ffd1afc09e0f2d57730ed79703b59d7ca167d5e0 --- /dev/null +++ b/sapiens/backbones/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sapiens import Sapiens +from .sapiens2 import Sapiens2 + +__all__ = ["Sapiens", "Sapiens2"] diff --git a/sapiens/backbones/sapiens.py b/sapiens/backbones/sapiens.py new file mode 100644 index 0000000000000000000000000000000000000000..6f9225e5bb192e874edafc3c9cb4cc9af46469f2 --- /dev/null +++ b/sapiens/backbones/sapiens.py @@ -0,0 +1,611 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.engine.models.base_model import BaseModel +from sapiens.registry import MODELS +from torch.nn import Linear, Sequential + + +# ---------------------------------------------------------------------------- +def to_2tuple(x): + if isinstance(x, (str, bytes)): + return (x, x) + if isinstance(x, Sequence): + x = tuple(x) + if len(x) == 2: + return x + raise ValueError("Expected scalar or length-2 iterable") + return (x, x) + + +def resize_pos_embed( + pos_embed, src_shape, dst_shape, mode="bicubic", num_extra_tokens=1 +): + if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: + return pos_embed + assert pos_embed.ndim == 3, "shape of pos_embed must be [1, L, C]" + _, L, C = pos_embed.shape + src_h, src_w = src_shape + assert L == src_h * src_w + num_extra_tokens, ( + f"The length of `pos_embed` ({L}) doesn't match the expected " + f"shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the" + "`img_size` argument." + ) + extra_tokens = pos_embed[:, :num_extra_tokens] + + src_weight = pos_embed[:, num_extra_tokens:] + src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) + + # The cubic interpolate algorithm only accepts float32 + dst_weight = F.interpolate( + src_weight.float(), size=dst_shape, align_corners=False, mode=mode + ) + dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) + dst_weight = dst_weight.to(src_weight.dtype) + + return torch.cat((extra_tokens, dst_weight), dim=1) + + +# ---------------------------------------------------------------------------- +class PatchEmbed(nn.Module): + def __init__( + self, + in_channels=3, + embed_dims=768, + kernel_size=16, + stride=16, + padding="corner", + dilation=1, + bias=True, + input_size=None, + ): + super().__init__() + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + padding = 0 + padding = to_2tuple(padding) + + self.projection = nn.Conv2d( + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + if input_size: + input_size = to_2tuple(input_size) + self.init_input_size = input_size + + h_out = ( + input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 + ) // stride[0] + 1 + w_out = ( + input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 + ) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x): + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + return x, out_size + + +# ---------------------------------------------------------------------------- +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + inplace: bool = False, + data_format: str = "channels_last", + scale: float = 1e-5, + ): + super().__init__() + assert data_format in ( + "channels_last", + "channels_first", + ), "'data_format' could only be channels_last or channels_first." + self.inplace = inplace + self.data_format = data_format + self.weight = nn.Parameter(torch.ones(dim) * scale) + + def forward(self, x) -> torch.Tensor: + if self.data_format == "channels_first": + shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2)))) + else: + shape = tuple((*(1 for _ in range(x.dim() - 1)), -1)) + if self.inplace: + return x.mul_(self.weight.view(*shape)) + else: + return x * self.weight.view(*shape) + + +# ---------------------------------------------------------------------------- +class FFN(nn.Module): + def __init__( + self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + add_identity=True, + layer_scale_init_value=0.0, + ): + super().__init__() + assert num_fcs >= 2, f"num_fcs should be no less than 2. got {num_fcs}." + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + Sequential( + Linear(in_channels, feedforward_channels), + nn.GELU(), + nn.Dropout(ffn_drop), + ) + ) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = nn.Identity() + self.add_identity = add_identity + + if layer_scale_init_value > 0: + self.gamma2 = LayerScale(embed_dims, scale=layer_scale_init_value) + else: + self.gamma2 = nn.Identity() + + def forward(self, x, identity=None): + out = self.layers(x) + out = self.gamma2(out) + if not self.add_identity: + return out + if identity is None: + identity = x + return identity + out + + +# ---------------------------------------------------------------------------- +class MultiheadAttention(nn.Module): + def __init__( + self, + embed_dims, + num_heads, + input_dims=None, + attn_drop=0.0, + proj_drop=0.0, + qkv_bias=True, + proj_bias=True, + v_shortcut=False, + ): + super(MultiheadAttention, self).__init__() + + self.input_dims = input_dims or embed_dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.v_shortcut = v_shortcut + + self.head_dims = embed_dims // num_heads + self.scaled_dot_product_attention = F.scaled_dot_product_attention + + self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.gamma1 = nn.Identity() + + def forward(self, x): + B, N, _ = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dims) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn_drop = self.attn_drop if self.training else 0.0 + x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) + + x = self.proj(x) + x = self.gamma1(self.proj_drop(x)) + + if self.v_shortcut: + x = v.squeeze(1) + x + return x + + +# ---------------------------------------------------------------------------- +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0.0, + attn_drop_rate=0.0, + num_fcs=2, + qkv_bias=True, + ): + super(TransformerEncoderLayer, self).__init__() + + self.embed_dims = embed_dims + self.ln1 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True) + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + qkv_bias=qkv_bias, + ) + + self.ln2 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True) + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + add_identity=True, + ) + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) + return x + + +# ---------------------------------------------------------------------------- +@MODELS.register_module() +class Sapiens(BaseModel): + arch_zoo = { + **dict.fromkeys( ## this is vit-large + ["0.3b", "sapiens_0.3b"], + { + "embed_dims": 1024, + "num_layers": 24, + "num_heads": 16, + "feedforward_channels": 1024 * 4, + }, + ), + **dict.fromkeys( ## this is vit-huge + ["0.6b", "sapiens_0.6b"], + { + "embed_dims": 1280, + "num_layers": 32, + "num_heads": 16, + "feedforward_channels": 1280 * 4, + }, + ), + **dict.fromkeys( ## this is vit-g + ["1b", "sapiens_1b"], + { + "embed_dims": 1536, + "num_layers": 40, + "num_heads": 24, + "feedforward_channels": 1536 * 4, + }, + ), + **dict.fromkeys( + ["2b", "sapiens_2b"], + { + "embed_dims": 1920, + "num_layers": 48, + "num_heads": 32, + "feedforward_channels": 1920 * 4, + }, + ), + } + num_extra_tokens = 1 # class token + OUT_TYPES = {"raw", "cls_token", "featmap"} + + def __init__( + self, + arch="base", + img_size=1024, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0.0, + qkv_bias=True, + final_norm=True, + out_type="cls_token", + with_cls_token=True, + frozen_stages=-1, + interpolate_mode="bicubic", + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None, + ): + super(Sapiens, self).__init__(init_cfg=init_cfg) + + arch = arch.lower() + assert arch in set(self.arch_zoo), ( + f"Arch {arch} is not in default archs {set(self.arch_zoo)}" + ) + self.arch_settings = self.arch_zoo[arch] + + self.embed_dims = self.arch_settings["embed_dims"] + self.num_layers = self.arch_settings["num_layers"] + self.img_size = to_2tuple(img_size) + self.patch_size = patch_size + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + kernel_size=patch_size, + stride=patch_size, + bias=True, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError( + f"Unsupported `out_type` {out_type}, please " + f"choose from {self.OUT_TYPES}" + ) + self.out_type = out_type + + # Set cls token + self.with_cls_token = with_cls_token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != "cls_token": + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError('with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_extra_tokens, self.embed_dims) + ) + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), ( + f'"out_indices" must by a sequence or int, get {type(out_indices)} instead.' + ) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, ( + f"Invalid out_indices {index}" + ) + self.out_indices = out_indices + + self.layers = nn.Sequential() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings["num_heads"], + feedforward_channels=self.arch_settings["feedforward_channels"], + drop_rate=drop_rate, + qkv_bias=qkv_bias, + ) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(TransformerEncoderLayer(**_layer_cfg)) + + self.frozen_stages = frozen_stages + self.pre_norm = nn.Identity() + + self.final_norm = final_norm + if final_norm: + self.ln1 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + self.init_weights() + + return + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + "pos_embed" + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + + from sapiens.engine.logger import Logger + + logger = Logger.get_current_instance() + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + + # Handle class token removal if needed + if not self.with_cls_token: + if ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1: + # Remove cls token from state dict if it's not used + state_dict[name] = state_dict[name][:, 1:] + ckpt_pos_embed_shape = state_dict[name].shape + elif ckpt_pos_embed_shape[1] % 2 == 1: + # Remove class token when interpolation is required + if rank == 0: + logger.info( + "Note: removing the class token from pretrained weights" + ) + state_dict[name] = state_dict[name][:, 1:] + ckpt_pos_embed_shape = state_dict[name].shape + + # Skip if shapes already match + if self.pos_embed.shape == ckpt_pos_embed_shape: + return + + if rank == 0: + logger.info( + f"Resize the pos_embed shape from {ckpt_pos_embed_shape} " + f"to {self.pos_embed.shape}." + ) + + # Calculate grid dimensions + pos_h, pos_w = self.patch_embed.init_out_size + assert pos_h >= pos_w # for vertical aspect ratio or square + + # Number of non-extra tokens in checkpoint + num_vis = ckpt_pos_embed_shape[1] - self.num_extra_tokens + + # Determine original grid shape + side = int(math.sqrt(num_vis)) + factor = int(math.sqrt((num_vis * self.patch_size * self.patch_size) // 12)) + + # Set old grid based on aspect ratio detection + if side * side == num_vis: + old_grid = (side, side) # square grid + elif 4 * factor * 3 * factor == num_vis * self.patch_size * self.patch_size: + old_grid = ( + (factor * 4) // self.patch_size, + (factor * 3) // self.patch_size, + ) # 4:3 ratio + else: + if rank == 0: + logger.warning( + f"Original pos_embed tokens ({num_vis}) not square or 4:3 does not match current size" + ) + state_dict[name] = self.pos_embed + return + + # Resize position embedding + new_grid = (pos_h, pos_w) + state_dict[name] = resize_pos_embed( + state_dict[name], + old_grid, + new_grid, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens, + ) + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + @staticmethod + def resize_pos_embed(*args, **kwargs): + """Interface for backward-compatibility.""" + return resize_pos_embed(*args, **kwargs) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze pre-norm + for param in self.pre_norm.parameters(): + param.requires_grad = False + # freeze cls_token + if self.cls_token is not None: + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + if self.frozen_stages == len(self.layers): + if self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + if self.out_type == "avg_featmap": + self.ln2.eval() + for param in self.ln2.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens, + ) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) ## B x (num tokens) x embed_dim + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == "raw": + return x + if self.out_type == "cls_token": + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens :] + if self.out_type == "featmap": + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) diff --git a/sapiens/backbones/sapiens2.py b/sapiens/backbones/sapiens2.py new file mode 100644 index 0000000000000000000000000000000000000000..8661d8d7509abf0e986de404d70364da23a19040 --- /dev/null +++ b/sapiens/backbones/sapiens2.py @@ -0,0 +1,916 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.engine.models.base_model import BaseModel +from sapiens.registry import MODELS +from torch import Tensor +from torch.nn.init import trunc_normal_ +from torch.utils.checkpoint import checkpoint + + +# ---------------------------------------------------------------------------- +def to_2tuple(x): + if isinstance(x, (str, bytes)): + return (x, x) + if isinstance(x, Sequence): + x = tuple(x) + if len(x) == 2: + return x + raise ValueError("Expected scalar or length-2 iterable") + return (x, x) + + +class RopePositionEmbedding(nn.Module): + def __init__( + self, + embed_dim: int, + *, + num_heads: int, + base: float | None = 100.0, + min_period: float | None = None, + max_period: float | None = None, + normalize_coords: Literal["min", "max", "separate"] = "separate", + shift_coords: float | None = None, + jitter_coords: float | None = None, + rescale_coords: float | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ): + super().__init__() + assert embed_dim % (4 * num_heads) == 0 + both_periods = min_period is not None and max_period is not None + if (base is None and not both_periods) or (base is not None and both_periods): + raise ValueError( + "Either `base` or `min_period`+`max_period` must be provided." + ) + + D_head = embed_dim // num_heads + self.base = base + self.min_period = min_period + self.max_period = max_period + self.D_head = D_head + self.normalize_coords = normalize_coords + self.shift_coords = shift_coords + self.jitter_coords = jitter_coords + self.rescale_coords = rescale_coords + + # Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher + self.dtype = dtype or torch.float32 # Don't rely on self.periods.dtype + self.register_buffer( + "periods", + torch.empty(D_head // 4, device=device, dtype=self.dtype), + persistent=True, + ) + self._init_weights() + + def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]: + device = self.periods.device + dtype = self.dtype + dd = {"device": device, "dtype": dtype} + # Prepare coords in range [-1, +1] + if self.normalize_coords == "max": + max_HW = max(H, W) + coords_h = torch.arange(0.5, H, **dd) / max_HW # [H] + coords_w = torch.arange(0.5, W, **dd) / max_HW # [W] + elif self.normalize_coords == "min": + min_HW = min(H, W) + coords_h = torch.arange(0.5, H, **dd) / min_HW # [H] + coords_w = torch.arange(0.5, W, **dd) / min_HW # [W] + elif self.normalize_coords == "separate": + coords_h = torch.arange(0.5, H, **dd) / H # [H] + coords_w = torch.arange(0.5, W, **dd) / W # [W] + else: + raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}") + coords = torch.stack( + torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1 + ) # [H, W, 2] + coords = coords.flatten(0, 1) # [HW, 2] + coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1] + + # Shift coords by adding a uniform value in [-shift, shift] + if self.training and self.shift_coords is not None: + shift_hw = torch.empty(2, **dd).uniform_( + -self.shift_coords, self.shift_coords + ) + coords += shift_hw[None, :] + + # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter] + if self.training and self.jitter_coords is not None: + jitter_max = np.log(self.jitter_coords) + jitter_min = -jitter_max + jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp() + coords *= jitter_hw[None, :] + + # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale] + if self.training and self.rescale_coords is not None: + rescale_max = np.log(self.rescale_coords) + rescale_min = -rescale_max + rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp() + coords *= rescale_hw + + # Prepare angles and sin/cos + angles = ( + 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] + ) # [HW, 2, D//4] + angles = angles.flatten(1, 2) # [HW, D//2] + angles = angles.tile(2) # [HW, D] + cos = torch.cos(angles) # [HW, D] + sin = torch.sin(angles) # [HW, D] + + return (sin, cos) # 2 * [HW, D] + + def _init_weights(self): + device = self.periods.device + dtype = self.dtype + if self.base is not None: + periods = self.base ** ( + 2 + * torch.arange(self.D_head // 4, device=device, dtype=dtype) + / (self.D_head // 2) + ) # [D//4] + else: + base = self.max_period / self.min_period + exponents = torch.linspace( + 0, 1, self.D_head // 4, device=device, dtype=dtype + ) # [D//4] range [0, 1] + periods = base**exponents # range [1, max_period / min_period] + periods = periods / base # range [min_period / max_period, 1] + periods = periods * self.max_period # range [min_period, max_period] + self.periods.data = periods + + +# ------------------------------------------------------------------------------- +class Tokenizer(nn.Module): + """Stacked window selfโ€‘attention that emits one token per window + by reโ€‘using TransformerEncoderLayer blocks.""" + + def __init__( + self, + embed_dims: int, + window_size: int = 4, + num_heads: int = 4, + num_tokenizer_layers: int = 1, + qkv_bias: bool = True, + use_qk_norm: bool = False, + chunk_size: int = 1024, # max windows per chunk + ): + super().__init__() + self.ws = window_size + self.chunk_size = chunk_size + + # local absolute positional embeddings for [CLS] + patch tokens + self.local_pos_embed = nn.Parameter( + torch.zeros(1, 1 + window_size * window_size, embed_dims) + ) + trunc_normal_(self.local_pos_embed, std=0.02) + + # build N identical TransformerEncoderLayer blocks + self.blocks = nn.ModuleList( + [ + TransformerEncoderLayer2( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=embed_dims * 4, # standard FFN size + qkv_bias=qkv_bias, + use_qk_norm=use_qk_norm, + ) + for _ in range(num_tokenizer_layers) + ] + ) + + # shared CLS token for pooling + self.w_cls = nn.Parameter(torch.zeros(1, 1, embed_dims)) + trunc_normal_(self.w_cls, std=0.02) + + def forward( + self, + x: torch.Tensor, + hw: Tuple[int, int], + ) -> Tuple[torch.Tensor, Tuple[int, int]]: + """Args: + x : B, N, C (N = H*W) + hw : (H, W) before reduction + Returns: + x_ : B, (H/ws)*(W/ws), C + hw_: (H/ws, W/ws) + """ + B, N, C = x.shape + H, W = hw + ws = self.ws + assert H % ws == 0 and W % ws == 0, ( + f"Image size {H}ร—{W} must be divisible by window {ws}." + ) + + # reshape tokens โ†’ nonโ€‘overlapping windows + x = x.view(B, H, W, C) + + ph, pw = H // ws, W // ws ## ints in eager mode + ph, pw = int(ph), int(pw) ## ints in scripting mode + x = x.view(B, ph, ws, pw, ws, C) # B, H/ws, ws, W/ws, ws, C + x = x.permute(0, 1, 3, 2, 4, 5) # B, H/ws, W/ws, ws, ws, C + x = x.contiguous().view(B * ph * pw, ws * ws, C) # (B*H/ws*W/ws), wsยฒ, C)) + + total_windows = x.size(0) + chunk_size = int(min(self.chunk_size, total_windows)) + token_out = x.new_empty(total_windows, C) + + use_ckpt = self.training and torch.is_grad_enabled() + + def _run_blocks(t: torch.Tensor) -> torch.Tensor: + for blk in self.blocks: + t = blk(t) + return t + + for i in range(0, total_windows, chunk_size): + chunk = x[i : i + chunk_size] # (m, wsยฒ, C) + m = chunk.size(0) + cls = self.w_cls.expand(m, -1, -1) # (m, 1, C) + chunk = torch.cat([cls, chunk], dim=1) # (m, 1+wsยฒ, C) + chunk = chunk + self.local_pos_embed # add local PE + + if use_ckpt: + chunk = checkpoint(_run_blocks, chunk, use_reentrant=False) + else: + chunk = _run_blocks(chunk) + + token_out[i : i + m] = chunk[:, 0] # take CLS out + + token = token_out.view(B, ph * pw, C) # (B, (H/ws)*(W + return token, (ph, pw) + + +# ------------------------------------------------------------------------------- +class GroupedQueryAttention(nn.Module): + def __init__( + self, + embed_dims, + num_heads, + num_kv_heads=None, + input_dims=None, + attn_drop=0.0, + proj_drop=0.0, + qkv_bias=True, + qk_scale=None, + proj_bias=True, + use_qk_norm=True, + v_shortcut=False, + layer_scale_init_value=0.0, + ): + super().__init__() + # Core dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads or num_heads + assert self.num_heads % self.num_kv_heads == 0, ( + "num_kv_heads must divide num_heads" + ) + self.head_dim = embed_dims // num_heads + self.input_dims = input_dims or embed_dims + # Features + self.attn_drop = attn_drop + self.v_shortcut = v_shortcut + self.use_qk_norm = use_qk_norm + + # Attention operation selection + if qk_scale is not None: + scale = qk_scale + else: + scale = self.head_dim**-0.5 + + assert qk_scale is None, "qk_scale is not supported" + self.attn_op = F.scaled_dot_product_attention + + # Q/K/V projections + self.wq = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias) + self.wk = nn.Linear( + self.input_dims, self.num_kv_heads * self.head_dim, bias=qkv_bias + ) + self.wv = nn.Linear( + self.input_dims, self.num_kv_heads * self.head_dim, bias=qkv_bias + ) + + if self.use_qk_norm: + self.q_norm = nn.RMSNorm(self.head_dim, eps=1e-6) + self.k_norm = nn.RMSNorm(self.head_dim, eps=1e-6) + + # Output projection + dropout + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + # Optional LayerScale + if layer_scale_init_value > 0: + self.gamma = LayerScale(embed_dims, scale=layer_scale_init_value) + else: + self.gamma = nn.Identity() + + def apply_rope( + self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor] + ) -> Tuple[Tensor, Tensor]: + # All operations will use the dtype of rope, the output is cast back to the dtype of q and k + q_dtype = q.dtype + k_dtype = k.dtype + sin, cos = rope + rope_dtype = sin.dtype + q = q.to(dtype=rope_dtype) + k = k.to(dtype=rope_dtype) + N = q.shape[-2] + prefix = N - sin.shape[-2] ## extra tokens + assert prefix >= 0 + q_prefix = q[:, :, :prefix, :] + q = self._rope_apply(q[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head] + q = torch.cat((q_prefix, q), dim=-2) # [B, head, N, D//head] + k_prefix = k[:, :, :prefix, :] + k = self._rope_apply(k[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head] + k = torch.cat((k_prefix, k), dim=-2) # [B, head, N, D//head] + q = q.to(dtype=q_dtype) + k = k.to(dtype=k_dtype) + return q, k + + def _rope_rotate_half(self, x: Tensor) -> Tensor: + # x: [ x0 x1 x2 x3 x4 x5] + # out: [-x3 -x4 -x5 x0 x1 x2] + x1, x2 = x.chunk(2, dim=-1) + return torch.cat([-x2, x1], dim=-1) + + def _rope_apply(self, x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: + # x: [..., D], eg [x0, x1, x2, x3, x4, x5] + # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2] + # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2] + return (x * cos) + (self._rope_rotate_half(x) * sin) + + def forward(self, x, rope=None): + B, N, _ = x.shape + # Q: (B, N, num_heads, head_dim) + q = self.wq(x).view(B, N, self.num_heads, self.head_dim) + # K/V: (B, N, num_kv_heads, head_dim) + k = self.wk(x).view(B, N, self.num_kv_heads, self.head_dim) + v = self.wv(x).view(B, N, self.num_kv_heads, self.head_dim) + + # (B, heads, N, head_dim) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + if self.use_qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + + # Repeat KV heads if group ratio >1 + if self.num_kv_heads != self.num_heads: + factor = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(factor, dim=1) + v = v.repeat_interleave(factor, dim=1) + + if rope is not None: + q, k = self.apply_rope(q, k, rope) + + # Scaled dot-product attention + attn_out = self.attn_op( + q, k, v, dropout_p=self.attn_drop if self.training else 0.0 + ) # (B, num_heads, N, head_dim) + + # Merge heads -> (B, N, embed_dims) + out = attn_out.permute(0, 2, 1, 3).reshape(B, N, self.embed_dims) + + # Output projection + drop + layer scale + out = self.proj(out) + out = self.gamma(self.proj_drop(out)) + + # Optional V-shortcut (only when MQA) + if self.v_shortcut and self.num_kv_heads == 1: + raise NotImplementedError + return out + + +# ------------------------------------------------------------------------------- +class TransformerEncoderLayer2(nn.Module): + def __init__( + self, + embed_dims, + num_heads, + num_kv_heads=None, + feedforward_channels=None, + drop_rate=0.0, + attn_drop_rate=0.0, + layer_scale_init_value=0.0, + use_qk_norm=True, + qkv_bias=True, + ): + super(TransformerEncoderLayer2, self).__init__() + + self.embed_dims = embed_dims + self.ln1 = nn.RMSNorm(self.embed_dims, eps=1e-6) + self.attn = GroupedQueryAttention( + embed_dims=embed_dims, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + qkv_bias=qkv_bias, + layer_scale_init_value=layer_scale_init_value, + use_qk_norm=use_qk_norm, + ) + + self.ln2 = nn.RMSNorm(self.embed_dims, eps=1e-6) + self.ffn = SwiGLUFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + ) + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def forward(self, x, rope=None): + x = x + self.attn(self.ln1(x), rope=rope) + x = self.ffn(self.ln2(x), identity=x) + return x + + +##----------------------------------- +@MODELS.register_module() +class Sapiens2(BaseModel): + arch_zoo = { + **dict.fromkeys( + ["sapiens2_0.1b"], + { + "embed_dims": 768, + "num_layers": 12, + "num_heads": 12, + "feedforward_channels": 768 * 4, + "num_tokenizer_layers": 2, + }, + ), + **dict.fromkeys( + ["sapiens2_0.4b"], + { + "embed_dims": 1024, + "num_layers": 24, + "num_heads": 16, + "feedforward_channels": 1024 * 4, + "num_tokenizer_layers": 2, + }, + ), + **dict.fromkeys( + ["sapiens2_0.8b"], + { + "embed_dims": 1280, + "num_layers": 32, + "num_heads": 16, + "feedforward_channels": 1280 * 4, + "num_tokenizer_layers": 3, + }, + ), + **dict.fromkeys( + ["sapiens2_1b"], + { + "embed_dims": 1536, + "num_layers": 40, + "num_heads": 24, + "feedforward_channels": 1536 * 4, + "num_tokenizer_layers": 4, + }, + ), + **dict.fromkeys( + ["sapiens2_5b"], + { + "embed_dims": 2432, + "num_layers": 56, + "num_heads": 32, + "feedforward_channels": 2432 * 4, + "num_tokenizer_layers": 6, + }, + ), + } + + num_extra_tokens = 1 # class token + OUT_TYPES = {"raw", "cls_token", "featmap"} + + def __init__( + self, + arch="sapiens2_1b", + img_size=(1024, 768), + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0.0, + window_size=4, + use_tokenizer=False, ## 4k resolution + use_qk_norm=True, + qkv_bias=True, + final_norm=True, + out_type="raw", + with_cls_token=True, + layer_scale_init_value=1e-4, ## non zero init to activate layerscale + frozen_stages=-1, + patch_cfg=dict(), + layer_cfgs=dict(), + pos_embed_rope_base: float = 100.0, + pos_embed_rope_min_period: float | None = None, + pos_embed_rope_max_period: float | None = None, + pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate", + pos_embed_rope_shift_coords: float | None = None, + pos_embed_rope_jitter_coords: float | None = None, + pos_embed_rope_rescale_coords: float | None = None, + pos_embed_rope_dtype: str = "bf16", + n_storage_tokens: int = 8, + init_cfg=None, + ): + super(Sapiens2, self).__init__(init_cfg=init_cfg) + + arch = arch.lower() + assert arch in set(self.arch_zoo), ( + f"Arch {arch} is not in default archs {set(self.arch_zoo)}" + ) + self.arch_settings = self.arch_zoo[arch] + + self.embed_dims = self.arch_settings["embed_dims"] + self.num_layers = self.arch_settings["num_layers"] + self.patch_size = patch_size + + self.window_size = window_size + img_size = to_2tuple(img_size) + encoder_img_size = ( + (img_size[0] // window_size, img_size[1] // window_size) + if use_tokenizer + else img_size + ) + self.img_size = to_2tuple(encoder_img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=self.img_size, + embed_dims=self.embed_dims, + kernel_size=patch_size, + stride=patch_size, + bias=True, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + self.rope_embed = RopePositionEmbedding( + embed_dim=self.embed_dims, + num_heads=self.arch_settings["num_heads"], + base=pos_embed_rope_base, + min_period=pos_embed_rope_min_period, + max_period=pos_embed_rope_max_period, + normalize_coords=pos_embed_rope_normalize_coords, + shift_coords=pos_embed_rope_shift_coords, + jitter_coords=pos_embed_rope_jitter_coords, + rescale_coords=pos_embed_rope_rescale_coords, + dtype=torch.bfloat16 if pos_embed_rope_dtype == "bf16" else torch.float32, + ) + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError( + f"Unsupported `out_type` {out_type}, please " + f"choose from {self.OUT_TYPES}" + ) + self.out_type = out_type + + if use_tokenizer == True: + self.tokenizer = Tokenizer( + embed_dims=self.embed_dims, + window_size=self.window_size, + num_heads=self.arch_settings["num_heads"], + num_tokenizer_layers=self.arch_settings["num_tokenizer_layers"], + qkv_bias=True, + use_qk_norm=False, + ) + else: + self.tokenizer = None + + # Set cls + storage tokens + self.with_cls_token = with_cls_token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != "cls_token": + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError('with_cls_token must be True when `out_type="cls_token"`.') + + ## registers + self.n_storage_tokens = int(n_storage_tokens) + self.storage_tokens = ( + nn.Parameter(torch.zeros(1, self.n_storage_tokens, self.embed_dims)) + if self.n_storage_tokens > 0 + else None + ) + # how many non-patch tokens are at the front + self.num_extra_tokens = ( + 1 if self.cls_token is not None else 0 + ) + self.n_storage_tokens + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), ( + f'"out_indices" must by a sequence or int, get {type(out_indices)} instead.' + ) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, ( + f"Invalid out_indices {index}" + ) + self.out_indices = out_indices + + self.blocks = nn.Sequential() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + + mhsa_early, mhsa_late = 8, 8 + for i in range(self.num_layers): + if i < mhsa_early or i >= self.num_layers - mhsa_late: + num_kv_heads = None ## use MHSA + else: + num_kv_heads = self.arch_settings["num_heads"] // 2 # Use GQA + + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings["num_heads"], + num_kv_heads=num_kv_heads, + feedforward_channels=self.arch_settings["feedforward_channels"], + use_qk_norm=use_qk_norm, + layer_scale_init_value=layer_scale_init_value, + drop_rate=drop_rate, + qkv_bias=qkv_bias, + ) + _layer_cfg.update(layer_cfgs[i]) + self.blocks.append(TransformerEncoderLayer2(**_layer_cfg)) + + self.frozen_stages = frozen_stages + + self.final_norm = final_norm + if final_norm: + self.ln1 = nn.RMSNorm(self.embed_dims, eps=1e-6) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + ## load init weights + self.init_weights() + + return + + def init_weights(self): + if self.init_cfg is not None: + super(Sapiens2, self).init_weights() + return + + # Initialize class token and storagr token embeddings + if self.with_cls_token: + trunc_normal_(self.cls_token, std=0.02) + + if self.storage_tokens is not None: + trunc_normal_(self.storage_tokens, std=0.02) + + # Apply custom initialization to all submodules + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # Use a truncated normal distribution for linear layer weights + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, (nn.LayerNorm, nn.RMSNorm)): + # Initialize normalization layers to act as an identity function + if hasattr(m, "bias") and m.bias is not None: + nn.init.constant_(m.bias, 0) + if hasattr(m, "weight") and m.weight is not None: + nn.init.constant_(m.weight, 1.0) + + elif isinstance(m, nn.Conv2d): + # Initialize conv layer weights like linear layers + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _freeze_stages(self): + ## freeze tokenizer + if self.frozen_stages >= 1 and self.tokenizer is not None: + self.tokenizer.eval() + for param in self.tokenizer.parameters(): + param.requires_grad = False + + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze cls_token + if self.cls_token is not None: + self.cls_token.requires_grad = False + if self.storage_tokens is not None: + self.storage_tokens.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze the last layer norm + if self.frozen_stages == len(self.blocks): + if self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + + x, patch_resolution = self.patch_embed(x) # (B, 256*256, C) + if self.tokenizer is not None: + x, patch_resolution = self.tokenizer(x, patch_resolution) + + # prepend [CLS] and storage tokens + prepend = [] + if self.cls_token is not None: + prepend.append(self.cls_token.expand(B, -1, -1)) + if self.storage_tokens is not None: + prepend.append(self.storage_tokens.expand(B, -1, -1)) + if len(prepend) > 0: + x = torch.cat(prepend + [x], dim=1) + + rope_sincos = self.rope_embed(H=patch_resolution[0], W=patch_resolution[1]) + outs = [] + for i, layer in enumerate(self.blocks): + x = layer(x, rope=rope_sincos) + + if i == len(self.blocks) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == "raw": + return x + if self.out_type == "cls_token": + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens :] + if self.out_type == "featmap": + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + + @property + def norm1(self): + return self.ln1 + + +# ---------------------------------------------------------------------------- +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + inplace: bool = False, + data_format: str = "channels_last", + scale: float = 1e-5, + ): + super().__init__() + assert data_format in ( + "channels_last", + "channels_first", + ), "'data_format' could only be channels_last or channels_first." + self.inplace = inplace + self.data_format = data_format + self.weight = nn.Parameter(torch.ones(dim) * scale) + + def forward(self, x) -> torch.Tensor: + if self.data_format == "channels_first": + shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2)))) + else: + shape = tuple((*(1 for _ in range(x.dim() - 1)), -1)) + if self.inplace: + return x.mul_(self.weight.view(*shape)) + else: + return x * self.weight.view(*shape) + + +# ---------------------------------------------------------------------------- +class PatchEmbed(nn.Module): + def __init__( + self, + in_channels=3, + embed_dims=768, + kernel_size=16, + stride=16, + padding="corner", + dilation=1, + bias=True, + input_size=None, + ): + super().__init__() + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + padding = 0 + padding = to_2tuple(padding) + + self.projection = nn.Conv2d( + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + if input_size: + input_size = to_2tuple(input_size) + self.init_input_size = input_size + h_out = ( + input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 + ) // stride[0] + 1 + w_out = ( + input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 + ) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x): + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + return x, out_size + + +# ---------------------------------------------------------------------------- +class SwiGLUFFN(nn.Module): + """SwiGLU FFN layer. + https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py + """ # noqa + + def __init__( + self, + embed_dims: int, + feedforward_channels: Optional[int] = None, + out_dims: Optional[int] = None, + layer_scale_init_value: float = 0.0, + bias: bool = True, + add_identity: bool = True, + ) -> None: + super().__init__() + self.embed_dims = embed_dims + self.out_dims = out_dims or embed_dims + hidden_dims = feedforward_channels or embed_dims + + self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias) + self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias) + + if layer_scale_init_value > 0: + self.gamma2 = LayerScale(dim=embed_dims, scale=layer_scale_init_value) + else: + self.gamma2 = nn.Identity() + + self.add_identity = add_identity + + def forward( + self, x: torch.Tensor, identity: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + out = self.w3(hidden) + out = self.gamma2(out) + + if self.out_dims != self.embed_dims or not self.add_identity: + # due to the dimension inconsistence or user setting + # not to apply residual operation + return out + + if identity is None: + identity = x + return identity + out diff --git a/sapiens/backbones/standalone/sapiens.py b/sapiens/backbones/standalone/sapiens.py new file mode 100644 index 0000000000000000000000000000000000000000..3f7d04ae087cf514d7050bce34f887a3285a39a8 --- /dev/null +++ b/sapiens/backbones/standalone/sapiens.py @@ -0,0 +1,648 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Linear, Sequential + + +# ---------------------------------------------------------------------------- +def to_2tuple(x): + if isinstance(x, (str, bytes)): + return (x, x) + if isinstance(x, Sequence): + x = tuple(x) + if len(x) == 2: + return x + raise ValueError("Expected scalar or length-2 iterable") + return (x, x) + + +def resize_pos_embed( + pos_embed, src_shape, dst_shape, mode="bicubic", num_extra_tokens=1 +): + if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: + return pos_embed + assert pos_embed.ndim == 3, "shape of pos_embed must be [1, L, C]" + _, L, C = pos_embed.shape + src_h, src_w = src_shape + assert L == src_h * src_w + num_extra_tokens, ( + f"The length of `pos_embed` ({L}) doesn't match the expected " + f"shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the" + "`img_size` argument." + ) + extra_tokens = pos_embed[:, :num_extra_tokens] + + src_weight = pos_embed[:, num_extra_tokens:] + src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2) + + # The cubic interpolate algorithm only accepts float32 + dst_weight = F.interpolate( + src_weight.float(), size=dst_shape, align_corners=False, mode=mode + ) + dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2) + dst_weight = dst_weight.to(src_weight.dtype) + + return torch.cat((extra_tokens, dst_weight), dim=1) + + +# ---------------------------------------------------------------------------- +class AdaptivePadding(nn.Module): + def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"): + super().__init__() + assert padding in ("same", "corner") + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max( + (output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, + 0, + ) + pad_w = max( + (output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, + 0, + ) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == "corner": + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == "same": + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + return x + + +# ---------------------------------------------------------------------------- +class PatchEmbed(nn.Module): + def __init__( + self, + in_channels=3, + embed_dims=768, + kernel_size=16, + stride=16, + padding="corner", + dilation=1, + bias=True, + input_size=None, + ): + super().__init__() + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adaptive_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + ) + padding = 0 + else: + self.adaptive_padding = None + padding = to_2tuple(padding) + + self.projection = nn.Conv2d( + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + if input_size: + input_size = to_2tuple(input_size) + self.init_input_size = input_size + if self.adaptive_padding: + pad_h, pad_w = self.adaptive_padding.get_pad_shape(input_size) + input_h, input_w = input_size + input_h = input_h + pad_h + input_w = input_w + pad_w + input_size = (input_h, input_w) + + h_out = ( + input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 + ) // stride[0] + 1 + w_out = ( + input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 + ) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]: + if self.adaptive_padding: + x = self.adaptive_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + return x, out_size + + +# ---------------------------------------------------------------------------- +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + inplace: bool = False, + data_format: str = "channels_last", + scale: float = 1e-5, + ): + super().__init__() + assert data_format in ( + "channels_last", + "channels_first", + ), "'data_format' could only be channels_last or channels_first." + self.inplace = inplace + self.data_format = data_format + self.weight = nn.Parameter(torch.ones(dim) * scale) + + def forward(self, x) -> torch.Tensor: + if self.data_format == "channels_first": + shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2)))) + else: + shape = tuple((*(1 for _ in range(x.dim() - 1)), -1)) + if self.inplace: + return x.mul_(self.weight.view(*shape)) + else: + return x * self.weight.view(*shape) + + +# ---------------------------------------------------------------------------- +class FFN(nn.Module): + def __init__( + self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + add_identity=True, + layer_scale_init_value=0.0, + ): + super().__init__() + assert num_fcs >= 2, f"num_fcs should be no less than 2. got {num_fcs}." + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + Sequential( + Linear(in_channels, feedforward_channels), + nn.GELU(), + nn.Dropout(ffn_drop), + ) + ) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = nn.Identity() + self.add_identity = add_identity + + if layer_scale_init_value > 0: + self.gamma2 = LayerScale(embed_dims, scale=layer_scale_init_value) + else: + self.gamma2 = nn.Identity() + + def forward(self, x, identity=None): + out = self.layers(x) + out = self.gamma2(out) + if not self.add_identity: + return out + if identity is None: + identity = x + return identity + out + + +# ---------------------------------------------------------------------------- +class MultiheadAttention(nn.Module): + def __init__( + self, + embed_dims, + num_heads, + input_dims=None, + attn_drop=0.0, + proj_drop=0.0, + qkv_bias=True, + proj_bias=True, + v_shortcut=False, + ): + super(MultiheadAttention, self).__init__() + + self.input_dims = input_dims or embed_dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.v_shortcut = v_shortcut + + self.head_dims = embed_dims // num_heads + self.scaled_dot_product_attention = F.scaled_dot_product_attention + + self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.gamma1 = nn.Identity() + + def forward(self, x): + B, N, _ = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dims) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn_drop = self.attn_drop if self.training else 0.0 + x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) + + x = self.proj(x) + x = self.gamma1(self.proj_drop(x)) + + if self.v_shortcut: + x = v.squeeze(1) + x + return x + + +# ---------------------------------------------------------------------------- +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0.0, + attn_drop_rate=0.0, + num_fcs=2, + qkv_bias=True, + ): + super(TransformerEncoderLayer, self).__init__() + + self.embed_dims = embed_dims + self.ln1 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True) + self.attn = MultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + qkv_bias=qkv_bias, + ) + + self.ln2 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True) + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + add_identity=True, + ) + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) + return x + + +# ---------------------------------------------------------------------------- +class Sapiens(nn.Module): + arch_zoo = { + **dict.fromkeys( ## this is vit-large + ["0.3b", "sapiens_0.3b"], + { + "embed_dims": 1024, + "num_layers": 24, + "num_heads": 16, + "feedforward_channels": 1024 * 4, + }, + ), + **dict.fromkeys( ## this is vit-huge + ["0.6b", "sapiens_0.6b"], + { + "embed_dims": 1280, + "num_layers": 32, + "num_heads": 16, + "feedforward_channels": 1280 * 4, + }, + ), + **dict.fromkeys( ## this is vit-g + ["1b", "sapiens_1b"], + { + "embed_dims": 1536, + "num_layers": 40, + "num_heads": 24, + "feedforward_channels": 1536 * 4, + }, + ), + **dict.fromkeys( + ["2b", "sapiens_2b"], + { + "embed_dims": 1920, + "num_layers": 48, + "num_heads": 32, + "feedforward_channels": 1920 * 4, + }, + ), + } + num_extra_tokens = 1 # class token + OUT_TYPES = {"raw", "cls_token", "featmap", "avg_featmap"} + + def __init__( + self, + arch="base", + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0.0, + qkv_bias=True, + final_norm=True, + out_type="cls_token", + with_cls_token=True, + frozen_stages=-1, + interpolate_mode="bicubic", + patch_cfg=dict(), + layer_cfgs=dict(), + ): + super(Sapiens, self).__init__() + + arch = arch.lower() + assert arch in set(self.arch_zoo), ( + f"Arch {arch} is not in default archs {set(self.arch_zoo)}" + ) + self.arch_settings = self.arch_zoo[arch] + + self.embed_dims = self.arch_settings["embed_dims"] + self.num_layers = self.arch_settings["num_layers"] + self.img_size = to_2tuple(img_size) + self.patch_size = patch_size + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + kernel_size=patch_size, + stride=patch_size, + bias=True, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError( + f"Unsupported `out_type` {out_type}, please " + f"choose from {self.OUT_TYPES}" + ) + self.out_type = out_type + + # Set cls token + self.with_cls_token = with_cls_token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != "cls_token": + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError('with_cls_token must be True when `out_type="cls_token"`.') + + # Set position embedding + self.interpolate_mode = interpolate_mode + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_extra_tokens, self.embed_dims) + ) + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), ( + f'"out_indices" must by a sequence or int, get {type(out_indices)} instead.' + ) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, ( + f"Invalid out_indices {index}" + ) + self.out_indices = out_indices + + self.layers = nn.Sequential() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + for i in range(self.num_layers): + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings["num_heads"], + feedforward_channels=self.arch_settings["feedforward_channels"], + drop_rate=drop_rate, + qkv_bias=qkv_bias, + ) + _layer_cfg.update(layer_cfgs[i]) + self.layers.append(TransformerEncoderLayer(**_layer_cfg)) + + self.frozen_stages = frozen_stages + self.pre_norm = nn.Identity() + + self.final_norm = final_norm + if final_norm: + self.ln1 = nn.LayerNorm(self.embed_dims, eps=1e-6, elementwise_affine=True) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + self._register_load_state_dict_pre_hook(self._prepare_pos_embed) + + return + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + "pos_embed" + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + + # Handle class token removal if needed + if not self.with_cls_token: + if ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1: + # Remove cls token from state dict if it's not used + state_dict[name] = state_dict[name][:, 1:] + ckpt_pos_embed_shape = state_dict[name].shape + elif ckpt_pos_embed_shape[1] % 2 == 1: + # Remove class token when interpolation is required + state_dict[name] = state_dict[name][:, 1:] + ckpt_pos_embed_shape = state_dict[name].shape + + # Skip if shapes already match + if self.pos_embed.shape == ckpt_pos_embed_shape: + return + + # Calculate grid dimensions + pos_h, pos_w = self.patch_embed.init_out_size + assert pos_h >= pos_w # for vertical aspect ratio or square + + # Number of non-extra tokens in checkpoint + num_vis = ckpt_pos_embed_shape[1] - self.num_extra_tokens + + # Determine original grid shape + side = int(math.sqrt(num_vis)) + factor = int(math.sqrt((num_vis * self.patch_size * self.patch_size) // 12)) + + # Set old grid based on aspect ratio detection + if side * side == num_vis: + old_grid = (side, side) # square grid + elif 4 * factor * 3 * factor == num_vis * self.patch_size * self.patch_size: + old_grid = ( + (factor * 4) // self.patch_size, + (factor * 3) // self.patch_size, + ) # 4:3 ratio + else: + state_dict[name] = self.pos_embed + return + + # Resize position embedding + new_grid = (pos_h, pos_w) + state_dict[name] = resize_pos_embed( + state_dict[name], + old_grid, + new_grid, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens, + ) + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + @staticmethod + def resize_pos_embed(*args, **kwargs): + """Interface for backward-compatibility.""" + return resize_pos_embed(*args, **kwargs) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze pre-norm + for param in self.pre_norm.parameters(): + param.requires_grad = False + # freeze cls_token + if self.cls_token is not None: + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + if self.frozen_stages == len(self.layers): + if self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + if self.out_type == "avg_featmap": + self.ln2.eval() + for param in self.ln2.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + if self.cls_token is not None: + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens, + ) + x = self.drop_after_pos(x) + + x = self.pre_norm(x) ## B x (num tokens) x embed_dim + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == "raw": + return x + if self.out_type == "cls_token": + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens :] + if self.out_type == "featmap": + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) diff --git a/sapiens/backbones/standalone/sapiens2.py b/sapiens/backbones/standalone/sapiens2.py new file mode 100644 index 0000000000000000000000000000000000000000..9682f0481e05736f2743688a964230c63c3a06ea --- /dev/null +++ b/sapiens/backbones/standalone/sapiens2.py @@ -0,0 +1,908 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.init import trunc_normal_ +from torch.utils.checkpoint import checkpoint + + +# ---------------------------------------------------------------------------- +def to_2tuple(x): + if isinstance(x, (str, bytes)): + return (x, x) + if isinstance(x, Sequence): + x = tuple(x) + if len(x) == 2: + return x + raise ValueError("Expected scalar or length-2 iterable") + return (x, x) + + +class RopePositionEmbedding(nn.Module): + def __init__( + self, + embed_dim: int, + *, + num_heads: int, + base: float | None = 100.0, + min_period: float | None = None, + max_period: float | None = None, + normalize_coords: Literal["min", "max", "separate"] = "separate", + shift_coords: float | None = None, + jitter_coords: float | None = None, + rescale_coords: float | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ): + super().__init__() + assert embed_dim % (4 * num_heads) == 0 + both_periods = min_period is not None and max_period is not None + if (base is None and not both_periods) or (base is not None and both_periods): + raise ValueError( + "Either `base` or `min_period`+`max_period` must be provided." + ) + + D_head = embed_dim // num_heads + self.base = base + self.min_period = min_period + self.max_period = max_period + self.D_head = D_head + self.normalize_coords = normalize_coords + self.shift_coords = shift_coords + self.jitter_coords = jitter_coords + self.rescale_coords = rescale_coords + + # Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher + self.dtype = dtype or torch.float32 # Don't rely on self.periods.dtype + self.register_buffer( + "periods", + torch.empty(D_head // 4, device=device, dtype=self.dtype), + persistent=True, + ) + self._init_weights() + + def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]: + device = self.periods.device + dtype = self.dtype + dd = {"device": device, "dtype": dtype} + # Prepare coords in range [-1, +1] + if self.normalize_coords == "max": + max_HW = max(H, W) + coords_h = torch.arange(0.5, H, **dd) / max_HW # [H] + coords_w = torch.arange(0.5, W, **dd) / max_HW # [W] + elif self.normalize_coords == "min": + min_HW = min(H, W) + coords_h = torch.arange(0.5, H, **dd) / min_HW # [H] + coords_w = torch.arange(0.5, W, **dd) / min_HW # [W] + elif self.normalize_coords == "separate": + coords_h = torch.arange(0.5, H, **dd) / H # [H] + coords_w = torch.arange(0.5, W, **dd) / W # [W] + else: + raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}") + coords = torch.stack( + torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1 + ) # [H, W, 2] + coords = coords.flatten(0, 1) # [HW, 2] + coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1] + + # Shift coords by adding a uniform value in [-shift, shift] + if self.training and self.shift_coords is not None: + shift_hw = torch.empty(2, **dd).uniform_( + -self.shift_coords, self.shift_coords + ) + coords += shift_hw[None, :] + + # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter] + if self.training and self.jitter_coords is not None: + jitter_max = np.log(self.jitter_coords) + jitter_min = -jitter_max + jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp() + coords *= jitter_hw[None, :] + + # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale] + if self.training and self.rescale_coords is not None: + rescale_max = np.log(self.rescale_coords) + rescale_min = -rescale_max + rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp() + coords *= rescale_hw + + # Prepare angles and sin/cos + angles = ( + 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] + ) # [HW, 2, D//4] + angles = angles.flatten(1, 2) # [HW, D//2] + angles = angles.tile(2) # [HW, D] + cos = torch.cos(angles) # [HW, D] + sin = torch.sin(angles) # [HW, D] + + return (sin, cos) # 2 * [HW, D] + + def _init_weights(self): + device = self.periods.device + dtype = self.dtype + if self.base is not None: + periods = self.base ** ( + 2 + * torch.arange(self.D_head // 4, device=device, dtype=dtype) + / (self.D_head // 2) + ) # [D//4] + else: + base = self.max_period / self.min_period + exponents = torch.linspace( + 0, 1, self.D_head // 4, device=device, dtype=dtype + ) # [D//4] range [0, 1] + periods = base**exponents # range [1, max_period / min_period] + periods = periods / base # range [min_period / max_period, 1] + periods = periods * self.max_period # range [min_period, max_period] + self.periods.data = periods + + +# ------------------------------------------------------------------------------- +class Tokenizer(nn.Module): + """Stacked window selfโ€‘attention that emits one token per window + by reโ€‘using TransformerEncoderLayer blocks.""" + + def __init__( + self, + embed_dims: int, + window_size: int = 4, + num_heads: int = 4, + num_tokenizer_layers: int = 1, + qkv_bias: bool = True, + use_qk_norm: bool = False, + chunk_size: int = 1024, # max windows per chunk + ): + super().__init__() + self.ws = window_size + self.chunk_size = chunk_size + + # local absolute positional embeddings for [CLS] + patch tokens + self.local_pos_embed = nn.Parameter( + torch.zeros(1, 1 + window_size * window_size, embed_dims) + ) + trunc_normal_(self.local_pos_embed, std=0.02) + + # build N identical TransformerEncoderLayer blocks + self.blocks = nn.ModuleList( + [ + TransformerEncoderLayer2( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=embed_dims * 4, # standard FFN size + qkv_bias=qkv_bias, + use_qk_norm=use_qk_norm, + ) + for _ in range(num_tokenizer_layers) + ] + ) + + # shared CLS token for pooling + self.w_cls = nn.Parameter(torch.zeros(1, 1, embed_dims)) + trunc_normal_(self.w_cls, std=0.02) + + def forward( + self, + x: torch.Tensor, + hw: Tuple[int, int], + ) -> Tuple[torch.Tensor, Tuple[int, int]]: + """Args: + x : B, N, C (N = H*W) + hw : (H, W) before reduction + Returns: + x_ : B, (H/ws)*(W/ws), C + hw_: (H/ws, W/ws) + """ + B, N, C = x.shape + H, W = hw + ws = self.ws + assert H % ws == 0 and W % ws == 0, ( + f"Image size {H}ร—{W} must be divisible by window {ws}." + ) + + # reshape tokens โ†’ nonโ€‘overlapping windows + x = x.view(B, H, W, C) + + ph, pw = H // ws, W // ws ## ints in eager mode + ph, pw = int(ph), int(pw) ## ints in scripting mode + x = x.view(B, ph, ws, pw, ws, C) # B, H/ws, ws, W/ws, ws, C + x = x.permute(0, 1, 3, 2, 4, 5) # B, H/ws, W/ws, ws, ws, C + x = x.contiguous().view(B * ph * pw, ws * ws, C) # (B*H/ws*W/ws), wsยฒ, C)) + + total_windows = x.size(0) + chunk_size = int(min(self.chunk_size, total_windows)) + token_out = x.new_empty(total_windows, C) + + use_ckpt = self.training and torch.is_grad_enabled() + + def _run_blocks(t: torch.Tensor) -> torch.Tensor: + for blk in self.blocks: + t = blk(t) + return t + + for i in range(0, total_windows, chunk_size): + chunk = x[i : i + chunk_size] # (m, wsยฒ, C) + m = chunk.size(0) + cls = self.w_cls.expand(m, -1, -1) # (m, 1, C) + chunk = torch.cat([cls, chunk], dim=1) # (m, 1+wsยฒ, C) + chunk = chunk + self.local_pos_embed # add local PE + + if use_ckpt: + chunk = checkpoint(_run_blocks, chunk, use_reentrant=False) + else: + chunk = _run_blocks(chunk) + + token_out[i : i + m] = chunk[:, 0] # take CLS out + + token = token_out.view(B, ph * pw, C) # (B, (H/ws)*(W + return token, (ph, pw) + + +# ------------------------------------------------------------------------------- +class GroupedQueryAttention(nn.Module): + def __init__( + self, + embed_dims, + num_heads, + num_kv_heads=None, + input_dims=None, + attn_drop=0.0, + proj_drop=0.0, + qkv_bias=True, + qk_scale=None, + proj_bias=True, + use_qk_norm=True, + v_shortcut=False, + layer_scale_init_value=0.0, + ): + super().__init__() + # Core dims + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads or num_heads + assert self.num_heads % self.num_kv_heads == 0, ( + "num_kv_heads must divide num_heads" + ) + self.head_dim = embed_dims // num_heads + self.input_dims = input_dims or embed_dims + # Features + self.attn_drop = attn_drop + self.v_shortcut = v_shortcut + self.use_qk_norm = use_qk_norm + + # Attention operation selection + if qk_scale is not None: + scale = qk_scale + else: + scale = self.head_dim**-0.5 + + assert qk_scale is None, "qk_scale is not supported" + self.attn_op = F.scaled_dot_product_attention + + # Q/K/V projections + self.wq = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias) + self.wk = nn.Linear( + self.input_dims, self.num_kv_heads * self.head_dim, bias=qkv_bias + ) + self.wv = nn.Linear( + self.input_dims, self.num_kv_heads * self.head_dim, bias=qkv_bias + ) + + if self.use_qk_norm: + self.q_norm = nn.RMSNorm(self.head_dim, eps=1e-6) + self.k_norm = nn.RMSNorm(self.head_dim, eps=1e-6) + + # Output projection + dropout + self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + # Optional LayerScale + if layer_scale_init_value > 0: + self.gamma = LayerScale(embed_dims, scale=layer_scale_init_value) + else: + self.gamma = nn.Identity() + + def apply_rope( + self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor] + ) -> Tuple[Tensor, Tensor]: + # All operations will use the dtype of rope, the output is cast back to the dtype of q and k + q_dtype = q.dtype + k_dtype = k.dtype + sin, cos = rope + rope_dtype = sin.dtype + q = q.to(dtype=rope_dtype) + k = k.to(dtype=rope_dtype) + N = q.shape[-2] + prefix = N - sin.shape[-2] ## extra tokens + assert prefix >= 0 + q_prefix = q[:, :, :prefix, :] + q = self._rope_apply(q[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head] + q = torch.cat((q_prefix, q), dim=-2) # [B, head, N, D//head] + k_prefix = k[:, :, :prefix, :] + k = self._rope_apply(k[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head] + k = torch.cat((k_prefix, k), dim=-2) # [B, head, N, D//head] + q = q.to(dtype=q_dtype) + k = k.to(dtype=k_dtype) + return q, k + + def _rope_rotate_half(self, x: Tensor) -> Tensor: + # x: [ x0 x1 x2 x3 x4 x5] + # out: [-x3 -x4 -x5 x0 x1 x2] + x1, x2 = x.chunk(2, dim=-1) + return torch.cat([-x2, x1], dim=-1) + + def _rope_apply(self, x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: + # x: [..., D], eg [x0, x1, x2, x3, x4, x5] + # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2] + # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2] + return (x * cos) + (self._rope_rotate_half(x) * sin) + + def forward(self, x, rope=None): + B, N, _ = x.shape + # Q: (B, N, num_heads, head_dim) + q = self.wq(x).view(B, N, self.num_heads, self.head_dim) + # K/V: (B, N, num_kv_heads, head_dim) + k = self.wk(x).view(B, N, self.num_kv_heads, self.head_dim) + v = self.wv(x).view(B, N, self.num_kv_heads, self.head_dim) + + # (B, heads, N, head_dim) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + if self.use_qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + + # Repeat KV heads if group ratio >1 + if self.num_kv_heads != self.num_heads: + factor = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(factor, dim=1) + v = v.repeat_interleave(factor, dim=1) + + if rope is not None: + q, k = self.apply_rope(q, k, rope) + + # Scaled dot-product attention + attn_out = self.attn_op( + q, k, v, dropout_p=self.attn_drop if self.training else 0.0 + ) # (B, num_heads, N, head_dim) + + # Merge heads -> (B, N, embed_dims) + out = attn_out.permute(0, 2, 1, 3).reshape(B, N, self.embed_dims) + + # Output projection + drop + layer scale + out = self.proj(out) + out = self.gamma(self.proj_drop(out)) + + # Optional V-shortcut (only when MQA) + if self.v_shortcut and self.num_kv_heads == 1: + raise NotImplementedError + return out + + +# ------------------------------------------------------------------------------- +class TransformerEncoderLayer2(nn.Module): + def __init__( + self, + embed_dims, + num_heads, + num_kv_heads=None, + feedforward_channels=None, + drop_rate=0.0, + attn_drop_rate=0.0, + layer_scale_init_value=0.0, + use_qk_norm=True, + qkv_bias=True, + ): + super(TransformerEncoderLayer2, self).__init__() + + self.embed_dims = embed_dims + self.ln1 = nn.RMSNorm(self.embed_dims, eps=1e-6) + self.attn = GroupedQueryAttention( + embed_dims=embed_dims, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + qkv_bias=qkv_bias, + layer_scale_init_value=layer_scale_init_value, + use_qk_norm=use_qk_norm, + ) + + self.ln2 = nn.RMSNorm(self.embed_dims, eps=1e-6) + self.ffn = SwiGLUFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + ) + + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def forward(self, x, rope=None): + x = x + self.attn(self.ln1(x), rope=rope) + x = self.ffn(self.ln2(x), identity=x) + return x + + +##----------------------------------- +class Sapiens2(nn.Module): + arch_zoo = { + **dict.fromkeys( + ["sapiens2_0.1b"], + { + "embed_dims": 768, + "num_layers": 12, + "num_heads": 12, + "feedforward_channels": 768 * 4, + "num_tokenizer_layers": 2, + }, + ), + **dict.fromkeys( + ["sapiens2_0.4b"], + { + "embed_dims": 1024, + "num_layers": 24, + "num_heads": 16, + "feedforward_channels": 1024 * 4, + "num_tokenizer_layers": 2, + }, + ), + **dict.fromkeys( + ["sapiens2_0.8b"], + { + "embed_dims": 1280, + "num_layers": 32, + "num_heads": 16, + "feedforward_channels": 1280 * 4, + "num_tokenizer_layers": 3, + }, + ), + **dict.fromkeys( + ["sapiens2_1b"], + { + "embed_dims": 1536, + "num_layers": 40, + "num_heads": 24, + "feedforward_channels": 1536 * 4, + "num_tokenizer_layers": 4, + }, + ), + **dict.fromkeys( + ["sapiens2_5b"], + { + "embed_dims": 2432, + "num_layers": 56, + "num_heads": 32, + "feedforward_channels": 2432 * 4, + "num_tokenizer_layers": 6, + }, + ), + } + + num_extra_tokens = 1 # class token + OUT_TYPES = {"raw", "cls_token", "featmap"} + + def __init__( + self, + arch="sapiens2_1b", + img_size=(1024, 768), + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0.0, + window_size=4, + use_tokenizer=False, ## 4k resolution + use_qk_norm=True, + qkv_bias=True, + final_norm=True, + out_type="raw", + with_cls_token=True, + layer_scale_init_value=1e-4, ## non zero init to activate layerscale + frozen_stages=-1, + patch_cfg=dict(), + layer_cfgs=dict(), + pos_embed_rope_base: float = 100.0, + pos_embed_rope_min_period: float | None = None, + pos_embed_rope_max_period: float | None = None, + pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate", + pos_embed_rope_shift_coords: float | None = None, + pos_embed_rope_jitter_coords: float | None = None, + pos_embed_rope_rescale_coords: float | None = None, + pos_embed_rope_dtype: str = "bf16", + n_storage_tokens: int = 8, + ): + super().__init__() + + arch = arch.lower() + assert arch in set(self.arch_zoo), ( + f"Arch {arch} is not in default archs {set(self.arch_zoo)}" + ) + self.arch_settings = self.arch_zoo[arch] + + self.embed_dims = self.arch_settings["embed_dims"] + self.num_layers = self.arch_settings["num_layers"] + self.patch_size = patch_size + + self.window_size = window_size + img_size = to_2tuple(img_size) + encoder_img_size = ( + (img_size[0] // window_size, img_size[1] // window_size) + if use_tokenizer + else img_size + ) + self.img_size = to_2tuple(encoder_img_size) + + # Set patch embedding + _patch_cfg = dict( + in_channels=in_channels, + input_size=self.img_size, + embed_dims=self.embed_dims, + kernel_size=patch_size, + stride=patch_size, + bias=True, + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + + self.rope_embed = RopePositionEmbedding( + embed_dim=self.embed_dims, + num_heads=self.arch_settings["num_heads"], + base=pos_embed_rope_base, + min_period=pos_embed_rope_min_period, + max_period=pos_embed_rope_max_period, + normalize_coords=pos_embed_rope_normalize_coords, + shift_coords=pos_embed_rope_shift_coords, + jitter_coords=pos_embed_rope_jitter_coords, + rescale_coords=pos_embed_rope_rescale_coords, + dtype=torch.bfloat16 if pos_embed_rope_dtype == "bf16" else torch.float32, + ) + + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError( + f"Unsupported `out_type` {out_type}, please " + f"choose from {self.OUT_TYPES}" + ) + self.out_type = out_type + + if use_tokenizer == True: + self.tokenizer = Tokenizer( + embed_dims=self.embed_dims, + window_size=self.window_size, + num_heads=self.arch_settings["num_heads"], + num_tokenizer_layers=self.arch_settings["num_tokenizer_layers"], + qkv_bias=True, + use_qk_norm=False, + ) + else: + self.tokenizer = None + + # Set cls + storage tokens + self.with_cls_token = with_cls_token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != "cls_token": + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError('with_cls_token must be True when `out_type="cls_token"`.') + + ## registers + self.n_storage_tokens = int(n_storage_tokens) + self.storage_tokens = ( + nn.Parameter(torch.zeros(1, self.n_storage_tokens, self.embed_dims)) + if self.n_storage_tokens > 0 + else None + ) + # how many non-patch tokens are at the front + self.num_extra_tokens = ( + 1 if self.cls_token is not None else 0 + ) + self.n_storage_tokens + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), ( + f'"out_indices" must by a sequence or int, get {type(out_indices)} instead.' + ) + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = self.num_layers + index + assert 0 <= out_indices[i] <= self.num_layers, ( + f"Invalid out_indices {index}" + ) + self.out_indices = out_indices + + self.blocks = nn.Sequential() + if isinstance(layer_cfgs, dict): + layer_cfgs = [layer_cfgs] * self.num_layers + + mhsa_early, mhsa_late = 8, 8 + for i in range(self.num_layers): + if i < mhsa_early or i >= self.num_layers - mhsa_late: + num_kv_heads = None ## use MHSA + else: + num_kv_heads = self.arch_settings["num_heads"] // 2 # Use GQA + + _layer_cfg = dict( + embed_dims=self.embed_dims, + num_heads=self.arch_settings["num_heads"], + num_kv_heads=num_kv_heads, + feedforward_channels=self.arch_settings["feedforward_channels"], + use_qk_norm=use_qk_norm, + layer_scale_init_value=layer_scale_init_value, + drop_rate=drop_rate, + qkv_bias=qkv_bias, + ) + _layer_cfg.update(layer_cfgs[i]) + self.blocks.append(TransformerEncoderLayer2(**_layer_cfg)) + + self.frozen_stages = frozen_stages + + self.final_norm = final_norm + if final_norm: + self.ln1 = nn.RMSNorm(self.embed_dims, eps=1e-6) + + # freeze stages only when self.frozen_stages > 0 + if self.frozen_stages > 0: + self._freeze_stages() + + ## load init weights + self.init_weights() + + return + + def init_weights(self): + # Initialize class token and storagr token embeddings + if self.with_cls_token: + trunc_normal_(self.cls_token, std=0.02) + + if self.storage_tokens is not None: + trunc_normal_(self.storage_tokens, std=0.02) + + # Apply custom initialization to all submodules + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # Use a truncated normal distribution for linear layer weights + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, (nn.LayerNorm, nn.RMSNorm)): + # Initialize normalization layers to act as an identity function + if hasattr(m, "bias") and m.bias is not None: + nn.init.constant_(m.bias, 0) + if hasattr(m, "weight") and m.weight is not None: + nn.init.constant_(m.weight, 1.0) + + elif isinstance(m, nn.Conv2d): + # Initialize conv layer weights like linear layers + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _freeze_stages(self): + ## freeze tokenizer + if self.frozen_stages >= 1 and self.tokenizer is not None: + self.tokenizer.eval() + for param in self.tokenizer.parameters(): + param.requires_grad = False + + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze cls_token + if self.cls_token is not None: + self.cls_token.requires_grad = False + if self.storage_tokens is not None: + self.storage_tokens.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze the last layer norm + if self.frozen_stages == len(self.blocks): + if self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + def forward(self, x): + B = x.shape[0] + + x, patch_resolution = self.patch_embed(x) # (B, 256*256, C) + if self.tokenizer is not None: + x, patch_resolution = self.tokenizer(x, patch_resolution) + + # prepend [CLS] and storage tokens + prepend = [] + if self.cls_token is not None: + prepend.append(self.cls_token.expand(B, -1, -1)) + if self.storage_tokens is not None: + prepend.append(self.storage_tokens.expand(B, -1, -1)) + if len(prepend) > 0: + x = torch.cat(prepend + [x], dim=1) + + rope_sincos = self.rope_embed(H=patch_resolution[0], W=patch_resolution[1]) + outs = [] + for i, layer in enumerate(self.blocks): + x = layer(x, rope=rope_sincos) + + if i == len(self.blocks) - 1 and self.final_norm: + x = self.ln1(x) + + if i in self.out_indices: + outs.append(self._format_output(x, patch_resolution)) + + return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == "raw": + return x + if self.out_type == "cls_token": + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens :] + if self.out_type == "featmap": + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + + @property + def norm1(self): + return self.ln1 + + +# ---------------------------------------------------------------------------- +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + inplace: bool = False, + data_format: str = "channels_last", + scale: float = 1e-5, + ): + super().__init__() + assert data_format in ( + "channels_last", + "channels_first", + ), "'data_format' could only be channels_last or channels_first." + self.inplace = inplace + self.data_format = data_format + self.weight = nn.Parameter(torch.ones(dim) * scale) + + def forward(self, x) -> torch.Tensor: + if self.data_format == "channels_first": + shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2)))) + else: + shape = tuple((*(1 for _ in range(x.dim() - 1)), -1)) + if self.inplace: + return x.mul_(self.weight.view(*shape)) + else: + return x * self.weight.view(*shape) + + +# ---------------------------------------------------------------------------- +class PatchEmbed(nn.Module): + def __init__( + self, + in_channels=3, + embed_dims=768, + kernel_size=16, + stride=16, + padding="corner", + dilation=1, + bias=True, + input_size=None, + ): + super().__init__() + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + padding = 0 + padding = to_2tuple(padding) + + self.projection = nn.Conv2d( + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + if input_size: + input_size = to_2tuple(input_size) + self.init_input_size = input_size + h_out = ( + input_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1 + ) // stride[0] + 1 + w_out = ( + input_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1 + ) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x): + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + return x, out_size + + +# ---------------------------------------------------------------------------- +class SwiGLUFFN(nn.Module): + """SwiGLU FFN layer. + https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py + """ # noqa + + def __init__( + self, + embed_dims: int, + feedforward_channels: Optional[int] = None, + out_dims: Optional[int] = None, + layer_scale_init_value: float = 0.0, + bias: bool = True, + add_identity: bool = True, + ) -> None: + super().__init__() + self.embed_dims = embed_dims + self.out_dims = out_dims or embed_dims + hidden_dims = feedforward_channels or embed_dims + + self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias) + self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias) + + if layer_scale_init_value > 0: + self.gamma2 = LayerScale(dim=embed_dims, scale=layer_scale_init_value) + else: + self.gamma2 = nn.Identity() + + self.add_identity = add_identity + + def forward( + self, x: torch.Tensor, identity: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + out = self.w3(hidden) + out = self.gamma2(out) + + if self.out_dims != self.embed_dims or not self.add_identity: + # due to the dimension inconsistence or user setting + # not to apply residual operation + return out + + if identity is None: + identity = x + return identity + out diff --git a/sapiens/dense/__init__.py b/sapiens/dense/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f035f3cbea404e8b6a8b9281c1276fe2c5f1eaf --- /dev/null +++ b/sapiens/dense/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pathlib +import pkgutil + +from .. import __version__ + +_src = pathlib.Path(__file__).with_name("src") +__path__ = pkgutil.extend_path(__path__, __name__) # allow namespace merge +__path__.append(str(_src)) +del pathlib, pkgutil, _src + + +# ----------------------------------------------------- +from importlib import import_module as _imp + +_pkg = _imp(__name__ + ".src") # runs src/__init__.py diff --git a/sapiens/dense/configs/albedo/render_people/sapiens2_0.4b_albedo_render_people-1024x768.py b/sapiens/dense/configs/albedo/render_people/sapiens2_0.4b_albedo_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..ebbd38dab49717a1801256625110ba40892b4da1 --- /dev/null +++ b/sapiens/dense/configs/albedo/render_people/sapiens2_0.4b_albedo_render_people-1024x768.py @@ -0,0 +1,274 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 2e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters. + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 10 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_0.4b" +embed_dim = 1024 +num_layers = 24 +num_heads = 16 + +layer_decay_rate = 0.8 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.4b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="AlbedoVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict( + type="AlbedoRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="AlbedoRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="AlbedoRandomFlip", + prob=0.3, + ), + dict(type="AlbedoResize", height=1024, width=768), + dict(type="RandomGaussianNoise", prob=0.2, var_range=(5.0, 20.0)), + dict( + type="AlbedoPackInputs", + meta_keys=( + "img_path", + "ori_shape", + ), + ), +] + +val_pipeline = [ + dict(type="AlbedoResize", height=1024, width=768, test_mode=True), + dict( + type="AlbedoPackInputs", + test_mode=True, + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + ), + ), +] + +test_pipeline = [ + dict(type="AlbedoResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="AlbedoPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "padding_size", + ), + ), +] + + +render_people_dataset = dict( + type="AlbedoRenderPeopleDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo", +) + +train_datasets = [render_people_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="AlbedoRenderPeopleDataset", + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo_test", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="AlbedoEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="AlbedoEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="AlbedoHead", + in_channels=embed_dim, + upsample_channels=[768, 512, 256, 128], ## 1K resolution + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + loss_decode=[ + dict(type="L1Loss", loss_weight=2.0), + dict(type="AlbedoGradL1Loss", loss_weight=2.0), + # dict(type="AlbedoLowFreqL1Loss", down_sample=32, loss_weight=1.0), + dict(type="AlbedoChromaticityL1Loss", loss_weight=1.0), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0) diff --git a/sapiens/dense/configs/albedo/render_people/sapiens2_0.8b_albedo_render_people-1024x768.py b/sapiens/dense/configs/albedo/render_people/sapiens2_0.8b_albedo_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a34e1359b818571684501657b5283cc6dac801 --- /dev/null +++ b/sapiens/dense/configs/albedo/render_people/sapiens2_0.8b_albedo_render_people-1024x768.py @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 2e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters. + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +# val_every_iters = 2000 +val_every_iters = 10000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 10 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_0.8b" +embed_dim = 1280 +num_layers = 32 +num_heads = 16 + +layer_decay_rate = 0.85 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.8b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="AlbedoVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict( + type="AlbedoRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="AlbedoRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="AlbedoRandomFlip", + prob=0.3, + ), + dict(type="AlbedoResize", height=1024, width=768), + dict(type="RandomGaussianNoise", prob=0.2, var_range=(5.0, 20.0)), + dict( + type="AlbedoPackInputs", + meta_keys=( + "img_path", + "ori_shape", + ), + ), +] + +val_pipeline = [ + dict(type="AlbedoResize", height=1024, width=768, test_mode=True), + dict( + type="AlbedoPackInputs", + test_mode=True, + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + ), + ), +] + +test_pipeline = [ + dict(type="AlbedoResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="AlbedoPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "padding_size", + ), + ), +] + + +render_people_dataset = dict( + type="AlbedoRenderPeopleDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo", +) + +train_datasets = [render_people_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="AlbedoRenderPeopleDataset", + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo_test", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="AlbedoEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="AlbedoEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="AlbedoHead", + in_channels=embed_dim, + upsample_channels=[768, 512, 256, 128], ## 1K resolution + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + loss_decode=[ + dict(type="L1Loss", loss_weight=2.0), + dict(type="AlbedoGradL1Loss", loss_weight=2.0), + # dict(type="AlbedoLowFreqL1Loss", down_sample=32, loss_weight=1.0), + dict(type="AlbedoChromaticityL1Loss", loss_weight=1.0), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/sapiens/dense/configs/albedo/render_people/sapiens2_1b_albedo_render_people-1024x768.py b/sapiens/dense/configs/albedo/render_people/sapiens2_1b_albedo_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..4199adf42b6a30581d22976ee23536f4dfb12e4a --- /dev/null +++ b/sapiens/dense/configs/albedo/render_people/sapiens2_1b_albedo_render_people-1024x768.py @@ -0,0 +1,274 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters. + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 10 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_1b" +embed_dim = 1536 +num_layers = 40 +num_heads = 24 +layer_decay_rate = 0.9 + +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_1b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="AlbedoVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict( + type="AlbedoRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="AlbedoRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="AlbedoRandomFlip", + prob=0.3, + ), + dict(type="AlbedoResize", height=1024, width=768), + dict(type="RandomGaussianNoise", prob=0.2, var_range=(5.0, 20.0)), + dict( + type="AlbedoPackInputs", + meta_keys=( + "img_path", + "ori_shape", + ), + ), +] + +val_pipeline = [ + dict(type="AlbedoResize", height=1024, width=768, test_mode=True), + dict( + type="AlbedoPackInputs", + test_mode=True, + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + ), + ), +] + +test_pipeline = [ + dict(type="AlbedoResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="AlbedoPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "padding_size", + ), + ), +] + + +render_people_dataset = dict( + type="AlbedoRenderPeopleDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo", +) + +train_datasets = [render_people_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="AlbedoRenderPeopleDataset", + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo_test", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="AlbedoEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="AlbedoEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="AlbedoHead", + in_channels=embed_dim, + upsample_channels=[768, 512, 256, 128], ## 1K resolution + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + loss_decode=[ + dict(type="L1Loss", loss_weight=2.0), + dict(type="AlbedoGradL1Loss", loss_weight=2.0), + # dict(type="AlbedoLowFreqL1Loss", down_sample=32, loss_weight=1.0), + dict(type="AlbedoChromaticityL1Loss", loss_weight=1.0), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/sapiens/dense/configs/albedo/render_people/sapiens2_5b_albedo_render_people-1024x768.py b/sapiens/dense/configs/albedo/render_people/sapiens2_5b_albedo_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..abd56d8327fa8922edb2c3074d36fe33ed641b1e --- /dev/null +++ b/sapiens/dense/configs/albedo/render_people/sapiens2_5b_albedo_render_people-1024x768.py @@ -0,0 +1,280 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters. + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +# val_every_iters = 1000 +# val_every_iters = 20000 +val_every_iters = 40000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 10 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_5b" +embed_dim = 2432 +num_layers = 56 +num_heads = 32 +layer_decay_rate = 0.94 + +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_5b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + # parallelism_cfg=dict( + # dp_shard_size=2, # Fully Sharded Data Parallel degree + # dp_replicate_size=1, # Data Parallel degree + # tp_size=1, # Tensor Parallel degree + # cp_size=4, # Context Parallel degree + # ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="AlbedoVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict( + type="AlbedoRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="AlbedoRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="AlbedoRandomFlip", + prob=0.3, + ), + dict(type="AlbedoResize", height=1024, width=768), + dict(type="RandomGaussianNoise", prob=0.2, var_range=(5.0, 20.0)), + dict( + type="AlbedoPackInputs", + meta_keys=( + "img_path", + "ori_shape", + ), + ), +] + +val_pipeline = [ + dict(type="AlbedoResize", height=1024, width=768, test_mode=True), + dict( + type="AlbedoPackInputs", + test_mode=True, + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + ), + ), +] + +test_pipeline = [ + dict(type="AlbedoResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="AlbedoPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "padding_size", + ), + ), +] + + +render_people_dataset = dict( + type="AlbedoRenderPeopleDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo", +) + +train_datasets = [render_people_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + shuffle=False, + dataset=dict( + type="AlbedoRenderPeopleDataset", + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_albedo_test", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="AlbedoEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="AlbedoEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="AlbedoHead", + in_channels=embed_dim, + upsample_channels=[1536, 768, 512, 256], ## 1K resolution + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + loss_decode=[ + dict(type="L1Loss", loss_weight=2.0), + dict(type="AlbedoGradL1Loss", loss_weight=2.0), + # dict(type="AlbedoLowFreqL1Loss", down_sample=32, loss_weight=1.0), + dict(type="AlbedoChromaticityL1Loss", loss_weight=1.0), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/sapiens/dense/configs/normal/metasim_render_people/sapiens2_0.4b_normal_metasim_render_people-1024x768.py b/sapiens/dense/configs/normal/metasim_render_people/sapiens2_0.4b_normal_metasim_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..0c1b5ddfb333aea8bcce75923d8359093dc0e35b --- /dev/null +++ b/sapiens/dense/configs/normal/metasim_render_people/sapiens2_0.4b_normal_metasim_render_people-1024x768.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 2e4 + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_0.4b" +embed_dim = 1024 +num_layers = 24 +num_heads = 16 + +layer_decay_rate = 0.8 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.4b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="NormalVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=8, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict(type="PhotoMetricDistortion"), + dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2), + dict( + type="NormalRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="NormalRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="NormalRandomFlip", + prob=0.3, + ), + dict(type="NormalResize", height=1024, width=768), + dict( + type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0) + ), + dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)), + dict(type="RandomSolarize", prob=0.3, threshold=128), + dict(type="NormalGenerateTarget"), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "ori_shape", + ), + ), +] + +val_pipeline = [ + dict(type="NormalResize", height=1024, width=768, test_mode=True), + dict( + type="NormalPackInputs", + test_mode=True, + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + ), + ), +] + +test_pipeline = [ + dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "padding_size", + ), + ), +] + +metasim_dataset = dict( + type="NormalMetaSimDataset", + airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data", + json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json", +) + +render_people_dataset = dict( + type="NormalRenderPeopleBodyDataset", ## body only + data_root=f"{_DATA_ROOT}/synthetic", + seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg", +) + +multihuman_render_people_dataset = dict( + type="NormalRenderPeopleMultihumanDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human", + normal_extension=".npz", + seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman +) + +# train_datasets = 2 * [metasim_dataset] + [ +# render_people_dataset, +# multihuman_render_people_dataset, +# ] + +# train_datasets = [render_people_dataset] +# train_datasets = [multihuman_render_people_dataset] +train_datasets = [metasim_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="NormalRenderPeopleBodyDataset", ## body only + # num_samples=100, ## debug: only use N samples for validation + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="NormalEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="NormalEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="NormalHead", + in_channels=embed_dim, + upsample_channels=[768, 512, 256, 128], ## 1K resolution + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + loss_decode=[ + dict( + type="NormalCosineSimilarityLoss", + loss_weight=10.0, + ), + dict(type="L1Loss", loss_weight=1.0), + dict(type="NormalGradL1Loss", loss_weight=10.0), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0) diff --git a/sapiens/dense/configs/normal/metasim_render_people/sapiens2_0.8b_normal_metasim_render_people-1024x768.py b/sapiens/dense/configs/normal/metasim_render_people/sapiens2_0.8b_normal_metasim_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf64f4635c457a8ded298326ed61c9647e93a48 --- /dev/null +++ b/sapiens/dense/configs/normal/metasim_render_people/sapiens2_0.8b_normal_metasim_render_people-1024x768.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 1e4 + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_0.8b" +embed_dim = 1280 +num_layers = 32 +num_heads = 16 + +layer_decay_rate = 0.85 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.8b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="NormalVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=8, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict(type="PhotoMetricDistortion"), + dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2), + dict( + type="NormalRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="NormalRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="NormalRandomFlip", + prob=0.3, + ), + dict(type="NormalResize", height=1024, width=768), + dict( + type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0) + ), + dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)), + dict(type="RandomSolarize", prob=0.3, threshold=128), + dict(type="NormalGenerateTarget"), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "ori_shape", + ), + ), +] + +val_pipeline = [ + dict(type="NormalResize", height=1024, width=768, test_mode=True), + dict( + type="NormalPackInputs", + test_mode=True, + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + ), + ), +] + +test_pipeline = [ + dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "padding_size", + ), + ), +] + +metasim_dataset = dict( + type="NormalMetaSimDataset", + airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data", + json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json", +) + +render_people_dataset = dict( + type="NormalRenderPeopleBodyDataset", ## body only + data_root=f"{_DATA_ROOT}/synthetic", + seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg", +) + +multihuman_render_people_dataset = dict( + type="NormalRenderPeopleMultihumanDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human", + normal_extension=".npz", + seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman +) + +# train_datasets = 2 * [metasim_dataset] + [ +# render_people_dataset, +# multihuman_render_people_dataset, +# ] + +# train_datasets = [render_people_dataset] +# train_datasets = [multihuman_render_people_dataset] +train_datasets = [metasim_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="NormalRenderPeopleBodyDataset", ## body only + # num_samples=100, ## debug: only use N samples for validation + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="NormalEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="NormalEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="NormalHead", + in_channels=embed_dim, + upsample_channels=[768, 512, 256, 128], ## 1K resolution + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + loss_decode=[ + dict( + type="NormalCosineSimilarityLoss", + loss_weight=10.0, + ), + dict(type="L1Loss", loss_weight=1.0), + dict(type="NormalGradL1Loss", loss_weight=10.0), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/sapiens/dense/configs/normal/metasim_render_people/sapiens2_1b_normal_metasim_render_people-1024x768.py b/sapiens/dense/configs/normal/metasim_render_people/sapiens2_1b_normal_metasim_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..bf76d49b8e0a08fda95e7e89ed7268535d7b8c9c --- /dev/null +++ b/sapiens/dense/configs/normal/metasim_render_people/sapiens2_1b_normal_metasim_render_people-1024x768.py @@ -0,0 +1,306 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters. +# num_iters = 1e4 ## light finetune + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_1b" +embed_dim = 1536 +num_layers = 40 +num_heads = 24 +layer_decay_rate = 0.9 + +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_1b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="NormalVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict(type="PhotoMetricDistortion"), + dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2), + dict( + type="NormalRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="NormalRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="NormalRandomFlip", + prob=0.3, + ), + dict(type="NormalResize", height=1024, width=768), + dict( + type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0) + ), + dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)), + dict(type="RandomSolarize", prob=0.3, threshold=128), + dict(type="NormalGenerateTarget"), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "ori_shape", + ), + ), +] + +val_pipeline = [ + dict(type="NormalResize", height=1024, width=768, test_mode=True), + dict( + type="NormalPackInputs", + test_mode=True, + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + ), + ), +] + +test_pipeline = [ + dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "padding_size", + ), + ), +] + +metasim_dataset = dict( + type="NormalMetaSimDataset", + airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data", + json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json", +) + +render_people_dataset = dict( + type="NormalRenderPeopleBodyDataset", ## body only + data_root=f"{_DATA_ROOT}/synthetic", + seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg", +) + +multihuman_render_people_dataset = dict( + type="NormalRenderPeopleMultihumanDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human", + normal_extension=".npz", + seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman +) + +# train_datasets = 2 * [metasim_dataset] + [ +# render_people_dataset, +# multihuman_render_people_dataset, +# ] + +# train_datasets = [render_people_dataset] +# train_datasets = [multihuman_render_people_dataset] +train_datasets = [metasim_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="NormalRenderPeopleBodyDataset", ## body only + # num_samples=100, ## debug: only use N samples for validation + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="NormalEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="NormalEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + # with_cls_token=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="NormalHead", + in_channels=embed_dim, + upsample_channels=[768, 512, 256, 128], ## 1K resolution + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + loss_decode=[ + dict( + type="NormalCosineSimilarityLoss", + loss_weight=10.0, + ), + dict(type="L1Loss", loss_weight=1.0), + dict(type="NormalGradL1Loss", loss_weight=10.0), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/sapiens/dense/configs/normal/metasim_render_people/sapiens2_5b_normal_metasim_render_people-1024x768.py b/sapiens/dense/configs/normal/metasim_render_people/sapiens2_5b_normal_metasim_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c0b330f08b9787b69defb916fa16e57cb7e10a --- /dev/null +++ b/sapiens/dense/configs/normal/metasim_render_people/sapiens2_5b_normal_metasim_render_people-1024x768.py @@ -0,0 +1,312 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters. +# num_iters = 1e4 ## light finetune + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_5b" +embed_dim = 2432 +num_layers = 56 +num_heads = 32 +layer_decay_rate = 0.94 + +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_5b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + # parallelism_cfg=dict( + # dp_shard_size=2, # Fully Sharded Data Parallel degree + # dp_replicate_size=1, # Data Parallel degree + # tp_size=1, # Tensor Parallel degree + # cp_size=4, # Context Parallel degree + # ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="NormalVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict(type="PhotoMetricDistortion"), + dict(type="RandomDownUpSampleImage", scale_range=(0.1, 0.7), prob=0.2), + dict( + type="NormalRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="NormalRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="NormalRandomFlip", + prob=0.3, + ), + dict(type="NormalResize", height=1024, width=768), + dict( + type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0) + ), + dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)), + dict(type="RandomSolarize", prob=0.3, threshold=128), + dict(type="NormalGenerateTarget"), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "ori_shape", + ), + ), +] + +val_pipeline = [ + dict(type="NormalResize", height=1024, width=768, test_mode=True), + dict( + type="NormalPackInputs", + test_mode=True, + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + ), + ), +] + +test_pipeline = [ + dict(type="NormalResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="NormalPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + ), + ), +] + +metasim_dataset = dict( + type="NormalMetaSimDataset", + airstore_template="airstore://codec_avatar_sapiens_metasim_v1_no_user_data", + json_path=f"{_DATA_ROOT}/seg/data/metasim/meta_data_v1.json", +) + +render_people_dataset = dict( + type="NormalRenderPeopleBodyDataset", ## body only + data_root=f"{_DATA_ROOT}/synthetic", + seg_data_root=f"{_DATA_ROOT}/RenderPeople/part_seg", +) + +multihuman_render_people_dataset = dict( + type="NormalRenderPeopleMultihumanDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human", + normal_extension=".npz", + seg_data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_multi_human/part_seg", ## supervise on face for multihuman +) + +# train_datasets = 2 * [metasim_dataset] + [ +# render_people_dataset, +# multihuman_render_people_dataset, +# ] + +# train_datasets = [render_people_dataset] +# train_datasets = [multihuman_render_people_dataset] +train_datasets = [metasim_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="NormalRenderPeopleBodyDataset", ## body only + # num_samples=100, ## debug: only use N samples for validation + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/metasim/evaluation", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="NormalEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="NormalEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="NormalHead", + in_channels=embed_dim, + upsample_channels=[1536, 768, 512, 256], ## 1K resolution + conv_out_channels=[128, 64, 32], + conv_kernel_sizes=[3, 3, 3], + loss_decode=[ + dict( + type="NormalCosineSimilarityLoss", + loss_weight=10.0, + ), + dict(type="L1Loss", loss_weight=1.0), + dict(type="NormalGradL1Loss", loss_weight=10.0), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + # lr=5e-4, + lr=1e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/sapiens/dense/configs/pointmap/render_people/sapiens2_0.4b_pointmap_render_people-1024x768.py b/sapiens/dense/configs/pointmap/render_people/sapiens2_0.4b_pointmap_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..88c1d63843e06af873f95fa8aab5c8135bdc4d9e --- /dev/null +++ b/sapiens/dense/configs/pointmap/render_people/sapiens2_0.4b_pointmap_render_people-1024x768.py @@ -0,0 +1,322 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 2e4 ## 16 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters. + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_0.4b" +embed_dim = 1024 +num_layers = 24 +num_heads = 16 + +layer_decay_rate = 0.8 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.4b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width + +patch_size = 16 +num_tokens = (image_size[0] // patch_size) * (image_size[1] // patch_size) +canonical_focal_length = 768.0 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + cpu_ram_efficient_loading=False, + ), + ) + + ## Note: to merge sharded weight using FSDP + # accelerate merge-weights pytorch_model_fsdp_0/ . + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="PointmapVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict(type="PhotoMetricDistortion"), + dict( + type="PointmapRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="PointmapRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="PointmapRandomFlip", + prob=0.3, + ), + dict(type="PointmapResize", height=1024, width=768), + ## target is same res as output, otherwise we get artifacts. + dict( + type="PointmapGenerateTarget", + canonical_focal_length=canonical_focal_length, + target_downsample_factor=1, + ), + dict( + type="PointmapPackInputs", + meta_keys=( + "img_path", + "ori_shape", + "img_shape", + "pad_shape", + "scale", + "flip", + "flip_direction", + "original_K", + "K", + "M", + ), + ), +] + +val_pipeline = [ + dict(type="PointmapResize", height=1024, width=768), + dict(type="PointmapGenerateTarget", canonical_focal_length=canonical_focal_length), + dict( + type="PointmapPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + "scale", + "padding_size", + "K", + "M", + ), + ), +] + +test_pipeline = [ + dict(type="PointmapResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="PointmapPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + "scale", + "padding_size", + "K", + "M", + ), + ), +] + +render_people_dataset = dict( + type="PointmapRenderPeopleDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2", +) + +train_datasets = [render_people_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="PointmapRenderPeopleDataset", + # num_samples=100, ## debug: only use N samples for validation + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2_test", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="PointmapEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="PointmapEstimator", + canonical_focal_length=canonical_focal_length, + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="PointmapHead", + in_channels=embed_dim, + upsample_channels=[1536, 768, 512, 256], + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + scale_conv_out_channels=(1536, 512, 128), + scale_conv_kernel_sizes=(1, 1, 1), + scale_final_layer=( + (num_tokens // ((2 * 2 * 2) * (2 * 2 * 2))) * 128, + 512, + 128, + 1, + ), ## scale regress + loss_decode=[ + dict(type="L1Loss", loss_weight=2.0), ## on pointmap, XYZ + dict( + type="MultiscaleL1Loss", + loss_weight=1.0, + scale_factor=2, + ), + dict(type="SiLogLoss", loss_weight=1.0), ## only applies silog loss + dict( + type="PointmapIntrinsicsConsistencyLoss", + loss_weight=1.0, + ), + dict( + type="PointmapShiftInvariantL1Loss", + loss_weight=1.0, + ), + dict(type="PointmapNormalLoss", loss_weight=2.0), + dict( + type="PointmapScaleL1Loss", loss_weight=4.0 + ), ## Canonical XYZ = scale * XYZ + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0) diff --git a/sapiens/dense/configs/pointmap/render_people/sapiens2_0.8b_pointmap_render_people-1024x768.py b/sapiens/dense/configs/pointmap/render_people/sapiens2_0.8b_pointmap_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..34ee2e67d5db5f59611609ce1de196d892fed69d --- /dev/null +++ b/sapiens/dense/configs/pointmap/render_people/sapiens2_0.8b_pointmap_render_people-1024x768.py @@ -0,0 +1,325 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 2e4 ## 16 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters. +# num_iters = 1e4 ## light finetune + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_0.8b" +embed_dim = 1280 +num_layers = 32 +num_heads = 16 + +layer_decay_rate = 0.85 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.8b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width + +patch_size = 16 +num_tokens = (image_size[0] // patch_size) * (image_size[1] // patch_size) +canonical_focal_length = 768.0 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + # mixed_precision=dict( + # param_dtype="bf16", + # reduce_dtype="bf16", + # ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="PointmapVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict(type="PhotoMetricDistortion"), + dict( + type="PointmapRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="PointmapRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="PointmapRandomFlip", + prob=0.3, + ), + dict(type="PointmapResize", height=1024, width=768), + ## target is same res as output, otherwise we get artifacts. + dict( + type="PointmapGenerateTarget", + canonical_focal_length=canonical_focal_length, + target_downsample_factor=1, + ), + dict( + type="PointmapPackInputs", + meta_keys=( + "img_path", + "ori_shape", + "img_shape", + "pad_shape", + "scale", + "flip", + "flip_direction", + "original_K", + "K", + "M", + ), + ), +] + +val_pipeline = [ + dict(type="PointmapResize", height=1024, width=768), + dict(type="PointmapGenerateTarget", canonical_focal_length=canonical_focal_length), + dict( + type="PointmapPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + "scale", + "padding_size", + "K", + "M", + ), + ), +] + +test_pipeline = [ + dict(type="PointmapResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="PointmapPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + "scale", + "padding_size", + "K", + "M", + ), + ), +] + +render_people_dataset = dict( + type="PointmapRenderPeopleDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2", +) + +train_datasets = [render_people_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="PointmapRenderPeopleDataset", + # num_samples=100, ## debug: only use N samples for validation + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2_test", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="PointmapEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="PointmapEstimator", + canonical_focal_length=canonical_focal_length, + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="PointmapHead", + in_channels=embed_dim, + upsample_channels=[1536, 768, 512, 256], + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + scale_conv_out_channels=(1536, 512, 128), + scale_conv_kernel_sizes=(1, 1, 1), + scale_final_layer=( + (num_tokens // ((2 * 2 * 2) * (2 * 2 * 2))) * 128, + 512, + 128, + 1, + ), ## scale regress + loss_decode=[ + dict(type="L1Loss", loss_weight=2.0), ## on pointmap, XYZ + dict( + type="MultiscaleL1Loss", + loss_weight=1.0, + scale_factor=2, + ), + dict(type="SiLogLoss", loss_weight=1.0), ## only applies silog loss + dict( + type="PointmapIntrinsicsConsistencyLoss", + loss_weight=1.0, + ), + dict( + type="PointmapShiftInvariantL1Loss", + loss_weight=1.0, + ), + dict(type="PointmapNormalLoss", loss_weight=2.0), + dict( + type="PointmapScaleL1Loss", loss_weight=4.0 + ), ## Canonical XYZ = scale * XYZ + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0) diff --git a/sapiens/dense/configs/pointmap/render_people/sapiens2_1b_pointmap_render_people-1024x768.py b/sapiens/dense/configs/pointmap/render_people/sapiens2_1b_pointmap_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..bc33ffebd352ef7111f5979d71c9ae1c2878fb93 --- /dev/null +++ b/sapiens/dense/configs/pointmap/render_people/sapiens2_1b_pointmap_render_people-1024x768.py @@ -0,0 +1,319 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 3, global bs: 768. num samples: 1e6. 1e6/768 = 1302. 1 epoch = 1e3 iters. + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_1b" +embed_dim = 1536 +num_layers = 40 +num_heads = 24 +layer_decay_rate = 0.9 + +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_1b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width + +patch_size = 16 +num_tokens = (image_size[0] // patch_size) * (image_size[1] // patch_size) +canonical_focal_length = 768.0 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + # mixed_precision=dict( + # param_dtype="bf16", + # reduce_dtype="bf16", + # ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="PointmapVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict(type="PhotoMetricDistortion"), + dict( + type="PointmapRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="PointmapRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="PointmapRandomFlip", + prob=0.3, + ), + dict(type="PointmapResize", height=1024, width=768), + ## target is same res as output, otherwise we get artifacts. + dict( + type="PointmapGenerateTarget", + canonical_focal_length=canonical_focal_length, + target_downsample_factor=1, + ), + dict( + type="PointmapPackInputs", + meta_keys=( + "img_path", + "ori_shape", + "img_shape", + "pad_shape", + "scale", + "flip", + "flip_direction", + "original_K", + "K", + "M", + ), + ), +] + +val_pipeline = [ + dict(type="PointmapResize", height=1024, width=768), + dict(type="PointmapGenerateTarget", canonical_focal_length=canonical_focal_length), + dict( + type="PointmapPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + "scale", + "padding_size", + "K", + "M", + ), + ), +] + +test_pipeline = [ + dict(type="PointmapResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="PointmapPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "padding_size", + ), + ), +] + +render_people_dataset = dict( + type="PointmapRenderPeopleDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2", +) + +train_datasets = [render_people_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="PointmapRenderPeopleDataset", + # num_samples=100, ## debug: only use N samples for validation + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2_test", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="PointmapEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="PointmapEstimator", + canonical_focal_length=canonical_focal_length, + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="PointmapHead", + in_channels=embed_dim, + upsample_channels=[1536, 768, 512, 256], + conv_out_channels=[64, 32, 16], + conv_kernel_sizes=[3, 3, 3], + scale_conv_out_channels=(1536, 512, 128), + scale_conv_kernel_sizes=(1, 1, 1), + scale_final_layer=( + (num_tokens // ((2 * 2 * 2) * (2 * 2 * 2))) * 128, + 512, + 128, + 1, + ), ## scale regress + loss_decode=[ + dict(type="L1Loss", loss_weight=2.0), ## on pointmap, XYZ + dict( + type="MultiscaleL1Loss", + loss_weight=1.0, + scale_factor=2, + ), + dict(type="SiLogLoss", loss_weight=1.0), ## only applies silog loss + dict( + type="PointmapIntrinsicsConsistencyLoss", + loss_weight=1.0, + ), + dict( + type="PointmapShiftInvariantL1Loss", + loss_weight=1.0, + ), + dict(type="PointmapNormalLoss", loss_weight=2.0), + dict( + type="PointmapScaleL1Loss", loss_weight=4.0 + ), ## Canonical XYZ = scale * XYZ + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/sapiens/dense/configs/pointmap/render_people/sapiens2_5b_pointmap_render_people-1024x768.py b/sapiens/dense/configs/pointmap/render_people/sapiens2_5b_pointmap_render_people-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..b387ac129731a5d5a76ed8589e0aa7dbbf1dcf2a --- /dev/null +++ b/sapiens/dense/configs/pointmap/render_people/sapiens2_5b_pointmap_render_people-1024x768.py @@ -0,0 +1,329 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 4e4 ## 32 nodes, 8 gpus: 256 gpus. bs: 1, global bs: 256. num samples: 1e6. 1e6/256 = 3906. 1 epoch = 3906 iters. + +## debug +# warmup_iters = 100 +# num_iters = 300 + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 10 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_5b" +embed_dim = 2432 +num_layers = 56 +num_heads = 32 +layer_decay_rate = 0.94 + +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_5b_pretrain.safetensors" + + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width + +patch_size = 16 +num_tokens = (image_size[0] // patch_size) * (image_size[1] // patch_size) +canonical_focal_length = 768.0 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + # mixed_precision=dict( + # param_dtype="bf16", + # reduce_dtype="bf16", + # ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="PointmapVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict(type="PhotoMetricDistortion"), + dict( + type="PointmapRandomScale", + scale_min=0.5, + scale_max=2.0, + prob=0.3, + ), + dict( + type="PointmapRandomCropContinuous", + ar_range=(0.5, 2.0), + area_range=(0.4, 1.0), + num_attempts=8, + prob=0.3, + ), + dict( + type="PointmapRandomFlip", + prob=0.3, + ), + dict(type="PointmapResize", height=1024, width=768), + ## target is same res as output, otherwise we get artifacts. + dict( + type="PointmapGenerateTarget", + canonical_focal_length=canonical_focal_length, + target_downsample_factor=1, + ), + dict( + type="PointmapPackInputs", + meta_keys=( + "img_path", + "ori_shape", + "img_shape", + "pad_shape", + "scale", + "flip", + "flip_direction", + "original_K", + "K", + "M", + ), + ), +] + +val_pipeline = [ + dict(type="PointmapResize", height=1024, width=768), + dict(type="PointmapGenerateTarget", canonical_focal_length=canonical_focal_length), + dict( + type="PointmapPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + "scale", + "padding_size", + "K", + "M", + ), + ), +] + +test_pipeline = [ + dict(type="PointmapResizePadImage", height=1024, width=768, pad_val=0), + dict( + type="PointmapPackInputs", + meta_keys=( + "img_path", + "orig_img_height", + "orig_img_width", + "img_shape", + "pad_shape", + "scale", + "padding_size", + "K", + "M", + ), + ), +] + +render_people_dataset = dict( + type="PointmapRenderPeopleDataset", + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2", +) + +train_datasets = [render_people_dataset] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=False, + dataset=dict( + type="PointmapRenderPeopleDataset", + # num_samples=100, ## only use N samples for validation + test_mode=True, + data_root=f"{_DATA_ROOT}/seg/data/render_people/synthetic_v2_test", + pipeline=val_pipeline, + ), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict( + type="PointmapEvaluator", + ), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="PointmapEstimator", + canonical_focal_length=canonical_focal_length, + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="PointmapHead", + in_channels=embed_dim, + # upsample_channels=[1536, 768, 512, 256], + # conv_out_channels=[64, 32, 16], + # conv_kernel_sizes=[3, 3, 3], + upsample_channels=[1536, 768, 768, 768], ## 1K resolution + conv_out_channels=[128, 64, 32], + conv_kernel_sizes=[3, 3, 3], + scale_conv_out_channels=(1536, 512, 128), + scale_conv_kernel_sizes=(1, 1, 1), + scale_final_layer=( + (num_tokens // ((2 * 2 * 2) * (2 * 2 * 2))) * 128, + 512, + 128, + 1, + ), ## scale regress + loss_decode=[ + dict(type="L1Loss", loss_weight=2.0), ## on pointmap, XYZ + dict( + type="MultiscaleL1Loss", + loss_weight=1.0, + scale_factor=2, + ), + dict(type="SiLogLoss", loss_weight=1.0), ## only applies silog loss + dict( + type="PointmapIntrinsicsConsistencyLoss", + loss_weight=1.0, + ), + dict( + type="PointmapShiftInvariantL1Loss", + loss_weight=1.0, + ), + dict(type="PointmapNormalLoss", loss_weight=2.0), + dict( + type="PointmapScaleL1Loss", loss_weight=4.0 + ), ## Canonical XYZ = scale * XYZ + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + # lr=5e-4, + lr=1e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_0.4b_seg_shutterstock_goliath-1024x768.py b/sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_0.4b_seg_shutterstock_goliath-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb31296e54e6e53e40894f79cda9f0f2cd02797 --- /dev/null +++ b/sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_0.4b_seg_shutterstock_goliath-1024x768.py @@ -0,0 +1,364 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 2e4 + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # # # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_0.4b" +embed_dim = 1024 +num_layers = 24 +num_heads = 16 + +layer_decay_rate = 0.8 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.4b_pretrain.safetensors" + +num_classes = 29 ## 29 classes +CLASS_WEIGHT = [ + 0.1, + 10, + 10, + 3, + 2, + 4, + 4, + 2, + 2, + 6, + 10, + 3, + 3, + 1, + 4, + 4, + 2, + 2, + 6, + 10, + 3, + 3, + 1, + 1, + 10, + 10, + 10, + 10, + 10, +] ## 29 classes + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +# use_fsdp = True +use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="SegVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, + class_palette_type="dome29", +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict( + type="SegRandomBackground", + prob=0.8, + skip_key="is_itw", + background_images_root=f"{_DATA_ROOT}/BG-20k/train", + ), + dict( + type="SegRandomResize", + base_height=1024, + base_width=768, + ratio_range=(0.4, 2.0), + keep_ratio=True, + ), + dict( + type="SegRandomCrop", + crop_height=1024, + crop_width=768, + prob=0.3, + cat_max_ratio=0.75, + ), + dict( + type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0) + ), + dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)), + dict( + type="SegRandomRotate", prob=0.5, degree=60, seg_pad_val=0 + ), ## the black pixels are set as background + dict( + type="SegRandomHorizontalFlip", + prob=0.5, + swap_seg_labels=[ + (5, 14), + (6, 15), + (7, 16), + (8, 17), + (9, 18), + (10, 19), + (11, 20), + (12, 21), + ], + ), ## for the 29 classes, + dict(type="PhotoMetricDistortion"), + dict(type="SegResize", height=1024, width=768, keep_ratio=False), + dict(type="SegPackInputs"), +] + +val_pipeline = [ + dict(type="SegResize", height=1024, width=768, keep_ratio=False, test_mode=True), + dict(type="SegPackInputs", test_mode=True), +] + +test_pipeline = [ + dict(type="SegResize", height=1024, width=768, keep_ratio=False, test_mode=True), + dict(type="SegPackInputs", test_mode=True), +] + +##------------------------------------------------------------------------ +dataset_dome_train = dict( + type="SegDomeClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_segmentation_33_train:2024092600.json", +) + +dataset_shutterstock_train = dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_train:2024121600.json", +) + +dataset_ca3_wide_train = dict( + type="SegDomeClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_wide_angle_body_segmentation_33_train:2024091700.json", +) + +dataset_caa_train = dict( + type="SegDomeClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/cca_segmentation_33_train:2024092400.json", +) + +dataset_ca3_zoom_train = dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_zoom_in_body_segmentation_50_train:2024091700.json", +) + +dataset_lighticon_train = dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/lighticon_lightful_body_segmentation_51_train:2025021900.json", +) + +dataset_internal_train = dict( + type="SegInternalClass29Dataset", + # ann_file=f"{_DATA_ROOT}/annotations/stylized_sapiens/20250807/Internal_segmentation_32:2025080700.json", + ann_file=f"{_DATA_ROOT}/annotations/internal_dataset/20251103/internal_keypoint_344_segmentation_32_train:2025091500.json", +) + +train_datasets = [ + dataset_dome_train, + dataset_ca3_wide_train, + dataset_caa_train, + dataset_ca3_zoom_train, + dataset_lighticon_train, + dataset_internal_train, +] + 2 * [dataset_shutterstock_train] + +train_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", ## avoids fork error with airstore + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_test:2024121600.json", + test_mode=True, + pipeline=val_pipeline, + ), + collate_fn=dict(type="eval_collate"), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict(type="SegEvaluator", class_names="dome29", nan_to_num=0.0), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="SegEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="SegHead", + in_channels=embed_dim, + deconv_out_channels=( + 512, + 256, + 128, + 64, + ), ## this will 2x at each step. so total is 16x. 1K output. + deconv_kernel_sizes=(4, 4, 4, 4), + conv_out_channels=(64, 64), + conv_kernel_sizes=(1, 1), + num_classes=num_classes, + loss_decode=[ + dict( + type="CrossEntropyLoss", + loss_weight=1.0, + reduction="none", + class_weight=CLASS_WEIGHT, + ignore_index=255, + ), + dict( + type="DiceLoss", + loss_weight=1.0, + reduction="none", + activate=True, + use_sigmoid=False, + include_background=False, + ignore_index=255, + ), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, ## use fused AdamW +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0) diff --git a/sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_0.8b_seg_shutterstock_goliath-1024x768.py b/sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_0.8b_seg_shutterstock_goliath-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..253a0201fce497b8df6e6e6870984c2614f87d61 --- /dev/null +++ b/sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_0.8b_seg_shutterstock_goliath-1024x768.py @@ -0,0 +1,368 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 3e4 ## bs: 5; 16 gpus + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # # # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_0.8b" +embed_dim = 1280 +num_layers = 32 +num_heads = 16 + +layer_decay_rate = 0.85 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.8b_pretrain.safetensors" + +num_classes = 29 ## 29 classes +CLASS_WEIGHT = [ + 0.1, + 10, + 10, + 3, + 2, + 4, + 4, + 2, + 2, + 6, + 10, + 3, + 3, + 1, + 4, + 4, + 2, + 2, + 6, + 10, + 3, + 3, + 1, + 1, + 10, + 10, + 10, + 10, + 10, +] ## 29 classes + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +# use_fsdp = True +use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + + ## Note: to merge sharded weight using FSDP + # accelerate merge-weights pytorch_model_fsdp_0/ . + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="SegVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, + class_palette_type="dome29", +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict( + type="SegRandomBackground", + prob=0.8, + skip_key="is_itw", + background_images_root=f"{_DATA_ROOT}/BG-20k/train", + ), + dict( + type="SegRandomResize", + base_height=1024, + base_width=768, + ratio_range=(0.4, 2.0), + keep_ratio=True, + ), + dict( + type="SegRandomCrop", + crop_height=1024, + crop_width=768, + prob=0.3, + cat_max_ratio=0.75, + ), + dict( + type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0) + ), + dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)), + dict( + type="SegRandomRotate", prob=0.5, degree=60, seg_pad_val=0 + ), ## the black pixels are set as background + dict( + type="SegRandomHorizontalFlip", + prob=0.5, + swap_seg_labels=[ + (5, 14), + (6, 15), + (7, 16), + (8, 17), + (9, 18), + (10, 19), + (11, 20), + (12, 21), + ], + ), ## for the 29 classes, + dict(type="PhotoMetricDistortion"), + dict(type="SegResize", height=1024, width=768, keep_ratio=False), + dict(type="SegPackInputs"), +] + +val_pipeline = [ + dict(type="SegResize", height=1024, width=768, keep_ratio=False, test_mode=True), + dict(type="SegPackInputs", test_mode=True), +] + +test_pipeline = [ + dict(type="SegResize", height=1024, width=768, keep_ratio=False, test_mode=True), + dict(type="SegPackInputs", test_mode=True), +] + +##------------------------------------------------------------------------ +dataset_dome_train = dict( + type="SegDomeClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_segmentation_33_train:2024092600.json", +) + +dataset_shutterstock_train = dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_train:2024121600.json", +) + +dataset_ca3_wide_train = dict( + type="SegDomeClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_wide_angle_body_segmentation_33_train:2024091700.json", +) + +dataset_caa_train = dict( + type="SegDomeClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/cca_segmentation_33_train:2024092400.json", +) + +dataset_ca3_zoom_train = dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_zoom_in_body_segmentation_50_train:2024091700.json", +) + +dataset_lighticon_train = dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/lighticon_lightful_body_segmentation_51_train:2025021900.json", +) + +dataset_internal_train = dict( + type="SegInternalClass29Dataset", + # ann_file=f"{_DATA_ROOT}/annotations/stylized_sapiens/20250807/Internal_segmentation_32:2025080700.json", + ann_file=f"{_DATA_ROOT}/annotations/internal_dataset/20251103/internal_keypoint_344_segmentation_32_train:2025091500.json", +) + +train_datasets = [ + dataset_dome_train, + dataset_ca3_wide_train, + dataset_caa_train, + dataset_ca3_zoom_train, + dataset_lighticon_train, + dataset_internal_train, +] + 2 * [dataset_shutterstock_train] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", ## avoids fork error with airstore + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="SegShutterstockClass29Dataset", + # num_samples=40, ## only use N samples for validation + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_test:2024121600.json", + test_mode=True, + pipeline=val_pipeline, + ), + collate_fn=dict(type="eval_collate"), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict(type="SegEvaluator", class_names="dome29", nan_to_num=0.0), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="SegEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="SegHead", + in_channels=embed_dim, + deconv_out_channels=( + 512, + 256, + 128, + 64, + ), ## this will 2x at each step. so total is 16x. 1K output. + deconv_kernel_sizes=(4, 4, 4, 4), + conv_out_channels=(64, 64), + conv_kernel_sizes=(1, 1), + num_classes=num_classes, + loss_decode=[ + dict( + type="CrossEntropyLoss", + loss_weight=1.0, + reduction="none", + class_weight=CLASS_WEIGHT, + ignore_index=255, + ), + dict( + type="DiceLoss", + loss_weight=1.0, + reduction="none", + activate=True, + use_sigmoid=False, + include_background=False, + ignore_index=255, + ), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_1b_seg_shutterstock_goliath-1024x768.py b/sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_1b_seg_shutterstock_goliath-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..577569d9bb2129f444d83047cb374b7a0780d679 --- /dev/null +++ b/sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_1b_seg_shutterstock_goliath-1024x768.py @@ -0,0 +1,366 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 2e4 +# num_iters = 4e4 + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # # # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_1b" +embed_dim = 1536 +num_layers = 40 +num_heads = 24 + +layer_decay_rate = 0.9 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_1b_pretrain.safetensors" + +num_classes = 29 ## 29 classes +CLASS_WEIGHT = [ + 0.1, + 10, + 10, + 3, + 2, + 4, + 4, + 2, + 2, + 6, + 10, + 3, + 3, + 1, + 4, + 4, + 2, + 2, + 6, + 10, + 3, + 3, + 1, + 1, + 10, + 10, + 10, + 10, + 10, +] ## 29 classes + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="SegVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, + class_palette_type="dome29", +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict( + type="SegRandomBackground", + prob=0.8, + skip_key="is_itw", + background_images_root=f"{_DATA_ROOT}/BG-20k/train", + ), + dict( + type="SegRandomResize", + base_height=1024, + base_width=768, + ratio_range=(0.4, 2.0), + keep_ratio=True, + ), + dict( + type="SegRandomCrop", + crop_height=1024, + crop_width=768, + prob=0.3, + cat_max_ratio=0.75, + ), + dict( + type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0) + ), + dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)), + dict( + type="SegRandomRotate", prob=0.5, degree=60, seg_pad_val=0 + ), ## the black pixels are set as background + dict( + type="SegRandomHorizontalFlip", + prob=0.5, + swap_seg_labels=[ + (5, 14), + (6, 15), + (7, 16), + (8, 17), + (9, 18), + (10, 19), + (11, 20), + (12, 21), + ], + ), ## for the 29 classes, + dict(type="PhotoMetricDistortion"), + dict(type="SegResize", height=1024, width=768, keep_ratio=False), + dict(type="SegPackInputs"), +] + +val_pipeline = [ + dict(type="SegResize", height=1024, width=768, keep_ratio=False, test_mode=True), + dict(type="SegPackInputs", test_mode=True), +] + +test_pipeline = [ + dict(type="SegResize", height=1024, width=768, keep_ratio=False, test_mode=True), + dict(type="SegPackInputs", test_mode=True), +] + +##------------------------------------------------------------------------ +dataset_dome_train = dict( + type="SegDomeClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_segmentation_33_train:2024092600.json", +) + +dataset_shutterstock_train = dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_train:2024121600.json", +) + +dataset_ca3_wide_train = dict( + type="SegDomeClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_wide_angle_body_segmentation_33_train:2024091700.json", +) + +dataset_caa_train = dict( + type="SegDomeClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/cca_segmentation_33_train:2024092400.json", +) + +dataset_ca3_zoom_train = dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_zoom_in_body_segmentation_50_train:2024091700.json", +) + +dataset_lighticon_train = dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/lighticon_lightful_body_segmentation_51_train:2025021900.json", +) + +dataset_internal_train = dict( + type="SegInternalClass29Dataset", + # ann_file=f"{_DATA_ROOT}/annotations/stylized_sapiens/20250807/Internal_segmentation_32:2025080700.json", + ann_file=f"{_DATA_ROOT}/annotations/internal_dataset/20251103/internal_keypoint_344_segmentation_32_train:2025091500.json", +) + +train_datasets = [ + dataset_dome_train, + dataset_ca3_wide_train, + dataset_caa_train, + dataset_ca3_zoom_train, + dataset_lighticon_train, + dataset_internal_train, +] + 2 * [dataset_shutterstock_train] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", ## avoids fork error with airstore + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dict( + type="SegShutterstockClass29Dataset", + # num_samples=40, ## only use N samples for validation + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_test:2024121600.json", + test_mode=True, + pipeline=val_pipeline, + ), + collate_fn=dict(type="eval_collate"), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict(type="SegEvaluator", class_names="dome29", nan_to_num=0.0), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="SegEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="SegHead", + in_channels=embed_dim, + deconv_out_channels=( + 512, + 256, + 128, + 64, + ), ## this will 2x at each step. so total is 16x. 1K output. + deconv_kernel_sizes=(4, 4, 4, 4), + conv_out_channels=(64, 64), + conv_kernel_sizes=(1, 1), + num_classes=num_classes, + loss_decode=[ + dict( + type="CrossEntropyLoss", + loss_weight=1.0, + reduction="none", + class_weight=CLASS_WEIGHT, + ignore_index=255, + ), + dict( + type="DiceLoss", + loss_weight=1.0, + reduction="none", + activate=True, + use_sigmoid=False, + include_background=False, + ignore_index=255, + ), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_5b_seg_shutterstock_goliath-1024x768.py b/sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_5b_seg_shutterstock_goliath-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..ce945a09e4b252093b415a25b17c7baaaccb4582 --- /dev/null +++ b/sapiens/dense/configs/seg/shutterstock_goliath/sapiens2_5b_seg_shutterstock_goliath-1024x768.py @@ -0,0 +1,365 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +num_iters = 5e4 ## for h200; bs is 4 + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 2000 +# val_every_iters = 2000 +val_every_iters = 10000 + +# # # # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 10 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_5b" +embed_dim = 2432 +num_layers = 56 +num_heads = 32 + +layer_decay_rate = 0.94 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_5b_pretrain.safetensors" + +num_classes = 29 ## 29 classes +CLASS_WEIGHT = [ + 0.1, + 10, + 10, + 3, + 2, + 4, + 4, + 2, + 2, + 6, + 10, + 3, + 3, + 1, + 4, + 4, + 2, + 2, + 6, + 10, + 3, + 3, + 1, + 1, + 10, + 10, + 10, + 10, + 10, +] ## 29 classes + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + # state_dict_type="FULL_STATE_DICT", # TODO: resume from this is not working + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + + ## Note: to merge sharded weight using FSDP + # accelerate merge-weights pytorch_model_fsdp_0/ . + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="SegVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, + class_palette_type="dome29", +) + + +##----------------------------------------------------------------- +train_pipeline = [ + dict( + type="SegRandomBackground", + prob=0.8, + skip_key="is_itw", + background_images_root=f"{_DATA_ROOT}/BG-20k/train", + ), + dict( + type="SegRandomResize", + base_height=1024, + base_width=768, + ratio_range=(0.4, 2.0), + keep_ratio=True, + ), + dict( + type="SegRandomCrop", + crop_height=1024, + crop_width=768, + prob=0.3, + cat_max_ratio=0.75, + ), + dict( + type="RandomGaussianBlur", prob=0.3, kernel_size=(3, 3), sigma_range=(0.1, 2.0) + ), + dict(type="RandomGaussianNoise", prob=0.3, var_range=(5.0, 20.0)), + dict( + type="SegRandomRotate", prob=0.5, degree=60, seg_pad_val=0 + ), ## the black pixels are set as background + dict( + type="SegRandomHorizontalFlip", + prob=0.5, + swap_seg_labels=[ + (5, 14), + (6, 15), + (7, 16), + (8, 17), + (9, 18), + (10, 19), + (11, 20), + (12, 21), + ], + ), ## for the 29 classes, + dict(type="PhotoMetricDistortion"), + dict(type="SegResize", height=1024, width=768, keep_ratio=False), + dict(type="SegPackInputs"), +] + +val_pipeline = [ + dict(type="SegResize", height=1024, width=768, keep_ratio=False), + dict(type="SegPackInputs"), +] + +test_pipeline = [ + dict(type="SegResize", height=1024, width=768, keep_ratio=False), + dict(type="SegPackInputs"), +] + +##------------------------------------------------------------------------ +dataset_dome_train = dict( + type="SegDomeClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_segmentation_33_train:2024092600.json", +) + +dataset_shutterstock_train = dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_train:2024121600.json", +) + +dataset_ca3_wide_train = dict( + type="SegDomeClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_wide_angle_body_segmentation_33_train:2024091700.json", +) + +dataset_caa_train = dict( + type="SegDomeClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/cca_segmentation_33_train:2024092400.json", +) + +dataset_ca3_zoom_train = dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/ca3_zoom_in_body_segmentation_50_train:2024091700.json", +) + +dataset_lighticon_train = dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/lighticon_lightful_body_segmentation_51_train:2025021900.json", +) + +dataset_internal_train = dict( + type="SegInternalClass29Dataset", + # ann_file=f"{_DATA_ROOT}/annotations/stylized_sapiens/20250807/Internal_segmentation_32:2025080700.json", + ann_file=f"{_DATA_ROOT}/annotations/internal_dataset/20251103/internal_keypoint_344_segmentation_32_train:2025091500.json", +) + +train_datasets = [ + dataset_dome_train, + dataset_ca3_wide_train, + dataset_caa_train, + dataset_ca3_zoom_train, + dataset_lighticon_train, + dataset_internal_train, +] + 2 * [dataset_shutterstock_train] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", ## avoids fork error with airstore + shuffle=False, + dataset=dict( + type="SegShutterstockClass29Dataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/itw_shutterstock_body_segmentation_51_test:2024121600.json", + test_mode=True, + pipeline=val_pipeline, + ), + collate_fn=dict(type="eval_collate"), +) + +val_cfg = dict( + val_interval=val_every_iters, + evaluator=dict(type="SegEvaluator", class_names="dome29", nan_to_num=0.0), +) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="SegEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="SegHead", + in_channels=embed_dim, + deconv_out_channels=( + 512, + 256, + 128, + 64, + ), ## this will 2x at each step. so total is 16x. 1K output. + deconv_kernel_sizes=(4, 4, 4, 4), + conv_out_channels=(64, 64), + conv_kernel_sizes=(1, 1), + num_classes=num_classes, + loss_decode=[ + dict( + type="CrossEntropyLoss", + loss_weight=1.0, + reduction="none", + class_weight=CLASS_WEIGHT, + ignore_index=255, + ), + dict( + type="DiceLoss", + loss_weight=1.0, + reduction="none", + activate=True, + use_sigmoid=False, + include_background=False, + ignore_index=255, + ), + ], + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) diff --git a/sapiens/dense/scripts/albedo/train/sapiens2_0.4b/node.sh b/sapiens/dense/scripts/albedo/train/sapiens2_0.4b/node.sh new file mode 100755 index 0000000000000000000000000000000000000000..0c25c50a406da0693ad539396caa0e9f67bd9c18 --- /dev/null +++ b/sapiens/dense/scripts/albedo/train/sapiens2_0.4b/node.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +cd "$(dirname "$(realpath "$0")")/../../../.." || exit + +#------------------------------------------------------------------------------- +DEVICES=0,1,2,3,4,5,6,7 +# DEVICES=0 + +#------------------------------------------------------------------------------- +TASK="albedo" +DATASET="render_people" +MODEL="sapiens2_0.4b_${TASK}_${DATASET}-1024x768" + +CONFIG_FILE="configs/${TASK}/$DATASET/${MODEL}.py" +TRAIN_BATCH_SIZE_PER_GPU=20 + +#------------------------------------------------------------------------------- +# mode='debug' +mode='multi-gpu' + +#------------------------------------------------------------------------------- +OUTPUT_DIR="Outputs/${TASK}/train/${MODEL}/node" +OUTPUT_DIR="$(echo "${OUTPUT_DIR}/$(date +"%m-%d-%Y_%H:%M:%S")")" + +#------------------------------------------------------------------------------- +OPTIONS="train_dataloader.batch_size=$TRAIN_BATCH_SIZE_PER_GPU" +OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}" +CMD_RESUME="${RESUME_FROM:+--resume $RESUME_FROM}" + +export TF_CPP_MIN_LOG_LEVEL=2 +PORT=$(( ((RANDOM<<15)|RANDOM) % 63001 + 2000 )) + +#------------------------------------------------------------------------------- +if [ "$mode" = "debug" ]; then + export TORCH_DISTRIBUTED_DEBUG=DETAIL + TRAIN_BATCH_SIZE_PER_GPU=1 + OPTIONS="train_dataloader.batch_size=${TRAIN_BATCH_SIZE_PER_GPU} train_dataloader.num_workers=0 train_dataloader.persistent_workers=False" + OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}" + + CUDA_VISIBLE_DEVICES=${DEVICES} python tools/train.py ${CONFIG_FILE} \ + --work-dir ${OUTPUT_DIR} \ + --cfg-options ${OPTIONS} \ + ${CMD_RESUME} + +elif [ "$mode" = "multi-gpu" ]; then + NUM_GPUS=$(echo $DEVICES | tr -s ',' ' ' | wc -w) + + LOG_FILE="${OUTPUT_DIR}/log.txt" + mkdir -p ${OUTPUT_DIR} + touch ${LOG_FILE} + + CUDA_VISIBLE_DEVICES=${DEVICES} PORT=${PORT} 'tools/dist_train.sh' ${CONFIG_FILE} \ + ${NUM_GPUS} \ + --work-dir ${OUTPUT_DIR} \ + --cfg-options ${OPTIONS} \ + ${CMD_RESUME} \ + | tee ${LOG_FILE} +fi diff --git a/sapiens/dense/scripts/albedo/train/sapiens2_0.8b/node.sh b/sapiens/dense/scripts/albedo/train/sapiens2_0.8b/node.sh new file mode 100755 index 0000000000000000000000000000000000000000..07e194bcf0f5363c24ab032a6d69b76bc4cb816f --- /dev/null +++ b/sapiens/dense/scripts/albedo/train/sapiens2_0.8b/node.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +cd "$(dirname "$(realpath "$0")")/../../../.." || exit + +#------------------------------------------------------------------------------- +DEVICES=0,1,2,3,4,5,6,7 +# DEVICES=0 + +#------------------------------------------------------------------------------- +TASK="albedo" +DATASET="render_people" +MODEL="sapiens2_0.8b_${TASK}_${DATASET}-1024x768" + +CONFIG_FILE="configs/${TASK}/$DATASET/${MODEL}.py" +TRAIN_BATCH_SIZE_PER_GPU=12 +LOAD_FROM='' + +#------------------------------------------------------------------------------- +# mode='debug' +mode='multi-gpu' + +#------------------------------------------------------------------------------- +OUTPUT_DIR="Outputs/${TASK}/train/${MODEL}/node" +OUTPUT_DIR="$(echo "${OUTPUT_DIR}/$(date +"%m-%d-%Y_%H:%M:%S")")" + +#------------------------------------------------------------------------------- +OPTIONS="train_dataloader.batch_size=$TRAIN_BATCH_SIZE_PER_GPU" +OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}" +CMD_RESUME="${RESUME_FROM:+--resume $RESUME_FROM}" + +export TF_CPP_MIN_LOG_LEVEL=2 +PORT=$(( ((RANDOM<<15)|RANDOM) % 63001 + 2000 )) + +#------------------------------------------------------------------------------- +if [ "$mode" = "debug" ]; then + export TORCH_DISTRIBUTED_DEBUG=DETAIL + TRAIN_BATCH_SIZE_PER_GPU=1 + OPTIONS="train_dataloader.batch_size=${TRAIN_BATCH_SIZE_PER_GPU} train_dataloader.num_workers=0 train_dataloader.persistent_workers=False" + OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}" + + CUDA_VISIBLE_DEVICES=${DEVICES} python tools/train.py ${CONFIG_FILE} \ + --work-dir ${OUTPUT_DIR} \ + --cfg-options ${OPTIONS} \ + ${CMD_RESUME} + +elif [ "$mode" = "multi-gpu" ]; then + NUM_GPUS=$(echo $DEVICES | tr -s ',' ' ' | wc -w) + + LOG_FILE="${OUTPUT_DIR}/log.txt" + mkdir -p ${OUTPUT_DIR} + touch ${LOG_FILE} + + CUDA_VISIBLE_DEVICES=${DEVICES} PORT=${PORT} 'tools/dist_train.sh' ${CONFIG_FILE} \ + ${NUM_GPUS} \ + --work-dir ${OUTPUT_DIR} \ + --cfg-options ${OPTIONS} \ + ${CMD_RESUME} \ + | tee ${LOG_FILE} +fi diff --git a/sapiens/dense/scripts/albedo/train/sapiens2_1b/node.sh b/sapiens/dense/scripts/albedo/train/sapiens2_1b/node.sh new file mode 100755 index 0000000000000000000000000000000000000000..ed6defa544fc403fb1d2bfe45e70b8945901cddb --- /dev/null +++ b/sapiens/dense/scripts/albedo/train/sapiens2_1b/node.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +cd "$(dirname "$(realpath "$0")")/../../../.." || exit + +#------------------------------------------------------------------------------- +DEVICES=0,1,2,3,4,5,6,7 +# DEVICES=0 + +#------------------------------------------------------------------------------- +TASK="albedo" +DATASET="render_people" +MODEL="sapiens2_1b_${TASK}_${DATASET}-1024x768" + +CONFIG_FILE="configs/${TASK}/$DATASET/${MODEL}.py" +TRAIN_BATCH_SIZE_PER_GPU=7 + + +#------------------------------------------------------------------------------- +# mode='debug' +mode='multi-gpu' + +#------------------------------------------------------------------------------- +OUTPUT_DIR="Outputs/${TASK}/train/${MODEL}/node" +OUTPUT_DIR="$(echo "${OUTPUT_DIR}/$(date +"%m-%d-%Y_%H:%M:%S")")" + +#------------------------------------------------------------------------------- +OPTIONS="train_dataloader.batch_size=$TRAIN_BATCH_SIZE_PER_GPU" +OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}" +CMD_RESUME="${RESUME_FROM:+--resume $RESUME_FROM}" + +export TF_CPP_MIN_LOG_LEVEL=2 +PORT=$(( ((RANDOM<<15)|RANDOM) % 63001 + 2000 )) + +#------------------------------------------------------------------------------- +if [ "$mode" = "debug" ]; then + export TORCH_DISTRIBUTED_DEBUG=DETAIL + TRAIN_BATCH_SIZE_PER_GPU=1 + OPTIONS="train_dataloader.batch_size=${TRAIN_BATCH_SIZE_PER_GPU} train_dataloader.num_workers=0 train_dataloader.persistent_workers=False" + OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}" + + CUDA_VISIBLE_DEVICES=${DEVICES} python tools/train.py ${CONFIG_FILE} \ + --work-dir ${OUTPUT_DIR} \ + --cfg-options ${OPTIONS} \ + ${CMD_RESUME} + +elif [ "$mode" = "multi-gpu" ]; then + NUM_GPUS=$(echo $DEVICES | tr -s ',' ' ' | wc -w) + + LOG_FILE="${OUTPUT_DIR}/log.txt" + mkdir -p ${OUTPUT_DIR} + touch ${LOG_FILE} + + CUDA_VISIBLE_DEVICES=${DEVICES} PORT=${PORT} 'tools/dist_train.sh' ${CONFIG_FILE} \ + ${NUM_GPUS} \ + --work-dir ${OUTPUT_DIR} \ + --cfg-options ${OPTIONS} \ + ${CMD_RESUME} \ + | tee ${LOG_FILE} +fi diff --git a/sapiens/dense/scripts/albedo/train/sapiens2_5b/node.sh b/sapiens/dense/scripts/albedo/train/sapiens2_5b/node.sh new file mode 100755 index 0000000000000000000000000000000000000000..75f23b7821c45afbda25b39a90841afb3a77f312 --- /dev/null +++ b/sapiens/dense/scripts/albedo/train/sapiens2_5b/node.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +cd "$(dirname "$(realpath "$0")")/../../../.." || exit + +#------------------------------------------------------------------------------- +DEVICES=0,1,2,3,4,5,6,7 +# DEVICES=0 + +#------------------------------------------------------------------------------- +TASK="albedo" +DATASET="render_people" +MODEL="sapiens2_5b_${TASK}_${DATASET}-1024x768" + +CONFIG_FILE="configs/${TASK}/$DATASET/${MODEL}.py" +TRAIN_BATCH_SIZE_PER_GPU=3 + +# LOAD_FROM="" + +#------------------------------------------------------------------------------- +# mode='debug' +mode='multi-gpu' + +#------------------------------------------------------------------------------- +OUTPUT_DIR="Outputs/${TASK}/train/${MODEL}/node" +OUTPUT_DIR="$(echo "${OUTPUT_DIR}/$(date +"%m-%d-%Y_%H:%M:%S")")" + +#------------------------------------------------------------------------------- +OPTIONS="train_dataloader.batch_size=$TRAIN_BATCH_SIZE_PER_GPU" +OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}" +CMD_RESUME="${RESUME_FROM:+--resume $RESUME_FROM}" + +export TF_CPP_MIN_LOG_LEVEL=2 +PORT=$(( ((RANDOM<<15)|RANDOM) % 63001 + 2000 )) + +#------------------------------------------------------------------------------- +if [ "$mode" = "debug" ]; then + export TORCH_DISTRIBUTED_DEBUG=DETAIL + TRAIN_BATCH_SIZE_PER_GPU=1 + OPTIONS="train_dataloader.batch_size=${TRAIN_BATCH_SIZE_PER_GPU} train_dataloader.num_workers=0 train_dataloader.persistent_workers=False" + OPTIONS="${OPTIONS}${LOAD_FROM:+ load_from=$LOAD_FROM}" + + CUDA_VISIBLE_DEVICES=${DEVICES} python tools/train.py ${CONFIG_FILE} \ + --work-dir ${OUTPUT_DIR} \ + --cfg-options ${OPTIONS} \ + ${CMD_RESUME} + +elif [ "$mode" = "multi-gpu" ]; then + NUM_GPUS=$(echo $DEVICES | tr -s ',' ' ' | wc -w) + + LOG_FILE="${OUTPUT_DIR}/log.txt" + mkdir -p ${OUTPUT_DIR} + touch ${LOG_FILE} + + CUDA_VISIBLE_DEVICES=${DEVICES} PORT=${PORT} 'tools/dist_train.sh' ${CONFIG_FILE} \ + ${NUM_GPUS} \ + --work-dir ${OUTPUT_DIR} \ + --cfg-options ${OPTIONS} \ + ${CMD_RESUME} \ + | tee ${LOG_FILE} +fi diff --git a/sapiens/dense/scripts/demo/albedo.sh b/sapiens/dense/scripts/demo/albedo.sh new file mode 100755 index 0000000000000000000000000000000000000000..10554f9a8e5000d5db4c74832660069b9fc88bce --- /dev/null +++ b/sapiens/dense/scripts/demo/albedo.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# Run albedo (intrinsic) estimation on a directory of images. + +cd "$(dirname "$(realpath "$0")")/../.." || exit +SAPIENS_CHECKPOINT_ROOT="${SAPIENS_CHECKPOINT_ROOT:-${HOME}/sapiens2_host}" + +#----------------------------set your input and output directories------------------------- +INPUT='./demo/data/itw_videos/reel1' +OUTPUT="${HOME}/Desktop/sapiens2/albedo/Outputs/vis/itw_videos/reel1" + +#--------------------------MODEL CARD (uncomment one)--------------------------------------- +# MODEL_NAME='sapiens2_0.4b'; CHECKPOINT="${SAPIENS_CHECKPOINT_ROOT}/albedo/sapiens2_0.4b_albedo.safetensors" +# MODEL_NAME='sapiens2_0.8b'; CHECKPOINT="${SAPIENS_CHECKPOINT_ROOT}/albedo/sapiens2_0.8b_albedo.safetensors" +MODEL_NAME='sapiens2_1b'; CHECKPOINT="${SAPIENS_CHECKPOINT_ROOT}/albedo/sapiens2_1b_albedo.safetensors" +# MODEL_NAME='sapiens2_5b'; CHECKPOINT="${SAPIENS_CHECKPOINT_ROOT}/albedo/sapiens2_5b_albedo.safetensors" + +DATASET='render_people' +MODEL="${MODEL_NAME}_albedo_${DATASET}-1024x768" +CONFIG_FILE="configs/albedo/${DATASET}/${MODEL}.py" +OUTPUT="${OUTPUT}/${MODEL_NAME}" + +##-------------------------------------inference-------------------------------------------- +RUN_FILE='tools/vis/vis_albedo.py' + +JOBS_PER_GPU=3; GPU_IDS=(0 1 2 3 4 5 6 7) +# JOBS_PER_GPU=1; GPU_IDS=(0) +TOTAL_JOBS=$((JOBS_PER_GPU * ${#GPU_IDS[@]})) + +IMAGE_LIST="${INPUT}/image_list.txt" +find "${INPUT}" -type f \( -iname \*.jpg -o -iname \*.jpeg -o -iname \*.png \) | sort > "${IMAGE_LIST}" + +if [ ! -s "${IMAGE_LIST}" ]; then + echo "No images found at ${INPUT}" + exit 1 +fi + +NUM_IMAGES=$(wc -l < "${IMAGE_LIST}") +IMAGES_PER_FILE=$((NUM_IMAGES / TOTAL_JOBS)) +EXTRA_IMAGES=$((NUM_IMAGES % TOTAL_JOBS)) + +export TF_CPP_MIN_LOG_LEVEL=2 +echo "Distributing ${NUM_IMAGES} image paths into ${TOTAL_JOBS} jobs." + +current_line=1 +for ((i=0; i "${TEXT_FILE}" + current_line=$((current_line + images_for_this_job)) + else + touch "${TEXT_FILE}" + fi +done + +for ((i=0; i "${IMAGE_LIST}" + +if [ ! -s "${IMAGE_LIST}" ]; then + echo "No images found at ${INPUT}" + exit 1 +fi + +NUM_IMAGES=$(wc -l < "${IMAGE_LIST}") +IMAGES_PER_FILE=$((NUM_IMAGES / TOTAL_JOBS)) +EXTRA_IMAGES=$((NUM_IMAGES % TOTAL_JOBS)) + +export TF_CPP_MIN_LOG_LEVEL=2 +echo "Distributing ${NUM_IMAGES} image paths into ${TOTAL_JOBS} jobs." + +current_line=1 +for ((i=0; i "${TEXT_FILE}" + current_line=$((current_line + images_for_this_job)) + else + touch "${TEXT_FILE}" + fi +done + +for ((i=0; i "${IMAGE_LIST}" + +if [ ! -s "${IMAGE_LIST}" ]; then + echo "No images found at ${INPUT}" + exit 1 +fi + +NUM_IMAGES=$(wc -l < "${IMAGE_LIST}") +IMAGES_PER_FILE=$((NUM_IMAGES / TOTAL_JOBS)) +EXTRA_IMAGES=$((NUM_IMAGES % TOTAL_JOBS)) + +export TF_CPP_MIN_LOG_LEVEL=2 +echo "Distributing ${NUM_IMAGES} image paths into ${TOTAL_JOBS} jobs." + +current_line=1 +for ((i=0; i "${TEXT_FILE}" + current_line=$((current_line + images_for_this_job)) + else + touch "${TEXT_FILE}" + fi +done + +for ((i=0; i "${IMAGE_LIST}" + +if [ ! -s "${IMAGE_LIST}" ]; then + echo "No images found at ${INPUT}" + exit 1 +fi + +NUM_IMAGES=$(wc -l < "${IMAGE_LIST}") +IMAGES_PER_FILE=$((NUM_IMAGES / TOTAL_JOBS)) +EXTRA_IMAGES=$((NUM_IMAGES % TOTAL_JOBS)) + +export TF_CPP_MIN_LOG_LEVEL=2 +echo "Distributing ${NUM_IMAGES} image paths into ${TOTAL_JOBS} jobs." + +current_line=1 +for ((i=0; i "${TEXT_FILE}" + current_line=$((current_line + images_for_this_job)) + else + touch "${TEXT_FILE}" + fi +done + +for ((i=0; i None: + self.num_samples = num_samples + super().__init__(**kwargs) + return + + def load_data_list(self) -> List[dict]: + data_list = [] + + self.rgb_dir = os.path.join(self.data_root, "rgb") + self.albedo_dir = os.path.join(self.data_root, "albedo") + self.mask_dir = os.path.join(self.data_root, "mask") + + print("\033[92mLoading {}!\033[0m".format(self.__class__.__name__)) + + # Create a set of common file names from all three directories + rgb_files = {x for x in os.listdir(self.rgb_dir) if x.endswith(".png")} + mask_files = {x for x in os.listdir(self.mask_dir) if x.endswith(".png")} + albedo_files = {x for x in os.listdir(self.albedo_dir) if x.endswith(".png")} + + common_names = sorted(rgb_files & mask_files & albedo_files) + + # Create data list using the common file names + data_list = [ + { + "rgb_path": os.path.join(self.rgb_dir, name), + "mask_path": os.path.join(self.mask_dir, name), + "albedo_path": os.path.join(self.albedo_dir, name), + } + for name in common_names + ] + + if self.num_samples is not None: + data_list = data_list[: self.num_samples] + + print( + "\033[92mDone! {}. Loaded total samples: {}. Test mode: {}\033[0m".format( + self.__class__.__name__, len(data_list), self.test_mode + ) + ) + + return data_list + + def get_data_info(self, idx): + data_info = copy.deepcopy(self.data_list[idx]) + try: + with suppress_stderr(): + img = cv2.imread(data_info["rgb_path"]) ## bgr image is default + albedo = cv2.imread(data_info["albedo_path"]) + albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB) ## RGB image + mask = cv2.imread(data_info["mask_path"]) + + except Exception: + return None + + mask = mask[:, :, 0] ## H x W + + if mask is None or mask.sum() < 16: ## min pixels is 16 + return None + + # Normalize albedo to the range 0-1 + albedo = albedo.astype(float) / 255.0 + + # Check if 98% of the pixels are the same color + albedo_mask = albedo[mask > 0] + std_per_channel = np.std(albedo_mask, axis=0) + if np.max(std_per_channel) < 0.02: # Threshold can be adjusted + return None + + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + + # Find the bounding box's bounds + y1, y2 = np.where(rows)[0][[0, -1]] + x1, x2 = np.where(cols)[0][[0, -1]] + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + data_info = { + "img": img, + "img_id": os.path.basename(data_info["rgb_path"]), + "img_path": data_info["rgb_path"], + "gt_albedo": albedo, ## rgb format + "mask": mask, + "id": idx, + "bbox": bbox, + "bbox_score": np.ones(1, dtype=np.float32), + } + + return data_info diff --git a/sapiens/dense/src/datasets/albedo/albedo_render_people_dataset.py b/sapiens/dense/src/datasets/albedo/albedo_render_people_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1b2edcc5f989083484fcbc66ef71a3153faf44cd --- /dev/null +++ b/sapiens/dense/src/datasets/albedo/albedo_render_people_dataset.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from sapiens.registry import DATASETS + +from .albedo_base_dataset import AlbedoBaseDataset + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class AlbedoRenderPeopleDataset(AlbedoBaseDataset): + def __init__(self, **kwargs): + super().__init__(**kwargs) + return diff --git a/sapiens/dense/src/datasets/normal/__init__.py b/sapiens/dense/src/datasets/normal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f54270e9a40117af486d1187a64bb42a657cc906 --- /dev/null +++ b/sapiens/dense/src/datasets/normal/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .normal_base_dataset import NormalBaseDataset +from .normal_hi4d_dataset import NormalHi4dDataset +from .normal_metasim_dataset import NormalMetaSimDataset +from .normal_render_people_body_dataset import NormalRenderPeopleBodyDataset +from .normal_render_people_multihuman_dataset import NormalRenderPeopleMultihumanDataset +from .normal_thuman_dataset import NormalTHumanDataset + +__all__ = [ + "NormalBaseDataset", + "NormalHi4dDataset", + "NormalMetaSimDataset", + "NormalRenderPeopleBodyDataset", + "NormalRenderPeopleMultihumanDataset", + "NormalTHumanDataset", +] diff --git a/sapiens/dense/src/datasets/normal/normal_base_dataset.py b/sapiens/dense/src/datasets/normal/normal_base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e3839bc13b97a993272cb5cc87e0cefde9f0a168 --- /dev/null +++ b/sapiens/dense/src/datasets/normal/normal_base_dataset.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +import os +import sys +from typing import List + +import cv2 +import numpy as np +from sapiens.engine.datasets import BaseDataset +from sapiens.registry import DATASETS + + +@contextlib.contextmanager +def suppress_stderr(): + devnull_fd = os.open(os.devnull, os.O_WRONLY) + stderr_fd = sys.stderr.fileno() + sys.stderr.flush() + saved_stderr_fd = os.dup(stderr_fd) + + os.dup2(devnull_fd, stderr_fd) + os.close(devnull_fd) + + try: + yield + finally: + os.dup2(saved_stderr_fd, stderr_fd) + os.close(saved_stderr_fd) + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class NormalBaseDataset(BaseDataset): + def __init__( + self, + seg_data_root: str = None, + num_samples: int = None, + normal_extension: str = ".npy", + **kwargs, + ) -> None: + self.seg_data_root = seg_data_root + self.num_samples = num_samples + self.normal_extension = normal_extension + assert self.normal_extension in [".npy", ".npz"] + super().__init__(**kwargs) + + return + + def load_data_list(self) -> List[dict]: + data_list = [] + + self.rgb_dir = os.path.join(self.data_root, "rgb") + self.normal_dir = os.path.join(self.data_root, "normal") + self.mask_dir = os.path.join(self.data_root, "mask") + + print("\033[92mLoading {}!\033[0m".format(self.__class__.__name__)) + + # Create a set of common file names from all three directories + rgb_files = {x for x in os.listdir(self.rgb_dir) if x.endswith(".png")} + mask_files = {x for x in os.listdir(self.mask_dir) if x.endswith(".png")} + normal_files = { + x.replace(self.normal_extension, ".png") + for x in os.listdir(self.normal_dir) + if x.endswith(self.normal_extension) + } + + # Find the intersection of file names between images, masks, and normals + common_names = sorted(rgb_files & mask_files & normal_files) + + # Create data list using the common file names + data_list = [ + { + "rgb_path": os.path.join(self.rgb_dir, name), + "mask_path": os.path.join(self.mask_dir, name), + "normal_path": os.path.join( + self.normal_dir, name.replace(".png", self.normal_extension) + ), + } + for name in common_names + ] + + if self.num_samples is not None: + data_list = data_list[: self.num_samples] + + print( + "\033[92mDone! {}. Loaded total samples: {}. Test mode: {}\033[0m".format( + self.__class__.__name__, len(data_list), self.test_mode + ) + ) + + return data_list + + def get_data_info(self, idx): + data_info = copy.deepcopy(self.data_list[idx]) + try: + with suppress_stderr(): + img = cv2.imread(data_info["rgb_path"]) ## bgr image is default + normal = ( + np.load(data_info["normal_path"]) + if self.normal_extension == ".npy" + else np.load(data_info["normal_path"])["normal"] + ) + mask = cv2.imread(data_info["mask_path"]) + except Exception as e: + return None + + mask = mask[:, :, 0] ## H x W + + if self.seg_data_root is not None: + mask = self.ignore_face_pixels(mask, data_info["rgb_path"]) + + if mask is None or mask.sum() < 16: ## min pixels is 16 + return None + + ## remove any nan normals from valid pixels + nan_normals = np.isnan(normal).any(axis=2) + if np.any(nan_normals): + mask[nan_normals] = 0 + + ## check if the normal are normalized + normal_valid = normal[mask > 0] + norm_normal_valid = np.linalg.norm(normal_valid, axis=1) + tolerance = 1e-6 + is_normalized = np.all(norm_normal_valid > 1 - tolerance) & np.all( + norm_normal_valid < 1 + tolerance + ) + + if not is_normalized: + norms = np.linalg.norm(normal, axis=2, keepdims=True) + norms = np.maximum(norms, 1e-6) # Adding epsilon to avoid division by zero + normal = normal / norms + + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + + # Find the bounding box's bounds + y1, y2 = np.where(rows)[0][[0, -1]] + x1, x2 = np.where(cols)[0][[0, -1]] + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + data_info = { + "img": img, + "img_id": os.path.basename(data_info["rgb_path"]), + "img_path": data_info["rgb_path"], + "gt_normal": normal, + "mask": mask, + "id": idx, + "bbox": bbox, + "bbox_score": np.ones(1, dtype=np.float32), + } + + return data_info + + ## filter our face/hair pixels for supervision + def ignore_face_pixels(self, mask, img_path): + seg_name = os.path.basename(img_path).replace(".png", "_seg.npy") + seg_path = os.path.join(self.seg_data_root, seg_name) + + if not os.path.exists(seg_path): + return None + + try: + seg = np.load(seg_path) ## part segmentation. 28 classes + except Exception as e: + return None + + if seg.shape[0] != mask.shape[0] or seg.shape[1] != mask.shape[1]: + return None + + ## modify mask such that only body is considered. remove face_neck + hair. label 2 and 3 + mask[seg == 2] = 0 + mask[seg == 3] = 0 + mask[seg == 23] = 0 + mask[seg == 24] = 0 + mask[seg == 25] = 0 + mask[seg == 26] = 0 + mask[seg == 27] = 0 + + return mask diff --git a/sapiens/dense/src/datasets/normal/normal_hi4d_dataset.py b/sapiens/dense/src/datasets/normal/normal_hi4d_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..47167361cb8dbc95e6274fc208030ca68e199adf --- /dev/null +++ b/sapiens/dense/src/datasets/normal/normal_hi4d_dataset.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from sapiens.registry import DATASETS + +from .normal_base_dataset import NormalBaseDataset + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class NormalHi4dDataset(NormalBaseDataset): + def __init__(self, **kwargs): + super().__init__(**kwargs) + return diff --git a/sapiens/dense/src/datasets/normal/normal_metasim_dataset.py b/sapiens/dense/src/datasets/normal/normal_metasim_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8e885333f4a8a1c6e76ae4acf187365cd9f29746 --- /dev/null +++ b/sapiens/dense/src/datasets/normal/normal_metasim_dataset.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import io +import json +import os +import random +from typing import Any, Iterator, List + +import cv2 +import numpy as np +import torch +from iopath.common.file_io import PathManager +from sapiens.registry import DATASETS + +from .normal_base_dataset import NormalBaseDataset, suppress_stderr + +try: + from airstore.client.airstore_tabular import AIRStorePathHandler +except: + pass + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class NormalMetaSimDataset(NormalBaseDataset): + def __init__(self, airstore_template=None, json_path=None, **kwargs) -> None: + self.airstore_template = airstore_template + self.json_path = json_path + self._cached_iterator = None + self.global_rank = int(os.environ.get("RANK", 0)) + self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + super().__init__(**kwargs) + return + + def load_data_list(self) -> List[dict]: + data_list = [] + + self.path_manager = PathManager() + self.path_manager.register_handler(AIRStorePathHandler()) + + print("\033[92mLoading {}!\033[0m".format(self.__class__.__name__)) + + ## load json file + with open(self.json_path, "r") as f: + data_list = json.load(f) + + if self.num_samples is not None: + data_list = data_list[: self.num_samples] + + print( + "\033[92mDone! {}. Loaded total samples: {}\033[0m".format( + self.__class__.__name__, len(data_list) + ) + ) + + return data_list + + def _open_iterator(self) -> Iterator[Any]: + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + num_workers = 1 + worker_id = 0 + else: + num_workers = worker_info.num_workers + worker_id = worker_info.id + + airstore_rank = self.global_rank * num_workers + worker_id + random_seed = ( + random.randint(0, 100000) + random.randint(0, 100000) * airstore_rank + ) + return self.path_manager.opent( + self.airstore_template, seed=random_seed, enable_shuffle=True + ) + + def get_data_info(self, idx): + data_info = copy.deepcopy(self.data_list[idx]) + + if self._cached_iterator is None: + self._cached_iterator = self._open_iterator() + + while True: + try: + with suppress_stderr(): + row = next(self._cached_iterator) + img_buf = np.frombuffer(row["image"], dtype=np.uint8) + img = cv2.imdecode(img_buf, cv2.IMREAD_COLOR) ## this in BGR format + + mask_buf = np.frombuffer(row["mask"], dtype=np.uint8) + mask = cv2.imdecode( + mask_buf, cv2.IMREAD_GRAYSCALE + ) ## this in BGR format + + normal = np.load(io.BytesIO(row["normal"]))["normal"] + break + + except Exception as e: + print(f"Error loading data: {e}, {data_info['rgb_path']}") + return None + + if mask.sum() < 16: + return None + + ## remove any nan normals from valid pixels + nan_normals = np.isnan(normal).any(axis=2) + if np.any(nan_normals): + mask[nan_normals] = 0 + + ## check if the normal are normalized + normal_valid = normal[mask > 0] + norm_normal_valid = np.linalg.norm(normal_valid, axis=1) + + tolerance = 1e-6 + is_normalized = np.all(norm_normal_valid > 1 - tolerance) & np.all( + norm_normal_valid < 1 + tolerance + ) + + if not is_normalized: + norms = np.linalg.norm(normal, axis=2, keepdims=True) + norms = np.maximum(norms, 1e-6) # Avoid division by zero + normal = normal / norms + + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + + # Find the bounding box's bounds + y1, y2 = np.where(rows)[0][[0, -1]] + x1, x2 = np.where(cols)[0][[0, -1]] + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + data_info = { + "img": img, + "img_id": os.path.basename(data_info["rgb_path"]), + "img_path": data_info["rgb_path"], + "gt_normal": normal, + "mask": mask, + "id": idx, + "bbox": bbox, + "bbox_score": np.ones(1, dtype=np.float32), + } + + return data_info diff --git a/sapiens/dense/src/datasets/normal/normal_render_people_body_dataset.py b/sapiens/dense/src/datasets/normal/normal_render_people_body_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9e30417c729ea7a362c2e76b76601cc548cbfb --- /dev/null +++ b/sapiens/dense/src/datasets/normal/normal_render_people_body_dataset.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import cv2 +import numpy as np +from sapiens.registry import DATASETS + +from .normal_base_dataset import NormalBaseDataset + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class NormalRenderPeopleBodyDataset(NormalBaseDataset): + def ignore_face_pixels(self, mask, img_path): + seg_name = os.path.basename(img_path).replace(".png", "_seg.npy") + seg_path = os.path.join(self.seg_data_root, seg_name) + + if not os.path.exists(seg_path): + return None + + seg = np.load(seg_path) ## part segmentation. 28 classes + + if ( + mask.shape[0] != 4096 + or mask.shape[1] != 3072 + or seg.shape[0] != 1024 + or seg.shape[1] != 768 + ): + return None + + ## nearest neighbor upsample seg + seg = cv2.resize( + seg, (mask.shape[1], mask.shape[0]), interpolation=cv2.INTER_NEAREST + ) + + ## modify mask such that only body is considered. remove face_neck + hair. label 2 and 3 + mask[seg == 2] = 0 + mask[seg == 3] = 0 + mask[seg == 23] = 0 + mask[seg == 24] = 0 + mask[seg == 25] = 0 + mask[seg == 26] = 0 + mask[seg == 27] = 0 + + return mask diff --git a/sapiens/dense/src/datasets/normal/normal_render_people_multihuman_dataset.py b/sapiens/dense/src/datasets/normal/normal_render_people_multihuman_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4b1c2d79467a43f936aec93ad6b7d9ef4cf9598c --- /dev/null +++ b/sapiens/dense/src/datasets/normal/normal_render_people_multihuman_dataset.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from sapiens.registry import DATASETS + +from .normal_base_dataset import NormalBaseDataset + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class NormalRenderPeopleMultihumanDataset(NormalBaseDataset): + def __init__(self, **kwargs): + super().__init__(**kwargs) + return diff --git a/sapiens/dense/src/datasets/normal/normal_thuman_dataset.py b/sapiens/dense/src/datasets/normal/normal_thuman_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7a963de600d25251f8f6712b2f556af0e440d794 --- /dev/null +++ b/sapiens/dense/src/datasets/normal/normal_thuman_dataset.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from sapiens.registry import DATASETS + +from .normal_base_dataset import NormalBaseDataset + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class NormalTHumanDataset(NormalBaseDataset): + def __init__(self, **kwargs): + super().__init__(**kwargs) + return diff --git a/sapiens/dense/src/datasets/pointmap/__init__.py b/sapiens/dense/src/datasets/pointmap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93537aa93c2c5ce6e370627ad0826dfdc5234d15 --- /dev/null +++ b/sapiens/dense/src/datasets/pointmap/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .pointmap_base_dataset import PointmapBaseDataset +from .pointmap_render_people_dataset import PointmapRenderPeopleDataset + + +__all__ = ["PointmapBaseDataset", "PointmapRenderPeopleDataset"] diff --git a/sapiens/dense/src/datasets/pointmap/pointmap_base_dataset.py b/sapiens/dense/src/datasets/pointmap/pointmap_base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..532ae792dd4ef7c9271c03de1044b161fd3daaf4 --- /dev/null +++ b/sapiens/dense/src/datasets/pointmap/pointmap_base_dataset.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +import os +import sys +from typing import List + +import cv2 +import numpy as np +from sapiens.engine.datasets import BaseDataset +from sapiens.registry import DATASETS + + +@contextlib.contextmanager +def suppress_stderr(): + devnull_fd = os.open(os.devnull, os.O_WRONLY) + stderr_fd = sys.stderr.fileno() + sys.stderr.flush() + saved_stderr_fd = os.dup(stderr_fd) + + os.dup2(devnull_fd, stderr_fd) + os.close(devnull_fd) + + try: + yield + finally: + os.dup2(saved_stderr_fd, stderr_fd) + os.close(saved_stderr_fd) + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class PointmapBaseDataset(BaseDataset): + def __init__(self, num_samples=None, **kwargs) -> None: + self.num_samples = num_samples + super().__init__(**kwargs) + return + + def load_data_list(self) -> List[dict]: + data_list = [] + + self.rgb_dir = os.path.join(self.data_root, "rgb") + self.mask_dir = os.path.join(self.data_root, "mask") + self.depth_dir = os.path.join(self.data_root, "depth") + self.K_dir = os.path.join(self.data_root, "camera_intrinsics") + self.M_dir = os.path.join( + self.data_root, "camera_extrinsics" + ) ## cv camera extrinsics + + print(f"\033[92mLoading {self.__class__.__name__}!\033[0m") + + # Create a set of common file names from all three directories + rgb_files = {x for x in os.listdir(self.rgb_dir) if x.endswith(".png")} + mask_files = {x for x in os.listdir(self.mask_dir) if x.endswith(".png")} + depth_files = { + x.replace(".npy", ".png") + for x in os.listdir(self.depth_dir) + if x.endswith(".npy") + } + K_files = { + x.replace(".txt", ".png") + for x in os.listdir(self.K_dir) + if x.endswith(".txt") + } + M_files = { + x.replace(".txt", ".png") + for x in os.listdir(self.M_dir) + if x.endswith(".txt") + } + + # Find the intersection of file names between images, masks, and normals + common_names = rgb_files & mask_files & depth_files & K_files & M_files + + # Create data list using the common file names + data_list = [ + { + "rgb_path": os.path.join(self.rgb_dir, name), + "mask_path": os.path.join(self.mask_dir, name), + "depth_path": os.path.join( + self.depth_dir, name.replace(".png", ".npy") + ), + "K_path": os.path.join(self.K_dir, name.replace(".png", ".txt")), + "M_path": os.path.join(self.M_dir, name.replace(".png", ".txt")), + } + for name in sorted(common_names) + ] + + if self.num_samples is not None: + data_list = data_list[: self.num_samples] + + print( + "\033[92mDone! {}. Loaded total samples: {}. Test mode: {}\033[0m".format( + self.__class__.__name__, len(data_list), self.test_mode + ) + ) + return data_list + + def get_data_info(self, idx): + data_info = copy.deepcopy(self.data_list[idx]) + try: + with suppress_stderr(): + image = cv2.imread(data_info["rgb_path"]) ## bgr image is default + mask = cv2.imread(data_info["mask_path"]) + depth = np.load(data_info["depth_path"]) ## H x W, ## is not in 0 to 1 + K = np.loadtxt(data_info["K_path"]) ## intrinsics, 3 x 3 + M = np.loadtxt(data_info["M_path"]) ## extrinsics, 4 x 4 + + except Exception as e: + return None + + mask = mask[:, :, 0] ## + + if image is None or mask is None or depth is None: + return None + + ## remove any nan depth from valid pixels + nan_depth = np.isnan(depth) + if np.any(nan_depth): + mask[nan_depth] = 0 + + if mask.sum() < 10: + return None + + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + + # Find the bounding box's bounds + y1, y2 = np.where(rows)[0][[0, -1]] + x1, x2 = np.where(cols)[0][[0, -1]] + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + data_info = { + "img": image, + "id": idx, + "orig_img_height": image.shape[0], + "orig_img_width": image.shape[1], + "img_id": os.path.basename(data_info["rgb_path"]), + "img_path": data_info["rgb_path"], + "gt_depth": depth, + "K": K, + "M": M, + "mask": mask, + "bbox": bbox, + "bbox_score": np.ones(1, dtype=np.float32), + } + + return data_info diff --git a/sapiens/dense/src/datasets/pointmap/pointmap_render_people_dataset.py b/sapiens/dense/src/datasets/pointmap/pointmap_render_people_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..30c76fefa2999beed15d806587b14b3ebeeee6a5 --- /dev/null +++ b/sapiens/dense/src/datasets/pointmap/pointmap_render_people_dataset.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from sapiens.registry import DATASETS + +from .pointmap_base_dataset import PointmapBaseDataset + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class PointmapRenderPeopleDataset(PointmapBaseDataset): + def __init__(self, **kwargs): + super().__init__(**kwargs) + return diff --git a/sapiens/dense/src/datasets/seg/__init__.py b/sapiens/dense/src/datasets/seg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..248c95735ba3f56bde8a639263cda5d4b14b70b8 --- /dev/null +++ b/sapiens/dense/src/datasets/seg/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .seg_base_dataset import SegBaseDataset +from .seg_dome_dataset import SegDomeClass29Dataset +from .seg_internal_dataset import SegInternalClass29Dataset +from .seg_shutterstock_dataset import SegShutterstockClass29Dataset +from .seg_utils import DOME_CLASSES_29 + +__all__ = [ + "SegBaseDataset", + "SegDomeClass29Dataset", + "SegShutterstockClass29Dataset", + "SegInternalClass29Dataset", + "DOME_CLASSES_29", +] diff --git a/sapiens/dense/src/datasets/seg/seg_base_dataset.py b/sapiens/dense/src/datasets/seg/seg_base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..185d3ebadd9641eb0ce2644dfd5d18fe37587da6 --- /dev/null +++ b/sapiens/dense/src/datasets/seg/seg_base_dataset.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +import io +import json +import os +from typing import List + +import numpy as np +from PIL import Image +from sapiens.engine.datasets import BaseDataset +from sapiens.registry import DATASETS + +with open(os.devnull, "w") as f, contextlib.redirect_stderr(f): + try: + from care.data.io import typed + from typedio.file_system.airstore_client import register_airstore_in_fsspec + + register_airstore_in_fsspec() + except Exception: + pass + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class SegBaseDataset(BaseDataset): + def __init__( + self, + ann_file=None, + num_samples=None, + classes=None, + palette=None, + source_to_target_index_mapping=None, + **kwargs, + ) -> None: + self.ann_file = ann_file + self.classes = classes + self.palette = palette + self.source_to_target_index_mapping = source_to_target_index_mapping + self.num_samples = num_samples + self.path_template = ( + "airstoreds://rlr_detection_services_ml_datasets_no_user_data" + ) + super().__init__(**kwargs) + + def _read_from_airstore(self, asset: str, sid: str) -> io.BytesIO: + with typed.open(self.path_template + f"/{asset}?sampleId={sid}").open() as f: + data = io.BytesIO(f.read()) + return data + + def load_data_list(self) -> List[dict]: + data_list = [] + + with open(self.ann_file, "rb") as f: + raw = f.read() + raw_data = json.loads(raw) + + print("\033[92mLoading {}!\033[0m".format(self.__class__.__name__)) + data_list = [] + for sample in raw_data: + dp = { + "airstore_id": sample["sample_id"], + "session_id": str(sample["session_id"]), + "camera_id": str(sample["camera_id"]), + "frame_id": str(sample["frame_number"]), + } + if sample.get("box-default") is not None: + dp["box"] = sample["box-default"] + data_list.append(dp) + + data_list = sorted( + data_list, key=lambda y: (y["session_id"], y["camera_id"], y["frame_id"]) + ) + + if self.num_samples is not None: + data_list = data_list[: self.num_samples] + + print( + "\033[92mDone! {}. Loaded total samples: {}. Test mode: {}\033[0m".format( + self.__class__.__name__, len(data_list), self.test_mode + ) + ) + return data_list + + def get_data_info(self, idx): + data_info = copy.deepcopy(self.data_list[idx]) + try: + img = Image.open( + self._read_from_airstore("image", data_info["airstore_id"]) + ) ## pillow image + segmentation = Image.open( + self._read_from_airstore("segmentation", data_info["airstore_id"]) + ) + + except Exception as e: + print( + f"Error loading image/seg {data_info['airstore_id']} in {self.__class__.__name__}. Retrying!" + ) + + return None + + # Important: Convert RGB to BGR, the pretrained model preprocessor will convert this to rgb again + img = np.array(img) ## rgb image + img = img[:, :, ::-1] + segmentation = np.array(segmentation) + + ##------remove the extra classes--- + if self.source_to_target_index_mapping is not None: + segmentation = np.vectorize( + lambda x: self.source_to_target_index_mapping.get(x, 255) + )(segmentation) + + ## get bbox + mask = (segmentation > 0).astype("uint8") ## 2D binary mask + if mask.sum() < 8 and self.test_mode is False: # too small mask + return None + + data_info = { + "img": img, + "img_id": "", + "img_path": data_info["airstore_id"], + "gt_seg": segmentation, + "id": idx, + "orig_img_height": img.shape[0], + "orig_img_width": img.shape[1], + } + + return data_info diff --git a/sapiens/dense/src/datasets/seg/seg_dome_dataset.py b/sapiens/dense/src/datasets/seg/seg_dome_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..04cbef31e546a69a89f179f33d3c119f4ef6f3be --- /dev/null +++ b/sapiens/dense/src/datasets/seg/seg_dome_dataset.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from sapiens.registry import DATASETS + +from .seg_base_dataset import SegBaseDataset +from .seg_utils import DOME_CLASSES_29, DOME_CLASSES_34, DOME_MAPPING_34_to_29 + +CLASSES = [DOME_CLASSES_29[i]["name"] for i in range(len(DOME_CLASSES_29))] +PALETTE = [DOME_CLASSES_29[i]["color"] for i in range(len(DOME_CLASSES_29))] + +SOURCE_TO_TARGET_INDEX_MAPPING = { + i: DOME_MAPPING_34_to_29[i]["target_class_idx"] + for i in range(len(DOME_CLASSES_34)) + if DOME_MAPPING_34_to_29[i]["target_class_idx"] is not None +} + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class SegDomeClass29Dataset(SegBaseDataset): + def __init__(self, **kwargs): + super().__init__( + classes=CLASSES, + palette=PALETTE, + source_to_target_index_mapping=SOURCE_TO_TARGET_INDEX_MAPPING, + **kwargs, + ) + return diff --git a/sapiens/dense/src/datasets/seg/seg_internal_dataset.py b/sapiens/dense/src/datasets/seg/seg_internal_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..645bf494eaaee2f52bc254a01c082770449c8f87 --- /dev/null +++ b/sapiens/dense/src/datasets/seg/seg_internal_dataset.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from sapiens.registry import DATASETS + +from .seg_base_dataset import SegBaseDataset +from .seg_utils import DOME_CLASSES_29, INTERNAL_CLASSES_32, INTERNAL_MAPPING_32_to_29 + +CLASSES = [DOME_CLASSES_29[i]["name"] for i in range(len(DOME_CLASSES_29))] +PALETTE = [DOME_CLASSES_29[i]["color"] for i in range(len(DOME_CLASSES_29))] + +SOURCE_TO_TARGET_INDEX_MAPPING = { + i: INTERNAL_MAPPING_32_to_29[i]["target_class_idx"] + for i in range(len(INTERNAL_CLASSES_32)) + if INTERNAL_MAPPING_32_to_29[i]["target_class_idx"] is not None +} + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class SegInternalClass29Dataset(SegBaseDataset): + def __init__(self, **kwargs): + super().__init__( + classes=CLASSES, + palette=PALETTE, + source_to_target_index_mapping=SOURCE_TO_TARGET_INDEX_MAPPING, + **kwargs, + ) + self.is_itw = True + return + + def get_data_info(self, idx): + data = super().get_data_info(idx) + + if data is None: + return None + data["is_itw"] = self.is_itw + return data diff --git a/sapiens/dense/src/datasets/seg/seg_shutterstock_dataset.py b/sapiens/dense/src/datasets/seg/seg_shutterstock_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c8faed9084a7dcda3a2338c5ad0243ece02ed5 --- /dev/null +++ b/sapiens/dense/src/datasets/seg/seg_shutterstock_dataset.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from sapiens.registry import DATASETS + +from .seg_base_dataset import SegBaseDataset +from .seg_utils import DOME_CLASSES_29, DOME_CLASSES_52, DOME_MAPPING_52_to_29 + +CLASSES = [DOME_CLASSES_29[i]["name"] for i in range(len(DOME_CLASSES_29))] +PALETTE = [DOME_CLASSES_29[i]["color"] for i in range(len(DOME_CLASSES_29))] + +SOURCE_TO_TARGET_INDEX_MAPPING = { + i: DOME_MAPPING_52_to_29[i]["target_class_idx"] + for i in range(len(DOME_CLASSES_52)) + if DOME_MAPPING_52_to_29[i]["target_class_idx"] is not None +} + + +##----------------------------------------------------------------------- +@DATASETS.register_module() +class SegShutterstockClass29Dataset(SegBaseDataset): + def __init__(self, **kwargs): + super().__init__( + classes=CLASSES, + palette=PALETTE, + source_to_target_index_mapping=SOURCE_TO_TARGET_INDEX_MAPPING, + **kwargs, + ) + self.is_itw = "itw_" in self.ann_file + return + + def get_data_info(self, idx): + data = super().get_data_info(idx) + if data is None: + return None + data["is_itw"] = self.is_itw + return data diff --git a/sapiens/dense/src/datasets/seg/seg_utils.py b/sapiens/dense/src/datasets/seg/seg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1b4d6c36fb1acc8482341f754f5f503ffc4cb3eb --- /dev/null +++ b/sapiens/dense/src/datasets/seg/seg_utils.py @@ -0,0 +1,804 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# ruff: noqa +# flake8: noqa +# fmt: off +from collections import OrderedDict + +INTERNAL_CLASSES_32 = OrderedDict( + [ + (0, {"name": "Background", "color": [0, 0, 0]}), + (1, {"name": "Face_Neck", "color": [0, 0, 0]}), + (2, {"name": "Nose", "color": [0, 0, 0]}), + (3, {"name": "Left_Ear", "color": [0, 0, 0]}), + (4, {"name": "Right_Ear", "color": [0, 0, 0]}), + (5, {"name": "Left_Iris", "color": [0, 0, 0]}), + (6, {"name": "Left_Sclera", "color": [0, 0, 0]}), + (7, {"name": "Right_Iris", "color": [0, 0, 0]}), + (8, {"name": "Right_Sclera", "color": [0, 0, 0]}), + (9, {"name": "Left_Eyebrow", "color": [0, 0, 0]}), + (10, {"name": "Right_Eyebrow", "color": [0, 0, 0]}), + (11, {"name": "Lower_Lip", "color": [0, 0, 0]}), + (12, {"name": "Upper_Lip", "color": [0, 0, 0]}), + (13, {"name": "Lower_Teeth", "color": [0, 0, 0]}), + (14, {"name": "Upper_Teeth", "color": [0, 0, 0]}), + (15, {"name": "Tongue", "color": [0, 0, 0]}), + (16, {"name": "Hair", "color": [0, 0, 0]}), + (17, {"name": "Left_Arm", "color": [0, 0, 0]}), + (18, {"name": "Left_Leg", "color": [0, 0, 0]}), + (19, {"name": "Right_Arm", "color": [0, 0, 0]}), + (20, {"name": "Right_Leg", "color": [0, 0, 0]}), + (21, {"name": "Torso", "color": [0, 0, 0]}), + (22, {"name": "Left_Hand", "color": [0, 0, 0]}), + (23, {"name": "Right_Hand", "color": [0, 0, 0]}), + (24, {"name": "Upper_Clothing", "color": [0, 0, 0]}), + (25, {"name": "Lower_Clothing", "color": [0, 0, 0]}), + (26, {"name": "Left_Shoe_Sock", "color": [0, 0, 0]}), + (27, {"name": "Right_Shoe_Sock", "color": [0, 0, 0]}), + (28, {"name": "Apparel", "color": [0, 0, 0]}), + (29, {"name": "Glasses", "color": [0, 0, 0]}), + (30, {"name": "Chair", "color": [0, 0, 0]}), + (31, {"name": "Headset", "color": [0, 0, 0]}), + (32, {"name": "Occluder", "color": [0, 0, 0]}), + ] +) + + +DOME_CLASSES_34 = OrderedDict( + [ + (0, {"name": "Background", "color": [50, 50, 50]}), + (1, {"name": "Apparel", "color": [255, 218, 0]}), + (2, {"name": "Chair", "color": [102, 204, 0]}), + (3, {"name": "Eyeglass_Frame", "color": [14, 0, 204]}), + (4, {"name": "Eyeglass_Lenses", "color": [0, 204, 160]}), + (5, {"name": "Face_Neck", "color": [128, 200, 255]}), + (6, {"name": "Hair", "color": [255, 0, 109]}), + (7, {"name": "Headset", "color": [0, 255, 36]}), + (8, {"name": "Left_Foot", "color": [189, 0, 204]}), + (9, {"name": "Left_Hand", "color": [255, 0, 218]}), + (10, {"name": "Left_Lower_Arm", "color": [0, 160, 204]}), + (11, {"name": "Left_Lower_Leg", "color": [0, 255, 145]}), + (12, {"name": "Left_Shoe", "color": [204, 0, 131]}), + (13, {"name": "Left_Sock", "color": [182, 0, 255]}), + (14, {"name": "Left_Upper_Arm", "color": [255, 109, 0]}), + (15, {"name": "Left_Upper_Leg", "color": [0, 255, 255]}), + (16, {"name": "Lower_Clothing", "color": [72, 0, 255]}), + (17, {"name": "Lower_Spandex", "color": [204, 43, 0]}), + (18, {"name": "Right_Foot", "color": [204, 131, 0]}), + (19, {"name": "Right_Hand", "color": [255, 0, 0]}), + (20, {"name": "Right_Lower_Arm", "color": [72, 255, 0]}), + (21, {"name": "Right_Lower_Leg", "color": [189, 204, 0]}), + (22, {"name": "Right_Shoe", "color": [182, 255, 0]}), + (23, {"name": "Right_Sock", "color": [102, 0, 204]}), + (24, {"name": "Right_Upper_Arm", "color": [32, 72, 204]}), + (25, {"name": "Right_Upper_Leg", "color": [0, 145, 255]}), + (26, {"name": "Torso", "color": [14, 204, 0]}), + (27, {"name": "Upper_Clothing", "color": [0, 128, 72]}), + (28, {"name": "Visible_Badge", "color": [204, 0, 43]}), + (29, {"name": "Lower_Lip", "color": [235, 205, 119]}), + (30, {"name": "Upper_Lip", "color": [115, 227, 112]}), + (31, {"name": "Lower_Teeth", "color": [157, 113, 143]}), + (32, {"name": "Upper_Teeth", "color": [132, 93, 50]}), + (33, {"name": "Tongue", "color": [82, 21, 114]}), + ] +) + +DOME_CLASSES_29 = OrderedDict( + [ + (0, {"name": "Background", "color": [50, 50, 50]}), + (1, {"name": "Apparel", "color": [255, 218, 0]}), + (2, {"name": "Eyeglass", "color": [14, 204, 182]}), + (3, {"name": "Face_Neck", "color": [128, 200, 255]}), + (4, {"name": "Hair", "color": [255, 0, 109]}), + (5, {"name": "Left_Foot", "color": [189, 0, 204]}), + (6, {"name": "Left_Hand", "color": [255, 0, 218]}), + (7, {"name": "Left_Lower_Arm", "color": [0, 160, 204]}), + (8, {"name": "Left_Lower_Leg", "color": [0, 255, 145]}), + (9, {"name": "Left_Shoe", "color": [204, 0, 131]}), + (10, {"name": "Left_Sock", "color": [182, 0, 255]}), + (11, {"name": "Left_Upper_Arm", "color": [255, 109, 0]}), + (12, {"name": "Left_Upper_Leg", "color": [0, 255, 255]}), + (13, {"name": "Lower_Clothing", "color": [72, 0, 255]}), + (14, {"name": "Right_Foot", "color": [204, 131, 0]}), + (15, {"name": "Right_Hand", "color": [255, 0, 0]}), + (16, {"name": "Right_Lower_Arm", "color": [72, 255, 0]}), + (17, {"name": "Right_Lower_Leg", "color": [189, 204, 0]}), + (18, {"name": "Right_Shoe", "color": [182, 255, 0]}), + (19, {"name": "Right_Sock", "color": [102, 0, 204]}), + (20, {"name": "Right_Upper_Arm", "color": [32, 72, 204]}), + (21, {"name": "Right_Upper_Leg", "color": [0, 145, 255]}), + (22, {"name": "Torso", "color": [14, 204, 0]}), + (23, {"name": "Upper_Clothing", "color": [0, 128, 72]}), + (24, {"name": "Lower_Lip", "color": [235, 205, 119]}), + (25, {"name": "Upper_Lip", "color": [115, 227, 112]}), + (26, {"name": "Lower_Teeth", "color": [157, 113, 143]}), + (27, {"name": "Upper_Teeth", "color": [132, 93, 50]}), + (28, {"name": "Tongue", "color": [82, 21, 114]}), + ] +) + +DOME_CLASSES_28 = OrderedDict( + [ + (0, {"name": "Background", "color": [50, 50, 50]}), + (1, {"name": "Apparel", "color": [255, 218, 0]}), + (2, {"name": "Face_Neck", "color": [128, 200, 255]}), + (3, {"name": "Hair", "color": [255, 0, 109]}), + (4, {"name": "Left_Foot", "color": [189, 0, 204]}), + (5, {"name": "Left_Hand", "color": [255, 0, 218]}), + (6, {"name": "Left_Lower_Arm", "color": [0, 160, 204]}), + (7, {"name": "Left_Lower_Leg", "color": [0, 255, 145]}), + (8, {"name": "Left_Shoe", "color": [204, 0, 131]}), + (9, {"name": "Left_Sock", "color": [182, 0, 255]}), + (10, {"name": "Left_Upper_Arm", "color": [255, 109, 0]}), + (11, {"name": "Left_Upper_Leg", "color": [0, 255, 255]}), + (12, {"name": "Lower_Clothing", "color": [72, 0, 255]}), + (13, {"name": "Right_Foot", "color": [204, 131, 0]}), + (14, {"name": "Right_Hand", "color": [255, 0, 0]}), + (15, {"name": "Right_Lower_Arm", "color": [72, 255, 0]}), + (16, {"name": "Right_Lower_Leg", "color": [189, 204, 0]}), + (17, {"name": "Right_Shoe", "color": [182, 255, 0]}), + (18, {"name": "Right_Sock", "color": [102, 0, 204]}), + (19, {"name": "Right_Upper_Arm", "color": [32, 72, 204]}), + (20, {"name": "Right_Upper_Leg", "color": [0, 145, 255]}), + (21, {"name": "Torso", "color": [14, 204, 0]}), + (22, {"name": "Upper_Clothing", "color": [0, 128, 72]}), + (23, {"name": "Lower_Lip", "color": [235, 205, 119]}), + (24, {"name": "Upper_Lip", "color": [115, 227, 112]}), + (25, {"name": "Lower_Teeth", "color": [157, 113, 143]}), + (26, {"name": "Upper_Teeth", "color": [132, 93, 50]}), + (27, {"name": "Tongue", "color": [82, 21, 114]}), + ] +) + +DOME_CLASSES_52 = OrderedDict( + [ + (0, {"name": "Background", "color": [50, 50, 50]}), + (1, {"name": "Hairnet", "color": [255, 105, 180]}), # Hot pink + (2, {"name": "Upper_Gum", "color": [220, 20, 60]}), # Crimson + (3, {"name": "Palate", "color": [255, 182, 193]}), # Light pink + (4, {"name": "Lower_Gum", "color": [178, 34, 34]}), # Firebrick + (5, {"name": "Floor_of_Mouth", "color": [255, 160, 122]}), # Light salmon + (6, {"name": "Cheeks_interior", "color": [219, 112, 147]}), # Pale violet red + (7, {"name": "Eyelash", "color": [0, 0, 0]}), # Black + (8, {"name": "Eyebrow", "color": [101, 67, 33]}), # Brown + (9, {"name": "Headset", "color": [0, 255, 36]}), # Bright green + (10, {"name": "Left_Foot", "color": [189, 0, 204]}), + (11, {"name": "Left_Lower_Leg", "color": [0, 255, 145]}), + (12, {"name": "Left_Shoe", "color": [204, 0, 131]}), + (13, {"name": "Left_Sock", "color": [182, 0, 255]}), + (14, {"name": "Left_Upper_Leg", "color": [0, 255, 255]}), + (15, {"name": "Right_Foot", "color": [204, 131, 0]}), + (16, {"name": "Right_Lower_Leg", "color": [189, 204, 0]}), + (17, {"name": "Right_Shoe", "color": [182, 255, 0]}), + (18, {"name": "Right_Sock", "color": [102, 0, 204]}), + (19, {"name": "Right_Upper_Leg", "color": [0, 145, 255]}), + (20, {"name": "Torso", "color": [14, 204, 0]}), + (21, {"name": "Upper_Clothing", "color": [0, 128, 72]}), + (22, {"name": "Lower_Clothing", "color": [72, 0, 255]}), + (23, {"name": "Face_Neck_Skin", "color": [138, 190, 255]}), + (24, {"name": "Tongue", "color": [82, 21, 114]}), + (25, {"name": "Eyeglass_Frame", "color": [14, 0, 204]}), + (26, {"name": "Eyeglass_Lenses", "color": [0, 204, 160]}), + (27, {"name": "Badge", "color": [255, 165, 0]}), # Orange + (28, {"name": "Right_Hand", "color": [255, 0, 0]}), + (29, {"name": "Right_Lower_Arm", "color": [72, 255, 0]}), + (30, {"name": "Right_Upper_Arm", "color": [32, 72, 204]}), + (31, {"name": "Left_Hand", "color": [255, 0, 218]}), + (32, {"name": "Left_Lower_Arm", "color": [0, 160, 204]}), + (33, {"name": "Left_Upper_Arm", "color": [255, 109, 0]}), + (34, {"name": "Apparel", "color": [255, 218, 0]}), + (35, {"name": "Upper_Teeth", "color": [132, 93, 50]}), + (36, {"name": "Lower_Teeth", "color": [157, 113, 143]}), + (37, {"name": "Lower_Lip", "color": [235, 205, 119]}), + (38, {"name": "Upper_Lip", "color": [115, 227, 112]}), + (39, {"name": "Hair", "color": [255, 0, 109]}), + (40, {"name": "Chair", "color": [139, 69, 19]}), # Saddle brown + (41, {"name": "Right_thumbNail", "color": [255, 102, 102]}), + (42, {"name": "Right_indexNail", "color": [255, 102, 102]}), + (43, {"name": "Right_middleNail", "color": [255, 102, 102]}), + (44, {"name": "Right_ringNail", "color": [255, 102, 102]}), + (45, {"name": "Right_pinkyNail", "color": [255, 102, 102]}), + (46, {"name": "Left_thumbNail", "color": [255, 102, 102]}), + (47, {"name": "Left_indexNail", "color": [255, 102, 102]}), + (48, {"name": "Left_middleNail", "color": [255, 102, 102]}), + (49, {"name": "Left_ringNail", "color": [255, 102, 102]}), + (50, {"name": "Left_pinkyNail", "color": [255, 102, 102]}), + (51, {"name": "Occluder", "color": [255, 0, 255]}), + ] +) + +INTERNAL_MAPPING_32_to_29 = { + 0: {"name": "Background", "target_class_idx": 0, "target_class_name": "Background"}, + 1: {"name": "Face_Neck", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 2: {"name": "Nose", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 3: {"name": "Left_Ear", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 4: {"name": "Right_Ear", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 5: {"name": "Left_Iris", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 6: {"name": "Left_Sclera", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 7: {"name": "Right_Iris", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 8: {"name": "Right_Sclera", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 9: {"name": "Left_Eyebrow", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 10: {"name": "Right_Eyebrow", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 11: {"name": "Lower_Lip", "target_class_idx": 24, "target_class_name": "Lower_Lip"}, + 12: {"name": "Upper_Lip", "target_class_idx": 25, "target_class_name": "Upper_Lip"}, + 13: {"name": "Lower_Teeth", "target_class_idx": 26, "target_class_name": "Lower_Teeth"}, + 14: {"name": "Upper_Teeth", "target_class_idx": 27, "target_class_name": "Upper_Teeth"}, + 15: {"name": "Tongue", "target_class_idx": 28, "target_class_name": "Tongue"}, + 16: {"name": "Hair", "target_class_idx": 4, "target_class_name": "Hair"}, + 17: {"name": "Left_Arm", "target_class_idx": None, "target_class_name": None}, + 18: {"name": "Left_Leg", "target_class_idx": None, "target_class_name": None}, + 19: {"name": "Right_Arm", "target_class_idx": None, "target_class_name": None}, + 20: {"name": "Right_Leg", "target_class_idx": None, "target_class_name": None}, + 21: {"name": "Torso", "target_class_idx": 22, "target_class_name": "Torso"}, + 22: {"name": "Left_Hand", "target_class_idx": 6, "target_class_name": "Left_Hand"}, + 23: {"name": "Right_Hand", "target_class_idx": 15, "target_class_name": "Right_Hand"}, + 24: {"name": "Upper_Clothing", "target_class_idx": 23, "target_class_name": "Upper_Clothing"}, + 25: {"name": "Lower_Clothing", "target_class_idx": 13, "target_class_name": "Lower_Clothing"}, + 26: {"name": "Left_Shoe_Sock", "target_class_idx": None, "target_class_name": None}, + 27: {"name": "Right_Shoe_Sock","target_class_idx": None, "target_class_name": None}, + 28: {"name": "Apparel", "target_class_idx": 1, "target_class_name": "Apparel"}, + 29: {"name": "Glasses", "target_class_idx": 2, "target_class_name": "Eyeglass"}, + 30: {"name": "Chair", "target_class_idx": 0, "target_class_name": "Background"}, + 31: {"name": "Headset", "target_class_idx": None, "target_class_name": None}, + 32: {"name": "Occluder", "target_class_idx": 0, "target_class_name": "Background"}, +} + + +## source class idx: {name, target class idx, target class name} +DOME_MAPPING_52_to_29 = { + 0: {"name": "Background", "target_class_idx": 0, "target_class_name": "Background"}, + 1: {"name": "Hairnet", "target_class_idx": 1, "target_class_name": "Apparel"}, + 2: {"name": "Upper_Gum", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 3: {"name": "Palate", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 4: {"name": "Lower_Gum", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 5: { + "name": "Floor_of_Mouth", + "target_class_idx": 3, + "target_class_name": "Face_Neck", + }, + 6: { + "name": "Cheeks_interior", + "target_class_idx": 3, + "target_class_name": "Face_Neck", + }, + 7: {"name": "Eyelash", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 8: {"name": "Eyebrow", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 9: { + "name": "Headset", + "target_class_idx": None, + "target_class_name": None, + }, # not mapped + 10: {"name": "Left_Foot", "target_class_idx": 5, "target_class_name": "Left_Foot"}, + 11: { + "name": "Left_Lower_Leg", + "target_class_idx": 8, + "target_class_name": "Left_Lower_Leg", + }, + 12: {"name": "Left_Shoe", "target_class_idx": 9, "target_class_name": "Left_Shoe"}, + 13: {"name": "Left_Sock", "target_class_idx": 10, "target_class_name": "Left_Sock"}, + 14: { + "name": "Left_Upper_Leg", + "target_class_idx": 12, + "target_class_name": "Left_Upper_Leg", + }, + 15: { + "name": "Right_Foot", + "target_class_idx": 14, + "target_class_name": "Right_Foot", + }, + 16: { + "name": "Right_Lower_Leg", + "target_class_idx": 17, + "target_class_name": "Right_Lower_Leg", + }, + 17: { + "name": "Right_Shoe", + "target_class_idx": 18, + "target_class_name": "Right_Shoe", + }, + 18: { + "name": "Right_Sock", + "target_class_idx": 19, + "target_class_name": "Right_Sock", + }, + 19: { + "name": "Right_Upper_Leg", + "target_class_idx": 21, + "target_class_name": "Right_Upper_Leg", + }, + 20: {"name": "Torso", "target_class_idx": 22, "target_class_name": "Torso"}, + 21: { + "name": "Upper_Clothing", + "target_class_idx": 23, + "target_class_name": "Upper_Clothing", + }, + 22: { + "name": "Lower_Clothing", + "target_class_idx": 13, + "target_class_name": "Lower_Clothing", + }, + 23: { + "name": "Face_Neck_Skin", + "target_class_idx": 3, + "target_class_name": "Face_Neck", + }, + 24: {"name": "Tongue", "target_class_idx": 28, "target_class_name": "Tongue"}, + 25: { + "name": "Eyeglass_Frame", + "target_class_idx": 2, + "target_class_name": "Eyeglass", + }, + 26: { + "name": "Eyeglass_Lenses", + "target_class_idx": 2, + "target_class_name": "Eyeglass", + }, + 27: { + "name": "Badge", + "target_class_idx": None, + "target_class_name": None, + }, # don't care + 28: { + "name": "Right_Hand", + "target_class_idx": 15, + "target_class_name": "Right_Hand", + }, + 29: { + "name": "Right_Lower_Arm", + "target_class_idx": 16, + "target_class_name": "Right_Lower_Arm", + }, + 30: { + "name": "Right_Upper_Arm", + "target_class_idx": 20, + "target_class_name": "Right_Upper_Arm", + }, + 31: {"name": "Left_Hand", "target_class_idx": 6, "target_class_name": "Left_Hand"}, + 32: { + "name": "Left_Lower_Arm", + "target_class_idx": 7, + "target_class_name": "Left_Lower_Arm", + }, + 33: { + "name": "Left_Upper_Arm", + "target_class_idx": 11, + "target_class_name": "Left_Upper_Arm", + }, + 34: {"name": "Apparel", "target_class_idx": 1, "target_class_name": "Apparel"}, + 35: { + "name": "Upper_Teeth", + "target_class_idx": 27, + "target_class_name": "Upper_Teeth", + }, + 36: { + "name": "Lower_Teeth", + "target_class_idx": 26, + "target_class_name": "Lower_Teeth", + }, + 37: {"name": "Lower_Lip", "target_class_idx": 24, "target_class_name": "Lower_Lip"}, + 38: {"name": "Upper_Lip", "target_class_idx": 25, "target_class_name": "Upper_Lip"}, + 39: {"name": "Hair", "target_class_idx": 4, "target_class_name": "Hair"}, + 40: {"name": "Chair", "target_class_idx": 0, "target_class_name": "Background"}, + 41: { + "name": "Right_thumbNail", + "target_class_idx": 15, + "target_class_name": "Right_Hand", + }, + 42: { + "name": "Right_indexNail", + "target_class_idx": 15, + "target_class_name": "Right_Hand", + }, + 43: { + "name": "Right_middleNail", + "target_class_idx": 15, + "target_class_name": "Right_Hand", + }, + 44: { + "name": "Right_ringNail", + "target_class_idx": 15, + "target_class_name": "Right_Hand", + }, + 45: { + "name": "Right_pinkyNail", + "target_class_idx": 15, + "target_class_name": "Right_Hand", + }, + 46: { + "name": "Left_thumbNail", + "target_class_idx": 6, + "target_class_name": "Left_Hand", + }, + 47: { + "name": "Left_indexNail", + "target_class_idx": 6, + "target_class_name": "Left_Hand", + }, + 48: { + "name": "Left_middleNail", + "target_class_idx": 6, + "target_class_name": "Left_Hand", + }, + 49: { + "name": "Left_ringNail", + "target_class_idx": 6, + "target_class_name": "Left_Hand", + }, + 50: { + "name": "Left_pinkyNail", + "target_class_idx": 6, + "target_class_name": "Left_Hand", + }, + 51: {"name": "Occluder", "target_class_idx": 0, "target_class_name": "Background"}, +} + +DOME_MAPPING_52_to_28 = { + 0: {"name": "Background", "target_class_idx": 0, "target_class_name": "Background"}, + 1: {"name": "Hairnet", "target_class_idx": 1, "target_class_name": "Apparel"}, + 2: {"name": "Upper_Gum", "target_class_idx": 2, "target_class_name": "Face_Neck"}, + 3: {"name": "Palate", "target_class_idx": 2, "target_class_name": "Face_Neck"}, + 4: {"name": "Lower_Gum", "target_class_idx": 2, "target_class_name": "Face_Neck"}, + 5: { + "name": "Floor_of_Mouth", + "target_class_idx": 2, + "target_class_name": "Face_Neck", + }, + 6: { + "name": "Cheeks_interior", + "target_class_idx": 2, + "target_class_name": "Face_Neck", + }, + 7: {"name": "Eyelash", "target_class_idx": 2, "target_class_name": "Face_Neck"}, + 8: {"name": "Eyebrow", "target_class_idx": 2, "target_class_name": "Face_Neck"}, + 9: {"name": "Headset", "target_class_idx": None, "target_class_name": None}, + 10: {"name": "Left_Foot", "target_class_idx": 4, "target_class_name": "Left_Foot"}, + 11: { + "name": "Left_Lower_Leg", + "target_class_idx": 7, + "target_class_name": "Left_Lower_Leg", + }, + 12: {"name": "Left_Shoe", "target_class_idx": 8, "target_class_name": "Left_Shoe"}, + 13: {"name": "Left_Sock", "target_class_idx": 9, "target_class_name": "Left_Sock"}, + 14: { + "name": "Left_Upper_Leg", + "target_class_idx": 11, + "target_class_name": "Left_Upper_Leg", + }, + 15: { + "name": "Right_Foot", + "target_class_idx": 13, + "target_class_name": "Right_Foot", + }, + 16: { + "name": "Right_Lower_Leg", + "target_class_idx": 16, + "target_class_name": "Right_Lower_Leg", + }, + 17: { + "name": "Right_Shoe", + "target_class_idx": 17, + "target_class_name": "Right_Shoe", + }, + 18: { + "name": "Right_Sock", + "target_class_idx": 18, + "target_class_name": "Right_Sock", + }, + 19: { + "name": "Right_Upper_Leg", + "target_class_idx": 20, + "target_class_name": "Right_Upper_Leg", + }, + 20: {"name": "Torso", "target_class_idx": 21, "target_class_name": "Torso"}, + 21: { + "name": "Upper_Clothing", + "target_class_idx": 22, + "target_class_name": "Upper_Clothing", + }, + 22: { + "name": "Lower_Clothing", + "target_class_idx": 12, + "target_class_name": "Lower_Clothing", + }, + 23: { + "name": "Face_Neck_Skin", + "target_class_idx": 2, + "target_class_name": "Face_Neck", + }, + 24: {"name": "Tongue", "target_class_idx": 27, "target_class_name": "Tongue"}, + 25: { + "name": "Eyeglass_Frame", + "target_class_idx": 2, + "target_class_name": "Face_Neck", + }, + 26: { + "name": "Eyeglass_Lenses", + "target_class_idx": 2, + "target_class_name": "Face_Neck", + }, + 27: { + "name": "Badge", + "target_class_idx": 1, + "target_class_name": "Apparel", + }, # accessory โ†’ Apparel + 28: { + "name": "Right_Hand", + "target_class_idx": 14, + "target_class_name": "Right_Hand", + }, + 29: { + "name": "Right_Lower_Arm", + "target_class_idx": 15, + "target_class_name": "Right_Lower_Arm", + }, + 30: { + "name": "Right_Upper_Arm", + "target_class_idx": 19, + "target_class_name": "Right_Upper_Arm", + }, + 31: {"name": "Left_Hand", "target_class_idx": 5, "target_class_name": "Left_Hand"}, + 32: { + "name": "Left_Lower_Arm", + "target_class_idx": 6, + "target_class_name": "Left_Lower_Arm", + }, + 33: { + "name": "Left_Upper_Arm", + "target_class_idx": 10, + "target_class_name": "Left_Upper_Arm", + }, + 34: {"name": "Apparel", "target_class_idx": 1, "target_class_name": "Apparel"}, + 35: { + "name": "Upper_Teeth", + "target_class_idx": 26, + "target_class_name": "Upper_Teeth", + }, + 36: { + "name": "Lower_Teeth", + "target_class_idx": 25, + "target_class_name": "Lower_Teeth", + }, + 37: {"name": "Lower_Lip", "target_class_idx": 23, "target_class_name": "Lower_Lip"}, + 38: {"name": "Upper_Lip", "target_class_idx": 24, "target_class_name": "Upper_Lip"}, + 39: {"name": "Hair", "target_class_idx": 3, "target_class_name": "Hair"}, + 40: { + "name": "Chair", + "target_class_idx": 0, + "target_class_name": "Background", + }, # treated as background + 41: { + "name": "Right_thumbNail", + "target_class_idx": 14, + "target_class_name": "Right_Hand", + }, + 42: { + "name": "Right_indexNail", + "target_class_idx": 14, + "target_class_name": "Right_Hand", + }, + 43: { + "name": "Right_middleNail", + "target_class_idx": 14, + "target_class_name": "Right_Hand", + }, + 44: { + "name": "Right_ringNail", + "target_class_idx": 14, + "target_class_name": "Right_Hand", + }, + 45: { + "name": "Right_pinkyNail", + "target_class_idx": 14, + "target_class_name": "Right_Hand", + }, + 46: { + "name": "Left_thumbNail", + "target_class_idx": 5, + "target_class_name": "Left_Hand", + }, + 47: { + "name": "Left_indexNail", + "target_class_idx": 5, + "target_class_name": "Left_Hand", + }, + 48: { + "name": "Left_middleNail", + "target_class_idx": 5, + "target_class_name": "Left_Hand", + }, + 49: { + "name": "Left_ringNail", + "target_class_idx": 5, + "target_class_name": "Left_Hand", + }, + 50: { + "name": "Left_pinkyNail", + "target_class_idx": 5, + "target_class_name": "Left_Hand", + }, + 51: {"name": "Occluder", "target_class_idx": 0, "target_class_name": "Background"}, +} + +DOME_MAPPING_34_to_29 = { + 0: {"name": "Background", "target_class_idx": 0, "target_class_name": "Background"}, + 1: {"name": "Apparel", "target_class_idx": 1, "target_class_name": "Apparel"}, + 2: {"name": "Chair", "target_class_idx": 0, "target_class_name": "Background"}, + 3: { + "name": "Eyeglass_Frame", + "target_class_idx": 2, + "target_class_name": "Eyeglass", + }, + 4: { + "name": "Eyeglass_Lenses", + "target_class_idx": 2, + "target_class_name": "Eyeglass", + }, + 5: {"name": "Face_Neck", "target_class_idx": 3, "target_class_name": "Face_Neck"}, + 6: {"name": "Hair", "target_class_idx": 4, "target_class_name": "Hair"}, + 7: {"name": "Headset", "target_class_idx": None, "target_class_name": None}, + 8: {"name": "Left_Foot", "target_class_idx": 5, "target_class_name": "Left_Foot"}, + 9: {"name": "Left_Hand", "target_class_idx": 6, "target_class_name": "Left_Hand"}, + 10: { + "name": "Left_Lower_Arm", + "target_class_idx": 7, + "target_class_name": "Left_Lower_Arm", + }, + 11: { + "name": "Left_Lower_Leg", + "target_class_idx": 8, + "target_class_name": "Left_Lower_Leg", + }, + 12: {"name": "Left_Shoe", "target_class_idx": 9, "target_class_name": "Left_Shoe"}, + 13: {"name": "Left_Sock", "target_class_idx": 10, "target_class_name": "Left_Sock"}, + 14: { + "name": "Left_Upper_Arm", + "target_class_idx": 11, + "target_class_name": "Left_Upper_Arm", + }, + 15: { + "name": "Left_Upper_Leg", + "target_class_idx": 12, + "target_class_name": "Left_Upper_Leg", + }, + 16: { + "name": "Lower_Clothing", + "target_class_idx": 13, + "target_class_name": "Lower_Clothing", + }, + 17: { + "name": "Lower_Spandex", + "target_class_idx": 13, + "target_class_name": "Lower_Clothing", + }, + 18: { + "name": "Right_Foot", + "target_class_idx": 14, + "target_class_name": "Right_Foot", + }, + 19: { + "name": "Right_Hand", + "target_class_idx": 15, + "target_class_name": "Right_Hand", + }, + 20: { + "name": "Right_Lower_Arm", + "target_class_idx": 16, + "target_class_name": "Right_Lower_Arm", + }, + 21: { + "name": "Right_Lower_Leg", + "target_class_idx": 17, + "target_class_name": "Right_Lower_Leg", + }, + 22: { + "name": "Right_Shoe", + "target_class_idx": 18, + "target_class_name": "Right_Shoe", + }, + 23: { + "name": "Right_Sock", + "target_class_idx": 19, + "target_class_name": "Right_Sock", + }, + 24: { + "name": "Right_Upper_Arm", + "target_class_idx": 20, + "target_class_name": "Right_Upper_Arm", + }, + 25: { + "name": "Right_Upper_Leg", + "target_class_idx": 21, + "target_class_name": "Right_Upper_Leg", + }, + 26: {"name": "Torso", "target_class_idx": 22, "target_class_name": "Torso"}, + 27: { + "name": "Upper_Clothing", + "target_class_idx": 23, + "target_class_name": "Upper_Clothing", + }, + 28: {"name": "Visible_Badge", "target_class_idx": None, "target_class_name": None}, + 29: {"name": "Lower_Lip", "target_class_idx": 24, "target_class_name": "Lower_Lip"}, + 30: {"name": "Upper_Lip", "target_class_idx": 25, "target_class_name": "Upper_Lip"}, + 31: { + "name": "Lower_Teeth", + "target_class_idx": 26, + "target_class_name": "Lower_Teeth", + }, + 32: { + "name": "Upper_Teeth", + "target_class_idx": 27, + "target_class_name": "Upper_Teeth", + }, + 33: {"name": "Tongue", "target_class_idx": 28, "target_class_name": "Tongue"}, +} + +# ## loss weights: 29 classes +# 0 Background 0.1 +# 1 Apparel 10 +# 2 Eyeglass 10 +# 4 Face_Neck 3 +# 5 Hair 2 +# 6 Left_Foot 4 +# 7 Left_Hand 4 +# 8 Left_Lower_Arm 2 +# 9 Left_Lower_Leg 2 +# 10 Left_Shoe 6 +# 11 Left_Sock 10 +# 12 Left_Upper_Arm 3 +# 13 Left_Upper_Leg 3 +# 14 Lower_Clothing 1 +# 15 Right_Foot 4 +# 16 Right_Hand 4 +# 17 Right_Lower_Arm 2 +# 18 Right_Lower_Leg 2 +# 19 Right_Shoe 6 +# 20 Right_Sock 10 +# 21 Right_Upper_Arm 3 +# 22 Right_Upper_Leg 3 +# 23 Torso 1 +# 24 Upper_Clothing 1 +# 25 Lower_Lip 10 +# 26 Upper_Lip 10 +# 27 Lower_Teeth 10 +# 28 Upper_Teeth 10 +# 29 Tongue 10 + +# ## loss weights: 34 classes +# 0 Background 0.1 +# 1 Apparel 10 +# 2 Chair 3 +# 3 Eyeglass_Frame 10 +# 4 Eyeglass_Lenses 10 +# 5 Face_Neck 3 +# 6 Hair 2 +# 7 Headset 2 +# 8 Left_Foot 4 +# 9 Left_Hand 4 +# 10 Left_Lower_Arm 2 +# 11 Left_Lower_Leg 2 +# 12 Left_Shoe 6 +# 13 Left_Sock 10 +# 14 Left_Upper_Arm 3 +# 15 Left_Upper_Leg 3 +# 16 Lower_Clothing 1 +# 17 Lower_Spandex 0.5 +# 18 Right_Foot 4 +# 19 Right_Hand 4 +# 20 Right_Lower_Arm 2 +# 21 Right_Lower_Leg 2 +# 22 Right_Shoe 6 +# 23 Right_Sock 10 +# 24 Right_Upper_Arm 3 +# 25 Right_Upper_Leg 3 +# 26 Torso 1 +# 27 Upper_Clothing 1 +# 28 Visible_Badge 10 +# 29 Lower_Lip 10 +# 30 Upper_Lip 10 +# 31 Lower_Teeth 10 +# 32 Upper_Teeth 10 +# 33 Tongue 10 diff --git a/sapiens/dense/src/datasets/transforms/__init__.py b/sapiens/dense/src/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82de8eb8ce8f826b736ee6b05b9306994a967187 --- /dev/null +++ b/sapiens/dense/src/datasets/transforms/__init__.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .albedo_transforms import ( + AlbedoPackInputs, + AlbedoRandomCrop, + AlbedoRandomCropContinuous, + AlbedoRandomFlip, + AlbedoRandomScale, + AlbedoResize, + AlbedoResizePadImage, +) +from .normal_transforms import ( + NormalGenerateTarget, + NormalPackInputs, + NormalRandomCrop, + NormalRandomCropContinuous, + NormalRandomFlip, + NormalRandomScale, + NormalResize, + NormalResizePadImage, +) +from .pointmap_transforms import ( + PointmapGenerateTarget, + PointmapPackInputs, + PointmapRandomCrop, + PointmapRandomCropContinuous, + PointmapRandomFlip, + PointmapRandomScale, + PointmapResize, + PointmapResizePadImage, +) +from .seg_transforms import ( + SegPackInputs, + SegRandomBackground, + SegRandomCrop, + SegRandomHorizontalFlip, + SegRandomResize, + SegRandomRotate, + SegResize, +) + +__all__ = [ + "SegRandomBackground", + "SegRandomResize", + "SegPackInputs", + "SegRandomCrop", + "SegRandomHorizontalFlip", + "SegRandomRotate", + "SegResize", + "PointmapGenerateTarget", + "PointmapPackInputs", + "PointmapRandomCrop", + "PointmapRandomCropContinuous", + "PointmapRandomFlip", + "PointmapRandomScale", + "PointmapResize", + "PointmapResizePadImage", + "NormalGenerateTarget", + "NormalPackInputs", + "NormalRandomCrop", + "NormalRandomCropContinuous", + "NormalRandomFlip", + "NormalRandomScale", + "NormalResize", + "NormalResizePadImage", + "AlbedoPackInputs", + "AlbedoRandomCrop", + "AlbedoRandomCropContinuous", + "AlbedoRandomFlip", + "AlbedoRandomScale", + "AlbedoResize", + "AlbedoResizePadImage", +] diff --git a/sapiens/dense/src/datasets/transforms/albedo_transforms.py b/sapiens/dense/src/datasets/transforms/albedo_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..393081538c19ef3c3889c362dfc92ddd20500f9e --- /dev/null +++ b/sapiens/dense/src/datasets/transforms/albedo_transforms.py @@ -0,0 +1,505 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import random +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +from sapiens.engine.datasets import BaseTransform, to_tensor +from sapiens.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class AlbedoRandomScale(BaseTransform): + def __init__( + self, + scale_min: float = 0.5, + scale_max: float = 2.0, + prob: float = 0.5, + ): + super().__init__() + assert 0 < scale_min <= scale_max, ( + f"Invalid scale range: ({scale_min}, {scale_max})" + ) + self.scale_min = scale_min + self.scale_max = scale_max + self.prob = prob + + def _random_scale_factor(self) -> float: + """Sample a random scale factor in [scale_min, scale_max].""" + return np.random.uniform(self.scale_min, self.scale_max) + + def transform(self, results: dict) -> dict: + if np.random.rand() >= self.prob: + return results + + img = results["img"] + orig_h, orig_w = img.shape[:2] + + # 1. Sample a random scale factor + s = self._random_scale_factor() + + # 2. Compute the new size + new_w = int(round(orig_w * s)) + new_h = int(round(orig_h * s)) + + # 3. Resize the image + img_resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + results["img"] = img_resized + results["img_shape"] = (new_h, new_w) + + # 4. Resize mask, depth, etc. using INTER_NEAREST + if "mask" in results: + mask_resized = cv2.resize( + results["mask"].astype(np.uint8), + (new_w, new_h), + interpolation=cv2.INTER_NEAREST, + ) + results["mask"] = mask_resized + + if "gt_albedo" in results: + albedo_resized = cv2.resize( + results["gt_albedo"], (new_w, new_h), interpolation=cv2.INTER_LINEAR + ) + results["gt_albedo"] = albedo_resized + + return results + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"scale_min={self.scale_min}, " + f"scale_max={self.scale_max})" + ) + + +@TRANSFORMS.register_module() +class AlbedoRandomCrop(BaseTransform): + def __init__(self, crop_sizes: List[Tuple[int, int]], prob: float = 0.5): + super().__init__() + assert isinstance(crop_sizes, list) and len(crop_sizes) > 0, ( + "crop_sizes must be a non-empty list of (h, w) tuples." + ) + for size in crop_sizes: + assert len(size) == 2 and size[0] > 0 and size[1] > 0, ( + f"Invalid crop size: {size}" + ) + + self.crop_sizes = crop_sizes + self.prob = prob + + def __repr__(self): + return ( + f"{self.__class__.__name__}(crop_sizes={self.crop_sizes}, prob={self.prob})" + ) + + def _get_crop_bbox(self, img: np.ndarray, crop_h: int, crop_w: int) -> tuple: + """Randomly generate a crop bounding box for an image given target (h, w).""" + h, w = img.shape[:2] + + # Ensure the target crop is not bigger than the image + crop_h = min(crop_h, h) + crop_w = min(crop_w, w) + + margin_h = h - crop_h + margin_w = w - crop_w + + # Random top-left corner + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + + y1, y2 = offset_h, offset_h + crop_h + x1, x2 = offset_w, offset_w + crop_w + return (y1, y2, x1, x2) + + def _crop_img(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop image to bbox = (y1, y2, x1, x2).""" + y1, y2, x1, x2 = crop_bbox + return img[y1:y2, x1:x2, ...] + + def transform(self, results: dict) -> dict: + # Decide whether to apply cropping + if np.random.rand() >= self.prob: + return results # skip cropping + + img = results["img"] + crop_h, crop_w = random.choice(self.crop_sizes) + crop_bbox = self._get_crop_bbox(img, crop_h, crop_w) + cropped_img = self._crop_img(img, crop_bbox) + + results["img"] = cropped_img + results["img_shape"] = cropped_img.shape[:2] + + # Crop other maps if they exist + for key in ["gt_albedo", "mask"]: + if key in results: + results[key] = self._crop_img(results[key], crop_bbox) + + return results + + +@TRANSFORMS.register_module() +class AlbedoRandomCropContinuous(BaseTransform): + def __init__( + self, + ar_range: Tuple[float, float] = (0.5, 2.0), + area_range: Tuple[float, float] = (0.1, 1.0), + num_attempts: int = 10, + prob: float = 0.5, + ): + super().__init__() + assert ar_range[0] > 0 and ar_range[1] >= ar_range[0], ( + f"Invalid ar_range={ar_range}" + ) + assert area_range[0] > 0 and area_range[1] >= area_range[0], ( + f"Invalid area_range={area_range}" + ) + self.ar_range = ar_range + self.area_range = area_range + self.num_attempts = num_attempts + self.prob = prob + + def __repr__(self): + return ( + f"{self.__class__.__name__}(ar_range={self.ar_range}, " + f"area_range={self.area_range}, " + f"num_attempts={self.num_attempts}, " + f"prob={self.prob})" + ) + + def transform(self, results: Dict) -> Dict: + """Apply the random aspect-ratio crop if conditions are met.""" + if not (random.random() < self.prob): + return results # skip cropping + + img = results["img"] + orig_h, orig_w = img.shape[:2] + img_area = orig_h * orig_w + + # Try up to num_attempts times to find a valid crop + for attempt in range(self.num_attempts): + # 1) Sample aspect ratio in [ar_min, ar_max] + ar = random.uniform(*self.ar_range) # aspect ratio + + # 2) Sample area fraction in [area_min, area_max] + area_frac = random.uniform(*self.area_range) + target_area = area_frac * img_area + + # 3) Solve for crop_h, crop_w + crop_h = math.sqrt(target_area / ar) + crop_w = ar * crop_h + + # 4) Check feasibility: both must be <= orig dims + if crop_w <= orig_w and crop_h <= orig_h: + # 5) Random top-left corner + crop_h = int(round(crop_h)) + crop_w = int(round(crop_w)) + margin_h = orig_h - crop_h + margin_w = orig_w - crop_w + y1 = random.randint(0, margin_h + 1) + x1 = random.randint(0, margin_w + 1) + + y2 = y1 + crop_h + x2 = x1 + crop_w + + # We found a valid crop + crop_bbox = (y1, y2, x1, x2) + break + else: + # If we never broke out, no valid crop found; skip + # (or we could do a fallback like no crop) + return results + + # --- We do the actual cropping now --- + def _crop(img_: np.ndarray, bbox: tuple) -> np.ndarray: + (yy1, yy2, xx1, xx2) = bbox + return img_[yy1:yy2, xx1:xx2, ...] + + # Crop the main image + cropped_img = _crop(img, crop_bbox) + results["img"] = cropped_img + results["img_shape"] = cropped_img.shape[:2] + + # Crop depth/mask if present + for key in ["gt_albedo", "mask"]: + if key in results: + results[key] = _crop(results[key], crop_bbox) + + return results + + +@TRANSFORMS.register_module() +class AlbedoResize(BaseTransform): + def __init__(self, height, width, test_mode: bool = False) -> None: + super().__init__() + self.target_height = height + self.target_width = width + self.test_mode = test_mode + + def transform(self, results: Dict) -> Dict: + img = results["img"] + orig_height, orig_width = img.shape[:2] + + # 1. Compute the scale factor to maintain aspect ratio + scale_w = self.target_width / orig_width + scale_h = self.target_height / orig_height + scale_factor = min(scale_w, scale_h) + + # 2. Determine new (width, height) after aspect-preserving resize + new_width = int(round(orig_width * scale_factor)) + new_height = int(round(orig_height * scale_factor)) + + # 3. Resize the image + resized_img = cv2.resize( + img, (new_width, new_height), interpolation=cv2.INTER_LINEAR + ) + + # 4. Create a black canvas of final size [H, W] + final_img = np.zeros( + (self.target_height, self.target_width, resized_img.shape[2]) + if resized_img.ndim == 3 + else (self.target_height, self.target_width), + dtype=resized_img.dtype, + ) + + # 5. Compute offsets to center the resized image + offset_x = (self.target_width - new_width) // 2 + offset_y = (self.target_height - new_height) // 2 + + # 6. Copy resized image into the canvas + if final_img.ndim == 3: # color image + final_img[ + offset_y : offset_y + new_height, offset_x : offset_x + new_width, : + ] = resized_img + else: # single-channel image + final_img[ + offset_y : offset_y + new_height, offset_x : offset_x + new_width + ] = resized_img + + # 7. Replace `results['img']` with our padded image + results["img"] = final_img + results["img_shape"] = final_img.shape[:2] + + # 8. Do the same for mask & gt_depth + # (using nearest interpolation, then padding to center) + if "mask" in results and self.test_mode is False: + mask_resized = cv2.resize( + results["mask"].astype(np.uint8), + (new_width, new_height), + interpolation=cv2.INTER_NEAREST, + ) + final_mask = np.zeros( + (self.target_height, self.target_width), dtype=mask_resized.dtype + ) + final_mask[ + offset_y : offset_y + new_height, offset_x : offset_x + new_width + ] = mask_resized + results["mask"] = final_mask + + if "gt_albedo" in results and self.test_mode is False: + albedo_resized = cv2.resize( + results["gt_albedo"], + (new_width, new_height), + interpolation=cv2.INTER_LINEAR, + ) + final_albedo = np.zeros( + (self.target_height, self.target_width, 3), dtype=albedo_resized.dtype + ) + final_albedo[ + offset_y : offset_y + new_height, offset_x : offset_x + new_width, : + ] = albedo_resized + results["gt_albedo"] = final_albedo + + return results + + +@TRANSFORMS.register_module() +class AlbedoRandomFlip(BaseTransform): + def __init__(self, prob=0.5) -> None: + super().__init__() + self.prob = prob + + def _flip(self, results: dict) -> None: + """Flip images, masks, depth maps and adjust camera parameters.""" + # flip image + results["img"] = cv2.flip(results["img"], 1) # 1 for horizontal flip + + # flip seg map and depth (horizontal flip) + results["mask"] = cv2.flip(results["mask"], 1) + + if "gt_albedo" in results: + gt_albedo = results["gt_albedo"] + gt_albedo = cv2.flip(gt_albedo, 1) # 1 for horizontal flip + results["gt_albedo"] = gt_albedo + + def transform(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]: + if np.random.rand() < self.prob: + self._flip(results) + return results + + +@TRANSFORMS.register_module() +class AlbedoPackInputs(BaseTransform): + def __init__( + self, + test_mode: bool = False, + meta_keys=( + "img_path", + "ori_shape", + "img_shape", + ), + ): + self.test_mode = test_mode + self.meta_keys = meta_keys + + def transform(self, results: dict) -> dict: + packed_results = dict() + if "img" in results: + img = results["img"] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + if not img.flags.c_contiguous: + img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1))) + else: + img = img.transpose(2, 0, 1) + img = to_tensor(img).contiguous() + packed_results["inputs"] = img + + data_sample = dict() + + if "gt_albedo" in results: + mask = results["mask"] > 0 ## boolean mask + + ## min number of valid pixels is 16 + if mask.sum() < 16 and self.test_mode is False: + return None + + if (mask.sum() / (mask.shape[0] * mask.shape[1]) > 0.96) and ( + self.test_mode is False + ): + return None + + ##----------------------------------------- + mask = to_tensor(mask[None, ...].copy()) ## 1 x H x W + data_sample["mask"] = mask + + gt_albedo = results["gt_albedo"].astype(np.float32) ## H x W x 3 + gt_albedo = gt_albedo.transpose(2, 0, 1) # H x W x 3 -> 3 x H x W + data_sample["gt_albedo"] = to_tensor(gt_albedo.copy()) + + img_meta = {} + for key in self.meta_keys: + if key in results: + if isinstance(results[key], (int, float)): + img_meta[key] = np.float32(results[key]) + elif isinstance(results[key], np.ndarray): + img_meta[key] = results[key].astype(np.float32) + else: + img_meta[key] = results[key] + data_sample["meta"] = img_meta + packed_results["data_samples"] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(meta_keys={self.meta_keys})" + return repr_str + + +@TRANSFORMS.register_module() +class AlbedoResizePadImage(BaseTransform): + def __init__( + self, + height: int = 1024, + width: int = 768, + pad_val: Optional[int] = 0, + padding_mode: str = "constant", + ) -> None: + self.height = height + self.width = width + self.pad_val = pad_val + assert padding_mode in ["constant", "edge", "reflect", "symmetric"] + self.padding_mode = padding_mode + + def _resize_maintain_aspect_ratio(self, img, target_size): + """Resize image maintaining aspect ratio and return padding sizes.""" + original_height, original_width = img.shape[:2] + target_width, target_height = target_size + + # Calculate scaling factors + scale_w = target_width / original_width + scale_h = target_height / original_height + scale = min(scale_w, scale_h) # Use the smaller scaling factor + + # Calculate new dimensions + new_width = int(original_width * scale) + new_height = int(original_height * scale) + + # Resize image + resized_img = cv2.resize( + img, (new_width, new_height), interpolation=cv2.INTER_LINEAR + ) + + # Calculate padding + pad_width = target_width - new_width + pad_height = target_height - new_height + + padding_left = pad_width // 2 + padding_right = pad_width - padding_left + padding_top = pad_height // 2 + padding_bottom = pad_height - padding_top + + return resized_img, (padding_left, padding_right, padding_top, padding_bottom) + + def _pad_img(self, results: dict) -> None: + """Resize image maintaining aspect ratio and pad to target size.""" + img = results["img"] + target_size = (self.width, self.height) # (width, height) + + # Resize image maintaining aspect ratio + resized_img, padding_size = self._resize_maintain_aspect_ratio(img, target_size) + + # Prepare padding value + pad_val = self.pad_val + + # Pad image + padding_left, padding_right, padding_top, padding_bottom = padding_size + if resized_img.ndim == 3: + padded_img = np.pad( + resized_img, + ((padding_top, padding_bottom), (padding_left, padding_right), (0, 0)), + mode=self.padding_mode, + constant_values=pad_val, + ) + else: + padded_img = np.pad( + resized_img, + ((padding_top, padding_bottom), (padding_left, padding_right)), + mode=self.padding_mode, + constant_values=pad_val, + ) + + # Update results dictionary + results["img"] = padded_img + results["pad_shape"] = padded_img.shape + results["pad_fixed_size"] = target_size + results["img_shape"] = padded_img.shape[:2] + results["padding_size"] = padding_size + + def transform(self, results: dict) -> dict: + self._pad_img(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"height={self.height}, " + repr_str += f"width={self.width}, " + repr_str += f"pad_val={self.pad_val}, " + repr_str += f"padding_mode={self.padding_mode})" + return repr_str diff --git a/sapiens/dense/src/datasets/transforms/normal_transforms.py b/sapiens/dense/src/datasets/transforms/normal_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..94ce8915c7a3ae5abf7cc518908c0e794a2fa3a1 --- /dev/null +++ b/sapiens/dense/src/datasets/transforms/normal_transforms.py @@ -0,0 +1,539 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import random +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +from sapiens.engine.datasets import BaseTransform, to_tensor +from sapiens.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class NormalRandomScale(BaseTransform): + def __init__( + self, + scale_min: float = 0.5, + scale_max: float = 2.0, + prob: float = 0.5, + interpolation: int = cv2.INTER_LINEAR, + ): + super().__init__() + assert 0 < scale_min <= scale_max, ( + f"Invalid scale range: ({scale_min}, {scale_max})" + ) + self.scale_min = scale_min + self.scale_max = scale_max + self.interpolation = interpolation + self.prob = prob + + def _random_scale_factor(self) -> float: + """Sample a random scale factor in [scale_min, scale_max].""" + return np.random.uniform(self.scale_min, self.scale_max) + + def transform(self, results: dict) -> dict: + if np.random.rand() >= self.prob: + return results + + img = results["img"] + orig_h, orig_w = img.shape[:2] + + # 1. Sample a random scale factor + s = self._random_scale_factor() + + # 2. Compute the new size + new_w = int(round(orig_w * s)) + new_h = int(round(orig_h * s)) + + # 3. Resize the image + img_resized = cv2.resize(img, (new_w, new_h), interpolation=self.interpolation) + results["img"] = img_resized + results["img_shape"] = (new_h, new_w) + + # 4. Resize mask, depth, etc. using INTER_NEAREST + if "mask" in results: + mask_resized = cv2.resize( + results["mask"].astype(np.uint8), + (new_w, new_h), + interpolation=cv2.INTER_NEAREST, + ) + results["mask"] = mask_resized + + if "gt_normal" in results: + normal_resized = cv2.resize( + results["gt_normal"], + (new_w, new_h), + interpolation=cv2.INTER_NEAREST, + ) + results["gt_normal"] = normal_resized + + return results + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"scale_min={self.scale_min}, " + f"scale_max={self.scale_max})" + ) + + +@TRANSFORMS.register_module() +class NormalRandomCrop(BaseTransform): + def __init__(self, crop_sizes: List[Tuple[int, int]], prob: float = 0.5): + super().__init__() + assert isinstance(crop_sizes, list) and len(crop_sizes) > 0, ( + "crop_sizes must be a non-empty list of (h, w) tuples." + ) + for size in crop_sizes: + assert len(size) == 2 and size[0] > 0 and size[1] > 0, ( + f"Invalid crop size: {size}" + ) + + self.crop_sizes = crop_sizes + self.prob = prob + + def __repr__(self): + return ( + f"{self.__class__.__name__}(crop_sizes={self.crop_sizes}, prob={self.prob})" + ) + + def _get_crop_bbox(self, img: np.ndarray, crop_h: int, crop_w: int) -> tuple: + """Randomly generate a crop bounding box for an image given target (h, w).""" + h, w = img.shape[:2] + + # Ensure the target crop is not bigger than the image + crop_h = min(crop_h, h) + crop_w = min(crop_w, w) + + margin_h = h - crop_h + margin_w = w - crop_w + + # Random top-left corner + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + + y1, y2 = offset_h, offset_h + crop_h + x1, x2 = offset_w, offset_w + crop_w + return (y1, y2, x1, x2) + + def _crop_img(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop image to bbox = (y1, y2, x1, x2).""" + y1, y2, x1, x2 = crop_bbox + return img[y1:y2, x1:x2, ...] + + def transform(self, results: dict) -> dict: + # Decide whether to apply cropping + if np.random.rand() >= self.prob: + return results # skip cropping + + img = results["img"] + # Pick one (h, w) from the list of possible crop sizes + + crop_h, crop_w = random.choice(self.crop_sizes) + + # Generate the crop bounding box + crop_bbox = self._get_crop_bbox(img, crop_h, crop_w) + + # Apply to the main image + cropped_img = self._crop_img(img, crop_bbox) + + results["img"] = cropped_img + results["img_shape"] = cropped_img.shape[:2] + + # Crop other maps if they exist + for key in ["gt_normal", "mask"]: + if key in results: + results[key] = self._crop_img(results[key], crop_bbox) + + return results + + +@TRANSFORMS.register_module() +class NormalRandomCropContinuous(BaseTransform): + def __init__( + self, + ar_range: Tuple[float, float] = (0.5, 2.0), + area_range: Tuple[float, float] = (0.1, 1.0), + num_attempts: int = 10, + prob: float = 0.5, + ): + super().__init__() + assert ar_range[0] > 0 and ar_range[1] >= ar_range[0], ( + f"Invalid ar_range={ar_range}" + ) + assert area_range[0] > 0 and area_range[1] >= area_range[0], ( + f"Invalid area_range={area_range}" + ) + self.ar_range = ar_range + self.area_range = area_range + self.num_attempts = num_attempts + self.prob = prob + + def __repr__(self): + return ( + f"{self.__class__.__name__}(ar_range={self.ar_range}, " + f"area_range={self.area_range}, " + f"num_attempts={self.num_attempts}, " + f"prob={self.prob})" + ) + + def transform(self, results: Dict) -> Dict: + """Apply the random aspect-ratio crop if conditions are met.""" + if not (random.random() < self.prob): + return results # skip cropping + + img = results["img"] + orig_h, orig_w = img.shape[:2] + img_area = orig_h * orig_w + + # Try up to num_attempts times to find a valid crop + for attempt in range(self.num_attempts): + # 1) Sample aspect ratio in [ar_min, ar_max] + ar = random.uniform(*self.ar_range) # aspect ratio + + # 2) Sample area fraction in [area_min, area_max] + area_frac = random.uniform(*self.area_range) + target_area = area_frac * img_area + + # 3) Solve for crop_h, crop_w + crop_h = math.sqrt(target_area / ar) + crop_w = ar * crop_h + + # 4) Check feasibility: both must be <= orig dims + if crop_w <= orig_w and crop_h <= orig_h: + # 5) Random top-left corner + crop_h = int(round(crop_h)) + crop_w = int(round(crop_w)) + margin_h = orig_h - crop_h + margin_w = orig_w - crop_w + y1 = random.randint(0, margin_h + 1) + x1 = random.randint(0, margin_w + 1) + + y2 = y1 + crop_h + x2 = x1 + crop_w + + # We found a valid crop + crop_bbox = (y1, y2, x1, x2) + break + else: + # If we never broke out, no valid crop found; skip + # (or we could do a fallback like no crop) + return results + + # --- We do the actual cropping now --- + def _crop(img_: np.ndarray, bbox: tuple) -> np.ndarray: + (yy1, yy2, xx1, xx2) = bbox + return img_[yy1:yy2, xx1:xx2, ...] + + # Crop the main image + cropped_img = _crop(img, crop_bbox) + results["img"] = cropped_img + results["img_shape"] = cropped_img.shape[:2] + + # Crop depth/mask if present + for key in ["gt_normal", "mask"]: + if key in results: + results[key] = _crop(results[key], crop_bbox) + + return results + + +@TRANSFORMS.register_module() +class NormalResize(BaseTransform): + def __init__(self, height, width, test_mode: bool = False) -> None: + super().__init__() + self.target_height = height + self.target_width = width + self.test_mode = test_mode + + def transform(self, results: Dict) -> Dict: + img = results["img"] + orig_height, orig_width = img.shape[:2] + + # 1. Compute the scale factor to maintain aspect ratio + scale_w = self.target_width / orig_width + scale_h = self.target_height / orig_height + scale_factor = min(scale_w, scale_h) + + # 2. Determine new (width, height) after aspect-preserving resize + new_width = int(round(orig_width * scale_factor)) + new_height = int(round(orig_height * scale_factor)) + + # 3. Resize the image + resized_img = cv2.resize( + img, (new_width, new_height), interpolation=cv2.INTER_LINEAR + ) + + # 4. Create a black canvas of final size [H, W] + final_img = np.zeros( + (self.target_height, self.target_width, resized_img.shape[2]) + if resized_img.ndim == 3 + else (self.target_height, self.target_width), + dtype=resized_img.dtype, + ) + + # 5. Compute offsets to center the resized image + offset_x = (self.target_width - new_width) // 2 + offset_y = (self.target_height - new_height) // 2 + + # 6. Copy resized image into the canvas + if final_img.ndim == 3: # color image + final_img[ + offset_y : offset_y + new_height, offset_x : offset_x + new_width, : + ] = resized_img + else: # single-channel image + final_img[ + offset_y : offset_y + new_height, offset_x : offset_x + new_width + ] = resized_img + + # 7. Replace `results['img']` with our padded image + results["img"] = final_img + results["img_shape"] = final_img.shape[:2] + + # 8. Do the same for mask & gt_depth + # (using nearest interpolation, then padding to center) + if "mask" in results and self.test_mode is False: + mask_resized = cv2.resize( + results["mask"].astype(np.uint8), + (new_width, new_height), + interpolation=cv2.INTER_NEAREST, + ) + final_mask = np.zeros( + (self.target_height, self.target_width), dtype=mask_resized.dtype + ) + final_mask[ + offset_y : offset_y + new_height, offset_x : offset_x + new_width + ] = mask_resized + results["mask"] = final_mask + + if "gt_normal" in results and self.test_mode is False: + normal_resized = cv2.resize( + results["gt_normal"], + (new_width, new_height), + interpolation=cv2.INTER_NEAREST, + ) + final_normal = np.zeros( + (self.target_height, self.target_width, 3), dtype=normal_resized.dtype + ) + final_normal[ + offset_y : offset_y + new_height, offset_x : offset_x + new_width, : + ] = normal_resized + results["gt_normal"] = final_normal + + return results + + +@TRANSFORMS.register_module() +class NormalRandomFlip(BaseTransform): + def __init__(self, prob=0.5) -> None: + super().__init__() + self.prob = prob + + def _flip(self, results: dict) -> None: + """Flip images, masks, depth maps and adjust camera parameters.""" + # flip image + results["img"] = cv2.flip(results["img"], 1) # 1 for horizontal flip + img_shape = results["img"].shape[:2] + + # flip seg map and depth (horizontal flip) + results["mask"] = cv2.flip(results["mask"], 1) + + if "gt_normal" in results: + gt_normal = results["gt_normal"] + gt_normal = cv2.flip(gt_normal, 1) # 1 for horizontal flip + gt_normal[:, :, 0] = -gt_normal[:, :, 0] # flip x + results["gt_normal"] = gt_normal + + def transform(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]: + if np.random.rand() < self.prob: + self._flip(results) + return results + + +@TRANSFORMS.register_module() +class NormalGenerateTarget(BaseTransform): + def __init__(self, background_val: int = -1000): + self.background_val = background_val + return + + def transform(self, results: dict) -> dict: + if "gt_normal" not in results.keys(): + return results + + gt_normal = results["gt_normal"] + mask = results["mask"] + norms = np.linalg.norm(gt_normal, axis=2, keepdims=True) + gt_normal = gt_normal / (norms + 1e-6) + gt_normal[mask == 0] = self.background_val + results["gt_normal"] = gt_normal + return results + + def __repr__(self): + return self.__class__.__name__ + + +@TRANSFORMS.register_module() +class NormalPackInputs(BaseTransform): + def __init__( + self, + test_mode: bool = False, + meta_keys=( + "img_path", + "ori_shape", + "img_shape", + ), + ): + self.test_mode = test_mode + self.meta_keys = meta_keys + + def transform(self, results: dict) -> dict: + packed_results = dict() + if "img" in results: + img = results["img"] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + if not img.flags.c_contiguous: + img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1))) + else: + img = img.transpose(2, 0, 1) + img = to_tensor(img).contiguous() + packed_results["inputs"] = img + + data_sample = dict() + + if "gt_normal" in results: + mask = results["mask"] > 0 ## boolean mask + + ## min number of valid pixels is 16 + if mask.sum() < 16 and self.test_mode is False: + return None + + if (mask.sum() / (mask.shape[0] * mask.shape[1]) > 0.96) and ( + self.test_mode is False + ): + return None + + ##----------------------------------------- + mask = to_tensor(mask[None, ...].copy()) ## 1 x H x W + data_sample["mask"] = mask + + gt_normal = results["gt_normal"].astype(np.float32) ## H x W x 3 + gt_normal = gt_normal.transpose(2, 0, 1) # H x W x 3 -> 3 x H x W + data_sample["gt_normal"] = to_tensor(gt_normal.copy()) + + img_meta = {} + for key in self.meta_keys: + if key in results: + if isinstance(results[key], (int, float)): + img_meta[key] = np.float32(results[key]) + elif isinstance(results[key], np.ndarray): + img_meta[key] = results[key].astype(np.float32) + else: + img_meta[key] = results[key] + data_sample["meta"] = img_meta + packed_results["data_samples"] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(meta_keys={self.meta_keys})" + return repr_str + + +@TRANSFORMS.register_module() +class NormalResizePadImage(BaseTransform): + def __init__( + self, + height: int = 1024, + width: int = 768, + pad_val: Optional[int] = 0, + padding_mode: str = "constant", + ) -> None: + self.height = height + self.width = width + self.pad_val = pad_val + assert padding_mode in ["constant", "edge", "reflect", "symmetric"] + self.padding_mode = padding_mode + + def _resize_maintain_aspect_ratio(self, img, target_size): + """Resize image maintaining aspect ratio and return padding sizes.""" + original_height, original_width = img.shape[:2] + target_width, target_height = target_size + + # Calculate scaling factors + scale_w = target_width / original_width + scale_h = target_height / original_height + scale = min(scale_w, scale_h) # Use the smaller scaling factor + + # Calculate new dimensions + new_width = int(original_width * scale) + new_height = int(original_height * scale) + + # Resize image + resized_img = cv2.resize( + img, (new_width, new_height), interpolation=cv2.INTER_LINEAR + ) + + # Calculate padding + pad_width = target_width - new_width + pad_height = target_height - new_height + + padding_left = pad_width // 2 + padding_right = pad_width - padding_left + padding_top = pad_height // 2 + padding_bottom = pad_height - padding_top + + return resized_img, (padding_left, padding_right, padding_top, padding_bottom) + + def _pad_img(self, results: dict) -> None: + """Resize image maintaining aspect ratio and pad to target size.""" + img = results["img"] + target_size = (self.width, self.height) # (width, height) + + # Resize image maintaining aspect ratio + resized_img, padding_size = self._resize_maintain_aspect_ratio(img, target_size) + + # Prepare padding value + pad_val = self.pad_val + + # Pad image + padding_left, padding_right, padding_top, padding_bottom = padding_size + if resized_img.ndim == 3: + padded_img = np.pad( + resized_img, + ((padding_top, padding_bottom), (padding_left, padding_right), (0, 0)), + mode=self.padding_mode, + constant_values=pad_val, + ) + else: + padded_img = np.pad( + resized_img, + ((padding_top, padding_bottom), (padding_left, padding_right)), + mode=self.padding_mode, + constant_values=pad_val, + ) + + # Update results dictionary + results["img"] = padded_img + results["pad_shape"] = padded_img.shape + results["pad_fixed_size"] = target_size + results["img_shape"] = padded_img.shape[:2] + results["padding_size"] = padding_size + + def transform(self, results: dict) -> dict: + self._pad_img(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"height={self.height}, " + repr_str += f"width={self.width}, " + repr_str += f"pad_val={self.pad_val}, " + repr_str += f"padding_mode={self.padding_mode})" + return repr_str diff --git a/sapiens/dense/src/datasets/transforms/pointmap_transforms.py b/sapiens/dense/src/datasets/transforms/pointmap_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6fa4feb1aead0e86e99546ae326e9dda73e3eef5 --- /dev/null +++ b/sapiens/dense/src/datasets/transforms/pointmap_transforms.py @@ -0,0 +1,726 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import random +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +from sapiens.engine.datasets import BaseTransform, to_tensor +from sapiens.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class PointmapRandomScale(BaseTransform): + def __init__( + self, + scale_min: float = 0.5, + scale_max: float = 2.0, + prob: float = 0.5, + interpolation: int = cv2.INTER_LINEAR, + ): + super().__init__() + assert 0 < scale_min <= scale_max, ( + f"Invalid scale range: ({scale_min}, {scale_max})" + ) + self.scale_min = scale_min + self.scale_max = scale_max + self.interpolation = interpolation + self.prob = prob + + def _random_scale_factor(self) -> float: + """Sample a random scale factor in [scale_min, scale_max].""" + return np.random.uniform(self.scale_min, self.scale_max) + + def transform(self, results: dict) -> dict: + if np.random.rand() >= self.prob: + return results + + img = results["img"] + orig_h, orig_w = img.shape[:2] + + # 1. Sample a random scale factor + s = self._random_scale_factor() + + # 2. Compute the new size + new_w = int(round(orig_w * s)) + new_h = int(round(orig_h * s)) + + # 3. Resize the image + img_resized = cv2.resize(img, (new_w, new_h), interpolation=self.interpolation) + results["img"] = img_resized + results["img_shape"] = (new_h, new_w) + + # 4. Resize mask, depth, etc. using INTER_NEAREST + if "mask" in results: + mask_resized = cv2.resize( + results["mask"].astype(np.uint8), + (new_w, new_h), + interpolation=cv2.INTER_NEAREST, + ) + results["mask"] = mask_resized + + if "gt_depth" in results: + depth_resized = cv2.resize( + results["gt_depth"], + (new_w, new_h), + interpolation=cv2.INTER_NEAREST, + ) + results["gt_depth"] = depth_resized + + # 5. Update camera intrinsics if present + if "K" in results: + K_new = results["K"].copy() + # Scale fx, fy + K_new[0, 0] *= s # fx + K_new[1, 1] *= s # fy + # Shift principal point + K_new[0, 2] *= s # cx + K_new[1, 2] *= s # cy + results["K"] = K_new + + return results + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"scale_min={self.scale_min}, " + f"scale_max={self.scale_max})" + ) + + +@TRANSFORMS.register_module() +class PointmapRandomCrop(BaseTransform): + def __init__(self, crop_sizes: List[Tuple[int, int]], prob: float = 0.5): + super().__init__() + assert isinstance(crop_sizes, list) and len(crop_sizes) > 0, ( + "crop_sizes must be a non-empty list of (h, w) tuples." + ) + for size in crop_sizes: + assert len(size) == 2 and size[0] > 0 and size[1] > 0, ( + f"Invalid crop size: {size}" + ) + + self.crop_sizes = crop_sizes + self.prob = prob + + def __repr__(self): + return ( + f"{self.__class__.__name__}(crop_sizes={self.crop_sizes}, prob={self.prob})" + ) + + def _get_crop_bbox(self, img: np.ndarray, crop_h: int, crop_w: int) -> tuple: + """Randomly generate a crop bounding box for an image given target (h, w).""" + h, w = img.shape[:2] + + # Ensure the target crop is not bigger than the image + crop_h = min(crop_h, h) + crop_w = min(crop_w, w) + + margin_h = h - crop_h + margin_w = w - crop_w + + # Random top-left corner + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + + y1, y2 = offset_h, offset_h + crop_h + x1, x2 = offset_w, offset_w + crop_w + return (y1, y2, x1, x2) + + def _crop_img(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop image to bbox = (y1, y2, x1, x2).""" + y1, y2, x1, x2 = crop_bbox + return img[y1:y2, x1:x2, ...] + + def transform(self, results: dict) -> dict: + # Decide whether to apply cropping + if np.random.rand() >= self.prob: + return results # skip cropping + + img = results["img"] + # Pick one (h, w) from the list of possible crop sizes + + crop_h, crop_w = random.choice(self.crop_sizes) + + # Generate the crop bounding box + crop_bbox = self._get_crop_bbox(img, crop_h, crop_w) + + # Apply to the main image + cropped_img = self._crop_img(img, crop_bbox) + + results["img"] = cropped_img + results["img_shape"] = cropped_img.shape[:2] + + # Crop other maps if they exist + for key in ["gt_depth", "mask"]: + if key in results: + results[key] = self._crop_img(results[key], crop_bbox) + + # Adjust intrinsics if present + if "K" in results: + K_new = results["K"].copy() + y1, y2, x1, x2 = crop_bbox + # Shift principal point + K_new[0, 2] -= x1 + K_new[1, 2] -= y1 + results["K"] = K_new + + return results + + +@TRANSFORMS.register_module() +class PointmapRandomCropContinuous(BaseTransform): + def __init__( + self, + ar_range: Tuple[float, float] = (0.5, 2.0), + area_range: Tuple[float, float] = (0.1, 1.0), + num_attempts: int = 10, + prob: float = 0.5, + ): + super().__init__() + assert ar_range[0] > 0 and ar_range[1] >= ar_range[0], ( + f"Invalid ar_range={ar_range}" + ) + assert area_range[0] > 0 and area_range[1] >= area_range[0], ( + f"Invalid area_range={area_range}" + ) + self.ar_range = ar_range + self.area_range = area_range + self.num_attempts = num_attempts + self.prob = prob + + def __repr__(self): + return ( + f"{self.__class__.__name__}(ar_range={self.ar_range}, " + f"area_range={self.area_range}, " + f"num_attempts={self.num_attempts}, " + f"prob={self.prob})" + ) + + def transform(self, results: Dict) -> Dict: + """Apply the random aspect-ratio crop if conditions are met.""" + if not (random.random() < self.prob): + return results # skip cropping + + img = results["img"] + orig_h, orig_w = img.shape[:2] + img_area = orig_h * orig_w + + # Try up to num_attempts times to find a valid crop + for attempt in range(self.num_attempts): + # 1) Sample aspect ratio in [ar_min, ar_max] + ar = random.uniform(*self.ar_range) # aspect ratio + + # 2) Sample area fraction in [area_min, area_max] + area_frac = random.uniform(*self.area_range) + target_area = area_frac * img_area + + # 3) Solve for crop_h, crop_w + crop_h = math.sqrt(target_area / ar) + crop_w = ar * crop_h + + # 4) Check feasibility: both must be <= orig dims + if crop_w <= orig_w and crop_h <= orig_h: + # 5) Random top-left corner + crop_h = int(round(crop_h)) + crop_w = int(round(crop_w)) + margin_h = orig_h - crop_h + margin_w = orig_w - crop_w + y1 = random.randint(0, margin_h + 1) + x1 = random.randint(0, margin_w + 1) + + y2 = y1 + crop_h + x2 = x1 + crop_w + + # We found a valid crop + crop_bbox = (y1, y2, x1, x2) + break + else: + # If we never broke out, no valid crop found; skip + # (or we could do a fallback like no crop) + return results + + # --- We do the actual cropping now --- + def _crop(img_: np.ndarray, bbox: tuple) -> np.ndarray: + (yy1, yy2, xx1, xx2) = bbox + return img_[yy1:yy2, xx1:xx2, ...] + + # Crop the main image + cropped_img = _crop(img, crop_bbox) + results["img"] = cropped_img + results["img_shape"] = cropped_img.shape[:2] + + # Crop depth/mask if present + for key in ["gt_depth", "mask"]: + if key in results: + results[key] = _crop(results[key], crop_bbox) + + # Adjust intrinsics if present + if "K" in results: + K_new = results["K"].copy() + # Shift principal point + y1, y2, x1, x2 = crop_bbox + K_new[0, 2] -= x1 + K_new[1, 2] -= y1 + results["K"] = K_new + + return results + + +@TRANSFORMS.register_module() +class PointmapResize(BaseTransform): + def __init__(self, height, width) -> None: + super().__init__() + self.target_height = height + self.target_width = width + + def transform(self, results: Dict) -> Dict: + img = results["img"] + orig_height, orig_width = img.shape[:2] + + # 1. Compute the scale factor to maintain aspect ratio + scale_w = self.target_width / orig_width + scale_h = self.target_height / orig_height + scale_factor = min(scale_w, scale_h) + + # 2. Determine new (width, height) after aspect-preserving resize + new_width = int(round(orig_width * scale_factor)) + new_height = int(round(orig_height * scale_factor)) + + # 3. Resize the image + resized_img = cv2.resize( + img, (new_width, new_height), interpolation=cv2.INTER_LINEAR + ) + + # 4. Create a black canvas of final size [H, W] + final_img = np.zeros( + (self.target_height, self.target_width, resized_img.shape[2]) + if resized_img.ndim == 3 + else (self.target_height, self.target_width), + dtype=resized_img.dtype, + ) + + # 5. Compute offsets to center the resized image + offset_x = (self.target_width - new_width) // 2 + offset_y = (self.target_height - new_height) // 2 + + # 6. Copy resized image into the canvas + if final_img.ndim == 3: # color image + final_img[ + offset_y : offset_y + new_height, offset_x : offset_x + new_width, : + ] = resized_img + else: # single-channel image + final_img[ + offset_y : offset_y + new_height, offset_x : offset_x + new_width + ] = resized_img + + # 7. Replace `results['img']` with our padded image + results["img"] = final_img + results["img_shape"] = final_img.shape[:2] + + # 8. Do the same for mask & gt_depth + # (using nearest interpolation, then padding to center) + if "mask" in results: + mask_resized = cv2.resize( + results["mask"].astype(np.uint8), + (new_width, new_height), + interpolation=cv2.INTER_NEAREST, + ) + final_mask = np.zeros( + (self.target_height, self.target_width), dtype=mask_resized.dtype + ) + final_mask[ + offset_y : offset_y + new_height, offset_x : offset_x + new_width + ] = mask_resized + results["mask"] = final_mask + + if "gt_depth" in results: + depth_resized = cv2.resize( + results["gt_depth"], + (new_width, new_height), + interpolation=cv2.INTER_NEAREST, + ) + final_depth = np.zeros( + (self.target_height, self.target_width), dtype=depth_resized.dtype + ) + final_depth[ + offset_y : offset_y + new_height, offset_x : offset_x + new_width + ] = depth_resized + results["gt_depth"] = final_depth + + # 9. Adjust camera intrinsics K accordingly + if "K" in results: + K_new = results["K"].copy() + # Scale fx, fy + K_new[0, 0] *= scale_factor # fx + K_new[1, 1] *= scale_factor # fy + + # Scale and then shift principal point by offsets + K_new[0, 2] = K_new[0, 2] * scale_factor + offset_x + K_new[1, 2] = K_new[1, 2] * scale_factor + offset_y + + results["K"] = K_new + + return results + + +@TRANSFORMS.register_module() +class PointmapRandomFlip(BaseTransform): + def __init__(self, prob=0.5) -> None: + super().__init__() + self.prob = prob + + def _flip(self, results: dict) -> None: + """Flip images, masks, depth maps and adjust camera parameters.""" + # flip image + results["img"] = cv2.flip(results["img"], 1) # 1 for horizontal flip + img_shape = results["img"].shape[:2] + + # flip seg map and depth (horizontal flip) + results["mask"] = cv2.flip(results["mask"], 1) + + if "gt_depth" in results: + results["gt_depth"] = cv2.flip(results["gt_depth"], 1) + + # adjust camera parameters + if "K" in results: + # Flip the principal point for the left-right flipped image + results["K"][0, 2] = img_shape[1] - results["K"][0, 2] - 1 + + if "M" in results: + # Flip the sign of the first column of the extrinsics matrix + results["M"][0, :] = -results["M"][0, :] + + def transform(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]: + if np.random.rand() < self.prob: + self._flip(results) + return results + + +@TRANSFORMS.register_module() +class PointmapGenerateTarget(BaseTransform): + def __init__(self, canonical_focal_length=768, target_downsample_factor=None): + self.canonical_focal_length = canonical_focal_length + self.target_downsample_factor = target_downsample_factor + return + + def transform(self, results: dict) -> dict: + if "gt_depth" not in results.keys(): + return results + + ## only downsample gt_depth, mask and K + if self.target_downsample_factor is not None: + assert isinstance(self.target_downsample_factor, int) + + gt_depth = results["gt_depth"] + mask = results["mask"] + K = results["K"] + + gt_depth = cv2.resize( + gt_depth, + None, + fx=1 / self.target_downsample_factor, + fy=1 / self.target_downsample_factor, + interpolation=cv2.INTER_NEAREST, + ) + mask = cv2.resize( + mask, + None, + fx=1 / self.target_downsample_factor, + fy=1 / self.target_downsample_factor, + interpolation=cv2.INTER_NEAREST, + ) + + K[0, 0] = K[0, 0] / self.target_downsample_factor + K[1, 1] = K[1, 1] / self.target_downsample_factor + K[0, 2] = K[0, 2] / self.target_downsample_factor + K[1, 2] = K[1, 2] / self.target_downsample_factor + + results["gt_depth"] = gt_depth + results["mask"] = mask + results["K"] = K + + if "uv_map" in results: + uv_map = results["uv_map"] + uv_map = cv2.resize( + uv_map, + None, + fx=1 / self.target_downsample_factor, + fy=1 / self.target_downsample_factor, + interpolation=cv2.INTER_LINEAR, + ) + results["uv_map"] = uv_map + + gt_depth = results["gt_depth"] ## no normalization + mask = results["mask"] + + fx = results["K"][0, 0] + fy = results["K"][1, 1] + cx = results["K"][0, 2] + cy = results["K"][1, 2] + + scale = 1.0 + if self.canonical_focal_length is not None: + scale = self.canonical_focal_length / fx + + cols, rows = np.meshgrid( + np.arange(gt_depth.shape[1]), np.arange(gt_depth.shape[0]) + ) + X = (cols - cx) * gt_depth / fx + Y = (rows - cy) * gt_depth / fy + Z = gt_depth + + # # # ##-----------debug----------------------- + # image = results['img'] + # K = results['K'] + # mask = results['mask'] > 0 + + # # Set random seed + # seed = np.random.randint(0, 10000) + + # # Project to image plane + # x = K[0,0] * X/Z + K[0,2] # new_fx * X/Z + cx + # y = K[1,1] * Y/Z + K[1,2] # new_fy * Y/Z + cy + + # # Round to nearest pixel and clip to image bounds + # x = np.clip(np.round(x), 0, image.shape[1]-1).astype(int) + # y = np.clip(np.round(y), 0, image.shape[0]-1).astype(int) + + # # Create visualization + # debug_img = image.copy() + # # Draw all valid projected points in green + # debug_img[y[mask], x[mask]] = [0, 255, 0] # Set projected points to green + # debug_img = np.concatenate([image, debug_img], axis=1) + + # # Save debug image + # cv2.imwrite(f'seed{seed}.jpg', debug_img) + # # ----------------------------------------- + + # Scale the coordinates. isotropic scaling + X = X * scale + Y = Y * scale + Z = Z * scale + + results["original_K"] = results["K"].copy() + results["scale"] = scale + + if self.canonical_focal_length is not None: + # New camera intrinsics + new_K = results["K"].copy() + new_K[0, 0] = fx * scale # new fx + new_K[1, 1] = fy * scale # new fy + new_K[0, 2] = cx * scale + new_K[1, 2] = cy * scale + results["K"] = new_K + + gt_pointmap = np.stack([X, Y, Z], axis=-1) + results["gt_depth"] = Z ## canonical depth + + ## preserve range by removing invalid points + gt_pointmap[mask == 0] = 0 + results["gt_pointmap"] = gt_pointmap + + return results + + def __repr__(self): + return self.__class__.__name__ + + +@TRANSFORMS.register_module() +class PointmapPackInputs(BaseTransform): + def __init__( + self, + meta_keys=( + "img_path", + "ori_shape", + "img_shape", + "pad_shape", + "scale_factor", + "flip", + "flip_direction", + "K", + "original_K", + "M", + ), + ): + self.meta_keys = meta_keys + + def transform(self, results: dict) -> dict: + packed_results = dict() + if "img" in results: + img = results["img"] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + if not img.flags.c_contiguous: + img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1))) + else: + img = img.transpose(2, 0, 1) + img = to_tensor(img).contiguous() + packed_results["inputs"] = img + + data_sample = dict() + + if "gt_pointmap" in results: + mask = results["mask"] > 0 ## boolean mask + + ## min number of valid pixels is 4 + if mask.sum() < 16: + return None + + if mask.sum() / (mask.shape[0] * mask.shape[1]) > 0.96: + return None + + ## clipping inside camera + min_depth = results["gt_pointmap"][results["mask"] > 0, 2].min() + if min_depth < 0.04: + return None + + gt_mean_depth = results["gt_pointmap"][results["mask"] > 0, 2].mean() + + # ##------------------debug------------------ + # ## print the min, max and mean of the X, Y, Z coordinates + # X = results["gt_pointmap"][results["mask"] > 0, 0] + # Y = results["gt_pointmap"][results["mask"] > 0, 1] + # Z = results["gt_pointmap"][results["mask"] > 0, 2] + # inv_Z = 1 / Z + + # print("scale:", results["scale"]) + # print("X min:", X.min(), "X max:", X.max(), "X mean:", X.mean()) + # print("Y min:", Y.min(), "Y max:", Y.max(), "Y mean:", Y.mean()) + # print("Z min:", Z.min(), "Z max:", Z.max(), "Z mean:", Z.mean()) + # print( + # "inv_Z min:", + # inv_Z.min(), + # "inv_Z max:", + # inv_Z.max(), + # "inv_Z mean:", + # inv_Z.mean(), + # ) + # print() + ##----------------------------------------- + mask = to_tensor(mask[None, ...].copy()) ## 1 x H x W + data_sample["mask"] = mask + + gt_pointmap = results["gt_pointmap"].astype(np.float32) ## H x W x 3 + gt_pointmap = gt_pointmap.transpose(2, 0, 1) ## H x W x 3 -> 3 x H x W + data_sample["gt_pointmap"] = to_tensor(gt_pointmap.copy()) + data_sample["gt_mean_depth"] = to_tensor( + gt_mean_depth[None, None, None].copy() + ) + + img_meta = {} + for key in self.meta_keys: + if key in results: + if isinstance(results[key], (int, float)): + img_meta[key] = np.float32(results[key]) + elif isinstance(results[key], np.ndarray): + img_meta[key] = results[key].astype(np.float32) + else: + img_meta[key] = results[key] + data_sample["meta"] = img_meta + packed_results["data_samples"] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(meta_keys={self.meta_keys})" + return repr_str + + +@TRANSFORMS.register_module() +class PointmapResizePadImage(BaseTransform): + def __init__( + self, + height: int = 1024, + width: int = 768, + pad_val: Optional[int] = 0, + padding_mode: str = "constant", + ) -> None: + self.height = height + self.width = width + self.pad_val = pad_val + assert padding_mode in ["constant", "edge", "reflect", "symmetric"] + self.padding_mode = padding_mode + + def _resize_maintain_aspect_ratio(self, img, target_size): + """Resize image maintaining aspect ratio and return padding sizes.""" + original_height, original_width = img.shape[:2] + target_width, target_height = target_size + + # Calculate scaling factors + scale_w = target_width / original_width + scale_h = target_height / original_height + scale = min(scale_w, scale_h) # Use the smaller scaling factor + + # Calculate new dimensions + new_width = int(original_width * scale) + new_height = int(original_height * scale) + + # Resize image + resized_img = cv2.resize( + img, (new_width, new_height), interpolation=cv2.INTER_LINEAR + ) + + # Calculate padding + pad_width = target_width - new_width + pad_height = target_height - new_height + + padding_left = pad_width // 2 + padding_right = pad_width - padding_left + padding_top = pad_height // 2 + padding_bottom = pad_height - padding_top + + return resized_img, (padding_left, padding_right, padding_top, padding_bottom) + + def _pad_img(self, results: dict) -> None: + """Resize image maintaining aspect ratio and pad to target size.""" + img = results["img"] + target_size = (self.width, self.height) # (width, height) + + # Resize image maintaining aspect ratio + resized_img, padding_size = self._resize_maintain_aspect_ratio(img, target_size) + + # Prepare padding value + pad_val = self.pad_val + + # Pad image + padding_left, padding_right, padding_top, padding_bottom = padding_size + if resized_img.ndim == 3: + padded_img = np.pad( + resized_img, + ((padding_top, padding_bottom), (padding_left, padding_right), (0, 0)), + mode=self.padding_mode, + constant_values=pad_val, + ) + else: + padded_img = np.pad( + resized_img, + ((padding_top, padding_bottom), (padding_left, padding_right)), + mode=self.padding_mode, + constant_values=pad_val, + ) + + # Update results dictionary + results["img"] = padded_img + results["pad_shape"] = padded_img.shape + results["pad_fixed_size"] = target_size + results["img_shape"] = padded_img.shape[:2] + results["padding_size"] = padding_size + + def transform(self, results: dict) -> dict: + self._pad_img(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"height={self.height}, " + repr_str += f"width={self.width}, " + repr_str += f"pad_val={self.pad_val}, " + repr_str += f"padding_mode={self.padding_mode})" + return repr_str diff --git a/sapiens/dense/src/datasets/transforms/seg_transforms.py b/sapiens/dense/src/datasets/transforms/seg_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..8d42d47856e1e1288c8859164062101658b23a51 --- /dev/null +++ b/sapiens/dense/src/datasets/transforms/seg_transforms.py @@ -0,0 +1,479 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import random +import warnings + +import cv2 +import numpy as np +from sapiens.engine.datasets import BaseTransform, to_tensor +from sapiens.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class SegRandomRotate(BaseTransform): + def __init__( + self, + prob=0.5, + degree=60, + pad_val=0, + seg_pad_val=255, + ): + super().__init__() + self.prob = prob + assert prob >= 0 and prob <= 1 + assert degree > 0, f"degree {degree} should be positive" + self.degree = (-degree, degree) + assert len(self.degree) == 2, ( + f"degree {self.degree} should be a tuple of (min, max)" + ) + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + def transform(self, results: dict) -> dict: + if random.random() > self.prob: + return results + + degree = random.uniform(min(*self.degree), max(*self.degree)) + img = results["img"] + gt_seg = results["gt_seg"] + h, w = img.shape[:2] + center = (w / 2, h / 2) + + M = cv2.getRotationMatrix2D(center, degree, 1.0) + + results["img"] = cv2.warpAffine( + img, M, (w, h), flags=cv2.INTER_LINEAR, borderValue=self.pad_val + ) + + # Rotate the segmentation map + results["gt_seg"] = cv2.warpAffine( + gt_seg, + M, + (w, h), + flags=cv2.INTER_NEAREST, + borderValue=self.seg_pad_val, + ) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += ( + f"(prob={self.prob}, " + f"degree={self.degree}, " + f"pad_val={self.pal_val}, " + f"seg_pad_val={self.seg_pad_val}, " + ) + return repr_str + + +@TRANSFORMS.register_module() +class SegRandomHorizontalFlip(BaseTransform): + def __init__(self, prob=0.5, swap_seg_labels=None): + super().__init__() + self.prob = prob + self.swap_seg_labels = swap_seg_labels + + def transform(self, results: dict) -> dict: + if random.random() > self.prob: + return results + + img = results["img"] + gt_seg = results["gt_seg"] + img = cv2.flip(img, 1) + gt_seg = cv2.flip(gt_seg, 1) + temp = gt_seg.copy() + if self.swap_seg_labels is not None: + for pair in self.swap_seg_labels: + assert len(pair) == 2 + gt_seg[temp == pair[0]] = pair[1] + gt_seg[temp == pair[1]] = pair[0] + + results["img"] = img + results["gt_seg"] = gt_seg + return results + + +@TRANSFORMS.register_module() +class SegPackInputs(BaseTransform): + def __init__( + self, + test_mode: bool = False, + meta_keys=( + "img_path", + "ori_shape", + "img_shape", + "pad_shape", + "flip", + ), + ): + super().__init__() + self.test_mode = test_mode + self.meta_keys = meta_keys + + def transform(self, results: dict) -> dict: + packed_results = dict() + if "img" in results: + img = results["img"] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + if not img.flags.c_contiguous: + img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1))) + else: + img = img.transpose(2, 0, 1) + img = to_tensor(img).contiguous() + packed_results["inputs"] = img + + data_sample = dict() + + if "gt_seg" in results: + assert len(results["gt_seg"].shape) == 2 # H x W + mask = (results["gt_seg"] > 0) * (results["gt_seg"] != 255) + if ( + mask.sum() / (mask.shape[0] * mask.shape[1]) < 0.01 + and self.test_mode == False + ): + return None + + data_sample["gt_seg"] = to_tensor( + results["gt_seg"][None, ...].astype(np.int64) + ) + + img_meta = {} + for key in self.meta_keys: + if key in results: + if isinstance(results[key], (int, float)): + img_meta[key] = np.float32(results[key]) + elif isinstance(results[key], np.ndarray): + img_meta[key] = results[key].astype(np.float32) + else: + img_meta[key] = results[key] + data_sample["meta"] = img_meta + packed_results["data_samples"] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(meta_keys={self.meta_keys})" + return repr_str + + +@TRANSFORMS.register_module() +class SegRandomResize(BaseTransform): + def __init__( + self, + base_height=1024, + base_width=768, + ratio_range=(0.4, 2.0), + keep_ratio=True, + ): + super().__init__() + self.base_height = base_height + self.base_width = base_width + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + self.resizer = SegResize( + height=self.base_height, width=self.base_width, keep_ratio=keep_ratio + ) + + def transform(self, results: dict) -> dict: + min_ratio, max_ratio = self.ratio_range + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + self.resizer.height = int(self.base_height * ratio) + self.resizer.width = int(self.base_width * ratio) + return self.resizer.transform(results) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"base_height={self.base_height}, " + f"base_width={self.base_width}, " + f"ratio_range={self.ratio_range}, " + f"keep_ratio={self.keep_ratio})" + ) + + +@TRANSFORMS.register_module() +class SegResize(BaseTransform): + def __init__( + self, + height=1024, + width=768, + keep_ratio=False, + test_mode: bool = False, + ): + super().__init__() + self.height = height + self.width = width + self.keep_ratio = keep_ratio + self.test_mode = test_mode + + def transform(self, results: dict) -> dict: + img = results["img"] + h, w = img.shape[:2] + + target_height = self.height + target_width = self.width + + if self.keep_ratio is True: + scale_factor = min(target_width / w, target_height / h) + new_w = int(round(w * scale_factor)) + new_h = int(round(h * scale_factor)) + + else: + new_w = target_width + new_h = target_height + + dsize = (new_w, new_h) + + # Use INTER_AREA for shrinking and INTER_CUBIC for enlarging + # to get antialiased results. + img_interpolation = cv2.INTER_AREA if new_w < w else cv2.INTER_CUBIC + + # Update the results dictionary + results["img"] = cv2.resize(img, dsize, interpolation=img_interpolation) + + ## resize gt seg if training + if "gt_seg" in results and self.test_mode is False: + gt_seg = results["gt_seg"] + results["gt_seg"] = cv2.resize( + gt_seg, dsize, interpolation=cv2.INTER_NEAREST + ) + return results + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"height={self.height}, " + f"width={self.width}, " + f"keep_ratio={self.keep_ratio})" + ) + + +@TRANSFORMS.register_module() +class SegRandomCrop(BaseTransform): + def __init__( + self, + crop_height=1024, + crop_width=768, + prob=0.5, + cat_max_ratio=0.75, + ignore_index=255, + ): + super().__init__() + self.crop_height = crop_height + self.crop_width = crop_width + self.prob = prob + self.cat_max_ratio = cat_max_ratio + self.ignore_index = ignore_index + + def _generate_crop_bbox(self, img: np.ndarray) -> tuple: + """Randomly get a crop bounding box.""" + margin_h = max(img.shape[0] - self.crop_height, 0) + margin_w = max(img.shape[1] - self.crop_width, 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_height + crop_x1, crop_x2 = offset_w, offset_w + self.crop_width + return crop_y1, crop_y2, crop_x1, crop_x2 + + def _crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop from ``img``""" + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + return img + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images and segmentation maps.""" + if random.random() > self.prob: + return results + + img = results["img"] + gt_seg = results["gt_seg"] + h, w = img.shape[:2] + + # Pad the image if it's smaller than the crop size + pad_h = max(self.crop_height - h, 0) + pad_w = max(self.crop_width - w, 0) + if pad_h > 0 or pad_w > 0: + img = cv2.copyMakeBorder( + img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0 + ) + gt_seg = cv2.copyMakeBorder( + gt_seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_index + ) + + padded_img = img + padded_gt_seg = gt_seg + + crop_bbox = self._generate_crop_bbox(padded_img) + if self.cat_max_ratio < 1.0: + # Repeat 10 times to find a valid crop + for _ in range(10): + seg_temp = self._crop(padded_gt_seg, crop_bbox) + labels, cnt = np.unique(seg_temp, return_counts=True) + # Filter out the ignore_index + cnt = cnt[labels != self.ignore_index] + if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < self.cat_max_ratio: + break # Found a valid crop + crop_bbox = self._generate_crop_bbox(padded_img) + + # Crop the image and segmentation map + results["img"] = self._crop(padded_img, crop_bbox) + results["gt_seg"] = self._crop(padded_gt_seg, crop_bbox) + + return results + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"crop_height={self.crop_height}, " + f"crop_width={self.crop_width}, " + f"prob={self.prob}, " + f"cat_max_ratio={self.cat_max_ratio}, " + f"ignore_index={self.ignore_index})" + ) + + +@TRANSFORMS.register_module() +class SegRandomBackground(BaseTransform): + def __init__( + self, + prob: float = 0.25, + skip_key: str = "is_itw", + background_images_root: str = "", + ): + super().__init__() + + self.prob = prob + self.skip_key = skip_key + self.background_images_root = background_images_root + self.background_images = sorted( + [ + image_name + for image_name in os.listdir(background_images_root) + if image_name.endswith(".jpg") + ] + ) + + def transform(self, results: dict) -> dict: + if random.random() > self.prob: + return results + + if self.skip_key in results and results[self.skip_key]: + return results + + image = results["img"] ## bgr image + + if "gt_seg" in results: + gt_seg = results["gt_seg"] + mask = (gt_seg > 0).astype(np.uint8) + + elif "mask" in results: + mask = results["mask"] + mask = (mask > 0).astype(np.uint8) + + else: + warnings.warn( + f"foreground mask not found in results, skip random background!" + ) + return results + + background_image_path = os.path.join( + self.background_images_root, random.choice(self.background_images) + ) + background_image = cv2.imread(background_image_path) ## bgr image + + ##----------------------------- + background_height = background_image.shape[0] + background_width = background_image.shape[1] + + image_height = image.shape[0] + image_width = image.shape[1] + + new_background_height = image_height + new_background_width = int( + new_background_height * background_width / background_height + ) + background_image = cv2.resize( + background_image, (new_background_width, new_background_height) + ) + + # Crop the background image to the width of the original image + if new_background_width > image_width: + start_x = (new_background_width - image_width) // 2 + end_x = start_x + image_width + background_image = background_image[:, start_x:end_x] + + if ( + background_image.shape[0] != image_height + or background_image.shape[1] != image_width + ): + background_image = cv2.resize(background_image, (image_width, image_height)) + + # Use the segmentation mask as an alpha channel. + alpha_norm = mask.astype(np.float32) # Values 0 or 1. + alpha_mask = np.stack([alpha_norm] * 3, axis=-1) + composite = alpha_mask * image + (1 - alpha_mask) * background_image + composite = composite.astype(np.uint8) + + # Apply color transfer using the Reinhard algorithm. + composite = self.reinhard_alpha(composite, alpha_norm) + results["img"] = composite + + return results + + def reinhard_alpha(self, comp_img, alpha_mask): + """ + # Reinhard color transfer algorithm with alpha mask support + # paper: https://www.cs.tau.ac.il/~turkel/imagepapers/ColorTransfer.pdf + # alpha mask in range [0, 1] + """ + + # Convert to LAB color space + comp_lab = cv2.cvtColor(comp_img, cv2.COLOR_BGR2Lab) + + # Calculate weighted mean and std for background and foreground + bg_weights = 1 - alpha_mask + fg_weights = alpha_mask + + bg_mean, bg_std = self.weighted_mean_std(comp_lab, bg_weights) + fg_mean, fg_std = self.weighted_mean_std(comp_lab, fg_weights) + + # Avoid division by zero + fg_std = np.maximum(fg_std, 1e-6) + + ratio = (bg_std / fg_std).reshape(-1) + offset = (bg_mean - fg_mean * bg_std / fg_std).reshape(-1) + + # Apply color transfer + trans_lab = cv2.convertScaleAbs(comp_lab * ratio + offset) + trans_img = cv2.cvtColor(trans_lab, cv2.COLOR_Lab2BGR) + + # Blend the transferred image with the original image using the alpha mask + alpha_mask_3d = np.repeat(alpha_mask[:, :, np.newaxis], 3, axis=2) + trans_comp = ( + trans_img * alpha_mask_3d + comp_img * (1 - alpha_mask_3d) + ).astype(np.uint8) + + return trans_comp + + def weighted_mean_std(self, img, weights): + # Ensure weights have the same shape as img + weights_3d = np.repeat(weights[:, :, np.newaxis], img.shape[2], axis=2) + + # Calculate weighted mean + total_weights = np.sum(weights_3d, axis=(0, 1)) + mean = np.sum(img * weights_3d, axis=(0, 1)) / total_weights + + # Calculate weighted standard deviation + variance = np.sum(((img - mean) ** 2) * weights_3d, axis=(0, 1)) / total_weights + std = np.sqrt(variance) + + return mean, std diff --git a/sapiens/dense/src/evaluators/__init__.py b/sapiens/dense/src/evaluators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02e9caf6ba53cc91042a74c79af80eda6010f8cd --- /dev/null +++ b/sapiens/dense/src/evaluators/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .albedo_evaluator import AlbedoEvaluator +from .normal_evaluator import NormalEvaluator +from .pointmap_evaluator import PointmapEvaluator +from .seg_evaluator import SegEvaluator + +__all__ = [ + "PointmapEvaluator", + "SegEvaluator", + "NormalEvaluator", + "AlbedoEvaluator", +] diff --git a/sapiens/dense/src/evaluators/albedo_evaluator.py b/sapiens/dense/src/evaluators/albedo_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..133037eed63ffac6920dfb6eac65b2e6d1e7055c --- /dev/null +++ b/sapiens/dense/src/evaluators/albedo_evaluator.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +import torch.nn.functional as F +from prettytable import PrettyTable +from sapiens.engine.evaluators import BaseEvaluator +from sapiens.registry import MODELS + + +@MODELS.register_module() +class AlbedoEvaluator(BaseEvaluator): + def __init__(self): + super().__init__() + self._psnr_data_range: float | None = None # set on first batch + + @staticmethod + def _gaussian_kernel(ks: int = 11, sigma: float = 1.5, device=None, dtype=None): + ax = torch.arange(ks, device=device, dtype=dtype) - (ks - 1) / 2.0 + xx, yy = torch.meshgrid(ax, ax, indexing="xy") + k = torch.exp(-(xx * xx + yy * yy) / (2 * sigma * sigma)) + k = k / k.sum() + return k + + @torch.no_grad() + def _masked_ssim_sum( + self, + pred: torch.Tensor, + gt: torch.Tensor, + mask: torch.Tensor, + data_range: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + pred, gt: (3, H, W), mask: (H, W) bool/0-1 + Returns (sum_ssim, count_ssim) across valid windows. + """ + eps = 1e-8 + C1 = (0.01 * data_range) ** 2 + C2 = (0.03 * data_range) ** 2 + + if pred.dtype == torch.bfloat16: + pred = pred.float() + gt = gt.float() + mask = mask.float() + + x = pred.unsqueeze(0) # (1,3,H,W) + y = gt.unsqueeze(0) # (1,3,H,W) + m = mask.unsqueeze(0).unsqueeze(0).to(dtype=x.dtype) # (1,1,H,W) + + B, C, H, W = x.shape + k = self._gaussian_kernel(ks=11, sigma=1.5, device=x.device, dtype=x.dtype) + pad = 11 // 2 + k_img = k.view(1, 1, 11, 11) + k_ch = k_img.repeat(C, 1, 1, 1) # grouped conv kernel + + # local normalization with mask + m_conv = F.conv2d(m, k_img, padding=pad) # (1,1,H,W) + m_conv = torch.clamp(m_conv, min=eps) + + def _conv(z): + return F.conv2d(z, k_ch, padding=pad, groups=C) + + x_m = x * m + y_m = y * m + + mu_x = _conv(x_m) / m_conv + mu_y = _conv(y_m) / m_conv + + x2_m = (x * x) * m + y2_m = (y * y) * m + xy_m = (x * y) * m + + sigma_x2 = _conv(x2_m) / m_conv - mu_x * mu_x + sigma_y2 = _conv(y2_m) / m_conv - mu_y * mu_y + sigma_xy = _conv(xy_m) / m_conv - mu_x * mu_y + + num = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2) + den = (mu_x * mu_x + mu_y * mu_y + C1) * (sigma_x2 + sigma_y2 + C2) + ssim_map_ch = num / (den + eps) # (1,C,H,W) + ssim_map = ssim_map_ch.mean( + dim=1, keepdim=True + ) # average over channels -> (1,1,H,W) + + # Only count windows with sufficient valid support + valid_win = (m_conv > 0.5).squeeze(0).squeeze(0) # (H,W) + sum_ssim = ssim_map.squeeze(0).squeeze(0)[valid_win].to(torch.float64).sum() + cnt_ssim = valid_win.to(torch.float64).sum() + return sum_ssim, cnt_ssim + + @torch.no_grad() + def process(self, predictions: torch.Tensor, data_samples: dict, accelerator=None): + """ + Args: + predictions: Tensor, predicted albedo (B, 3, H_low, W_low) + data_samples: dict with keys: + - "mask": (B, 1, H, W) >0 is valid + - "gt_albedo": (B, 3, H, W) + """ + assert accelerator is not None, "evaluation process expects an accelerator" + pred_albedos = predictions # (B,3,h,w) + gt_masks = data_samples["mask"] # (B,1,H,W) + gt_albedos = data_samples["gt_albedo"] # (B,3,H,W) + + # align spatial size + if pred_albedos.shape[2:] != gt_albedos.shape[2:]: + pred_albedos = F.interpolate( + input=pred_albedos, + size=gt_albedos.shape[2:], + mode="bilinear", + align_corners=False, + antialias=False, + ) + + # set PSNR range (once) + if self._psnr_data_range is None: + mx = gt_albedos.detach().max() + self._psnr_data_range = 255.0 if mx > 1.5 else 1.0 + + B = gt_albedos.shape[0] + per_sample_vecs = [] # each: [sum_l1, sum_l2, N_pix, sum_grad_l1, N_grad, sum_ssim, N_ssim] + + for i in range(B): + mask = gt_masks[i, 0] > 0 + n_valid = int(mask.sum().item()) + assert n_valid > 0, "no valid pixels found" + + gt = gt_albedos[i] # (3,H,W) + pr = pred_albedos[i] # (3,H,W) + + # --- Pixel MAE / RMSE accumulators (average over channels per pixel) --- + diff = pr - gt + l1_pix = diff.abs().mean(dim=0) # (H,W) + l2_pix = (diff * diff).mean(dim=0) # (H,W) + + sum_l1 = l1_pix[mask].to(torch.float64).sum().unsqueeze(0) # (1,) + sum_l2 = l2_pix[mask].to(torch.float64).sum().unsqueeze(0) # (1,) + N_pix = torch.tensor( + [float(n_valid)], dtype=torch.float64, device=pr.device + ) + + # --- Gradient L1 (simple forward differences; mask both sides) --- + # horizontal + mask_h = mask[:, 1:] & mask[:, :-1] + dx_pr = pr[:, :, 1:] - pr[:, :, :-1] + dx_gt = gt[:, :, 1:] - gt[:, :, :-1] + grad_l1_h = (dx_pr - dx_gt).abs().mean(dim=0) # (H,W-1) + sum_grad_h = grad_l1_h[mask_h].to(torch.float64).sum() + N_grad_h = mask_h.to(torch.float64).sum() + + # vertical + mask_v = mask[1:, :] & mask[:-1, :] + dy_pr = pr[:, 1:, :] - pr[:, :-1, :] + dy_gt = gt[:, 1:, :] - gt[:, :-1, :] + grad_l1_v = (dy_pr - dy_gt).abs().mean(dim=0) # (H-1,W) + sum_grad_v = grad_l1_v[mask_v].to(torch.float64).sum() + N_grad_v = mask_v.to(torch.float64).sum() + + sum_grad_l1 = (sum_grad_h + sum_grad_v).unsqueeze(0) # (1,) + N_grad = (N_grad_h + N_grad_v).unsqueeze(0) # (1,) + + # --- SSIM (masked, with Gaussian window) --- + sum_ssim, cnt_ssim = self._masked_ssim_sum( + pr, gt, mask, data_range=float(self._psnr_data_range) + ) + sum_ssim = sum_ssim.unsqueeze(0) # (1,) + N_ssim = cnt_ssim.unsqueeze(0) # (1,) + + vec = torch.cat( + [sum_l1, sum_l2, N_pix, sum_grad_l1, N_grad, sum_ssim, N_ssim], dim=0 + ) + per_sample_vecs.append(vec) + + pack = torch.stack(per_sample_vecs, dim=0) # (B_local, 7) + gpack = accelerator.gather_for_metrics(pack) # (B_global_step, 7) + step_totals = gpack.sum(dim=0) # (7,) + + if accelerator.is_main_process: + self.results.append(step_totals) + return + + def evaluate(self, logger=None, accelerator=None) -> Dict[str, float]: + """ + Returns: + Dict[str, float]: { + 'albedo_mae', 'albedo_rmse', 'albedo_psnr', + 'albedo_ssim', 'albedo_grad_l1' + } + """ + assert accelerator is not None, "evaluation aggregation expects an accelerator" + + if not accelerator.is_main_process: + self.reset() + return {} + + if not self.results: + if logger is not None: + logger.info("No results to evaluate.") + return {} + + totals = torch.stack(self.results, dim=0).sum(dim=0) # (7,) + idx = 0 + sum_l1 = totals[idx] + idx += 1 + sum_l2 = totals[idx] + idx += 1 + N_pix = totals[idx] + idx += 1 + sum_grad_l1 = totals[idx] + idx += 1 + N_grad = totals[idx] + idx += 1 + sum_ssim = totals[idx] + idx += 1 + N_ssim = totals[idx] + idx += 1 + + # Core metrics + mae = (sum_l1 / N_pix).item() + mse = (sum_l2 / N_pix).clamp_min(1e-12) + rmse = torch.sqrt(mse).item() + + # PSNR + L2 = float(self._psnr_data_range or 1.0) ** 2 + psnr = (10.0 * torch.log10(torch.tensor(L2, dtype=torch.float64) / mse)).item() + + # SSIM (mean over valid windows) + ssim = (sum_ssim / torch.clamp_min(N_ssim, 1.0)).item() + + # Gradient L1 + grad_l1 = (sum_grad_l1 / torch.clamp_min(N_grad, 1.0)).item() + + metrics: Dict[str, float] = { + "mae": float(mae), + "rmse": float(rmse), + "psnr": float(psnr), + "ssim": float(ssim), + "grad_l1": float(grad_l1), + } + + table = PrettyTable() + table.field_names = list(metrics.keys()) + table.add_row([f"{float(v):.5f}" for v in metrics.values()]) + if logger is not None: + logger.info("\n" + table.get_string()) + + self.reset() + return metrics diff --git a/sapiens/dense/src/evaluators/normal_evaluator.py b/sapiens/dense/src/evaluators/normal_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef3297a7f1debcc6f36b93004d8fb74b13804de --- /dev/null +++ b/sapiens/dense/src/evaluators/normal_evaluator.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from typing import Dict, List, Sequence + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from prettytable import PrettyTable +from sapiens.engine.evaluators import BaseEvaluator +from sapiens.registry import MODELS + + +@MODELS.register_module() +class NormalEvaluator(BaseEvaluator): + def __init__( + self, + angle_thresholds: list[float] = [5.0, 11.25, 22.5, 30.0], + hist_bin_size_deg: float = 0.5, + hist_max_deg: float = 180.0, + ): + super().__init__() + self.angle_thresholds = angle_thresholds + self.hist_bin_size_deg = float(hist_bin_size_deg) + self.hist_max_deg = float(hist_max_deg) + + # number of histogram bins, edges computed on demand + self._num_bins = int( + torch.ceil(torch.tensor(self.hist_max_deg / self.hist_bin_size_deg)).item() + ) + + @torch.no_grad() + def process(self, predictions: torch.Tensor, data_samples: dict, accelerator=None): + """ + Process a single batch of predictions and ground truth data. + + Args: + predictions (tuple): A tuple containing the predicted pointmap and scale. + data_samples (List[Dict]): A list of dictionaries, each containing ground truth data. + """ + assert accelerator is not None, "evaluation process expects an accelerator" + pred_normals = predictions ## pred normals, B x 3 x H_low x W_low + gt_masks = data_samples["mask"] # B x 1 x H x W + gt_normals = data_samples["gt_normal"] # B x 3 x H x W + + if pred_normals.shape[2:] != gt_normals.shape[2:]: + pred_normals = F.interpolate( + input=pred_normals, + size=gt_normals.shape[2:], + mode="bilinear", + align_corners=False, + antialias=False, + ) + + ## normalize + eps = 1e-6 + pred_normals = pred_normals / pred_normals.norm(dim=1, keepdim=True).clamp_min( + eps + ) + gt_normals = gt_normals / gt_normals.norm(dim=1, keepdim=True).clamp_min(eps) + + B = gt_normals.shape[0] + HN = self._num_bins + + # packed vector layout: + # [ sum_angle, sum_angle2, N, counts( 0 + n_valid = int(valid.sum().item()) + + assert n_valid > 0, "no valid pixels found" + + gt = gt_normals[i].permute(1, 2, 0)[valid] # (N,3) + pr = pred_normals[i].permute(1, 2, 0)[valid] # (N,3) + + dot = (gt * pr).sum(dim=1) # (N,) + dot = dot.clamp(-1.0, 1.0) + angle = torch.acos(dot) * (180.0 / torch.pi) # (N,) + + ## sums + sum_angle = angle.sum().to(torch.float64).unsqueeze(0) # shape (1,) + sum_angle2 = ( + (angle * angle).sum().to(torch.float64).unsqueeze(0) + ) # shape (1,) + N_tensor = torch.tensor( + [float(n_valid)], dtype=torch.float64, device=pred_normals.device + ) # (1,) + + ## thresholds + th_counts = torch.stack( + [(angle < t).sum().to(torch.float64) for t in self.angle_thresholds], + dim=0, + ) + + ## histogram + idx = torch.floor(angle / self.hist_bin_size_deg).long().clamp_(0, HN - 1) + hist = torch.bincount(idx, minlength=HN).to(torch.float64) + + vec = torch.cat( + [sum_angle, sum_angle2, N_tensor, th_counts, hist], dim=0 + ) # (K,) + per_sample_vecs.append(vec) + + # (B_local, K) + pack = torch.stack(per_sample_vecs, dim=0) + gpack = accelerator.gather_for_metrics(pack) # (B_global_this_step, K) + step_totals = gpack.sum(dim=0) # (K,) + + if accelerator.is_main_process: + self.results.append(step_totals) # store one vector per step on rank-0 + return + + def evaluate(self, logger=None, accelerator=None) -> Dict[str, float]: + """ + Compute and log the final metrics after processing all batches. + + Returns: + Dict[str, float]: A dictionary of the final computed metrics. + """ + assert accelerator is not None, "evaluation aggregation expects an accelerator" + + if not accelerator.is_main_process: + self.reset() + return {} + + if not self.results: + if logger is not None: + logger.info("No results to evaluate.") + return {} + + totals_vec = torch.stack(self.results, dim=0).sum(dim=0) # (K,) + A = len(self.angle_thresholds) + HN = self._num_bins + + idx = 0 + sum_angle = totals_vec[idx] + idx += 1 + sum_angle2 = totals_vec[idx] + idx += 1 + n_total = totals_vec[idx] + idx += 1 + ang_counts = totals_vec[idx : idx + A] + idx += A + hist_counts = totals_vec[idx : idx + HN] + idx += HN + + # Core metrics + mae = (sum_angle / n_total).item() + rmse = torch.sqrt(sum_angle2 / n_total).item() + within = (ang_counts / n_total * 100.0).tolist() + + # Global median from histogram (bin center) + cdf = torch.cumsum(hist_counts, dim=0) + mid = 0.5 * n_total + bin_idx = torch.searchsorted(cdf, mid).clamp(max=HN - 1).item() + bin_lo = bin_idx * self.hist_bin_size_deg + bin_hi = (bin_idx + 1) * self.hist_bin_size_deg + median = 0.5 * (bin_lo + bin_hi) + + # Assemble metrics dict + metrics: Dict[str, float] = { + "normal_mae": mae, + "normal_median_deg": float(median), + "normal_rmse": rmse, + } + for j, t in enumerate(self.angle_thresholds): + suf = str(t).replace(".", "_") + metrics[f"within_{suf}_deg"] = float(within[j]) + + # Pretty print + table = PrettyTable() + table.field_names = list(metrics.keys()) + table.add_row([f"{float(v):.5f}" for v in metrics.values()]) + if logger is not None: + logger.info("\n" + table.get_string()) + + self.reset() + return metrics diff --git a/sapiens/dense/src/evaluators/pointmap_evaluator.py b/sapiens/dense/src/evaluators/pointmap_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..4f36ce8bbfb8352187d1a3327bb162d04afcbc6a --- /dev/null +++ b/sapiens/dense/src/evaluators/pointmap_evaluator.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from typing import Dict, List, Sequence + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from prettytable import PrettyTable +from sapiens.engine.evaluators import BaseEvaluator +from sapiens.registry import MODELS + + +@MODELS.register_module() +class PointmapEvaluator(BaseEvaluator): + def __init__( + self, + distance_thresholds: List[float] = [0.05, 0.10, 0.20], + angle_thresholds: List[float] = [5.0, 11.25, 22.5, 30.0], + ): + super().__init__() + self.distance_thresholds = distance_thresholds + self.angle_thresholds = angle_thresholds + + def _compute_surface_normals( + self, point_map: torch.Tensor, valid_mask: torch.Tensor + ) -> torch.Tensor: + """ + Compute surface normals from a point map. + + Args: + point_map (torch.Tensor): Point map of shape (H, W, 3). + valid_mask (torch.Tensor): Boolean mask of valid points of shape (H, W). + + Returns: + torch.Tensor: Surface normals for valid points of shape (N, 3). + """ + points_np = point_map.cpu().numpy().astype(np.float32) + + grad_x = cv2.Sobel(points_np, cv2.CV_32F, 1, 0, ksize=5) + grad_y = cv2.Sobel(points_np, cv2.CV_32F, 0, 1, ksize=5) + + normals = np.cross(grad_x, grad_y) + norms = np.linalg.norm(normals, axis=2, keepdims=True) + normals = normals / (norms + 1e-6) + + normals_tensor = torch.from_numpy(normals).to(point_map.device) + valid_normals = normals_tensor[valid_mask] + + return valid_normals + + @torch.no_grad() + def process(self, predictions: tuple, data_samples: dict, accelerator=None): + """ + Process a single batch of predictions and ground truth data. + + Args: + predictions (tuple): A tuple containing the predicted pointmap and scale. + data_samples (List[Dict]): A list of dictionaries, each containing ground truth data. + """ + assert accelerator is not None, "evaluation process expects an accelerator" + pred_pointmaps, _ = predictions ## gt pointmaps are canonicalized + gt_masks = data_samples["mask"] # B x 1 x H x W + gt_pointmaps = data_samples["gt_pointmap"] # B x 3 x H x W + + if pred_pointmaps.shape[2:] != gt_pointmaps.shape[2:]: + pred_pointmaps = F.interpolate( + input=pred_pointmaps, + size=gt_pointmaps.shape[2:], + mode="bilinear", + align_corners=False, + antialias=False, + ) + + B = gt_pointmaps.shape[0] + D = len(self.distance_thresholds) + A = len(self.angle_thresholds) + per_sample_vecs = [] # (B_local, K) + + for i in range(B): + pred_pm = pred_pointmaps[i].permute(1, 2, 0) # (H, W, 3) + gt_pm = gt_pointmaps[i].permute(1, 2, 0) # (H, W, 3) + valid = gt_masks[i][0] > 0 # (H, W) bool + + # Keep shapes consistent even if there are no valid points + if valid.any(): + gt_pts = gt_pm[valid] # (N, 3) + pred_pts = pred_pm[valid] # (N, 3) + diff = pred_pts - gt_pts # (N, 3) + + # distances & axis errors + distances = torch.norm(diff, dim=-1) # (N,) + axis_err = torch.abs(diff) # (N, 3) + + # normals (computed from *full* maps, then masked) + gt_normals = self._compute_surface_normals(gt_pm, valid) # (N, 3) + pred_normals = self._compute_surface_normals(pred_pm, valid) # (N, 3) + dot = (gt_normals * pred_normals).sum(dim=1).clamp(-1.0, 1.0) + angles = torch.acos(dot) * (180.0 / np.pi) # (N,) + + num_points = float(distances.shape[0]) + + # base sums + l2_sum = distances.sum() + x_abs_sum = axis_err[:, 0].sum() + y_abs_sum = axis_err[:, 1].sum() + z_abs_sum = axis_err[:, 2].sum() + squared_dist_sum = (distances**2).sum() + angle_sum = angles.sum() + squared_angle_sum = (angles**2).sum() + + # threshold counts + dist_counts = [(distances < t).sum() for t in self.distance_thresholds] + ang_counts = [(angles < t).sum() for t in self.angle_thresholds] + + else: + # all zeros when no valid points + l2_sum = x_abs_sum = y_abs_sum = z_abs_sum = 0.0 + squared_dist_sum = angle_sum = squared_angle_sum = 0.0 + num_points = 0.0 + dist_counts = [0.0] * D + ang_counts = [0.0] * A + + # assemble fixed-length vector on the same device + vec_list = [ + l2_sum, + x_abs_sum, + y_abs_sum, + z_abs_sum, + squared_dist_sum, + angle_sum, + squared_angle_sum, + num_points, + *dist_counts, + *ang_counts, + ] + vec = torch.tensor( + [float(v) for v in vec_list], + device=pred_pointmaps.device, + dtype=torch.float64, + ) # stable accumulation + per_sample_vecs.append(vec) + + # (B_local, K) + pack = torch.stack(per_sample_vecs, dim=0) + + # Global per-step totals via Accelerate (dedups final step automatically) + gpack = accelerator.gather_for_metrics(pack) # (B_global_this_step, K) + step_totals = gpack.sum(dim=0) # (K,) + + if accelerator.is_main_process: + self.results.append(step_totals) # store one vector per step on rank-0 + return + + def evaluate(self, logger=None, accelerator=None) -> Dict[str, float]: + """ + Compute and log the final metrics after processing all batches. + + Returns: + Dict[str, float]: A dictionary of the final computed metrics. + """ + assert accelerator is not None, "evaluation aggregation expects an accelerator" + + if not accelerator.is_main_process: + self.reset() + return {} + + if not self.results: + if logger is not None: + logger.info("No results to evaluate.") + return {} + + totals_vec = torch.stack(self.results, dim=0).sum(dim=0) # (K,) + D = len(self.distance_thresholds) + A = len(self.angle_thresholds) + + # unpack + idx = 0 + l2_sum = totals_vec[idx] + idx += 1 + x_abs_sum = totals_vec[idx] + idx += 1 + y_abs_sum = totals_vec[idx] + idx += 1 + z_abs_sum = totals_vec[idx] + idx += 1 + squared_dist_sum = totals_vec[idx] + idx += 1 + angle_sum = totals_vec[idx] + idx += 1 + squared_angle_sum = totals_vec[idx] + idx += 1 + num_points = totals_vec[idx] + idx += 1 + + dist_counts = totals_vec[idx : idx + D] + idx += D + ang_counts = totals_vec[idx : idx + A] + idx += A + + total_points = float(num_points.item()) + if total_points <= 0: + if logger is not None: + logger.info("No valid points found to evaluate.") + self.reset() + return {} + + # metrics + metrics: Dict[str, float] = {} + metrics["l2_mean"] = (l2_sum / total_points).item() + metrics["x_mae"] = (x_abs_sum / total_points).item() + metrics["y_mae"] = (y_abs_sum / total_points).item() + metrics["z_mae"] = (z_abs_sum / total_points).item() + metrics["rmse"] = torch.sqrt(squared_dist_sum / total_points).item() + metrics["normal_mae"] = (angle_sum / total_points).item() + metrics["normal_rmse"] = torch.sqrt(squared_angle_sum / total_points).item() + + for i, t in enumerate(self.distance_thresholds): + out_key = f"within_{int(t * 100):02d}_cm" + metrics[out_key] = (dist_counts[i] / total_points).item() + + for j, t in enumerate(self.angle_thresholds): + suf = str(t).replace(".", "_") + out_key = f"within_{suf}_deg" + metrics[out_key] = (ang_counts[j] / total_points).item() + + # pretty print + table = PrettyTable() + table.field_names = list(metrics.keys()) + table.add_row([f"{float(val):.5f}" for val in metrics.values()]) + if logger is not None: + logger.info("\n" + table.get_string()) + + self.reset() + return metrics diff --git a/sapiens/dense/src/evaluators/seg_evaluator.py b/sapiens/dense/src/evaluators/seg_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..80fc9b4a8b78cc660217da1a30c9de13a469d536 --- /dev/null +++ b/sapiens/dense/src/evaluators/seg_evaluator.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from collections import defaultdict, OrderedDict +from typing import Dict, List, Optional, Sequence + +import numpy as np +import torch +import torch.nn.functional as F +from prettytable import PrettyTable +from sapiens.engine.evaluators import BaseEvaluator +from sapiens.registry import MODELS + +from ..datasets.seg.seg_dome_dataset import DOME_CLASSES_29 + + +@MODELS.register_module() +class SegEvaluator(BaseEvaluator): + def __init__( + self, + class_names="dome29", + ignore_index: int = 255, + iou_metrics: List[str] = ["mIoU"], + nan_to_num: Optional[int] = None, + beta: int = 1, + ): + super().__init__() + self.ignore_index = ignore_index + self.class_names = ( + self.extract_class(DOME_CLASSES_29) if class_names == "dome29" else None + ) + self.metrics = iou_metrics + self.nan_to_num = nan_to_num + self.beta = beta + + def extract_class(self, class_names): + return [class_info["name"] for _, class_info in class_names.items()] + + @torch.no_grad() + def process(self, pred_logits, data_samples: dict, accelerator=None): + assert accelerator is not None, "evaluation process expects an accelerator" + num_classes = pred_logits.shape[1] + + ai_list, au_list, apl_list, al_list = [], [], [], [] + for i in range(len(pred_logits)): + pred_logit = pred_logits[i] # C x H x W + gt_label = data_samples[i]["gt_seg"].squeeze() # H x W + + if pred_logit.shape[2:] != gt_label.shape: + pred_logit = F.interpolate( + input=pred_logit.unsqueeze(0), + size=gt_label.shape, + mode="bilinear", + align_corners=False, + antialias=False, + ).squeeze(0) + + pred_label = pred_logit.argmax(dim=0) # H x W + + a_i, a_u, a_pl, a_l = self.intersect_and_union( + pred_label, gt_label, num_classes, self.ignore_index + ) + + ai_list.append(a_i) + au_list.append(a_u) + apl_list.append(a_pl) + al_list.append(a_l) + + # Local per-batch tensors: (B_local, C) + ai = torch.stack(ai_list, dim=0) + au = torch.stack(au_list, dim=0) + apl = torch.stack(apl_list, dim=0) + al = torch.stack(al_list, dim=0) + + # Pack as (B_local, 4, C) so gather concatenates along the batch dim. + pack = torch.stack([ai, au, apl, al], dim=1) # (B_local, 4, C) + gpack = accelerator.gather_for_metrics(pack) # (B_global_this_step, 4, C) + batch_tot = gpack.sum(dim=0) # (4, C) global for this step + + ai_g, au_g, apl_g, al_g = batch_tot[0], batch_tot[1], batch_tot[2], batch_tot[3] + + # Only rank-0 appends real totals for this batch + if accelerator.is_main_process: + self.results.append((ai_g, au_g, apl_g, al_g)) + + return + + def evaluate(self, logger=None, accelerator=None) -> Dict[str, float]: + assert accelerator is not None, "evaluation aggregation expects an accelerator" + + if not accelerator.is_main_process: + self.reset() + return {} + + if not self.results: + if logger is not None: + logger.info("No results to evaluate.") + return {} + + per_field = list(zip(*self.results)) # [(ai_b), (au_b), (apl_b), (al_b)] + totals = [torch.stack(x, dim=0).sum(dim=0) for x in per_field] + + ( + total_area_intersect, + total_area_union, + total_area_pred_label, + total_area_label, + ) = totals # tensors already reduced across ranks + + ret_metrics = self.total_area_to_metrics( + total_area_intersect, + total_area_union, + total_area_pred_label, + total_area_label, + self.metrics, + self.nan_to_num, + self.beta, + ) + + ret_metrics_summary = OrderedDict( + { + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + } + ) + metrics = dict() + for key, val in ret_metrics_summary.items(): + if key == "aAcc": + metrics[key] = val + else: + metrics["m" + key] = val + + # each class table + ret_metrics.pop("aAcc", None) + ret_metrics_class = OrderedDict( + { + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + } + ) + if self.class_names is not None: + ret_metrics_class.update({"Class": self.class_names}) + ret_metrics_class.move_to_end("Class", last=False) + class_table_data = PrettyTable() + for key, val in ret_metrics_class.items(): + class_table_data.add_column(key, val) + + logger.info("\n" + class_table_data.get_string()) + + self.reset() + return metrics + + def intersect_and_union( + self, + pred_label: torch.tensor, + label: torch.tensor, + num_classes: int, + ignore_index: int, + ): + mask = label != ignore_index + pred_label = pred_label[mask] + label = label[mask] + + intersect = pred_label[pred_label == label] + area_intersect = torch.histc( + intersect.float(), bins=(num_classes), min=0, max=num_classes - 1 + ) + area_pred_label = torch.histc( + pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1 + ) + area_label = torch.histc( + label.float(), bins=(num_classes), min=0, max=num_classes - 1 + ) + area_union = area_pred_label + area_label - area_intersect + return area_intersect, area_union, area_pred_label, area_label + + def total_area_to_metrics( + self, + total_area_intersect: np.ndarray, + total_area_union: np.ndarray, + total_area_pred_label: np.ndarray, + total_area_label: np.ndarray, + metrics: List[str] = ["mIoU"], + nan_to_num: Optional[int] = None, + beta: int = 1, + ): + def f_score(precision, recall, beta=1): + score = ( + (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall) + ) + return score + + if isinstance(metrics, str): + metrics = [metrics] + allowed_metrics = ["mIoU", "mDice", "mFscore"] + if not set(metrics).issubset(set(allowed_metrics)): + raise KeyError(f"metrics {metrics} is not supported") + + all_acc = total_area_intersect.sum() / total_area_label.sum() + ret_metrics = OrderedDict({"aAcc": all_acc}) + for metric in metrics: + if metric == "mIoU": + iou = total_area_intersect / total_area_union + acc = total_area_intersect / total_area_label + ret_metrics["IoU"] = iou + ret_metrics["Acc"] = acc + elif metric == "mDice": + dice = ( + 2 + * total_area_intersect + / (total_area_pred_label + total_area_label) + ) + acc = total_area_intersect / total_area_label + ret_metrics["Dice"] = dice + ret_metrics["Acc"] = acc + elif metric == "mFscore": + precision = total_area_intersect / total_area_pred_label + recall = total_area_intersect / total_area_label + f_value = torch.tensor( + [f_score(x[0], x[1], beta) for x in zip(precision, recall)] + ) + ret_metrics["Fscore"] = f_value + ret_metrics["Precision"] = precision + ret_metrics["Recall"] = recall + + ret_metrics = { + metric: ( + value.detach().cpu().numpy() + if isinstance(value, torch.Tensor) + else value + ) + for metric, value in ret_metrics.items() + } + if nan_to_num is not None: + ret_metrics = OrderedDict( + { + metric: np.nan_to_num(metric_value, nan=nan_to_num) + for metric, metric_value in ret_metrics.items() + } + ) + return ret_metrics diff --git a/sapiens/dense/src/models/__init__.py b/sapiens/dense/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7cc0bef718fe87be4c8419b261ca639ccd950fb --- /dev/null +++ b/sapiens/dense/src/models/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .core import * +from .heads import * +from .losses import * +from .init_model import init_model diff --git a/sapiens/dense/src/models/core/__init__.py b/sapiens/dense/src/models/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb71a6fb864cd90c4dce2996689bc24ec037ad2c --- /dev/null +++ b/sapiens/dense/src/models/core/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .albedo_estimator import AlbedoEstimator +from .normal_estimator import NormalEstimator +from .pointmap_estimator import PointmapEstimator +from .seg_estimator import SegEstimator + +__all__ = [ + "PointmapEstimator", + "SegEstimator", + "NormalEstimator", + "AlbedoEstimator", +] diff --git a/sapiens/dense/src/models/core/albedo_estimator.py b/sapiens/dense/src/models/core/albedo_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..08e988a73d287ea050e48bd5e7cb6f5f3675a903 --- /dev/null +++ b/sapiens/dense/src/models/core/albedo_estimator.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Union + +import torch +from sapiens.engine.models import BaseModel +from sapiens.registry import MODELS +from torch import Tensor + + +@MODELS.register_module() +class AlbedoEstimator(BaseModel): + def __init__( + self, + backbone: dict = None, + decode_head: dict = None, + init_cfg: dict = None, + ): + BaseModel.__init__(self, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + self.decode_head = MODELS.build(decode_head) + + def loss( + self, outputs: Union[Tensor, Tuple[Tensor, ...]], data_samples: dict + ) -> Tuple[dict, Tensor]: + losses, normals = self.decode_head.loss(outputs, data_samples) + parsed_losses, log_vars = self.parse_losses(losses) + log_vars["outputs"] = normals + return parsed_losses, log_vars + + def forward(self, inputs: Tensor) -> Tensor: + ## backbone forward returns a list of tensors at different depths, 0 is the final layer + if self.training: + x = torch.utils.checkpoint.checkpoint( + self.backbone, inputs, use_reentrant=False + )[0] + else: + x = self.backbone(inputs)[0] + + x = self.decode_head(x) + return x diff --git a/sapiens/dense/src/models/core/normal_estimator.py b/sapiens/dense/src/models/core/normal_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd75caf91691ea7630eb449179c37ed3fbec6b9 --- /dev/null +++ b/sapiens/dense/src/models/core/normal_estimator.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Union + +import torch +from sapiens.engine.models import BaseModel +from sapiens.registry import MODELS +from torch import Tensor + + +@MODELS.register_module() +class NormalEstimator(BaseModel): + def __init__( + self, + backbone: dict = None, + decode_head: dict = None, + init_cfg: dict = None, + ): + BaseModel.__init__(self, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + self.decode_head = MODELS.build(decode_head) + + def loss( + self, outputs: Union[Tensor, Tuple[Tensor, ...]], data_samples: dict + ) -> Tuple[dict, Tensor]: + losses, normals = self.decode_head.loss(outputs, data_samples) + parsed_losses, log_vars = self.parse_losses(losses) + log_vars["outputs"] = normals + return parsed_losses, log_vars + + def forward(self, inputs: Tensor) -> Tensor: + ## backbone forward returns a list of tensors at different depths, 0 is the final layer + if self.training: + x = torch.utils.checkpoint.checkpoint( + self.backbone, inputs, use_reentrant=False + )[0] + else: + x = self.backbone(inputs)[0] + + x = self.decode_head(x) + return x diff --git a/sapiens/dense/src/models/core/pointmap_estimator.py b/sapiens/dense/src/models/core/pointmap_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..410f77477426805c81099fdbf04037ca57f17caa --- /dev/null +++ b/sapiens/dense/src/models/core/pointmap_estimator.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Union + +import torch +from sapiens.engine.models import BaseModel +from sapiens.registry import MODELS +from torch import Tensor + + +@MODELS.register_module() +class PointmapEstimator(BaseModel): + def __init__( + self, + backbone: dict = None, + decode_head: dict = None, + canonical_focal_length: float = 768.0, + init_cfg: dict = None, + ): + BaseModel.__init__(self, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + self.decode_head = MODELS.build(decode_head) + self.canonical_focal_length = canonical_focal_length + + def loss( + self, outputs: Union[Tensor, Tuple[Tensor, ...]], data_samples: dict + ) -> Tuple[dict, Tensor]: + losses, pointmaps = self.decode_head.loss(outputs, data_samples) + parsed_losses, log_vars = self.parse_losses(losses) + log_vars["outputs"] = pointmaps + return parsed_losses, log_vars + + def forward(self, inputs: Tensor) -> Tensor: + ## backbone forward returns a list of tensors at different depths, 0 is the final layer + if self.training: + x = torch.utils.checkpoint.checkpoint( + self.backbone, inputs, use_reentrant=False + )[0] + else: + x = self.backbone(inputs)[0] + + x = self.decode_head(x) + return x diff --git a/sapiens/dense/src/models/core/seg_estimator.py b/sapiens/dense/src/models/core/seg_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..321fe0a99ed3727e28be1dd5a8ea1b50ee4b14c3 --- /dev/null +++ b/sapiens/dense/src/models/core/seg_estimator.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Union + +import torch +from sapiens.engine.models import BaseModel +from sapiens.registry import MODELS +from torch import Tensor + + +@MODELS.register_module() +class SegEstimator(BaseModel): + def __init__( + self, + backbone: dict = None, + decode_head: dict = None, + init_cfg: dict = None, + ): + BaseModel.__init__(self, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + self.decode_head = MODELS.build(decode_head) + + def loss( + self, outputs: Union[Tensor, Tuple[Tensor, ...]], data_samples: dict + ) -> Tuple[dict, Tensor]: + losses, seg_logits = self.decode_head.loss(outputs, data_samples) + parsed_losses, log_vars = self.parse_losses(losses) + log_vars["outputs"] = seg_logits + return parsed_losses, log_vars + + def forward(self, inputs: Tensor) -> Tensor: + ## backbone forward returns a list of tensors at different depths, 0 is the final layer + if self.training: + x = torch.utils.checkpoint.checkpoint( + self.backbone, inputs, use_reentrant=False + )[0] + else: + x = self.backbone(inputs)[0] + + x = self.decode_head(x) ## B x num_classes x H x W + return x diff --git a/sapiens/dense/src/models/heads/__init__.py b/sapiens/dense/src/models/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68af1864e547bc33267e4cceea1b9e944759b646 --- /dev/null +++ b/sapiens/dense/src/models/heads/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .albedo_head import AlbedoHead +from .normal_head import NormalHead +from .pointmap_head import PointmapHead +from .seg_head import SegHead + +__all__ = ["PointmapHead", "SegHead", "NormalHead", "AlbedoHead"] diff --git a/sapiens/dense/src/models/heads/albedo_head.py b/sapiens/dense/src/models/heads/albedo_head.py new file mode 100644 index 0000000000000000000000000000000000000000..5b5ff63ccbc943e2cf11ef253a6e74595b6e65b2 --- /dev/null +++ b/sapiens/dense/src/models/heads/albedo_head.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.registry import MODELS +from torch import Tensor + + +class RMSNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6, affine: bool = True): + super().__init__() + self.norm = nn.RMSNorm(num_channels, eps=eps, elementwise_affine=affine) + + def forward(self, x): + # x: (N, C, H, W) โ†’ normalize across C per spatial location + x = x.movedim(1, -1).contiguous() # (N, H, W, C) + x = self.norm(x) # RMSNorm over last dim = C + return x.movedim(-1, 1).contiguous() # (N, C, H, W) + + +@MODELS.register_module() +class AlbedoHead(nn.Module): + def __init__( + self, + in_channels: int = 768, + channels: int = 16, + upsample_channels: List[int] = [768, 384, 192, 96], + conv_out_channels: Optional[Sequence[int]] = None, + conv_kernel_sizes: Optional[Sequence[int]] = None, + loss_decode=dict(type="L1Loss", loss_weight=1.0), + **kwargs, + ): + super().__init__(**kwargs) + self.in_channels = in_channels + self.channels = channels + self._build_network(upsample_channels, conv_out_channels, conv_kernel_sizes) + in_channels = ( + conv_out_channels[-1] if conv_out_channels else upsample_channels[-1] + ) + self.conv_albedo = nn.Conv2d(in_channels, 3, kernel_size=1) + + if isinstance(loss_decode, dict): + self.loss_decode = MODELS.build(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(MODELS.build(loss)) + else: + raise TypeError( + f"loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}" + ) + + self._init_weights() + + def _build_network( + self, + upsample_channels: List[int], + conv_out_channels: Optional[Sequence[int]], + conv_kernel_sizes: Optional[Sequence[int]], + ) -> None: + in_channels = self.in_channels + + self.input_conv = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), + RMSNorm2d( + in_channels + ), # nn.InstanceNorm2d(in_channels), # Normalize first + nn.SiLU(inplace=True), + ) + + # Progressive upsampling blocks + up_blocks = [] + cur_ch = in_channels + for out_ch in upsample_channels: + up_blocks.append( + nn.Sequential( + nn.Conv2d(cur_ch, out_ch * 4, kernel_size=3, padding=1), + nn.PixelShuffle(2), # โ†‘ spatial ร—2 + RMSNorm2d(out_ch), # nn.InstanceNorm2d(out_ch), + nn.SiLU(inplace=True), + ) + ) + + cur_ch = out_ch + self.upsample_blocks = nn.Sequential(*up_blocks) + + # optional extra conv layers + conv_layers = [] + if conv_out_channels and conv_kernel_sizes: + for out_ch, k in zip(conv_out_channels, conv_kernel_sizes): + conv_layers.extend( + [ + nn.Conv2d(cur_ch, out_ch, k, padding=(k - 1) // 2), + RMSNorm2d(out_ch), # nn.InstanceNorm2d(out_ch), + nn.SiLU(inplace=True), + ] + ) + cur_ch = out_ch + + self.conv_layers = nn.Sequential(*conv_layers) + + def _init_weights(self) -> None: + """Initialize network weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + weight_dtype = m.weight.dtype + weight = nn.init.kaiming_normal_( + m.weight.float(), mode="fan_out", nonlinearity="relu" + ) + m.weight.data = weight.to(weight_dtype) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + weight_dtype = m.weight.dtype + weight = nn.init.kaiming_normal_( + m.weight.float(), mode="fan_in", nonlinearity="linear" + ) + m.weight.data = weight.to(weight_dtype) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Tensor: + x_albedo = self.input_conv(x) + x_albedo = self.upsample_blocks(x_albedo) + x_albedo = self.conv_layers(x_albedo) + albedo = self.conv_albedo(x_albedo) + return albedo + + def loss( + self, + outputs: Tuple[Tensor], + data_samples: dict, + ) -> dict: + pred_albedo = outputs + gt_albedo = data_samples["gt_albedo"] ## B x 3 x H x W + gt_mask = data_samples["mask"] ## B x 1 x H x W + + if pred_albedo.shape[2:] != gt_albedo.shape[2:]: + pred_albedo = F.interpolate( + input=pred_albedo, + size=gt_albedo.shape[2:], + mode="bilinear", + align_corners=False, + antialias=False, + ) + + ##--------------------------------- + loss = dict() + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + + for loss_decode in losses_decode: + this_loss = loss_decode( + pred_albedo, + gt_albedo, + valid_mask=gt_mask, + ) + + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = this_loss + else: + loss[loss_decode.loss_name] += this_loss + + return loss, pred_albedo diff --git a/sapiens/dense/src/models/heads/normal_head.py b/sapiens/dense/src/models/heads/normal_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d1310fa48aca82f2aece474d102a6a27a9b27853 --- /dev/null +++ b/sapiens/dense/src/models/heads/normal_head.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.registry import MODELS +from torch import Tensor + + +@MODELS.register_module() +class NormalHead(nn.Module): + def __init__( + self, + in_channels: int = 768, + channels: int = 16, + upsample_channels: List[int] = [768, 384, 192, 96], + conv_out_channels: Optional[Sequence[int]] = None, + conv_kernel_sizes: Optional[Sequence[int]] = None, + loss_decode=dict(type="L1Loss", loss_weight=1.0), + **kwargs, + ): + super().__init__(**kwargs) + self.in_channels = in_channels + self.channels = channels + self._build_network(upsample_channels, conv_out_channels, conv_kernel_sizes) + # final conv layer to predict normal + in_channels = ( + conv_out_channels[-1] if conv_out_channels else upsample_channels[-1] + ) + self.conv_normal = nn.Conv2d(in_channels, 3, kernel_size=1) + + if isinstance(loss_decode, dict): + self.loss_decode = MODELS.build(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(MODELS.build(loss)) + else: + raise TypeError( + f"loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}" + ) + + self._init_weights() + + def _build_network( + self, + upsample_channels: List[int], + conv_out_channels: Optional[Sequence[int]], + conv_kernel_sizes: Optional[Sequence[int]], + ) -> None: + in_channels = self.in_channels + + self.input_conv = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), + nn.InstanceNorm2d(in_channels), # Normalize first + nn.SiLU(inplace=True), + ) + + # Progressive upsampling blocks + up_blocks = [] + cur_ch = in_channels + for out_ch in upsample_channels: + up_blocks.append( + nn.Sequential( + nn.Conv2d(cur_ch, out_ch * 4, kernel_size=3, padding=1), + nn.PixelShuffle(2), # โ†‘ spatial ร—2 + nn.InstanceNorm2d(out_ch), + nn.SiLU(inplace=True), + ) + ) + + cur_ch = out_ch + self.upsample_blocks = nn.Sequential(*up_blocks) + + # optional extra conv layers + conv_layers = [] + if conv_out_channels and conv_kernel_sizes: + for out_ch, k in zip(conv_out_channels, conv_kernel_sizes): + conv_layers.extend( + [ + nn.Conv2d(cur_ch, out_ch, k, padding=(k - 1) // 2), + nn.InstanceNorm2d(out_ch), + nn.SiLU(inplace=True), + ] + ) + cur_ch = out_ch + + self.conv_layers = nn.Sequential(*conv_layers) + + def _init_weights(self) -> None: + """Initialize network weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + weight_dtype = m.weight.dtype + weight = nn.init.kaiming_normal_( + m.weight.float(), mode="fan_out", nonlinearity="relu" + ) + m.weight.data = weight.to(weight_dtype) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + weight_dtype = m.weight.dtype + weight = nn.init.kaiming_normal_( + m.weight.float(), mode="fan_in", nonlinearity="linear" + ) + m.weight.data = weight.to(weight_dtype) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Tensor: + x_normal = self.input_conv(x) + x_normal = self.upsample_blocks(x_normal) + x_normal = self.conv_layers(x_normal) + normal = self.conv_normal(x_normal) + return normal + + def loss( + self, + outputs: Tuple[Tensor], + data_samples: dict, + ) -> dict: + pred_normal = outputs + gt_normal = data_samples["gt_normal"] ## B x 3 x H x W + gt_mask = data_samples["mask"] ## B x 1 x H x W + + if pred_normal.shape[2:] != gt_normal.shape[2:]: + pred_normal = F.interpolate( + input=pred_normal, + size=gt_normal.shape[2:], + mode="bilinear", + align_corners=False, + antialias=False, + ) + + ##--------------------------------- + loss = dict() + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + + for loss_decode in losses_decode: + this_loss = loss_decode( + pred_normal, + gt_normal, + valid_mask=gt_mask, + ) + + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = this_loss + else: + loss[loss_decode.loss_name] += this_loss + + return loss, pred_normal diff --git a/sapiens/dense/src/models/heads/pointmap_head.py b/sapiens/dense/src/models/heads/pointmap_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c3de93f79058610ec93b7072a16af3505032196c --- /dev/null +++ b/sapiens/dense/src/models/heads/pointmap_head.py @@ -0,0 +1,277 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.registry import MODELS +from torch import Tensor + + +@MODELS.register_module() +class PointmapHead(nn.Module): + def __init__( + self, + in_channels: int = 768, + channels: int = 16, + upsample_channels: List[int] = [768, 384, 192, 96], + conv_out_channels: Optional[Sequence[int]] = None, + conv_kernel_sizes: Optional[Sequence[int]] = None, + scale_conv_out_channels: Optional[Sequence[int]] = (1536, 512, 128), + scale_conv_kernel_sizes: Optional[Sequence[int]] = (1, 1, 1), + scale_final_layer: Optional[Sequence[int]] = (48 * 128, 512, 64, 1), + loss_decode=dict(type="L1Loss", loss_weight=1.0), + **kwargs, + ): + super().__init__(**kwargs) + self.in_channels = in_channels + self.channels = channels + + self._build_network(upsample_channels, conv_out_channels, conv_kernel_sizes) + if scale_conv_out_channels is not None: + self.scale_conv_layers = self._make_regression_conv_layers( + in_channels=self.in_channels, + layer_out_channels=scale_conv_out_channels, + layer_kernel_sizes=scale_conv_kernel_sizes, + ) + self.scale_final_layer = self._make_final_layer(scale_final_layer) + + else: + self.scale_conv_layers = None + self.scale_final_layer = None + + # final conv layer to predict pointmap + in_channels = ( + conv_out_channels[-1] if conv_out_channels else upsample_channels[-1] + ) + self.conv_pointmap = nn.Conv2d(in_channels, 3, kernel_size=1) + + if isinstance(loss_decode, dict): + self.loss_decode = MODELS.build(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(MODELS.build(loss)) + else: + raise TypeError( + f"loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}" + ) + + self._init_weights() + + def _build_network( + self, + upsample_channels: List[int], + conv_out_channels: Optional[Sequence[int]], + conv_kernel_sizes: Optional[Sequence[int]], + ) -> None: + in_channels = self.in_channels + + self.input_conv = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), + nn.InstanceNorm2d(in_channels), # Normalize first + nn.SiLU(inplace=True), + ) + + # Progressive upsampling blocks + up_blocks = [] + cur_ch = in_channels + for out_ch in upsample_channels: + up_blocks.append( + nn.Sequential( + nn.Conv2d(cur_ch, out_ch * 4, kernel_size=3, padding=1), + nn.PixelShuffle(2), # โ†‘ spatial ร—2 + nn.InstanceNorm2d(out_ch), + nn.SiLU(inplace=True), + ) + ) + + cur_ch = out_ch + self.upsample_blocks = nn.Sequential(*up_blocks) + + # optional extra conv layers + conv_layers = [] + if conv_out_channels and conv_kernel_sizes: + for out_ch, k in zip(conv_out_channels, conv_kernel_sizes): + conv_layers.extend( + [ + nn.Conv2d(cur_ch, out_ch, k, padding=(k - 1) // 2), + nn.InstanceNorm2d(out_ch), + nn.SiLU(inplace=True), + ] + ) + cur_ch = out_ch + + self.conv_layers = nn.Sequential(*conv_layers) + + def _make_final_layer(self, final_layer: Sequence[int]) -> nn.Module: + """Create final layer by given parameters.""" + layers = [nn.Flatten()] + in_features = final_layer[0] + + for i in range(1, len(final_layer)): + layers.append(nn.Linear(in_features, final_layer[i])) + if i < len(final_layer) - 1: # No activation after the last layer + layers.append(nn.SiLU()) + in_features = final_layer[i] + + return nn.Sequential(*layers) + + def _make_regression_conv_layers( + self, + in_channels: int, + layer_out_channels: Sequence[int], + layer_kernel_sizes: Sequence[int], + ) -> nn.Module: + """Create convolutional layers by given parameters.""" + + layers = [] + for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes): + stride = 2 # Set stride to 2 to reduce resolution by half + padding = (kernel_size - 1) // 2 + layers.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + layers.append(nn.InstanceNorm2d(out_channels)) + layers.append(nn.SiLU(inplace=True)) + + in_channels = out_channels + + return nn.Sequential(*layers) + + def _init_weights(self) -> None: + """Initialize network weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + weight_dtype = m.weight.dtype + weight = nn.init.kaiming_normal_( + m.weight.float(), mode="fan_out", nonlinearity="relu" + ) + m.weight.data = weight.to(weight_dtype) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + weight_dtype = m.weight.dtype + weight = nn.init.kaiming_normal_( + m.weight.float(), mode="fan_in", nonlinearity="linear" + ) + m.weight.data = weight.to(weight_dtype) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Tensor: + x_pointmap = self.input_conv(x) + x_pointmap = self.upsample_blocks(x_pointmap) + x_pointmap = self.conv_layers(x_pointmap) + pointmap = self.conv_pointmap(x_pointmap) + + if self.scale_conv_layers is not None: + x_scale = self.scale_conv_layers(x) + scale = self.scale_final_layer( + x_scale + ) ## B x 1. scale = f_c / f_actual. in pixel spac of fx + else: + scale = None + + return pointmap, scale + + def loss( + self, + outputs: Tuple[Tensor], + data_samples: dict, + ) -> dict: + pred_pointmap, pred_scale = outputs + gt_pointmap = data_samples["gt_pointmap"] ## B x 3 x H x W + gt_mean_depth = data_samples["gt_mean_depth"] ## B x 1 x 1 x 1 + + # gt_K = data_samples["meta"]["K"] ## B x 3 x 3 + gt_original_K = data_samples["meta"]["original_K"] ## B x 3 x 3 + gt_scale = data_samples["meta"]["scale"].view(-1, 1) ## B x 1 + gt_mask = data_samples["mask"] ## B x 1 x H x W + + if pred_pointmap.shape[2:] != gt_pointmap.shape[2:]: + print( + "Warning: this is not recommended in pointmap, you may get artifacts!" + ) + print( + f"pred_pointmap size: {pred_pointmap.shape}, gt_pointmap size: {gt_pointmap.shape}" + ) + pred_pointmap = F.interpolate( + input=pred_pointmap, + size=gt_pointmap.shape[2:], + mode="bilinear", + align_corners=False, + antialias=False, + ) + + ##--------------------------------- + loss = dict() + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + + ## B x 1 x H x W + pred_depth = pred_pointmap[:, 2].unsqueeze(dim=1) ## B x 1 x H x W + gt_depth = gt_pointmap[:, 2].unsqueeze(dim=1) ## B x 1 x H x W + + for loss_decode in losses_decode: + ## pointmap consistency loss + if loss_decode.loss_name == "loss_K_consistency": + this_loss = loss_decode( + pred_pointmap, + gt_pointmap, + valid_mask=gt_mask, + intrinsics=gt_original_K, ## Caution: using original K for consistency loss. since X/Z and Y/Z ratio is the same + ) + elif loss_decode.loss_name == "loss_silog": + this_loss = loss_decode( + pred_depth, + gt_depth, + valid_mask=gt_mask, + ) + elif loss_decode.loss_name == "loss_normal": + this_loss = loss_decode( + pred_pointmap, + gt_pointmap, + valid_mask=gt_mask, + scale=gt_scale, + ) + elif loss_decode.loss_name == "loss_scale_l1": + this_loss = loss_decode(pred_scale, gt_scale) + elif loss_decode.loss_name in [ + "loss_l1", + "loss_shift_invariant", + "loss_multiscale_l1_2", + "loss_multiscale_l1_4", + ]: + this_loss = loss_decode( + pred_pointmap / gt_mean_depth, + gt_pointmap / gt_mean_depth, + valid_mask=gt_mask, + ) + this_loss = torch.clamp(this_loss, max=4.0) + + else: + raise NotImplementedError( + f"loss {loss_decode.loss_name} is not implemented" + ) + + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = this_loss + else: + loss[loss_decode.loss_name] += this_loss + + return loss, (pred_pointmap, pred_scale) diff --git a/sapiens/dense/src/models/heads/seg_head.py b/sapiens/dense/src/models/heads/seg_head.py new file mode 100644 index 0000000000000000000000000000000000000000..95ca1dcb72df511246891790509adbee4fd8c245 --- /dev/null +++ b/sapiens/dense/src/models/heads/seg_head.py @@ -0,0 +1,257 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.registry import MODELS +from torch import Tensor + + +@MODELS.register_module() +class SegHead(nn.Module): + def __init__( + self, + in_channels: int = 768, + deconv_out_channels: Optional[Sequence[int]] = (256, 256, 256), + deconv_kernel_sizes: Optional[Sequence[int]] = (4, 4, 4), + conv_out_channels: Optional[Sequence[int]] = None, + conv_kernel_sizes: Optional[Sequence[int]] = None, + num_classes: int = 29, + loss_decode=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), + **kwargs, + ): + super().__init__(**kwargs) + self.in_channels = in_channels + + if deconv_out_channels: + if deconv_kernel_sizes is None or len(deconv_out_channels) != len( + deconv_kernel_sizes + ): + raise ValueError( + '"deconv_out_channels" and "deconv_kernel_sizes" should ' + "be integer sequences with the same length. Got " + f"mismatched lengths {deconv_out_channels} and " + f"{deconv_kernel_sizes}" + ) + + self.deconv_layers = self._make_deconv_layers( + in_channels=in_channels, + layer_out_channels=deconv_out_channels, + layer_kernel_sizes=deconv_kernel_sizes, + ) + in_channels = deconv_out_channels[-1] + else: + self.deconv_layers = nn.Identity() + + if conv_out_channels: + if conv_kernel_sizes is None or len(conv_out_channels) != len( + conv_kernel_sizes + ): + raise ValueError( + '"conv_out_channels" and "conv_kernel_sizes" should ' + "be integer sequences with the same length. Got " + f"mismatched lengths {conv_out_channels} and " + f"{conv_kernel_sizes}" + ) + + self.conv_layers = self._make_conv_layers( + in_channels=in_channels, + layer_out_channels=conv_out_channels, + layer_kernel_sizes=conv_kernel_sizes, + ) + in_channels = conv_out_channels[-1] + else: + self.conv_layers = nn.Identity() + + self.num_classes = num_classes + self.conv_seg = nn.Conv2d(in_channels, self.num_classes, kernel_size=1) + + if isinstance(loss_decode, dict): + self.loss_decode = MODELS.build(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(MODELS.build(loss)) + else: + raise TypeError( + f"loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}" + ) + + self._init_weights() + + def _make_conv_layers( + self, + in_channels: int, + layer_out_channels: Sequence[int], + layer_kernel_sizes: Sequence[int], + ) -> nn.Module: + """Create convolutional layers by given parameters.""" + + layers = [] + for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes): + padding = (kernel_size - 1) // 2 + layers.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding, + ) + ) + layers.append(nn.InstanceNorm2d(out_channels)) + layers.append(nn.SiLU(inplace=True)) + in_channels = out_channels + + return nn.Sequential(*layers) + + def _make_deconv_layers( + self, + in_channels: int, + layer_out_channels: Sequence[int], + layer_kernel_sizes: Sequence[int], + ) -> nn.Module: + """Create deconvolutional layers by given parameters.""" + layers = [] + for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes): + if kernel_size == 4: + padding = 1 + output_padding = 0 + elif kernel_size == 3: + padding = 1 + output_padding = 1 + elif kernel_size == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError( + f"Unsupported kernel size {kernel_size} for" + "deconvlutional layers in " + f"{self.__class__.__name__}" + ) + layers.append( + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False, + ) + ) + layers.append(nn.InstanceNorm2d(out_channels)) + layers.append(nn.SiLU(inplace=True)) + in_channels = out_channels + + return nn.Sequential(*layers) + + def _init_weights(self) -> None: + """Initialize network weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + weight_dtype = m.weight.dtype + weight = nn.init.kaiming_normal_( + m.weight.float(), mode="fan_out", nonlinearity="relu" + ) + m.weight.data = weight.to(weight_dtype) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + weight_dtype = m.weight.dtype + weight = nn.init.kaiming_normal_( + m.weight.float(), mode="fan_in", nonlinearity="linear" + ) + m.weight.data = weight.to(weight_dtype) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Tensor: + x = self.deconv_layers(x) + x = self.conv_layers(x) + out = self.conv_seg(x) + return out + + def loss( + self, + seg_logits: Tensor, + data_samples: dict, + ) -> dict: + seg_labels = data_samples["gt_seg"] ## B x 1 x H x W; torch.int64 + + if seg_logits.shape[2:] != seg_labels.shape[2:]: + seg_logits = F.interpolate( + input=seg_logits, + size=seg_labels.shape[2:], + mode="bilinear", + align_corners=False, + ) + + seg_labels = seg_labels.squeeze(dim=1) ## remove the fake dimension, B x H x W + + loss = dict() + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + seg_logits, + seg_labels, + ) + else: + loss[loss_decode.loss_name] += loss_decode( + seg_logits, + seg_labels, + ) + loss["acc_seg"] = self.accuracy(seg_logits, seg_labels) + + return loss, seg_logits + + def accuracy(self, pred, target, topk=1, thresh=None, ignore_index=255): + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk,) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.size(0) == 0: + accu = [pred.new_tensor(0.0) for i in range(len(topk))] + return accu[0] if return_single else accu + assert pred.ndim == target.ndim + 1 + assert pred.size(0) == target.size(0) + assert maxk <= pred.size(1), ( + f"maxk {maxk} exceeds pred dimension {pred.size(1)}" + ) + pred_value, pred_label = pred.topk(maxk, dim=1) + # transpose to shape (maxk, N, ...) + pred_label = pred_label.transpose(0, 1) + correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) + if thresh is not None: + # Only prediction values larger than thresh are counted as correct + correct = correct & (pred_value > thresh).t() + + if ignore_index is not None: + correct = correct[:, target != ignore_index] + res = [] + eps = torch.finfo(torch.float32).eps + for k in topk: + # Avoid causing ZeroDivisionError when all pixels + # of an image are ignored + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps + if ignore_index is not None: + total_num = target[target != ignore_index].numel() + eps + else: + total_num = target.numel() + eps + res.append(correct_k.mul_(100.0 / total_num)) + return res[0] if return_single else res diff --git a/sapiens/dense/src/models/init_model.py b/sapiens/dense/src/models/init_model.py new file mode 100644 index 0000000000000000000000000000000000000000..03b9201a9e35c0c70ac31a5dada88bac7a6a2112 --- /dev/null +++ b/sapiens/dense/src/models/init_model.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +from typing import Optional, Union + +import torch +from safetensors.torch import load_file +from sapiens.engine.config import Config +from sapiens.engine.datasets import Compose +from sapiens.registry import MODELS + + +def init_model( + config: Union[str, Path], + checkpoint: Optional[Union[str, Path]] = None, + device: str = "cuda:0", +): + assert isinstance(config, (str, Path)) + assert checkpoint is None or isinstance(checkpoint, (str, Path)) + + config = Config.fromfile(config) + + ## avoid loading the pretrained backbone weights + if "init_cfg" in config.model["backbone"]: + config.model["backbone"].pop("init_cfg") + + model = MODELS.build(config.model) + data_preprocessor = MODELS.build(config.data_preprocessor) + + if checkpoint is not None: + if str(checkpoint).endswith(".safetensors"): + state_dict = load_file(checkpoint, device="cpu") + else: # Handle .pth and .bin files + checkpoint_data = torch.load( + checkpoint, map_location="cpu", weights_only=False + ) + state_dict = ( + checkpoint_data["state_dict"] + if "state_dict" in checkpoint_data + else checkpoint_data["model"] + ) + + incompat = model.load_state_dict(state_dict, strict=False) + + if incompat.missing_keys: + print(f"Missing keys: {incompat.missing_keys}") + + if incompat.unexpected_keys: + print(f"Unexpected keys: {incompat.unexpected_keys}") + + print(f"\033[96mModel loaded from {checkpoint}\033[0m") + + model.cfg = config + model.data_preprocessor = data_preprocessor + model.pipeline = Compose(config.test_pipeline) + + model.to(device) + model.eval() + + return model diff --git a/sapiens/dense/src/models/losses/__init__.py b/sapiens/dense/src/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b2ff9d5bde7f341e5f4a64f2d0466a25f579614 --- /dev/null +++ b/sapiens/dense/src/models/losses/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .albedo_loss import AlbedoGradL1Loss +from .l1_loss import L1Loss +from .normal_loss import NormalCosineSimilarityLoss, NormalGradL1Loss +from .pointmap_loss import ( + PointmapIntrinsicsConsistencyLoss, + PointmapNormalLoss, + PointmapScaleL1Loss, + PointmapShiftInvariantL1Loss, +) +from .seg_loss import CrossEntropyLoss, DiceLoss +from .silog_loss import SiLogLoss + +__all__ = [ + "L1Loss", + "PointmapIntrinsicsConsistencyLoss", + "PointmapNormalLoss", + "PointmapScaleL1Loss", + "PointmapShiftInvariantL1Loss", + "SiLogLoss", + "CrossEntropyLoss", + "DiceLoss", + "NormalCosineSimilarityLoss", + "NormalGradL1Loss", + "AlbedoGradL1Loss", +] diff --git a/sapiens/dense/src/models/losses/albedo_loss.py b/sapiens/dense/src/models/losses/albedo_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..91485f5759c99759da879600044b8e97200fc952 --- /dev/null +++ b/sapiens/dense/src/models/losses/albedo_loss.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.registry import MODELS + + +@MODELS.register_module() +class AlbedoChromaticityL1Loss(nn.Module): + """ + Brightness-invariant color loss in RGB. + Per-pixel chroma = RGB / (R+G+B); then L1 on chroma. + Expects pred/target in [0,1] sRGB, shape (B,3,H,W). + """ + + def __init__( + self, + loss_weight: float = 1.0, + eps: float = 1e-6, + loss_name: str = "loss_chroma_l1", + ): + super().__init__() + self.loss_weight = loss_weight + self.eps = eps + self._loss_name = loss_name + + @staticmethod + def _chroma(x, eps): + s = x.sum(dim=1, keepdim=True) # (B,1,H,W) + return x / (s + eps) + + def forward( + self, pred: torch.Tensor, target: torch.Tensor, valid_mask: torch.Tensor + ) -> torch.Tensor: + pred = pred.clamp(0, 1) + + pc = self._chroma(pred, self.eps) + gc = self._chroma(target, self.eps) + + m = valid_mask.to(pred.dtype) # (B,1,H,W), broadcasts over C + per_pix = (pc - gc).abs() # (B,3,H,W) + num = (per_pix * m).sum() + den = m.sum().clamp(min=1) + loss = (num / den) * self.loss_weight + return torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0) + + @property + def loss_name(self): + return self._loss_name + + +@MODELS.register_module() +class AlbedoLowFreqL1Loss(nn.Module): + """ + Compares heavy low-pass versions; leaves edges alone. + Expects pred, gt in [0,1] sRGB (range doesn't matter for the low-pass itself). + """ + + def __init__( + self, + loss_weight=1.0, + down_sample: int = 32, + loss_name="loss_low_l1", + ): + super().__init__() + self.loss_weight = loss_weight + self._loss_name = loss_name + self.down_sample = down_sample + + def _lowpass_masked(self, x, m): + # x:(B,3,H,W), m:(B,1,H,W) in {0,1} + if self.down_sample == 1: + return x + xm = x * m + x_ds = F.interpolate(xm, scale_factor=1.0 / self.down_sample, mode="area") + m_ds = F.interpolate(m, scale_factor=1.0 / self.down_sample, mode="area") + x_up = F.interpolate( + x_ds, size=x.shape[-2:], mode="bilinear", align_corners=False + ) + m_up = F.interpolate( + m_ds, size=x.shape[-2:], mode="bilinear", align_corners=False + ) + return x_up / (m_up + 1e-6) + + def forward(self, pred, target, valid_mask): + assert pred.shape == target.shape, ( + f"The shapes of pred ({pred.shape}) and target ({target.shape}) are mismatched" + ) + + if valid_mask.sum() == 0: + return 0.0 * pred.sum() + + pred_lp = self._lowpass_masked(pred, valid_mask.to(pred.dtype)) + target_lp = self._lowpass_masked(target, valid_mask.to(pred.dtype)) + + if pred_lp.dtype != target_lp.dtype: + target_lp = target_lp.to(pred_lp.dtype) + + loss = F.l1_loss(pred_lp, target_lp, reduction="none") * valid_mask + loss = loss.sum() / valid_mask.sum().clamp(min=1) + loss = loss * self.loss_weight + + # Convert nan to 0 + loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0) + return loss + + @property + def loss_name(self): + return self._loss_name + + +@MODELS.register_module() +class AlbedoGradL1Loss(nn.Module): + def __init__( + self, + loss_weight=1.0, + loss_name="loss_grad_l1", + ): + super().__init__() + self.loss_weight = loss_weight + self._loss_name = loss_name + + def forward(self, pred, target, valid_mask): + assert pred.shape == target.shape, ( + f"The shapes of pred ({pred.shape}) and target ({target.shape}) are mismatched" + ) + if valid_mask.sum() == 0: + return 0.0 * pred.sum() + + ## pred is B x C x H x W + ## target is B x C x H x W + ## valid_mask is B x 1 x H x W + # Compute gradients + pred_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1] + pred_dy = pred[:, :, 1:, :] - pred[:, :, :-1, :] + target_dx = target[:, :, :, 1:] - target[:, :, :, :-1] + target_dy = target[:, :, 1:, :] - target[:, :, :-1, :] + + if target_dx.dtype != pred_dx.dtype or target_dy.dtype != pred_dy.dtype: + # Convert to the same dtype as pred + target_dx = target_dx.to(pred_dx.dtype) + target_dy = target_dy.to(pred_dy.dtype) + + # Adjust valid mask for gradients + valid_mask_dx = valid_mask[:, :, :, :-1] * valid_mask[:, :, :, 1:] + valid_mask_dy = valid_mask[:, :, :-1, :] * valid_mask[:, :, 1:, :] + + # Compute edge-aware loss + loss_dx = F.mse_loss(pred_dx, target_dx, reduction="none") + loss_dy = F.mse_loss(pred_dy, target_dy, reduction="none") + + # Apply valid mask before reduction + loss_dx = (loss_dx * valid_mask_dx).sum() / valid_mask_dx.sum().clamp(min=1) + loss_dy = (loss_dy * valid_mask_dy).sum() / valid_mask_dy.sum().clamp(min=1) + + # Combine losses + loss = (loss_dx + loss_dy) * 0.5 * self.loss_weight + + # Convert nan to 0 + loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0) + return loss + + @property + def loss_name(self): + return self._loss_name diff --git a/sapiens/dense/src/models/losses/l1_loss.py b/sapiens/dense/src/models/losses/l1_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ee624c59f9f17e394813ada4d1814c9739be910a --- /dev/null +++ b/sapiens/dense/src/models/losses/l1_loss.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.registry import MODELS + + +@MODELS.register_module() +class L1Loss(nn.Module): + def __init__(self, reduction="mean", loss_weight=1.0, loss_name="loss_l1"): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self._loss_name = loss_name + + def forward(self, pred, target, valid_mask): + assert pred.shape == target.shape, ( + f"The shapes of pred ({pred.shape}) and target ({target.shape}) are mismatched" + ) + + if valid_mask.sum() == 0: + return 0.0 * pred.sum() + + loss = F.l1_loss(pred, target, reduction="none") * valid_mask ## B x C x H x W + loss = loss.sum() / valid_mask.sum().clamp(min=1) + loss = loss * self.loss_weight + + ## convert nan to 0 + loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0) + + return loss + + @property + def loss_name(self): + return self._loss_name + + +@MODELS.register_module() +class MultiscaleL1Loss(nn.Module): + def __init__( + self, + reduction="mean", + loss_weight=1.0, + loss_name="loss_multiscale_l1", + scale_factor=2, + interpolate_mode="bilinear", + align_corners=False, + ): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self._loss_name = loss_name + f"_{scale_factor}" + self.scale_factor = scale_factor + self.interpolate_mode = interpolate_mode + self.align_corners = align_corners + assert self.interpolate_mode in ["bilinear", "bicubic"] + + def forward(self, pred, target, valid_mask): + assert pred.shape == target.shape, ( + f"The shapes of pred ({pred.shape}) and target ({target.shape}) are mismatched" + ) + + if valid_mask.dtype == torch.bool: + valid_mask = valid_mask.float() + + if valid_mask.sum() == 0: + return 0.0 * pred.sum() + + # Upsample pred and target using the specified interpolation mode + pred_scaled = F.interpolate( + pred, + scale_factor=self.scale_factor, + mode=self.interpolate_mode, + align_corners=self.align_corners, + ) + target_scaled = F.interpolate( + target, + scale_factor=self.scale_factor, + mode=self.interpolate_mode, + align_corners=self.align_corners, + ) + + # Upsample valid_mask using nearest neighbor to preserve binary values + valid_mask_scaled = F.interpolate( + valid_mask, scale_factor=self.scale_factor, mode="nearest" + ) + + loss_scaled = ( + F.l1_loss(pred_scaled, target_scaled, reduction="none") * valid_mask_scaled + ) # B x C x H' x W' + + loss_scaled = loss_scaled.sum() / valid_mask_scaled.sum().clamp(min=1) + loss_scaled = loss_scaled * self.loss_weight + + # Convert any NaN or Inf values to 0 + loss_scaled = torch.nan_to_num(loss_scaled, nan=0.0, posinf=0.0, neginf=0.0) + return loss_scaled + + @property + def loss_name(self): + return self._loss_name diff --git a/sapiens/dense/src/models/losses/normal_loss.py b/sapiens/dense/src/models/losses/normal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..51bfc8aff20980175d3b54e46dba8e7a7e6d9605 --- /dev/null +++ b/sapiens/dense/src/models/losses/normal_loss.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.registry import MODELS + +from .utils import weight_reduce_loss + + +@MODELS.register_module() +class NormalCosineSimilarityLoss(nn.Module): + def __init__( + self, + reduction="mean", + loss_weight=1.0, + eps=-100, + thres_eps=1e-6, + loss_name="loss_cosine_sim", + ): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.eps = eps + self.thres_eps = thres_eps + self._loss_name = loss_name + + def forward( + self, + pred, + target, + valid_mask=None, + weight=None, + avg_factor=None, + reduction_override=None, + ): + assert pred.shape == target.shape, ( + f"The shapes of pred ({pred.shape}) and target ({target.shape}) are mismatched" + ) + assert reduction_override in ( + None, + "none", + "mean", + "sum", + ), "Invalid reduction_override value" + + reduction = reduction_override if reduction_override else self.reduction + + if valid_mask is None: + valid_mask = (target > self.eps).detach().float() + valid_mask = valid_mask[:, 0, :, :].unsqueeze(1) ## B x 1 x H x W + + if valid_mask.sum() == 0: + return 0.0 * pred.sum() + + # Normalize predictions and targets to unit vectors + pred_norm = F.normalize(pred, p=2, dim=1, eps=self.thres_eps) + target_norm = F.normalize(target, p=2, dim=1, eps=self.thres_eps) + + # Compute cosine similarity + cos_sim = (pred_norm * target_norm).sum(dim=1, keepdim=True) + loss = ( + 1 - cos_sim + ) * valid_mask # Ensuring loss is computed only for valid pixels. B x 1 x H x W + + if reduction == "mean": + loss = loss.sum() / valid_mask.sum().clamp(min=1) + elif reduction == "sum": + loss = loss.sum() + elif reduction == "none": + pass # Keep per-pixel loss + else: + raise ValueError(f"Invalid reduction type: {reduction}") + + loss = ( + weight_reduce_loss(loss, weight, reduction, avg_factor) * self.loss_weight + ) + + ## convert nan to 0 + loss = torch.nan_to_num( + loss, + nan=torch.tensor(0, dtype=pred.dtype, device=pred.device), + posinf=torch.tensor(0, dtype=pred.dtype, device=pred.device), + neginf=torch.tensor(0, dtype=pred.dtype, device=pred.device), + ) + + return loss + + @property + def loss_name(self): + """Returns the name of this loss function.""" + return self._loss_name + + +@MODELS.register_module() +class NormalGradL1Loss(nn.Module): + def __init__( + self, + reduction="mean", + loss_weight=1.0, + eps=-100, + thres_eps=1e-6, + loss_name="loss_grad_l1", + ): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.eps = eps + self.thres_eps = thres_eps + self._loss_name = loss_name + + def forward( + self, + pred, + target, + valid_mask=None, + weight=None, + avg_factor=None, + reduction_override=None, + ): + assert pred.shape == target.shape, ( + f"The shapes of pred ({pred.shape}) and target ({target.shape}) are mismatched" + ) + assert reduction_override in ( + None, + "none", + "mean", + "sum", + ), "Invalid reduction_override value" + + reduction = reduction_override if reduction_override else self.reduction + if valid_mask is None: + valid_mask = (target > self.eps).detach().float() + valid_mask = valid_mask[:, 0, :, :].unsqueeze(1) ## B x 1 x H x W + + if valid_mask.sum() == 0: + return 0.0 * pred.sum() + + ## pred is B x C x H x W + ## target is B x C x H x W + # Compute gradients + pred_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1] + pred_dy = pred[:, :, 1:, :] - pred[:, :, :-1, :] + target_dx = target[:, :, :, 1:] - target[:, :, :, :-1] + target_dy = target[:, :, 1:, :] - target[:, :, :-1, :] + + if target_dx.dtype != pred_dx.dtype or target_dy.dtype != pred_dy.dtype: + # Convert to the same dtype as pred + target_dx = target_dx.to(pred_dx.dtype) + target_dy = target_dy.to(pred_dy.dtype) + + # Adjust valid mask for gradients + valid_mask_dx = valid_mask[:, :, :, :-1] * valid_mask[:, :, :, 1:] + valid_mask_dy = valid_mask[:, :, :-1, :] * valid_mask[:, :, 1:, :] + + # Compute edge-aware loss + loss_dx = F.mse_loss(pred_dx, target_dx, reduction="none") + loss_dy = F.mse_loss(pred_dy, target_dy, reduction="none") + + # Apply valid mask before reduction + loss_dx = (loss_dx * valid_mask_dx).sum() / valid_mask_dx.sum().clamp(min=1) + loss_dy = (loss_dy * valid_mask_dy).sum() / valid_mask_dy.sum().clamp(min=1) + + # Combine losses + loss = loss_dx + loss_dy + + if reduction == "mean": + pass # Loss is already averaged + elif reduction == "sum": + loss = loss * pred.shape[0] # Multiply by batch size to get sum + elif reduction == "none": + loss = loss.expand( + pred.shape[0], -1, -1, -1 + ) # Expand loss to match input shape + else: + raise ValueError(f"Invalid reduction type: {reduction}") + + loss = ( + weight_reduce_loss(loss, weight, reduction, avg_factor) * self.loss_weight + ) + + # Convert nan to 0 + loss = torch.nan_to_num( + loss, + nan=torch.tensor(0, dtype=pred.dtype, device=pred.device), + posinf=torch.tensor(0, dtype=pred.dtype, device=pred.device), + neginf=torch.tensor(0, dtype=pred.dtype, device=pred.device), + ) + + return loss + + @property + def loss_name(self): + """Returns the name of this loss function.""" + return self._loss_name diff --git a/sapiens/dense/src/models/losses/pointmap_loss.py b/sapiens/dense/src/models/losses/pointmap_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..22dd93ab4504ebaae6712d752a72957d8c46bfa7 --- /dev/null +++ b/sapiens/dense/src/models/losses/pointmap_loss.py @@ -0,0 +1,435 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.registry import MODELS + +from .utils import weight_reduce_loss + + +@MODELS.register_module() +class PointmapShiftInvariantL1Loss(nn.Module): + """L1 loss that is invariant to global translation of the point cloud""" + + def __init__( + self, + reduction="mean", + loss_weight=1.0, + eps=-100, + loss_name="loss_shift_invariant", + ): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.eps = eps + self._loss_name = loss_name + + def forward( + self, + pred, + target, + valid_mask=None, + weight=None, + avg_factor=None, + reduction_override=None, + ): + assert pred.shape == target.shape, ( + f"The shapes of pred ({pred.shape}) and target ({target.shape}) are mismatched" + ) + assert reduction_override in ( + None, + "none", + "mean", + "sum", + ), f"Invalid reduction: {reduction_override}" + + reduction = reduction_override if reduction_override else self.reduction + + if valid_mask is None: + valid_mask = (target > self.eps).detach().float() + valid_mask = valid_mask[:, 0, :, :].unsqueeze(1) # B x 1 x H x W + + if valid_mask.sum() == 0: + return 0.0 * pred.sum() + + # Compute centroids for each batch + pred_centroid = (pred * valid_mask).sum(dim=(2, 3)) / ( + valid_mask.sum(dim=(2, 3)) + 1e-6 + ) # (B, 3) + target_centroid = (target * valid_mask).sum(dim=(2, 3)) / ( + valid_mask.sum(dim=(2, 3)) + 1e-6 + ) # (B, 3) + + # Center both point clouds + pred_centered = pred - pred_centroid.view(pred.shape[0], 3, 1, 1) + target_centered = target - target_centroid.view(target.shape[0], 3, 1, 1) + + # Compute L1 loss on centered points + loss = torch.abs(pred_centered - target_centered) * valid_mask + + # Apply reduction + if reduction == "mean": + loss = loss.sum() / (valid_mask.sum().clamp(min=1)) + elif reduction == "sum": + loss = loss.sum() + + loss = ( + weight_reduce_loss(loss, weight, reduction, avg_factor) * self.loss_weight + ) + + # Handle numerical instabilities + loss = torch.nan_to_num( + loss, + nan=torch.tensor(0, dtype=pred.dtype, device=pred.device), + posinf=torch.tensor(0, dtype=pred.dtype, device=pred.device), + neginf=torch.tensor(0, dtype=pred.dtype, device=pred.device), + ) + + return loss + + @property + def loss_name(self): + """Returns the name of this loss function.""" + return self._loss_name + + +def visualize_normals(pred_normals, target_normals, valid_mask): + # Set random seed + seed = np.random.randint(0, 100000) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + # Convert tensors to numpy arrays + pred = pred_normals.detach().cpu().numpy() + target = target_normals.detach().cpu().numpy() + mask = valid_mask.detach().cpu().numpy() + + def process_normal_map(normal_map, mask): + # Take the first sample in batch + normal_map = normal_map[0] # (3, H, W) + mask = mask[0, 0] # (H, W) + + # Print unique values for each channel + print("\nNormal map statistics:") + print("-" * 50) + for i, axis in enumerate(["X", "Y", "Z"]): + unique_vals = np.unique(normal_map[i]) + print(f"\n{axis}-component unique values:") + print(f"Min: {normal_map[i].min():.4f}") + print(f"Max: {normal_map[i].max():.4f}") + print(f"Number of unique values: {len(unique_vals)}") + print( + f"Sample of unique values: {unique_vals[:10]}" + ) # Show first 10 unique values + + # Print statistics for valid regions only + valid_normals = normal_map[:, mask > 0] + print("\nValid regions statistics:") + print("-" * 50) + for i, axis in enumerate(["X", "Y", "Z"]): + if valid_normals.size > 0: + print(f"\n{axis}-component (valid regions):") + print(f"Min: {valid_normals[i].min():.4f}") + print(f"Max: {valid_normals[i].max():.4f}") + print(f"Mean: {valid_normals[i].mean():.4f}") + print(f"Std: {valid_normals[i].std():.4f}") + + # Calculate and print vector magnitudes + magnitudes = np.linalg.norm(normal_map, axis=0) + print("\nNormal vector magnitudes:") + print(f"Min magnitude: {magnitudes.min():.4f}") + print(f"Max magnitude: {magnitudes.max():.4f}") + print(f"Mean magnitude: {magnitudes.mean():.4f}") + + # Convert to RGB space [0, 255] + # Map from [-1, 1] to [0, 1] + normal_rgb = (normal_map + 1.0) * 0.5 + normal_rgb = np.clip(normal_rgb * 255, 0, 255).astype(np.uint8) + + # Transpose from (3, H, W) to (H, W, 3) + normal_rgb = np.transpose(normal_rgb, (1, 2, 0)) + + # Apply mask + normal_rgb[mask == 0] = 0 + + return normal_rgb + + # Process both prediction and ground truth + pred_vis = process_normal_map(pred, mask) + gt_vis = process_normal_map(target, mask) + + # Save images + cv2.imwrite(f"seed_{seed}_pred.jpg", cv2.cvtColor(pred_vis, cv2.COLOR_RGB2BGR)) + cv2.imwrite(f"seed_{seed}_gt.jpg", cv2.cvtColor(gt_vis, cv2.COLOR_RGB2BGR)) + + return pred_vis, gt_vis + + +@MODELS.register_module() +class PointmapNormalLoss(nn.Module): + def __init__( + self, + reduction="mean", + loss_weight=1.0, + eps=-100, + l1_weight=1.0, + loss_name="loss_normal", + ): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.eps = eps + self.l1_weight = l1_weight # Weight for L1 loss component + self._loss_name = loss_name + + def compute_normals(self, pointmap, valid_mask, scale=1.0, amplification=100.0): + """ + Compute surface normals from pointmap using neighboring points + Args: + pointmap: (B, 3, H, W) tensor of XYZ coordinates + valid_mask: (B, 1, H, W) binary mask + Returns: + normals: (B, 3, H, W) tensor of normalized surface normals + """ + ## scale canonical Z to metric Z + scaled_pointmap = pointmap.clone() + scale = scale.view(-1, 1, 1, 1) + scaled_pointmap = ( + pointmap / scale + ) ## scale canonical pointmap to metric pointmap + + # Pad pointmap for computing neighbors + pointmap_pad = F.pad(scaled_pointmap, (1, 1, 1, 1), mode="reflect") + center = pointmap_pad[:, :, 1:-1, 1:-1] + + # Get vectors to neighbors (use central differences) + dx = pointmap_pad[:, :, 1:-1, 2:] - center # right - center + dy = pointmap_pad[:, :, 2:, 1:-1] - center # down - center + + dx = dx * amplification # Scale up x gradient + dy = dy * amplification # Scale up y gradient + + # Compute cross product for normal vectors + normal = torch.cross(dx, dy, dim=1) # (B, 3, H, W) + + # Normalize the vectors + normal_magnitude = torch.norm(normal, dim=1, keepdim=True) + normal = normal / normal_magnitude.clamp(min=1e-6) + # normal = torch.where( + # mask, normal / (normal_magnitude + self.thres_eps), torch.zeros_like(normal) + # ) + + # Apply valid mask and handle invalid normals + magnitude_is_valid = normal_magnitude > 1e-6 + valid_normals = magnitude_is_valid & (valid_mask > 0) + normal = normal * valid_normals.float() + + return normal, valid_normals + + def forward( + self, + pred, + target, + valid_mask=None, + scale=None, + weight=None, + avg_factor=None, + reduction_override=None, + ): + """ + Args: + pred (Tensor): Predicted pointmap of shape (B, 3, H, W) + target (Tensor): Target pointmap of shape (B, 3, H, W) + valid_mask (Tensor, optional): Validity mask of shape (B, 1, H, W) + """ + assert pred.shape == target.shape + assert reduction_override in (None, "none", "mean", "sum") + + reduction = reduction_override if reduction_override else self.reduction + + ## metric Z = canonical Z / scale + if valid_mask is None: + valid_mask = (target > self.eps).detach().float() + valid_mask = valid_mask[:, 0, :, :].unsqueeze(1) + + if valid_mask.sum() == 0: + return 0.0 * pred.sum() + + # Compute normals with validity masks + pred_normals, _ = self.compute_normals(pred, valid_mask, scale) + target_normals, target_valid = self.compute_normals(target, valid_mask, scale) + # Combined validity mask for normal comparison + valid = target_valid ## only use gt pointmap + + # Cosine similarity loss (1 - cos) + cos_similarity = torch.sum(pred_normals * target_normals, dim=1, keepdim=True) + cos_similarity = torch.clamp(cos_similarity, min=-1.0, max=1.0) + normal_loss = (1 - cos_similarity) * valid.float() + + # L1 loss on normal vectors + l1_loss = ( + torch.abs(pred_normals - target_normals).mean(dim=1, keepdim=True) + * valid.float() + ) + + # Combine losses + combined_loss = normal_loss + self.l1_weight * l1_loss + + # Apply reduction + if reduction == "mean": + loss = combined_loss.sum() / (valid.sum().clamp(min=1)) + elif reduction == "sum": + loss = combined_loss.sum() + else: + loss = combined_loss + + loss = ( + weight_reduce_loss(loss, weight, reduction, avg_factor) * self.loss_weight + ) + + # ----------debug----------- + # visualize_normals(pred_normals, target_normals, valid) + + return loss + + @property + def loss_name(self): + return self._loss_name + + +@MODELS.register_module() +class PointmapIntrinsicsConsistencyLoss(nn.Module): + def __init__( + self, + reduction="mean", + loss_weight=1.0, + eps=-100, + loss_name="loss_K_consistency", + ): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.eps = eps + self._loss_name = loss_name + + def forward( + self, + pred, + target, + valid_mask=None, + intrinsics=None, + weight=None, + avg_factor=None, + reduction_override=None, + ): + assert pred.shape == target.shape, ( + f"The shapes of pred ({pred.shape}) and target ({target.shape}) are mismatched" + ) + assert reduction_override in ( + None, + "none", + "mean", + "sum", + ), "Invalid reduction_override value" + assert intrinsics is not None, ( + "intrinsics must be provided for PointmapIntrinsicsConsistencyLoss" + ) + assert intrinsics.shape[1] == 3 and intrinsics.shape[2] == 3, ( + "intrinsics must be a B x 3 x 3 tensor" + ) + + reduction = reduction_override if reduction_override else self.reduction + + if valid_mask is None: + valid_mask = (target > self.eps).detach().float() + valid_mask = valid_mask[:, 0, :, :].unsqueeze(1) ## B x 1 x H x W + + if valid_mask.sum() == 0: + return 0.0 * pred.sum() + + B, C, H, W = pred.shape + device = pred.device + valid_mask = valid_mask.squeeze(1) ## B x H x W + + pred_X = pred[:, 0, :, :] ## B x H x W + pred_Y = pred[:, 1, :, :] ## B x H x W + pred_Z = pred[:, 2, :, :] ## B x H x W + + cols = torch.arange(W, device=device).repeat(B, H, 1) # B x H x W + rows = ( + torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2) + ) # B x H x W + + # Compute x and y from z and K + x = ( + (cols - intrinsics[:, 0, 2].view(B, 1, 1)) + * pred_Z + / intrinsics[:, 0, 0].view(B, 1, 1) + ) # B x H x W + y = ( + (rows - intrinsics[:, 1, 2].view(B, 1, 1)) + * pred_Z + / intrinsics[:, 1, 1].view(B, 1, 1) + ) # B x H x W + + # # ##---------------to debug consistency with target------------------- + # target_X = target[:, 0, :, :] + # target_Y = target[:, 1, :, :] + # x = (cols - intrinsics[:, 0, 2].view(B, 1, 1)) * target_Z / intrinsics[:, 0, 0].view(B, 1, 1) + # y = (rows - intrinsics[:, 1, 2].view(B, 1, 1)) * target_Z / intrinsics[:, 1, 1].view(B, 1, 1) + + # loss_X = torch.abs(target_X - x) * valid_mask + # loss_Y = torch.abs(target_Y - y) * valid_mask + + ##----------------------------------------------------- + # Loss calculations + loss_X = torch.abs(pred_X - x) * valid_mask + loss_Y = torch.abs(pred_Y - y) * valid_mask + + # Apply reduction + if self.reduction == "mean": + loss = (loss_X + loss_Y).sum() / valid_mask.sum().clamp(min=1) + elif self.reduction == "sum": + loss = (loss_X + loss_Y).sum() + elif self.reduction == "none": + loss = loss_X + loss_Y + else: + raise ValueError("Unsupported reduction type") + + # Handle NaN values + loss = torch.nan_to_num(loss, nan=0.0) * self.loss_weight + + return loss + + @property + def loss_name(self): + return self._loss_name + + +@MODELS.register_module() +class PointmapScaleL1Loss(nn.Module): + """L1 loss that is invariant to global translation of the point cloud""" + + def __init__(self, loss_weight=1.0, loss_name="loss_scale_l1", **kwargs): + super().__init__() + self.loss_weight = loss_weight + self._loss_name = loss_name + + def forward(self, pred, target): + loss = F.l1_loss(pred, target, reduction="mean") * self.loss_weight + return loss + + @property + def loss_name(self): + """Returns the name of this loss function.""" + return self._loss_name diff --git a/sapiens/dense/src/models/losses/seg_loss.py b/sapiens/dense/src/models/losses/seg_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b47ab9fcb6a4d7403be2a4be479d10f1a1a630ab --- /dev/null +++ b/sapiens/dense/src/models/losses/seg_loss.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.registry import MODELS + +from .utils import weight_reduce_loss + + +def cross_entropy( + pred, + label, + weight=None, + class_weight=None, + reduction="mean", + avg_factor=None, + ignore_index=255, +): + loss = F.cross_entropy( + pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index + ) + + if avg_factor is None and reduction == "mean": + if class_weight is None: + avg_factor = label.numel() + + else: + # label_weights = torch.tensor( + # [class_weight[cls] for cls in label], device=class_weight.device + # ) + label_weights = class_weight[label] + avg_factor = label_weights.sum() + + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor + ) + + return loss + + +@MODELS.register_module() +class CrossEntropyLoss(nn.Module): + def __init__( + self, + reduction="mean", + class_weight=None, + loss_weight=1.0, + loss_name="loss_ce", + ignore_index=255, + ): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = class_weight + self.cls_criterion = cross_entropy + self._loss_name = loss_name + self.ignore_index = ignore_index + + def forward( + self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs, + ): + """Forward function.""" + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + ignore_index=self.ignore_index, + **kwargs, + ) + return loss_cls + + @property + def loss_name(self): + return self._loss_name + + +# ------------------------------------------------------------------------------- +@MODELS.register_module() +class DiceLoss(nn.Module): + def __init__( + self, + use_sigmoid: bool = False, # set True for binary single-logit heads + activate: bool = True, # apply sigmoid/softmax inside the loss + reduction: str = "mean", # "none" | "mean" | "sum" + naive_dice: bool = False, # if True: (2TP)/(P+G); else: (2TP)/(||P||^2+||G||^2) + loss_weight: float = 1.0, + include_background: bool = False, + ignore_index: Union[int, None] = 255, + eps: float = 1e-3, + loss_name: str = "loss_dice", + ): + super().__init__() + self.use_sigmoid = use_sigmoid + self.activate = activate + self.reduction = reduction + self.naive_dice = naive_dice + self.loss_weight = loss_weight + self.ignore_index = ignore_index + self.include_background = include_background + self.eps = eps + self._loss_name = loss_name + + def _to_one_hot_and_mask(self, target: torch.Tensor, num_classes: int): + assert target.dtype == torch.long, ( + f"target must be torch.long, got {target.dtype}" + ) + B, H, W = target.shape + if self.ignore_index is None: + valid_mask = torch.ones( + (B, 1, H, W), dtype=torch.bool, device=target.device + ) + t = target + else: + valid_mask = (target != self.ignore_index).unsqueeze(1) + t = torch.where( + target == self.ignore_index, torch.zeros_like(target), target + ) + + if num_classes == 1: + # keep 0/1 labels; do NOT clamp to [0,0] + one_hot = (t == 1).float().unsqueeze(1) # (B,1,H,W) + else: + # now it's safe to clamp to [0, C-1] + t = t.clamp(min=0, max=max(0, num_classes - 1)) + one_hot = F.one_hot(t, num_classes=num_classes).permute(0, 3, 1, 2).float() + return one_hot, valid_mask + + def _dice_per_sample(self, pred, target_oh, valid_mask): + valid = valid_mask.to(dtype=pred.dtype) + pred_m = pred * valid + gt_m = target_oh * valid + + # Exclude background for multi-class if requested + if ( + (not self.include_background) + and (gt_m.size(1) > 1) + and (not self.use_sigmoid) + ): + pred_m = pred_m[:, 1:, ...] + gt_m = gt_m[:, 1:, ...] + + dims = (2, 3) + inter = (pred_m * gt_m).sum(dims) + if self.naive_dice: + p_sum = pred_m.sum(dims) + g_sum = gt_m.sum(dims) + dice_c = (2 * inter + self.eps) / (p_sum + g_sum + self.eps) + else: + p_sq = (pred_m * pred_m).sum(dims) + g_sq = (gt_m * gt_m).sum(dims) + dice_c = (2 * inter + self.eps) / (p_sq + g_sq + self.eps) + + present = (gt_m.sum(dims) > 0).to(pred.dtype) # classes present in GT + denom = present.sum(dim=1).clamp(min=1) + return (dice_c * present).sum(dim=1) / denom + + def forward( + self, + pred: torch.Tensor, # (B, C, H, W) logits + target: torch.Tensor, # (B, H, W) long + weight: Union[torch.Tensor, None] = None, # optional per-sample weights (B,) + avg_factor: Union[int, None] = None, + reduction_override: Union[str, None] = None, + ): + assert pred.dim() == 4 and target.dim() == 3, ( + "Shapes must be (B,C,H,W) and (B,H,W)." + ) + B, C, H, W = pred.shape + reduction = reduction_override or self.reduction + + # activate if requested + if self.activate: + if C == 1 or self.use_sigmoid: + pred = pred.sigmoid() + else: + pred = pred.softmax(dim=1) + + # one-hot + valid mask + target_oh, valid = self._to_one_hot_and_mask(target, num_classes=C) + + # per-sample macro dice + dice_ps = self._dice_per_sample(pred, target_oh, valid) # (B,) + loss = 1.0 - dice_ps + + if self.ignore_index is None: + sample_valid = torch.ones( + (loss.shape[0],), dtype=loss.dtype, device=loss.device + ) + else: + sample_valid = valid.view(valid.size(0), -1).any(dim=1).to(dtype=loss.dtype) + + gt_m = (target_oh * valid).sum(dim=(2, 3)) > 0 # (B, C) + if (not self.include_background) and (C > 1) and (not self.use_sigmoid): + gt_m = gt_m[:, 1:] + has_present = gt_m.any(dim=1).to(loss.dtype) # (B,) + effective_mask = sample_valid * has_present + + if weight is None: + eff_weight = effective_mask + else: + if weight.shape != loss.shape: + raise ValueError( + f"weight must be shape {loss.shape}, got {weight.shape}" + ) + eff_weight = effective_mask * weight + + if reduction == "mean": + mean_denom = ( + avg_factor + if (avg_factor is not None) + else eff_weight.sum().clamp(min=1).item() + ) + else: + mean_denom = None + + loss = self.loss_weight * weight_reduce_loss( + loss * eff_weight, weight=None, reduction=reduction, avg_factor=mean_denom + ) + + return loss + + @property + def loss_name(self): + return self._loss_name diff --git a/sapiens/dense/src/models/losses/silog_loss.py b/sapiens/dense/src/models/losses/silog_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..145844fd0eb7fb2b2b3d7d9f3c97c00b372ff52b --- /dev/null +++ b/sapiens/dense/src/models/losses/silog_loss.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Union + +import torch +import torch.nn as nn +from sapiens.registry import MODELS +from torch import Tensor + +from .utils import weight_reduce_loss + + +def silog_loss( + pred: Tensor, + target: Tensor, + valid_mask: Optional[Tensor] = None, + weight: Optional[Tensor] = None, + eps: float = 1e-4, + clamp_eps: float = 1e-4, + reduction: Union[str, None] = "mean", + avg_factor: Optional[int] = None, +) -> Tensor: + pred, target = pred.flatten(1), target.flatten(1) + + if valid_mask is None: + valid_mask = (target > eps).detach().float() + + valid_mask = valid_mask.flatten(1) + diff_log = torch.log(target.clamp(min=clamp_eps)) - torch.log( + pred.clamp(min=clamp_eps) + ) + + valid_mask = (valid_mask > 0) & (target > eps).detach() & (~torch.isnan(diff_log)) + diff_log[~valid_mask] = 0.0 + valid_mask = valid_mask.float() + + diff_log_sq_mean = (diff_log.pow(2) * valid_mask).sum(dim=1) / valid_mask.sum( + dim=1 + ).clamp(min=clamp_eps) + diff_log_mean = (diff_log * valid_mask).sum(dim=1) / valid_mask.sum(dim=1).clamp( + min=clamp_eps + ) + + loss = torch.sqrt(diff_log_sq_mean - 0.5 * diff_log_mean.pow(2)) + + if weight is not None: + weight = weight.float() + + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class SiLogLoss(nn.Module): + def __init__( + self, reduction="mean", loss_weight=1.0, eps=-100, loss_name="loss_silog" + ): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.eps = eps + self._loss_name = loss_name + + def forward( + self, + pred, + target, + valid_mask=None, + weight=None, + avg_factor=None, + reduction_override=None, + ): + assert pred.shape == target.shape, ( + "the shapes of pred " + f"({pred.shape}) and target ({target.shape}) are mismatch" + ) + + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + + if valid_mask is None: + valid_mask = (target > self.eps).detach().float() + valid_mask = valid_mask[:, 0, :, :].unsqueeze(1) ## B x 1 x H x W + + loss = self.loss_weight * silog_loss( + pred, + target, + valid_mask=valid_mask, + weight=weight, + eps=self.eps, + clamp_eps=1e-4, + reduction=reduction, + avg_factor=avg_factor, + ) + + return loss + + @property + def loss_name(self): + return self._loss_name diff --git a/sapiens/dense/src/models/losses/utils.py b/sapiens/dense/src/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..85ca3590b6a3e2123e84cc573ea5bddd80c264c8 --- /dev/null +++ b/sapiens/dense/src/models/losses/utils.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +# pyre-strict + +import functools +from typing import Callable + +import torch +from torch import Tensor + + +def reduce_loss(loss: Tensor, reduction: str) -> Tensor: + """Reduce the loss tensor based on reduction type. + + Args: + loss: Loss tensor to reduce. + reduction: Reduction type ('none', 'mean', or 'sum'). + + Returns: + Reduced loss tensor. + """ + match reduction: + case "none": + return loss + case "mean": + return loss.mean() + case "sum": + return loss.sum() + case _: + raise ValueError(f"Unknown reduction type: {reduction}") + + +def weight_reduce_loss( + loss: Tensor, + weight: Tensor | None = None, + reduction: str = "mean", + avg_factor: float | None = None, +) -> Tensor: + """Apply weight and reduction to loss tensor. + + Args: + loss: Loss tensor. + weight: Optional element-wise weight. + reduction: Reduction type ('none', 'mean', or 'sum'). + avg_factor: Optional averaging factor. + + Returns: + Weighted and reduced loss tensor. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + if weight.dim() > 1: + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == "mean": + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != "none": + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss( + loss_func: Callable[..., Tensor], +) -> Callable[..., Tensor]: + """Decorator to add weight and reduction support to a loss function. + + Args: + loss_func: Loss function to wrap. + + Returns: + Wrapped loss function with weight and reduction support. + """ + + @functools.wraps(loss_func) + def wrapper( + pred: Tensor, + target: Tensor, + weight: Tensor | None = None, + reduction: str = "mean", + avg_factor: float | None = None, + **kwargs: object, + ) -> Tensor: + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper diff --git a/sapiens/dense/src/visualizers/__init__.py b/sapiens/dense/src/visualizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb44856d8f4497727c9872642d1bbfe971b702dd --- /dev/null +++ b/sapiens/dense/src/visualizers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .albedo_visualizer import AlbedoVisualizer +from .normal_visualizer import NormalVisualizer +from .pointmap_visualizer import PointmapVisualizer +from .seg_visualizer import SegVisualizer + +__all__ = [ + "PointmapVisualizer", + "SegVisualizer", + "NormalVisualizer", + "AlbedoVisualizer", +] diff --git a/sapiens/dense/src/visualizers/albedo_visualizer.py b/sapiens/dense/src/visualizers/albedo_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d7392eccfb85b051c23ca5f7c5af76a19585bc --- /dev/null +++ b/sapiens/dense/src/visualizers/albedo_visualizer.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from pathlib import Path + +import cv2 +import numpy as np +import torch +from sapiens.registry import VISUALIZERS +from torch import nn + + +@VISUALIZERS.register_module() +class AlbedoVisualizer(nn.Module): + def __init__( + self, + output_dir: str, + vis_interval: int = 100, + vis_max_samples: int = 4, + vis_image_width: int = 384, + vis_image_height: int = 512, + ): + super().__init__() + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.vis_max_samples = vis_max_samples + self.vis_interval = vis_interval + self.vis_image_width = vis_image_width + self.vis_image_height = vis_image_height + + def add_batch(self, data_batch: dict, logs: dict, step: int): + pred_albedos = logs["outputs"] + pred_albedos = pred_albedos.detach().cpu() # B x 3 x H x W + gt_albedos = ( + data_batch["data_samples"]["gt_albedo"].detach().cpu() + ) # B x 3 x H x W + masks = data_batch["data_samples"]["mask"].detach().cpu() # B x 1 x H x + inputs = data_batch["inputs"].detach().cpu() # B x 3 x H x W + + if pred_albedos.dtype == torch.bfloat16: + inputs = inputs.float() + pred_albedos = pred_albedos.float() + + pred_albedos = pred_albedos.cpu().detach().numpy() ## B x 3 x H x W + pred_albedos = pred_albedos.transpose((0, 2, 3, 1)) ## B x H x W x 3 + + gt_albedos = gt_albedos.cpu().detach().numpy() ## B x 3 x H x W + gt_albedos = gt_albedos.transpose((0, 2, 3, 1)) ## B x H x W x 3 + + batch_size = min(len(inputs), self.vis_max_samples) + inputs = inputs[:batch_size] + pred_albedos = pred_albedos[:batch_size] ## B x 3 x H x W + gt_albedos = gt_albedos[:batch_size] ## B x 3 x H x W + masks = masks[:batch_size] ## B x 1 x H x W + + prefix = os.path.join(self.output_dir, "train") + suffix = str(step).zfill(6) + suffix += "_" + data_batch["data_samples"]["meta"]["img_path"][0].split("/")[ + -1 + ].replace(".png", "") + vis_images = [] + + for i, (input, gt_albedo, mask, pred_albedo) in enumerate( + zip(inputs, gt_albedos, masks, pred_albedos) + ): + image = input.permute(1, 2, 0).cpu().numpy() ## bgr image + image = np.ascontiguousarray(image.copy()) + mask = mask[0].numpy() > 0 ## H x W + + if ( + pred_albedo.shape[0] != image.shape[0] + or pred_albedo.shape[1] != image.shape[1] + ): + image = cv2.resize( + image, + (pred_albedo.shape[1], pred_albedo.shape[0]), + interpolation=cv2.INTER_LINEAR, + ) + + vis_gt_albedo = (gt_albedo * 255).astype(np.uint8) ## rgb + vis_pred_albedo = (pred_albedo * 255).astype(np.uint8) ## rgb + + vis_gt_albedo = cv2.cvtColor(vis_gt_albedo, cv2.COLOR_RGB2BGR) + vis_pred_albedo = cv2.cvtColor(vis_pred_albedo, cv2.COLOR_RGB2BGR) + + vis_image = np.concatenate( + [ + image, + vis_gt_albedo, + vis_pred_albedo, + ], + axis=1, + ) + vis_image = cv2.resize( + vis_image, + (3 * self.vis_image_width, self.vis_image_height), + interpolation=cv2.INTER_AREA, + ) + vis_images.append(vis_image) + + grid_image = np.concatenate(vis_images, axis=0) + + # Save the grid image to a file + grid_out_file = "{}_{}.jpg".format(prefix, suffix) + cv2.imwrite(grid_out_file, grid_image) + + return diff --git a/sapiens/dense/src/visualizers/normal_visualizer.py b/sapiens/dense/src/visualizers/normal_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..eecc876e40e1299e6db69922f851526592592350 --- /dev/null +++ b/sapiens/dense/src/visualizers/normal_visualizer.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from pathlib import Path + +import cv2 +import numpy as np +import torch +from sapiens.registry import VISUALIZERS +from torch import nn + + +@VISUALIZERS.register_module() +class NormalVisualizer(nn.Module): + def __init__( + self, + output_dir: str, + vis_interval: int = 100, + vis_max_samples: int = 4, + vis_image_width: int = 384, + vis_image_height: int = 512, + ): + super().__init__() + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.vis_max_samples = vis_max_samples + self.vis_interval = vis_interval + self.vis_image_width = vis_image_width + self.vis_image_height = vis_image_height + + def vis_normal(self, normal_map, mask=None): + normal_map[mask == 0] = np.nan + normal_map_vis = (((normal_map + 1) / 2) * 255).astype(np.uint8) + ## bgr to rgb + normal_map_vis = normal_map_vis[:, :, ::-1] + return normal_map_vis + + def add_batch(self, data_batch: dict, logs: dict, step: int): + pred_normals = logs["outputs"] + pred_normals = pred_normals.detach().cpu() # B x 3 x H x W + gt_normals = ( + data_batch["data_samples"]["gt_normal"].detach().cpu() + ) # B x 3 x H x W + masks = data_batch["data_samples"]["mask"].detach().cpu() # B x 1 x H x + inputs = data_batch["inputs"].detach().cpu() # B x 3 x H x W + + if pred_normals.dtype == torch.bfloat16: + inputs = inputs.float() + pred_normals = pred_normals.float() + + pred_normals = pred_normals.cpu().detach().numpy() ## B x 3 x H x W + pred_normals = pred_normals.transpose((0, 2, 3, 1)) ## B x H x W x 3 + batch_size = min(len(inputs), self.vis_max_samples) + + inputs = inputs[:batch_size] + pred_normals = pred_normals[:batch_size] ## B x 3 x H x W + gt_normals = gt_normals[:batch_size] ## B x 3 x H x W + masks = masks[:batch_size] ## B x 1 x H x W + + prefix = os.path.join(self.output_dir, "train") + suffix = str(step).zfill(6) + suffix += "_" + data_batch["data_samples"]["meta"]["img_path"][0].split("/")[ + -1 + ].replace(".png", "") + vis_images = [] + + for i, (input, gt_normal, mask, pred_normal) in enumerate( + zip(inputs, gt_normals, masks, pred_normals) + ): + image = input.permute(1, 2, 0).cpu().numpy() ## bgr image + image = np.ascontiguousarray(image.copy()) + + gt_normal = gt_normal.numpy() ## 3 x H x W + gt_normal = gt_normal.transpose((1, 2, 0)) ## H x W x 3 + mask = mask[0].numpy() > 0 ## H x W + + if ( + pred_normal.shape[0] != image.shape[0] + or pred_normal.shape[1] != image.shape[1] + ): + image = cv2.resize( + image, + (pred_normal.shape[1], pred_normal.shape[0]), + interpolation=cv2.INTER_LINEAR, + ) + + vis_gt_normal = self.vis_normal(gt_normal, mask) + vis_pred_normal = self.vis_normal(pred_normal, mask) + + vis_image = np.concatenate( + [ + image, + vis_gt_normal, + vis_pred_normal, + ], + axis=1, + ) + vis_image = cv2.resize( + vis_image, + (3 * self.vis_image_width, self.vis_image_height), + interpolation=cv2.INTER_AREA, + ) + vis_images.append(vis_image) + + grid_image = np.concatenate(vis_images, axis=0) + + # Save the grid image to a file + grid_out_file = "{}_{}.jpg".format(prefix, suffix) + cv2.imwrite(grid_out_file, grid_image) + + return diff --git a/sapiens/dense/src/visualizers/pointmap_visualizer.py b/sapiens/dense/src/visualizers/pointmap_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1a504874e43128b7c3100beb94c5899a29fc3e --- /dev/null +++ b/sapiens/dense/src/visualizers/pointmap_visualizer.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from pathlib import Path + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +from sapiens.registry import VISUALIZERS +from torch import nn + + +@VISUALIZERS.register_module() +class PointmapVisualizer(nn.Module): + def __init__( + self, + output_dir: str, + vis_interval: int = 100, + vis_max_samples: int = 4, + vis_image_width: int = 384, + vis_image_height: int = 512, + ): + super().__init__() + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.vis_max_samples = vis_max_samples + self.vis_interval = vis_interval + self.vis_image_width = vis_image_width + self.vis_image_height = vis_image_height + self.cmap = plt.get_cmap("turbo") + self.error_cmap = plt.get_cmap("hot") + + def vis_point_map(self, point_map, mask=None): + depth_map = point_map[:, :, 2] ### x,y,z. z is the depth + img = self.vis_depth_map(depth_map, mask=mask) + return img + + def vis_depth_map(self, depth, mask=None, background_color=100): + if mask is None: + inverse_depth = 1 / depth + inverse_depth_normalized = (inverse_depth - inverse_depth.min()) / ( + inverse_depth.max() - inverse_depth.min() + ) + color_depth = (self.cmap(inverse_depth_normalized)[..., :3] * 255).astype( + np.uint8 + ) + + ## convert RGB to BGR to save with cv2 + color_depth = color_depth[..., ::-1] + return color_depth + + depth_foreground = depth[mask > 0] + processed_depth = np.full( + (mask.shape[0], mask.shape[1], 3), background_color, dtype=np.uint8 + ) + + if len(depth_foreground) == 0: + return processed_depth + + inverse_depth_foreground = 1 / depth_foreground + + # Visualize inverse depth instead of depth, clipped to [0.1m;250m] range for better visualization. + max_inverse_depth = min(inverse_depth_foreground.max(), 1 / 0.1) + min_inverse_depth = max(1 / 250, inverse_depth_foreground.min()) + inverse_depth_foreground_normalized = ( + inverse_depth_foreground - min_inverse_depth + ) / (max_inverse_depth - min_inverse_depth) + + color_depth = ( + self.cmap(inverse_depth_foreground_normalized)[..., :3] * 255 + ).astype(np.uint8) + processed_depth[mask] = color_depth + + ## convert RGB to BGR to save with cv2 + processed_depth = processed_depth[..., ::-1] + + return processed_depth + + def vis_normal_from_point_map(self, point_map, mask=None, kernel_size=7): + depth_map = point_map[:, :, 2] ### x,y,z. z is the depth + + if mask.sum() == 0: + return np.full((mask.shape[0], mask.shape[1], 3), 100, dtype=np.uint8) + + depth_foreground = depth_map[mask > 0] + min_val, max_val = np.min(depth_foreground), np.max(depth_foreground) + + depth_normalized = np.full(mask.shape, np.inf) + depth_normalized[mask > 0] = 1 - ( + (depth_map[mask > 0] - min_val) / (max_val - min_val) + ) + + grad_x = cv2.Sobel( + depth_normalized.astype(np.float32), cv2.CV_32F, 1, 0, ksize=kernel_size + ) + grad_y = cv2.Sobel( + depth_normalized.astype(np.float32), cv2.CV_32F, 0, 1, ksize=kernel_size + ) + normals = np.dstack((-grad_x, -grad_y, np.full(grad_x.shape, -1))) + + normals_mag = np.linalg.norm(normals, axis=2, keepdims=True) + normals_normalized = normals / (normals_mag + 1e-5) + normal_vis = ((normals_normalized + 1) / 2 * 255).astype(np.uint8) + return normal_vis[:, :, ::-1] + + def vis_l1_error(self, gt_pointmap, pred_pointmap, mask=None, background_color=100): + """Visualize L1 error between ground truth and predicted pointmaps.""" + if mask is None: + mask = np.ones_like(gt_pointmap[:, :, 0], dtype=bool) + + error_map = np.full( + (mask.shape[0], mask.shape[1], 3), background_color, dtype=np.uint8 + ) + + # Calculate L1 error for valid points + l1_error = np.abs(gt_pointmap - pred_pointmap) # H x W x 3 + l1_error = np.mean(l1_error, axis=2) # Average across XYZ dimensions, H x W + + if np.sum(mask) > 0: + error_foreground = l1_error[mask] + + # Normalize error for visualization + error_normalized = (error_foreground - error_foreground.min()) / ( + error_foreground.max() - error_foreground.min() + 1e-6 + ) + + # Convert to color using hot colormap + error_colored = (self.error_cmap(error_normalized)[..., :3] * 255).astype( + np.uint8 + ) + error_map[mask] = error_colored + + # Convert to BGR for OpenCV + error_map = error_map[..., ::-1] + + return error_map + + def add_batch(self, data_batch: dict, logs: dict, step: int): + (pred_pointmaps, _) = logs["outputs"] + pred_pointmaps = pred_pointmaps.detach().cpu() # B x 3 x H x W + gt_pointmaps = ( + data_batch["data_samples"]["gt_pointmap"].detach().cpu() + ) # B x 3 x H x + masks = data_batch["data_samples"]["mask"].detach().cpu() # B x 1 x H x + inputs = data_batch["inputs"].detach().cpu() # B x 3 x H x W + + if pred_pointmaps.dtype == torch.bfloat16: + inputs = inputs.float() + pred_pointmaps = pred_pointmaps.float() + + pred_pointmaps = pred_pointmaps.cpu().detach().numpy() ## B x 3 x H x W + pred_pointmaps = pred_pointmaps.transpose((0, 2, 3, 1)) ## B x H x W x 3 + batch_size = min(len(inputs), self.vis_max_samples) + + inputs = inputs[:batch_size] + pred_pointmaps = pred_pointmaps[:batch_size] ## B x 3 x H x W + gt_pointmaps = gt_pointmaps[:batch_size] ## B x 3 x H x W + masks = masks[:batch_size] ## B x 1 x H x W + + prefix = os.path.join(self.output_dir, "train") + suffix = str(step).zfill(6) + suffix += "_" + data_batch["data_samples"]["meta"]["img_path"][0].split("/")[ + -1 + ].replace(".png", "") + vis_images = [] + + for i, (input, gt_pointmap, mask, pred_pointmap) in enumerate( + zip(inputs, gt_pointmaps, masks, pred_pointmaps) + ): + image = input.permute(1, 2, 0).cpu().numpy() ## bgr image + image = np.ascontiguousarray(image.copy()) + + gt_pointmap = gt_pointmap.numpy() ## 3 x H x W + gt_pointmap = gt_pointmap.transpose((1, 2, 0)) ## H x W x 3 + mask = mask[0].numpy() > 0 ## H x W + + ## resize predpoint to image size + if ( + pred_pointmap.shape[0] != image.shape[0] + or pred_pointmap.shape[1] != image.shape[1] + ): + image = cv2.resize( + image, + (pred_pointmap.shape[1], pred_pointmap.shape[0]), + interpolation=cv2.INTER_LINEAR, + ) + + vis_gt_pointmap = self.vis_point_map(gt_pointmap, mask) + vis_pred_pointmap = self.vis_point_map(pred_pointmap, mask) + + vis_gt_normal = self.vis_normal_from_point_map(gt_pointmap, mask) + vis_pred_normal = self.vis_normal_from_point_map(pred_pointmap, mask) + + vis_error = self.vis_l1_error(gt_pointmap, pred_pointmap, mask) + + vis_image = np.concatenate( + [ + image, + vis_gt_pointmap, + vis_gt_normal, + vis_pred_pointmap, + vis_pred_normal, + vis_error, + ], + axis=1, + ) + vis_image = cv2.resize( + vis_image, + (6 * self.vis_image_width, self.vis_image_height), + interpolation=cv2.INTER_AREA, + ) + vis_images.append(vis_image) + + grid_image = np.concatenate(vis_images, axis=0) + + # Save the grid image to a file + grid_out_file = "{}_{}.jpg".format(prefix, suffix) + cv2.imwrite(grid_out_file, grid_image) + + return diff --git a/sapiens/dense/src/visualizers/seg_visualizer.py b/sapiens/dense/src/visualizers/seg_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b8f884dace0782a9e88cb2e966efb86801a105 --- /dev/null +++ b/sapiens/dense/src/visualizers/seg_visualizer.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from pathlib import Path + +import cv2 +import numpy as np +import torch +from sapiens.registry import VISUALIZERS +from torch import nn + +from ..datasets import DOME_CLASSES_29 + + +@VISUALIZERS.register_module() +class SegVisualizer(nn.Module): + def __init__( + self, + output_dir: str = None, + vis_interval: int = 100, + vis_max_samples: int = 4, + vis_image_width: int = 384, + vis_image_height: int = 512, + class_palette_type="dome29", + overlay_opacity: float = 0.5, # 0..1; 1 = only mask colors + with_labels: bool = True, + ): + super().__init__() + if output_dir is not None: + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + + self.vis_max_samples = vis_max_samples + self.vis_interval = vis_interval + self.vis_image_width = vis_image_width + self.vis_image_height = vis_image_height + self.class_palette_type = class_palette_type + self.overlay_opacity = float(np.clip(overlay_opacity, 0.0, 1.0)) + self.with_labels = with_labels + self.class_palette = None + self.class_names = {} + + if self.class_palette_type == "dome29": + self.class_palette = self._build_palette_from_dict(DOME_CLASSES_29) + self.class_names = { + cid: meta.get("name", f"class_{cid}") + for cid, meta in DOME_CLASSES_29.items() + } + else: + self.class_palette = (np.random.rand(256, 3) * 255).astype(np.uint8) + + def _build_palette_from_dict(self, class_dict) -> np.ndarray: + max_id = max(int(k) for k in class_dict.keys()) + pal = np.zeros((max(max_id + 1, 256), 3), dtype=np.uint8) + for cid, meta in class_dict.items(): + col = meta.get("color", [0, 0, 0]) + pal[int(cid)] = np.array(col, dtype=np.uint8) + return pal # RGB format + + def _get_center_loc(self, mask: np.ndarray): + """ + Finds the center of the largest contour in a binary mask. + This is a robust method using moments. + """ + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not contours: + return None + largest_contour = max(contours, key=cv2.contourArea) + M = cv2.moments(largest_contour) + if M["m00"] == 0: + return None + cx = int(M["m10"] / M["m00"]) + cy = int(M["m01"] / M["m00"]) + return (cx, cy) + + def _draw_labels(self, image: np.ndarray, label_map: np.ndarray) -> np.ndarray: + """Draws class labels on the image at the center of each segment.""" + unique_labels = np.unique(label_map) + for class_id in unique_labels: + if class_id == 0 or class_id not in self.class_names: + continue # Skip background or unknown classes + + class_name = self.class_names[class_id] + mask = (label_map == class_id).astype(np.uint8) + loc = self._get_center_loc(mask) + if loc is None: + continue + + font = cv2.FONT_HERSHEY_SIMPLEX + scale = 0.05 + fontScale = min(image.shape[0], image.shape[1]) / (75 / scale) + fontColor = (255, 255, 255) # White text + thickness = 1 + rectangleThickness = 1 + + (label_width, label_height), baseline = cv2.getTextSize( + class_name, font, fontScale, thickness + ) + + x, y = loc + x = max(x - label_width // 2, 0) + y_text = y + label_height // 2 + rect_start_pt = (x, y - label_height // 2 - baseline) + rect_end_pt = (x + label_width, y + label_height // 2 + baseline) + class_color_rgb = self.class_palette[class_id] + class_color_bgr = tuple(int(c) for c in class_color_rgb[::-1]) + cv2.rectangle(image, rect_start_pt, rect_end_pt, class_color_bgr, -1) + cv2.rectangle( + image, rect_start_pt, rect_end_pt, (0, 0, 0), rectangleThickness + ) + cv2.putText( + image, class_name, (x, y_text), font, fontScale, fontColor, thickness + ) + return image + + def _visualize_segmentation( + self, image_bgr: np.ndarray, label_map: np.ndarray + ) -> np.ndarray: + if image_bgr.dtype != np.uint8: + raise ValueError("Input image must be uint8 for visualization.") + palette_bgr = self.class_palette[:, ::-1] + color_mask = palette_bgr[label_map] + + if self.with_labels: + color_mask = self._draw_labels(color_mask, label_map) + + overlay = cv2.addWeighted( + image_bgr, + 1 - self.overlay_opacity, + color_mask, + self.overlay_opacity, + 0, + ) + + return overlay + + def add_batch(self, data_batch: dict, logs: dict, step: int): + inputs = data_batch["inputs"].detach().cpu() # B x 3 x H x W + pred_logits = logs["outputs"].detach().cpu() # B x num_classes x H x W + gt_labels = (data_batch["data_samples"]["gt_seg"].detach().cpu()).squeeze( + dim=1 + ) ## B x H x W; + + if pred_logits.dtype == torch.bfloat16: + inputs = inputs.float() + pred_logits = pred_logits.float() + + pred_labels = pred_logits.argmax(dim=1) ## B x H x W + batch_size = min(len(inputs), self.vis_max_samples) + + inputs = inputs[:batch_size] + gt_labels = gt_labels[:batch_size] ## B x 1 x H x W + pred_labels = pred_labels[:batch_size] ## B x H x W + + prefix = os.path.join(self.output_dir, "train") + suffix = str(step).zfill(6) + vis_images = [] + + for i, (input, gt_label, pred_label) in enumerate( + zip(inputs, gt_labels, pred_labels) + ): + image = input.permute(1, 2, 0).cpu().numpy() ## bgr image + image = np.ascontiguousarray(image.copy()).astype(np.uint8) + + gt_label = gt_label.numpy().astype(np.uint8) ## H x W + pred_label = pred_label.numpy().astype(np.uint8) ## H x W + + vis_gt_seg = self._visualize_segmentation(image, gt_label) + vis_pred_seg = self._visualize_segmentation(image, pred_label) + + vis_image = np.concatenate( + [ + image, + vis_gt_seg, + vis_pred_seg, + ], + axis=1, + ) + vis_image = cv2.resize( + vis_image, + (3 * self.vis_image_width, self.vis_image_height), + interpolation=cv2.INTER_AREA, + ) + vis_images.append(vis_image) + + grid_image = np.concatenate(vis_images, axis=0) + grid_out_file = "{}_{}.jpg".format(prefix, suffix) + cv2.imwrite(grid_out_file, grid_image) + + return diff --git a/sapiens/dense/tools/deployment/pytorch2torchscript.py b/sapiens/dense/tools/deployment/pytorch2torchscript.py new file mode 100644 index 0000000000000000000000000000000000000000..48db755ef71902efc2465b06e24a8c7883ba35aa --- /dev/null +++ b/sapiens/dense/tools/deployment/pytorch2torchscript.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# originally copied from https://www.internalfb.com/code/fbsource/[671aa4920700]/fbcode/xrcia/projects/sapiens/experimental_ghe_import/sapiens2/sapiens/seg/tools/deployment/pytorch2torchscript.py?lines=1-204 + +import argparse +import os + +import torch +import torch._C +import torch.serialization +from sapiens.dense.models import init_model + +torch.manual_seed(3) +TORCH_MINIMUM_VERSION = "1.8.0" + + +def digit_version(version_str: str) -> list[int]: + """Convert a version string into a list of integers for comparison. + + This function parses version strings with complex formats and converts them into + comparable numeric arrays. It handles standard version numbers (like '1.2.3') + as well as release candidates (containing 'rc'). + + For standard version components, each number is directly converted to an integer. + For release candidates (e.g., '2rc1'), the function treats them as slightly + earlier than the final release by: + - Converting the number before 'rc' to (number - 1) + - Appending the rc number as an additional version component + + Examples: + '1.2.3' -> [1, 2, 3] + '0.1.2rc1' -> [0, 1, 1, 1] # 2rc1 becomes [1, 1] + '2.0rc1' -> [2, -1, 1] # 0rc1 becomes [-1, 1] + + Args: + version_str (str): The version string to convert. + + Returns: + list[int]: A list of integers representing the version for comparison. + """ + digit_version = [] + for x in version_str.split("."): # Split the version string by '.' + if x.isdigit(): # Check if the part is a digit + digit_version.append(int(x)) # Append the digit as an integer + elif x.find("rc") != -1: # Check if the part contains 'rc' + patch_version = x.split("rc") # Split the part by 'rc' + digit_version.append( + int(patch_version[0]) - 1 + ) # Append the number before 'rc' minus 1 + digit_version.append(int(patch_version[1])) # Append the number after 'rc' + return digit_version + + +def check_torch_version() -> None: + """Validate that the installed PyTorch version meets the minimum requirement. + + Raises: + RuntimeError: If the installed PyTorch version is below TORCH_MINIMUM_VERSION. + """ + torch_version = digit_version(torch.__version__) + if torch_version < digit_version(TORCH_MINIMUM_VERSION): + raise RuntimeError( + f"Torch=={torch.__version__} is not supported for converting to " + f"torchscript. Please install pytorch>={TORCH_MINIMUM_VERSION}." + ) + + +def pytorch2torchscript( + model: torch.nn.Module, + input_shape: tuple[int, int, int, int], + device: str, + show_graph: bool = False, + output_file: str = "tmp.pt", + verify: bool = False, +) -> None: + """Export Pytorch model to TorchScript model and verify the outputs are + same between Pytorch and TorchScript. + + Args: + model (nn.Module): Pytorch model we want to export. + input_shape (tuple): Use this input shape to construct + the corresponding dummy input and execute the model. + show_graph (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the + output TorchScript model. Default: `tmp.pt`. + verify (bool): Whether compare the outputs between + Pytorch and TorchScript. Default: False. + """ + + # Clear CUDA cache before starting conversion + if device == "cuda" or device.startswith("cuda:"): + torch.cuda.empty_cache() + print(f"Cleared CUDA cache before conversion") + + # replace the original forward with forward_dummy + # model.forward = model.forward_dummy + model.eval() + + # Use no_grad context to avoid storing gradients during tracing + # Create inputs inside the context to minimize memory footprint + with torch.no_grad(): + inputs = torch.rand(input_shape).to(device) + traced_model = torch.jit.trace( + model, + example_inputs=inputs, + check_trace=verify, + ) + # Explicitly delete inputs and clear cache to free memory + del inputs + if device == "cuda" or device.startswith("cuda:"): + torch.cuda.empty_cache() + + if show_graph: + print(traced_model.graph) + + # Clear CUDA cache before saving to free up memory + if device == "cuda" or device.startswith("cuda:"): + torch.cuda.empty_cache() + print(f"Cleared CUDA cache before saving") + + traced_model.save(output_file) + print(f"Successfully exported TorchScript model: {output_file}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert .pth checkpoint to TorchScript" + ) + parser.add_argument("config", help="test config file path") + parser.add_argument("--checkpoint", help="Checkpoint file") + parser.add_argument( + "--show-graph", action="store_true", help="show TorchScript graph" + ) + parser.add_argument( + "--verify", action="store_true", help="verify the TorchScript model" + ) + parser.add_argument("--output-file", type=str, default="tmp.pt") + parser.add_argument( + "--shape", + type=int, + nargs="+", + default=[1024, 768], + help="input image size (height, width)", + ) + parser.add_argument("--device", default="cuda:0", help="Device used for inference") + args = parser.parse_args() + return args + + +def main() -> None: + args = parse_args() + check_torch_version() + + if len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = ( + 1, + 3, + ) + tuple(args.shape) + else: + raise ValueError("invalid input shape") + + # build the model, load checkpoint + model = init_model(args.config, args.checkpoint, device=args.device) + + ## create the output directory if it does not exist + output_dir = os.path.dirname(args.output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # convert the PyTorch model to TorchScript model + pytorch2torchscript( + model, + input_shape=input_shape, + device=args.device, + show_graph=args.show_graph, + output_file=args.output_file, + verify=args.verify, + ) + + +if __name__ == "__main__": + main() diff --git a/sapiens/dense/tools/deployment/torch_optimization.py b/sapiens/dense/tools/deployment/torch_optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..1b234953da716f47e04d5448c81bf138ea2998bd --- /dev/null +++ b/sapiens/dense/tools/deployment/torch_optimization.py @@ -0,0 +1,422 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""PyTorch model optimization utilities for exporting and compiling models. + +This module provides utilities for: +- Converting SyncBatchNorm layers to standard BatchNorm +- Benchmarking model performance +- Exporting models with torch.export +- Compiling models with torch.compile +""" + +import argparse +from typing import Any + +import numpy as np +import torch +from sapiens.dense.models import init_model +from torch import nn + + +# ============================================================================= +# BatchNorm Conversion Utilities +# ============================================================================= + + +class _BatchNormXd(nn.modules.batchnorm._BatchNorm): + """A general BatchNorm layer without input dimension check. + + Reproduced from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + + The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc + is `_check_input_dim` that is designed for tensor sanity checks. + The check has been bypassed in this class for the convenience of converting + SyncBatchNorm. + """ + + def _check_input_dim(self, input: torch.Tensor) -> None: + return + + +def revert_sync_batchnorm(module: nn.Module) -> nn.Module: + """Convert all SyncBatchNorm layers in the model to BatchNormXd layers. + + Adapted from @kapily's work: + (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) + + Args: + module: The module containing `SyncBatchNorm` layers. + + Returns: + The converted module with `BatchNormXd` layers. + """ + module_output = module + module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm] + + if isinstance(module, tuple(module_checklist)): + module_output = _BatchNormXd( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + ) + if module.affine: + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + module_output.training = module.training + if hasattr(module, "qconfig"): + module_output.qconfig = module.qconfig + + for name, child in module.named_children(): + try: + module_output.add_module(name, revert_sync_batchnorm(child)) + except Exception: + print(f"Failed to convert {child} from SyncBN to BN!") + + del module + return module_output + + +def convert_batchnorm(module: nn.Module) -> nn.Module: + """Convert SyncBatchNorm to BatchNorm2d and optionally SiLU to ReLU. + + Args: + module: The module to convert. + + Returns: + The converted module. + """ + module_output = module + + if isinstance(module, torch.nn.SyncBatchNorm): + module_output = torch.nn.BatchNorm2d( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + ) + if module.affine: + module_output.weight.data = module.weight.data.clone().detach() + module_output.bias.data = module.bias.data.clone().detach() + module_output.weight.requires_grad = module.weight.requires_grad + module_output.bias.requires_grad = module.bias.requires_grad + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + + if isinstance(module, torch.nn.SiLU): + module_output = torch.nn.ReLU(inplace=True) + + for name, child in module.named_children(): + module_output.add_module(name, convert_batchnorm(child)) + + del module + return module_output + + +# ============================================================================= +# Benchmarking Utilities +# ============================================================================= + + +def benchmark_model( + model: nn.Module, + inputs: dict[str, Any], + model_name: str = "", + num_warmup: int = 3, + num_iterations: int = 10, +) -> float: + """Benchmark model inference time. + + Args: + model: The model to benchmark. + inputs: Dictionary containing 'imgs' tensor. + model_name: Name for logging purposes. + num_warmup: Number of warmup iterations (not counted). + num_iterations: Number of timed iterations. + + Returns: + Mean inference time per sample in milliseconds. + """ + imgs = ( + inputs["imgs"][0, ...].unsqueeze(0) + if model_name.lower() == "original" + else inputs["imgs"] + ) + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for benchmarking") + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + times = [] + + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(stream), torch.no_grad(): + # Warmup + for _ in range(num_warmup): + model(imgs) + torch.cuda.synchronize() + + # Timed iterations + for _ in range(num_iterations): + torch.cuda.synchronize() + start_event.record() + model(imgs) + end_event.record() + torch.cuda.synchronize() + times.append(start_event.elapsed_time(end_event)) + + torch.cuda.current_stream().wait_stream(stream) + + mean_time = np.mean(times) / imgs.shape[0] + print(f"Benchmark results for '{model_name}':") + print(f" Average time per sample: {mean_time:.2f} ms") + print(f" Total time ({num_iterations} iterations): {sum(times):.2f} ms") + print(f" Individual times: {[f'{t:.2f}' for t in times]}") + + return mean_time + + +# ============================================================================= +# Input Generation +# ============================================================================= + + +def create_demo_inputs(input_shape: tuple[int, int, int, int]) -> dict[str, Any]: + """Create demo inputs for model testing and export. + + Args: + input_shape: Tuple of (N, C, H, W) for input dimensions. + + Returns: + Dictionary with 'imgs' tensor and 'img_metas' list. + """ + n, c, h, w = input_shape + rng = np.random.RandomState(0) + imgs = rng.rand(*input_shape) + + img_metas = [ + { + "img_shape": (h, w, c), + "ori_shape": (h, w, c), + "pad_shape": (h, w, c), + "filename": ".png", + "scale_factor": 1.0, + "flip": False, + } + for _ in range(n) + ] + + return { + "imgs": torch.FloatTensor(imgs), + "img_metas": img_metas, + } + + +# ============================================================================= +# Model Export and Compilation +# ============================================================================= + + +class _ToDeviceTransformer(torch.fx.Transformer): + """FX Transformer to move operations to a specific device.""" + + def __init__(self, module: nn.Module, device: str): + super().__init__(module) + self.target_device = torch.device(device) + + def call_function(self, target, args, kwargs): + if "device" not in kwargs: + return super().call_function(target, args, kwargs) + + kwargs = dict(kwargs) + kwargs["device"] = self.target_device + return super().call_function(target, args, kwargs) + + +def compile_and_export_model( + model: nn.Module, + inputs: dict[str, Any], + output_file: str = "compiled_model.pt", + max_batch_size: int = 32, + dtype: torch.dtype = torch.bfloat16, +) -> None: + """Export model using torch.export and optionally compile with torch.compile. + + Args: + model: The model to export. + inputs: Demo inputs for tracing. + output_file: Path to save the exported model. + max_batch_size: Maximum batch size for dynamic shapes. + dtype: Data type for the model. + """ + inputs["imgs"] = inputs["imgs"].to(dtype) + imgs = inputs["imgs"] + model.eval() + + # Define dynamic shapes + dynamic_batch = torch.export.Dim("batch", min=1, max=max_batch_size) + dynamic_h = torch.export.Dim("h", min=1024, max=2048) + dynamic_w = torch.export.Dim("w", min=768, max=1536) + dynamic_shapes = {"inputs": {0: dynamic_batch, 2: dynamic_h, 3: dynamic_w}} + + # Export model + exported_model = torch.export.export( + model, + args=(imgs,), + kwargs={}, + dynamic_shapes=dynamic_shapes, + ) + torch.export.save(exported_model, output_file) + print(f"Model exported to: {output_file}") + + if not torch.cuda.is_available(): + return + # Compile and benchmark + device = "cuda:0" + model = torch.export.load(output_file).module().to(device) + model = _ToDeviceTransformer(model, device).transform() + imgs = imgs.to(device) + inputs["imgs"] = inputs["imgs"].to(device) + + _compile_and_benchmark(model, imgs, inputs) + + +def _compile_and_benchmark( + model: nn.Module, + imgs: torch.Tensor, + inputs: dict[str, Any], +) -> None: + """Compile model and benchmark different compilation modes. + + Args: + model: Model to compile. + imgs: Input images tensor. + inputs: Full inputs dictionary for benchmarking. + """ + modes = {"default": "default"} + best_mode = None + min_mean = float("inf") + + for mode_name, mode in modes.items(): + print(f"Compiling model with '{mode_name}' mode...") + + compiled_model = torch.compile(model, mode=mode) + + # Warmup + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream), torch.no_grad(): + for _ in range(3): + compiled_model(imgs) + torch.cuda.synchronize() + torch.cuda.current_stream().wait_stream(stream) + + mean_time = benchmark_model(compiled_model, inputs, model_name=mode_name) + if mean_time < min_mean: + min_mean = mean_time + best_mode = mode_name + + print(f"Best compilation mode: {best_mode}") + + +# ============================================================================= +# CLI Interface +# ============================================================================= + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments. + + Returns: + Parsed arguments. + """ + parser = argparse.ArgumentParser( + description="Export and optimize a model for deployment" + ) + parser.add_argument("config", help="Model config file path") + parser.add_argument("--checkpoint", help="Checkpoint file path") + parser.add_argument( + "--shape", + type=int, + nargs="+", + default=[1024, 768], + help="Input image size as (height, width)", + ) + parser.add_argument( + "--output-file", + "--output-dir", + type=str, + required=True, + help="Output file path for exported model", + ) + parser.add_argument( + "--max-batch-size", + type=int, + default=32, + help="Maximum batch size for dynamic export", + ) + parser.add_argument( + "--fp16", + action="store_true", + help="Use fp16 instead of bfloat16", + ) + + return parser.parse_args() + + +def main() -> None: + """Main entry point for model optimization CLI.""" + args = parse_args() + + # Determine input shape + if len(args.shape) == 1: + input_shape = (16, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = (16, 3, args.shape[0], args.shape[1]) + else: + raise ValueError("Shape must be 1 or 2 integers (height, width)") + + # Clamp batch size + max_batch_size = args.max_batch_size + input_shape = (max(1, min(input_shape[0], max_batch_size)), *input_shape[1:]) + + # Initialize model + model = init_model(args.config, args.checkpoint, device="cpu") + model.eval() + model = revert_sync_batchnorm(model) + + # Create demo inputs + demo_inputs = create_demo_inputs(input_shape) + + # Set dtype + dtype = torch.half if args.fp16 else torch.bfloat16 + model.to(dtype) + demo_inputs["imgs"] = demo_inputs["imgs"].to(dtype) + + # Export and compile + compile_and_export_model( + model, + demo_inputs, + output_file=args.output_file, + max_batch_size=max_batch_size, + dtype=dtype, + ) + + +if __name__ == "__main__": + main() diff --git a/sapiens/dense/tools/dist_test.sh b/sapiens/dense/tools/dist_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..06fe6c67e46c4ffc2458d561c93465782ba49399 --- /dev/null +++ b/sapiens/dense/tools/dist_test.sh @@ -0,0 +1,23 @@ +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH + +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 + +torchrun \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/test.py \ + $CONFIG \ + $CHECKPOINT \ + ${@:4} diff --git a/sapiens/dense/tools/dist_train.sh b/sapiens/dense/tools/dist_train.sh new file mode 100755 index 0000000000000000000000000000000000000000..c5e50b63e5e3208c89e7ea6420da7c1838c33b85 --- /dev/null +++ b/sapiens/dense/tools/dist_train.sh @@ -0,0 +1,21 @@ +CONFIG=$1 +GPUS=$2 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH + +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 + +torchrun \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/train.py \ + $CONFIG \ + ${@:3} diff --git a/sapiens/dense/tools/test.py b/sapiens/dense/tools/test.py new file mode 100644 index 0000000000000000000000000000000000000000..95e926ed9982c03b16712c66fc34047851a16b78 --- /dev/null +++ b/sapiens/dense/tools/test.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch + +# pyre-ignore[21]: Cannot find module `sapiens.engine.config` +from sapiens.engine.config import Config, DictAction +from sapiens.engine.runners import * + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Train a segmentor") + parser.add_argument("config", help="train config file path") + parser.add_argument("checkpoint", help="checkpoint file") + parser.add_argument("--work-dir", help="the dir to save logs and models") + parser.add_argument( + "--cfg-options", + nargs="+", + action=DictAction, + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) + parser.add_argument("--local_rank", "--local-rank", type=int, default=0) + args = parser.parse_args(argv) + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) + + return args + + +def main(argv: list[str] | None = None) -> None: + args = parse_args(argv) + + # load config + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + cfg.work_dir = args.work_dir + cfg.load_from = args.checkpoint + + # set train to false + cfg.train_dataloader = None + + # start testing + runner_type = cfg.get("runner_type", "BaseRunner") + runner = eval(runner_type).from_cfg(cfg) + runner.test() + + +if __name__ == "__main__": + main() diff --git a/sapiens/dense/tools/train.py b/sapiens/dense/tools/train.py new file mode 100644 index 0000000000000000000000000000000000000000..3c918a6fb689976814b361afaee3a45eeb61ca00 --- /dev/null +++ b/sapiens/dense/tools/train.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +import torchvision + +# pyre-ignore[21]: Cannot find module `sapiens.engine.config` +from sapiens.engine.config import Config, DictAction +from sapiens.engine.runners import * + +torch.set_float32_matmul_precision("high") # A100 gpus +torchvision.disable_beta_transforms_warning() # Disable the beta transforms warning + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Train a segmentor") + parser.add_argument("config", help="train config file path") + parser.add_argument("--work-dir", help="the dir to save logs and models") + parser.add_argument( + "--resume", + nargs="?", + type=str, + const="auto", + help="If specify checkpint path, resume from it, while if not " + "specify, try to auto resume from the latest checkpoint " + "in the work directory.", + ) + parser.add_argument( + "--cfg-options", + nargs="+", + action=DictAction, + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) + parser.add_argument("--local_rank", "--local-rank", type=int, default=0) + args = parser.parse_args(argv) + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) + + return args + + +def main(argv: list[str] | None = None) -> None: + args = parse_args(argv) + + # load config + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + cfg.work_dir = args.work_dir + + # resume training + if args.resume is not None: + cfg.resume = True + cfg.load_from = args.resume + + # start training + runner_type = cfg.get("runner_type", "BaseRunner") + runner = eval(runner_type).from_cfg(cfg) + runner.train() + + +if __name__ == "__main__": + main() diff --git a/sapiens/dense/tools/vis/vis_albedo.py b/sapiens/dense/tools/vis/vis_albedo.py new file mode 100644 index 0000000000000000000000000000000000000000..494d2f43c6792ae4208891caf8d906e5aae04038 --- /dev/null +++ b/sapiens/dense/tools/vis/vis_albedo.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from argparse import ArgumentParser + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from sapiens.dense.models import init_model +from tqdm import tqdm + + +def main(): + parser = ArgumentParser() + parser.add_argument("config", help="Config file") + parser.add_argument("checkpoint", help="Checkpoint file") + parser.add_argument("--input", help="Input image dir") + parser.add_argument("--output", default=None, help="Path to output dir") + parser.add_argument( + "--seg_dir", "--seg-dir", default=None, help="Path to segmentation dir" + ) + parser.add_argument("--device", default="cuda:0", help="Device used for inference") + args = parser.parse_args() + + model = init_model(args.config, args.checkpoint, device=args.device) + os.makedirs(args.output, exist_ok=True) + + # Get image list + if os.path.isdir(args.input): + input_dir = args.input + image_names = [ + name + for name in sorted(os.listdir(input_dir)) + if name.endswith((".jpg", ".png", ".jpeg")) + ] + else: + with open(args.input, "r") as f: + image_paths = [line.strip() for line in f if line.strip()] + image_names = [os.path.basename(path) for path in image_paths] + input_dir = os.path.dirname(image_paths[0]) + + seg_dir = args.seg_dir + + for image_name in tqdm(image_names, total=len(image_names)): + image_path = os.path.join(input_dir, image_name) + image = cv2.imread(image_path) + + mask_path = os.path.join( + seg_dir, + image_name.replace(".png", ".npy") + .replace(".jpg", ".npy") + .replace(".jpeg", ".npy"), + ) + + mask_path_candidates = [ + mask_path, # npy + mask_path.replace(".npy", "_seg.npy"), # npy, seg probs + os.path.join(seg_dir, image_name), # png or jpg + ] + + mask = np.ones_like(image[:, :, 0], dtype=bool) + for mask_path in mask_path_candidates: + if not os.path.exists(mask_path): + continue + if mask_path.endswith("_seg.npy"): + mask = np.load(mask_path) ## H x W, float; class labels + mask = mask > 0 ## skip the bg class + elif mask_path.endswith(".npy"): + mask = np.load(mask_path) ## H x W, boolean + else: + mask = cv2.imread(mask_path)[:, :, 0] ## H x W, uint8 + mask = mask > 0 + break + + ##------------------------------------------ + data = model.pipeline(dict(img=image)) ## resize and pad + data = model.data_preprocessor(data) ## normalize, add batch dim and cast + inputs, data_samples = data["inputs"], data["data_samples"] + + ## pointmap is 1 x 3 x H x W, scale is 1 x 1 + with torch.no_grad(): + albedo = model(inputs) + albedo = albedo.clamp(0, 1) # clamp to [0, 1] + + # ------------------------------------------ + pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"] + albedo = albedo[ + :, + :, + pad_top : inputs.shape[2] - pad_bottom, + pad_left : inputs.shape[3] - pad_right, + ] + + albedo = F.interpolate( + albedo, + size=(image.shape[0], image.shape[1]), + mode="bilinear", + align_corners=False, + ) + + albedo = albedo.squeeze(0).cpu().numpy().transpose(1, 2, 0) ## H x W x 3 + albedo = (albedo * 255).astype(np.uint8) + albedo[mask == 0] = [100, 100, 100] + albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR) + vis_image = np.concatenate([image, albedo], axis=1) + save_path = os.path.join(args.output, image_name) + cv2.imwrite(save_path, vis_image) + + +if __name__ == "__main__": + main() diff --git a/sapiens/dense/tools/vis/vis_normal.py b/sapiens/dense/tools/vis/vis_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..c267bf967c135b71c3c388768f57d1e3a59dcb62 --- /dev/null +++ b/sapiens/dense/tools/vis/vis_normal.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from argparse import ArgumentParser + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from sapiens.dense.models import init_model +from tqdm import tqdm + + +def main(): + parser = ArgumentParser() + parser.add_argument("config", help="Config file") + parser.add_argument("checkpoint", help="Checkpoint file") + parser.add_argument("--input", help="Input image dir") + parser.add_argument("--output", default=None, help="Path to output dir") + parser.add_argument( + "--no-black-background", + "--no_black_background", + action="store_true", + help="No black background", + ) + parser.add_argument( + "--seg_dir", "--seg-dir", default=None, help="Path to segmentation dir" + ) + parser.add_argument("--device", default="cuda:0", help="Device used for inference") + parser.add_argument( + "--no-save-predictions", + action="store_true", + help="If provided, do not save .npy prediction files", + ) + + args = parser.parse_args() + + model = init_model(args.config, args.checkpoint, device=args.device) + os.makedirs(args.output, exist_ok=True) + + # Get image list + if os.path.isdir(args.input): + input_dir = args.input + image_names = [ + name + for name in sorted(os.listdir(input_dir)) + if name.endswith((".jpg", ".png", ".jpeg")) + ] + else: + with open(args.input, "r") as f: + image_paths = [line.strip() for line in f if line.strip()] + image_names = [os.path.basename(path) for path in image_paths] + input_dir = os.path.dirname(image_paths[0]) + + seg_dir = args.seg_dir + + for image_name in tqdm(image_names, total=len(image_names)): + image_path = os.path.join(input_dir, image_name) + image = cv2.imread(image_path) + + mask_path = os.path.join( + seg_dir, + image_name.replace(".png", ".npy") + .replace(".jpg", ".npy") + .replace(".jpeg", ".npy"), + ) + + mask_path_candidates = [ + mask_path, # npy + mask_path.replace(".npy", "_seg.npy"), # npy, seg probs + os.path.join(seg_dir, image_name), # png or jpg + ] + + mask = np.ones_like(image[:, :, 0], dtype=bool) + for mask_path in mask_path_candidates: + if not os.path.exists(mask_path): + continue + if mask_path.endswith("_seg.npy"): + mask = np.load(mask_path) ## H x W, float; class labels + mask = mask > 0 ## skip the bg class + elif mask_path.endswith(".npy"): + mask = np.load(mask_path) ## H x W, boolean + else: + mask = cv2.imread(mask_path)[:, :, 0] ## H x W, uint8 + mask = mask > 0 + break + + ##------------------------------------------ + data = model.pipeline(dict(img=image)) ## resize and pad + data = model.data_preprocessor(data) ## normalize, add batch dim and cast + inputs, data_samples = data["inputs"], data["data_samples"] + + with torch.no_grad(): + normal = model(inputs) # normal is 1 x 3 x H x W + normal = normal / torch.norm(normal, dim=1, keepdim=True).clamp( + min=1e-8 + ) # normalize to unit length + + # ------------------------------------------ + pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"] + normal = normal[ + :, + :, + pad_top : inputs.shape[2] - pad_bottom, + pad_left : inputs.shape[3] - pad_right, + ] + + normal = F.interpolate( + normal, + size=(image.shape[0], image.shape[1]), + mode="bilinear", + align_corners=False, + ) + + normal = normal.squeeze(0).cpu().numpy().transpose(1, 2, 0) ## H x W x 3 + + if not args.no_save_predictions: + base_path = os.path.join(args.output, image_name.rsplit(".")[0]) + np.save(f"{base_path}.npy", normal) + + normal[mask == 0] = -1 + normal_vis = ((normal + 1) / 2 * 255).astype(np.uint8) + normal_vis = normal_vis[:, :, ::-1] + + if args.no_black_background: + normal_vis[mask == 0] = image[mask == 0] + + vis_image = np.concatenate([image, normal_vis], axis=1) + save_path = os.path.join(args.output, image_name) + cv2.imwrite(save_path, vis_image) + + +if __name__ == "__main__": + main() diff --git a/sapiens/dense/tools/vis/vis_pointmap.py b/sapiens/dense/tools/vis/vis_pointmap.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ac81ec3b27bb3a07603d9915f645e71421f6e0 --- /dev/null +++ b/sapiens/dense/tools/vis/vis_pointmap.py @@ -0,0 +1,320 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from argparse import ArgumentParser + +import cv2 +import numpy as np +import torch + +try: + import open3d as o3d +except ImportError as e: + raise ImportError( + "open3d is required for pointmap visualization. " + "Install with: pip install open3d (or `pip install -e .[pointmap]`)" + ) from e +import torch.nn.functional as F +import torchvision +import torchvision.transforms.functional as TVF +from matplotlib import pyplot as plt +from sapiens.dense.models import init_model +from tqdm import tqdm + +cmap = plt.get_cmap("turbo") +torchvision.disable_beta_transforms_warning() + + +# ------------------------------------------------------------------------------- +def process_depth_map_with_bounds(depth_map, mask, min_val, max_val): + """Render depth as turbo colormap using pre-computed (global) min/max bounds.""" + processed_depth = np.full((mask.shape[0], mask.shape[1], 3), 100, dtype=np.uint8) + + if min_val is None or max_val is None or not np.any(mask): + return processed_depth + + depth_foreground = depth_map[mask > 0] + if len(depth_foreground) == 0: + return processed_depth + + inverse_depth_foreground = 1 / np.clip(depth_foreground, 1e-6, None) + max_inverse_depth = min(1 / max(min_val, 1e-6), 1 / 0.1) + min_inverse_depth = max(1 / 250, 1 / max(max_val, 1e-6)) + + inverse_depth_foreground_normalized = ( + inverse_depth_foreground - min_inverse_depth + ) / (max_inverse_depth - min_inverse_depth) + inverse_depth_foreground_normalized = np.clip( + inverse_depth_foreground_normalized, 0, 1 + ) + + color_depth = (cmap(inverse_depth_foreground_normalized)[..., :3] * 255).astype( + np.uint8 + ) + processed_depth[mask] = color_depth + processed_depth = processed_depth[..., ::-1] ## convert RGB to BGR to save with cv2 + return np.array(processed_depth, dtype=np.uint8) + + +def compute_surface_normals(depth_map, mask, min_val, max_val, kernel_size=7): + """Compute surface normals from depth map.""" + depth_normalized = np.full(mask.shape, np.inf) + depth_normalized[mask > 0] = 1 - ( + (depth_map[mask > 0] - min_val) / (max_val - min_val) + ) + + grad_x = cv2.Sobel( + depth_normalized.astype(np.float32), cv2.CV_32F, 1, 0, ksize=kernel_size + ) + grad_y = cv2.Sobel( + depth_normalized.astype(np.float32), cv2.CV_32F, 0, 1, ksize=kernel_size + ) + normals = np.dstack((-grad_x, -grad_y, np.full(grad_x.shape, -1))) + + normals_mag = np.linalg.norm(normals, axis=2, keepdims=True) + normals_normalized = normals / (normals_mag + 1e-5) + normal_vis = ((normals_normalized + 1) / 2 * 255).astype(np.uint8) + return normal_vis[:, :, ::-1] + + +def resize_pointmap( + pointmap, + target_height, + target_width, + smooth=False, + blur_ks=3, + blur_sigma=0.8, + smooth_iters=4, +): + assert pointmap.dim() == 4 and pointmap.shape[1] == 3, "pointmap must be 1x3xHxW" + H, W = pointmap.shape[2], pointmap.shape[3] + up = (target_height > H) or (target_width > W) + + if smooth and up: + for _ in range(smooth_iters): + pointmap = TVF.gaussian_blur( + pointmap, kernel_size=[blur_ks, blur_ks], sigma=[blur_sigma, blur_sigma] + ) + + pointmap = F.interpolate( + pointmap, + size=(target_height, target_width), + mode="bilinear", + align_corners=False, + antialias=False, + ) ## 1 x 3 x H x W + + return pointmap + + +def load_image_and_mask(input_dir, seg_dir, image_name): + """Load BGR image and boolean foreground mask. Mask defaults to all-True if missing.""" + image_path = os.path.join(input_dir, image_name) + image = cv2.imread(image_path) + + mask = np.ones_like(image[:, :, 0], dtype=bool) + if seg_dir is None: + return image, mask + + mask_base = ( + image_name.replace(".png", ".npy") + .replace(".jpg", ".npy") + .replace(".jpeg", ".npy") + ) + npy_path = os.path.join(seg_dir, mask_base) + candidates = [ + npy_path, + npy_path.replace(".npy", "_seg.npy"), + os.path.join(seg_dir, image_name), + ] + + for mp in candidates: + if not os.path.exists(mp): + continue + if mp.endswith("_seg.npy"): + m = np.load(mp) ## H x W, float; class labels + mask = m > 0 + elif mp.endswith(".npy"): + mask = np.load(mp) ## H x W, boolean + else: + mask = cv2.imread(mp)[:, :, 0] > 0 + break + + return image, mask + + +def main(): + parser = ArgumentParser() + parser.add_argument("config", help="Config file") + parser.add_argument("checkpoint", help="Checkpoint file") + parser.add_argument("--input", help="Input image dir") + parser.add_argument("--output", default=None, help="Path to output dir") + parser.add_argument( + "--seg_dir", "--seg-dir", default=None, help="Path to segmentation dir" + ) + parser.add_argument("--device", default="cuda:0", help="Device used for inference") + parser.add_argument( + "--no-black-background", + "--no_black_background", + action="store_true", + help="No black background", + ) + parser.add_argument( + "--no-save-predictions", + action="store_true", + help="If provided, do not save .ply prediction files", + ) + parser.add_argument( + "--with-normal", + "--with_normal", + action="store_true", + help="Also render surface-normal panel (default off โ€” output is image | depth only).", + ) + + args = parser.parse_args() + + os.makedirs(args.output, exist_ok=True) + + # Get image list + if os.path.isdir(args.input): + input_dir = args.input + image_names = [ + name + for name in sorted(os.listdir(input_dir)) + if name.endswith((".jpg", ".png", ".jpeg")) + ] + else: + with open(args.input, "r") as f: + image_paths = [line.strip() for line in f if line.strip()] + image_names = [os.path.basename(path) for path in image_paths] + input_dir = os.path.dirname(image_paths[0]) + + seg_dir = args.seg_dir + + # Defer model init until we actually need to run inference (skip-if-cached). + model = None + + # ============== Pass 1: inference (cached) + percentile collection ============== + per_frame_percentiles = [] # list of (p1, p99) of foreground depth per frame + for image_name in tqdm(image_names, desc="pass 1: inference"): + base_path = os.path.join(args.output, image_name.rsplit(".")[0]) + depth_npy_path = f"{base_path}_depth.npy" + + image, mask = load_image_and_mask(input_dir, seg_dir, image_name) + + if os.path.exists(depth_npy_path): + depth = np.load(depth_npy_path).astype(np.float32) + else: + if model is None: + model = init_model(args.config, args.checkpoint, device=args.device) + + data = model.pipeline(dict(img=image)) ## resize and pad + data = model.data_preprocessor(data) ## normalize, add batch dim and cast + inputs, data_samples = data["inputs"], data["data_samples"] + + ## pointmap is 1 x 3 x H x W, scale is 1 x 1 + with torch.no_grad(): + pointmap, scale = model(inputs) + pointmap = pointmap / scale ## convert pointmap to metric + + assert ( + pointmap.shape[0] == 1 + and pointmap.shape[2] == inputs.shape[2] + and pointmap.shape[3] == inputs.shape[3] + ) + + pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"][ + "padding_size" + ] + pointmap = pointmap[ + :, + :, + pad_top : inputs.shape[2] - pad_bottom, + pad_left : inputs.shape[3] - pad_right, + ] + + pointmap = resize_pointmap( + pointmap, + target_height=mask.shape[0], + target_width=mask.shape[1], + smooth=True, + ) + pointmap = pointmap.squeeze(0).cpu().numpy().transpose(1, 2, 0) ## H x W x 3 + + depth = pointmap[:, :, 2].astype(np.float32) + np.save(depth_npy_path, depth.astype(np.float16)) + + if not args.no_save_predictions: + points = pointmap[mask > 0].reshape(-1, 3) ## N x 3 + pc = o3d.geometry.PointCloud() + colors = image[mask > 0] / 255.0 + colors = colors[:, [2, 1, 0]] # Convert BGR to RGB + + pc.points = o3d.utility.Vector3dVector(points) + pc.colors = o3d.utility.Vector3dVector(colors) + sphere = o3d.geometry.TriangleMesh.create_sphere( + radius=0.05, resolution=20 + ) + sphere.translate([0, 0, 0]) # Center the sphere at the origin + sphere.paint_uniform_color([0, 0, 1]) # Color the sphere blue + sphere_pc = sphere.sample_points_poisson_disk(number_of_points=500) + sphere_pc.colors = o3d.utility.Vector3dVector( + [[0, 0, 1] for _ in range(len(sphere_pc.points))] + ) + pc = pc + sphere_pc + o3d.io.write_point_cloud(f"{base_path}.ply", pc) + + depth_fg = depth[mask > 0] + if len(depth_fg) > 0: + p1, p99 = np.percentile(depth_fg, [1, 99]) + per_frame_percentiles.append((float(p1), float(p99))) + + # ============== Aggregate global bounds: min(p1) and max(p99) ============== + if per_frame_percentiles: + arr = np.array(per_frame_percentiles) + global_min = float(arr[:, 0].min()) + global_max = float(arr[:, 1].max()) + print( + f"global depth bounds across {len(per_frame_percentiles)} frames: " + f"min={global_min:.3f}, max={global_max:.3f}" + ) + else: + global_min, global_max = None, None + print("warning: no foreground found in any frame, skipping render pass") + + # ============== Pass 2: render with global bounds (no inference) ============== + for image_name in tqdm(image_names, desc="pass 2: render"): + base_path = os.path.join(args.output, image_name.rsplit(".")[0]) + depth_npy_path = f"{base_path}_depth.npy" + + if not os.path.exists(depth_npy_path): + continue + + image, mask = load_image_and_mask(input_dir, seg_dir, image_name) + if not np.any(mask): + continue + + depth = np.load(depth_npy_path).astype(np.float32) + + processed_depth = process_depth_map_with_bounds( + depth, mask, global_min, global_max + ) + panels = [image, processed_depth] + if args.with_normal: + normal_vis = compute_surface_normals(depth, mask, global_min, global_max) + panels.append(normal_vis) + + if args.no_black_background: + for p in panels[1:]: + p[mask == 0] = image[mask == 0] + + vis_image = np.concatenate(panels, axis=1) + cv2.imwrite(f"{base_path}{os.path.splitext(image_name)[1]}", vis_image) + + +if __name__ == "__main__": + main() diff --git a/sapiens/dense/tools/vis/vis_seg.py b/sapiens/dense/tools/vis/vis_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..0337d84d8abc58f20497c221f471baf3970bfc88 --- /dev/null +++ b/sapiens/dense/tools/vis/vis_seg.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from argparse import ArgumentParser + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from sapiens.dense.models import init_model +from sapiens.dense.visualizers import SegVisualizer +from tqdm import tqdm + + +def main(): + parser = ArgumentParser() + parser.add_argument("config", help="Config file") + parser.add_argument("checkpoint", help="Checkpoint file") + parser.add_argument("--input", help="Input image dir") + parser.add_argument("--output", default=None, help="Path to output dir") + parser.add_argument( + "--class_palette_type", default="dome29", help="Color palette for 29 classes" + ) + parser.add_argument( + "--save_pred", action="store_true", help="Save prediction to file" + ) + parser.add_argument("--device", default="cuda:0", help="Device used for inference") + args = parser.parse_args() + + model = init_model(args.config, args.checkpoint, device=args.device) + os.makedirs(args.output, exist_ok=True) + + # Get image list + if os.path.isdir(args.input): + input_dir = args.input + image_names = [ + name + for name in sorted(os.listdir(input_dir)) + if name.endswith((".jpg", ".png", ".jpeg")) + ] + else: + with open(args.input, "r") as f: + image_paths = [line.strip() for line in f if line.strip()] + image_names = [os.path.basename(path) for path in image_paths] + input_dir = os.path.dirname(image_paths[0]) + + class_palette_type = args.class_palette_type + visualizer = SegVisualizer(class_palette_type=class_palette_type, with_labels=False) + + for i, image_name in tqdm(enumerate(image_names), total=len(image_names)): + image_path = os.path.join(input_dir, image_name) + image = cv2.imread(image_path) + + data = model.pipeline(dict(img=image)) ## resize and pad + data = model.data_preprocessor(data) ## normalize, add batch dim and cast + inputs = data["inputs"] + + ## pointmap is 1 x 3 x H x W, scale is 1 x 1 + with torch.no_grad(): + seg_logits = model(inputs) + + # resize prediction to image size + seg_logits = F.interpolate( + seg_logits, size=image.shape[:2], mode="bilinear" + ) ## 1 x C x H x W + pred_labels = seg_logits.argmax(dim=1).cpu().numpy() ## 1 x H x W + pred_labels = pred_labels.squeeze(0) ## H x W + + vis_seg = visualizer._visualize_segmentation(image, pred_labels) + vis_image = np.concatenate([image, vis_seg], axis=1) + base_path = os.path.join(args.output, image_name.rsplit(".")[0]) + cv2.imwrite(f"{base_path}.{image_name.rsplit('.')[1]}", vis_image) + + if args.save_pred: + np.save(f"{base_path}_seg.npy", pred_labels) + + +if __name__ == "__main__": + main() diff --git a/sapiens/engine/__init__.py b/sapiens/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c3377db59421ab3387ba875e3bd41b9f6ab27603 --- /dev/null +++ b/sapiens/engine/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .optim import * +from .runners import * +from .logger import * +from .visualizers import * +from .datasets import * +from .models import * +from .config import * diff --git a/sapiens/engine/config.py b/sapiens/engine/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c8bd8d296c0ff334a0c7c7c06929080404d10f9b --- /dev/null +++ b/sapiens/engine/config.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import argparse +import ast +import json +from importlib import util as _iu +from pathlib import Path +from types import ModuleType, SimpleNamespace + + +class Config(SimpleNamespace): + """Dot-dict wrapper around a Python-module config.""" + + DELETE_KEY = "_delete_" + + def __init__(self, **kwargs): + super().__init__(**kwargs) # keep SimpleNamespace behaviour + object.__setattr__( + self, + "_cfg_dict", # use object.__setattr__ to avoid + json.loads(json.dumps(kwargs)), + ) # recursion hook + + @classmethod + def fromfile(cls, path: str | Path) -> "Config": + path = Path(path).expanduser() + spec = _iu.spec_from_file_location(path.stem, path) + mod = _iu.module_from_spec(spec) + spec.loader.exec_module(mod) # type: ignore + data = { + k: v + for k, v in vars(mod).items() + if not k.startswith("_") and not isinstance(v, ModuleType) + } + data["filename"] = str(path) + return cls(**data) + + def merge_from_dict( + self, options: dict[str, object], allow_list_keys: bool = True + ) -> None: + """Merge a dict of dotted keys (or nested dict) into this Config.""" + # Lazily ensure _cfg_dict exists if __init__ was bypassed + if not hasattr(self, "_cfg_dict"): + object.__setattr__(self, "_cfg_dict", self.to_dict()) + + # 1. expand dotted keys -> nested dict + nested_patch: dict = {} + for dotted_key, value in options.items(): + node = nested_patch + parts = dotted_key.split(".") + for p in parts[:-1]: + node = node.setdefault(p, {}) + node[parts[-1]] = value + + # 2. deep-merge + merged = self._merge_a_into_b( + nested_patch, self._cfg_dict, allow_list_keys=allow_list_keys + ) + object.__setattr__(self, "_cfg_dict", merged) + + # 3. rebuild attributes so dot-access sees the new values + self.__dict__.update(merged) + + # ------------------------------------------------------------------ + @staticmethod + def _merge_a_into_b(a: dict, b: dict | list, *, allow_list_keys: bool = False): + """Deep-merge *a* into *b* (returns a **new** structure).""" + from copy import deepcopy + + # If the target is a list and we allow list keys --------------- + if allow_list_keys and isinstance(b, list): + b = deepcopy(b) + for k, v in a.items(): + if not k.isdigit(): + raise TypeError(f"Expected int-like key for list merge, got {k!r}") + idx = int(k) + if idx >= len(b): + raise IndexError(f"Index {idx} out of range for list.") + b[idx] = Config._merge_a_into_b( + v, b[idx], allow_list_keys=allow_list_keys + ) + return b + + # Otherwise we are merging dicts ------------------------------- + b = deepcopy(b) + for k, v in a.items(): + # Deletion flag (_delete_ = True) -------------------------- + if isinstance(v, dict) and v.pop(Config.DELETE_KEY, False): + b[k] = Config._merge_a_into_b(v, {}, allow_list_keys) + continue + + if ( + k in b + and isinstance(b[k], (dict, list)) + and isinstance(v, (dict, list)) + ): + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys=allow_list_keys) + else: + b[k] = deepcopy(v) + return b + + def __getitem__(self, item): + return self.__dict__[item] + + def __setitem__(self, key, value): + self.__dict__[key] = value + + def get(self, key, default=None): + """Get a value by key with an optional default if key doesn't exist.""" + return self.__dict__.get(key, default) + + def to_dict(self) -> dict: + def _rec(obj): + if isinstance(obj, Config): + return {k: _rec(v) for k, v in obj.__dict__.items()} + if isinstance(obj, dict): + return {k: _rec(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return type(obj)(_rec(x) for x in obj) + return obj + + return _rec(self) + + +# --------------------------------------------------------------------------- +class DictAction(argparse.Action): + def __call__(self, _parser, namespace, values, _option_string=None): + out: dict[str, object] = {} + for kv in values: + if "=" not in kv: + raise ValueError(f"--cfg-options expected key=value, got {kv}") + key, val = kv.split("=", 1) + for parser_fn in (json.loads, ast.literal_eval, str): + try: + val = parser_fn(val) + break + except Exception: + continue + cur = out + *parents, leaf = key.split(".") + for p in parents: + cur = cur.setdefault(p, {}) + cur[leaf] = val + setattr(namespace, self.dest, out) + + +# --------------------------------------------------------------------------- +_INDENT = 4 + + +def _indent_block(txt: str, n: int = _INDENT) -> str: + pad = " " * n + return "\n".join(pad + ln if ln else ln for ln in txt.splitlines()) + + +def _format_basic(k, v, mapping: bool) -> str: + v_str = repr(v) if isinstance(v, str) else str(v) + if mapping: + k_str = repr(k) if isinstance(k, str) else str(k) + return f"{k_str}: {v_str}" + return f"{k}={v_str}" + + +def _format_list_tuple(obj, mapping_key=None, mapping=False): + left, right = ("[", "]") if isinstance(obj, list) else ("(", ")") + body = [] + for item in obj: + if isinstance(item, dict): + body.append("dict(" + _indent_block(_format_dict(item), _INDENT) + "),") + elif isinstance(item, (list, tuple)): + body.append(_indent_block(_format_list_tuple(item), _INDENT) + ",") + else: + body.append(repr(item) + ",") + inner = "\n".join(body) + if mapping_key is None: + return _indent_block(left + "\n" + inner + "\n" + right, _INDENT) + if mapping: + k_str = repr(mapping_key) if isinstance(mapping_key, str) else str(mapping_key) + return f"{k_str}: {left}\n{_indent_block(inner, _INDENT)}\n{right}" + return f"{mapping_key}={left}\n{_indent_block(inner, _INDENT)}\n{right}" + + +def _contains_non_identifier(keys): + return any((not str(k).isidentifier()) for k in keys) + + +def _format_dict(d: dict, outer: bool = False) -> str: + lines = [] + mapping = _contains_non_identifier(d) + if mapping and not outer: + lines.append("{") + for idx, (k, v) in enumerate(sorted(d.items(), key=lambda x: str(x[0]))): + is_last = idx == len(d) - 1 + suffix = "" if outer or is_last else "," + if isinstance(v, dict): + inner = _format_dict(v) + if mapping: + k_str = repr(k) if isinstance(k, str) else str(k) + line = f"{k_str}: dict(\n{_indent_block(inner, _INDENT)}\n){suffix}" + else: + line = f"{k}=dict(\n{_indent_block(inner, _INDENT)}\n){suffix}" + elif isinstance(v, (list, tuple)): + line = _format_list_tuple(v, mapping_key=k, mapping=mapping) + suffix + else: + line = _format_basic(k, v, mapping) + suffix + lines.append(_indent_block(line, _INDENT)) + if mapping and not outer: + lines.append("}") + return "\n".join(lines) + + +# ----------------------------------------------------------------- +def pretty_text(cfg_dict: dict) -> str: + body = _format_dict(cfg_dict, outer=True) + + import textwrap + + body = textwrap.dedent(body) + + try: + from yapf.yapflib.yapf_api import FormatCode + + yapf_style = dict( + based_on_style="pep8", + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True, + ) + body, _ = FormatCode(body, style_config=yapf_style) + except ImportError: + pass + except Exception as exc: # keep raw body if yapf fails + print(f"[pretty_text] yapf failed: {exc}\nReturning un-formatted text.") + return body + + +def print_cfg(cfg_dict): + try: + from rich.console import Console + from rich.syntax import Syntax + + console = Console() + code_str = pretty_text(cfg_dict) + syntax_block = Syntax(code_str, "python", theme="ansi_dark", word_wrap=False) + console.print(syntax_block, style="green") + return "---" + except ImportError: + from pprint import pformat + + GREEN = "\033[92m" + RESET = "\033[0m" + print(GREEN + pformat(cfg_dict, sort_dicts=False) + RESET) + return "---" diff --git a/sapiens/engine/datasets/__init__.py b/sapiens/engine/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..acea85b9e795aa9ebeaa26a09b03916fbcf9f7ad --- /dev/null +++ b/sapiens/engine/datasets/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .base_dataset import BaseDataset, Compose +from .combined_dataset import CombinedDataset +from .data_preprocessors import ImagePreprocessor +from .transforms import * + + +__all__ = [ + "CombinedDataset", + "BaseDataset", + "ImagePreprocessor", +] diff --git a/sapiens/engine/datasets/base_dataset.py b/sapiens/engine/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f95e9888413d08e62be228ddeb816aa5983da6 --- /dev/null +++ b/sapiens/engine/datasets/base_dataset.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import warnings +from abc import abstractmethod +from typing import Any, Callable, List, Optional, Sequence, Union + +import numpy as np +from sapiens.registry import TRANSFORMS +from torch.utils.data import Dataset + + +class Compose: + def __init__(self, transforms: Optional[Sequence[Union[dict, Callable]]]): + self.transforms = [] + for t in transforms or []: + if isinstance(t, dict): + t = TRANSFORMS.build(t) + if not callable(t): + raise TypeError(f"Transform must be callable, got {type(t)}") + self.transforms.append(t) + + def __call__(self, data: dict) -> Optional[dict]: + for t in self.transforms: + data = t(data) + if data is None: + return None + return data + + def __repr__(self): + return f"{self.__class__.__name__}({self.transforms})" + + +# ------------------------------------------------------------------------------- +class BaseDataset(Dataset): + def __init__( + self, + data_root: Optional[str] = "", + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + max_refetch: int = 1000, + ): + self.data_root = data_root + self.test_mode = test_mode + self.max_refetch = max_refetch + self.pipeline = Compose(pipeline) + self.data_list = self.load_data_list() + + def get_data_info(self, idx: int) -> dict: + data_info = copy.deepcopy(self.data_list[idx]) + if idx >= 0: + data_info["sample_idx"] = idx + else: + data_info["sample_idx"] = len(self) + idx + return data_info + + def __getitem__(self, idx: int) -> dict: + if self.test_mode: + data_info = self.get_data_info(idx) + if data_info is None: + warnings.warn( + f"Test time pipeline should not get `None` data_sample, index:{idx}, using idx=0 as default" + ) + return self.__getitem__(idx=0) + data = self.pipeline(data_info) + if data is None: + warnings.warn( + f"Test time pipeline outputs `None` for index:{idx}, using idx=0 as default" + ) + return self.__getitem__(idx=0) + + return data + + for _ in range(self.max_refetch + 1): + data = self.prepare_data(idx) + if data is None: + idx = self._rand_another() + continue + return data + + raise Exception(f"Cannot find valid data after {self.max_refetch}! ") + + @abstractmethod + def load_data_list(self) -> List[dict]: + pass + + def _rand_another(self) -> int: + return np.random.randint(0, len(self)) + + def __len__(self) -> int: + return len(self.data_list) + + def prepare_data(self, idx) -> Any: + data_info = self.get_data_info(idx) + if data_info is None: + return None + return self.pipeline(data_info) diff --git a/sapiens/engine/datasets/combined_dataset.py b/sapiens/engine/datasets/combined_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..548ed0264501faeb5ffce188a0a82794754e0e57 --- /dev/null +++ b/sapiens/engine/datasets/combined_dataset.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, List, Tuple, Union + +from sapiens.registry import DATASETS + +from .base_dataset import BaseDataset + + +@DATASETS.register_module() +class CombinedDataset(BaseDataset): + def __init__( + self, datasets: list, pipeline: List[Union[dict, Callable]] = [], **kwargs + ): + self.datasets = [] + for cfg in datasets: + dataset = DATASETS.build(cfg) + self.datasets.append(dataset) + + self._lens = [len(dataset) for dataset in self.datasets] + self._len = sum(self._lens) + + super(CombinedDataset, self).__init__(pipeline=pipeline, **kwargs) + assert len(self.datasets) > 0 + return + + def __len__(self): + return self._len + + def _get_subset_index(self, index: int) -> Tuple[int, int]: + if index >= len(self) or index < -len(self): + raise ValueError(f"index {index} out of bounds for length {len(self)}.") + + if index < 0: + index = index + len(self) + + subset_index = 0 + while index >= self._lens[subset_index]: + index -= self._lens[subset_index] + subset_index += 1 + return subset_index, index + + def prepare_data(self, idx: int) -> Any: + data_info = self.get_data_info(idx) + + if data_info is None: + return None + + for transform in self.pipeline.transforms: + data_info = transform(data_info) + + return data_info + + def get_data_info(self, idx: int) -> dict: + subset_idx, sample_idx = self._get_subset_index(idx) + data_info = self.datasets[subset_idx][sample_idx] + return data_info diff --git a/sapiens/engine/datasets/data_preprocessors/__init__.py b/sapiens/engine/datasets/data_preprocessors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6949c3d7a22416dee0663ed9112b36628f63baee --- /dev/null +++ b/sapiens/engine/datasets/data_preprocessors/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .base_preprocessor import BasePreprocessor +from .image_preprocessor import ImagePreprocessor + +__all__ = ["BasePreprocessor", "ImagePreprocessor"] diff --git a/sapiens/engine/datasets/data_preprocessors/base_preprocessor.py b/sapiens/engine/datasets/data_preprocessors/base_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..85b47ac20d6fc0d22a4c0228decff63b2206cba7 --- /dev/null +++ b/sapiens/engine/datasets/data_preprocessors/base_preprocessor.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Sequence +from typing import Any, List, Mapping, Optional, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.registry import MODELS + +CastData = Union[torch.Tensor, Mapping, Sequence, str, bytes, None] + + +# ------------------------------------------------------------------------------- +@MODELS.register_module() +class BasePreprocessor(nn.Module): + def __init__(self, non_blocking: Optional[bool] = False): + super().__init__() + self._non_blocking = non_blocking + self._device = torch.device("cpu") + + def is_seq_of( + self, + seq: Any, + expected_type: Union[Type, tuple], + seq_type: Optional[Type] = None, + ) -> bool: + if seq_type is None: + exp_seq_type = Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + def stack_batch( + self, + tensor_list: List[torch.Tensor], + pad_size_divisor: int = 1, + pad_value: Union[int, float] = 0, + ) -> torch.Tensor: + if not tensor_list: + raise ValueError("tensor_list cannot be empty") + if len({t.ndim for t in tensor_list}) != 1: + raise ValueError("All tensors must have same number of dimensions") + + ndim = tensor_list[0].ndim + shapes = torch.tensor([list(t.shape) for t in tensor_list]) + max_dims = ( + torch.ceil(torch.max(shapes, dim=0)[0] / pad_size_divisor) + * pad_size_divisor + ) + + # Don't pad channel dimension + pad_amounts = max_dims - shapes + pad_amounts[:, 0] = 0 + + if pad_amounts.sum() == 0: + return torch.stack(tensor_list) + + # Create padding tuples and pad tensors + padded = [] + for i, tensor in enumerate(tensor_list): + pad_tuple = [] + for j in range(ndim - 1, -1, -1): # Reverse order for F.pad + pad_tuple.extend([0, int(pad_amounts[i, j])]) + padded.append(F.pad(tensor, pad_tuple, value=pad_value)) + + return torch.stack(padded) + + def cast_data(self, data: CastData, device: torch.device) -> CastData: + if isinstance(data, Mapping): + return {key: self.cast_data(data[key], device) for key in data} + elif isinstance(data, (str, bytes)) or data is None: + return data + elif isinstance(data, Sequence): + return type(data)(self.cast_data(sample, device) for sample in data) + elif isinstance(data, (torch.Tensor)): + return data.to(device, non_blocking=self._non_blocking) + else: + return data + + def forward(self, data: dict, training: bool = False) -> Union[dict, list]: + raise NotImplementedError( + "The forward method must be implemented by a subclass." + ) diff --git a/sapiens/engine/datasets/data_preprocessors/image_preprocessor.py b/sapiens/engine/datasets/data_preprocessors/image_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..de3314cc2aefbc2742333af46ee03caa0ed6c82f --- /dev/null +++ b/sapiens/engine/datasets/data_preprocessors/image_preprocessor.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional, Sequence, Union + +import torch +import torch.nn.functional as F +from sapiens.registry import MODELS + +from .base_preprocessor import BasePreprocessor + + +@MODELS.register_module() +class ImagePreprocessor(BasePreprocessor): + def __init__( + self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + non_blocking: Optional[bool] = False, + ): + super().__init__(non_blocking) + self._validate_params(mean, std, bgr_to_rgb, rgb_to_bgr) + self._setup_normalization(mean, std) + self._channel_conversion = bgr_to_rgb or rgb_to_bgr + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + + def _validate_params(self, mean, std, bgr_to_rgb, rgb_to_bgr): + if bgr_to_rgb and rgb_to_bgr: + raise ValueError("Cannot set both bgr_to_rgb and rgb_to_bgr to True") + if (mean is None) != (std is None): + raise ValueError("mean and std must both be None or both be provided") + + def _setup_normalization(self, mean, std): + if mean is None: + self._enable_normalize = False + return + + if len(mean) not in [1, 3] or len(std) not in [1, 3]: + raise ValueError("mean and std must have 1 or 3 values") + + self._enable_normalize = True + self.register_buffer("mean", torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer("std", torch.tensor(std).view(-1, 1, 1), False) + + def _process_single_image(self, img: torch.Tensor) -> torch.Tensor: + if img.dtype not in [torch.uint8, torch.float16, torch.float32, torch.float64]: + raise TypeError(f"Unsupported image dtype: {img.dtype}") + + # Handle batched input (NCHW) + if img.dim() == 4: + if img.shape[1] != 3: + raise ValueError(f"Expected 3 channels in dim=1, got {img.shape}") + img = img.float() + if self._channel_conversion: + img = img[:, [2, 1, 0], ...] # BGR<->RGB + if self._enable_normalize: + img = (img - self.mean[None]) / self.std[None] + return img + + # Handle single image (CHW) + elif img.dim() == 3: + if img.shape[0] != 3: + raise ValueError(f"Expected 3 channels in dim=0, got {img.shape}") + img = img.float() + if self._channel_conversion: + img = img[[2, 1, 0], ...] + if self._enable_normalize: + img = (img - self.mean) / self.std + return img + + else: + raise ValueError(f"Expected 3D or 4D tensor, got shape {img.shape}") + + def _pad_tensor(self, tensor: torch.Tensor) -> torch.Tensor: + if self.pad_size_divisor <= 1: + return tensor + + h, w = tensor.shape[-2:] + target_h = math.ceil(h / self.pad_size_divisor) * self.pad_size_divisor + target_w = math.ceil(w / self.pad_size_divisor) * self.pad_size_divisor + + pad_h = target_h - h + pad_w = target_w - w + + if pad_h == 0 and pad_w == 0: + return tensor + + return F.pad(tensor, (0, pad_w, 0, pad_h), "constant", self.pad_value) + + def forward(self, data: dict) -> dict: + data = self.cast_data(data, device=self.mean.device) + inputs = data["inputs"] + + if self.is_seq_of(inputs, torch.Tensor): + # Process list of individual images + processed_imgs = [self._process_single_image(img) for img in inputs] + batch_inputs = self.stack_batch( + processed_imgs, self.pad_size_divisor, self.pad_value + ) + elif isinstance(inputs, torch.Tensor): + # Process batched tensor + if inputs.dim() == 4: + batch_inputs = self._process_single_image(inputs) + batch_inputs = self._pad_tensor(batch_inputs) + elif inputs.dim() == 5: + # inputs: (B, V, C, H, W) + B, V, C, H, W = inputs.shape + flat_inputs = inputs.view(B * V, C, H, W) + + processed = self._process_single_image(flat_inputs) + processed = self._pad_tensor(processed) + + batch_inputs = processed.view( + B, V, C, processed.shape[-2], processed.shape[-1] + ) + elif inputs.dim() == 3: + # Single image (C, H, W), unsqueeze to (1, C, H, W) + img = inputs.unsqueeze(0) + processed = self._process_single_image(img) + batch_inputs = self._pad_tensor(processed) + else: + raise ValueError( + f"Expected 3D, 4D or 5D tensor, got shape {inputs.shape}" + ) + else: + raise TypeError(f"Expected tensor or list of tensors, got {type(inputs)}") + + data["inputs"] = batch_inputs + data.setdefault("data_samples", None) + return data diff --git a/sapiens/engine/datasets/transforms/__init__.py b/sapiens/engine/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f8ec957f8e5a1dd78a84f649d51b8dbf1dbbe3 --- /dev/null +++ b/sapiens/engine/datasets/transforms/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .base_transform import BaseTransform, to_tensor +from .common_transforms import ( + ImagePackInputs, + ImageResize, + PhotoMetricDistortion, + RandomPhotoMetricDistortion, +) + +__all__ = [ + "to_tensor", + "BaseTransform", + "ImageResize", + "ImagePackInputs", + "PhotoMetricDistortion", + "RandomPhotoMetricDistortion", +] diff --git a/sapiens/engine/datasets/transforms/base_transform.py b/sapiens/engine/datasets/transforms/base_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..4728e752e9204397ac89ee3da47d83c6c490bfcb --- /dev/null +++ b/sapiens/engine/datasets/transforms/base_transform.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABCMeta, abstractmethod +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch + + +def to_tensor( + data: Union[torch.Tensor, np.ndarray, Sequence, int, float], +) -> torch.Tensor: + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f"type {type(data)} cannot be converted to tensor.") + + +class BaseTransform(metaclass=ABCMeta): + def __call__(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]: + return self.transform(results) + + @abstractmethod + def transform(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]: + pass diff --git a/sapiens/engine/datasets/transforms/common_transforms.py b/sapiens/engine/datasets/transforms/common_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..082e3cbfe673c251109aef5767c5eb8d5995baaa --- /dev/null +++ b/sapiens/engine/datasets/transforms/common_transforms.py @@ -0,0 +1,447 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random +from typing import Dict, List, Optional, Sequence + +import cv2 +import numpy as np +import torchvision.transforms as T +from sapiens.registry import TRANSFORMS + +from .base_transform import BaseTransform, to_tensor + + +@TRANSFORMS.register_module() +class ImageResize(BaseTransform): + def __init__( + self, + image_height: int, + image_width: int, + ): + self.image_height = image_height + self.image_width = image_width + + def transform(self, results: Dict) -> Optional[Dict]: + image = results["image"] + image = cv2.resize( + image, (self.image_width, self.image_height), interpolation=cv2.INTER_AREA + ) + results["image"] = image + + return results + + +@TRANSFORMS.register_module() +class ImagePackInputs(BaseTransform): + def __init__(self, meta_keys: List[str]): + self.meta_keys = meta_keys + self.to_tensor = T.ToTensor() + + def transform(self, results: Dict) -> Optional[Dict]: + packed_results = dict() + + raw_image = results["image"] + image = raw_image.copy() + if len(image.shape) < 3: + image = np.expand_dims(image, -1) + if not image.flags.c_contiguous: + image = to_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + else: + image = image.transpose(2, 0, 1) + image = to_tensor(image).contiguous() + packed_results["inputs"] = image + + data_samples = dict() + + # Pack the specified meta keys + for key in self.meta_keys: + if key in results: + data_samples[key] = results[key] + + data_samples["image"] = self.to_tensor(raw_image) + packed_results["data_samples"] = data_samples + + return packed_results + + +@TRANSFORMS.register_module() +class PhotoMetricDistortion(BaseTransform): + def __init__( + self, + brightness_delta: int = 32, + contrast_range: Sequence[float] = (0.5, 1.5), + saturation_range: Sequence[float] = (0.5, 1.5), + hue_delta: int = 18, + ): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def convert(self, img: np.ndarray, alpha: int = 1, beta: int = 0) -> np.ndarray: + img = img.astype(np.float32) * alpha + beta + img = np.clip(img, 0, 255) + return img.astype(np.uint8) + + def brightness(self, img: np.ndarray) -> np.ndarray: + if random.randint(0, 1): + return self.convert( + img, beta=random.uniform(-self.brightness_delta, self.brightness_delta) + ) + return img + + def contrast(self, img: np.ndarray) -> np.ndarray: + if random.randint(0, 1): + return self.convert( + img, alpha=random.uniform(self.contrast_lower, self.contrast_upper) + ) + return img + + def saturation(self, img: np.ndarray) -> np.ndarray: + if random.randint(0, 1): + img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + img[:, :, 1] = self.convert( + img[:, :, 1], + alpha=random.uniform(self.saturation_lower, self.saturation_upper), + ) + img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR) + return img + + def hue(self, img: np.ndarray) -> np.ndarray: + if random.randint(0, 1): + img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + img[:, :, 0] = ( + img[:, :, 0].astype(int) + + random.randint(-self.hue_delta, self.hue_delta) + ) % 180 + img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR) + return img + + def transform(self, results: dict) -> dict: + img = results["img"] + # random brightness + img = self.brightness(img) + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(0, 1) + if mode == 1: + img = self.contrast(img) + + # random saturation + img = self.saturation(img) + + # random hue + img = self.hue(img) + + # random contrast + if mode == 0: + img = self.contrast(img) + + results["img"] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += ( + f"(brightness_delta={self.brightness_delta}, " + f"contrast_range=({self.contrast_lower}, " + f"{self.contrast_upper}), " + f"saturation_range=({self.saturation_lower}, " + f"{self.saturation_upper}), " + f"hue_delta={self.hue_delta})" + ) + return repr_str + + +@TRANSFORMS.register_module() +class RandomPhotoMetricDistortion(PhotoMetricDistortion): + def __init__( + self, + prob: float = 0.5, + **kwargs, + ): + super().__init__(**kwargs) + self.prob = prob + + def transform(self, results: Dict) -> Optional[Dict]: + if np.random.rand() > self.prob: + return results + return super().transform(results) + + +@TRANSFORMS.register_module() +class RandomDownUpSampleImage(BaseTransform): + _INTERP_LIST = [ + cv2.INTER_NEAREST, + cv2.INTER_LINEAR, + cv2.INTER_CUBIC, + cv2.INTER_AREA, + cv2.INTER_LANCZOS4, + ] + + def __init__(self, scale_range=(0.1, 0.5), prob=0.4): + super().__init__() + self.scale_range = scale_range + self.prob = prob + + def transform(self, results: dict) -> dict: + if np.random.rand() > self.prob: + return results # Skip with probability (1 - prob) + + img = results["img"] + orig_h, orig_w = img.shape[:2] + + # Pick a random factor in [min_scale, max_scale] + min_scale, max_scale = self.scale_range + scale_factor = np.random.uniform(min_scale, max_scale) + + # Randomly select interpolation modes for downsampling and upsampling + down_interp = random.choice(self._INTERP_LIST) + up_interp = random.choice(self._INTERP_LIST) + + # Compute downsample size + down_w = max(1, int(orig_w * scale_factor)) + down_h = max(1, int(orig_h * scale_factor)) + + # Downsample + img_down = cv2.resize(img, (down_w, down_h), interpolation=down_interp) + img_up = cv2.resize(img_down, (orig_w, orig_h), interpolation=up_interp) + + # Replace the original image with the heavily down-up-sampled version + results["img"] = img_up + return results + + def __repr__(self): + return ( + f"{self.__class__.__name__}(scale_range={self.scale_range}, " + f"prob={self.prob})" + ) + + +@TRANSFORMS.register_module() +class RandomGaussianBlur(BaseTransform): + def __init__(self, prob=0.4, kernel_size=(3, 3), sigma_range=(0.1, 2.0)): + super().__init__() + self.prob = prob + self.kernel_size = kernel_size + self.sigma_range = sigma_range + + def transform(self, results: dict) -> dict: + if np.random.rand() > self.prob: + return results + + img = results["img"] + if self.sigma_range is not None: + sigma = np.random.uniform(self.sigma_range[0], self.sigma_range[1]) + else: + sigma = 0 # OpenCV auto-calculates + + blurred = cv2.GaussianBlur(img, self.kernel_size, sigma) + results["img"] = blurred + return results + + def __repr__(self): + return ( + f"{self.__class__.__name__}(prob={self.prob}, " + f"kernel_size={self.kernel_size}, sigma_range={self.sigma_range})" + ) + + +@TRANSFORMS.register_module() +class RandomJPEGCompression(BaseTransform): + def __init__(self, prob=0.4, quality_range=(30, 60)): + super().__init__() + self.prob = prob + self.quality_range = quality_range + + def transform(self, results: dict) -> dict: + if np.random.rand() > self.prob: + return results + + img = results["img"] + q_min, q_max = self.quality_range + quality = np.random.randint(q_min, q_max + 1) + + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] + success, enc_img = cv2.imencode(".jpg", img, encode_param) + if success: + dec_img = cv2.imdecode(enc_img, cv2.IMREAD_COLOR) + results["img"] = dec_img + return results + + def __repr__(self): + return ( + f"{self.__class__.__name__}(prob={self.prob}, " + f"quality_range={self.quality_range})" + ) + + +@TRANSFORMS.register_module() +class RandomGaussianNoise(BaseTransform): + def __init__(self, prob=0.4, mean=0.0, var_range=(5.0, 20.0)): + super().__init__() + self.prob = prob + self.mean = mean + self.var_range = var_range + + def transform(self, results: dict) -> dict: + if np.random.rand() > self.prob: + return results + + img = results["img"].astype(np.float32) + var = np.random.uniform(self.var_range[0], self.var_range[1]) + sigma = var**0.5 + + noise = np.random.normal(self.mean, sigma, img.shape).astype(np.float32) + noisy_img = img + noise + noisy_img = np.clip(noisy_img, 0, 255).astype(np.uint8) + + results["img"] = noisy_img + return results + + def __repr__(self): + return ( + f"{self.__class__.__name__}(prob={self.prob}, " + f"mean={self.mean}, var_range={self.var_range})" + ) + + +@TRANSFORMS.register_module() +class RandomGamma(BaseTransform): + def __init__(self, prob=0.4, gamma_range=(0.7, 1.3)): + super().__init__() + self.prob = prob + self.gamma_range = gamma_range + + def transform(self, results: dict) -> dict: + if np.random.rand() > self.prob: + return results + + img = results["img"] + gamma = np.random.uniform(self.gamma_range[0], self.gamma_range[1]) + + # Build a lookup table for [0..255] + table = ( + np.array([(i / 255.0) ** gamma * 255 for i in range(256)]) + .clip(0, 255) + .astype(np.uint8) + ) + img_corrected = cv2.LUT(img, table) + results["img"] = img_corrected + return results + + def __repr__(self): + return ( + f"{self.__class__.__name__}(prob={self.prob}, " + f"gamma_range={self.gamma_range})" + ) + + +@TRANSFORMS.register_module() +class RandomGrayscale(BaseTransform): + def __init__(self, prob=0.4): + super().__init__() + self.prob = prob + + def transform(self, results: dict) -> dict: + if np.random.rand() > self.prob: + return results + + img = results["img"] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + gray_3ch = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) + results["img"] = gray_3ch + return results + + def __repr__(self): + return f"{self.__class__.__name__}(prob={self.prob})" + + +@TRANSFORMS.register_module() +class RandomChannelShuffle(BaseTransform): + def __init__(self, prob=0.4): + super().__init__() + self.prob = prob + + def transform(self, results: dict) -> dict: + if np.random.rand() > self.prob: + return results + + img = results["img"] + channels = [0, 1, 2] + np.random.shuffle(channels) + img = img[..., channels] + results["img"] = img + return results + + def __repr__(self): + return f"{self.__class__.__name__}(prob={self.prob})" + + +@TRANSFORMS.register_module() +class RandomInvert(BaseTransform): + def __init__(self, prob=0.4): + super().__init__() + self.prob = prob + + def transform(self, results: dict) -> dict: + if np.random.rand() > self.prob: + return results + + img = results["img"] + results["img"] = 255 - img + return results + + def __repr__(self): + return f"{self.__class__.__name__}(prob={self.prob})" + + +@TRANSFORMS.register_module() +class RandomSolarize(BaseTransform): + def __init__(self, prob=0.4, threshold=128): + super().__init__() + self.prob = prob + self.threshold = threshold + + def transform(self, results: dict) -> dict: + if np.random.rand() > self.prob: + return results + + img = results["img"] + mask = img > self.threshold + img[mask] = 255 - img[mask] + results["img"] = img + return results + + def __repr__(self): + return ( + f"{self.__class__.__name__}(prob={self.prob}, threshold={self.threshold})" + ) + + +@TRANSFORMS.register_module() +class RandomPosterize(BaseTransform): + def __init__(self, prob=0.4, bits=(2, 5)): + super().__init__() + self.prob = prob + self.bits = bits + + def transform(self, results: dict) -> dict: + if np.random.rand() > self.prob: + return results + + img = results["img"] + # pick random bits + bits_chosen = random.randint(self.bits[0], self.bits[1]) + shift = 8 - bits_chosen + img = (img >> shift) << shift + results["img"] = img + return results + + def __repr__(self): + return f"{self.__class__.__name__}(prob={self.prob}, bits={self.bits})" diff --git a/sapiens/engine/evaluators/__init__.py b/sapiens/engine/evaluators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f19020eb66d2783207a825eb64801beb9fc5feea --- /dev/null +++ b/sapiens/engine/evaluators/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .base_evaluator import BaseEvaluator +from .eval_collate import eval_collate + +__all__ = ["eval_collate", "BaseEvaluator"] diff --git a/sapiens/engine/evaluators/base_evaluator.py b/sapiens/engine/evaluators/base_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..bb880d5b96a9dfd10c812cc782314f6a47b4ad73 --- /dev/null +++ b/sapiens/engine/evaluators/base_evaluator.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.distributed as dist +from sapiens.registry import MODELS + + +@MODELS.register_module() +class BaseEvaluator: + def __init__(self, dtype: torch.dtype = torch.float32): + assert torch.cuda.is_available(), "CUDA is required for evaluation" + self.device = torch.device("cuda", torch.cuda.current_device()) + self.dtype = dtype + self.results = [] + + def reset(self): + self.results: List[Union[Dict[str, Any], List[Any], tuple]] = [] + + def process(self, outputs, data_samples): + raise NotImplementedError + + def evaluate(self): + raise NotImplementedError diff --git a/sapiens/engine/evaluators/eval_collate.py b/sapiens/engine/evaluators/eval_collate.py new file mode 100644 index 0000000000000000000000000000000000000000..5a4bd5577d667fc7db7c4f21790a8a2397e76dd7 --- /dev/null +++ b/sapiens/engine/evaluators/eval_collate.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from sapiens.registry import MODELS +from torch.utils.data import default_collate + + +@MODELS.register_module() +def eval_collate(batch: list): + passthrough_keys = {"data_samples"} + collated_data, passthrough_data = [], {key: [] for key in passthrough_keys} + for item in batch: + item_for_collation = { + k: v for k, v in item.items() if k not in passthrough_keys + } + for key in passthrough_keys: + passthrough_data[key].append(item[key]) + collated_data.append(item_for_collation) + final_batch = default_collate(collated_data) + final_batch.update(passthrough_data) + return final_batch diff --git a/sapiens/engine/logger.py b/sapiens/engine/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d33bc6d1daccae77d62c72df809ad70e3b54f2 --- /dev/null +++ b/sapiens/engine/logger.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +from datetime import datetime +from logging import Logger as BaseLogger, LogRecord +from pathlib import Path +from typing import Optional, Union + +from sapiens.registry import LOGGERS +from termcolor import colored + + +class FilterDuplicateWarning(logging.Filter): + def __init__(self, name: str = "sapiens"): + super().__init__(name) + self.seen: set = set() + + def filter(self, record: LogRecord) -> bool: + """Filter the repeated warning message.""" + if record.levelno != logging.WARNING: + return True + + if record.msg not in self.seen: + self.seen.add(record.msg) + return True + return False + + +class Formatter(logging.Formatter): + """Colorful format forLogger.""" + + _color_mapping: dict = dict( + ERROR="red", WARNING="yellow", INFO="white", DEBUG="green" + ) + + def __init__(self, color: bool = True, blink: bool = False, **kwargs): + super().__init__(**kwargs) + assert not (not color and blink), ( + "blink should only be available when color is True" + ) + + # Get prefix format according to color. + error_prefix = self._get_prefix("ERROR", color, blink=True) + warn_prefix = self._get_prefix("WARNING", color, blink=True) + info_prefix = self._get_prefix("INFO", color, blink) + debug_prefix = self._get_prefix("DEBUG", color, blink) + + # Config output format. + self.err_format = ( + f"%(asctime)s - %(name)s - {error_prefix} - " + "%(pathname)s - %(funcName)s - %(lineno)d - " + "%(message)s" + ) + self.warn_format = f"%(asctime)s - %(name)s - {warn_prefix} - %(message)s" + self.info_format = f"%(asctime)s - %(name)s - {info_prefix} - %(message)s" + self.debug_format = f"%(asctime)s - %(name)s - {debug_prefix} - %(message)s" + + def _get_prefix(self, level: str, color: bool, blink=False) -> str: + """Get the prefix of the target log level.""" + if color: + attrs = ["underline"] + if blink: + attrs.append("blink") + prefix = colored(level, self._color_mapping[level], attrs=attrs) + else: + prefix = level + return prefix + + def format(self, record: LogRecord) -> str: + """Override the logging.Formatter.format method.""" + if record.levelno == logging.ERROR: + self._style._fmt = self.err_format + elif record.levelno == logging.WARNING: + self._style._fmt = self.warn_format + elif record.levelno == logging.INFO: + self._style._fmt = self.info_format + elif record.levelno == logging.DEBUG: + self._style._fmt = self.debug_format + + result = logging.Formatter.format(self, record) + return result + + +@LOGGERS.register_module() +class Logger(BaseLogger): + """Standalone logger.""" + + _instances = {} + + def __init__( + self, + name: str = "sapiens", + logger_name: str = "sapiens", + log_file: Optional[str] = None, + log_level: Union[int, str] = "INFO", + file_mode: str = "w", + log_interval: int = 10, + dir: Optional[Union[str, Path]] = None, + **kwargs, + ): + super().__init__(logger_name) + + self._name = name + self._log_interval = log_interval + time_log_file = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + self._log_file = log_file or ( + Path(dir).mkdir(parents=True, exist_ok=True) + or Path(dir) / time_log_file / f"{time_log_file}.log" + if dir + else None + ) + # pyre-ignore + self._log_dir = os.path.dirname(self._log_file) if self._log_file else None + + log_level = ( + logging._nameToLevel[log_level] if isinstance(log_level, str) else log_level + ) + + self._add_stream_handler(log_level, logger_name) + self._add_file_handler(log_level, logger_name, file_mode) + + Logger._instances[name] = self + + def _add_stream_handler(self, level: int, logger_name: str): + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(level) + handler.setFormatter(Formatter(color=True, datefmt="%m/%d %H:%M:%S")) + handler.addFilter(FilterDuplicateWarning(logger_name)) + self.addHandler(handler) + + def _add_file_handler(self, level: int, logger_name: str, mode: str): + if self._log_file is None: + return + os.makedirs(os.path.dirname(self._log_file), exist_ok=True) + handler = logging.FileHandler(self._log_file, mode) + handler.setLevel(level) + handler.setFormatter(Formatter(color=False, datefmt="%Y/%m/%d %H:%M:%S")) + handler.addFilter(FilterDuplicateWarning(logger_name)) + self.addHandler(handler) + + @property + def log_file(self): + return self._log_file + + @classmethod + def get_instance(cls, name: str, **kwargs) -> "Logger": + """Get or create a logger instance.""" + if name not in cls._instances: + cls._instances[name] = cls(name, **kwargs) + return cls._instances[name] + + @classmethod + def get_current_instance(cls) -> "Logger": + """Get the most recently created logger instance.""" + if not cls._instances: + cls.get_instance("lca") + return list(cls._instances.values())[-1] + + +def print_log( + msg, logger: Optional[Union[Logger, str]] = None, level=logging.INFO +) -> None: + if logger is None: + print(msg) + elif isinstance(logger, logging.Logger): + logger.log(level, msg) + elif logger == "silent": + pass + elif logger == "current": + logger_instance = Logger.get_current_instance() + logger_instance.log(level, msg) + elif isinstance(logger, str): + if logger in Logger._instances: + logger_instance = Logger.get_instance(logger) + logger_instance.log(level, msg) + else: + raise ValueError(f"Logger: {logger} has not been created!") + else: + raise TypeError( + "`logger` should be either a logging.Logger object, str, " + f'"silent", "current" or None, but got {type(logger)}' + ) diff --git a/sapiens/engine/models/__init__.py b/sapiens/engine/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1532c81c3da3dc0eb01028db208a77403aae0cdc --- /dev/null +++ b/sapiens/engine/models/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .base_model import BaseModel + +__all__ = ["BaseModel"] diff --git a/sapiens/engine/models/base_model.py b/sapiens/engine/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3933dbe0b513a32d8c67e9d90a51a98bb47f3f14 --- /dev/null +++ b/sapiens/engine/models/base_model.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import abc +import math +import socket +import warnings +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Tuple, Union + +import safetensors.torch +import torch +from sapiens.registry import MODELS +from torch import nn, Tensor + + +def is_list_of(seq: Any, expected_type: Union[type, tuple[type, ...]]) -> bool: + """Check if sequence is list of expected type.""" + if not isinstance(seq, list): + return False + return all(isinstance(item, expected_type) for item in seq) + + +def _no_grad_trunc_normal_( + tensor: Tensor, mean: float, std: float, a: float, b: float +) -> Tensor: + def norm_cdf(x): + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + lower = norm_cdf((a - mean) / std) + upper = norm_cdf((b - mean) / std) + tensor.uniform_(2 * lower - 1, 2 * upper - 1) + tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_( + tensor: Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> Tensor: + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# ---------------------------------------------------------------------------- +@MODELS.register_module() +class BaseModel(nn.Module, abc.ABC): + def __init__( + self, + init_cfg: Optional[dict] = None, + ): + super().__init__() + self.init_cfg = init_cfg + + def init_weights(self): + if self.init_cfg is None: + return + + if not isinstance(self.init_cfg, dict): + raise TypeError(f"init_cfg must be a dict, got {type(self.init_cfg)}") + + init_type = self.init_cfg.get("type", "") + + if init_type == "Pretrained": + checkpoint_path = self.init_cfg.get("checkpoint") + if checkpoint_path is None: + raise ValueError( + "checkpoint path must be provided for Pretrained init_cfg" + ) + self._load_checkpoint(checkpoint_path) + elif init_type == "": + raise ValueError("init_cfg must specify a 'type' field") + else: + raise ValueError(f"Unsupported init_cfg type: {init_type}") + + def _load_checkpoint(self, checkpoint_path: str): + """Load model weights from checkpoint.""" + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + servername = socket.gethostname().split(".")[0] + + if rank == 0: + from sapiens.engine.logger import Logger + + logger = Logger.get_current_instance() + logger.info(f"Loading checkpoint from {checkpoint_path} on {servername}.") + + try: + if checkpoint_path.endswith(".safetensors"): + state_dict = safetensors.torch.load_file(checkpoint_path) + else: + checkpoint = torch.load( + checkpoint_path, map_location="cpu", weights_only=False + ) + + # Handle different checkpoint formats + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: + state_dict = checkpoint["model"] + elif "teacher" in checkpoint: + state_dict = checkpoint["teacher"] + # Remove 'backbone.' prefix from state_dict keys if present + state_dict = { + key.replace("backbone.", "", 1) + if key.startswith("backbone.") + else key: value + for key, value in state_dict.items() + } + else: + state_dict = checkpoint + + # Load state dict with strict=False to allow partial loading + missing_keys, unexpected_keys = self.load_state_dict( + state_dict, strict=False + ) + + if missing_keys and rank == 0: + logger.warning(f"Missing keys when loading checkpoint: {missing_keys}") + if unexpected_keys and rank == 0: + logger.warning( + f"Unexpected keys when loading checkpoint: {unexpected_keys}" + ) + + if rank == 0: + logger.info(f"Checkpoint {checkpoint_path} loaded successfully!") + + except Exception as e: + raise RuntimeError(f"Failed to load checkpoint from {checkpoint_path}: {e}") + + @abc.abstractmethod + def forward(self, inputs: torch.Tensor): + """Forward function.""" + pass + + def parse_losses( + self, losses: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + log_vars = [] + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars.append([loss_name, loss_value.mean()]) + elif is_list_of(loss_value, torch.Tensor): + log_vars.append([loss_name, sum(_loss.mean() for _loss in loss_value)]) + else: + raise TypeError(f"{loss_name} is not a tensor or list of tensors") + + loss = sum(value for key, value in log_vars if "loss" in key) + log_vars.insert(0, ["loss", loss]) + log_vars = OrderedDict(log_vars) # type: ignore + + return loss, log_vars # type: ignore diff --git a/sapiens/engine/optim/__init__.py b/sapiens/engine/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa28a16e777b02cae577c64e473e442e5549c1a --- /dev/null +++ b/sapiens/engine/optim/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .lr_scheduler import * +from .optimizer import * diff --git a/sapiens/engine/optim/lr_scheduler.py b/sapiens/engine/optim/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..30514048e738e847383954b43545711c112d176f --- /dev/null +++ b/sapiens/engine/optim/lr_scheduler.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from sapiens.registry import SCHEDULERS +from torch.optim.lr_scheduler import ( + _LRScheduler, + ConstantLR, + CosineAnnealingLR, + ExponentialLR, + LinearLR, + MultiStepLR, + PolynomialLR, + SequentialLR as _SequentialLR, + StepLR, +) + +SCHEDULERS.register_module(name="LinearLR")(LinearLR) +SCHEDULERS.register_module(name="PolynomialLR")(PolynomialLR) +SCHEDULERS.register_module(name="CosineAnnealingLR")(CosineAnnealingLR) +SCHEDULERS.register_module(name="ConstantLR")(ConstantLR) +SCHEDULERS.register_module(name="StepLR")(StepLR) +SCHEDULERS.register_module(name="MultiStepLR")(MultiStepLR) +SCHEDULERS.register_module(name="ExponentialLR")(ExponentialLR) + + +# ------------------------------------------------------------------------- # +@SCHEDULERS.register_module(name="SequentialLR") +class SequentialLR(_SequentialLR): + """SequentialLR that accepts inner schedulers as config dicts. + + Example (iteration based): + + ```python + warmup_iters = 400 + param_scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, + total_iters=warmup_iters), + dict(type="PolynomialLR", total_iters=num_iters-warmup_iters, + power=1.0), + ], + ) + ``` + """ + + def __init__( + self, + optimizer, + schedulers, + milestones, + last_epoch: int = -1, + ): + built = [ + s + if isinstance(s, _LRScheduler) + else SCHEDULERS.build(s, optimizer=optimizer) + for s in schedulers + ] + super().__init__( + optimizer, + schedulers=built, + milestones=milestones, + last_epoch=last_epoch, + ) diff --git a/sapiens/engine/optim/optimizer.py b/sapiens/engine/optim/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..989ee34e2c71146438d18e1e6f8f0ab32545b1fc --- /dev/null +++ b/sapiens/engine/optim/optimizer.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from sapiens.registry import OPTIMIZERS +from torch.optim import Adam, AdamW, SGD + +OPTIMIZERS.register_module(name="AdamW")(AdamW) +OPTIMIZERS.register_module(name="Adam")(Adam) +OPTIMIZERS.register_module(name="SGD")(SGD) diff --git a/sapiens/engine/runners/__init__.py b/sapiens/engine/runners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..106e7b10d1b330327e59787795eeced9bbbcccb9 --- /dev/null +++ b/sapiens/engine/runners/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .base_runner import BaseRunner + +__all__ = ["BaseRunner"] diff --git a/sapiens/engine/runners/base_runner.py b/sapiens/engine/runners/base_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..0f9b2c576ee07b75bc63279c30ee51446f20f108 --- /dev/null +++ b/sapiens/engine/runners/base_runner.py @@ -0,0 +1,984 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import os +import random +import reprlib +import time +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict + +import numpy as np +import torch +import torch.distributed as dist +from accelerate import Accelerator +from accelerate.parallelism_config import ParallelismConfig +from accelerate.utils import ( + DistributedDataParallelKwargs, + FullyShardedDataParallelPlugin, + TorchDynamoPlugin, +) +from safetensors.torch import load_file +from sapiens.registry import ( + DATASETS, + LOGGERS, + MODELS, + OPTIMIZERS, + SCHEDULERS, + VISUALIZERS, +) +from torch import nn +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.utils.data import DataLoader + +from ..config import pretty_text + +_repr = reprlib.Repr() +_repr.maxlist = 10 + + +# --------------------------------------------------------------------------- +class BaseRunner: + def __init__( + self, + *, + model: dict | nn.Module, + work_dir: str, + train_dataloader: dict | DataLoader | None = None, + val_dataloader: dict | None = None, + val_cfg: dict | None = None, + data_preprocessor: dict | None = None, + accelerator_cfg: Dict[str, Any], + optimizer: dict | torch.optim.Optimizer, + scheduler: dict | None = None, + clip_grad: Dict[str, Any] | None = None, + logger: dict | None = None, + checkpoint: dict | None = None, + visualizer: dict | None = None, + randomness: Dict[str, Any] | None = None, + cfg: Dict[str, Any] | None = None, + **_ignored, + ) -> None: + self.cfg = cfg + self.work_dir = Path(work_dir).resolve() + self.work_dir.mkdir(parents=True, exist_ok=True) + self._init_env() + self._set_seed(randomness or {}) + self._init_logger(logger=logger) + self._log_config() + self._init_accelerator(accelerator_cfg) + + # train dataloader + self.train_dataloader = None + if train_dataloader is not None: + train_dataset = DATASETS.build(train_dataloader["dataset"]) + self.train_dataloader = DataLoader( + train_dataset, + batch_size=train_dataloader.get("batch_size", 1), + shuffle=train_dataloader.get("shuffle", True), + num_workers=train_dataloader.get("num_workers", 0), + persistent_workers=train_dataloader.get("persistent_workers", True), + pin_memory=train_dataloader.get("pin_memory", True), + ) + + # val dataloader + self.val_dataloader = None + if val_dataloader is not None and val_cfg is not None: + val_dataset = DATASETS.build(val_dataloader["dataset"]) + collate_fn_cfg = val_dataloader.get("collate_fn") + collate_fn_obj = ( + MODELS.get(collate_fn_cfg["type"]) if collate_fn_cfg else None + ) + self.val_dataloader = DataLoader( + val_dataset, + batch_size=val_dataloader.get("batch_size", 1), + shuffle=val_dataloader.get("shuffle", False), + num_workers=val_dataloader.get("num_workers", 0), + persistent_workers=val_dataloader.get("persistent_workers", True), + pin_memory=val_dataloader.get("pin_memory", True), + collate_fn=collate_fn_obj, + multiprocessing_context=val_dataloader.get( + "multiprocessing_context", None + ), + ) + self.val_cfg = val_cfg + self.val_every = self.val_cfg.get("val_interval", 100) + self.evaluator = MODELS.build(self.val_cfg["evaluator"]) + + self.data_preprocessor = MODELS.build(data_preprocessor) # data_preprocessor + self.model = MODELS.build(model) + + # optimizer, scheduler, clip_grad + self.optimizer = self._build_optimizer(optimizer) + self.scheduler = SCHEDULERS.build(scheduler, optimizer=self.optimizer) + self.clip_grad = clip_grad # clip_grad + + self.visualizer = None + if self.train_dataloader is not None: + self.visualizer = ( + VISUALIZERS.build( + {**visualizer, "output_dir": self.work_dir / "vis_data"} + ) + if visualizer + else None + ) + + # prepare + self._prepare_accelerator() + self._print_model() + + ## logging params + self.log_every = self.logger._log_interval if self.logger else 0 + self.save_every = (checkpoint or {}).get("save_interval", 0) + self.vis_every = self.visualizer.vis_interval if self.visualizer else 0 + + # -------------------------------------------------------------------------- + def train(self) -> None: + self.model.train() + + data_iter = iter(self.train_dataloader) + + while self.iter < self.max_iters: + t = time.time() + + if not self.gpu_profiler_disabled: + self.gpu_profiler.before_step() + + try: + data_batch = next(data_iter) + except StopIteration: + data_iter = iter(self.train_dataloader) + data_batch = next(data_iter) + data_time = time.time() - t + + # ------------------------------------------------------ + with self.accelerator.autocast(), self.accelerator.accumulate(self.model): + t = time.time() + + loss, logs = self.forward(data_batch) + self.accelerator.backward(loss) # backward + + # step + grad_norm = self._clip_gradients() + self.optimizer.step() + self.scheduler.step() + self.optimizer.zero_grad() + + iter_time = time.time() - t + + # ------------------------------------------------------ + self.iter += 1 + + if not self.gpu_profiler_disabled: + self.gpu_profiler.after_step() + + # ------------------------------------------------------ + if self.save_every and self.iter % self.save_every == 0 and self.iter > 0: + self._save_checkpoint(f"iter_{self.iter}") + + # ------------------------------------------------------ + if ( + self.visualizer + and self.iter % self.vis_every == 0 + and self.accelerator.is_main_process + ): + self.visualizer.add_batch(data_batch, logs, step=self.iter) + self.logger.info(f"\033[96mVisualized iter {self.iter}\033[0m") + + # ------------------------------------------------------ + if self.val_dataloader is not None and self.iter % self.val_every == 0: + val_metrics = self.val() + logs["val_metrics"] = val_metrics + + if self.accelerator.is_main_process: + self._log_iter( + logs=logs, + iter_time=iter_time, + data_time=data_time, + grad_norm=grad_norm, + ) + + # ------------------------------------------------- + self._save_checkpoint("final") + + self.accelerator.save_model(self.model, self.work_dir / "checkpoints") + self.accelerator.end_training() + + if self.accelerator.is_main_process: + self.logger.info("\033[92mTraining finished โœ”\033[0m") + + # ------------------------------------------------------------------------- + def forward(self, data_batch: dict) -> tuple[float, dict]: + data_batch = self.data_preprocessor(data_batch) # preprocess + inputs, data_samples = data_batch["inputs"], data_batch["data_samples"] + + if self.pc is not None: + pred = self.model(inputs, cp=self.accelerator.maybe_context_parallel) + else: + pred = self.model(inputs) # forward + + loss, logs = self.raw_model.loss(pred, data_samples) + return loss, logs + + # ------------------------------------------------------------------------- + def test(self) -> None: + if self.accelerator.is_main_process: + self.logger.info(f"\033[95mStarting test...\033[0m") + + self.model.eval() + self.evaluator.reset() + for i, data_batch in enumerate(self.val_dataloader): + data_batch = self.data_preprocessor(data_batch) # preprocess + inputs, data_samples = data_batch["inputs"], data_batch["data_samples"] + + with torch.no_grad(): + if self.pc is not None: + pred = self.model( + inputs, cp=self.accelerator.maybe_context_parallel + ) + else: + pred = self.model(inputs) # forward + + if self.accelerator.is_main_process and i > 0 and i % 100 == 0: + self.logger.info( + f"\033[95mTest: {i}/{len(self.val_dataloader)}: batch_size: {len(data_batch['inputs'])}\033[0m" + ) + self.evaluator.process( + pred, data_samples, accelerator=self.accelerator + ) ## accelerator used to gather and dedup in val + + # metrics eval on main process + metrics = self.evaluator.evaluate( + logger=self.logger, accelerator=self.accelerator + ) + if self.accelerator.is_main_process: + self.logger.info( + f"\033[95mTest: {', '.join([f'{k}: {v:.4f}' for k, v in metrics.items()])}\033[0m" + ) + self.logger.info(f"\033[95mTesting finished โœ”\033[0m") + + # ------------------------------------------------------------------------- + def val(self) -> None: + self.model.eval() + + if self.accelerator.is_main_process: + self.logger.info(f"\033[95mValidating iter {self.iter}\033[0m") + + self.evaluator.reset() + for i, data_batch in enumerate(self.val_dataloader): + data_batch = self.data_preprocessor(data_batch) # preprocess + inputs, data_samples = data_batch["inputs"], data_batch["data_samples"] + + with torch.no_grad(): + if self.pc is not None: + pred = self.model( + inputs, cp=self.accelerator.maybe_context_parallel + ) + else: + pred = self.model(inputs) # forward + + if self.accelerator.is_main_process and i > 0 and i % 100 == 0: + self.logger.info( + f"\033[95mVal: {i}/{len(self.val_dataloader)}: batch_size: {len(data_batch['inputs'])}\033[0m" + ) + + self.evaluator.process(pred, data_samples, accelerator=self.accelerator) + metric = self.evaluator.evaluate( + logger=self.logger, accelerator=self.accelerator + ) + self.model.train() + return metric + + # -------------------------------------------------------------------------- + def _clip_gradients(self) -> float | None: + if not self.clip_grad or not self.accelerator.sync_gradients: + return None + + max_norm = float(self.clip_grad.get("max_norm", 1.0)) + norm_type = float(self.clip_grad.get("norm_type", 2.0)) + total_norm = self.accelerator.clip_grad_norm_( + self.model.parameters(), max_norm, norm_type + ) + + return total_norm + + def _log_iter(self, *, logs, iter_time, data_time, grad_norm=None): + """Call once per iteration; prints every `self._log_every` steps.""" + log_payload = {} + if "val_metrics" in logs: + val_metrics = logs.pop("val_metrics") + log_payload.update(val_metrics) + self.logger.info( + f"\033[95mVal-Iter[{self.iter}]: {', '.join([f'{k}: {v:.4f}' for k, v in val_metrics.items()])}\033[0m" + ) + + ## aggregate losses and metrics + for key in logs: + if key.startswith("loss_") or key.startswith("acc_"): + self._loss_acc[key] += float(logs[key].item()) + + self._time_acc += iter_time + self._data_acc += data_time + + if isinstance(grad_norm, torch.Tensor): + grad_norm = grad_norm.item() + + if grad_norm is not None: + self._grad_acc += float(grad_norm) + + # log every `self._log_every` steps + if ( + self.log_every > 0 + and (self.iter % self.log_every == 0 or self.iter == self.max_iters - 1) + and self.iter > 0 + ): + k = self.log_every + avg_losses = { + key: val / k + for key, val in self._loss_acc.items() + if key.startswith("loss_") + } + total_avg_loss = sum(avg_losses.values()) + avg_time = self._time_acc / k + avg_data_time = self._data_acc / k + avg_grad = self._grad_acc / k if self._grad_acc else 0.0 + + eta_secs = avg_time * (self.max_iters - self.iter) + eta = str(datetime.timedelta(seconds=int(eta_secs))) + mem_mb = int(torch.cuda.max_memory_allocated() / 1024 / 1024) + + loss_str_parts = [f"{key}: {val:.4f}" for key, val in avg_losses.items()] + loss_str = f"loss: {total_avg_loss:.4f} {' '.join(loss_str_parts)}" + + acc_str = "" + for key, val in self._loss_acc.items(): + if key.startswith("acc_"): + acc_str += f"{key}: {val / k:.4f} " + if acc_str: + loss_str += f" {acc_str}" + + if ( + self.optimizer.param_groups[0]["lr"] + != self.optimizer.param_groups[-1]["lr"] + ): + decayed_lr = self.optimizer.param_groups[0]["lr"] + lr = self.optimizer.param_groups[-1]["lr"] + + lr_str = f"lr: {lr:.2e} decay_lr: {decayed_lr:.2e}" + else: + lr_str = f"lr: {self.optimizer.param_groups[0]['lr']:.2e}" + + self.logger.info( + f"Iter(train) [{self.iter}/{self.max_iters}]: " + f"{lr_str} " + f"eta: {eta} " + f"data_time: {avg_data_time:.2f} " + f"iter_time: {avg_time:.2f} " + f"memory: {mem_mb} " + f"grad_norm: {avg_grad:.2f} " + f"{loss_str}" + ) + + log_payload.update( + { + "loss": total_avg_loss, + "lr": self.optimizer.param_groups[0]["lr"], + "grad_norm": avg_grad, + "iter_time": avg_time, + "data_time": avg_data_time, + **avg_losses, # Add individual average losses + } + ) + + self.accelerator.log(log_payload, step=self.iter) + self._loss_acc.clear() + self._time_acc = self._data_acc = self._grad_acc = 0.0 + + # -------------------------------------------------------------------------- + def _save_checkpoint(self, tag: str) -> None: + checkpoint_dir = self.work_dir / "checkpoints" / tag + self.accelerator.save_state(output_dir=checkpoint_dir) + + if self.accelerator.is_main_process: + self.logger.info( + f"\033[92mCheckpoint saved โžœ {os.path.basename(checkpoint_dir)}\033[0m" + ) + + # -------------------------------------------------------------------------- + def state_dict(self) -> Dict[str, Any]: + """ + Custom state to be saved by Accelerator. + """ + return {"iter": torch.tensor(self.iter, dtype=torch.int64, device="cpu")} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Load custom state saved by Accelerator. + """ + self.iter = int(state_dict["iter"]) + + def _init_env(self): + """Setup distributed environment variables if not already set.""" + if "RANK" not in os.environ: + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("LOCAL_RANK", "0") + os.environ.setdefault("LOCAL_WORLD_SIZE", "1") + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = f"127.0.0.{random.randint(1, 255)}" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = str(random.randint(1024, 65535)) + + def _init_accelerator(self, accelerator_cfg) -> None: + """Initialize Accelerator.""" + self.accelerator_cfg = accelerator_cfg.copy() + + compile_cfg = accelerator_cfg.pop("compile_cfg", {}) + dynamo_plugin = TorchDynamoPlugin(**compile_cfg) if compile_cfg else None + + self.dist_type = accelerator_cfg.pop("type", "DDP").upper() # "DDP" | "FSDP" + fsdp_cfg = accelerator_cfg.pop("fsdp_cfg", {}) + parallelism_cfg = accelerator_cfg.pop("parallelism_cfg", {}) + self.max_iters = int(accelerator_cfg.pop("max_interval", 1e4)) + self.pc = None + + find_unused_parameters = bool( + accelerator_cfg.pop("find_unused_parameters", False) + ) + + common_kwargs = dict( + project_dir=self.work_dir, + dynamo_plugin=dynamo_plugin, + **accelerator_cfg, + ) + + if self.dist_type == "FSDP": + policy_name = fsdp_cfg.pop("auto_wrap_policy", "none") + min_params = fsdp_cfg.pop("auto_wrap_min_num_params", 1e6) + + if policy_name == "size_based": + fsdp_cfg["min_num_params"] = min_params + elif policy_name == "transformer": + fsdp_cfg["auto_wrap_policy"] = transformer_auto_wrap_policy + + mp_cfg = fsdp_cfg.pop("mixed_precision", None) + + if mp_cfg: + _DTYPE = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + } + fsdp_cfg["mixed_precision_policy"] = MixedPrecisionPolicy( + param_dtype=_DTYPE.get(mp_cfg.get("param_dtype", "fp32")), + reduce_dtype=_DTYPE.get(mp_cfg.get("reduce_dtype", "fp32")), + ) + fsdp_plugin = FullyShardedDataParallelPlugin(**fsdp_cfg) + + # https://docs.axolotl.ai/docs/nd_parallelism.html + self.pc = ( + ParallelismConfig( + **parallelism_cfg, + ) + if parallelism_cfg + else None + ) + + self.accelerator = Accelerator( + parallelism_config=self.pc, fsdp_plugin=fsdp_plugin, **common_kwargs + ) + + else: # DDP (default) + if find_unused_parameters: + common_kwargs["kwargs_handlers"] = [ + DistributedDataParallelKwargs(find_unused_parameters=True) + ] + self.accelerator = Accelerator(**common_kwargs) + + if self.logger is not None: + self.accelerator.init_trackers(self.logger._log_dir) + + def _prepare_accelerator(self) -> None: + self.iter = 0 + self._loss_acc = defaultdict(float) + self._time_acc = self._data_acc = self._grad_acc = 0.0 + + self.accelerator.register_for_checkpointing(self) + + load_from = self.cfg.get("load_from", None) # path or None + resume = self.cfg.get("resume", False) + + if load_from and not resume: + self._load_checkpoint(load_from) + + ## train + val + if self.train_dataloader is not None and self.val_dataloader is not None: + ( + self.model, + self.optimizer, + self.train_dataloader, + self.scheduler, + self.val_dataloader, + self.evaluator, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + self.train_dataloader, + self.scheduler, + self.val_dataloader, + self.evaluator, + ) + ## train only + elif self.train_dataloader is not None and self.val_dataloader is None: + self.model, self.optimizer, self.train_dataloader, self.scheduler = ( + self.accelerator.prepare( + self.model, self.optimizer, self.train_dataloader, self.scheduler + ) + ) + ## val only + elif self.train_dataloader is None and self.val_dataloader is not None: + ( + self.model, + self.optimizer, + self.scheduler, + self.val_dataloader, + self.evaluator, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + self.scheduler, + self.val_dataloader, + self.evaluator, + ) + + ## data_preprocessor + if self.data_preprocessor is not None: + model_dtype = None + if self.accelerator.mixed_precision == "fp16": + model_dtype = torch.float16 + elif self.accelerator.mixed_precision == "bf16": + model_dtype = torch.bfloat16 + + # Move to device and cast dtype simultaneously. + self.data_preprocessor = self.data_preprocessor.to( + device=self.accelerator.device, dtype=model_dtype + ) + + if load_from and resume: + self._resume(load_from) + + gradient_accumulation_steps = self.accelerator.gradient_accumulation_steps + + if gradient_accumulation_steps > 1: + self.logger.warning( + f"Gradient accumulation with {gradient_accumulation_steps} steps is not supported. " + "LR schedule will be off from expected." + ) + + self.raw_model = self.accelerator.unwrap_model(self.model) + + # -------------------------------------------------------------------------- + def _build_optimizer(self, optimizer): + optimizer_cfg = optimizer.copy() + paramwise_cfg = optimizer_cfg.pop("paramwise_cfg", None) + + if paramwise_cfg: + # Add base lr and weight_decay for the helper to use + paramwise_cfg["lr"] = optimizer_cfg.get("lr") + paramwise_cfg["weight_decay"] = optimizer_cfg.get("weight_decay") + params = self._generate_param_groups(paramwise_cfg) + + if "weight_decay" in optimizer_cfg: + optimizer_cfg["weight_decay"] = float( + optimizer_cfg["weight_decay"] or 0.0 + ) + + optimizer_cls = OPTIMIZERS.get(optimizer_cfg.pop("type")) + return optimizer_cls(params, **optimizer_cfg) + else: + return OPTIMIZERS.build(optimizer, params=self.model.parameters()) + + def _get_layer_id_for_sapiens(self, var_name: str, num_max_layer: int) -> int: + """Assigns a layer ID to each parameter for layer-wise decay.""" + # remove fsdp prefix + if "_fsdp_wrapped_module" in var_name: + var_name = var_name.replace("_fsdp_wrapped_module.", "") + + if var_name in ( + "backbone.cls_token", + "backbone.mask_token", + "backbone.pos_embed", + "backbone.storage_tokens", + ): + return 0 + elif var_name.startswith("backbone.patch_embed"): + return 0 + elif var_name.startswith("backbone.tokenizer"): + return 0 + elif var_name.startswith("backbone.layers") or var_name.startswith( + "backbone.blocks" + ): + try: + # e.g., backbone.layers.10.norm.weight -> 10 + layer_id = int(var_name.split(".")[2]) + return layer_id + 1 + except (ValueError, IndexError): + # Fallback for unexpected layer name format + return num_max_layer - 1 + else: + # All other parameters (e.g., decode_head, final norm) get the highest LR + return num_max_layer - 1 + + def _generate_param_groups(self, paramwise_cfg: dict) -> list: + """Generates parameter groups using sapiens specific layer decay logic.""" + base_lr = float(paramwise_cfg.get("lr", 0.0)) + base_wd = float(paramwise_cfg.get("weight_decay") or 0.0) + + # Layer decay is optional. If rate==1.0 or num_layers missing -> no layer decay. + layer_decay_rate = float(paramwise_cfg.get("layer_decay_rate", 1.0)) + num_layers_cfg = paramwise_cfg.get("num_layers") + use_layer_decay = (layer_decay_rate != 1.0) and (num_layers_cfg is not None) + if use_layer_decay: + num_layers = int(num_layers_cfg) + 2 + + param_groups = [] + params_map = {} # Key: (lr, wd) -> list[(name, param)] + + for name, param in self.model.named_parameters(): + if not param.requires_grad: + continue + + # --- Weight decay per-parameter --- + if len(param.shape) == 1 or name.endswith(".bias") or "pos_embed" in name: + this_weight_decay = 0.0 + else: + this_weight_decay = base_wd + + # --- Learning rate scaling (optional layer-decay) --- + if use_layer_decay: + layer_id = self._get_layer_id_for_sapiens(name, num_layers) + lr_scale = layer_decay_rate ** (num_layers - layer_id - 1) + this_lr = base_lr * lr_scale + else: + this_lr = base_lr + + key = (this_lr, this_weight_decay) + params_map.setdefault(key, []).append((name, param)) + + # materialize groups + for (lr, wd), named_params in params_map.items(): + params = [p for _, p in named_params] + param_groups.append({"params": params, "lr": lr, "weight_decay": wd}) + + if ( + self.logger + and self.accelerator.is_main_process + and self.train_dataloader is not None + ): + # Create a new dictionary to group parameters by LR only for logging + lr_groups = {} + for (lr, _), named_params in params_map.items(): + if lr not in lr_groups: + lr_groups[lr] = [] + lr_groups[lr].extend(named_params) + + log_str = "\033[96mOptimizer parameter groups created:\n" + + # Sort by learning rate and log one line per LR + for lr, named_params in sorted(lr_groups.items()): + num_tensors = len(named_params) + num_params = sum(p.numel() for name, p in named_params) + + param_names = [name for name, p in named_params] + example_names = ", ".join(param_names[: min(4, len(param_names))]) + + if len(param_names) > 4: + example_names += ", ..." + + # Use formatting to align columns + log_str += ( + f" - decayed_lr: {lr:<11.4e} | tensors: {num_tensors:<4} | " + f"params: {num_params / 1e6:<6.2f}M | names: {example_names}\n" + ) + log_str += "\033[0m" + self.logger.info(log_str) + + return param_groups + + # Only loads model weights, not training state. This is to handle the torch.compile preload case. + def _load_checkpoint(self, load_from: str | os.PathLike): + load_from = Path(load_from) + weights_file = None + + if load_from.is_file() and load_from.name.endswith( + (".safetensors", ".pth", ".bin") + ): + weights_file = load_from + + elif load_from.is_dir(): + candidates = ["model.safetensors", "model.pth", "pytorch_model.bin"] + for name in candidates: + if (load_from / name).exists(): + weights_file = load_from / name + break + + if not weights_file: + for d in load_from.glob("*"): + if d.is_dir(): + for name in candidates: + if (d / name).exists(): + weights_file = d / name + break + if weights_file: + break + + if not weights_file or not weights_file.exists(): + raise FileNotFoundError( + f"Could not find a valid .safetensors, .pth, or .bin file in {load_from}" + ) + + if self.accelerator.is_main_process: + self.logger.info(f"Loading model weights from: {weights_file}") + + if str(weights_file).endswith(".safetensors"): + state_dict = load_file(str(weights_file), device="cpu") + else: # Handle .pth and .bin files + checkpoint = torch.load( + str(weights_file), map_location="cpu", weights_only=False + ) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif "model" in checkpoint: + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + model_state_dict = self.model.state_dict() + compatible_state_dict = {} + mismatched_keys = [] + + for key, checkpoint_tensor in state_dict.items(): + if key in model_state_dict: + model_tensor = model_state_dict[key] + + # Check if the shapes match or if its pos_embed + if checkpoint_tensor.shape == model_tensor.shape or "pos_embed" in key: + compatible_state_dict[key] = checkpoint_tensor + else: + # If shapes do not match, record it and skip loading + mismatched_keys.append( + f"- {key}: " + f"checkpoint has shape {checkpoint_tensor.shape}, " + f"model has shape {model_tensor.shape}" + ) + + incompat = self.model.load_state_dict(compatible_state_dict, strict=False) + + if self.accelerator.is_main_process: + if mismatched_keys: + log_str = "\n".join(mismatched_keys) + self.logger.warning( + "\033[31mSize Mismatch (these weights were NOT loaded): \n" + f"{log_str}\033[0m" + ) + + if incompat.missing_keys: + self.logger.warning( + "\033[38;5;208mMissing keys (in model, NOT in checkpoint): \n" + + "\n".join(incompat.missing_keys) + + "\033[0m" + ) + if incompat.unexpected_keys: + self.logger.warning( + "\033[38;5;208mUnexpected keys (in checkpoint, NOT in model): \n" + + "\n".join(_repr.repr(k) for k in incompat.unexpected_keys) + + "\033[0m" + ) + + self.logger.info("Model weights loaded successfully โœ”") + + def _resume(self, load_from: str | os.PathLike): + # If a file is provided, use its parent directory as the checkpoint directory + if str(load_from).endswith((".safetensors", ".pth", ".bin")): + load_from = Path(load_from).parent + + load_from = str(load_from) + + if self.accelerator.is_main_process: + self.logger.info(f"Resuming state from: {load_from}") + + self.accelerator.load_state(load_from) + + if self.accelerator.is_main_process: + self.logger.info("Training state resumed โœ”") + + # -------------------------------------------------------------------------- + def _init_logger(self, logger) -> None: + self.logger = None + if os.environ.get("RANK", "0") == "0": + self.logger = LOGGERS.build({**logger, "dir": self.work_dir}) + + # -------------------------------------------------------------------------- + def _log_config(self) -> None: + if os.environ.get("RANK", "0") == "0": + file = os.path.join(self.work_dir, os.path.basename(self.cfg["filename"])) + with open(file, "w", encoding="utf-8") as f: + f.write(pretty_text(self.cfg)) + from pygments import highlight + from pygments.formatters import TerminalFormatter + from pygments.lexers import PythonLexer + + self.logger.info( + highlight( + pretty_text(self.cfg), + PythonLexer(), + TerminalFormatter(style="monokai"), + ) + ) + + # -------------------------------------------------------------------------- + def _set_seed(self, rnd: Dict[str, Any]): + seed = int(rnd.get("seed", 0)) + deterministic = bool(rnd.get("deterministic", False)) + diff_rank_seed = bool(rnd.get("diff_rank_seed", True)) + + rank = 0 + if diff_rank_seed: + if dist.is_initialized(): + rank = dist.get_rank() + else: + rank = int(os.environ.get("RANK", "0")) + seed += rank + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if deterministic: + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + + # ------------------------------------------------------------------------- + def _get_model_summary_str(self, model, max_depth=5): + """Creates a concise, dependency-free summary of a PyTorch model, grouping identical repeating layers.""" + summary_lines = [] + + def VRAM_repr(num_params): + if num_params > 1e9: + return f"{num_params / 1e9:,.2f}B" + if num_params > 1e6: + return f"{num_params / 1e6:,.2f}M" + if num_params > 1e3: + return f"{num_params / 1e3:,.2f}K" + return str(num_params) + + def recurse(module, prefix="", depth=0): + if depth > max_depth: + return + children = list(module.named_children()) + i = 0 + while i < len(children): + name, child = children[i] + # Count identical sequential modules + num_repeats = 1 + for j in range(i + 1, len(children)): + next_name, next_child = children[j] + if isinstance(next_child, type(child)) and str(next_child) == str( + child + ): + num_repeats += 1 + else: + break + + is_last = (i + num_repeats - 1) == (len(children) - 1) + connector = "`-- " if is_last else "|-- " + child_params = sum(p.numel() for p in child.parameters()) + + if num_repeats > 1: + last_name_in_block = children[i + num_repeats - 1][0] + block_name = f"{name}..{last_name_in_block}" + total_params = child_params * num_repeats + summary_lines.append( + f"{prefix}{connector}{block_name} ({type(child).__name__} x {num_repeats}): " + f"{VRAM_repr(total_params)} params" + ) + else: + summary_lines.append( + f"{prefix}{connector}{name} ({type(child).__name__}): {VRAM_repr(child_params)} params" + ) + new_prefix = prefix + (" " if is_last else "| ") + recurse(child, prefix=new_prefix, depth=depth + 1) + + i += num_repeats + + total_params = sum(p.numel() for p in model.parameters()) + summary_lines.append(f"Total params: {VRAM_repr(total_params)}") + recurse(model) + return "\n".join(summary_lines) + + def _print_model(self) -> None: + if not self.logger or not self.accelerator.is_main_process: + return + + tot, trainable = 0, 0 + for p in self.raw_model.parameters(): + n = p.numel() + tot += n + trainable += n if p.requires_grad else 0 + + self.logger.info( + f"\033[92mModel Architecture:\n{self._get_model_summary_str(self.raw_model, max_depth=5)}\033[0m" + ) + self.logger.info( + f"\033[92mParameters: {tot / 1e6:.2f} M total | {trainable / 1e6:.2f} M learnable\033[0m" + ) + + if ( + self.accelerator_cfg["type"] == "DDP" + and "compile_cfg" not in self.accelerator_cfg + ): + try: + from fvcore.nn import FlopCountAnalysis + + dummy_input = torch.randn( + 1, 3, 1024, 768, device=self.accelerator.device + ) + + flops = FlopCountAnalysis(self.raw_model, dummy_input) + gflops = flops.total() / 1e9 + self.logger.info(f"\033[92mFLOPs (GMac): {gflops:.2f} GFLOPs\033[0m") + except Exception as e: + self.logger.warning(f"Could not calculate FLOPs: {e}") + + if self.train_dataloader is not None: + unique_lrs = sorted({g["lr"] for g in self.optimizer.param_groups}) + lr_str = ", ".join(f"{v:.4e}" for v in unique_lrs) + self.logger.info(f"\033[92mInitial Learning Rate(s): {lr_str}\033[0m") + + # -------------------------------------------------------------------------- + @classmethod + def from_cfg(cls, cfg): + return cls( + model=cfg.model, + work_dir=cfg.work_dir, + train_dataloader=cfg.train_dataloader, + val_dataloader=getattr(cfg, "val_dataloader", None), + val_cfg=getattr(cfg, "val_cfg", None), + data_preprocessor=cfg.data_preprocessor, + accelerator_cfg=cfg.accelerator_cfg, + optimizer=cfg.optimizer, + scheduler=getattr(cfg, "scheduler", None), + clip_grad=getattr(cfg, "clip_grad", None), + logger=getattr(cfg, "logger", None), + checkpoint=getattr(cfg, "checkpoint", None), + visualizer=getattr(cfg, "visualizer", None), + randomness=getattr(cfg, "randomness", None), + cfg=cfg.to_dict(), + ) diff --git a/sapiens/engine/visualizers/__init__.py b/sapiens/engine/visualizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..efb9673a98375a08ddcfb65bf86dc6df302fa6cd --- /dev/null +++ b/sapiens/engine/visualizers/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .base_visualizer import BaseVisualizer + +__all__ = ["BaseVisualizer"] diff --git a/sapiens/engine/visualizers/base_visualizer.py b/sapiens/engine/visualizers/base_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a4baf22d54e0336d40917ed4eb7666b8d39d9896 --- /dev/null +++ b/sapiens/engine/visualizers/base_visualizer.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from pathlib import Path + +import cv2 +import numpy as np +import torch +from sapiens.registry import VISUALIZERS +from torch import nn + + +@VISUALIZERS.register_module() +class BaseVisualizer(nn.Module): + def __init__( + self, + output_dir: str, + vis_interval: int = 100, + vis_max_samples: int = 16, + vis_downsample: int = 2, + ): + super().__init__() + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.vis_max_samples = vis_max_samples + self.vis_interval = vis_interval + self.vis_downsample = vis_downsample + + def add_batch(self, data_batch: dict, logs: dict, step: int): + images = data_batch["data_samples"]["image"].detach().cpu() + outputs = logs["outputs"].detach().cpu() # (B, C, H, W) + + if outputs.dtype == torch.bfloat16: + images = images.float() + outputs = outputs.float() + + batch_size = min(len(images), self.vis_max_samples) + + save_images = [] + + for i in range(batch_size): + gt_image = images[i].permute(1, 2, 0).cpu().numpy() * 255 + pred_image = outputs[i].permute(1, 2, 0).cpu().numpy() * 255 + + gt_image = np.clip(gt_image, 0, 255).astype(np.uint8) + pred_image = np.clip(pred_image, 0, 255).astype(np.uint8) + + image_height, image_width = gt_image.shape[:2] + + if self.vis_downsample > 1: + image_height = int(image_height / self.vis_downsample) + image_width = int(image_width / self.vis_downsample) + + gt_image = cv2.resize( + gt_image, + (image_width, image_height), + interpolation=cv2.INTER_AREA, + ) + pred_image = cv2.resize( + pred_image, + (image_width, image_height), + interpolation=cv2.INTER_AREA, + ) + + save_image = np.concatenate([gt_image, pred_image], axis=1) + save_images.append(save_image) + + out_file = self.output_dir / f"{step:06d}.jpg" + image_height, image_width = save_images[0].shape[:2] + cols = int(math.ceil(math.sqrt(batch_size))) + rows = int(math.ceil(batch_size / cols)) + + canvas_height = rows * image_height + canvas_width = cols * image_width + + canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8) + + for idx, image in enumerate(save_images): + row = idx // cols + col = idx % cols + canvas[ + row * image_height : (row + 1) * image_height, + col * image_width : (col + 1) * image_width, + ] = image + + ## downsample canvas by 2x + canvas = cv2.resize( + canvas, + (canvas_width, canvas_height), + interpolation=cv2.INTER_AREA, + ) + cv2.imwrite(out_file, canvas) + + return diff --git a/sapiens/pose/__init__.py b/sapiens/pose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f035f3cbea404e8b6a8b9281c1276fe2c5f1eaf --- /dev/null +++ b/sapiens/pose/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pathlib +import pkgutil + +from .. import __version__ + +_src = pathlib.Path(__file__).with_name("src") +__path__ = pkgutil.extend_path(__path__, __name__) # allow namespace merge +__path__.append(str(_src)) +del pathlib, pkgutil, _src + + +# ----------------------------------------------------- +from importlib import import_module as _imp + +_pkg = _imp(__name__ + ".src") # runs src/__init__.py diff --git a/sapiens/pose/configs/_base_/keypoints308.py b/sapiens/pose/configs/_base_/keypoints308.py new file mode 100644 index 0000000000000000000000000000000000000000..111f7183b1d8ad544b6e2760d7f69975ba15d016 --- /dev/null +++ b/sapiens/pose/configs/_base_/keypoints308.py @@ -0,0 +1,3902 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +coco_wholebody_info = dict( + dataset_name="coco_wholebody", + paper_info=dict( + author="Jin, Sheng and Xu, Lumin and Xu, Jin and " + "Wang, Can and Liu, Wentao and " + "Qian, Chen and Ouyang, Wanli and Luo, Ping", + title="Whole-Body Human Pose Estimation in the Wild", + container="Proceedings of the European Conference on Computer Vision (ECCV)", + year="2020", + homepage="https://github.com/jin-s13/COCO-WholeBody/", + ), + keypoint_info={ + 0: dict(name="nose", id=0, color=[51, 153, 255], type="upper", swap=""), + 1: dict( + name="left_eye", id=1, color=[51, 153, 255], type="upper", swap="right_eye" + ), + 2: dict( + name="right_eye", id=2, color=[51, 153, 255], type="upper", swap="left_eye" + ), + 3: dict( + name="left_ear", id=3, color=[51, 153, 255], type="upper", swap="right_ear" + ), + 4: dict( + name="right_ear", id=4, color=[51, 153, 255], type="upper", swap="left_ear" + ), + 5: dict( + name="left_shoulder", + id=5, + color=[0, 255, 0], + type="upper", + swap="right_shoulder", + ), + 6: dict( + name="right_shoulder", + id=6, + color=[255, 128, 0], + type="upper", + swap="left_shoulder", + ), + 7: dict( + name="left_elbow", id=7, color=[0, 255, 0], type="upper", swap="right_elbow" + ), + 8: dict( + name="right_elbow", + id=8, + color=[255, 128, 0], + type="upper", + swap="left_elbow", + ), + 9: dict( + name="left_wrist", id=9, color=[0, 255, 0], type="upper", swap="right_wrist" + ), + 10: dict( + name="right_wrist", + id=10, + color=[255, 128, 0], + type="upper", + swap="left_wrist", + ), + 11: dict( + name="left_hip", id=11, color=[0, 255, 0], type="lower", swap="right_hip" + ), + 12: dict( + name="right_hip", id=12, color=[255, 128, 0], type="lower", swap="left_hip" + ), + 13: dict( + name="left_knee", id=13, color=[0, 255, 0], type="lower", swap="right_knee" + ), + 14: dict( + name="right_knee", + id=14, + color=[255, 128, 0], + type="lower", + swap="left_knee", + ), + 15: dict( + name="left_ankle", + id=15, + color=[0, 255, 0], + type="lower", + swap="right_ankle", + ), + 16: dict( + name="right_ankle", + id=16, + color=[255, 128, 0], + type="lower", + swap="left_ankle", + ), + 17: dict( + name="left_big_toe", + id=17, + color=[255, 128, 0], + type="lower", + swap="right_big_toe", + ), + 18: dict( + name="left_small_toe", + id=18, + color=[255, 128, 0], + type="lower", + swap="right_small_toe", + ), + 19: dict( + name="left_heel", + id=19, + color=[255, 128, 0], + type="lower", + swap="right_heel", + ), + 20: dict( + name="right_big_toe", + id=20, + color=[255, 128, 0], + type="lower", + swap="left_big_toe", + ), + 21: dict( + name="right_small_toe", + id=21, + color=[255, 128, 0], + type="lower", + swap="left_small_toe", + ), + 22: dict( + name="right_heel", + id=22, + color=[255, 128, 0], + type="lower", + swap="left_heel", + ), + 23: dict(name="face-0", id=23, color=[255, 255, 255], type="", swap="face-16"), + 24: dict(name="face-1", id=24, color=[255, 255, 255], type="", swap="face-15"), + 25: dict(name="face-2", id=25, color=[255, 255, 255], type="", swap="face-14"), + 26: dict(name="face-3", id=26, color=[255, 255, 255], type="", swap="face-13"), + 27: dict(name="face-4", id=27, color=[255, 255, 255], type="", swap="face-12"), + 28: dict(name="face-5", id=28, color=[255, 255, 255], type="", swap="face-11"), + 29: dict(name="face-6", id=29, color=[255, 255, 255], type="", swap="face-10"), + 30: dict(name="face-7", id=30, color=[255, 255, 255], type="", swap="face-9"), + 31: dict(name="face-8", id=31, color=[255, 255, 255], type="", swap=""), + 32: dict(name="face-9", id=32, color=[255, 255, 255], type="", swap="face-7"), + 33: dict(name="face-10", id=33, color=[255, 255, 255], type="", swap="face-6"), + 34: dict(name="face-11", id=34, color=[255, 255, 255], type="", swap="face-5"), + 35: dict(name="face-12", id=35, color=[255, 255, 255], type="", swap="face-4"), + 36: dict(name="face-13", id=36, color=[255, 255, 255], type="", swap="face-3"), + 37: dict(name="face-14", id=37, color=[255, 255, 255], type="", swap="face-2"), + 38: dict(name="face-15", id=38, color=[255, 255, 255], type="", swap="face-1"), + 39: dict(name="face-16", id=39, color=[255, 255, 255], type="", swap="face-0"), + 40: dict(name="face-17", id=40, color=[255, 255, 255], type="", swap="face-26"), + 41: dict(name="face-18", id=41, color=[255, 255, 255], type="", swap="face-25"), + 42: dict(name="face-19", id=42, color=[255, 255, 255], type="", swap="face-24"), + 43: dict(name="face-20", id=43, color=[255, 255, 255], type="", swap="face-23"), + 44: dict(name="face-21", id=44, color=[255, 255, 255], type="", swap="face-22"), + 45: dict(name="face-22", id=45, color=[255, 255, 255], type="", swap="face-21"), + 46: dict(name="face-23", id=46, color=[255, 255, 255], type="", swap="face-20"), + 47: dict(name="face-24", id=47, color=[255, 255, 255], type="", swap="face-19"), + 48: dict(name="face-25", id=48, color=[255, 255, 255], type="", swap="face-18"), + 49: dict(name="face-26", id=49, color=[255, 255, 255], type="", swap="face-17"), + 50: dict(name="face-27", id=50, color=[255, 255, 255], type="", swap=""), + 51: dict(name="face-28", id=51, color=[255, 255, 255], type="", swap=""), + 52: dict(name="face-29", id=52, color=[255, 255, 255], type="", swap=""), + 53: dict(name="face-30", id=53, color=[255, 255, 255], type="", swap=""), + 54: dict(name="face-31", id=54, color=[255, 255, 255], type="", swap="face-35"), + 55: dict(name="face-32", id=55, color=[255, 255, 255], type="", swap="face-34"), + 56: dict(name="face-33", id=56, color=[255, 255, 255], type="", swap=""), + 57: dict(name="face-34", id=57, color=[255, 255, 255], type="", swap="face-32"), + 58: dict(name="face-35", id=58, color=[255, 255, 255], type="", swap="face-31"), + 59: dict(name="face-36", id=59, color=[255, 255, 255], type="", swap="face-45"), + 60: dict(name="face-37", id=60, color=[255, 255, 255], type="", swap="face-44"), + 61: dict(name="face-38", id=61, color=[255, 255, 255], type="", swap="face-43"), + 62: dict(name="face-39", id=62, color=[255, 255, 255], type="", swap="face-42"), + 63: dict(name="face-40", id=63, color=[255, 255, 255], type="", swap="face-47"), + 64: dict(name="face-41", id=64, color=[255, 255, 255], type="", swap="face-46"), + 65: dict(name="face-42", id=65, color=[255, 255, 255], type="", swap="face-39"), + 66: dict(name="face-43", id=66, color=[255, 255, 255], type="", swap="face-38"), + 67: dict(name="face-44", id=67, color=[255, 255, 255], type="", swap="face-37"), + 68: dict(name="face-45", id=68, color=[255, 255, 255], type="", swap="face-36"), + 69: dict(name="face-46", id=69, color=[255, 255, 255], type="", swap="face-41"), + 70: dict(name="face-47", id=70, color=[255, 255, 255], type="", swap="face-40"), + 71: dict(name="face-48", id=71, color=[255, 255, 255], type="", swap="face-54"), + 72: dict(name="face-49", id=72, color=[255, 255, 255], type="", swap="face-53"), + 73: dict(name="face-50", id=73, color=[255, 255, 255], type="", swap="face-52"), + 74: dict(name="face-51", id=74, color=[255, 255, 255], type="", swap=""), + 75: dict(name="face-52", id=75, color=[255, 255, 255], type="", swap="face-50"), + 76: dict(name="face-53", id=76, color=[255, 255, 255], type="", swap="face-49"), + 77: dict(name="face-54", id=77, color=[255, 255, 255], type="", swap="face-48"), + 78: dict(name="face-55", id=78, color=[255, 255, 255], type="", swap="face-59"), + 79: dict(name="face-56", id=79, color=[255, 255, 255], type="", swap="face-58"), + 80: dict(name="face-57", id=80, color=[255, 255, 255], type="", swap=""), + 81: dict(name="face-58", id=81, color=[255, 255, 255], type="", swap="face-56"), + 82: dict(name="face-59", id=82, color=[255, 255, 255], type="", swap="face-55"), + 83: dict(name="face-60", id=83, color=[255, 255, 255], type="", swap="face-64"), + 84: dict(name="face-61", id=84, color=[255, 255, 255], type="", swap="face-63"), + 85: dict(name="face-62", id=85, color=[255, 255, 255], type="", swap=""), + 86: dict(name="face-63", id=86, color=[255, 255, 255], type="", swap="face-61"), + 87: dict(name="face-64", id=87, color=[255, 255, 255], type="", swap="face-60"), + 88: dict(name="face-65", id=88, color=[255, 255, 255], type="", swap="face-67"), + 89: dict(name="face-66", id=89, color=[255, 255, 255], type="", swap=""), + 90: dict(name="face-67", id=90, color=[255, 255, 255], type="", swap="face-65"), + 91: dict( + name="left_hand_root", + id=91, + color=[255, 255, 255], + type="", + swap="right_hand_root", + ), + 92: dict( + name="left_thumb1", id=92, color=[255, 128, 0], type="", swap="right_thumb1" + ), + 93: dict( + name="left_thumb2", id=93, color=[255, 128, 0], type="", swap="right_thumb2" + ), + 94: dict( + name="left_thumb3", id=94, color=[255, 128, 0], type="", swap="right_thumb3" + ), + 95: dict( + name="left_thumb4", id=95, color=[255, 128, 0], type="", swap="right_thumb4" + ), + 96: dict( + name="left_forefinger1", + id=96, + color=[255, 153, 255], + type="", + swap="right_forefinger1", + ), + 97: dict( + name="left_forefinger2", + id=97, + color=[255, 153, 255], + type="", + swap="right_forefinger2", + ), + 98: dict( + name="left_forefinger3", + id=98, + color=[255, 153, 255], + type="", + swap="right_forefinger3", + ), + 99: dict( + name="left_forefinger4", + id=99, + color=[255, 153, 255], + type="", + swap="right_forefinger4", + ), + 100: dict( + name="left_middle_finger1", + id=100, + color=[102, 178, 255], + type="", + swap="right_middle_finger1", + ), + 101: dict( + name="left_middle_finger2", + id=101, + color=[102, 178, 255], + type="", + swap="right_middle_finger2", + ), + 102: dict( + name="left_middle_finger3", + id=102, + color=[102, 178, 255], + type="", + swap="right_middle_finger3", + ), + 103: dict( + name="left_middle_finger4", + id=103, + color=[102, 178, 255], + type="", + swap="right_middle_finger4", + ), + 104: dict( + name="left_ring_finger1", + id=104, + color=[255, 51, 51], + type="", + swap="right_ring_finger1", + ), + 105: dict( + name="left_ring_finger2", + id=105, + color=[255, 51, 51], + type="", + swap="right_ring_finger2", + ), + 106: dict( + name="left_ring_finger3", + id=106, + color=[255, 51, 51], + type="", + swap="right_ring_finger3", + ), + 107: dict( + name="left_ring_finger4", + id=107, + color=[255, 51, 51], + type="", + swap="right_ring_finger4", + ), + 108: dict( + name="left_pinky_finger1", + id=108, + color=[0, 255, 0], + type="", + swap="right_pinky_finger1", + ), + 109: dict( + name="left_pinky_finger2", + id=109, + color=[0, 255, 0], + type="", + swap="right_pinky_finger2", + ), + 110: dict( + name="left_pinky_finger3", + id=110, + color=[0, 255, 0], + type="", + swap="right_pinky_finger3", + ), + 111: dict( + name="left_pinky_finger4", + id=111, + color=[0, 255, 0], + type="", + swap="right_pinky_finger4", + ), + 112: dict( + name="right_hand_root", + id=112, + color=[255, 255, 255], + type="", + swap="left_hand_root", + ), + 113: dict( + name="right_thumb1", + id=113, + color=[255, 128, 0], + type="", + swap="left_thumb1", + ), + 114: dict( + name="right_thumb2", + id=114, + color=[255, 128, 0], + type="", + swap="left_thumb2", + ), + 115: dict( + name="right_thumb3", + id=115, + color=[255, 128, 0], + type="", + swap="left_thumb3", + ), + 116: dict( + name="right_thumb4", + id=116, + color=[255, 128, 0], + type="", + swap="left_thumb4", + ), + 117: dict( + name="right_forefinger1", + id=117, + color=[255, 153, 255], + type="", + swap="left_forefinger1", + ), + 118: dict( + name="right_forefinger2", + id=118, + color=[255, 153, 255], + type="", + swap="left_forefinger2", + ), + 119: dict( + name="right_forefinger3", + id=119, + color=[255, 153, 255], + type="", + swap="left_forefinger3", + ), + 120: dict( + name="right_forefinger4", + id=120, + color=[255, 153, 255], + type="", + swap="left_forefinger4", + ), + 121: dict( + name="right_middle_finger1", + id=121, + color=[102, 178, 255], + type="", + swap="left_middle_finger1", + ), + 122: dict( + name="right_middle_finger2", + id=122, + color=[102, 178, 255], + type="", + swap="left_middle_finger2", + ), + 123: dict( + name="right_middle_finger3", + id=123, + color=[102, 178, 255], + type="", + swap="left_middle_finger3", + ), + 124: dict( + name="right_middle_finger4", + id=124, + color=[102, 178, 255], + type="", + swap="left_middle_finger4", + ), + 125: dict( + name="right_ring_finger1", + id=125, + color=[255, 51, 51], + type="", + swap="left_ring_finger1", + ), + 126: dict( + name="right_ring_finger2", + id=126, + color=[255, 51, 51], + type="", + swap="left_ring_finger2", + ), + 127: dict( + name="right_ring_finger3", + id=127, + color=[255, 51, 51], + type="", + swap="left_ring_finger3", + ), + 128: dict( + name="right_ring_finger4", + id=128, + color=[255, 51, 51], + type="", + swap="left_ring_finger4", + ), + 129: dict( + name="right_pinky_finger1", + id=129, + color=[0, 255, 0], + type="", + swap="left_pinky_finger1", + ), + 130: dict( + name="right_pinky_finger2", + id=130, + color=[0, 255, 0], + type="", + swap="left_pinky_finger2", + ), + 131: dict( + name="right_pinky_finger3", + id=131, + color=[0, 255, 0], + type="", + swap="left_pinky_finger3", + ), + 132: dict( + name="right_pinky_finger4", + id=132, + color=[0, 255, 0], + type="", + swap="left_pinky_finger4", + ), + }, + skeleton_info={ + 0: dict(link=("left_ankle", "left_knee"), id=0, color=[0, 255, 0]), + 1: dict(link=("left_knee", "left_hip"), id=1, color=[0, 255, 0]), + 2: dict(link=("right_ankle", "right_knee"), id=2, color=[255, 128, 0]), + 3: dict(link=("right_knee", "right_hip"), id=3, color=[255, 128, 0]), + 4: dict(link=("left_hip", "right_hip"), id=4, color=[51, 153, 255]), + 5: dict(link=("left_shoulder", "left_hip"), id=5, color=[51, 153, 255]), + 6: dict(link=("right_shoulder", "right_hip"), id=6, color=[51, 153, 255]), + 7: dict(link=("left_shoulder", "right_shoulder"), id=7, color=[51, 153, 255]), + 8: dict(link=("left_shoulder", "left_elbow"), id=8, color=[0, 255, 0]), + 9: dict(link=("right_shoulder", "right_elbow"), id=9, color=[255, 128, 0]), + 10: dict(link=("left_elbow", "left_wrist"), id=10, color=[0, 255, 0]), + 11: dict(link=("right_elbow", "right_wrist"), id=11, color=[255, 128, 0]), + 12: dict(link=("left_eye", "right_eye"), id=12, color=[51, 153, 255]), + 13: dict(link=("nose", "left_eye"), id=13, color=[51, 153, 255]), + 14: dict(link=("nose", "right_eye"), id=14, color=[51, 153, 255]), + 15: dict(link=("left_eye", "left_ear"), id=15, color=[51, 153, 255]), + 16: dict(link=("right_eye", "right_ear"), id=16, color=[51, 153, 255]), + 17: dict(link=("left_ear", "left_shoulder"), id=17, color=[51, 153, 255]), + 18: dict(link=("right_ear", "right_shoulder"), id=18, color=[51, 153, 255]), + 19: dict(link=("left_ankle", "left_big_toe"), id=19, color=[0, 255, 0]), + 20: dict(link=("left_ankle", "left_small_toe"), id=20, color=[0, 255, 0]), + 21: dict(link=("left_ankle", "left_heel"), id=21, color=[0, 255, 0]), + 22: dict(link=("right_ankle", "right_big_toe"), id=22, color=[255, 128, 0]), + 23: dict(link=("right_ankle", "right_small_toe"), id=23, color=[255, 128, 0]), + 24: dict(link=("right_ankle", "right_heel"), id=24, color=[255, 128, 0]), + 25: dict(link=("left_hand_root", "left_thumb1"), id=25, color=[255, 128, 0]), + 26: dict(link=("left_thumb1", "left_thumb2"), id=26, color=[255, 128, 0]), + 27: dict(link=("left_thumb2", "left_thumb3"), id=27, color=[255, 128, 0]), + 28: dict(link=("left_thumb3", "left_thumb4"), id=28, color=[255, 128, 0]), + 29: dict( + link=("left_hand_root", "left_forefinger1"), id=29, color=[255, 153, 255] + ), + 30: dict( + link=("left_forefinger1", "left_forefinger2"), id=30, color=[255, 153, 255] + ), + 31: dict( + link=("left_forefinger2", "left_forefinger3"), id=31, color=[255, 153, 255] + ), + 32: dict( + link=("left_forefinger3", "left_forefinger4"), id=32, color=[255, 153, 255] + ), + 33: dict( + link=("left_hand_root", "left_middle_finger1"), id=33, color=[102, 178, 255] + ), + 34: dict( + link=("left_middle_finger1", "left_middle_finger2"), + id=34, + color=[102, 178, 255], + ), + 35: dict( + link=("left_middle_finger2", "left_middle_finger3"), + id=35, + color=[102, 178, 255], + ), + 36: dict( + link=("left_middle_finger3", "left_middle_finger4"), + id=36, + color=[102, 178, 255], + ), + 37: dict( + link=("left_hand_root", "left_ring_finger1"), id=37, color=[255, 51, 51] + ), + 38: dict( + link=("left_ring_finger1", "left_ring_finger2"), id=38, color=[255, 51, 51] + ), + 39: dict( + link=("left_ring_finger2", "left_ring_finger3"), id=39, color=[255, 51, 51] + ), + 40: dict( + link=("left_ring_finger3", "left_ring_finger4"), id=40, color=[255, 51, 51] + ), + 41: dict( + link=("left_hand_root", "left_pinky_finger1"), id=41, color=[0, 255, 0] + ), + 42: dict( + link=("left_pinky_finger1", "left_pinky_finger2"), id=42, color=[0, 255, 0] + ), + 43: dict( + link=("left_pinky_finger2", "left_pinky_finger3"), id=43, color=[0, 255, 0] + ), + 44: dict( + link=("left_pinky_finger3", "left_pinky_finger4"), id=44, color=[0, 255, 0] + ), + 45: dict(link=("right_hand_root", "right_thumb1"), id=45, color=[255, 128, 0]), + 46: dict(link=("right_thumb1", "right_thumb2"), id=46, color=[255, 128, 0]), + 47: dict(link=("right_thumb2", "right_thumb3"), id=47, color=[255, 128, 0]), + 48: dict(link=("right_thumb3", "right_thumb4"), id=48, color=[255, 128, 0]), + 49: dict( + link=("right_hand_root", "right_forefinger1"), id=49, color=[255, 153, 255] + ), + 50: dict( + link=("right_forefinger1", "right_forefinger2"), + id=50, + color=[255, 153, 255], + ), + 51: dict( + link=("right_forefinger2", "right_forefinger3"), + id=51, + color=[255, 153, 255], + ), + 52: dict( + link=("right_forefinger3", "right_forefinger4"), + id=52, + color=[255, 153, 255], + ), + 53: dict( + link=("right_hand_root", "right_middle_finger1"), + id=53, + color=[102, 178, 255], + ), + 54: dict( + link=("right_middle_finger1", "right_middle_finger2"), + id=54, + color=[102, 178, 255], + ), + 55: dict( + link=("right_middle_finger2", "right_middle_finger3"), + id=55, + color=[102, 178, 255], + ), + 56: dict( + link=("right_middle_finger3", "right_middle_finger4"), + id=56, + color=[102, 178, 255], + ), + 57: dict( + link=("right_hand_root", "right_ring_finger1"), id=57, color=[255, 51, 51] + ), + 58: dict( + link=("right_ring_finger1", "right_ring_finger2"), + id=58, + color=[255, 51, 51], + ), + 59: dict( + link=("right_ring_finger2", "right_ring_finger3"), + id=59, + color=[255, 51, 51], + ), + 60: dict( + link=("right_ring_finger3", "right_ring_finger4"), + id=60, + color=[255, 51, 51], + ), + 61: dict( + link=("right_hand_root", "right_pinky_finger1"), id=61, color=[0, 255, 0] + ), + 62: dict( + link=("right_pinky_finger1", "right_pinky_finger2"), + id=62, + color=[0, 255, 0], + ), + 63: dict( + link=("right_pinky_finger2", "right_pinky_finger3"), + id=63, + color=[0, 255, 0], + ), + 64: dict( + link=("right_pinky_finger3", "right_pinky_finger4"), + id=64, + color=[0, 255, 0], + ), + }, + joint_weights=[1.0] * 133, + # 'https://github.com/jin-s13/COCO-WholeBody/blob/master/' + # 'evaluation/myeval_wholebody.py#L175' + sigmas=[ + 0.026, + 0.025, + 0.025, + 0.035, + 0.035, + 0.079, + 0.079, + 0.072, + 0.072, + 0.062, + 0.062, + 0.107, + 0.107, + 0.087, + 0.087, + 0.089, + 0.089, + 0.068, + 0.066, + 0.066, + 0.092, + 0.094, + 0.094, + 0.042, + 0.043, + 0.044, + 0.043, + 0.040, + 0.035, + 0.031, + 0.025, + 0.020, + 0.023, + 0.029, + 0.032, + 0.037, + 0.038, + 0.043, + 0.041, + 0.045, + 0.013, + 0.012, + 0.011, + 0.011, + 0.012, + 0.012, + 0.011, + 0.011, + 0.013, + 0.015, + 0.009, + 0.007, + 0.007, + 0.007, + 0.012, + 0.009, + 0.008, + 0.016, + 0.010, + 0.017, + 0.011, + 0.009, + 0.011, + 0.009, + 0.007, + 0.013, + 0.008, + 0.011, + 0.012, + 0.010, + 0.034, + 0.008, + 0.008, + 0.009, + 0.008, + 0.008, + 0.007, + 0.010, + 0.008, + 0.009, + 0.009, + 0.009, + 0.007, + 0.007, + 0.008, + 0.011, + 0.008, + 0.008, + 0.008, + 0.01, + 0.008, + 0.029, + 0.022, + 0.035, + 0.037, + 0.047, + 0.026, + 0.025, + 0.024, + 0.035, + 0.018, + 0.024, + 0.022, + 0.026, + 0.017, + 0.021, + 0.021, + 0.032, + 0.02, + 0.019, + 0.022, + 0.031, + 0.029, + 0.022, + 0.035, + 0.037, + 0.047, + 0.026, + 0.025, + 0.024, + 0.035, + 0.018, + 0.024, + 0.022, + 0.026, + 0.017, + 0.021, + 0.021, + 0.032, + 0.02, + 0.019, + 0.022, + 0.031, + ], +) + +dataset_info = dict( + dataset_name="goliath", + min_visible_keypoints=8, + image_height=4096, + image_width=2668, + original_keypoint_info={ + 0: "nose", + 1: "left_eye", + 2: "right_eye", + 3: "left_ear", + 4: "right_ear", + 5: "left_shoulder", + 6: "right_shoulder", + 7: "left_elbow", + 8: "right_elbow", + 9: "left_hip", + 10: "right_hip", + 11: "left_knee", + 12: "right_knee", + 13: "left_ankle", + 14: "right_ankle", + 15: "left_big_toe_tip", + 16: "left_small_toe_tip", + 17: "left_heel", + 18: "right_big_toe_tip", + 19: "right_small_toe_tip", + 20: "right_heel", + 21: "right_thumb_tip", + 22: "right_thumb_first_joint", + 23: "right_thumb_second_joint", + 24: "right_thumb_third_joint", + 25: "right_index_tip", + 26: "right_index_first_joint", + 27: "right_index_second_joint", + 28: "right_index_third_joint", + 29: "right_middle_tip", + 30: "right_middle_first_joint", + 31: "right_middle_second_joint", + 32: "right_middle_third_joint", + 33: "right_ring_tip", + 34: "right_ring_first_joint", + 35: "right_ring_second_joint", + 36: "right_ring_third_joint", + 37: "right_pinky_tip", + 38: "right_pinky_first_joint", + 39: "right_pinky_second_joint", + 40: "right_pinky_third_joint", + 41: "right_wrist", + 42: "left_thumb_tip", + 43: "left_thumb_first_joint", + 44: "left_thumb_second_joint", + 45: "left_thumb_third_joint", + 46: "left_index_tip", + 47: "left_index_first_joint", + 48: "left_index_second_joint", + 49: "left_index_third_joint", + 50: "left_middle_tip", + 51: "left_middle_first_joint", + 52: "left_middle_second_joint", + 53: "left_middle_third_joint", + 54: "left_ring_tip", + 55: "left_ring_first_joint", + 56: "left_ring_second_joint", + 57: "left_ring_third_joint", + 58: "left_pinky_tip", + 59: "left_pinky_first_joint", + 60: "left_pinky_second_joint", + 61: "left_pinky_third_joint", + 62: "left_wrist", + 63: "left_olecranon", + 64: "right_olecranon", + 65: "left_cubital_fossa", + 66: "right_cubital_fossa", + 67: "left_acromion", + 68: "right_acromion", + 69: "neck", + 70: "center_of_glabella", + 71: "center_of_nose_root", + 72: "tip_of_nose_bridge", + 73: "midpoint_1_of_nose_bridge", + 74: "midpoint_2_of_nose_bridge", + 75: "midpoint_3_of_nose_bridge", + 76: "center_of_labiomental_groove", + 77: "tip_of_chin", + 78: "upper_startpoint_of_r_eyebrow", + 79: "lower_startpoint_of_r_eyebrow", + 80: "end_of_r_eyebrow", + 81: "upper_midpoint_1_of_r_eyebrow", + 82: "lower_midpoint_1_of_r_eyebrow", + 83: "upper_midpoint_2_of_r_eyebrow", + 84: "upper_midpoint_3_of_r_eyebrow", + 85: "lower_midpoint_2_of_r_eyebrow", + 86: "lower_midpoint_3_of_r_eyebrow", + 87: "upper_startpoint_of_l_eyebrow", + 88: "lower_startpoint_of_l_eyebrow", + 89: "end_of_l_eyebrow", + 90: "upper_midpoint_1_of_l_eyebrow", + 91: "lower_midpoint_1_of_l_eyebrow", + 92: "upper_midpoint_2_of_l_eyebrow", + 93: "upper_midpoint_3_of_l_eyebrow", + 94: "lower_midpoint_2_of_l_eyebrow", + 95: "lower_midpoint_3_of_l_eyebrow", + 96: "l_inner_end_of_upper_lash_line", + 97: "l_outer_end_of_upper_lash_line", + 98: "l_centerpoint_of_upper_lash_line", + 99: "l_midpoint_2_of_upper_lash_line", + 100: "l_midpoint_1_of_upper_lash_line", + 101: "l_midpoint_6_of_upper_lash_line", + 102: "l_midpoint_5_of_upper_lash_line", + 103: "l_midpoint_4_of_upper_lash_line", + 104: "l_midpoint_3_of_upper_lash_line", + 105: "l_outer_end_of_upper_eyelid_line", + 106: "l_midpoint_6_of_upper_eyelid_line", + 107: "l_midpoint_2_of_upper_eyelid_line", + 108: "l_midpoint_5_of_upper_eyelid_line", + 109: "l_centerpoint_of_upper_eyelid_line", + 110: "l_midpoint_4_of_upper_eyelid_line", + 111: "l_midpoint_1_of_upper_eyelid_line", + 112: "l_midpoint_3_of_upper_eyelid_line", + 113: "l_midpoint_6_of_upper_crease_line", + 114: "l_midpoint_2_of_upper_crease_line", + 115: "l_midpoint_5_of_upper_crease_line", + 116: "l_centerpoint_of_upper_crease_line", + 117: "l_midpoint_4_of_upper_crease_line", + 118: "l_midpoint_1_of_upper_crease_line", + 119: "l_midpoint_3_of_upper_crease_line", + 120: "r_inner_end_of_upper_lash_line", + 121: "r_outer_end_of_upper_lash_line", + 122: "r_centerpoint_of_upper_lash_line", + 123: "r_midpoint_1_of_upper_lash_line", + 124: "r_midpoint_2_of_upper_lash_line", + 125: "r_midpoint_3_of_upper_lash_line", + 126: "r_midpoint_4_of_upper_lash_line", + 127: "r_midpoint_5_of_upper_lash_line", + 128: "r_midpoint_6_of_upper_lash_line", + 129: "r_outer_end_of_upper_eyelid_line", + 130: "r_midpoint_3_of_upper_eyelid_line", + 131: "r_midpoint_1_of_upper_eyelid_line", + 132: "r_midpoint_4_of_upper_eyelid_line", + 133: "r_centerpoint_of_upper_eyelid_line", + 134: "r_midpoint_5_of_upper_eyelid_line", + 135: "r_midpoint_2_of_upper_eyelid_line", + 136: "r_midpoint_6_of_upper_eyelid_line", + 137: "r_midpoint_3_of_upper_crease_line", + 138: "r_midpoint_1_of_upper_crease_line", + 139: "r_midpoint_4_of_upper_crease_line", + 140: "r_centerpoint_of_upper_crease_line", + 141: "r_midpoint_5_of_upper_crease_line", + 142: "r_midpoint_2_of_upper_crease_line", + 143: "r_midpoint_6_of_upper_crease_line", + 144: "l_inner_end_of_lower_lash_line", + 145: "l_outer_end_of_lower_lash_line", + 146: "l_centerpoint_of_lower_lash_line", + 147: "l_midpoint_2_of_lower_lash_line", + 148: "l_midpoint_1_of_lower_lash_line", + 149: "l_midpoint_6_of_lower_lash_line", + 150: "l_midpoint_5_of_lower_lash_line", + 151: "l_midpoint_4_of_lower_lash_line", + 152: "l_midpoint_3_of_lower_lash_line", + 153: "l_outer_end_of_lower_eyelid_line", + 154: "l_midpoint_6_of_lower_eyelid_line", + 155: "l_midpoint_2_of_lower_eyelid_line", + 156: "l_midpoint_5_of_lower_eyelid_line", + 157: "l_centerpoint_of_lower_eyelid_line", + 158: "l_midpoint_4_of_lower_eyelid_line", + 159: "l_midpoint_1_of_lower_eyelid_line", + 160: "l_midpoint_3_of_lower_eyelid_line", + 161: "r_inner_end_of_lower_lash_line", + 162: "r_outer_end_of_lower_lash_line", + 163: "r_centerpoint_of_lower_lash_line", + 164: "r_midpoint_1_of_lower_lash_line", + 165: "r_midpoint_2_of_lower_lash_line", + 166: "r_midpoint_3_of_lower_lash_line", + 167: "r_midpoint_4_of_lower_lash_line", + 168: "r_midpoint_5_of_lower_lash_line", + 169: "r_midpoint_6_of_lower_lash_line", + 170: "r_outer_end_of_lower_eyelid_line", + 171: "r_midpoint_3_of_lower_eyelid_line", + 172: "r_midpoint_1_of_lower_eyelid_line", + 173: "r_midpoint_4_of_lower_eyelid_line", + 174: "r_centerpoint_of_lower_eyelid_line", + 175: "r_midpoint_5_of_lower_eyelid_line", + 176: "r_midpoint_2_of_lower_eyelid_line", + 177: "r_midpoint_6_of_lower_eyelid_line", + 178: "tip_of_nose", + 179: "bottom_center_of_nose", + 180: "r_outer_corner_of_nose", + 181: "l_outer_corner_of_nose", + 182: "inner_corner_of_r_nostril", + 183: "outer_corner_of_r_nostril", + 184: "upper_corner_of_r_nostril", + 185: "inner_corner_of_l_nostril", + 186: "outer_corner_of_l_nostril", + 187: "upper_corner_of_l_nostril", + 188: "r_outer_corner_of_mouth", + 189: "l_outer_corner_of_mouth", + 190: "center_of_cupid_bow", + 191: "center_of_lower_outer_lip", + 192: "midpoint_1_of_upper_outer_lip", + 193: "midpoint_2_of_upper_outer_lip", + 194: "midpoint_1_of_lower_outer_lip", + 195: "midpoint_2_of_lower_outer_lip", + 196: "midpoint_3_of_upper_outer_lip", + 197: "midpoint_4_of_upper_outer_lip", + 198: "midpoint_5_of_upper_outer_lip", + 199: "midpoint_6_of_upper_outer_lip", + 200: "midpoint_3_of_lower_outer_lip", + 201: "midpoint_4_of_lower_outer_lip", + 202: "midpoint_5_of_lower_outer_lip", + 203: "midpoint_6_of_lower_outer_lip", + 204: "r_inner_corner_of_mouth", + 205: "l_inner_corner_of_mouth", + 206: "center_of_upper_inner_lip", + 207: "center_of_lower_inner_lip", + 208: "midpoint_1_of_upper_inner_lip", + 209: "midpoint_2_of_upper_inner_lip", + 210: "midpoint_1_of_lower_inner_lip", + 211: "midpoint_2_of_lower_inner_lip", + 212: "midpoint_3_of_upper_inner_lip", + 213: "midpoint_4_of_upper_inner_lip", + 214: "midpoint_5_of_upper_inner_lip", + 215: "midpoint_6_of_upper_inner_lip", + 216: "midpoint_3_of_lower_inner_lip", + 217: "midpoint_4_of_lower_inner_lip", + 218: "midpoint_5_of_lower_inner_lip", + 219: "midpoint_6_of_lower_inner_lip", + 220: "teeth", + 221: "teeth", + 222: "teeth", + 223: "teeth", + 224: "teeth", + 225: "teeth", + 226: "teeth", + 227: "teeth", + 228: "teeth", + 229: "teeth", + 230: "teeth", + 231: "teeth", + 232: "teeth", + 233: "teeth", + 234: "teeth", + 235: "teeth", + 236: "teeth", + 237: "teeth", + 238: "teeth", + 239: "teeth", + 240: "teeth", + 241: "teeth", + 242: "teeth", + 243: "teeth", + 244: "teeth", + 245: "teeth", + 246: "teeth", + 247: "teeth", + 248: "teeth", + 249: "teeth", + 250: "teeth", + 251: "teeth", + 252: "teeth", + 253: "teeth", + 254: "teeth", + 255: "teeth", + 256: "l_top_end_of_inferior_crus", + 257: "l_top_end_of_superior_crus", + 258: "l_start_of_antihelix", + 259: "l_end_of_antihelix", + 260: "l_midpoint_1_of_antihelix", + 261: "l_midpoint_1_of_inferior_crus", + 262: "l_midpoint_2_of_antihelix", + 263: "l_midpoint_3_of_antihelix", + 264: "l_point_1_of_inner_helix", + 265: "l_point_2_of_inner_helix", + 266: "l_point_3_of_inner_helix", + 267: "l_point_4_of_inner_helix", + 268: "l_point_5_of_inner_helix", + 269: "l_point_6_of_inner_helix", + 270: "l_point_7_of_inner_helix", + 271: "l_highest_point_of_antitragus", + 272: "l_bottom_point_of_tragus", + 273: "l_protruding_point_of_tragus", + 274: "l_top_point_of_tragus", + 275: "l_start_point_of_crus_of_helix", + 276: "l_deepest_point_of_concha", + 277: "l_tip_of_ear_lobe", + 278: "l_midpoint_between_22_15", + 279: "l_bottom_connecting_point_of_ear_lobe", + 280: "l_top_connecting_point_of_helix", + 281: "l_point_8_of_inner_helix", + 282: "r_top_end_of_inferior_crus", + 283: "r_top_end_of_superior_crus", + 284: "r_start_of_antihelix", + 285: "r_end_of_antihelix", + 286: "r_midpoint_1_of_antihelix", + 287: "r_midpoint_1_of_inferior_crus", + 288: "r_midpoint_2_of_antihelix", + 289: "r_midpoint_3_of_antihelix", + 290: "r_point_1_of_inner_helix", + 291: "r_point_8_of_inner_helix", + 292: "r_point_3_of_inner_helix", + 293: "r_point_4_of_inner_helix", + 294: "r_point_5_of_inner_helix", + 295: "r_point_6_of_inner_helix", + 296: "r_point_7_of_inner_helix", + 297: "r_highest_point_of_antitragus", + 298: "r_bottom_point_of_tragus", + 299: "r_protruding_point_of_tragus", + 300: "r_top_point_of_tragus", + 301: "r_start_point_of_crus_of_helix", + 302: "r_deepest_point_of_concha", + 303: "r_tip_of_ear_lobe", + 304: "r_midpoint_between_22_15", + 305: "r_bottom_connecting_point_of_ear_lobe", + 306: "r_top_connecting_point_of_helix", + 307: "r_point_2_of_inner_helix", + 308: "l_center_of_iris", + 309: "l_border_of_iris_3", + 310: "l_border_of_iris_midpoint_1", + 311: "l_border_of_iris_12", + 312: "l_border_of_iris_midpoint_4", + 313: "l_border_of_iris_9", + 314: "l_border_of_iris_midpoint_3", + 315: "l_border_of_iris_6", + 316: "l_border_of_iris_midpoint_2", + 317: "r_center_of_iris", + 318: "r_border_of_iris_3", + 319: "r_border_of_iris_midpoint_1", + 320: "r_border_of_iris_12", + 321: "r_border_of_iris_midpoint_4", + 322: "r_border_of_iris_9", + 323: "r_border_of_iris_midpoint_3", + 324: "r_border_of_iris_6", + 325: "r_border_of_iris_midpoint_2", + 326: "l_center_of_pupil", + 327: "l_border_of_pupil_3", + 328: "l_border_of_pupil_midpoint_1", + 329: "l_border_of_pupil_12", + 330: "l_border_of_pupil_midpoint_4", + 331: "l_border_of_pupil_9", + 332: "l_border_of_pupil_midpoint_3", + 333: "l_border_of_pupil_6", + 334: "l_border_of_pupil_midpoint_2", + 335: "r_center_of_pupil", + 336: "r_border_of_pupil_3", + 337: "r_border_of_pupil_midpoint_1", + 338: "r_border_of_pupil_12", + 339: "r_border_of_pupil_midpoint_4", + 340: "r_border_of_pupil_9", + 341: "r_border_of_pupil_midpoint_3", + 342: "r_border_of_pupil_6", + 343: "r_border_of_pupil_midpoint_2", + }, + keypoint_info={ + 0: dict(name="nose", id=0, color=[51, 153, 255], type="upper", swap=""), + 1: dict( + name="left_eye", id=1, color=[51, 153, 255], type="upper", swap="right_eye" + ), + 2: dict( + name="right_eye", id=2, color=[51, 153, 255], type="upper", swap="left_eye" + ), + 3: dict( + name="left_ear", id=3, color=[51, 153, 255], type="upper", swap="right_ear" + ), + 4: dict( + name="right_ear", id=4, color=[51, 153, 255], type="upper", swap="left_ear" + ), + 5: dict( + name="left_shoulder", + id=5, + color=[51, 153, 255], + type="upper", + swap="right_shoulder", + ), + 6: dict( + name="right_shoulder", + id=6, + color=[51, 153, 255], + type="upper", + swap="left_shoulder", + ), + 7: dict( + name="left_elbow", + id=7, + color=[51, 153, 255], + type="upper", + swap="right_elbow", + ), + 8: dict( + name="right_elbow", + id=8, + color=[51, 153, 255], + type="upper", + swap="left_elbow", + ), + 9: dict( + name="left_hip", id=9, color=[51, 153, 255], type="lower", swap="right_hip" + ), + 10: dict( + name="right_hip", id=10, color=[51, 153, 255], type="lower", swap="left_hip" + ), + 11: dict( + name="left_knee", + id=11, + color=[51, 153, 255], + type="lower", + swap="right_knee", + ), + 12: dict( + name="right_knee", + id=12, + color=[51, 153, 255], + type="lower", + swap="left_knee", + ), + 13: dict( + name="left_ankle", + id=13, + color=[51, 153, 255], + type="lower", + swap="right_ankle", + ), + 14: dict( + name="right_ankle", + id=14, + color=[51, 153, 255], + type="lower", + swap="left_ankle", + ), + 15: dict( + name="left_big_toe", + id=15, + color=[51, 153, 255], + type="lower", + swap="right_big_toe", + ), + 16: dict( + name="left_small_toe", + id=16, + color=[51, 153, 255], + type="lower", + swap="right_small_toe", + ), + 17: dict( + name="left_heel", + id=17, + color=[51, 153, 255], + type="lower", + swap="right_heel", + ), + 18: dict( + name="right_big_toe", + id=18, + color=[51, 153, 255], + type="lower", + swap="left_big_toe", + ), + 19: dict( + name="right_small_toe", + id=19, + color=[51, 153, 255], + type="lower", + swap="left_small_toe", + ), + 20: dict( + name="right_heel", + id=20, + color=[51, 153, 255], + type="lower", + swap="left_heel", + ), + 21: dict( + name="right_thumb4", + id=21, + color=[51, 153, 255], + type="upper", + swap="left_thumb4", + ), + 22: dict( + name="right_thumb3", + id=22, + color=[51, 153, 255], + type="upper", + swap="left_thumb3", + ), + 23: dict( + name="right_thumb2", + id=23, + color=[51, 153, 255], + type="upper", + swap="left_thumb2", + ), + 24: dict( + name="right_thumb_third_joint", + id=24, + color=[51, 153, 255], + type="upper", + swap="left_thumb_third_joint", + ), + 25: dict( + name="right_forefinger4", + id=25, + color=[51, 153, 255], + type="upper", + swap="left_forefinger4", + ), + 26: dict( + name="right_forefinger3", + id=26, + color=[51, 153, 255], + type="upper", + swap="left_forefinger3", + ), + 27: dict( + name="right_forefinger2", + id=27, + color=[51, 153, 255], + type="upper", + swap="left_forefinger2", + ), + 28: dict( + name="right_forefinger_third_joint", + id=28, + color=[51, 153, 255], + type="upper", + swap="left_forefinger_third_joint", + ), + 29: dict( + name="right_middle_finger4", + id=29, + color=[51, 153, 255], + type="upper", + swap="left_middle_finger4", + ), + 30: dict( + name="right_middle_finger3", + id=30, + color=[51, 153, 255], + type="upper", + swap="left_middle_finger3", + ), + 31: dict( + name="right_middle_finger2", + id=31, + color=[51, 153, 255], + type="upper", + swap="left_middle_finger2", + ), + 32: dict( + name="right_middle_finger_third_joint", + id=32, + color=[51, 153, 255], + type="upper", + swap="left_middle_finger_third_joint", + ), + 33: dict( + name="right_ring_finger4", + id=33, + color=[51, 153, 255], + type="upper", + swap="left_ring_finger4", + ), + 34: dict( + name="right_ring_finger3", + id=34, + color=[51, 153, 255], + type="upper", + swap="left_ring_finger3", + ), + 35: dict( + name="right_ring_finger2", + id=35, + color=[51, 153, 255], + type="upper", + swap="left_ring_finger2", + ), + 36: dict( + name="right_ring_finger_third_joint", + id=36, + color=[51, 153, 255], + type="upper", + swap="left_ring_finger_third_joint", + ), + 37: dict( + name="right_pinky_finger4", + id=37, + color=[51, 153, 255], + type="upper", + swap="left_pinky_finger4", + ), + 38: dict( + name="right_pinky_finger3", + id=38, + color=[51, 153, 255], + type="upper", + swap="left_pinky_finger3", + ), + 39: dict( + name="right_pinky_finger2", + id=39, + color=[51, 153, 255], + type="upper", + swap="left_pinky_finger2", + ), + 40: dict( + name="right_pinky_finger_third_joint", + id=40, + color=[51, 153, 255], + type="upper", + swap="left_pinky_finger_third_joint", + ), + 41: dict( + name="right_wrist", + id=41, + color=[51, 153, 255], + type="upper", + swap="left_wrist", + ), + 42: dict( + name="left_thumb4", + id=42, + color=[51, 153, 255], + type="upper", + swap="right_thumb4", + ), + 43: dict( + name="left_thumb3", + id=43, + color=[51, 153, 255], + type="upper", + swap="right_thumb3", + ), + 44: dict( + name="left_thumb2", + id=44, + color=[51, 153, 255], + type="upper", + swap="right_thumb2", + ), + 45: dict( + name="left_thumb_third_joint", + id=45, + color=[51, 153, 255], + type="upper", + swap="right_thumb_third_joint", + ), ## doesnt match with wholebody + 46: dict( + name="left_forefinger4", + id=46, + color=[51, 153, 255], + type="upper", + swap="right_forefinger4", + ), + 47: dict( + name="left_forefinger3", + id=47, + color=[51, 153, 255], + type="upper", + swap="right_forefinger3", + ), + 48: dict( + name="left_forefinger2", + id=48, + color=[51, 153, 255], + type="upper", + swap="right_forefinger2", + ), + 49: dict( + name="left_forefinger_third_joint", + id=49, + color=[51, 153, 255], + type="upper", + swap="right_forefinger_third_joint", + ), + 50: dict( + name="left_middle_finger4", + id=50, + color=[51, 153, 255], + type="upper", + swap="right_middle_finger4", + ), + 51: dict( + name="left_middle_finger3", + id=51, + color=[51, 153, 255], + type="upper", + swap="right_middle_finger3", + ), + 52: dict( + name="left_middle_finger2", + id=52, + color=[51, 153, 255], + type="upper", + swap="right_middle_finger2", + ), + 53: dict( + name="left_middle_finger_third_joint", + id=53, + color=[51, 153, 255], + type="upper", + swap="right_middle_finger_third_joint", + ), + 54: dict( + name="left_ring_finger4", + id=54, + color=[51, 153, 255], + type="upper", + swap="right_ring_finger4", + ), + 55: dict( + name="left_ring_finger3", + id=55, + color=[51, 153, 255], + type="upper", + swap="right_ring_finger3", + ), + 56: dict( + name="left_ring_finger2", + id=56, + color=[51, 153, 255], + type="upper", + swap="right_ring_finger2", + ), + 57: dict( + name="left_ring_finger_third_joint", + id=57, + color=[51, 153, 255], + type="upper", + swap="right_ring_finger_third_joint", + ), + 58: dict( + name="left_pinky_finger4", + id=58, + color=[51, 153, 255], + type="upper", + swap="right_pinky_finger4", + ), + 59: dict( + name="left_pinky_finger3", + id=59, + color=[51, 153, 255], + type="upper", + swap="right_pinky_finger3", + ), + 60: dict( + name="left_pinky_finger2", + id=60, + color=[51, 153, 255], + type="upper", + swap="right_pinky_finger2", + ), + 61: dict( + name="left_pinky_finger_third_joint", + id=61, + color=[51, 153, 255], + type="upper", + swap="right_pinky_finger_third_joint", + ), + 62: dict( + name="left_wrist", + id=62, + color=[51, 153, 255], + type="upper", + swap="right_wrist", + ), + 63: dict( + name="left_olecranon", + id=63, + color=[51, 153, 255], + type="", + swap="right_olecranon", + ), + 64: dict( + name="right_olecranon", + id=64, + color=[51, 153, 255], + type="", + swap="left_olecranon", + ), + 65: dict( + name="left_cubital_fossa", + id=65, + color=[51, 153, 255], + type="", + swap="right_cubital_fossa", + ), + 66: dict( + name="right_cubital_fossa", + id=66, + color=[51, 153, 255], + type="", + swap="left_cubital_fossa", + ), + 67: dict( + name="left_acromion", + id=67, + color=[51, 153, 255], + type="", + swap="right_acromion", + ), + 68: dict( + name="right_acromion", + id=68, + color=[51, 153, 255], + type="", + swap="left_acromion", + ), + 69: dict(name="neck", id=69, color=[51, 153, 255], type="", swap=""), + 70: dict( + name="center_of_glabella", id=70, color=[255, 255, 255], type="", swap="" + ), + 71: dict( + name="center_of_nose_root", id=71, color=[255, 255, 255], type="", swap="" + ), + 72: dict( + name="tip_of_nose_bridge", id=72, color=[255, 255, 255], type="", swap="" + ), + 73: dict( + name="midpoint_1_of_nose_bridge", + id=73, + color=[255, 255, 255], + type="", + swap="", + ), + 74: dict( + name="midpoint_2_of_nose_bridge", + id=74, + color=[255, 255, 255], + type="", + swap="", + ), + 75: dict( + name="midpoint_3_of_nose_bridge", + id=75, + color=[255, 255, 255], + type="", + swap="", + ), + 76: dict( + name="center_of_labiomental_groove", + id=76, + color=[255, 255, 255], + type="", + swap="", + ), + 77: dict(name="tip_of_chin", id=77, color=[255, 255, 255], type="", swap=""), + 78: dict( + name="upper_startpoint_of_r_eyebrow", + id=78, + color=[255, 255, 255], + type="", + swap="upper_startpoint_of_l_eyebrow", + ), + 79: dict( + name="lower_startpoint_of_r_eyebrow", + id=79, + color=[255, 255, 255], + type="", + swap="lower_startpoint_of_l_eyebrow", + ), + 80: dict( + name="end_of_r_eyebrow", + id=80, + color=[255, 255, 255], + type="", + swap="end_of_l_eyebrow", + ), + 81: dict( + name="upper_midpoint_1_of_r_eyebrow", + id=81, + color=[255, 255, 255], + type="", + swap="upper_midpoint_1_of_l_eyebrow", + ), + 82: dict( + name="lower_midpoint_1_of_r_eyebrow", + id=82, + color=[255, 255, 255], + type="", + swap="lower_midpoint_1_of_l_eyebrow", + ), + 83: dict( + name="upper_midpoint_2_of_r_eyebrow", + id=83, + color=[255, 255, 255], + type="", + swap="upper_midpoint_3_of_l_eyebrow", + ), + 84: dict( + name="upper_midpoint_3_of_r_eyebrow", + id=84, + color=[255, 255, 255], + type="", + swap="upper_midpoint_2_of_l_eyebrow", + ), + 85: dict( + name="lower_midpoint_2_of_r_eyebrow", + id=85, + color=[255, 255, 255], + type="", + swap="lower_midpoint_3_of_l_eyebrow", + ), + 86: dict( + name="lower_midpoint_3_of_r_eyebrow", + id=86, + color=[255, 255, 255], + type="", + swap="lower_midpoint_2_of_l_eyebrow", + ), + 87: dict( + name="upper_startpoint_of_l_eyebrow", + id=87, + color=[255, 255, 255], + type="", + swap="upper_startpoint_of_r_eyebrow", + ), + 88: dict( + name="lower_startpoint_of_l_eyebrow", + id=88, + color=[255, 255, 255], + type="", + swap="lower_startpoint_of_r_eyebrow", + ), + 89: dict( + name="end_of_l_eyebrow", + id=89, + color=[255, 255, 255], + type="", + swap="end_of_r_eyebrow", + ), + 90: dict( + name="upper_midpoint_1_of_l_eyebrow", + id=90, + color=[255, 255, 255], + type="", + swap="upper_midpoint_1_of_r_eyebrow", + ), + 91: dict( + name="lower_midpoint_1_of_l_eyebrow", + id=91, + color=[255, 255, 255], + type="", + swap="lower_midpoint_1_of_r_eyebrow", + ), + 92: dict( + name="upper_midpoint_2_of_l_eyebrow", + id=92, + color=[255, 255, 255], + type="", + swap="upper_midpoint_3_of_r_eyebrow", + ), + 93: dict( + name="upper_midpoint_3_of_l_eyebrow", + id=93, + color=[255, 255, 255], + type="", + swap="upper_midpoint_2_of_r_eyebrow", + ), + 94: dict( + name="lower_midpoint_2_of_l_eyebrow", + id=94, + color=[255, 255, 255], + type="", + swap="lower_midpoint_3_of_r_eyebrow", + ), + 95: dict( + name="lower_midpoint_3_of_l_eyebrow", + id=95, + color=[255, 255, 255], + type="", + swap="lower_midpoint_2_of_r_eyebrow", + ), + 96: dict( + name="l_inner_end_of_upper_lash_line", + id=96, + color=[192, 64, 128], + type="", + swap="r_inner_end_of_upper_lash_line", + ), + 97: dict( + name="l_outer_end_of_upper_lash_line", + id=97, + color=[192, 64, 128], + type="", + swap="r_outer_end_of_upper_lash_line", + ), + 98: dict( + name="l_centerpoint_of_upper_lash_line", + id=98, + color=[192, 64, 128], + type="", + swap="r_centerpoint_of_upper_lash_line", + ), + 99: dict( + name="l_midpoint_2_of_upper_lash_line", + id=99, + color=[192, 64, 128], + type="", + swap="r_midpoint_1_of_upper_lash_line", + ), + 100: dict( + name="l_midpoint_1_of_upper_lash_line", + id=100, + color=[192, 64, 128], + type="", + swap="r_midpoint_2_of_upper_lash_line", + ), + 101: dict( + name="l_midpoint_6_of_upper_lash_line", + id=101, + color=[192, 64, 128], + type="", + swap="r_midpoint_3_of_upper_lash_line", + ), + 102: dict( + name="l_midpoint_5_of_upper_lash_line", + id=102, + color=[192, 64, 128], + type="", + swap="r_midpoint_4_of_upper_lash_line", + ), + 103: dict( + name="l_midpoint_4_of_upper_lash_line", + id=103, + color=[192, 64, 128], + type="", + swap="r_midpoint_5_of_upper_lash_line", + ), + 104: dict( + name="l_midpoint_3_of_upper_lash_line", + id=104, + color=[192, 64, 128], + type="", + swap="r_midpoint_6_of_upper_lash_line", + ), + 105: dict( + name="l_outer_end_of_upper_eyelid_line", + id=105, + color=[192, 64, 128], + type="", + swap="r_outer_end_of_upper_eyelid_line", + ), + 106: dict( + name="l_midpoint_6_of_upper_eyelid_line", + id=106, + color=[192, 64, 128], + type="", + swap="r_midpoint_3_of_upper_eyelid_line", + ), + 107: dict( + name="l_midpoint_2_of_upper_eyelid_line", + id=107, + color=[192, 64, 128], + type="", + swap="r_midpoint_1_of_upper_eyelid_line", + ), + 108: dict( + name="l_midpoint_5_of_upper_eyelid_line", + id=108, + color=[192, 64, 128], + type="", + swap="r_midpoint_4_of_upper_eyelid_line", + ), + 109: dict( + name="l_centerpoint_of_upper_eyelid_line", + id=109, + color=[192, 64, 128], + type="", + swap="r_centerpoint_of_upper_eyelid_line", + ), + 110: dict( + name="l_midpoint_4_of_upper_eyelid_line", + id=110, + color=[192, 64, 128], + type="", + swap="r_midpoint_5_of_upper_eyelid_line", + ), + 111: dict( + name="l_midpoint_1_of_upper_eyelid_line", + id=111, + color=[192, 64, 128], + type="", + swap="r_midpoint_2_of_upper_eyelid_line", + ), + 112: dict( + name="l_midpoint_3_of_upper_eyelid_line", + id=112, + color=[192, 64, 128], + type="", + swap="r_midpoint_6_of_upper_eyelid_line", + ), + 113: dict( + name="l_midpoint_6_of_upper_crease_line", + id=113, + color=[192, 64, 128], + type="", + swap="r_midpoint_3_of_upper_crease_line", + ), + 114: dict( + name="l_midpoint_2_of_upper_crease_line", + id=114, + color=[192, 64, 128], + type="", + swap="r_midpoint_1_of_upper_crease_line", + ), + 115: dict( + name="l_midpoint_5_of_upper_crease_line", + id=115, + color=[192, 64, 128], + type="", + swap="r_midpoint_4_of_upper_crease_line", + ), + 116: dict( + name="l_centerpoint_of_upper_crease_line", + id=116, + color=[192, 64, 128], + type="", + swap="r_centerpoint_of_upper_crease_line", + ), + 117: dict( + name="l_midpoint_4_of_upper_crease_line", + id=117, + color=[192, 64, 128], + type="", + swap="r_midpoint_5_of_upper_crease_line", + ), + 118: dict( + name="l_midpoint_1_of_upper_crease_line", + id=118, + color=[192, 64, 128], + type="", + swap="r_midpoint_2_of_upper_crease_line", + ), + 119: dict( + name="l_midpoint_3_of_upper_crease_line", + id=119, + color=[192, 64, 128], + type="", + swap="r_midpoint_6_of_upper_crease_line", + ), + 120: dict( + name="r_inner_end_of_upper_lash_line", + id=120, + color=[64, 32, 192], + type="", + swap="l_inner_end_of_upper_lash_line", + ), + 121: dict( + name="r_outer_end_of_upper_lash_line", + id=121, + color=[64, 32, 192], + type="", + swap="l_outer_end_of_upper_lash_line", + ), + 122: dict( + name="r_centerpoint_of_upper_lash_line", + id=122, + color=[64, 32, 192], + type="", + swap="l_centerpoint_of_upper_lash_line", + ), + 123: dict( + name="r_midpoint_1_of_upper_lash_line", + id=123, + color=[64, 32, 192], + type="", + swap="l_midpoint_2_of_upper_lash_line", + ), + 124: dict( + name="r_midpoint_2_of_upper_lash_line", + id=124, + color=[64, 32, 192], + type="", + swap="l_midpoint_1_of_upper_lash_line", + ), + 125: dict( + name="r_midpoint_3_of_upper_lash_line", + id=125, + color=[64, 32, 192], + type="", + swap="l_midpoint_6_of_upper_lash_line", + ), + 126: dict( + name="r_midpoint_4_of_upper_lash_line", + id=126, + color=[64, 32, 192], + type="", + swap="l_midpoint_5_of_upper_lash_line", + ), + 127: dict( + name="r_midpoint_5_of_upper_lash_line", + id=127, + color=[64, 32, 192], + type="", + swap="l_midpoint_4_of_upper_lash_line", + ), + 128: dict( + name="r_midpoint_6_of_upper_lash_line", + id=128, + color=[64, 32, 192], + type="", + swap="l_midpoint_3_of_upper_lash_line", + ), + 129: dict( + name="r_outer_end_of_upper_eyelid_line", + id=129, + color=[64, 32, 192], + type="", + swap="l_outer_end_of_upper_eyelid_line", + ), + 130: dict( + name="r_midpoint_3_of_upper_eyelid_line", + id=130, + color=[64, 32, 192], + type="", + swap="l_midpoint_6_of_upper_eyelid_line", + ), + 131: dict( + name="r_midpoint_1_of_upper_eyelid_line", + id=131, + color=[64, 32, 192], + type="", + swap="l_midpoint_2_of_upper_eyelid_line", + ), + 132: dict( + name="r_midpoint_4_of_upper_eyelid_line", + id=132, + color=[64, 32, 192], + type="", + swap="l_midpoint_5_of_upper_eyelid_line", + ), + 133: dict( + name="r_centerpoint_of_upper_eyelid_line", + id=133, + color=[64, 32, 192], + type="", + swap="l_centerpoint_of_upper_eyelid_line", + ), + 134: dict( + name="r_midpoint_5_of_upper_eyelid_line", + id=134, + color=[64, 32, 192], + type="", + swap="l_midpoint_4_of_upper_eyelid_line", + ), + 135: dict( + name="r_midpoint_2_of_upper_eyelid_line", + id=135, + color=[64, 32, 192], + type="", + swap="l_midpoint_1_of_upper_eyelid_line", + ), + 136: dict( + name="r_midpoint_6_of_upper_eyelid_line", + id=136, + color=[64, 32, 192], + type="", + swap="l_midpoint_3_of_upper_eyelid_line", + ), + 137: dict( + name="r_midpoint_3_of_upper_crease_line", + id=137, + color=[64, 32, 192], + type="", + swap="l_midpoint_6_of_upper_crease_line", + ), + 138: dict( + name="r_midpoint_1_of_upper_crease_line", + id=138, + color=[64, 32, 192], + type="", + swap="l_midpoint_2_of_upper_crease_line", + ), + 139: dict( + name="r_midpoint_4_of_upper_crease_line", + id=139, + color=[64, 32, 192], + type="", + swap="l_midpoint_5_of_upper_crease_line", + ), + 140: dict( + name="r_centerpoint_of_upper_crease_line", + id=140, + color=[64, 32, 192], + type="", + swap="l_centerpoint_of_upper_crease_line", + ), + 141: dict( + name="r_midpoint_5_of_upper_crease_line", + id=141, + color=[64, 32, 192], + type="", + swap="l_midpoint_4_of_upper_crease_line", + ), + 142: dict( + name="r_midpoint_2_of_upper_crease_line", + id=142, + color=[64, 32, 192], + type="", + swap="l_midpoint_1_of_upper_crease_line", + ), + 143: dict( + name="r_midpoint_6_of_upper_crease_line", + id=143, + color=[64, 32, 192], + type="", + swap="l_midpoint_3_of_upper_crease_line", + ), + 144: dict( + name="l_inner_end_of_lower_lash_line", + id=144, + color=[64, 192, 128], + type="", + swap="r_inner_end_of_lower_lash_line", + ), + 145: dict( + name="l_outer_end_of_lower_lash_line", + id=145, + color=[64, 192, 128], + type="", + swap="r_outer_end_of_lower_lash_line", + ), + 146: dict( + name="l_centerpoint_of_lower_lash_line", + id=146, + color=[64, 192, 128], + type="", + swap="r_centerpoint_of_lower_lash_line", + ), + 147: dict( + name="l_midpoint_2_of_lower_lash_line", + id=147, + color=[64, 192, 128], + type="", + swap="r_midpoint_1_of_lower_lash_line", + ), + 148: dict( + name="l_midpoint_1_of_lower_lash_line", + id=148, + color=[64, 192, 128], + type="", + swap="r_midpoint_2_of_lower_lash_line", + ), + 149: dict( + name="l_midpoint_6_of_lower_lash_line", + id=149, + color=[64, 192, 128], + type="", + swap="r_midpoint_3_of_lower_lash_line", + ), + 150: dict( + name="l_midpoint_5_of_lower_lash_line", + id=150, + color=[64, 192, 128], + type="", + swap="r_midpoint_4_of_lower_lash_line", + ), + 151: dict( + name="l_midpoint_4_of_lower_lash_line", + id=151, + color=[64, 192, 128], + type="", + swap="r_midpoint_5_of_lower_lash_line", + ), + 152: dict( + name="l_midpoint_3_of_lower_lash_line", + id=152, + color=[64, 192, 128], + type="", + swap="r_midpoint_6_of_lower_lash_line", + ), + 153: dict( + name="l_outer_end_of_lower_eyelid_line", + id=153, + color=[64, 192, 128], + type="", + swap="r_outer_end_of_lower_eyelid_line", + ), + 154: dict( + name="l_midpoint_6_of_lower_eyelid_line", + id=154, + color=[64, 192, 128], + type="", + swap="r_midpoint_3_of_lower_eyelid_line", + ), + 155: dict( + name="l_midpoint_2_of_lower_eyelid_line", + id=155, + color=[64, 192, 128], + type="", + swap="r_midpoint_1_of_lower_eyelid_line", + ), + 156: dict( + name="l_midpoint_5_of_lower_eyelid_line", + id=156, + color=[64, 192, 128], + type="", + swap="r_midpoint_4_of_lower_eyelid_line", + ), + 157: dict( + name="l_centerpoint_of_lower_eyelid_line", + id=157, + color=[64, 192, 128], + type="", + swap="r_centerpoint_of_lower_eyelid_line", + ), + 158: dict( + name="l_midpoint_4_of_lower_eyelid_line", + id=158, + color=[64, 192, 128], + type="", + swap="r_midpoint_5_of_lower_eyelid_line", + ), + 159: dict( + name="l_midpoint_1_of_lower_eyelid_line", + id=159, + color=[64, 192, 128], + type="", + swap="r_midpoint_2_of_lower_eyelid_line", + ), + 160: dict( + name="l_midpoint_3_of_lower_eyelid_line", + id=160, + color=[64, 192, 128], + type="", + swap="r_midpoint_6_of_lower_eyelid_line", + ), + 161: dict( + name="r_inner_end_of_lower_lash_line", + id=161, + color=[64, 192, 32], + type="", + swap="l_inner_end_of_lower_lash_line", + ), + 162: dict( + name="r_outer_end_of_lower_lash_line", + id=162, + color=[64, 192, 32], + type="", + swap="l_outer_end_of_lower_lash_line", + ), + 163: dict( + name="r_centerpoint_of_lower_lash_line", + id=163, + color=[64, 192, 32], + type="", + swap="l_centerpoint_of_lower_lash_line", + ), + 164: dict( + name="r_midpoint_1_of_lower_lash_line", + id=164, + color=[64, 192, 32], + type="", + swap="l_midpoint_2_of_lower_lash_line", + ), + 165: dict( + name="r_midpoint_2_of_lower_lash_line", + id=165, + color=[64, 192, 32], + type="", + swap="l_midpoint_1_of_lower_lash_line", + ), + 166: dict( + name="r_midpoint_3_of_lower_lash_line", + id=166, + color=[64, 192, 32], + type="", + swap="l_midpoint_6_of_lower_lash_line", + ), + 167: dict( + name="r_midpoint_4_of_lower_lash_line", + id=167, + color=[64, 192, 32], + type="", + swap="l_midpoint_5_of_lower_lash_line", + ), + 168: dict( + name="r_midpoint_5_of_lower_lash_line", + id=168, + color=[64, 192, 32], + type="", + swap="l_midpoint_4_of_lower_lash_line", + ), + 169: dict( + name="r_midpoint_6_of_lower_lash_line", + id=169, + color=[64, 192, 32], + type="", + swap="l_midpoint_3_of_lower_lash_line", + ), + 170: dict( + name="r_outer_end_of_lower_eyelid_line", + id=170, + color=[64, 192, 32], + type="", + swap="l_outer_end_of_lower_eyelid_line", + ), + 171: dict( + name="r_midpoint_3_of_lower_eyelid_line", + id=171, + color=[64, 192, 32], + type="", + swap="l_midpoint_6_of_lower_eyelid_line", + ), + 172: dict( + name="r_midpoint_1_of_lower_eyelid_line", + id=172, + color=[64, 192, 32], + type="", + swap="l_midpoint_2_of_lower_eyelid_line", + ), + 173: dict( + name="r_midpoint_4_of_lower_eyelid_line", + id=173, + color=[64, 192, 32], + type="", + swap="l_midpoint_5_of_lower_eyelid_line", + ), + 174: dict( + name="r_centerpoint_of_lower_eyelid_line", + id=174, + color=[64, 192, 32], + type="", + swap="l_centerpoint_of_lower_eyelid_line", + ), + 175: dict( + name="r_midpoint_5_of_lower_eyelid_line", + id=175, + color=[64, 192, 32], + type="", + swap="l_midpoint_4_of_lower_eyelid_line", + ), + 176: dict( + name="r_midpoint_2_of_lower_eyelid_line", + id=176, + color=[64, 192, 32], + type="", + swap="l_midpoint_1_of_lower_eyelid_line", + ), + 177: dict( + name="r_midpoint_6_of_lower_eyelid_line", + id=177, + color=[64, 192, 32], + type="", + swap="l_midpoint_3_of_lower_eyelid_line", + ), + 178: dict(name="tip_of_nose", id=178, color=[0, 192, 0], type="", swap=""), + 179: dict( + name="bottom_center_of_nose", id=179, color=[0, 192, 0], type="", swap="" + ), + 180: dict( + name="r_outer_corner_of_nose", + id=180, + color=[0, 192, 0], + type="", + swap="l_outer_corner_of_nose", + ), + 181: dict( + name="l_outer_corner_of_nose", + id=181, + color=[0, 192, 0], + type="", + swap="r_outer_corner_of_nose", + ), + 182: dict( + name="inner_corner_of_r_nostril", + id=182, + color=[0, 192, 0], + type="", + swap="inner_corner_of_l_nostril", + ), + 183: dict( + name="outer_corner_of_r_nostril", + id=183, + color=[0, 192, 0], + type="", + swap="outer_corner_of_l_nostril", + ), + 184: dict( + name="upper_corner_of_r_nostril", + id=184, + color=[0, 192, 0], + type="", + swap="upper_corner_of_l_nostril", + ), + 185: dict( + name="inner_corner_of_l_nostril", + id=185, + color=[0, 192, 0], + type="", + swap="inner_corner_of_r_nostril", + ), + 186: dict( + name="outer_corner_of_l_nostril", + id=186, + color=[0, 192, 0], + type="", + swap="outer_corner_of_r_nostril", + ), + 187: dict( + name="upper_corner_of_l_nostril", + id=187, + color=[0, 192, 0], + type="", + swap="upper_corner_of_r_nostril", + ), + 188: dict( + name="r_outer_corner_of_mouth", + id=188, + color=[192, 0, 0], + type="", + swap="l_outer_corner_of_mouth", + ), + 189: dict( + name="l_outer_corner_of_mouth", + id=189, + color=[192, 0, 0], + type="", + swap="r_outer_corner_of_mouth", + ), + 190: dict( + name="center_of_cupid_bow", id=190, color=[192, 0, 0], type="", swap="" + ), + 191: dict( + name="center_of_lower_outer_lip", + id=191, + color=[192, 0, 0], + type="", + swap="", + ), + 192: dict( + name="midpoint_1_of_upper_outer_lip", + id=192, + color=[192, 0, 0], + type="", + swap="midpoint_2_of_upper_outer_lip", + ), + 193: dict( + name="midpoint_2_of_upper_outer_lip", + id=193, + color=[192, 0, 0], + type="", + swap="midpoint_1_of_upper_outer_lip", + ), + 194: dict( + name="midpoint_1_of_lower_outer_lip", + id=194, + color=[192, 0, 0], + type="", + swap="midpoint_2_of_lower_outer_lip", + ), + 195: dict( + name="midpoint_2_of_lower_outer_lip", + id=195, + color=[192, 0, 0], + type="", + swap="midpoint_1_of_lower_outer_lip", + ), + 196: dict( + name="midpoint_3_of_upper_outer_lip", + id=196, + color=[192, 0, 0], + type="", + swap="midpoint_6_of_upper_outer_lip", + ), + 197: dict( + name="midpoint_4_of_upper_outer_lip", + id=197, + color=[192, 0, 0], + type="", + swap="midpoint_5_of_upper_outer_lip", + ), + 198: dict( + name="midpoint_5_of_upper_outer_lip", + id=198, + color=[192, 0, 0], + type="", + swap="midpoint_4_of_upper_outer_lip", + ), + 199: dict( + name="midpoint_6_of_upper_outer_lip", + id=199, + color=[192, 0, 0], + type="", + swap="midpoint_3_of_upper_outer_lip", + ), + 200: dict( + name="midpoint_3_of_lower_outer_lip", + id=200, + color=[192, 0, 0], + type="", + swap="midpoint_6_of_lower_outer_lip", + ), + 201: dict( + name="midpoint_4_of_lower_outer_lip", + id=201, + color=[192, 0, 0], + type="", + swap="midpoint_5_of_lower_outer_lip", + ), + 202: dict( + name="midpoint_5_of_lower_outer_lip", + id=202, + color=[192, 0, 0], + type="", + swap="midpoint_4_of_lower_outer_lip", + ), + 203: dict( + name="midpoint_6_of_lower_outer_lip", + id=203, + color=[192, 0, 0], + type="", + swap="midpoint_3_of_lower_outer_lip", + ), + 204: dict( + name="r_inner_corner_of_mouth", + id=204, + color=[0, 192, 192], + type="", + swap="l_inner_corner_of_mouth", + ), + 205: dict( + name="l_inner_corner_of_mouth", + id=205, + color=[0, 192, 192], + type="", + swap="r_inner_corner_of_mouth", + ), + 206: dict( + name="center_of_upper_inner_lip", + id=206, + color=[0, 192, 192], + type="", + swap="", + ), + 207: dict( + name="center_of_lower_inner_lip", + id=207, + color=[0, 192, 192], + type="", + swap="", + ), + 208: dict( + name="midpoint_1_of_upper_inner_lip", + id=208, + color=[0, 192, 192], + type="", + swap="midpoint_2_of_upper_inner_lip", + ), + 209: dict( + name="midpoint_2_of_upper_inner_lip", + id=209, + color=[0, 192, 192], + type="", + swap="midpoint_1_of_upper_inner_lip", + ), + 210: dict( + name="midpoint_1_of_lower_inner_lip", + id=210, + color=[0, 192, 192], + type="", + swap="midpoint_2_of_lower_inner_lip", + ), + 211: dict( + name="midpoint_2_of_lower_inner_lip", + id=211, + color=[0, 192, 192], + type="", + swap="midpoint_1_of_lower_inner_lip", + ), + 212: dict( + name="midpoint_3_of_upper_inner_lip", + id=212, + color=[0, 192, 192], + type="", + swap="midpoint_6_of_upper_inner_lip", + ), + 213: dict( + name="midpoint_4_of_upper_inner_lip", + id=213, + color=[0, 192, 192], + type="", + swap="midpoint_5_of_upper_inner_lip", + ), + 214: dict( + name="midpoint_5_of_upper_inner_lip", + id=214, + color=[0, 192, 192], + type="", + swap="midpoint_4_of_upper_inner_lip", + ), + 215: dict( + name="midpoint_6_of_upper_inner_lip", + id=215, + color=[0, 192, 192], + type="", + swap="midpoint_3_of_upper_inner_lip", + ), + 216: dict( + name="midpoint_3_of_lower_inner_lip", + id=216, + color=[0, 192, 192], + type="", + swap="midpoint_6_of_lower_inner_lip", + ), + 217: dict( + name="midpoint_4_of_lower_inner_lip", + id=217, + color=[0, 192, 192], + type="", + swap="midpoint_5_of_lower_inner_lip", + ), + 218: dict( + name="midpoint_5_of_lower_inner_lip", + id=218, + color=[0, 192, 192], + type="", + swap="midpoint_4_of_lower_inner_lip", + ), + 219: dict( + name="midpoint_6_of_lower_inner_lip", + id=219, + color=[0, 192, 192], + type="", + swap="midpoint_3_of_lower_inner_lip", + ), + 220: dict(name="teeth_1", id=220, color=[51, 153, 255], type="", swap=""), + 221: dict(name="teeth_2", id=221, color=[51, 153, 255], type="", swap=""), + 222: dict(name="teeth_3", id=222, color=[51, 153, 255], type="", swap=""), + 223: dict(name="teeth_4", id=223, color=[51, 153, 255], type="", swap=""), + 224: dict(name="teeth_5", id=224, color=[51, 153, 255], type="", swap=""), + 225: dict(name="teeth_6", id=225, color=[51, 153, 255], type="", swap=""), + 226: dict(name="teeth_7", id=226, color=[51, 153, 255], type="", swap=""), + 227: dict(name="teeth_8", id=227, color=[51, 153, 255], type="", swap=""), + 228: dict(name="teeth_9", id=228, color=[51, 153, 255], type="", swap=""), + 229: dict(name="teeth_10", id=229, color=[51, 153, 255], type="", swap=""), + 230: dict(name="teeth_11", id=230, color=[51, 153, 255], type="", swap=""), + 231: dict(name="teeth_12", id=231, color=[51, 153, 255], type="", swap=""), + 232: dict(name="teeth_13", id=232, color=[51, 153, 255], type="", swap=""), + 233: dict(name="teeth_14", id=233, color=[51, 153, 255], type="", swap=""), + 234: dict(name="teeth_15", id=234, color=[51, 153, 255], type="", swap=""), + 235: dict(name="teeth_16", id=235, color=[51, 153, 255], type="", swap=""), + 236: dict(name="teeth_17", id=236, color=[51, 153, 255], type="", swap=""), + 237: dict(name="teeth_18", id=237, color=[51, 153, 255], type="", swap=""), + 238: dict(name="teeth_19", id=238, color=[51, 153, 255], type="", swap=""), + 239: dict(name="teeth_20", id=239, color=[51, 153, 255], type="", swap=""), + 240: dict(name="teeth_21", id=240, color=[51, 153, 255], type="", swap=""), + 241: dict(name="teeth_22", id=241, color=[51, 153, 255], type="", swap=""), + 242: dict(name="teeth_23", id=242, color=[51, 153, 255], type="", swap=""), + 243: dict(name="teeth_24", id=243, color=[51, 153, 255], type="", swap=""), + 244: dict(name="teeth_25", id=244, color=[51, 153, 255], type="", swap=""), + 245: dict(name="teeth_26", id=245, color=[51, 153, 255], type="", swap=""), + 246: dict(name="teeth_27", id=246, color=[51, 153, 255], type="", swap=""), + 247: dict(name="teeth_28", id=247, color=[51, 153, 255], type="", swap=""), + 248: dict(name="teeth_29", id=248, color=[51, 153, 255], type="", swap=""), + 249: dict(name="teeth_30", id=249, color=[51, 153, 255], type="", swap=""), + 250: dict(name="teeth_31", id=250, color=[51, 153, 255], type="", swap=""), + 251: dict(name="teeth_32", id=251, color=[51, 153, 255], type="", swap=""), + 252: dict(name="teeth_33", id=252, color=[51, 153, 255], type="", swap=""), + 253: dict(name="teeth_34", id=253, color=[51, 153, 255], type="", swap=""), + 254: dict(name="teeth_35", id=254, color=[51, 153, 255], type="", swap=""), + 255: dict(name="teeth_36", id=255, color=[51, 153, 255], type="", swap=""), + 256: dict( + name="l_top_end_of_inferior_crus", + id=256, + color=[200, 200, 0], + type="", + swap="r_top_end_of_inferior_crus", + ), + 257: dict( + name="l_top_end_of_superior_crus", + id=257, + color=[200, 200, 0], + type="", + swap="r_top_end_of_superior_crus", + ), + 258: dict( + name="l_start_of_antihelix", + id=258, + color=[200, 200, 0], + type="", + swap="r_start_of_antihelix", + ), + 259: dict( + name="l_end_of_antihelix", + id=259, + color=[200, 200, 0], + type="", + swap="r_end_of_antihelix", + ), + 260: dict( + name="l_midpoint_1_of_antihelix", + id=260, + color=[200, 200, 0], + type="", + swap="r_midpoint_1_of_antihelix", + ), + 261: dict( + name="l_midpoint_1_of_inferior_crus", + id=261, + color=[200, 200, 0], + type="", + swap="r_midpoint_1_of_inferior_crus", + ), + 262: dict( + name="l_midpoint_2_of_antihelix", + id=262, + color=[200, 200, 0], + type="", + swap="r_midpoint_2_of_antihelix", + ), + 263: dict( + name="l_midpoint_3_of_antihelix", + id=263, + color=[200, 200, 0], + type="", + swap="r_midpoint_3_of_antihelix", + ), + 264: dict( + name="l_point_1_of_inner_helix", + id=264, + color=[200, 200, 0], + type="", + swap="r_point_1_of_inner_helix", + ), + 265: dict( + name="l_point_2_of_inner_helix", + id=265, + color=[200, 200, 0], + type="", + swap="r_point_8_of_inner_helix", + ), + 266: dict( + name="l_point_3_of_inner_helix", + id=266, + color=[200, 200, 0], + type="", + swap="r_point_3_of_inner_helix", + ), + 267: dict( + name="l_point_4_of_inner_helix", + id=267, + color=[200, 200, 0], + type="", + swap="r_point_4_of_inner_helix", + ), + 268: dict( + name="l_point_5_of_inner_helix", + id=268, + color=[200, 200, 0], + type="", + swap="r_point_5_of_inner_helix", + ), + 269: dict( + name="l_point_6_of_inner_helix", + id=269, + color=[200, 200, 0], + type="", + swap="r_point_6_of_inner_helix", + ), + 270: dict( + name="l_point_7_of_inner_helix", + id=270, + color=[200, 200, 0], + type="", + swap="r_point_7_of_inner_helix", + ), + 271: dict( + name="l_highest_point_of_antitragus", + id=271, + color=[200, 200, 0], + type="", + swap="r_highest_point_of_antitragus", + ), + 272: dict( + name="l_bottom_point_of_tragus", + id=272, + color=[200, 200, 0], + type="", + swap="r_bottom_point_of_tragus", + ), + 273: dict( + name="l_protruding_point_of_tragus", + id=273, + color=[200, 200, 0], + type="", + swap="r_protruding_point_of_tragus", + ), + 274: dict( + name="l_top_point_of_tragus", + id=274, + color=[200, 200, 0], + type="", + swap="r_top_point_of_tragus", + ), + 275: dict( + name="l_start_point_of_crus_of_helix", + id=275, + color=[200, 200, 0], + type="", + swap="r_start_point_of_crus_of_helix", + ), + 276: dict( + name="l_deepest_point_of_concha", + id=276, + color=[200, 200, 0], + type="", + swap="r_deepest_point_of_concha", + ), + 277: dict( + name="l_tip_of_ear_lobe", + id=277, + color=[200, 200, 0], + type="", + swap="r_tip_of_ear_lobe", + ), + 278: dict( + name="l_midpoint_between_22_15", + id=278, + color=[200, 200, 0], + type="", + swap="r_midpoint_between_22_15", + ), + 279: dict( + name="l_bottom_connecting_point_of_ear_lobe", + id=279, + color=[200, 200, 0], + type="", + swap="r_bottom_connecting_point_of_ear_lobe", + ), + 280: dict( + name="l_top_connecting_point_of_helix", + id=280, + color=[200, 200, 0], + type="", + swap="r_top_connecting_point_of_helix", + ), + 281: dict( + name="l_point_8_of_inner_helix", + id=281, + color=[200, 200, 0], + type="", + swap="r_point_2_of_inner_helix", + ), + 282: dict( + name="r_top_end_of_inferior_crus", + id=282, + color=[0, 200, 200], + type="", + swap="l_top_end_of_inferior_crus", + ), + 283: dict( + name="r_top_end_of_superior_crus", + id=283, + color=[0, 200, 200], + type="", + swap="l_top_end_of_superior_crus", + ), + 284: dict( + name="r_start_of_antihelix", + id=284, + color=[0, 200, 200], + type="", + swap="l_start_of_antihelix", + ), + 285: dict( + name="r_end_of_antihelix", + id=285, + color=[0, 200, 200], + type="", + swap="l_end_of_antihelix", + ), + 286: dict( + name="r_midpoint_1_of_antihelix", + id=286, + color=[0, 200, 200], + type="", + swap="l_midpoint_1_of_antihelix", + ), + 287: dict( + name="r_midpoint_1_of_inferior_crus", + id=287, + color=[0, 200, 200], + type="", + swap="l_midpoint_1_of_inferior_crus", + ), + 288: dict( + name="r_midpoint_2_of_antihelix", + id=288, + color=[0, 200, 200], + type="", + swap="l_midpoint_2_of_antihelix", + ), + 289: dict( + name="r_midpoint_3_of_antihelix", + id=289, + color=[0, 200, 200], + type="", + swap="l_midpoint_3_of_antihelix", + ), + 290: dict( + name="r_point_1_of_inner_helix", + id=290, + color=[0, 200, 200], + type="", + swap="l_point_1_of_inner_helix", + ), + 291: dict( + name="r_point_8_of_inner_helix", + id=291, + color=[0, 200, 200], + type="", + swap="l_point_2_of_inner_helix", + ), + 292: dict( + name="r_point_3_of_inner_helix", + id=292, + color=[0, 200, 200], + type="", + swap="l_point_3_of_inner_helix", + ), + 293: dict( + name="r_point_4_of_inner_helix", + id=293, + color=[0, 200, 200], + type="", + swap="l_point_4_of_inner_helix", + ), + 294: dict( + name="r_point_5_of_inner_helix", + id=294, + color=[0, 200, 200], + type="", + swap="l_point_5_of_inner_helix", + ), + 295: dict( + name="r_point_6_of_inner_helix", + id=295, + color=[0, 200, 200], + type="", + swap="l_point_6_of_inner_helix", + ), + 296: dict( + name="r_point_7_of_inner_helix", + id=296, + color=[0, 200, 200], + type="", + swap="l_point_7_of_inner_helix", + ), + 297: dict( + name="r_highest_point_of_antitragus", + id=297, + color=[0, 200, 200], + type="", + swap="l_highest_point_of_antitragus", + ), + 298: dict( + name="r_bottom_point_of_tragus", + id=298, + color=[0, 200, 200], + type="", + swap="l_bottom_point_of_tragus", + ), + 299: dict( + name="r_protruding_point_of_tragus", + id=299, + color=[0, 200, 200], + type="", + swap="l_protruding_point_of_tragus", + ), + 300: dict( + name="r_top_point_of_tragus", + id=300, + color=[0, 200, 200], + type="", + swap="l_top_point_of_tragus", + ), + 301: dict( + name="r_start_point_of_crus_of_helix", + id=301, + color=[0, 200, 200], + type="", + swap="l_start_point_of_crus_of_helix", + ), + 302: dict( + name="r_deepest_point_of_concha", + id=302, + color=[0, 200, 200], + type="", + swap="l_deepest_point_of_concha", + ), + 303: dict( + name="r_tip_of_ear_lobe", + id=303, + color=[0, 200, 200], + type="", + swap="l_tip_of_ear_lobe", + ), + 304: dict( + name="r_midpoint_between_22_15", + id=304, + color=[0, 200, 200], + type="", + swap="l_midpoint_between_22_15", + ), + 305: dict( + name="r_bottom_connecting_point_of_ear_lobe", + id=305, + color=[0, 200, 200], + type="", + swap="l_bottom_connecting_point_of_ear_lobe", + ), + 306: dict( + name="r_top_connecting_point_of_helix", + id=306, + color=[0, 200, 200], + type="", + swap="l_top_connecting_point_of_helix", + ), + 307: dict( + name="r_point_2_of_inner_helix", + id=307, + color=[0, 200, 200], + type="", + swap="l_point_8_of_inner_helix", + ), + 308: dict( + name="l_center_of_iris", + id=308, + color=[128, 192, 64], + type="", + swap="r_center_of_iris", + ), + 309: dict( + name="l_border_of_iris_3", + id=309, + color=[128, 192, 64], + type="", + swap="r_border_of_iris_9", + ), + 310: dict( + name="l_border_of_iris_midpoint_1", + id=310, + color=[128, 192, 64], + type="", + swap="r_border_of_iris_midpoint_4", + ), + 311: dict( + name="l_border_of_iris_12", + id=311, + color=[128, 192, 64], + type="", + swap="r_border_of_iris_12", + ), + 312: dict( + name="l_border_of_iris_midpoint_4", + id=312, + color=[128, 192, 64], + type="", + swap="r_border_of_iris_midpoint_1", + ), + 313: dict( + name="l_border_of_iris_9", + id=313, + color=[128, 192, 64], + type="", + swap="r_border_of_iris_3", + ), + 314: dict( + name="l_border_of_iris_midpoint_3", + id=314, + color=[128, 192, 64], + type="", + swap="r_border_of_iris_midpoint_2", + ), + 315: dict( + name="l_border_of_iris_6", + id=315, + color=[128, 192, 64], + type="", + swap="r_border_of_iris_6", + ), + 316: dict( + name="l_border_of_iris_midpoint_2", + id=316, + color=[128, 192, 64], + type="", + swap="r_border_of_iris_midpoint_3", + ), + 317: dict( + name="r_center_of_iris", + id=317, + color=[192, 32, 64], + type="", + swap="l_center_of_iris", + ), + 318: dict( + name="r_border_of_iris_3", + id=318, + color=[192, 32, 64], + type="", + swap="l_border_of_iris_9", + ), + 319: dict( + name="r_border_of_iris_midpoint_1", + id=319, + color=[192, 32, 64], + type="", + swap="l_border_of_iris_midpoint_4", + ), + 320: dict( + name="r_border_of_iris_12", + id=320, + color=[192, 32, 64], + type="", + swap="l_border_of_iris_12", + ), + 321: dict( + name="r_border_of_iris_midpoint_4", + id=321, + color=[192, 32, 64], + type="", + swap="l_border_of_iris_midpoint_1", + ), + 322: dict( + name="r_border_of_iris_9", + id=322, + color=[192, 32, 64], + type="", + swap="l_border_of_iris_3", + ), + 323: dict( + name="r_border_of_iris_midpoint_3", + id=323, + color=[192, 32, 64], + type="", + swap="l_border_of_iris_midpoint_2", + ), + 324: dict( + name="r_border_of_iris_6", + id=324, + color=[192, 32, 64], + type="", + swap="l_border_of_iris_6", + ), + 325: dict( + name="r_border_of_iris_midpoint_2", + id=325, + color=[192, 32, 64], + type="", + swap="l_border_of_iris_midpoint_3", + ), + 326: dict( + name="l_center_of_pupil", + id=326, + color=[192, 128, 64], + type="", + swap="r_center_of_pupil", + ), + 327: dict( + name="l_border_of_pupil_3", + id=327, + color=[192, 128, 64], + type="", + swap="r_border_of_pupil_9", + ), + 328: dict( + name="l_border_of_pupil_midpoint_1", + id=328, + color=[192, 128, 64], + type="", + swap="r_border_of_pupil_midpoint_4", + ), + 329: dict( + name="l_border_of_pupil_12", + id=329, + color=[192, 128, 64], + type="", + swap="r_border_of_pupil_12", + ), + 330: dict( + name="l_border_of_pupil_midpoint_4", + id=330, + color=[192, 128, 64], + type="", + swap="r_border_of_pupil_midpoint_1", + ), + 331: dict( + name="l_border_of_pupil_9", + id=331, + color=[192, 128, 64], + type="", + swap="r_border_of_pupil_3", + ), + 332: dict( + name="l_border_of_pupil_midpoint_3", + id=332, + color=[192, 128, 64], + type="", + swap="r_border_of_pupil_midpoint_2", + ), + 333: dict( + name="l_border_of_pupil_6", + id=333, + color=[192, 128, 64], + type="", + swap="r_border_of_pupil_6", + ), + 334: dict( + name="l_border_of_pupil_midpoint_2", + id=334, + color=[192, 128, 64], + type="", + swap="r_border_of_pupil_midpoint_3", + ), + 335: dict( + name="r_center_of_pupil", + id=335, + color=[32, 192, 192], + type="", + swap="l_center_of_pupil", + ), + 336: dict( + name="r_border_of_pupil_3", + id=336, + color=[32, 192, 192], + type="", + swap="l_border_of_pupil_9", + ), + 337: dict( + name="r_border_of_pupil_midpoint_1", + id=337, + color=[32, 192, 192], + type="", + swap="l_border_of_pupil_midpoint_4", + ), + 338: dict( + name="r_border_of_pupil_12", + id=338, + color=[32, 192, 192], + type="", + swap="l_border_of_pupil_12", + ), + 339: dict( + name="r_border_of_pupil_midpoint_4", + id=339, + color=[32, 192, 192], + type="", + swap="l_border_of_pupil_midpoint_1", + ), + 340: dict( + name="r_border_of_pupil_9", + id=340, + color=[32, 192, 192], + type="", + swap="l_border_of_pupil_3", + ), + 341: dict( + name="r_border_of_pupil_midpoint_3", + id=341, + color=[32, 192, 192], + type="", + swap="l_border_of_pupil_midpoint_2", + ), + 342: dict( + name="r_border_of_pupil_6", + id=342, + color=[32, 192, 192], + type="", + swap="l_border_of_pupil_6", + ), + 343: dict( + name="r_border_of_pupil_midpoint_2", + id=343, + color=[32, 192, 192], + type="", + swap="l_border_of_pupil_midpoint_3", + ), + }, + remove_teeth=True, ## 36 of them, 344 - 36 = 308 + skeleton_info={ + 0: dict(link=("left_ankle", "left_knee"), id=0, color=[0, 255, 0]), + 1: dict(link=("left_knee", "left_hip"), id=1, color=[0, 255, 0]), + 2: dict(link=("right_ankle", "right_knee"), id=2, color=[255, 128, 0]), + 3: dict(link=("right_knee", "right_hip"), id=3, color=[255, 128, 0]), + 4: dict(link=("left_hip", "right_hip"), id=4, color=[51, 153, 255]), + 5: dict(link=("left_shoulder", "left_hip"), id=5, color=[51, 153, 255]), + 6: dict(link=("right_shoulder", "right_hip"), id=6, color=[51, 153, 255]), + 7: dict(link=("left_shoulder", "right_shoulder"), id=7, color=[51, 153, 255]), + 8: dict(link=("left_shoulder", "left_elbow"), id=8, color=[0, 255, 0]), + 9: dict(link=("right_shoulder", "right_elbow"), id=9, color=[255, 128, 0]), + 10: dict(link=("left_elbow", "left_wrist"), id=10, color=[0, 255, 0]), + 11: dict(link=("right_elbow", "right_wrist"), id=11, color=[255, 128, 0]), + 12: dict(link=("left_eye", "right_eye"), id=12, color=[51, 153, 255]), + 13: dict(link=("nose", "left_eye"), id=13, color=[51, 153, 255]), + 14: dict(link=("nose", "right_eye"), id=14, color=[51, 153, 255]), + 15: dict(link=("left_eye", "left_ear"), id=15, color=[51, 153, 255]), + 16: dict(link=("right_eye", "right_ear"), id=16, color=[51, 153, 255]), + 17: dict(link=("left_ear", "left_shoulder"), id=17, color=[51, 153, 255]), + 18: dict(link=("right_ear", "right_shoulder"), id=18, color=[51, 153, 255]), + 19: dict(link=("left_ankle", "left_big_toe"), id=19, color=[0, 255, 0]), + 20: dict(link=("left_ankle", "left_small_toe"), id=20, color=[0, 255, 0]), + 21: dict(link=("left_ankle", "left_heel"), id=21, color=[0, 255, 0]), + 22: dict(link=("right_ankle", "right_big_toe"), id=22, color=[255, 128, 0]), + 23: dict(link=("right_ankle", "right_small_toe"), id=23, color=[255, 128, 0]), + 24: dict(link=("right_ankle", "right_heel"), id=24, color=[255, 128, 0]), + 25: dict( + link=("left_wrist", "left_thumb_third_joint"), id=25, color=[255, 128, 0] + ), + 26: dict( + link=("left_thumb_third_joint", "left_thumb2"), id=26, color=[255, 128, 0] + ), + 27: dict(link=("left_thumb2", "left_thumb3"), id=27, color=[255, 128, 0]), + 28: dict(link=("left_thumb3", "left_thumb4"), id=28, color=[255, 128, 0]), + 29: dict( + link=("left_wrist", "left_forefinger_third_joint"), + id=29, + color=[255, 153, 255], + ), + 30: dict( + link=("left_forefinger_third_joint", "left_forefinger2"), + id=30, + color=[255, 153, 255], + ), + 31: dict( + link=("left_forefinger2", "left_forefinger3"), id=31, color=[255, 153, 255] + ), + 32: dict( + link=("left_forefinger3", "left_forefinger4"), id=32, color=[255, 153, 255] + ), + 33: dict( + link=("left_wrist", "left_middle_finger_third_joint"), + id=33, + color=[102, 178, 255], + ), + 34: dict( + link=("left_middle_finger_third_joint", "left_middle_finger2"), + id=34, + color=[102, 178, 255], + ), + 35: dict( + link=("left_middle_finger2", "left_middle_finger3"), + id=35, + color=[102, 178, 255], + ), + 36: dict( + link=("left_middle_finger3", "left_middle_finger4"), + id=36, + color=[102, 178, 255], + ), + 37: dict( + link=("left_wrist", "left_ring_finger_third_joint"), + id=37, + color=[255, 51, 51], + ), + 38: dict( + link=("left_ring_finger_third_joint", "left_ring_finger2"), + id=38, + color=[255, 51, 51], + ), + 39: dict( + link=("left_ring_finger2", "left_ring_finger3"), id=39, color=[255, 51, 51] + ), + 40: dict( + link=("left_ring_finger3", "left_ring_finger4"), id=40, color=[255, 51, 51] + ), + 41: dict( + link=("left_wrist", "left_pinky_finger_third_joint"), + id=41, + color=[0, 255, 0], + ), + 42: dict( + link=("left_pinky_finger_third_joint", "left_pinky_finger2"), + id=42, + color=[0, 255, 0], + ), + 43: dict( + link=("left_pinky_finger2", "left_pinky_finger3"), id=43, color=[0, 255, 0] + ), + 44: dict( + link=("left_pinky_finger3", "left_pinky_finger4"), id=44, color=[0, 255, 0] + ), + 45: dict( + link=("right_wrist", "right_thumb_third_joint"), id=45, color=[255, 128, 0] + ), + 46: dict( + link=("right_thumb_third_joint", "right_thumb2"), id=46, color=[255, 128, 0] + ), + 47: dict(link=("right_thumb2", "right_thumb3"), id=47, color=[255, 128, 0]), + 48: dict(link=("right_thumb3", "right_thumb4"), id=48, color=[255, 128, 0]), + 49: dict( + link=("right_wrist", "right_forefinger_third_joint"), + id=49, + color=[255, 153, 255], + ), + 50: dict( + link=("right_forefinger_third_joint", "right_forefinger2"), + id=50, + color=[255, 153, 255], + ), + 51: dict( + link=("right_forefinger2", "right_forefinger3"), + id=51, + color=[255, 153, 255], + ), + 52: dict( + link=("right_forefinger3", "right_forefinger4"), + id=52, + color=[255, 153, 255], + ), + 53: dict( + link=("right_wrist", "right_middle_finger_third_joint"), + id=53, + color=[102, 178, 255], + ), + 54: dict( + link=("right_middle_finger_third_joint", "right_middle_finger2"), + id=54, + color=[102, 178, 255], + ), + 55: dict( + link=("right_middle_finger2", "right_middle_finger3"), + id=55, + color=[102, 178, 255], + ), + 56: dict( + link=("right_middle_finger3", "right_middle_finger4"), + id=56, + color=[102, 178, 255], + ), + 57: dict( + link=("right_wrist", "right_ring_finger_third_joint"), + id=57, + color=[255, 51, 51], + ), + 58: dict( + link=("right_ring_finger_third_joint", "right_ring_finger2"), + id=58, + color=[255, 51, 51], + ), + 59: dict( + link=("right_ring_finger2", "right_ring_finger3"), + id=59, + color=[255, 51, 51], + ), + 60: dict( + link=("right_ring_finger3", "right_ring_finger4"), + id=60, + color=[255, 51, 51], + ), + 61: dict( + link=("right_wrist", "right_pinky_finger_third_joint"), + id=61, + color=[0, 255, 0], + ), + 62: dict( + link=("right_pinky_finger_third_joint", "right_pinky_finger2"), + id=62, + color=[0, 255, 0], + ), + 63: dict( + link=("right_pinky_finger2", "right_pinky_finger3"), + id=63, + color=[0, 255, 0], + ), + 64: dict( + link=("right_pinky_finger3", "right_pinky_finger4"), + id=64, + color=[0, 255, 0], + ), + }, + joint_weights=[1.0] * 344, + body_keypoint_names=[ + "nose", + "left_eye", + "right_eye", + "left_ear", + "right_ear", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_hip", + "right_hip", + "left_knee", + "right_knee", + "left_ankle", + "right_ankle", + ], + foot_keypoint_names=[ + "left_big_toe", + "left_small_toe", + "left_heel", + "right_big_toe", + "right_small_toe", + "right_heel", + ], + left_hand_keypoint_names=[ + "left_thumb4", + "left_thumb3", + "left_thumb2", + "left_thumb_third_joint", + "left_forefinger4", + "left_forefinger3", + "left_forefinger2", + "left_forefinger_third_joint", + "left_middle_finger4", + "left_middle_finger3", + "left_middle_finger2", + "left_middle_finger_third_joint", + "left_ring_finger4", + "left_ring_finger3", + "left_ring_finger2", + "left_ring_finger_third_joint", + "left_pinky_finger4", + "left_pinky_finger3", + "left_pinky_finger2", + "left_pinky_finger_third_joint", + ], + right_hand_keypoint_names=[ + "right_thumb4", + "right_thumb3", + "right_thumb2", + "right_thumb_third_joint", + "right_forefinger4", + "right_forefinger3", + "right_forefinger2", + "right_forefinger_third_joint", + "right_middle_finger4", + "right_middle_finger3", + "right_middle_finger2", + "right_middle_finger_third_joint", + "right_ring_finger4", + "right_ring_finger3", + "right_ring_finger2", + "right_ring_finger_third_joint", + "right_pinky_finger4", + "right_pinky_finger3", + "right_pinky_finger2", + "right_pinky_finger_third_joint", + ], + ## 7 of them + extra_keypoint_names=[ + "neck", + "left_olecranon", + "right_olecranon", + "left_cubital_fossa", + "right_cubital_fossa", + "left_acromion", + "right_acromion", + ], + face_keypoint_names=[ + "center_of_glabella", + "center_of_nose_root", + "tip_of_nose_bridge", + "midpoint_1_of_nose_bridge", + "midpoint_2_of_nose_bridge", + "midpoint_3_of_nose_bridge", + "center_of_labiomental_groove", + "tip_of_chin", + "upper_startpoint_of_r_eyebrow", + "lower_startpoint_of_r_eyebrow", + "end_of_r_eyebrow", + "upper_midpoint_1_of_r_eyebrow", + "lower_midpoint_1_of_r_eyebrow", + "upper_midpoint_2_of_r_eyebrow", + "upper_midpoint_3_of_r_eyebrow", + "lower_midpoint_2_of_r_eyebrow", + "lower_midpoint_3_of_r_eyebrow", + "upper_startpoint_of_l_eyebrow", + "lower_startpoint_of_l_eyebrow", + "end_of_l_eyebrow", + "upper_midpoint_1_of_l_eyebrow", + "lower_midpoint_1_of_l_eyebrow", + "upper_midpoint_2_of_l_eyebrow", + "upper_midpoint_3_of_l_eyebrow", + "lower_midpoint_2_of_l_eyebrow", + "lower_midpoint_3_of_l_eyebrow", + "l_inner_end_of_upper_lash_line", + "l_outer_end_of_upper_lash_line", + "l_centerpoint_of_upper_lash_line", + "l_midpoint_2_of_upper_lash_line", + "l_midpoint_1_of_upper_lash_line", + "l_midpoint_6_of_upper_lash_line", + "l_midpoint_5_of_upper_lash_line", + "l_midpoint_4_of_upper_lash_line", + "l_midpoint_3_of_upper_lash_line", + "l_outer_end_of_upper_eyelid_line", + "l_midpoint_6_of_upper_eyelid_line", + "l_midpoint_2_of_upper_eyelid_line", + "l_midpoint_5_of_upper_eyelid_line", + "l_centerpoint_of_upper_eyelid_line", + "l_midpoint_4_of_upper_eyelid_line", + "l_midpoint_1_of_upper_eyelid_line", + "l_midpoint_3_of_upper_eyelid_line", + "l_midpoint_6_of_upper_crease_line", + "l_midpoint_2_of_upper_crease_line", + "l_midpoint_5_of_upper_crease_line", + "l_centerpoint_of_upper_crease_line", + "l_midpoint_4_of_upper_crease_line", + "l_midpoint_1_of_upper_crease_line", + "l_midpoint_3_of_upper_crease_line", + "r_inner_end_of_upper_lash_line", + "r_outer_end_of_upper_lash_line", + "r_centerpoint_of_upper_lash_line", + "r_midpoint_1_of_upper_lash_line", + "r_midpoint_2_of_upper_lash_line", + "r_midpoint_3_of_upper_lash_line", + "r_midpoint_4_of_upper_lash_line", + "r_midpoint_5_of_upper_lash_line", + "r_midpoint_6_of_upper_lash_line", + "r_outer_end_of_upper_eyelid_line", + "r_midpoint_3_of_upper_eyelid_line", + "r_midpoint_1_of_upper_eyelid_line", + "r_midpoint_4_of_upper_eyelid_line", + "r_centerpoint_of_upper_eyelid_line", + "r_midpoint_5_of_upper_eyelid_line", + "r_midpoint_2_of_upper_eyelid_line", + "r_midpoint_6_of_upper_eyelid_line", + "r_midpoint_3_of_upper_crease_line", + "r_midpoint_1_of_upper_crease_line", + "r_midpoint_4_of_upper_crease_line", + "r_centerpoint_of_upper_crease_line", + "r_midpoint_5_of_upper_crease_line", + "r_midpoint_2_of_upper_crease_line", + "r_midpoint_6_of_upper_crease_line", + "l_inner_end_of_lower_lash_line", + "l_outer_end_of_lower_lash_line", + "l_centerpoint_of_lower_lash_line", + "l_midpoint_2_of_lower_lash_line", + "l_midpoint_1_of_lower_lash_line", + "l_midpoint_6_of_lower_lash_line", + "l_midpoint_5_of_lower_lash_line", + "l_midpoint_4_of_lower_lash_line", + "l_midpoint_3_of_lower_lash_line", + "l_outer_end_of_lower_eyelid_line", + "l_midpoint_6_of_lower_eyelid_line", + "l_midpoint_2_of_lower_eyelid_line", + "l_midpoint_5_of_lower_eyelid_line", + "l_centerpoint_of_lower_eyelid_line", + "l_midpoint_4_of_lower_eyelid_line", + "l_midpoint_1_of_lower_eyelid_line", + "l_midpoint_3_of_lower_eyelid_line", + "r_inner_end_of_lower_lash_line", + "r_outer_end_of_lower_lash_line", + "r_centerpoint_of_lower_lash_line", + "r_midpoint_1_of_lower_lash_line", + "r_midpoint_2_of_lower_lash_line", + "r_midpoint_3_of_lower_lash_line", + "r_midpoint_4_of_lower_lash_line", + "r_midpoint_5_of_lower_lash_line", + "r_midpoint_6_of_lower_lash_line", + "r_outer_end_of_lower_eyelid_line", + "r_midpoint_3_of_lower_eyelid_line", + "r_midpoint_1_of_lower_eyelid_line", + "r_midpoint_4_of_lower_eyelid_line", + "r_centerpoint_of_lower_eyelid_line", + "r_midpoint_5_of_lower_eyelid_line", + "r_midpoint_2_of_lower_eyelid_line", + "r_midpoint_6_of_lower_eyelid_line", + "tip_of_nose", + "bottom_center_of_nose", + "r_outer_corner_of_nose", + "l_outer_corner_of_nose", + "inner_corner_of_r_nostril", + "outer_corner_of_r_nostril", + "upper_corner_of_r_nostril", + "inner_corner_of_l_nostril", + "outer_corner_of_l_nostril", + "upper_corner_of_l_nostril", + "r_outer_corner_of_mouth", + "l_outer_corner_of_mouth", + "center_of_cupid_bow", + "center_of_lower_outer_lip", + "midpoint_1_of_upper_outer_lip", + "midpoint_2_of_upper_outer_lip", + "midpoint_1_of_lower_outer_lip", + "midpoint_2_of_lower_outer_lip", + "midpoint_3_of_upper_outer_lip", + "midpoint_4_of_upper_outer_lip", + "midpoint_5_of_upper_outer_lip", + "midpoint_6_of_upper_outer_lip", + "midpoint_3_of_lower_outer_lip", + "midpoint_4_of_lower_outer_lip", + "midpoint_5_of_lower_outer_lip", + "midpoint_6_of_lower_outer_lip", + "r_inner_corner_of_mouth", + "l_inner_corner_of_mouth", + "center_of_upper_inner_lip", + "center_of_lower_inner_lip", + "midpoint_1_of_upper_inner_lip", + "midpoint_2_of_upper_inner_lip", + "midpoint_1_of_lower_inner_lip", + "midpoint_2_of_lower_inner_lip", + "midpoint_3_of_upper_inner_lip", + "midpoint_4_of_upper_inner_lip", + "midpoint_5_of_upper_inner_lip", + "midpoint_6_of_upper_inner_lip", + "midpoint_3_of_lower_inner_lip", + "midpoint_4_of_lower_inner_lip", + "midpoint_5_of_lower_inner_lip", + "midpoint_6_of_lower_inner_lip", + "l_top_end_of_inferior_crus", + "l_top_end_of_superior_crus", + "l_start_of_antihelix", + "l_end_of_antihelix", + "l_midpoint_1_of_antihelix", + "l_midpoint_1_of_inferior_crus", + "l_midpoint_2_of_antihelix", + "l_midpoint_3_of_antihelix", + "l_point_1_of_inner_helix", + "l_point_2_of_inner_helix", + "l_point_3_of_inner_helix", + "l_point_4_of_inner_helix", + "l_point_5_of_inner_helix", + "l_point_6_of_inner_helix", + "l_point_7_of_inner_helix", + "l_highest_point_of_antitragus", + "l_bottom_point_of_tragus", + "l_protruding_point_of_tragus", + "l_top_point_of_tragus", + "l_start_point_of_crus_of_helix", + "l_deepest_point_of_concha", + "l_tip_of_ear_lobe", + "l_midpoint_between_22_15", + "l_bottom_connecting_point_of_ear_lobe", + "l_top_connecting_point_of_helix", + "l_point_8_of_inner_helix", + "r_top_end_of_inferior_crus", + "r_top_end_of_superior_crus", + "r_start_of_antihelix", + "r_end_of_antihelix", + "r_midpoint_1_of_antihelix", + "r_midpoint_1_of_inferior_crus", + "r_midpoint_2_of_antihelix", + "r_midpoint_3_of_antihelix", + "r_point_1_of_inner_helix", + "r_point_8_of_inner_helix", + "r_point_3_of_inner_helix", + "r_point_4_of_inner_helix", + "r_point_5_of_inner_helix", + "r_point_6_of_inner_helix", + "r_point_7_of_inner_helix", + "r_highest_point_of_antitragus", + "r_bottom_point_of_tragus", + "r_protruding_point_of_tragus", + "r_top_point_of_tragus", + "r_start_point_of_crus_of_helix", + "r_deepest_point_of_concha", + "r_tip_of_ear_lobe", + "r_midpoint_between_22_15", + "r_bottom_connecting_point_of_ear_lobe", + "r_top_connecting_point_of_helix", + "r_point_2_of_inner_helix", + "l_center_of_iris", + "l_border_of_iris_3", + "l_border_of_iris_midpoint_1", + "l_border_of_iris_12", + "l_border_of_iris_midpoint_4", + "l_border_of_iris_9", + "l_border_of_iris_midpoint_3", + "l_border_of_iris_6", + "l_border_of_iris_midpoint_2", + "r_center_of_iris", + "r_border_of_iris_3", + "r_border_of_iris_midpoint_1", + "r_border_of_iris_12", + "r_border_of_iris_midpoint_4", + "r_border_of_iris_9", + "r_border_of_iris_midpoint_3", + "r_border_of_iris_6", + "r_border_of_iris_midpoint_2", + "l_center_of_pupil", + "l_border_of_pupil_3", + "l_border_of_pupil_midpoint_1", + "l_border_of_pupil_12", + "l_border_of_pupil_midpoint_4", + "l_border_of_pupil_9", + "l_border_of_pupil_midpoint_3", + "l_border_of_pupil_6", + "l_border_of_pupil_midpoint_2", + "r_center_of_pupil", + "r_border_of_pupil_3", + "r_border_of_pupil_midpoint_1", + "r_border_of_pupil_12", + "r_border_of_pupil_midpoint_4", + "r_border_of_pupil_9", + "r_border_of_pupil_midpoint_3", + "r_border_of_pupil_6", + "r_border_of_pupil_midpoint_2", + ], +) + +##------------------------------------------------------------------------------------------------------------------ +### remove teeth keypoints +if dataset_info["remove_teeth"] == True: + ## get teeth ids + teeth_keypoint_ids = [ + keypoint_id + for keypoint_id, info in dataset_info["keypoint_info"].items() + if info["name"].startswith("teeth_") + ] + min_teeth_keypoint_id = min(teeth_keypoint_ids) + max_teeth_keypoint_id = max(teeth_keypoint_ids) + + dataset_info["teeth_keypoint_ids"] = teeth_keypoint_ids + + ## remove teeth keypoints from keypoint_info + keypoint_info_ = dataset_info["keypoint_info"] + + keypoint_info = {} + for keypoint_id, info in keypoint_info_.items(): + if keypoint_id in teeth_keypoint_ids: + continue + + if keypoint_id < min_teeth_keypoint_id: + keypoint_info[keypoint_id] = info + + if keypoint_id > max_teeth_keypoint_id: + keypoint_id = keypoint_id - len(teeth_keypoint_ids) + info["id"] = keypoint_id + keypoint_info[keypoint_id] = info + + dataset_info["keypoint_info"] = keypoint_info + + ## update joint_weights + dataset_info["joint_weights"] = (len(keypoint_info_) - len(teeth_keypoint_ids)) * [ + 1.0 + ] + +##------------------------------------------------------------------------------------------------------------------ +## reconfigure in the order of coco_whole_body +coco_wholebody_keypoint_info = { + keypoint_info["name"]: keypoint_info + for (keypoint_index, keypoint_info) in coco_wholebody_info["keypoint_info"].items() +} +coco_wholebody_to_goliath_mapping = {} ## coco_wholebody_index to goliath_index +coco_wholebody_to_goliath_keypoint_info = {} + +## find out common keypoints between goliath and coco_whole_body +for keypoint_index, keypoint_info in dataset_info["keypoint_info"].items(): + keypoint_name = keypoint_info["name"] + keypoint_index_ = keypoint_info["id"] + assert keypoint_index == keypoint_index_ + + if keypoint_name in coco_wholebody_keypoint_info.keys(): + coco_wholebody_to_goliath_keypoint_info[keypoint_name] = ( + coco_wholebody_keypoint_info[keypoint_name] + ) + coco_wholebody_to_goliath_mapping[ + coco_wholebody_keypoint_info[keypoint_name]["id"] + ] = keypoint_info["id"] + +dataset_info["coco_wholebody_to_goliath_mapping"] = ( + coco_wholebody_to_goliath_mapping ## store the cocowholebody indices +) +dataset_info["coco_wholebody_to_goliath_keypoint_info"] = ( + coco_wholebody_to_goliath_keypoint_info +) + +##------------------------------------------------------------------------------------------------------------------ +coco_wholebody_sigmas = {} + +## compute the coco_wholebody_sigmas +for keypoint_index, keypoint_info in coco_wholebody_info["keypoint_info"].items(): + coco_wholebody_sigmas[keypoint_info["name"]] = coco_wholebody_info["sigmas"][ + keypoint_info["id"] + ] + +default_sigma = 0.010 ## for mostly face keypoints +dataset_info["sigmas"] = [default_sigma] * len(dataset_info["keypoint_info"]) + +## we copy sigmas from coco_wholebody. Rest are assigned as below: +custom_sigmas = { + "left_thumb_third_joint": 0.022, + "left_forefinger_third_joint": 0.026, + "left_middle_finger_third_joint": 0.018, + "left_ring_finger_third_joint": 0.017, + "left_pinky_finger_third_joint": 0.02, + "right_thumb_third_joint": 0.022, + "right_forefinger_third_joint": 0.026, + "right_middle_finger_third_joint": 0.018, + "right_ring_finger_third_joint": 0.017, + "right_pinky_finger_third_joint": 0.02, + "neck": 0.079, ## same as shoulder + "left_olecranon": 0.072, ## same as elbow + "right_olecranon": 0.072, ## same as elbow + "left_cubital_fossa": 0.072, ## same as elbow + "right_cubital_fossa": 0.072, ## same as elbow + "left_acromion": 0.079, ## same as shoulder + "right_acromion": 0.079, ## same as shoulder +} + +## copy custom sigmas +for keypoint_name, sigma in custom_sigmas.items(): + keypoint_id = -1 + + ## search for keypoint id from keypoint name + for keypoint_id_ in dataset_info["keypoint_info"].keys(): + if dataset_info["keypoint_info"][keypoint_id_]["name"] == keypoint_name: + keypoint_id = keypoint_id_ + break + + if keypoint_id != -1: + keypoint_info = dataset_info["keypoint_info"][keypoint_id] + assert keypoint_info["name"] == keypoint_name + assert keypoint_info["id"] == keypoint_id + dataset_info["sigmas"][keypoint_info["id"]] = sigma + +## copy coco_wholebody sigmas +for keypoint_index, keypoint_info in dataset_info["keypoint_info"].items(): + if keypoint_info["name"] in coco_wholebody_sigmas.keys(): + dataset_info["sigmas"][keypoint_info["id"]] = coco_wholebody_sigmas[ + keypoint_info["name"] + ] diff --git a/sapiens/pose/configs/keypoints308/shutterstock_goliath_3po/sapiens2_0.4b_keypoints308_shutterstock_goliath_3po-1024x768.py b/sapiens/pose/configs/keypoints308/shutterstock_goliath_3po/sapiens2_0.4b_keypoints308_shutterstock_goliath_3po-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..0824ada878c8a5255350219ded0247c35554ca07 --- /dev/null +++ b/sapiens/pose/configs/keypoints308/shutterstock_goliath_3po/sapiens2_0.4b_keypoints308_shutterstock_goliath_3po-1024x768.py @@ -0,0 +1,325 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +# num_iters = 2e4 +num_iters = 1e4 + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # # # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_0.4b" +embed_dim = 1024 +num_layers = 24 +num_heads = 16 + +layer_decay_rate = 0.8 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.4b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +sigma = 6 ## sigma is 2 for 256 +scale = 4 +num_keypoints = 308 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="PoseVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, + num_keypoints=num_keypoints, +) + + +##----------------------------------------------------------------- +codec = dict( + type="UDPHeatmap", + input_size=(image_size[1], image_size[0]), ## width x height + heatmap_size=(int(image_size[1] / scale), int(image_size[0] / scale)), + sigma=sigma, +) ## sigma is 2 for 256 + +train_pipeline = [ + dict(type="PoseGetBBoxCenterScale"), + dict(type="PoseRandomFlip", direction="horizontal"), ## default prob is 0.5 + dict(type="PoseRandomHalfBody"), + dict(type="PoseRandomBBoxTransform"), + dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True), + dict(type="RandomPhotoMetricDistortion", prob=0.8), + dict( + type="PoseAlbumentation", + transforms=[ + dict(type="Blur", p=0.1), + dict(type="MedianBlur", p=0.1), + dict( + type="CoarseDropout", + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0, + ), + ], + ), + dict(type="PoseGenerateTarget", encoder=codec), + dict(type="PosePackInputs"), +] + +val_pipeline = [ + dict(type="PoseGetBBoxCenterScale"), + dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True), + dict(type="PosePackInputs"), +] + +test_pipeline = [ + dict(type="PoseGetBBoxCenterScale"), + dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True), + dict(type="PosePackInputs"), +] + +##------------------------------------------------------------------------ +dataset_shutterstock_train = dict( + type="Keypoints308ShutterstockDataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_102866/itw_shutterstock_body_keypoint_344_train:2025070300.json", +) + +dataset_goliath_train = dict( + type="Keypoints308GoliathDataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_keypoint_344_train:2024093001.json", + subsample_factor=8, +) + +dataset_3po_train = dict( + type="Keypoints308_3PODataset", + ann_file=f"{_DATA_ROOT}/indices/3po/train.json", + subsample_factor=2, +) + +# train_datasets = [dataset_shutterstock_train] +# train_datasets = [dataset_goliath_train] +# train_datasets = [dataset_3po_train] +train_datasets = ( + [dataset_goliath_train] + 2 * [dataset_shutterstock_train] + [dataset_3po_train] +) + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +# ------------------------------------------------------------------------------ +dataset_shutterstock_val = dict( + type="Keypoints308ShutterstockEvalDataset", + data_root=f"{_DATA_ROOT}/pose/data/shutterstock/test/images", + ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json", + test_mode=True, + pipeline=val_pipeline, +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", ## avoids fork error with airstore + shuffle=False, + dataset=dataset_shutterstock_val, + collate_fn=dict(type="eval_collate"), +) + +val_cfg = dict( + val_interval=val_every_iters, + flip_test=True, ## left right flip + evaluator=dict( + type="Keypoints308Evaluator", + decoder=codec, + ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json", + ), +) + +# dataset_goliath_val = dict( +# type="Keypoints308GoliathEvalDataset", +# data_root=f"{_DATA_ROOT}/pose/data/goliath/test_10000/images", +# ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json", +# test_mode=True, +# # num_samples=10, ## debug +# pipeline=val_pipeline, +# ) + +# val_dataloader = dict( +# batch_size=4, +# num_workers=4, +# persistent_workers=True, +# multiprocessing_context="spawn", ## avoids fork error with airstore +# # num_workers=0, # debug +# # persistent_workers=False, # debug +# shuffle=False, +# dataset=dataset_goliath_val, +# collate_fn=dict(type="eval_collate"), +# ) + +# val_cfg = dict( +# val_interval=val_every_iters, +# flip_test=True, ## left right flip +# evaluator=dict( +# type="Keypoints308Evaluator", +# decoder=codec, +# ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json", +# ), +# ) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="PoseTopdownEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="PoseHeatmapHead", + in_channels=embed_dim, + out_channels=num_keypoints, + deconv_out_channels=(1024, 768), ## this will 2x at each step. so total is 4x + deconv_kernel_sizes=(4, 4), + conv_out_channels=(512, 512, 256), + conv_kernel_sizes=(1, 1, 1), + loss_decode=dict( + type="KeypointMSELoss", use_target_weight=True, loss_weight=10.0 + ), + # loss_decode=dict(type='KeypointOHKMMSELoss', use_target_weight=True, topk=128), ## loss only for top 128 keypoints. for finetuning later. + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=2.0, norm_type=2.0) + +runner_type = "PoseRunner" diff --git a/sapiens/pose/configs/keypoints308/shutterstock_goliath_3po/sapiens2_0.8b_keypoints308_shutterstock_goliath_3po-1024x768.py b/sapiens/pose/configs/keypoints308/shutterstock_goliath_3po/sapiens2_0.8b_keypoints308_shutterstock_goliath_3po-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..b08eab49596567caf683e8e792bed7e613cd35d8 --- /dev/null +++ b/sapiens/pose/configs/keypoints308/shutterstock_goliath_3po/sapiens2_0.8b_keypoints308_shutterstock_goliath_3po-1024x768.py @@ -0,0 +1,325 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +# num_iters = 2e4 +num_iters = 1e4 + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # # # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_0.8b" +embed_dim = 1280 +num_layers = 32 +num_heads = 16 + +layer_decay_rate = 0.85 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_0.8b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +sigma = 6 ## sigma is 2 for 256 +scale = 4 +num_keypoints = 308 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="PoseVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, + num_keypoints=num_keypoints, +) + + +##----------------------------------------------------------------- +codec = dict( + type="UDPHeatmap", + input_size=(image_size[1], image_size[0]), ## width x height + heatmap_size=(int(image_size[1] / scale), int(image_size[0] / scale)), + sigma=sigma, +) ## sigma is 2 for 256 + +train_pipeline = [ + dict(type="PoseGetBBoxCenterScale"), + dict(type="PoseRandomFlip", direction="horizontal"), ## default prob is 0.5 + dict(type="PoseRandomHalfBody"), + dict(type="PoseRandomBBoxTransform"), + dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True), + dict(type="RandomPhotoMetricDistortion", prob=0.8), + dict( + type="PoseAlbumentation", + transforms=[ + dict(type="Blur", p=0.1), + dict(type="MedianBlur", p=0.1), + dict( + type="CoarseDropout", + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0, + ), + ], + ), + dict(type="PoseGenerateTarget", encoder=codec), + dict(type="PosePackInputs"), +] + +val_pipeline = [ + dict(type="PoseGetBBoxCenterScale"), + dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True), + dict(type="PosePackInputs"), +] + +test_pipeline = [ + dict(type="PoseGetBBoxCenterScale"), + dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True), + dict(type="PosePackInputs"), +] + +##------------------------------------------------------------------------ +dataset_shutterstock_train = dict( + type="Keypoints308ShutterstockDataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_102866/itw_shutterstock_body_keypoint_344_train:2025070300.json", +) + +dataset_goliath_train = dict( + type="Keypoints308GoliathDataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_keypoint_344_train:2024093001.json", + subsample_factor=8, +) + +dataset_3po_train = dict( + type="Keypoints308_3PODataset", + ann_file=f"{_DATA_ROOT}/indices/3po/train.json", + subsample_factor=2, +) + +# train_datasets = [dataset_shutterstock_train] +# train_datasets = [dataset_goliath_train] +# train_datasets = [dataset_3po_train] +train_datasets = ( + [dataset_goliath_train] + 2 * [dataset_shutterstock_train] + [dataset_3po_train] +) + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +# ------------------------------------------------------------------------------ +dataset_shutterstock_val = dict( + type="Keypoints308ShutterstockEvalDataset", + data_root=f"{_DATA_ROOT}/pose/data/shutterstock/test/images", + ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json", + test_mode=True, + pipeline=val_pipeline, +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", ## avoids fork error with airstore + shuffle=False, + dataset=dataset_shutterstock_val, + collate_fn=dict(type="eval_collate"), +) + +val_cfg = dict( + val_interval=val_every_iters, + flip_test=True, ## left right flip + evaluator=dict( + type="Keypoints308Evaluator", + decoder=codec, + ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json", + ), +) + +# dataset_goliath_val = dict( +# type="Keypoints308GoliathEvalDataset", +# data_root=f"{_DATA_ROOT}/pose/data/goliath/test_10000/images", +# ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json", +# test_mode=True, +# # num_samples=10, ## debug +# pipeline=val_pipeline, +# ) + +# val_dataloader = dict( +# batch_size=4, +# num_workers=4, +# persistent_workers=True, +# multiprocessing_context="spawn", ## avoids fork error with airstore +# # num_workers=0, # debug +# # persistent_workers=False, # debug +# shuffle=False, +# dataset=dataset_goliath_val, +# collate_fn=dict(type="eval_collate"), +# ) + +# val_cfg = dict( +# val_interval=val_every_iters, +# flip_test=True, ## left right flip +# evaluator=dict( +# type="Keypoints308Evaluator", +# decoder=codec, +# ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json", +# ), +# ) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="PoseTopdownEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="PoseHeatmapHead", + in_channels=embed_dim, + out_channels=num_keypoints, + deconv_out_channels=(1024, 768), ## this will 2x at each step. so total is 4x + deconv_kernel_sizes=(4, 4), + conv_out_channels=(512, 512, 256), + conv_kernel_sizes=(1, 1, 1), + loss_decode=dict( + type="KeypointMSELoss", use_target_weight=True, loss_weight=10.0 + ), + # loss_decode=dict(type='KeypointOHKMMSELoss', use_target_weight=True, topk=128), ## loss only for top 128 keypoints. for finetuning later. + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) + +runner_type = "PoseRunner" diff --git a/sapiens/pose/configs/keypoints308/shutterstock_goliath_3po/sapiens2_1b_keypoints308_shutterstock_goliath_3po-1024x768.py b/sapiens/pose/configs/keypoints308/shutterstock_goliath_3po/sapiens2_1b_keypoints308_shutterstock_goliath_3po-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..6b93c40b283cd77fe1da735720d3fa02adabd7de --- /dev/null +++ b/sapiens/pose/configs/keypoints308/shutterstock_goliath_3po/sapiens2_1b_keypoints308_shutterstock_goliath_3po-1024x768.py @@ -0,0 +1,328 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +# num_iters = 4e4 +num_iters = 2e4 + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # # # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_1b" +embed_dim = 1536 +num_layers = 40 +num_heads = 24 + +layer_decay_rate = 0.9 +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_1b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +sigma = 6 ## sigma is 2 for 256 +scale = 4 +num_keypoints = 308 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="PoseVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, + num_keypoints=num_keypoints, +) + + +##----------------------------------------------------------------- +codec = dict( + type="UDPHeatmap", + input_size=(image_size[1], image_size[0]), ## width x height + heatmap_size=(int(image_size[1] / scale), int(image_size[0] / scale)), + sigma=sigma, +) ## sigma is 2 for 256 + +train_pipeline = [ + dict(type="PoseGetBBoxCenterScale"), + dict(type="PoseRandomFlip", direction="horizontal"), ## default prob is 0.5 + dict(type="PoseRandomHalfBody"), + dict(type="PoseRandomBBoxTransform"), + dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True), + dict(type="RandomPhotoMetricDistortion", prob=0.8), + dict( + type="PoseAlbumentation", + transforms=[ + dict(type="Blur", p=0.1), + dict(type="MedianBlur", p=0.1), + dict( + type="CoarseDropout", + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0, + ), + ], + ), + dict(type="PoseGenerateTarget", encoder=codec), + dict(type="PosePackInputs"), +] + +val_pipeline = [ + dict(type="PoseGetBBoxCenterScale"), + dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True), + dict(type="PosePackInputs"), +] + +test_pipeline = [ + dict(type="PoseGetBBoxCenterScale"), + dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True), + dict(type="PosePackInputs"), +] + +##------------------------------------------------------------------------ +dataset_shutterstock_train = dict( + type="Keypoints308ShutterstockDataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_102866/itw_shutterstock_body_keypoint_344_train:2025070300.json", +) + +dataset_goliath_train = dict( + type="Keypoints308GoliathDataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_keypoint_344_train:2024093001.json", + subsample_factor=8, +) + +dataset_3po_train = dict( + type="Keypoints308_3PODataset", + ann_file=f"{_DATA_ROOT}/indices/3po/train.json", + subsample_factor=2, +) + +# train_datasets = [dataset_shutterstock_train] +# train_datasets = [dataset_goliath_train] +# train_datasets = [dataset_3po_train] +train_datasets = ( + [dataset_goliath_train] + 2 * [dataset_shutterstock_train] + [dataset_3po_train] +) + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +# ------------------------------------------------------------------------------ +dataset_shutterstock_val = dict( + type="Keypoints308ShutterstockEvalDataset", + data_root=f"{_DATA_ROOT}/pose/data/shutterstock/test/images", + ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json", + test_mode=True, + # num_samples=10, ## debug + pipeline=val_pipeline, +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", ## avoids fork error with airstore + # num_workers=0, # debug + # persistent_workers=False, # debug + shuffle=False, + dataset=dataset_shutterstock_val, + collate_fn=dict(type="eval_collate"), +) + +val_cfg = dict( + val_interval=val_every_iters, + flip_test=True, ## left right flip + evaluator=dict( + type="Keypoints308Evaluator", + decoder=codec, + ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json", + ), +) + +# dataset_goliath_val = dict( +# type="Keypoints308GoliathEvalDataset", +# data_root=f"{_DATA_ROOT}/pose/data/goliath/test_10000/images", +# ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json", +# test_mode=True, +# # num_samples=10, ## debug +# pipeline=val_pipeline, +# ) + +# val_dataloader = dict( +# batch_size=4, +# num_workers=4, +# persistent_workers=True, +# multiprocessing_context="spawn", ## avoids fork error with airstore +# # num_workers=0, # debug +# # persistent_workers=False, # debug +# shuffle=False, +# dataset=dataset_goliath_val, +# collate_fn=dict(type="eval_collate"), +# ) + +# val_cfg = dict( +# val_interval=val_every_iters, +# flip_test=True, ## left right flip +# evaluator=dict( +# type="Keypoints308Evaluator", +# decoder=codec, +# ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json", +# ), +# ) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="PoseTopdownEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="PoseHeatmapHead", + in_channels=embed_dim, + out_channels=num_keypoints, + deconv_out_channels=(1536, 1024), ## this will 2x at each step. so total is 4x + deconv_kernel_sizes=(4, 4), + conv_out_channels=(768, 512, 256), + conv_kernel_sizes=(1, 1, 1), + loss_decode=dict( + type="KeypointMSELoss", use_target_weight=True, loss_weight=10.0 + ), + # loss_decode=dict(type='KeypointOHKMMSELoss', use_target_weight=True, topk=128), ## loss only for top 128 keypoints. for finetuning later. + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) + +runner_type = "PoseRunner" diff --git a/sapiens/pose/configs/keypoints308/shutterstock_goliath_3po/sapiens2_5b_keypoints308_shutterstock_goliath_3po-1024x768.py b/sapiens/pose/configs/keypoints308/shutterstock_goliath_3po/sapiens2_5b_keypoints308_shutterstock_goliath_3po-1024x768.py new file mode 100644 index 0000000000000000000000000000000000000000..3e12721c79d54c9cb2408a505f22a1c88b3fa38f --- /dev/null +++ b/sapiens/pose/configs/keypoints308/shutterstock_goliath_3po/sapiens2_5b_keypoints308_shutterstock_goliath_3po-1024x768.py @@ -0,0 +1,326 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +_CHECKPOINT_ROOT = os.path.expanduser( + os.environ.get("SAPIENS_CHECKPOINT_ROOT", "~/sapiens2_host") +) +_DATA_ROOT = os.path.expanduser(os.environ.get("DATA_ROOT", "~/sapiens_data")) + +warmup_iters = 500 +# num_iters = 4e4 +num_iters = 2e4 ## light finetune + +# ------------------------------------------------------------------------------ +vis_every_iters = 100 +log_every_iters = 10 +save_every_iters = 1000 +val_every_iters = 1000 + +# # # # debug +# vis_every_iters = 1 +# log_every_iters = 1 +# val_every_iters = 2 +# save_every_iters = 1000 + +load_from = None +resume = False + +# ------------------------------------------------------------------ +model_name = "sapiens2_5b" +embed_dim = 2432 +num_layers = 56 +num_heads = 32 +layer_decay_rate = 0.94 + +pretrained_checkpoint = f"{_CHECKPOINT_ROOT}/pretrain/sapiens2_5b_pretrain.safetensors" + +##----------------------------------------------------------------- +image_size = (1024, 768) ## height x width +patch_size = 16 + +sigma = 6 ## sigma is 2 for 256 +scale = 4 +num_keypoints = 308 + +# ------------------------------------------------------------------ +use_fsdp = True +# use_fsdp = False + +use_compile = True +# use_compile = False + +## DDP config +if use_fsdp is False: + accelerator_cfg = dict( + type="DDP", + log_with="tensorboard", + # find_unused_parameters=True, + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + # mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, ## schedule independent of n_gpus + ) + +else: + accelerator_cfg = dict( + type="FSDP", + log_with="tensorboard", + gradient_accumulation_steps=1, # only accumulation=1 is supported. Otherwise, the LR scheduler will be off. + max_interval=num_iters, + mixed_precision="bf16", # Options: โ€˜noโ€™,โ€˜fp16โ€™,โ€˜bf16โ€™ or โ€˜fp8โ€™. + step_scheduler_with_optimizer=False, + fsdp_cfg=dict( + fsdp_version=2, # DTensor-based engine + state_dict_type="SHARDED_STATE_DICT", # SHARDED_STATE_DICT | FULL_STATE_DICT + mixed_precision=dict( + param_dtype="bf16", + reduce_dtype="bf16", + ), + cpu_ram_efficient_loading=False, + ), + ) + +if use_compile: + accelerator_cfg["compile_cfg"] = dict( + backend="inductor", + mode="default", # Options: "default", "reduce-overhead", "max-autotune" + fullgraph=False, + dynamic=False, + ) + +# ------------------------------------------------------------------ +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +logger = dict( + type="Logger", + log_interval=log_every_iters, +) +checkpoint = dict( + type="Checkpointer", + save_interval=save_every_iters, +) + +visualizer = dict( + type="PoseVisualizer", + vis_interval=vis_every_iters, + vis_max_samples=4, + vis_image_width=384, + vis_image_height=512, + num_keypoints=num_keypoints, +) + + +##----------------------------------------------------------------- +codec = dict( + type="UDPHeatmap", + input_size=(image_size[1], image_size[0]), ## width x height + heatmap_size=(int(image_size[1] / scale), int(image_size[0] / scale)), + sigma=sigma, +) ## sigma is 2 for 256 + +train_pipeline = [ + dict(type="PoseGetBBoxCenterScale"), + dict(type="PoseRandomFlip", direction="horizontal"), ## default prob is 0.5 + dict(type="PoseRandomHalfBody"), + dict(type="PoseRandomBBoxTransform"), + dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True), + dict(type="RandomPhotoMetricDistortion", prob=0.8), + dict( + type="PoseAlbumentation", + transforms=[ + dict(type="Blur", p=0.1), + dict(type="MedianBlur", p=0.1), + dict( + type="CoarseDropout", + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0, + ), + ], + ), + dict(type="PoseGenerateTarget", encoder=codec), + dict(type="PosePackInputs"), +] + +val_pipeline = [ + dict(type="PoseGetBBoxCenterScale"), + dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True), + dict(type="PosePackInputs"), +] + +test_pipeline = [ + dict(type="PoseGetBBoxCenterScale"), + dict(type="PoseTopdownAffine", input_size=codec["input_size"], use_udp=True), + dict(type="PosePackInputs"), +] + +##------------------------------------------------------------------------ +dataset_shutterstock_train = dict( + type="Keypoints308ShutterstockDataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_102866/itw_shutterstock_body_keypoint_344_train:2025070300.json", +) + +dataset_goliath_train = dict( + type="Keypoints308GoliathDataset", + ann_file=f"{_DATA_ROOT}/annotations/ingestion_90942/sociopticon_body_keypoint_344_train:2024093001.json", + subsample_factor=8, +) + +dataset_3po_train = dict( + type="Keypoints308_3PODataset", + ann_file=f"{_DATA_ROOT}/indices/3po/train.json", + subsample_factor=2, +) + +# train_datasets = [dataset_shutterstock_train] +# train_datasets = [dataset_goliath_train] +# train_datasets = [dataset_3po_train] +train_datasets = ( + [dataset_goliath_train] + 2 * [dataset_shutterstock_train] + [dataset_3po_train] +) + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + shuffle=True, + dataset=dict( + type="CombinedDataset", datasets=train_datasets, pipeline=train_pipeline + ), +) + +# ------------------------------------------------------------------------------ +dataset_shutterstock_val = dict( + type="Keypoints308ShutterstockEvalDataset", + data_root=f"{_DATA_ROOT}/pose/data/shutterstock/test/images", + ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json", + test_mode=True, + pipeline=val_pipeline, +) + +val_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + multiprocessing_context="spawn", ## avoids fork error with airstore + shuffle=False, + dataset=dataset_shutterstock_val, + collate_fn=dict(type="eval_collate"), +) + +val_cfg = dict( + val_interval=val_every_iters, + flip_test=True, ## left right flip + evaluator=dict( + type="Keypoints308Evaluator", + decoder=codec, + ann_file=f"{_DATA_ROOT}/pose/data/shutterstock/test/annotations/person_keypoints_test2025_1k.json", + ), +) + +# dataset_goliath_val = dict( +# type="Keypoints308GoliathEvalDataset", +# data_root=f"{_DATA_ROOT}/pose/data/goliath/test_10000/images", +# ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json", +# test_mode=True, +# # num_samples=10, ## debug +# pipeline=val_pipeline, +# ) + +# val_dataloader = dict( +# batch_size=4, +# num_workers=4, +# persistent_workers=True, +# multiprocessing_context="spawn", ## avoids fork error with airstore +# # num_workers=0, # debug +# # persistent_workers=False, # debug +# shuffle=False, +# dataset=dataset_goliath_val, +# collate_fn=dict(type="eval_collate"), +# ) + +# val_cfg = dict( +# val_interval=val_every_iters, +# flip_test=True, ## left right flip +# evaluator=dict( +# type="Keypoints308Evaluator", +# decoder=codec, +# ann_file=f"{_DATA_ROOT}/pose/data/goliath/test_10000/annotations/person_keypoints_test2023.json", +# ), +# ) + +data_preprocessor = dict( + type="ImagePreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, ## convert from bgr to rgb for pretrained models +) + +##----------------------------------------------------------------- +model = dict( + type="PoseTopdownEstimator", + backbone=dict( + type="Sapiens2", + arch=model_name, + img_size=image_size, + patch_size=patch_size, + final_norm=True, + use_tokenizer=False, + with_cls_token=True, + out_type="featmap", + init_cfg=dict(type="Pretrained", checkpoint=pretrained_checkpoint), + ), + decode_head=dict( + type="PoseHeatmapHead", + in_channels=embed_dim, + out_channels=num_keypoints, + deconv_out_channels=(1024, 768), ## this will 2x at each step. so total is 4x + deconv_kernel_sizes=(4, 4), + conv_out_channels=(512, 512, 256), + conv_kernel_sizes=(1, 1, 1), + loss_decode=dict( + type="KeypointMSELoss", use_target_weight=True, loss_weight=10.0 + ), + # loss_decode=dict(type='KeypointOHKMMSELoss', use_target_weight=True, topk=128), ## loss only for top 128 keypoints. for finetuning later. + ), +) + + +##----------------------------------------------------------------- +optimizer = dict( + type="AdamW", + # lr=5e-4, + lr=1e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + paramwise_cfg=dict( + num_layers=num_layers, + layer_decay_rate=layer_decay_rate, + ), + fused=True, +) + +scheduler = dict( + type="SequentialLR", + milestones=[warmup_iters], + schedulers=[ + dict(type="LinearLR", start_factor=1e-3, total_iters=warmup_iters), + dict( + type="PolynomialLR", + total_iters=num_iters - warmup_iters, + power=1.0, + ), + ], +) + +clip_grad = dict(mode="norm", max_norm=4.0, norm_type=2.0) + +runner_type = "PoseRunner" diff --git a/sapiens/pose/scripts/demo/keypoints308.sh b/sapiens/pose/scripts/demo/keypoints308.sh new file mode 100755 index 0000000000000000000000000000000000000000..551e46ea348ee3c2eda0c60c6c7977a0f50cb6d7 --- /dev/null +++ b/sapiens/pose/scripts/demo/keypoints308.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# Run 308-keypoint pose estimation on a directory of images. + +cd "$(dirname "$(realpath "$0")")/../.." || exit +SAPIENS_CHECKPOINT_ROOT="${SAPIENS_CHECKPOINT_ROOT:-${HOME}/sapiens2_host}" + +#----------------------------set your input and output directories------------------------- +INPUT='./demo/data/itw_videos/reel1' +OUTPUT="${HOME}/Desktop/sapiens2/pose/Outputs/vis/itw_videos/reel1" + +#--------------------------MODEL CARD (uncomment one)--------------------------------------- +# MODEL_NAME='sapiens2_0.4b'; CHECKPOINT="${SAPIENS_CHECKPOINT_ROOT}/pose/sapiens2_0.4b_pose.safetensors" +# MODEL_NAME='sapiens2_0.8b'; CHECKPOINT="${SAPIENS_CHECKPOINT_ROOT}/pose/sapiens2_0.8b_pose.safetensors" +MODEL_NAME='sapiens2_1b'; CHECKPOINT="${SAPIENS_CHECKPOINT_ROOT}/pose/sapiens2_1b_pose.safetensors" +# MODEL_NAME='sapiens2_5b'; CHECKPOINT="${SAPIENS_CHECKPOINT_ROOT}/pose/sapiens2_5b_pose.safetensors" + +DATASET='shutterstock_goliath_3po' +MODEL="${MODEL_NAME}_keypoints308_${DATASET}-1024x768" +CONFIG_FILE="configs/keypoints308/${DATASET}/${MODEL}.py" +OUTPUT="${OUTPUT}/${MODEL_NAME}" + +# Person detector (for bbox) +DETECTION_CONFIG_FILE='tools/vis/rtmdet_m_640-8xb32_coco-person.py' +DETECTION_CHECKPOINT="${SAPIENS_CHECKPOINT_ROOT}/detector/rtmdet_m.pth" + +#---------------------------VISUALIZATION PARAMS-------------------------------------------- +LINE_THICKNESS=8 +RADIUS=8 +KPT_THRES=0.3 + +##-------------------------------------inference-------------------------------------------- +RUN_FILE='tools/vis/vis_pose.py' + +# Number of inference jobs per GPU and which GPUs to use +JOBS_PER_GPU=2; GPU_IDS=(0 1 2 3 4 5 6 7) +# JOBS_PER_GPU=1; GPU_IDS=(0) +TOTAL_JOBS=$((JOBS_PER_GPU * ${#GPU_IDS[@]})) + +# Find images and partition across jobs +IMAGE_LIST="${INPUT}/image_list.txt" +find "${INPUT}" -type f \( -iname \*.jpg -o -iname \*.jpeg -o -iname \*.png \) | sort > "${IMAGE_LIST}" + +if [ ! -s "${IMAGE_LIST}" ]; then + echo "No images found at ${INPUT}" + exit 1 +fi + +NUM_IMAGES=$(wc -l < "${IMAGE_LIST}") +IMAGES_PER_FILE=$((NUM_IMAGES / TOTAL_JOBS)) +EXTRA_IMAGES=$((NUM_IMAGES % TOTAL_JOBS)) + +export TF_CPP_MIN_LOG_LEVEL=2 +echo "Distributing ${NUM_IMAGES} image paths into ${TOTAL_JOBS} jobs." + +current_line=1 +for ((i=0; i "${TEXT_FILE}" + current_line=$((current_line + images_for_this_job)) + else + touch "${TEXT_FILE}" + fi +done + +# Launch parallel inference +for ((i=0; i None: + super().__init__() + self.input_size = input_size + self.heatmap_size = heatmap_size + self.sigma = sigma + self.radius_factor = radius_factor + self.heatmap_type = heatmap_type + self.blur_kernel_size = blur_kernel_size + self.scale_factor = ( + (np.array(input_size) - 1) / (np.array(heatmap_size) - 1) + ).astype(np.float32) + + if self.heatmap_type not in {"gaussian", "combined"}: + raise ValueError( + f"{self.__class__.__name__} got invalid `heatmap_type` value" + f"{self.heatmap_type}. Should be one of " + '{"gaussian", "combined"}' + ) + + def encode( + self, keypoints: np.ndarray, keypoints_visible: Optional[np.ndarray] = None + ) -> dict: + assert keypoints.shape[0] == 1, ( + f"{self.__class__.__name__} only support single-instance keypoint encoding" + ) + + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + if self.heatmap_type == "gaussian": + heatmaps, keypoint_weights = generate_udp_gaussian_heatmaps( + heatmap_size=self.heatmap_size, + keypoints=keypoints / self.scale_factor, + keypoints_visible=keypoints_visible, + sigma=self.sigma, + ) + elif self.heatmap_type == "combined": + heatmaps, keypoint_weights = generate_offset_heatmap( + heatmap_size=self.heatmap_size, + keypoints=keypoints / self.scale_factor, + keypoints_visible=keypoints_visible, + radius_factor=self.radius_factor, + ) + else: + raise ValueError( + f"{self.__class__.__name__} got invalid `heatmap_type` value" + f"{self.heatmap_type}. Should be one of " + '{"gaussian", "combined"}' + ) + + encoded = dict(heatmaps=heatmaps, keypoint_weights=keypoint_weights) + + return encoded + + def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + heatmaps = encoded.copy() + + if self.heatmap_type == "gaussian": + keypoints, scores = get_heatmap_maximum(heatmaps) + # unsqueeze the instance dimension for single-instance results + keypoints = keypoints[None] + scores = scores[None] + + keypoints = refine_keypoints_dark_udp( + keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size + ) + + elif self.heatmap_type == "combined": + _K, H, W = heatmaps.shape + K = _K // 3 + + for cls_heatmap in heatmaps[::3]: + # Apply Gaussian blur on classification maps + ks = 2 * self.blur_kernel_size + 1 + cv2.GaussianBlur(cls_heatmap, (ks, ks), 0, cls_heatmap) + + # valid radius + radius = self.radius_factor * max(W, H) + + x_offset = heatmaps[1::3].flatten() * radius + y_offset = heatmaps[2::3].flatten() * radius + keypoints, scores = get_heatmap_maximum(heatmaps=heatmaps[::3]) + index = (keypoints[..., 0] + keypoints[..., 1] * W).flatten() + index += W * H * np.arange(0, K) + index = index.astype(int) + keypoints += np.stack((x_offset[index], y_offset[index]), axis=-1) + # unsqueeze the instance dimension for single-instance results + keypoints = keypoints[None].astype(np.float32) + scores = scores[None] + + W, H = self.heatmap_size + keypoints = keypoints / [W - 1, H - 1] * self.input_size + + return keypoints, scores diff --git a/sapiens/pose/src/datasets/codecs/utils/__init__.py b/sapiens/pose/src/datasets/codecs/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..763b5cee6b934d4d7f38b77a4648ad8d9fb2050e --- /dev/null +++ b/sapiens/pose/src/datasets/codecs/utils/__init__.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .gaussian_heatmap import ( + generate_gaussian_heatmaps, + generate_udp_gaussian_heatmaps, + generate_unbiased_gaussian_heatmaps, +) +from .instance_property import ( + get_diagonal_lengths, + get_instance_bbox, + get_instance_root, +) +from .offset_heatmap import generate_displacement_heatmap, generate_offset_heatmap +from .post_processing import ( + batch_heatmap_nms, + gaussian_blur, + gaussian_blur1d, + get_heatmap_maximum, + get_simcc_maximum, + get_simcc_normalized, +) +from .refinement import ( + refine_keypoints, + refine_keypoints_dark, + refine_keypoints_dark_udp, + refine_simcc_dark, +) + +__all__ = [ + "generate_gaussian_heatmaps", + "generate_udp_gaussian_heatmaps", + "generate_unbiased_gaussian_heatmaps", + "gaussian_blur", + "get_heatmap_maximum", + "get_simcc_maximum", + "generate_offset_heatmap", + "batch_heatmap_nms", + "refine_keypoints", + "refine_keypoints_dark", + "refine_keypoints_dark_udp", + "generate_displacement_heatmap", + "refine_simcc_dark", + "gaussian_blur1d", + "get_diagonal_lengths", + "get_instance_root", + "get_instance_bbox", + "get_simcc_normalized", +] diff --git a/sapiens/pose/src/datasets/codecs/utils/gaussian_heatmap.py b/sapiens/pose/src/datasets/codecs/utils/gaussian_heatmap.py new file mode 100644 index 0000000000000000000000000000000000000000..322cb739d261043743545cedd60521ef9b0df7c5 --- /dev/null +++ b/sapiens/pose/src/datasets/codecs/utils/gaussian_heatmap.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product +from typing import Tuple, Union + +import numpy as np + + +def generate_gaussian_heatmaps( + heatmap_size: Tuple[int, int], + keypoints: np.ndarray, + keypoints_visible: np.ndarray, + sigma: Union[float, Tuple[float], np.ndarray], +) -> Tuple[np.ndarray, np.ndarray]: + """Generate gaussian heatmaps of keypoints. + + Args: + heatmap_size (Tuple[int, int]): Heatmap size in [W, H] + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + sigma (float or List[float]): A list of sigma values of the Gaussian + heatmap for each instance. If sigma is given as a single float + value, it will be expanded into a tuple + + Returns: + tuple: + - heatmaps (np.ndarray): The generated heatmap in shape + (K, H, W) where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + """ + + N, K, _ = keypoints.shape + W, H = heatmap_size + + heatmaps = np.zeros((K, H, W), dtype=np.float32) + keypoint_weights = keypoints_visible.copy() + + if isinstance(sigma, (int, float)): + sigma = (sigma,) * N + + for n in range(N): + # 3-sigma rule + radius = sigma[n] * 3 + + # xy grid + gaussian_size = 2 * radius + 1 + x = np.arange(0, gaussian_size, 1, dtype=np.float32) + y = x[:, None] + x0 = y0 = gaussian_size // 2 + + for k in range(K): + # skip unlabled keypoints + if keypoints_visible[n, k] < 0.5: + continue + + # get gaussian center coordinates + mu = (keypoints[n, k] + 0.5).astype(np.int64) + + # check that the gaussian has in-bounds part + left, top = (mu - radius).astype(np.int64) + right, bottom = (mu + radius + 1).astype(np.int64) + + if left >= W or top >= H or right < 0 or bottom < 0: + keypoint_weights[n, k] = 0 + continue + + # The gaussian is not normalized, + # we want the center value to equal 1 + gaussian = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma[n] ** 2)) + + # valid range in gaussian + g_x1 = max(0, -left) + g_x2 = min(W, right) - left + g_y1 = max(0, -top) + g_y2 = min(H, bottom) - top + + # valid range in heatmap + h_x1 = max(0, left) + h_x2 = min(W, right) + h_y1 = max(0, top) + h_y2 = min(H, bottom) + + heatmap_region = heatmaps[k, h_y1:h_y2, h_x1:h_x2] + gaussian_regsion = gaussian[g_y1:g_y2, g_x1:g_x2] + + _ = np.maximum(heatmap_region, gaussian_regsion, out=heatmap_region) + + return heatmaps, keypoint_weights + + +def generate_unbiased_gaussian_heatmaps( + heatmap_size: Tuple[int, int], + keypoints: np.ndarray, + keypoints_visible: np.ndarray, + sigma: float, +) -> Tuple[np.ndarray, np.ndarray]: + """Generate gaussian heatmaps of keypoints using `Dark Pose`_. + + Args: + heatmap_size (Tuple[int, int]): Heatmap size in [W, H] + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + tuple: + - heatmaps (np.ndarray): The generated heatmap in shape + (K, H, W) where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + + .. _`Dark Pose`: https://arxiv.org/abs/1910.06278 + """ + + N, K, _ = keypoints.shape + W, H = heatmap_size + + heatmaps = np.zeros((K, H, W), dtype=np.float32) + keypoint_weights = keypoints_visible.copy() + + # 3-sigma rule + radius = sigma * 3 + + # xy grid + x = np.arange(0, W, 1, dtype=np.float32) + y = np.arange(0, H, 1, dtype=np.float32)[:, None] + + for n, k in product(range(N), range(K)): + # skip unlabled keypoints + if keypoints_visible[n, k] < 0.5: + continue + + mu = keypoints[n, k] + # check that the gaussian has in-bounds part + left, top = mu - radius + right, bottom = mu + radius + 1 + + if left >= W or top >= H or right < 0 or bottom < 0: + keypoint_weights[n, k] = 0 + continue + + gaussian = np.exp(-((x - mu[0]) ** 2 + (y - mu[1]) ** 2) / (2 * sigma**2)) + + _ = np.maximum(gaussian, heatmaps[k], out=heatmaps[k]) + + return heatmaps, keypoint_weights + + +def generate_udp_gaussian_heatmaps( + heatmap_size: Tuple[int, int], + keypoints: np.ndarray, + keypoints_visible: np.ndarray, + sigma: float, +) -> Tuple[np.ndarray, np.ndarray]: + """Generate gaussian heatmaps of keypoints using `UDP`_. + + Args: + heatmap_size (Tuple[int, int]): Heatmap size in [W, H] + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + sigma (float): The sigma value of the Gaussian heatmap + + Returns: + tuple: + - heatmaps (np.ndarray): The generated heatmap in shape + (K, H, W) where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + + .. _`UDP`: https://arxiv.org/abs/1911.07524 + """ + + N, K, _ = keypoints.shape + W, H = heatmap_size + + heatmaps = np.zeros((K, H, W), dtype=np.float32) + keypoint_weights = keypoints_visible.copy() + + # 3-sigma rule + radius = sigma * 3 + + # xy grid + gaussian_size = 2 * radius + 1 + x = np.arange(0, gaussian_size, 1, dtype=np.float32) + y = x[:, None] + + for n, k in product(range(N), range(K)): + # skip unlabled keypoints + if keypoints_visible[n, k] < 0.5: + continue + + mu = (keypoints[n, k] + 0.5).astype(np.int64) + # check that the gaussian has in-bounds part + left, top = (mu - radius).astype(np.int64) + right, bottom = (mu + radius + 1).astype(np.int64) + + if left >= W or top >= H or right < 0 or bottom < 0: + keypoint_weights[n, k] = 0 + continue + + mu_ac = keypoints[n, k] + x0 = y0 = gaussian_size // 2 + x0 += mu_ac[0] - mu[0] + y0 += mu_ac[1] - mu[1] + gaussian = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma**2)) + + # valid range in gaussian + g_x1 = max(0, -left) + g_x2 = min(W, right) - left + g_y1 = max(0, -top) + g_y2 = min(H, bottom) - top + + # valid range in heatmap + h_x1 = max(0, left) + h_x2 = min(W, right) + h_y1 = max(0, top) + h_y2 = min(H, bottom) + + heatmap_region = heatmaps[k, h_y1:h_y2, h_x1:h_x2] + gaussian_regsion = gaussian[g_y1:g_y2, g_x1:g_x2] + + _ = np.maximum(heatmap_region, gaussian_regsion, out=heatmap_region) + + return heatmaps, keypoint_weights diff --git a/sapiens/pose/src/datasets/codecs/utils/instance_property.py b/sapiens/pose/src/datasets/codecs/utils/instance_property.py new file mode 100644 index 0000000000000000000000000000000000000000..a64872f48389ed5bbb0b853224369d137172b488 --- /dev/null +++ b/sapiens/pose/src/datasets/codecs/utils/instance_property.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import numpy as np + + +def get_instance_root( + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None, + root_type: str = "kpt_center", +) -> np.ndarray: + """Calculate the coordinates and visibility of instance roots. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + root_type (str): Calculation of instance roots which should + be one of the following options: + + - ``'kpt_center'``: The roots' coordinates are the mean + coordinates of visible keypoints + - ``'bbox_center'``: The roots' are the center of bounding + boxes outlined by visible keypoints + + Defaults to ``'kpt_center'`` + + Returns: + tuple + - roots_coordinate(np.ndarray): Coordinates of instance roots in + shape [N, D] + - roots_visible(np.ndarray): Visibility of instance roots in + shape [N] + """ + + roots_coordinate = np.zeros((keypoints.shape[0], 2), dtype=np.float32) + roots_visible = np.ones((keypoints.shape[0]), dtype=np.float32) * 2 + + for i in range(keypoints.shape[0]): + # collect visible keypoints + if keypoints_visible is not None: + visible_keypoints = keypoints[i][keypoints_visible[i] > 0] + else: + visible_keypoints = keypoints[i] + if visible_keypoints.size == 0: + roots_visible[i] = 0 + continue + + # compute the instance root with visible keypoints + if root_type == "kpt_center": + roots_coordinate[i] = visible_keypoints.mean(axis=0) + roots_visible[i] = 1 + elif root_type == "bbox_center": + roots_coordinate[i] = ( + visible_keypoints.max(axis=0) + visible_keypoints.min(axis=0) + ) / 2.0 + roots_visible[i] = 1 + else: + raise ValueError( + f"the value of `root_type` must be 'kpt_center' or " + f"'bbox_center', but got '{root_type}'" + ) + + return roots_coordinate, roots_visible + + +def get_instance_bbox( + keypoints: np.ndarray, keypoints_visible: Optional[np.ndarray] = None +) -> np.ndarray: + """Calculate the pseudo instance bounding box from visible keypoints. The + bounding boxes are in the xyxy format. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + np.ndarray: bounding boxes in [N, 4] + """ + bbox = np.zeros((keypoints.shape[0], 4), dtype=np.float32) + for i in range(keypoints.shape[0]): + if keypoints_visible is not None: + visible_keypoints = keypoints[i][keypoints_visible[i] > 0] + else: + visible_keypoints = keypoints[i] + if visible_keypoints.size == 0: + continue + + bbox[i, :2] = visible_keypoints.min(axis=0) + bbox[i, 2:] = visible_keypoints.max(axis=0) + return bbox + + +def get_diagonal_lengths( + keypoints: np.ndarray, keypoints_visible: Optional[np.ndarray] = None +) -> np.ndarray: + """Calculate the diagonal length of instance bounding box from visible + keypoints. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + np.ndarray: bounding box diagonal length in [N] + """ + pseudo_bbox = get_instance_bbox(keypoints, keypoints_visible) + pseudo_bbox = pseudo_bbox.reshape(-1, 2, 2) + h_w_diff = pseudo_bbox[:, 1] - pseudo_bbox[:, 0] + diagonal_length = np.sqrt(np.power(h_w_diff, 2).sum(axis=1)) + + return diagonal_length diff --git a/sapiens/pose/src/datasets/codecs/utils/offset_heatmap.py b/sapiens/pose/src/datasets/codecs/utils/offset_heatmap.py new file mode 100644 index 0000000000000000000000000000000000000000..776d1eb76fd41582626c75070f54871e9e0bb081 --- /dev/null +++ b/sapiens/pose/src/datasets/codecs/utils/offset_heatmap.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product +from typing import Tuple + +import numpy as np + + +def generate_offset_heatmap( + heatmap_size: Tuple[int, int], + keypoints: np.ndarray, + keypoints_visible: np.ndarray, + radius_factor: float, +) -> Tuple[np.ndarray, np.ndarray]: + """Generate offset heatmaps of keypoints, where each keypoint is + represented by 3 maps: one pixel-level class label map (1 for keypoint and + 0 for non-keypoint) and 2 pixel-level offset maps for x and y directions + respectively. + + Args: + heatmap_size (Tuple[int, int]): Heatmap size in [W, H] + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + radius_factor (float): The radius factor of the binary label + map. The positive region is defined as the neighbor of the + keypoint with the radius :math:`r=radius_factor*max(W, H)` + + Returns: + tuple: + - heatmap (np.ndarray): The generated heatmap in shape + (K*3, H, W) where [W, H] is the `heatmap_size` + - keypoint_weights (np.ndarray): The target weights in shape + (K,) + """ + + N, K, _ = keypoints.shape + W, H = heatmap_size + + heatmaps = np.zeros((K, 3, H, W), dtype=np.float32) + keypoint_weights = keypoints_visible.copy() + + # xy grid + x = np.arange(0, W, 1) + y = np.arange(0, H, 1)[:, None] + + # positive area radius in the classification map + radius = radius_factor * max(W, H) + + for n, k in product(range(N), range(K)): + if keypoints_visible[n, k] < 0.5: + continue + + mu = keypoints[n, k] + + x_offset = (mu[0] - x) / radius + y_offset = (mu[1] - y) / radius + + heatmaps[k, 0] = np.where(x_offset**2 + y_offset**2 <= 1, 1.0, 0.0) + heatmaps[k, 1] = x_offset + heatmaps[k, 2] = y_offset + + heatmaps = heatmaps.reshape(K * 3, H, W) + + return heatmaps, keypoint_weights + + +def generate_displacement_heatmap( + heatmap_size: Tuple[int, int], + keypoints: np.ndarray, + keypoints_visible: np.ndarray, + roots: np.ndarray, + roots_visible: np.ndarray, + diagonal_lengths: np.ndarray, + radius: float, +): + """Generate displacement heatmaps of keypoints, where each keypoint is + represented by 3 maps: one pixel-level class label map (1 for keypoint and + 0 for non-keypoint) and 2 pixel-level offset maps for x and y directions + respectively. + + Args: + heatmap_size (Tuple[int, int]): Heatmap size in [W, H] + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + roots (np.ndarray): Coordinates of instance centers in shape (N, D). + The displacement fields of each instance will locate around its + center. + roots_visible (np.ndarray): Roots visibilities in shape (N,) + diagonal_lengths (np.ndarray): Diaginal length of the bounding boxes + of each instance in shape (N,) + radius (float): The radius factor of the binary label + map. The positive region is defined as the neighbor of the + keypoint with the radius :math:`r=radius_factor*max(W, H)` + + Returns: + tuple: + - displacements (np.ndarray): The generated displacement map in + shape (K*2, H, W) where [W, H] is the `heatmap_size` + - displacement_weights (np.ndarray): The target weights in shape + (K*2, H, W) + """ + N, K, _ = keypoints.shape + W, H = heatmap_size + + displacements = np.zeros((K * 2, H, W), dtype=np.float32) + displacement_weights = np.zeros((K * 2, H, W), dtype=np.float32) + instance_size_map = np.zeros((H, W), dtype=np.float32) + + for n in range(N): + if ( + roots_visible[n] < 1 + or (roots[n, 0] < 0 or roots[n, 1] < 0) + or (roots[n, 0] >= W or roots[n, 1] >= H) + ): + continue + + diagonal_length = diagonal_lengths[n] + + for k in range(K): + if ( + keypoints_visible[n, k] < 1 + or keypoints[n, k, 0] < 0 + or keypoints[n, k, 1] < 0 + or keypoints[n, k, 0] >= W + or keypoints[n, k, 1] >= H + ): + continue + + start_x = max(int(roots[n, 0] - radius), 0) + start_y = max(int(roots[n, 1] - radius), 0) + end_x = min(int(roots[n, 0] + radius), W) + end_y = min(int(roots[n, 1] + radius), H) + + for x in range(start_x, end_x): + for y in range(start_y, end_y): + if ( + displacements[2 * k, y, x] != 0 + or displacements[2 * k + 1, y, x] != 0 + ): + if diagonal_length > instance_size_map[y, x]: + # keep the gt displacement of smaller instance + continue + + displacement_weights[2 * k : 2 * k + 2, y, x] = 1 / diagonal_length + displacements[2 * k : 2 * k + 2, y, x] = keypoints[n, k] - [x, y] + instance_size_map[y, x] = diagonal_length + + return displacements, displacement_weights diff --git a/sapiens/pose/src/datasets/codecs/utils/post_processing.py b/sapiens/pose/src/datasets/codecs/utils/post_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..e845530673ec161f20e04b28d5d66bc935930b71 --- /dev/null +++ b/sapiens/pose/src/datasets/codecs/utils/post_processing.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product +from typing import Tuple + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor + + +def get_simcc_normalized(batch_pred_simcc, sigma=None): + """Normalize the predicted SimCC. + + Args: + batch_pred_simcc (torch.Tensor): The predicted SimCC. + sigma (float): The sigma of the Gaussian distribution. + + Returns: + torch.Tensor: The normalized SimCC. + """ + B, K, _ = batch_pred_simcc.shape + + # Scale and clamp the tensor + if sigma is not None: + batch_pred_simcc = batch_pred_simcc / (sigma * np.sqrt(np.pi * 2)) + batch_pred_simcc = batch_pred_simcc.clamp(min=0) + + # Compute the binary mask + mask = (batch_pred_simcc.amax(dim=-1) > 1).reshape(B, K, 1) + + # Normalize the tensor using the maximum value + norm = batch_pred_simcc / batch_pred_simcc.amax(dim=-1).reshape(B, K, 1) + + # Apply normalization + batch_pred_simcc = torch.where(mask, norm, batch_pred_simcc) + + return batch_pred_simcc + + +def get_simcc_maximum( + simcc_x: np.ndarray, simcc_y: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + + assert isinstance(simcc_x, np.ndarray), "simcc_x should be numpy.ndarray" + assert isinstance(simcc_y, np.ndarray), "simcc_y should be numpy.ndarray" + assert simcc_x.ndim == 2 or simcc_x.ndim == 3, f"Invalid shape {simcc_x.shape}" + assert simcc_y.ndim == 2 or simcc_y.ndim == 3, f"Invalid shape {simcc_y.shape}" + assert simcc_x.ndim == simcc_y.ndim, f"{simcc_x.shape} != {simcc_y.shape}" + + if simcc_x.ndim == 3: + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + else: + N = None + + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.0] = -1 + + if N: + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def get_heatmap_maximum(heatmaps: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from heatmaps. + + Note: + batch_size: B + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + heatmaps (np.ndarray): Heatmaps in shape (K, H, W) or (B, K, H, W) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (B, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (B, K) + """ + assert isinstance(heatmaps, np.ndarray), "heatmaps should be numpy.ndarray" + assert heatmaps.ndim == 3 or heatmaps.ndim == 4, f"Invalid shape {heatmaps.shape}" + + if heatmaps.ndim == 3: + K, H, W = heatmaps.shape + B = None + heatmaps_flatten = heatmaps.reshape(K, -1) + else: + B, K, H, W = heatmaps.shape + heatmaps_flatten = heatmaps.reshape(B * K, -1) + + y_locs, x_locs = np.unravel_index(np.argmax(heatmaps_flatten, axis=1), shape=(H, W)) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + vals = np.amax(heatmaps_flatten, axis=1) + locs[vals <= 0.0] = -1 + + if B: + locs = locs.reshape(B, K, 2) + vals = vals.reshape(B, K) + + return locs, vals + + +def gaussian_blur(heatmaps: np.ndarray, kernel: int = 11) -> np.ndarray: + """Modulate heatmap distribution with Gaussian. + + Note: + - num_keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + heatmaps (np.ndarray[K, H, W]): model predicted heatmaps. + kernel (int): Gaussian kernel size (K) for modulation, which should + match the heatmap gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + + Returns: + np.ndarray ([K, H, W]): Modulated heatmap distribution. + """ + assert kernel % 2 == 1 + + border = (kernel - 1) // 2 + K, H, W = heatmaps.shape + + for k in range(K): + origin_max = np.max(heatmaps[k]) + dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32) + dr[border:-border, border:-border] = heatmaps[k].copy() + dr = cv2.GaussianBlur(dr, (kernel, kernel), 0) + heatmaps[k] = dr[border:-border, border:-border].copy() + heatmaps[k] *= origin_max / np.max(heatmaps[k]) + return heatmaps + + +def gaussian_blur1d(simcc: np.ndarray, kernel: int = 11) -> np.ndarray: + """Modulate simcc distribution with Gaussian. + + Note: + - num_keypoints: K + - simcc length: Wx + + Args: + simcc (np.ndarray[K, Wx]): model predicted simcc. + kernel (int): Gaussian kernel size (K) for modulation, which should + match the simcc gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + + Returns: + np.ndarray ([K, Wx]): Modulated simcc distribution. + """ + assert kernel % 2 == 1 + + border = (kernel - 1) // 2 + N, K, Wx = simcc.shape + + for n, k in product(range(N), range(K)): + origin_max = np.max(simcc[n, k]) + dr = np.zeros((1, Wx + 2 * border), dtype=np.float32) + dr[0, border:-border] = simcc[n, k].copy() + dr = cv2.GaussianBlur(dr, (kernel, 1), 0) + simcc[n, k] = dr[0, border:-border].copy() + simcc[n, k] *= origin_max / np.max(simcc[n, k]) + return simcc + + +def batch_heatmap_nms(batch_heatmaps: Tensor, kernel_size: int = 5): + """Apply NMS on a batch of heatmaps. + + Args: + batch_heatmaps (Tensor): batch heatmaps in shape (B, K, H, W) + kernel_size (int): The kernel size of the NMS which should be + a odd integer. Defaults to 5 + + Returns: + Tensor: The batch heatmaps after NMS. + """ + + assert isinstance(kernel_size, int) and kernel_size % 2 == 1, ( + f"The kernel_size should be an odd integer, got {kernel_size}" + ) + + padding = (kernel_size - 1) // 2 + + maximum = F.max_pool2d(batch_heatmaps, kernel_size, stride=1, padding=padding) + maximum_indicator = torch.eq(batch_heatmaps, maximum) + batch_heatmaps = batch_heatmaps * maximum_indicator.float() + + return batch_heatmaps diff --git a/sapiens/pose/src/datasets/codecs/utils/refinement.py b/sapiens/pose/src/datasets/codecs/utils/refinement.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed66b7dd11e4f7a7b3522846693a8cf2fc77d76 --- /dev/null +++ b/sapiens/pose/src/datasets/codecs/utils/refinement.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product + +import numpy as np + +from .post_processing import gaussian_blur, gaussian_blur1d + + +def refine_keypoints(keypoints: np.ndarray, heatmaps: np.ndarray) -> np.ndarray: + """Refine keypoint predictions by moving from the maximum towards the + second maximum by 0.25 pixel. The operation is in-place. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - heatmap size: [W, H] + + Args: + keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D) + heatmaps (np.ndarray): The heatmaps in shape (K, H, W) + + Returns: + np.ndarray: Refine keypoint coordinates in shape (N, K, D) + """ + N, K = keypoints.shape[:2] + H, W = heatmaps.shape[1:] + + for n, k in product(range(N), range(K)): + x, y = keypoints[n, k, :2].astype(int) + + if 1 < x < W - 1 and 0 < y < H: + dx = heatmaps[k, y, x + 1] - heatmaps[k, y, x - 1] + else: + dx = 0.0 + + if 1 < y < H - 1 and 0 < x < W: + dy = heatmaps[k, y + 1, x] - heatmaps[k, y - 1, x] + else: + dy = 0.0 + + keypoints[n, k] += np.sign([dx, dy], dtype=np.float32) * 0.25 + + return keypoints + + +def refine_keypoints_dark( + keypoints: np.ndarray, heatmaps: np.ndarray, blur_kernel_size: int +) -> np.ndarray: + """Refine keypoint predictions using distribution aware coordinate + decoding. See `Dark Pose`_ for details. The operation is in-place. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - heatmap size: [W, H] + + Args: + keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D) + heatmaps (np.ndarray): The heatmaps in shape (K, H, W) + blur_kernel_size (int): The Gaussian blur kernel size of the heatmap + modulation + + Returns: + np.ndarray: Refine keypoint coordinates in shape (N, K, D) + + .. _`Dark Pose`: https://arxiv.org/abs/1910.06278 + """ + N, K = keypoints.shape[:2] + H, W = heatmaps.shape[1:] + + # modulate heatmaps + heatmaps = gaussian_blur(heatmaps, blur_kernel_size) + np.maximum(heatmaps, 1e-10, heatmaps) + np.log(heatmaps, heatmaps) + + for n, k in product(range(N), range(K)): + x, y = keypoints[n, k, :2].astype(int) + if 1 < x < W - 2 and 1 < y < H - 2: + dx = 0.5 * (heatmaps[k, y, x + 1] - heatmaps[k, y, x - 1]) + dy = 0.5 * (heatmaps[k, y + 1, x] - heatmaps[k, y - 1, x]) + + dxx = 0.25 * ( + heatmaps[k, y, x + 2] - 2 * heatmaps[k, y, x] + heatmaps[k, y, x - 2] + ) + dxy = 0.25 * ( + heatmaps[k, y + 1, x + 1] + - heatmaps[k, y - 1, x + 1] + - heatmaps[k, y + 1, x - 1] + + heatmaps[k, y - 1, x - 1] + ) + dyy = 0.25 * ( + heatmaps[k, y + 2, x] - 2 * heatmaps[k, y, x] + heatmaps[k, y - 2, x] + ) + derivative = np.array([[dx], [dy]]) + hessian = np.array([[dxx, dxy], [dxy, dyy]]) + if dxx * dyy - dxy**2 != 0: + hessianinv = np.linalg.inv(hessian) + offset = -hessianinv @ derivative + offset = np.squeeze(np.array(offset.T), axis=0) + keypoints[n, k, :2] += offset + return keypoints + + +def refine_keypoints_dark_udp( + keypoints: np.ndarray, heatmaps: np.ndarray, blur_kernel_size: int +) -> np.ndarray: + """Refine keypoint predictions using distribution aware coordinate decoding + for UDP. See `UDP`_ for details. The operation is in-place. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - heatmap size: [W, H] + + Args: + keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D) + heatmaps (np.ndarray): The heatmaps in shape (K, H, W) + blur_kernel_size (int): The Gaussian blur kernel size of the heatmap + modulation + + Returns: + np.ndarray: Refine keypoint coordinates in shape (N, K, D) + + .. _`UDP`: https://arxiv.org/abs/1911.07524 + """ + N, K = keypoints.shape[:2] + H, W = heatmaps.shape[1:] + + # modulate heatmaps + heatmaps = gaussian_blur(heatmaps, blur_kernel_size) + np.clip(heatmaps, 1e-3, 50.0, heatmaps) + np.log(heatmaps, heatmaps) + + heatmaps_pad = np.pad(heatmaps, ((0, 0), (1, 1), (1, 1)), mode="edge").flatten() + + for n in range(N): + index = keypoints[n, :, 0] + 1 + (keypoints[n, :, 1] + 1) * (W + 2) + index += (W + 2) * (H + 2) * np.arange(0, K) + index = index.astype(int).reshape(-1, 1) + i_ = heatmaps_pad[index] + ix1 = heatmaps_pad[index + 1] + iy1 = heatmaps_pad[index + W + 2] + ix1y1 = heatmaps_pad[index + W + 3] + ix1_y1_ = heatmaps_pad[index - W - 3] + ix1_ = heatmaps_pad[index - 1] + iy1_ = heatmaps_pad[index - 2 - W] + + dx = 0.5 * (ix1 - ix1_) + dy = 0.5 * (iy1 - iy1_) + derivative = np.concatenate([dx, dy], axis=1) + derivative = derivative.reshape(K, 2, 1) + + dxx = ix1 - 2 * i_ + ix1_ + dyy = iy1 - 2 * i_ + iy1_ + dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) + hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1) + hessian = hessian.reshape(K, 2, 2) + hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) + keypoints[n] -= np.einsum("imn,ink->imk", hessian, derivative).squeeze() + + return keypoints + + +def refine_simcc_dark( + keypoints: np.ndarray, simcc: np.ndarray, blur_kernel_size: int +) -> np.ndarray: + """SimCC version. Refine keypoint predictions using distribution aware + coordinate decoding for UDP. See `UDP`_ for details. The operation is in- + place. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + + Args: + keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D) + simcc (np.ndarray): The heatmaps in shape (N, K, Wx) + blur_kernel_size (int): The Gaussian blur kernel size of the heatmap + modulation + + Returns: + np.ndarray: Refine keypoint coordinates in shape (N, K, D) + + .. _`UDP`: https://arxiv.org/abs/1911.07524 + """ + N = simcc.shape[0] + + # modulate simcc + simcc = gaussian_blur1d(simcc, blur_kernel_size) + np.clip(simcc, 1e-3, 50.0, simcc) + np.log(simcc, simcc) + + simcc = np.pad(simcc, ((0, 0), (0, 0), (2, 2)), "edge") + + for n in range(N): + px = (keypoints[n] + 2.5).astype(np.int64).reshape(-1, 1) # K, 1 + + dx0 = np.take_along_axis(simcc[n], px, axis=1) # K, 1 + dx1 = np.take_along_axis(simcc[n], px + 1, axis=1) + dx_1 = np.take_along_axis(simcc[n], px - 1, axis=1) + dx2 = np.take_along_axis(simcc[n], px + 2, axis=1) + dx_2 = np.take_along_axis(simcc[n], px - 2, axis=1) + + dx = 0.5 * (dx1 - dx_1) + dxx = 1e-9 + 0.25 * (dx2 - 2 * dx0 + dx_2) + + offset = dx / dxx + keypoints[n] -= offset.reshape(-1) + + return keypoints diff --git a/sapiens/pose/src/datasets/keypoints308_3po_dataset.py b/sapiens/pose/src/datasets/keypoints308_3po_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a33d21a737226434fa35a47dc815988696a86f3d --- /dev/null +++ b/sapiens/pose/src/datasets/keypoints308_3po_dataset.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import io +import json +import os +import pickle # nosec +from contextlib import redirect_stderr +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import cv2 +import numpy as np +from PIL import Image +from sapiens.registry import DATASETS + +from .pose_base_dataset import PoseBaseDataset + +with open(os.devnull, "w") as f, redirect_stderr(f): + try: + from care.data.io import typed + except Exception: + pass + + +@DATASETS.register_module() +class Keypoints308_3PODataset(PoseBaseDataset): + METAINFO: dict = dict( + from_file=os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "configs", + "_base_", + "keypoints308.py", + ) + ) + + def __init__(self, subsample_factor: int = 1, **kwargs) -> None: + self.subsample_factor = subsample_factor + super().__init__(**kwargs) + + self.remove_teeth = self.metainfo["remove_teeth"] + if self.remove_teeth: + self.teeth_ids = self.metainfo["teeth_keypoint_ids"] + + return + + def load_data_list(self) -> List[dict]: + """Load data list from 344 body points.""" + self._register_airstore_handler() + with open(self.ann_file, "rb") as f: + raw = f.read() + raw_data = json.loads(raw) # samples=5,267,269 + + data_list = [] + for i, sample in enumerate(raw_data): + dp = { + "airstore_id": sample["sample_id"], + "img_id": i, + "person_id": sample["person_id"], + } + if sample.get("box-default") is not None: + dp["box"] = sample["box-default"] + data_list.append(dp) + return data_list + + def _register_airstore_handler(self) -> None: + from typedio.file_system.airstore_client import register_airstore_in_fsspec + + register_airstore_in_fsspec() + self.path_template = ( + "airstoreds://codec_avatar_sapiens_3po_ala_v1_train_no_user_data" + ) + self.airstore = True + + def _read_from_airstore(self, asset: str, sid: str) -> io.BytesIO: + with typed.open(self.path_template + f"/{asset}?sampleId={sid}").open() as f: + data = io.BytesIO(f.read()) + return data + + def __len__(self) -> int: + return len(self.data_list) // self.subsample_factor + + def get_data_info(self, idx): + if self.subsample_factor > 1: + idx = idx * self.subsample_factor + np.random.randint( + 0, self.subsample_factor + ) + idx = idx % len(self.data_list) + + data_info = copy.deepcopy(self.data_list[idx]) + + try: + img = Image.open( + self._read_from_airstore("image", data_info["airstore_id"]) + ) ## pillow image + keypoints_list = pickle.loads( + self._read_from_airstore( + "annotations", data_info["airstore_id"] + ).getvalue() + ) # shape 3 x 344 + except Exception as e: + print(f"Error loading data: {e}") + return None + + keypoints_np = None + for keypoints_info in keypoints_list: + if keypoints_info["person_id"] == data_info["person_id"]: + keypoints_np = np.array(keypoints_info["keypoints_2d"]) # 70 x 3 + break + + if keypoints_np is None: + return None + + ## convert to 344 keypoints + keypoints_344_np = np.zeros((344, 3)) + keypoints_344_np[:70] = keypoints_np ## 344 x 3 + keypoints_344_np = keypoints_344_np.reshape(1, -1, 3) # shape 1 x 344 x 3 + + img = np.array(img) ## RGB image + img = img[ + :, :, ::-1 + ] # Convert RGB to BGR, the model preprocessor will convert this to rgb again + + img_w, img_h = img.shape[1], img.shape[0] + + # process keypoints + keypoints = keypoints_344_np[:, :, :2] # shape 1 x 344 x 2 + keypoints_visible = keypoints_344_np[:, :, 2] > 0 # shape 1 x 344 + + # Identify keypoints that are out of bounds for x (width) and y (height) + out_of_bounds_w = np.logical_or( + keypoints[0, :, 0] <= 0, keypoints[0, :, 0] >= img_w + ) + out_of_bounds_h = np.logical_or( + keypoints[0, :, 1] <= 0, keypoints[0, :, 1] >= img_h + ) + + # Update keypoints_visible based on the out-of-bounds keypoints + keypoints_visible[0, out_of_bounds_w | out_of_bounds_h] = 0 + keypoints[keypoints_visible == 0] = 0 + + ## remove teeth keypoints + if self.remove_teeth: + # Use numpy's boolean indexing to remove keypoints + mask = np.ones(keypoints.shape[1], dtype=bool) + mask[self.teeth_ids] = False + keypoints = keypoints[:, mask, :] + keypoints_visible = keypoints_visible[:, mask] + + # Default bounding box to the full image size + bbox = np.array([0, 0, img_w, img_h], dtype=np.float32).reshape(1, 4) + + if np.any(keypoints_visible): # If any keypoints are visible + visible_keypoints = keypoints[0][ + keypoints_visible[0] == 1 + ] # Filter out the invisible keypoints + + # Get the bounding box encompassing the keypoints + x_min, y_min = np.clip( + np.min(visible_keypoints, axis=0), [0, 0], [img_w, img_h] + ) + x_max, y_max = np.clip( + np.max(visible_keypoints, axis=0), [0, 0], [img_w, img_h] + ) + + bbox = np.array([x_min, y_min, x_max, y_max], dtype=np.float32).reshape( + 1, 4 + ) + + num_keypoints = np.count_nonzero(keypoints_visible) + + data_info = { + "img": img, + "img_id": data_info["img_id"], + "img_path": "", + "airstore_id": data_info["airstore_id"], + "bbox": bbox, + "bbox_score": np.ones(1, dtype=np.float32), + "num_keypoints": num_keypoints, + "keypoints": keypoints, + "keypoints_visible": keypoints_visible, + "iscrowd": 0, + "segmentation": None, + "id": idx, + "category_id": 1, + } + + # Some codebase needs `sample_idx` of data information. Here we convert + # the idx to a positive number and save it in data information. + if idx >= 0: + data_info["sample_idx"] = idx + else: + data_info["sample_idx"] = len(self) + idx + + # Add metainfo items that are required in the pipeline and the model + metainfo_keys = [ + "upper_body_ids", + "lower_body_ids", + "flip_pairs", + "dataset_keypoint_weights", + "flip_indices", + "skeleton_links", + ] + + for key in metainfo_keys: + assert key not in data_info, ( + f'"{key}" is a reserved key for `metainfo`, but already ' + "exists in the `data_info`." + ) + + data_info[key] = deepcopy(self.metainfo[key]) + + return data_info diff --git a/sapiens/pose/src/datasets/keypoints308_goliath_dataset.py b/sapiens/pose/src/datasets/keypoints308_goliath_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..44f5cd48b5cb72a014c380f1c54b09703a457623 --- /dev/null +++ b/sapiens/pose/src/datasets/keypoints308_goliath_dataset.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import io +import json +import os +from contextlib import redirect_stderr +from copy import deepcopy +from typing import List + +import cv2 +import numpy as np +from PIL import Image +from sapiens.registry import DATASETS + +from .pose_base_dataset import PoseBaseDataset + +with open(os.devnull, "w") as f, redirect_stderr(f): + try: + from care.data.io import typed + except Exception: + pass + + +@DATASETS.register_module() +class Keypoints308GoliathDataset(PoseBaseDataset): + METAINFO: dict = dict( + from_file=os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "configs", + "_base_", + "keypoints308.py", + ) + ) + + def __init__(self, subsample_factor: int = 1, **kwargs) -> None: + self.subsample_factor = subsample_factor + super().__init__(**kwargs) + self.remove_teeth = self.metainfo["remove_teeth"] + if self.remove_teeth: + self.teeth_ids = self.metainfo["teeth_keypoint_ids"] + + return + + def __len__(self) -> int: + return len(self.data_list) // self.subsample_factor + + def load_data_list(self) -> List[dict]: + """Load data list from 344 body points.""" + self._register_airstore_handler() + with open(self.ann_file, "rb") as f: + raw = f.read() + raw_data = json.loads(raw) # samples=5,267,269 + + data_list = [] + for i, sample in enumerate(raw_data): + if "sample_id" not in sample: + sample["sample_id"] = sample["airstore_id"] + + dp = { + "airstore_id": sample["sample_id"], + "img_id": i, + } + if sample.get("box-default") is not None: + dp["box"] = sample["box-default"] + data_list.append(dp) + return data_list + + def _register_airstore_handler(self) -> None: + from typedio.file_system.airstore_client import register_airstore_in_fsspec + + register_airstore_in_fsspec() + self.path_template = ( + "airstoreds://rlr_detection_services_ml_datasets_no_user_data" + ) + self.airstore = True + + def _read_from_airstore(self, asset: str, sid: str) -> io.BytesIO: + with typed.open(self.path_template + f"/{asset}?sampleId={sid}").open() as f: + data = io.BytesIO(f.read()) + return data + + def get_data_info(self, idx): + if self.subsample_factor > 1: + idx = idx * self.subsample_factor + np.random.randint( + 0, self.subsample_factor + ) + idx = idx % len(self.data_list) + + data_info = copy.deepcopy(self.data_list[idx]) + + try: + img = Image.open( + self._read_from_airstore("image", data_info["airstore_id"]) + ) ## pillow image + keypoints_np = np.load( + self._read_from_airstore("keypoint", data_info["airstore_id"]) + ) # shape 3 x 344 + except Exception as e: + print(f"Error loading data: {e}") + return None + + img = np.array(img) ## RGB image + img = img[ + :, :, ::-1 + ] # Convert RGB to BGR, the model preprocessor will convert this to rgb again + + img_w, img_h = img.shape[1], img.shape[0] + + # process keypoints + keypoints = keypoints_np[:2].T.reshape(1, -1, 2) # shape 1 x 344 x 2 + keypoints_visible = np.where(keypoints_np[2].T > 0, 1, 0).reshape( + 1, -1 + ) # shape 1 x 344 + + # Identify keypoints that are out of bounds for x (width) and y (height) + out_of_bounds_w = np.logical_or( + keypoints[0, :, 0] <= 0, keypoints[0, :, 0] >= img_w + ) + out_of_bounds_h = np.logical_or( + keypoints[0, :, 1] <= 0, keypoints[0, :, 1] >= img_h + ) + + # Update keypoints_visible based on the out-of-bounds keypoints + keypoints_visible[0, out_of_bounds_w | out_of_bounds_h] = 0 + keypoints[keypoints_visible == 0] = 0 + + ## remove teeth keypoints + if self.remove_teeth: + # Use numpy's boolean indexing to remove keypoints + mask = np.ones(keypoints.shape[1], dtype=bool) + mask[self.teeth_ids] = False + keypoints = keypoints[:, mask, :] + keypoints_visible = keypoints_visible[:, mask] + + # Default bounding box to the full image size + bbox = np.array([0, 0, img_w, img_h], dtype=np.float32).reshape(1, 4) + + if np.any(keypoints_visible): # If any keypoints are visible + visible_keypoints = keypoints[0][ + keypoints_visible[0] == 1 + ] # Filter out the invisible keypoints + + # Get the bounding box encompassing the keypoints + x_min, y_min = np.clip( + np.min(visible_keypoints, axis=0), [0, 0], [img_w, img_h] + ) + x_max, y_max = np.clip( + np.max(visible_keypoints, axis=0), [0, 0], [img_w, img_h] + ) + + bbox = np.array([x_min, y_min, x_max, y_max], dtype=np.float32).reshape( + 1, 4 + ) + + num_keypoints = np.count_nonzero(keypoints_visible) + + ## atleast 8 vis keypoints + if num_keypoints < self.metainfo["min_visible_keypoints"]: + random_idx = np.random.randint(0, len(self.data_list)) + return self.get_data_info(random_idx) + + ## check body keypoints additionally + num_body_keypoints = np.count_nonzero(keypoints_visible[0, :21]) + if num_body_keypoints < 6: + return None + + ## ignore greyscale images for training + B, G, R = cv2.split(img) + if np.array_equal(B, G) and np.array_equal(B, R): + random_idx = np.random.randint(0, len(self.data_list)) + return self.get_data_info(random_idx) + + data_info = { + "img": img, + "img_id": data_info["img_id"], + "img_path": "", + "airstore_id": data_info["airstore_id"], + "bbox": bbox, + "bbox_score": np.ones(1, dtype=np.float32), + "num_keypoints": num_keypoints, + "keypoints": keypoints, ## 1 x 308 x 2 + "keypoints_visible": keypoints_visible, ## 1 x 308 + "iscrowd": 0, + "segmentation": None, + "id": idx, + "category_id": 1, + } + + if idx >= 0: + data_info["sample_idx"] = idx + else: + data_info["sample_idx"] = len(self) + idx + + # Add metainfo items that are required in the pipeline and the model + metainfo_keys = [ + "upper_body_ids", + "lower_body_ids", + "flip_pairs", + "dataset_keypoint_weights", + "flip_indices", + "skeleton_links", + ] + + for key in metainfo_keys: + assert key not in data_info, ( + f'"{key}" is a reserved key for `metainfo`, but already ' + "exists in the `data_info`." + ) + + data_info[key] = deepcopy(self.metainfo[key]) + + return data_info diff --git a/sapiens/pose/src/datasets/keypoints308_goliath_eval_dataset.py b/sapiens/pose/src/datasets/keypoints308_goliath_eval_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..adc4ae6c1777f5c9d44b84ac34adc0eef1cfc567 --- /dev/null +++ b/sapiens/pose/src/datasets/keypoints308_goliath_eval_dataset.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import io +import json +import os +from contextlib import redirect_stderr +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import cv2 +import numpy as np +from PIL import Image +from sapiens.registry import DATASETS + +from .pose_base_dataset import PoseBaseDataset + + +@DATASETS.register_module() +class Keypoints308GoliathEvalDataset(PoseBaseDataset): + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + ann = raw_data_info["raw_ann_info"] + img = raw_data_info["raw_img_info"] + + img_path = os.path.join(self.data_root, img["file_name"]) + img_w, img_h = img["width"], img["height"] + + # get bbox in shape [1, 4], formatted as xywh + x, y, w, h = ann["bbox"] + x1 = np.clip(x, 0, img_w - 1) + y1 = np.clip(y, 0, img_h - 1) + x2 = np.clip(x + w, 0, img_w - 1) + y2 = np.clip(y + h, 0, img_h - 1) + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array(ann["goliath_wholebody_kpts"]).reshape( + 1, -1, 3 + ) ## 1 z 308 x 3 + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2] > 0) + + num_keypoints = ann["num_keypoints"] + + data_info = { + "img_id": ann["image_id"], + "img_path": img_path, + "bbox": bbox, + "bbox_score": np.ones(1, dtype=np.float32), + "num_keypoints": num_keypoints, + "keypoints": keypoints, + "keypoints_visible": keypoints_visible, + "iscrowd": ann["iscrowd"], + "segmentation": None, + "id": ann["id"], + "category_id": ann["category_id"], + "raw_ann_info": copy.deepcopy(ann), + } + + return data_info diff --git a/sapiens/pose/src/datasets/keypoints308_shutterstock_dataset.py b/sapiens/pose/src/datasets/keypoints308_shutterstock_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a60b6446950257558b4e73854de2cb6f903c010f --- /dev/null +++ b/sapiens/pose/src/datasets/keypoints308_shutterstock_dataset.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import io +import json +import os +from contextlib import redirect_stderr +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import cv2 +import numpy as np +from PIL import Image +from sapiens.registry import DATASETS + +from .pose_base_dataset import PoseBaseDataset + +with open(os.devnull, "w") as f, redirect_stderr(f): + try: + from care.data.io import typed + except Exception: + pass + + +@DATASETS.register_module() +class Keypoints308ShutterstockDataset(PoseBaseDataset): + METAINFO: dict = dict( + from_file=os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "configs", + "_base_", + "keypoints308.py", + ) + ) + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + self.remove_teeth = self.metainfo["remove_teeth"] + if self.remove_teeth: + self.teeth_ids = self.metainfo["teeth_keypoint_ids"] + + return + + def load_data_list(self) -> List[dict]: + """Load data list from 344 body points.""" + self._register_airstore_handler() + with open(self.ann_file, "rb") as f: + raw = f.read() + raw_data = json.loads(raw) # samples=5,267,269 + + data_list = [] + for i, sample in enumerate(raw_data): + if "sample_id" not in sample: + sample["sample_id"] = sample["airstore_id"] + + dp = { + "airstore_id": sample["sample_id"], + "img_id": i, + } + if sample.get("box-default") is not None: + dp["box"] = sample["box-default"] + data_list.append(dp) + return data_list + + def _register_airstore_handler(self) -> None: + from typedio.file_system.airstore_client import register_airstore_in_fsspec + + register_airstore_in_fsspec() + self.path_template = ( + "airstoreds://rlr_detection_services_ml_datasets_no_user_data" + ) + self.airstore = True + + def _read_from_airstore(self, asset: str, sid: str) -> io.BytesIO: + with typed.open(self.path_template + f"/{asset}?sampleId={sid}").open() as f: + data = io.BytesIO(f.read()) + return data + + def get_data_info(self, idx): + data_info = copy.deepcopy(self.data_list[idx]) + + try: + img = Image.open( + self._read_from_airstore("image", data_info["airstore_id"]) + ) ## pillow image + keypoints_np = np.load( + self._read_from_airstore("keypoint", data_info["airstore_id"]) + ) # shape 3 x 344 + except Exception as e: + print(f"Error loading data: {e}") + return None + + img = np.array(img) ## RGB image + img = img[ + :, :, ::-1 + ] # Convert RGB to BGR, the model preprocessor will convert this to rgb again + + img_w, img_h = img.shape[1], img.shape[0] + + # process keypoints + keypoints = keypoints_np[:2].T.reshape(1, -1, 2) # shape 1 x 344 x 2 + keypoints_visible = np.where(keypoints_np[2].T > 0, 1, 0).reshape( + 1, -1 + ) # shape 1 x 344 + + # Identify keypoints that are out of bounds for x (width) and y (height) + out_of_bounds_w = np.logical_or( + keypoints[0, :, 0] <= 0, keypoints[0, :, 0] >= img_w + ) + out_of_bounds_h = np.logical_or( + keypoints[0, :, 1] <= 0, keypoints[0, :, 1] >= img_h + ) + + # Update keypoints_visible based on the out-of-bounds keypoints + keypoints_visible[0, out_of_bounds_w | out_of_bounds_h] = 0 # shape 1 x 344 + keypoints[keypoints_visible == 0] = 0 + + ## remove teeth keypoints + if self.remove_teeth: + # Use numpy's boolean indexing to remove keypoints + mask = np.ones(keypoints.shape[1], dtype=bool) + mask[self.teeth_ids] = False + keypoints = keypoints[:, mask, :] + keypoints_visible = keypoints_visible[:, mask] + + # Default bounding box to the full image size + bbox = np.array([0, 0, img_w, img_h], dtype=np.float32).reshape(1, 4) + + if np.any(keypoints_visible): # If any keypoints are visible + visible_keypoints = keypoints[0][ + keypoints_visible[0] == 1 + ] # Filter out the invisible keypoints + + # Get the bounding box encompassing the keypoints + x_min, y_min = np.clip( + np.min(visible_keypoints, axis=0), [0, 0], [img_w, img_h] + ) + x_max, y_max = np.clip( + np.max(visible_keypoints, axis=0), [0, 0], [img_w, img_h] + ) + + bbox = np.array([x_min, y_min, x_max, y_max], dtype=np.float32).reshape( + 1, 4 + ) + + num_keypoints = np.count_nonzero(keypoints_visible) + + ## atleast 8 vis keypoints + if num_keypoints < self.metainfo["min_visible_keypoints"]: + random_idx = np.random.randint(0, len(self.data_list)) + return self.get_data_info(random_idx) + + ## check body keypoints additionally + num_body_keypoints = np.count_nonzero(keypoints_visible[0, :21]) + if num_body_keypoints < 6: + return None + + ## ignore greyscale images for training + B, G, R = cv2.split(img) + if np.array_equal(B, G) and np.array_equal(B, R): + random_idx = np.random.randint(0, len(self.data_list)) + return self.get_data_info(random_idx) + + data_info = { + "img": img, + "img_id": data_info["img_id"], + "img_path": "", + "airstore_id": data_info["airstore_id"], + "bbox": bbox, + "bbox_score": np.ones(1, dtype=np.float32), + "num_keypoints": num_keypoints, + "keypoints": keypoints, + "keypoints_visible": keypoints_visible, + "iscrowd": 0, + "segmentation": None, + "id": idx, + "category_id": 1, + } + + # Some codebase needs `sample_idx` of data information. Here we convert + # the idx to a positive number and save it in data information. + if idx >= 0: + data_info["sample_idx"] = idx + else: + data_info["sample_idx"] = len(self) + idx + + # Add metainfo items that are required in the pipeline and the model + metainfo_keys = [ + "upper_body_ids", + "lower_body_ids", + "flip_pairs", + "dataset_keypoint_weights", + "flip_indices", + "skeleton_links", + ] + + for key in metainfo_keys: + assert key not in data_info, ( + f'"{key}" is a reserved key for `metainfo`, but already ' + "exists in the `data_info`." + ) + + data_info[key] = deepcopy(self.metainfo[key]) + + return data_info diff --git a/sapiens/pose/src/datasets/keypoints308_shutterstock_eval_dataset.py b/sapiens/pose/src/datasets/keypoints308_shutterstock_eval_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..07bc23ed8361110ffa439db3db0aa60de56cad2d --- /dev/null +++ b/sapiens/pose/src/datasets/keypoints308_shutterstock_eval_dataset.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import io +import json +import os +from contextlib import redirect_stderr +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import cv2 +import numpy as np +from PIL import Image +from sapiens.registry import DATASETS + +from .pose_base_dataset import PoseBaseDataset + + +@DATASETS.register_module() +class Keypoints308ShutterstockEvalDataset(PoseBaseDataset): + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + ann = raw_data_info["raw_ann_info"] + img = raw_data_info["raw_img_info"] + + img_path = os.path.join(self.data_root, img["file_name"]) + img_w, img_h = img["width"], img["height"] + + # get bbox in shape [1, 4], formatted as xywh + x, y, w, h = ann["bbox"] + x1 = np.clip(x, 0, img_w - 1) + y1 = np.clip(y, 0, img_h - 1) + x2 = np.clip(x + w, 0, img_w - 1) + y2 = np.clip(y + h, 0, img_h - 1) + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array(ann["goliath_wholebody_kpts"]).reshape( + 1, -1, 3 + ) ## 1 z 308 x 3 + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2] > 0) + + num_keypoints = ann["num_keypoints"] + + data_info = { + "img_id": ann["image_id"], + "img_path": img_path, + "bbox": bbox, + "bbox_score": np.ones(1, dtype=np.float32), + "num_keypoints": num_keypoints, + "keypoints": keypoints, + "keypoints_visible": keypoints_visible, + "iscrowd": ann["iscrowd"], + "segmentation": None, + "id": ann["id"], + "category_id": ann["category_id"], + "raw_ann_info": copy.deepcopy(ann), + } + + return data_info diff --git a/sapiens/pose/src/datasets/pose_base_dataset.py b/sapiens/pose/src/datasets/pose_base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a78f557ecceeffdd13c3f2231141c0c3b0393cee --- /dev/null +++ b/sapiens/pose/src/datasets/pose_base_dataset.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import json +import os +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import cv2 +import numpy as np +from sapiens.engine.datasets import BaseDataset +from sapiens.registry import DATASETS + +from .utils import parse_pose_metainfo + + +@DATASETS.register_module() +class PoseBaseDataset(BaseDataset): + METAINFO: dict = dict(from_file="configs/_base_/keypoints308.py") + + def __init__( + self, + ann_file: str = "", + num_samples: int = None, + bbox_file: Optional[str] = None, + **kwargs, + ): + self.bbox_file = bbox_file + self.ann_file = ann_file + self.num_samples = num_samples + + self.metainfo = parse_pose_metainfo(self.METAINFO) + super().__init__(**kwargs) + + if self.num_samples is not None: + self.data_list = self.data_list[:num_samples] + + print( + "\033[96mLoaded {} samples for {}, Test mode: {}\033[0m".format( + self.__len__(), self.__class__.__name__, self.test_mode + ) + ) + return + + def prepare_data(self, idx) -> Any: + data_info = self.get_data_info(idx) + transformed_data_info = self.pipeline(data_info) + + if transformed_data_info is None: + return None + + ## pipeline is set to empty when using concatenation of datasets. + if ( + self.test_mode == False + and "data_samples" in transformed_data_info + and "gt_instance_labels" in transformed_data_info["data_samples"] + and "keypoints_visible" + in transformed_data_info["data_samples"].gt_instance_labels + ): + num_transformed_keypoints = ( + transformed_data_info["data_samples"] + .gt_instance_labels["keypoints_visible"] + .sum() + .item() + ) ## after cropping + + ## minimum visible keypoints for coco_wholebody is 8 + if self.metainfo["dataset_name"] == "coco_wholebody": + if num_transformed_keypoints < 8: + return None + + ## absolute minimum visible keypoints is 3 + if num_transformed_keypoints < 3: + return None + + return transformed_data_info + + def get_data_info(self, idx: int) -> dict: + data_info = super().get_data_info(idx) + data_info["img"] = cv2.imread(data_info["img_path"]) + + # Add metainfo items that are required in the pipeline and the model + metainfo_keys = [ + "upper_body_ids", + "lower_body_ids", + "flip_pairs", + "dataset_keypoint_weights", + "flip_indices", + "skeleton_links", + ] + + for key in metainfo_keys: + assert key not in data_info, ( + f'"{key}" is a reserved key for `metainfo`, but already ' + "exists in the `data_info`." + ) + + data_info[key] = deepcopy(self.metainfo[key]) + + return data_info + + def load_data_list(self) -> List[dict]: + if self.bbox_file: + data_list = self._load_detection_results() + else: + instance_list, _ = self._load_annotations() + data_list = self._get_topdown_data_infos(instance_list) + + return data_list + + def _load_annotations(self) -> Tuple[List[dict], List[dict]]: + from xtcocotools.coco import COCO # lazy: only needed for COCO-format ann files + + assert os.path.exists(self.ann_file), "Annotation file does not exist" + self.coco = COCO(self.ann_file) + self.metainfo["CLASSES"] = self.coco.loadCats(self.coco.getCatIds()) + + instance_list = [] + image_list = [] + + for img_id in self.coco.getImgIds(): + img = self.coco.loadImgs(img_id)[0] + img.update( + { + "img_id": img_id, + "img_path": os.path.join(self.data_root, img["file_name"]), + } + ) + image_list.append(img) + ann_ids = self.coco.getAnnIds(imgIds=img_id) + for ann in self.coco.loadAnns(ann_ids): + instance_info = self.parse_data_info( + dict(raw_ann_info=ann, raw_img_info=img) + ) + + # skip invalid instance annotation. + if not instance_info: + continue + + instance_list.append(instance_info) + return instance_list, image_list + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + ann = raw_data_info["raw_ann_info"] + img = raw_data_info["raw_img_info"] + + # filter invalid instance + if "bbox" not in ann or "keypoints" not in ann: + return None + + img_w, img_h = img["width"], img["height"] + + # get bbox in shape [1, 4], formatted as xywh + x, y, w, h = ann["bbox"] + x1 = np.clip(x, 0, img_w - 1) + y1 = np.clip(y, 0, img_h - 1) + x2 = np.clip(x + w, 0, img_w - 1) + y2 = np.clip(y + h, 0, img_h - 1) + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array(ann["keypoints"], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + if "num_keypoints" in ann: + num_keypoints = ann["num_keypoints"] + else: + num_keypoints = np.count_nonzero(keypoints.max(axis=2)) + + data_info = { + "img_id": ann["image_id"], + "img_path": img["img_path"], + "bbox": bbox, + "bbox_score": np.ones(1, dtype=np.float32), + "num_keypoints": num_keypoints, + "keypoints": keypoints, + "keypoints_visible": keypoints_visible, + "iscrowd": ann.get("iscrowd", 0), + "segmentation": ann.get("segmentation", None), + "id": ann["id"], + "category_id": ann["category_id"], + "raw_ann_info": copy.deepcopy(ann), + } + + if "crowdIndex" in img: + data_info["crowd_index"] = img["crowdIndex"] + + return data_info + + @staticmethod + def _is_valid_instance(data_info: Dict) -> bool: + # crowd annotation + if "iscrowd" in data_info and data_info["iscrowd"]: + return False + # invalid keypoints + if "num_keypoints" in data_info and data_info["num_keypoints"] == 0: + return False + # invalid bbox + if "bbox" in data_info: + bbox = data_info["bbox"][0] + w, h = bbox[2:4] - bbox[:2] + if w <= 0 or h <= 0: + return False + # invalid keypoints + if "keypoints" in data_info: + if np.max(data_info["keypoints"]) <= 0: + return False + return True + + def _get_topdown_data_infos(self, instance_list: List[Dict]) -> List[Dict]: + data_list_tp = list(filter(self._is_valid_instance, instance_list)) + return data_list_tp + + def _load_detection_results(self) -> List[dict]: + raise NotImplementedError diff --git a/sapiens/pose/src/datasets/transforms/__init__.py b/sapiens/pose/src/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9858bdfa443f0357093d18c41f147cdf936db351 --- /dev/null +++ b/sapiens/pose/src/datasets/transforms/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .bbox_transforms import * +from .pose_transforms import * diff --git a/sapiens/pose/src/datasets/transforms/bbox_transforms.py b/sapiens/pose/src/datasets/transforms/bbox_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..e5b94baf3ba4cb3c7582de0fe937508d435dce25 --- /dev/null +++ b/sapiens/pose/src/datasets/transforms/bbox_transforms.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import cv2 +import numpy as np + + +def bbox_xyxy2xywh(bbox_xyxy: np.ndarray) -> np.ndarray: + bbox_xywh = bbox_xyxy.copy() + bbox_xywh[:, 2] = bbox_xywh[:, 2] - bbox_xywh[:, 0] + bbox_xywh[:, 3] = bbox_xywh[:, 3] - bbox_xywh[:, 1] + + return bbox_xywh + + +def bbox_xywh2xyxy(bbox_xywh: np.ndarray) -> np.ndarray: + bbox_xyxy = bbox_xywh.copy() + bbox_xyxy[:, 2] = bbox_xyxy[:, 2] + bbox_xyxy[:, 0] + bbox_xyxy[:, 3] = bbox_xyxy[:, 3] + bbox_xyxy[:, 1] + + return bbox_xyxy + + +def bbox_xyxy2cs( + bbox: np.ndarray, padding: float = 1.0 +) -> Tuple[np.ndarray, np.ndarray]: + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x1 + x2, y1 + y2]) * 0.5 + scale = np.hstack([x2 - x1, y2 - y1]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def bbox_xywh2cs( + bbox: np.ndarray, padding: float = 1.0 +) -> Tuple[np.ndarray, np.ndarray]: + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + x, y, w, h = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x + w * 0.5, y + h * 0.5]) + scale = np.hstack([w, h]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def bbox_cs2xyxy( + center: np.ndarray, scale: np.ndarray, padding: float = 1.0 +) -> np.ndarray: + dim = center.ndim + assert scale.ndim == dim + + if dim == 1: + center = center[None, :] + scale = scale[None, :] + + wh = scale / padding + xy = center - 0.5 * wh + bbox = np.hstack((xy, xy + wh)) + + if dim == 1: + bbox = bbox[0] + + return bbox + + +def bbox_cs2xywh( + center: np.ndarray, scale: np.ndarray, padding: float = 1.0 +) -> np.ndarray: + dim = center.ndim + assert scale.ndim == dim + + if dim == 1: + center = center[None, :] + scale = scale[None, :] + + wh = scale / padding + xy = center - 0.5 * wh + bbox = np.hstack((xy, wh)) + + if dim == 1: + bbox = bbox[0] + + return bbox + + +def get_udp_warp_matrix( + center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], +) -> np.ndarray: + assert len(center) == 2 + assert len(scale) == 2 + assert len(output_size) == 2 + + input_size = center * 2 + rot_rad = np.deg2rad(rot) + warp_mat = np.zeros((2, 3), dtype=np.float32) + scale_x = (output_size[0] - 1) / scale[0] + scale_y = (output_size[1] - 1) / scale[1] + warp_mat[0, 0] = math.cos(rot_rad) * scale_x + warp_mat[0, 1] = -math.sin(rot_rad) * scale_x + warp_mat[0, 2] = scale_x * ( + -0.5 * input_size[0] * math.cos(rot_rad) + + 0.5 * input_size[1] * math.sin(rot_rad) + + 0.5 * scale[0] + ) + warp_mat[1, 0] = math.sin(rot_rad) * scale_y + warp_mat[1, 1] = math.cos(rot_rad) * scale_y + warp_mat[1, 2] = scale_y * ( + -0.5 * input_size[0] * math.sin(rot_rad) + - 0.5 * input_size[1] * math.cos(rot_rad) + + 0.5 * scale[1] + ) + return warp_mat + + +def get_warp_matrix( + center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], + shift: Tuple[float, float] = (0.0, 0.0), + inv: bool = False, +) -> np.ndarray: + assert len(center) == 2 + assert len(scale) == 2 + assert len(output_size) == 2 + assert len(shift) == 2 + + shift = np.array(shift) + src_w = scale[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.deg2rad(rot) + src_dir = _rotate_point(np.array([0.0, src_w * -0.5]), rot_rad) + dst_dir = np.array([0.0, dst_w * -0.5]) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale * shift + src[1, :] = center + src_dir + scale * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + return warp_mat diff --git a/sapiens/pose/src/datasets/transforms/pose_transforms.py b/sapiens/pose/src/datasets/transforms/pose_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..9c009768868d7a61fce6da05bd7956a2ba3232b4 --- /dev/null +++ b/sapiens/pose/src/datasets/transforms/pose_transforms.py @@ -0,0 +1,684 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from copy import deepcopy +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import cv2 +import numpy as np +from sapiens.engine.datasets import BaseTransform, to_tensor +from sapiens.registry import TRANSFORMS +from scipy.stats import truncnorm + +from ..codecs.udp_heatmap import UDPHeatmap +from .bbox_transforms import bbox_xyxy2cs, get_udp_warp_matrix, get_warp_matrix + +try: + warnings.filterwarnings( + "ignore", + message=r"Error fetching version info", + category=UserWarning, + module=r"^albumentations\.check_version$", + ) + + import albumentations + +except ImportError: + albumentations = None + +Number = Union[int, float] + + +@TRANSFORMS.register_module() +class PoseGenerateTarget(BaseTransform): + def __init__( + self, + encoder: None, + multilevel: bool = False, + use_dataset_keypoint_weights: bool = False, + ) -> None: + super().__init__() + self.encoder_cfg = deepcopy(encoder) + self.multilevel = multilevel + self.use_dataset_keypoint_weights = use_dataset_keypoint_weights + encoder_type = self.encoder_cfg.pop("type") + assert encoder_type == "UDPHeatmap", "Only UDPHeatmap is supported" + self.encoder = UDPHeatmap(**self.encoder_cfg) + + def transform(self, results: Dict) -> Optional[dict]: + if results.get("transformed_keypoints", None) is not None: + keypoints = results["transformed_keypoints"] ## N x K x 2 + elif results.get("keypoints", None) is not None: + keypoints = results["keypoints"] + else: + raise ValueError( + "GenerateTarget requires 'transformed_keypoints' or" + " 'keypoints' in the results." + ) + + keypoints_visible = results["keypoints_visible"] ## N x K + + auxiliary_encode_kwargs = { + key: results[key] for key in self.encoder.auxiliary_encode_keys + } + encoded = self.encoder.encode( + keypoints=keypoints, + keypoints_visible=keypoints_visible, + **auxiliary_encode_kwargs, + ) + + if self.use_dataset_keypoint_weights and "keypoint_weights" in encoded: + if isinstance(encoded["keypoint_weights"], list): + for w in encoded["keypoint_weights"]: + w *= results["dataset_keypoint_weights"] + else: + encoded["keypoint_weights"] *= results["dataset_keypoint_weights"] + + results.update(encoded) + + if results.get("keypoint_weights", None) is not None: + results["transformed_keypoints_visible"] = results["keypoint_weights"] + elif results.get("keypoints", None) is not None: + results["transformed_keypoints_visible"] = results["keypoints_visible"] + else: + raise ValueError( + "GenerateTarget requires 'keypoint_weights' or" + " 'keypoints_visible' in the results." + ) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(encoder={str(self.encoder_cfg)}, " + repr_str += f"use_dataset_keypoint_weights={self.use_dataset_keypoint_weights})" + return repr_str + + +@TRANSFORMS.register_module() +class PoseTopdownAffine(BaseTransform): + def __init__(self, input_size: Tuple[int, int], use_udp: bool = True) -> None: + super().__init__() + + assert len(input_size) == 2, f"Invalid input_size {input_size}" + + self.input_size = input_size + self.use_udp = use_udp + + @staticmethod + def _fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float): + w, h = np.hsplit(bbox_scale, [1]) + bbox_scale = np.where( + w > h * aspect_ratio, + np.hstack([w, w / aspect_ratio]), + np.hstack([h * aspect_ratio, h]), + ) + return bbox_scale + + def transform(self, results: Dict) -> Optional[dict]: + w, h = self.input_size + warp_size = (int(w), int(h)) + + # reshape bbox to fixed aspect ratio + results["bbox_scale"] = self._fix_aspect_ratio( + results["bbox_scale"], aspect_ratio=w / h + ) + + assert results["bbox_center"].shape[0] == 1, ( + "Top-down heatmap only supports single instance. Got invalid " + f"shape of bbox_center {results['bbox_center'].shape}." + ) + + center = results["bbox_center"][0] + scale = results["bbox_scale"][0] + if "bbox_rotation" in results: + rot = results["bbox_rotation"][0] + else: + rot = 0.0 + + if self.use_udp: + warp_mat = get_udp_warp_matrix(center, scale, rot, output_size=(w, h)) + else: + warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) + + # estimate overall scale from the affine matrix + sx = np.linalg.norm(warp_mat[0, :2]) + sy = np.linalg.norm(warp_mat[1, :2]) + scale_factor = min(sx, sy) + + # choose interpolation: area for down, linear for up + interp = cv2.INTER_AREA if scale_factor < 1.0 else cv2.INTER_CUBIC + + results["img"] = cv2.warpAffine( + results["img"], warp_mat, warp_size, flags=interp + ) ## H x W x 3 + + if results.get("keypoints", None) is not None: + transformed_keypoints = results["keypoints"].copy() + # Only transform (x, y) coordinates + transformed_keypoints[..., :2] = cv2.transform( + results["keypoints"][..., :2], warp_mat + ) + + ## if transformed_keypoints out of bound, set them to zero + out_of_bounds = ( + (transformed_keypoints[..., 0] < 0) + | (transformed_keypoints[..., 0] >= w) + | (transformed_keypoints[..., 1] < 0) + | (transformed_keypoints[..., 1] >= h) + ) ## N x K + + transformed_keypoints[out_of_bounds] = 0 # mask out-of-bound keypoints + results["transformed_keypoints"] = transformed_keypoints + + # # ## set the visibility of out-of-bound keypoints to 0 + results["keypoints_visible"] = results["keypoints_visible"].copy() + results["keypoints_visible"][out_of_bounds] = 0 + + results["input_size"] = (w, h) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(input_size={self.input_size}, " + repr_str += f"use_udp={self.use_udp})" + return repr_str + + +@TRANSFORMS.register_module() +class PoseGetBBoxCenterScale(BaseTransform): + def __init__(self, padding: float = 1.25) -> None: + super().__init__() + self.padding = padding + + def transform(self, results: Dict) -> Optional[dict]: + if "bbox_center" in results and "bbox_scale" in results: + rank, _ = get_dist_info() + if rank == 0: + warnings.warn( + 'Use the existing "bbox_center" and "bbox_scale"' + ". The padding will still be applied." + ) + results["bbox_scale"] *= self.padding + + else: + bbox = results["bbox"] + center, scale = bbox_xyxy2cs(bbox, padding=self.padding) + + results["bbox_center"] = center + results["bbox_scale"] = scale + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + f"(padding={self.padding})" + return repr_str + + +@TRANSFORMS.register_module() +class PoseRandomFlip(BaseTransform): + def __init__( + self, + prob: Union[float, List[float]] = 0.5, + direction: str = "horizontal", + ) -> None: + if isinstance(prob, list): + assert is_list_of(prob, float) + assert 0 <= sum(prob) <= 1 + elif isinstance(prob, float): + assert 0 <= prob <= 1 + else: + raise ValueError( + f"probs must be float or list of float, but \ + got `{type(prob)}`." + ) + self.prob = prob + self.direction = direction + + def flip_bbox( + self, + bbox: np.ndarray, + image_size: Tuple[int, int], + bbox_format: str = "xyxy", + direction: str = "horizontal", + ) -> np.ndarray: + format_options = {"xywh", "xyxy", "center"} + assert bbox_format in format_options, ( + f'Invalid bbox format "{bbox_format}". Options are {format_options}' + ) + + bbox_flipped = bbox.copy() + w, h = image_size + + if direction == "horizontal": + if bbox_format == "xywh" or bbox_format == "center": + bbox_flipped[..., 0] = w - bbox[..., 0] - 1 + elif bbox_format == "xyxy": + bbox_flipped[..., ::2] = w - bbox[..., ::2] - 1 + elif direction == "vertical": + if bbox_format == "xywh" or bbox_format == "center": + bbox_flipped[..., 1] = h - bbox[..., 1] - 1 + elif bbox_format == "xyxy": + bbox_flipped[..., 1::2] = h - bbox[..., 1::2] - 1 + elif direction == "diagonal": + if bbox_format == "xywh" or bbox_format == "center": + bbox_flipped[..., :2] = [w, h] - bbox[..., :2] - 1 + elif bbox_format == "xyxy": + bbox_flipped[...] = [w, h, w, h] - bbox - 1 + + return bbox_flipped + + def flip_keypoints( + self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray], + image_size: Tuple[int, int], + flip_indices: List[int], + direction: str = "horizontal", + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + assert keypoints.shape[:-1] == keypoints_visible.shape, ( + f"Mismatched shapes of keypoints {keypoints.shape} and " + f"keypoints_visible {keypoints_visible.shape}" + ) + + direction_options = {"horizontal"} + assert direction in direction_options, ( + f'Invalid flipping direction "{direction}". Options are {direction_options}' + ) + + # swap the symmetric keypoint pairs + if direction == "horizontal" or direction == "vertical": + keypoints = keypoints[..., flip_indices, :] + if keypoints_visible is not None: + keypoints_visible = keypoints_visible[..., flip_indices] + + # flip the keypoints + w, h = image_size + if direction == "horizontal": + keypoints[..., 0] = w - 1 - keypoints[..., 0] + elif direction == "vertical": + keypoints[..., 1] = h - 1 - keypoints[..., 1] + else: + keypoints = [w, h] - keypoints - 1 + + return keypoints, keypoints_visible + + def transform(self, results: dict) -> dict: + if np.random.rand() > self.prob: + results["flip"] = False + results["flip_direction"] = "" + return results + + flip_dir = "horizontal" + results["flip"] = True + results["flip_direction"] = flip_dir + + h, w = results["img"].shape[:2] + results["img"] = cv2.flip(results["img"], 1) # horizontal flip + + # flip bboxes + if results.get("bbox", None) is not None: + results["bbox"] = self.flip_bbox( + results["bbox"], + image_size=(w, h), + bbox_format="xyxy", + direction=flip_dir, + ) + + if results.get("bbox_center", None) is not None: + results["bbox_center"] = self.flip_bbox( + results["bbox_center"], + image_size=(w, h), + bbox_format="center", + direction=flip_dir, + ) + + # flip keypoints + if results.get("keypoints", None) is not None: + keypoints, keypoints_visible = self.flip_keypoints( + results["keypoints"], + results.get("keypoints_visible", None), + image_size=(w, h), + flip_indices=results["flip_indices"], + direction=flip_dir, + ) + + results["keypoints"] = keypoints + results["keypoints_visible"] = keypoints_visible + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(prob={self.prob}, " + repr_str += f"direction={self.direction})" + return repr_str + + +@TRANSFORMS.register_module() +class PoseRandomHalfBody(BaseTransform): + def __init__( + self, + min_total_keypoints: int = 9, + min_upper_keypoints: int = 2, + min_lower_keypoints: int = 3, + padding: float = 1.5, + prob: float = 0.3, + upper_prioritized_prob: float = 0.7, + ) -> None: + super().__init__() + self.min_total_keypoints = min_total_keypoints + self.min_upper_keypoints = min_upper_keypoints + self.min_lower_keypoints = min_lower_keypoints + self.padding = padding + self.prob = prob + self.upper_prioritized_prob = upper_prioritized_prob + + def _get_half_body_bbox( + self, keypoints: np.ndarray, half_body_ids: List[int] + ) -> Tuple[np.ndarray, np.ndarray]: + selected_keypoints = keypoints[half_body_ids] + center = selected_keypoints.mean(axis=0)[:2] + x1, y1 = selected_keypoints.min(axis=0) + x2, y2 = selected_keypoints.max(axis=0) + w = x2 - x1 + h = y2 - y1 + scale = np.array([w, h], dtype=center.dtype) * self.padding + return center, scale + + def _random_select_half_body( + self, + keypoints_visible: np.ndarray, + upper_body_ids: List[int], + lower_body_ids: List[int], + ) -> List[Optional[List[int]]]: + half_body_ids = [] + + for visible in keypoints_visible: + if visible.sum() < self.min_total_keypoints: + indices = None + elif np.random.rand() > self.prob: + indices = None + else: + upper_valid_ids = [i for i in upper_body_ids if visible[i] > 0] + lower_valid_ids = [i for i in lower_body_ids if visible[i] > 0] + + num_upper = len(upper_valid_ids) + num_lower = len(lower_valid_ids) + + prefer_upper = np.random.rand() < self.upper_prioritized_prob + if ( + num_upper < self.min_upper_keypoints + and num_lower < self.min_lower_keypoints + ): + indices = None + elif num_lower < self.min_lower_keypoints: + indices = upper_valid_ids + elif num_upper < self.min_upper_keypoints: + indices = lower_valid_ids + else: + indices = upper_valid_ids if prefer_upper else lower_valid_ids + + half_body_ids.append(indices) + + return half_body_ids + + def transform(self, results: Dict) -> Optional[dict]: + half_body_ids = self._random_select_half_body( + keypoints_visible=results["keypoints_visible"], + upper_body_ids=results["upper_body_ids"], + lower_body_ids=results["lower_body_ids"], + ) + + bbox_center = [] + bbox_scale = [] + + for i, indices in enumerate(half_body_ids): + if indices is None: + bbox_center.append(results["bbox_center"][i]) + bbox_scale.append(results["bbox_scale"][i]) + else: + _center, _scale = self._get_half_body_bbox( + results["keypoints"][i], indices + ) + bbox_center.append(_center) + bbox_scale.append(_scale) + + results["bbox_center"] = np.stack(bbox_center) + results["bbox_scale"] = np.stack(bbox_scale) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(min_total_keypoints={self.min_total_keypoints}, " + repr_str += f"min_upper_keypoints={self.min_upper_keypoints}, " + repr_str += f"min_lower_keypoints={self.min_lower_keypoints}, " + repr_str += f"padding={self.padding}, " + repr_str += f"prob={self.prob}, " + repr_str += f"upper_prioritized_prob={self.upper_prioritized_prob})" + return repr_str + + +@TRANSFORMS.register_module() +class PoseRandomBBoxTransform(BaseTransform): + def __init__( + self, + shift_factor: float = 0.16, + shift_prob: float = 0.3, + scale_factor: Tuple[float, float] = (0.5, 1.5), + scale_prob: float = 1.0, + rotate_factor: float = 80.0, + rotate_prob: float = 0.6, + ) -> None: + super().__init__() + + self.shift_factor = shift_factor + self.shift_prob = shift_prob + self.scale_factor = scale_factor + self.scale_prob = scale_prob + self.rotate_factor = rotate_factor + self.rotate_prob = rotate_prob + + @staticmethod + def _truncnorm( + low: float = -1.0, high: float = 1.0, size: tuple = () + ) -> np.ndarray: + """Sample from a truncated normal distribution.""" + return truncnorm.rvs(low, high, size=size).astype(np.float32) + + def _get_transform_params(self, num_bboxes: int) -> Tuple: + # Get shift parameters + offset = self._truncnorm(size=(num_bboxes, 2)) * self.shift_factor + offset = np.where(np.random.rand(num_bboxes, 1) < self.shift_prob, offset, 0.0) + + # Get scaling parameters + scale_min, scale_max = self.scale_factor + mu = (scale_max + scale_min) * 0.5 + sigma = (scale_max - scale_min) * 0.5 + scale = self._truncnorm(size=(num_bboxes, 1)) * sigma + mu + scale = np.where(np.random.rand(num_bboxes, 1) < self.scale_prob, scale, 1.0) + + # Get rotation parameters + rotate = self._truncnorm(size=(num_bboxes,)) * self.rotate_factor + rotate = np.where(np.random.rand(num_bboxes) < self.rotate_prob, rotate, 0.0) + + return offset, scale, rotate + + def transform(self, results: Dict) -> Optional[dict]: + bbox_scale = results["bbox_scale"] + num_bboxes = bbox_scale.shape[0] + + offset, scale, rotate = self._get_transform_params(num_bboxes) + + results["bbox_center"] += offset * bbox_scale + results["bbox_scale"] *= scale + results["bbox_rotation"] = rotate + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(shift_prob={self.shift_prob}, " + repr_str += f"shift_factor={self.shift_factor}, " + repr_str += f"scale_prob={self.scale_prob}, " + repr_str += f"scale_factor={self.scale_factor}, " + repr_str += f"rotate_prob={self.rotate_prob}, " + repr_str += f"rotate_factor={self.rotate_factor})" + return repr_str + + +@TRANSFORMS.register_module() +class PoseAlbumentation(BaseTransform): + def __init__(self, transforms: List[dict], keymap: Optional[dict] = None) -> None: + if albumentations is None: + raise RuntimeError("albumentations is not installed") + self.transforms = transforms + self.aug = albumentations.Compose( + [self.albu_builder(t) for t in self.transforms] + ) + + if not keymap: + self.keymap_to_albu = { + "img": "image", + } + else: + self.keymap_to_albu = keymap + + def albu_builder(self, cfg: dict) -> albumentations: + assert isinstance(cfg, dict) and "type" in cfg + args = cfg.copy() + + obj_type = args.pop("type") + if isinstance(obj_type, str): + if albumentations is None: + raise RuntimeError("albumentations is not installed") + try: + from torch.distributed import get_rank + + rank = get_rank() + except (ImportError, RuntimeError): + rank = 0 + obj_cls = getattr(albumentations, obj_type) + elif isinstance(obj_type, type): + obj_cls = obj_type + else: + raise TypeError(f"type must be a str, but got {type(obj_type)}") + + if "transforms" in args: + args["transforms"] = [ + self.albu_builder(transform) for transform in args["transforms"] + ] + return obj_cls(**args) + + def transform(self, results: dict) -> dict: + results_albu = {} + for k, v in self.keymap_to_albu.items(): + assert k in results, ( + f"The `{k}` is required to perform albumentations transforms" + ) + results_albu[v] = results[k] + + # Apply albumentations transforms + results_albu = self.aug(**results_albu) + + # map the albu results back to the original format + for k, v in self.keymap_to_albu.items(): + results[k] = results_albu[v] + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + f"(transforms={self.transforms})" + return repr_str + + +@TRANSFORMS.register_module() +class PosePackInputs(BaseTransform): + def __init__( + self, + meta_keys=( + "id", + "img_id", + "img_path", + "category_id", + "crowd_index", + "ori_shape", + "img_shape", + "input_size", + "input_center", + "input_scale", + "bbox_center", + "bbox_scale", + "bbox_score", + "flip", + "flip_direction", + "flip_indices", + "raw_ann_info", + ), + pack_transformed=False, + ): + self.meta_keys = meta_keys + self.pack_transformed = pack_transformed + + def transform(self, results: dict) -> dict: + packed_results = dict() + if "img" in results: + img = results["img"] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + if not img.flags.c_contiguous: + img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1))) + else: + img = img.transpose(2, 0, 1) + img = to_tensor(img).contiguous() + packed_results["inputs"] = img + + data_sample = dict() + if "keypoints" in results: + keypoints = results["keypoints"].astype(np.float32) + keypoints_visible = results["keypoints_visible"].astype(np.float32) + data_sample["keypoints"] = keypoints + data_sample["keypoints_visible"] = keypoints_visible + + ## update keypoints weights with if keypoints within bounds + if "keypoint_weights" in results and "transformed_keypoints" in results: + transformed_keypoints = results["transformed_keypoints"] # 1 x K x 3 + h, w = img.shape[1:] + + keypoints_in_bounds = ( + keypoints_visible + * (transformed_keypoints[..., 0] >= 0) + * (transformed_keypoints[..., 1] >= 0) + * (transformed_keypoints[..., 0] < w) + * (transformed_keypoints[..., 1] < h) + ) + data_sample["keypoint_weights"] = ( + keypoints_in_bounds * results["keypoint_weights"] + ) + if "heatmaps" in results: + data_sample["heatmaps"] = results["heatmaps"] ## K x heatmap_H x heatmap_W + + img_meta = {} + for key in self.meta_keys: + if key in results: + if isinstance(results[key], (int, float)): + img_meta[key] = np.float32(results[key]) + elif isinstance(results[key], np.ndarray): + img_meta[key] = results[key].astype(np.float32) + else: + img_meta[key] = results[key] + + data_sample["meta"] = img_meta + packed_results["data_samples"] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(meta_keys={self.meta_keys})" + return repr_str diff --git a/sapiens/pose/src/datasets/utils.py b/sapiens/pose/src/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0641bbb0890752b9fc0483ddcaa9b10aab39bc54 --- /dev/null +++ b/sapiens/pose/src/datasets/utils.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from sapiens.engine import Config + + +def parse_pose_metainfo(metainfo: dict): + if "from_file" in metainfo: + cfg_file = metainfo["from_file"] + metainfo = Config.fromfile(cfg_file).dataset_info + + # check data integrity + assert "dataset_name" in metainfo + assert "keypoint_info" in metainfo + assert "skeleton_info" in metainfo + # assert 'joint_weights' in metainfo + # assert 'sigmas' in metainfo + + # parse metainfo + parsed = dict( + dataset_name=None, + num_keypoints=None, + keypoint_id2name={}, + keypoint_name2id={}, + upper_body_ids=[], + lower_body_ids=[], + flip_indices=[], + flip_pairs=[], + keypoint_colors=[], + num_skeleton_links=None, + skeleton_links=[], + skeleton_link_colors=[], + dataset_keypoint_weights=None, + sigmas=None, + ) + + parsed["dataset_name"] = metainfo["dataset_name"] + + if "remove_teeth" in metainfo: + parsed["remove_teeth"] = metainfo["remove_teeth"] + + if "min_visible_keypoints" in metainfo: + parsed["min_visible_keypoints"] = metainfo["min_visible_keypoints"] + + if "teeth_keypoint_ids" in metainfo: + parsed["teeth_keypoint_ids"] = metainfo["teeth_keypoint_ids"] + + if "coco_wholebody_to_goliath_mapping" in metainfo: + parsed["coco_wholebody_to_goliath_mapping"] = metainfo[ + "coco_wholebody_to_goliath_mapping" + ] + + if "coco_wholebody_to_goliath_keypoint_info" in metainfo: + parsed["coco_wholebody_to_goliath_keypoint_info"] = metainfo[ + "coco_wholebody_to_goliath_keypoint_info" + ] + + if "idx_to_original_idx_mapping" in metainfo: + parsed["idx_to_original_idx_mapping"] = metainfo["idx_to_original_idx_mapping"] + + # parse keypoint information + parsed["num_keypoints"] = len(metainfo["keypoint_info"]) + + for kpt_id, kpt in metainfo["keypoint_info"].items(): + kpt_name = kpt["name"] + parsed["keypoint_id2name"][kpt_id] = kpt_name + parsed["keypoint_name2id"][kpt_name] = kpt_id + parsed["keypoint_colors"].append(kpt.get("color", [255, 128, 0])) + + kpt_type = kpt.get("type", "") + if kpt_type == "upper": + parsed["upper_body_ids"].append(kpt_id) + elif kpt_type == "lower": + parsed["lower_body_ids"].append(kpt_id) + + swap_kpt = kpt.get("swap", "") + if swap_kpt == kpt_name or swap_kpt == "": + parsed["flip_indices"].append(kpt_name) + else: + parsed["flip_indices"].append(swap_kpt) + pair = (swap_kpt, kpt_name) + if pair not in parsed["flip_pairs"]: + parsed["flip_pairs"].append(pair) + + # parse skeleton information + parsed["num_skeleton_links"] = len(metainfo["skeleton_info"]) + for _, sk in metainfo["skeleton_info"].items(): + parsed["skeleton_links"].append(sk["link"]) + parsed["skeleton_link_colors"].append(sk.get("color", [96, 96, 255])) + + # parse extra information + if "joint_weights" in metainfo: + parsed["dataset_keypoint_weights"] = np.array( + metainfo["joint_weights"], dtype=np.float32 + ) + if "sigmas" in metainfo: + parsed["sigmas"] = np.array(metainfo["sigmas"], dtype=np.float32) + + if "stats_info" in metainfo: + parsed["stats_info"] = {} + for name, val in metainfo["stats_info"].items(): + parsed["stats_info"][name] = np.array(val, dtype=np.float32) + + # formatting + def _map(src, mapping: dict): + if isinstance(src, (list, tuple)): + cls = type(src) + return cls(_map(s, mapping) for s in src) + else: + return mapping[src] + + parsed["flip_pairs"] = _map( + parsed["flip_pairs"], mapping=parsed["keypoint_name2id"] + ) + parsed["flip_indices"] = _map( + parsed["flip_indices"], mapping=parsed["keypoint_name2id"] + ) + parsed["skeleton_links"] = _map( + parsed["skeleton_links"], mapping=parsed["keypoint_name2id"] + ) + + parsed["keypoint_colors"] = np.array(parsed["keypoint_colors"], dtype=np.uint8) + parsed["skeleton_link_colors"] = np.array( + parsed["skeleton_link_colors"], dtype=np.uint8 + ) + + return parsed diff --git a/sapiens/pose/src/evaluators/__init__.py b/sapiens/pose/src/evaluators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7ae3b2fceb5c0181134218ec806d3bd05e8e3624 --- /dev/null +++ b/sapiens/pose/src/evaluators/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .keypoints308_evaluator import Keypoints308Evaluator, nms, oks_nms + +__all__ = ["Keypoints308Evaluator", "nms", "oks_nms"] diff --git a/sapiens/pose/src/evaluators/keypoints308_evaluator.py b/sapiens/pose/src/evaluators/keypoints308_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..cce1ce43b4537919c6d1409cadf98b0d6a79578b --- /dev/null +++ b/sapiens/pose/src/evaluators/keypoints308_evaluator.py @@ -0,0 +1,762 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import sys +import tempfile +from collections import defaultdict, OrderedDict +from typing import Dict, List, Optional, Sequence + +import numpy as np +import torch +from sapiens.engine.evaluators import BaseEvaluator +from sapiens.registry import MODELS + +from ..datasets.codecs import UDPHeatmap +from ..datasets.codecs.utils import get_heatmap_maximum +from ..datasets.utils import parse_pose_metainfo + +## get the keypoints ids +try: + this_file = os.path.abspath(__file__) + root_dir = os.path.abspath(os.path.join(this_file, "..", "..", "..")) + sys.path.append(str(os.path.join(root_dir))) + from configs._base_.keypoints308 import dataset_info as KEYPOINTS308_INFO + + KEYPOINTS308_INFO["name2id"] = {} + for keypoint_id, keypoint_info in KEYPOINTS308_INFO["keypoint_info"].items(): + KEYPOINTS308_INFO["name2id"][keypoint_info["name"]] = keypoint_id + + KEYPOINTS308_INFO["body_keypoint_ids"] = [ + KEYPOINTS308_INFO["name2id"][name] + for name in KEYPOINTS308_INFO["body_keypoint_names"] + ] + + KEYPOINTS308_INFO["foot_keypoint_ids"] = [ + KEYPOINTS308_INFO["name2id"][name] + for name in KEYPOINTS308_INFO["foot_keypoint_names"] + ] + + KEYPOINTS308_INFO["face_keypoint_ids"] = [ + KEYPOINTS308_INFO["name2id"][name] + for name in KEYPOINTS308_INFO["face_keypoint_names"] + ] + + KEYPOINTS308_INFO["left_hand_keypoint_ids"] = [ + KEYPOINTS308_INFO["name2id"][name] + for name in KEYPOINTS308_INFO["left_hand_keypoint_names"] + ] + + KEYPOINTS308_INFO["right_hand_keypoint_ids"] = [ + KEYPOINTS308_INFO["name2id"][name] + for name in KEYPOINTS308_INFO["right_hand_keypoint_names"] + ] + +except Exception as e: + pass + + +@MODELS.register_module() +class Keypoints308Evaluator(BaseEvaluator): + body_num = 17 + foot_num = 6 + face_num = 238 + left_hand_num = 20 + right_hand_num = 20 + remaining_extra_num = 7 ## total to 308 + + def __init__( + self, + decoder: Optional[dict] = None, + ann_file: Optional[str] = None, + use_area: bool = True, + iou_type: str = "keypoints", + score_mode: str = "bbox_keypoint", + keypoint_score_thr: float = 0.2, + nms_mode: str = "oks_nms", + nms_thr: float = 0.9, + ): + from xtcocotools.coco import COCO # lazy: only needed for COCO-format ann files + + super().__init__() + self.num_keypoints = 308 + decoder_type = decoder.pop("type") + assert decoder_type == "UDPHeatmap", "Only UDPHeatmap is supported" + self.decoder = UDPHeatmap(**decoder) + self.coco = COCO(ann_file) + + self.dataset_meta = parse_pose_metainfo( + dict(from_file="configs/_base_/keypoints308.py") + ) + + self.body_keypoint_ids = KEYPOINTS308_INFO["body_keypoint_ids"] + self.foot_keypoint_ids = KEYPOINTS308_INFO["foot_keypoint_ids"] + self.face_keypoint_ids = KEYPOINTS308_INFO["face_keypoint_ids"] + self.left_hand_keypoint_ids = KEYPOINTS308_INFO["left_hand_keypoint_ids"] + self.right_hand_keypoint_ids = KEYPOINTS308_INFO["right_hand_keypoint_ids"] + + assert len(self.body_keypoint_ids) == self.body_num + assert len(self.foot_keypoint_ids) == self.foot_num + assert len(self.face_keypoint_ids) == self.face_num + assert len(self.left_hand_keypoint_ids) == self.left_hand_num + assert len(self.right_hand_keypoint_ids) == self.right_hand_num + + self.use_area = use_area + self.iou_type = iou_type + + allowed_score_modes = ["bbox", "bbox_keypoint", "bbox_rle", "keypoint"] + if score_mode not in allowed_score_modes: + raise ValueError( + "`score_mode` should be one of 'bbox', 'bbox_keypoint', " + f"'bbox_rle', but got {score_mode}" + ) + self.score_mode = score_mode + self.keypoint_score_thr = keypoint_score_thr + + allowed_nms_modes = ["oks_nms"] + if nms_mode not in allowed_nms_modes: + raise ValueError( + "`nms_mode` should be one of 'oks_nms', but got {nms_mode}" + ) + self.nms_mode = nms_mode + self.nms_thr = nms_thr + + @torch.no_grad() + def process(self, predictions: torch.Tensor, data_samples: dict, accelerator=None): + assert accelerator is not None, "evaluation process expects an accelerator" + + if predictions.dtype == torch.bfloat16: + predictions = predictions.float() + + pred_heatmaps = predictions.cpu().numpy() ## B x K x heatmap_H x heatmap_W + ( + keypoints_list, + scores_list, + ids_list, + img_ids_list, + areas_list, + bbox_scores_list, + ) = ([], [], [], [], [], []) + + for i in range(pred_heatmaps.shape[0]): + pred_heatmap = pred_heatmaps[i] + meta_sample = data_samples[i]["meta"] # Assuming 'meta' is a list of dicts + + keypoints, keypoint_scores = self.decoder.decode( + pred_heatmap + ) ## kps in crop image + + ## convert to global image size + bbox_center = meta_sample["bbox_center"] ## 1 x 2 + bbox_scale = meta_sample["bbox_scale"] ## 1 x 2 + input_size = np.array(meta_sample["input_size"]) ## 2, 768 x 1024 + area = np.prod(meta_sample["bbox_scale"]) + + keypoints = ( + keypoints / input_size * bbox_scale + bbox_center - 0.5 * bbox_scale + ) + keypoints_list.append(keypoints) + scores_list.append(keypoint_scores) + ids_list.append(int(meta_sample["id"])) + img_ids_list.append(int(meta_sample["img_id"])) + bbox_scores_list.append(meta_sample["bbox_score"]) + areas_list.append(area) + + if not areas_list: + areas_list = [0.0] * len(keypoints_list) + + results_to_gather = { + "keypoints": torch.tensor( + np.array(keypoints_list), device=predictions.device + ), + "keypoint_scores": torch.tensor( + np.array(scores_list), device=predictions.device + ), + "id": torch.tensor(ids_list, device=predictions.device), + "img_id": torch.tensor(img_ids_list, device=predictions.device), + "areas": torch.tensor(areas_list, device=predictions.device), + "bbox_scores": torch.tensor( + np.array(bbox_scores_list), device=predictions.device + ), + } + gathered_results = accelerator.gather_for_metrics(results_to_gather) + + if accelerator.is_main_process: + keypoints_all = gathered_results["keypoints"].cpu().numpy() + scores_all = gathered_results["keypoint_scores"].cpu().numpy() + ids_all = gathered_results["id"].cpu().tolist() + img_ids_all = gathered_results["img_id"].cpu().tolist() + areas_all = gathered_results["areas"].cpu().tolist() + bbox_scores_all = gathered_results["bbox_scores"].cpu().tolist() + + for i in range(len(keypoints_all)): + pred = { + "id": ids_all[i], + "img_id": img_ids_all[i], + "keypoints": keypoints_all[i], + "keypoint_scores": scores_all[i], + "areas": areas_all[i], + "category_id": 1, # Defaulting category_id + "bbox_scores": bbox_scores_all[i], + } + # Assuming self.results is the master list for the evaluator + self.results.append(pred) + return + + def evaluate(self, logger=None, accelerator=None) -> Dict[str, float]: + assert accelerator is not None, "evaluation aggregation expects an accelerator" + + if not accelerator.is_main_process: + self.reset() + return {} + + if not self.results: + if logger is not None: + logger.info("No results to evaluate.") + return {} + + kpts = defaultdict(list) + + print("len of results: ", len(self.results)) + for pred in self.results: + img_id = pred["img_id"] + for idx in range(len(pred["keypoints"])): + instance = { + "id": pred["id"], + "img_id": pred["img_id"], + "category_id": pred["category_id"], + "keypoints": pred["keypoints"][idx], ## K x 2 + "keypoint_scores": pred["keypoint_scores"][idx], ## K + "bbox_score": pred["bbox_scores"][idx], + } + + # use keypoint to calculate bbox and get area + keypoints = pred["keypoints"][idx] + area = (np.max(keypoints[:, 0]) - np.min(keypoints[:, 0])) * ( + np.max(keypoints[:, 1]) - np.min(keypoints[:, 1]) + ) + instance["area"] = area + + kpts[img_id].append(instance) + + # sort keypoint results according to id and remove duplicate ones + kpts = self._sort_and_unique_bboxes(kpts, key="id") + valid_kpts = defaultdict(list) + num_keypoints = self.num_keypoints + + assert len(self.dataset_meta["sigmas"]) == num_keypoints + + for img_id, instances in kpts.items(): + for instance in instances: + # concatenate the keypoint coordinates and scores + instance["keypoints"] = np.concatenate( + [instance["keypoints"], instance["keypoint_scores"][:, None]], + axis=-1, + ) + if self.score_mode == "bbox_keypoint": + bbox_score = instance["bbox_score"] + mean_kpt_score = 0 + valid_num = 0 + for kpt_idx in range(num_keypoints): + kpt_score = instance["keypoint_scores"][kpt_idx] + if kpt_score > self.keypoint_score_thr: + mean_kpt_score += kpt_score + valid_num += 1 + if valid_num != 0: + mean_kpt_score /= valid_num + instance["score"] = bbox_score * mean_kpt_score + # perform nms + nms = oks_nms if self.nms_mode == "oks_nms" else None + keep = nms(instances, self.nms_thr, sigmas=self.dataset_meta["sigmas"]) + valid_kpts[img_id] = [instances[_keep] for _keep in keep] + + tmp_dir = tempfile.TemporaryDirectory() + outfile_prefix = os.path.join(tmp_dir.name, "results") + self.results2json(valid_kpts, outfile_prefix=outfile_prefix) + + # evaluation results + eval_results = OrderedDict() + logger.info(f"Evaluating {self.__class__.__name__}...") + info_str = self._do_python_keypoint_eval(outfile_prefix) + name_value = OrderedDict(info_str) + eval_results.update(name_value) + + if tmp_dir is not None: + tmp_dir.cleanup() + + if logger is not None: + logger.info(info_str) + + self.reset() + + return eval_results + + def results2json(self, keypoints: Dict[int, list], outfile_prefix: str = "") -> str: + # the results with category_id + cat_id = 1 + cat_results = [] + + for _, img_kpts in keypoints.items(): + _keypoints = np.array([img_kpt["keypoints"] for img_kpt in img_kpts]) + num_keypoints = self.dataset_meta["num_keypoints"] + + _body_keypoints = _keypoints[ + :, self.body_keypoint_ids + ].copy() ## get only body keypoints + _foot_keypoints = _keypoints[ + :, self.foot_keypoint_ids + ].copy() ## get only foot keypoints + _face_keypoints = _keypoints[ + :, self.face_keypoint_ids + ].copy() ## get only face keypoints + _left_hand_keypoints = _keypoints[ + :, self.left_hand_keypoint_ids + ].copy() ## get only left hand keypoints + _right_hand_keypoints = _keypoints[ + :, self.right_hand_keypoint_ids + ].copy() ## get only right hand keypoints + + _keypoints = _keypoints.reshape(-1, num_keypoints * 3) ## flatten + _body_keypoints = _body_keypoints.reshape(-1, self.body_num * 3) ## flatten + _foot_keypoints = _foot_keypoints.reshape(-1, self.foot_num * 3) ## flatten + _face_keypoints = _face_keypoints.reshape(-1, self.face_num * 3) ## flatten + _left_hand_keypoints = _left_hand_keypoints.reshape( + -1, self.left_hand_num * 3 + ) ## flatten + _right_hand_keypoints = _right_hand_keypoints.reshape( + -1, self.right_hand_num * 3 + ) ## flatten + + result = [ + { + "image_id": img_kpt["img_id"], + "category_id": cat_id, + "goliath_wholebody_kpts": _keypoint.tolist(), ## all keypoints. Modified in xtcocotools + "keypoints": _body_keypoint.tolist(), ## xtcocotools treats this as body keypoints, 17 default + "foot_kpts": _foot_keypoint.tolist(), + "face_kpts": _face_keypoint.tolist(), + "lefthand_kpts": _left_hand_keypoint.tolist(), + "righthand_kpts": _right_hand_keypoint.tolist(), + "score": float(img_kpt["score"]), + } + for img_kpt, _keypoint, _body_keypoint, _foot_keypoint, _face_keypoint, _left_hand_keypoint, _right_hand_keypoint in zip( + img_kpts, + _keypoints, + _body_keypoints, + _foot_keypoints, + _face_keypoints, + _left_hand_keypoints, + _right_hand_keypoints, + ) + ] + + cat_results.extend(result) + + res_file = f"{outfile_prefix}.keypoints.json" + json.dump(cat_results, open(res_file, "w"), sort_keys=True, indent=4) + + def _do_python_keypoint_eval(self, outfile_prefix: str) -> list: + """Do keypoint evaluation using COCOAPI. + + Args: + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json files will be named + "somepath/xxx.keypoints.json", + + Returns: + list: a list of tuples. Each tuple contains the evaluation stats + name and corresponding stats value. + """ + from xtcocotools.cocoeval import COCOeval # lazy: only needed during eval + + res_file = f"{outfile_prefix}.keypoints.json" + coco_det = self.coco.loadRes(res_file) + sigmas = self.dataset_meta["sigmas"] + + coco_eval = COCOeval( + self.coco, + coco_det, + "keypoints_body", + sigmas[self.body_keypoint_ids], + use_area=True, + ) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + coco_eval = COCOeval( + self.coco, + coco_det, + "keypoints_foot", + sigmas[self.foot_keypoint_ids], + use_area=True, + ) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + coco_eval = COCOeval( + self.coco, + coco_det, + "keypoints_face", + sigmas[self.face_keypoint_ids], + use_area=True, + ) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + coco_eval = COCOeval( + self.coco, + coco_det, + "keypoints_lefthand", + sigmas[self.left_hand_keypoint_ids], + use_area=True, + ) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + coco_eval = COCOeval( + self.coco, + coco_det, + "keypoints_righthand", + sigmas[self.right_hand_keypoint_ids], + use_area=True, + ) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + coco_eval = COCOeval( + self.coco, coco_det, "keypoints_wholebody_goliath", sigmas, use_area=True + ) + coco_eval.params.useSegm = None + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + stats_names = [ + "AP", + "AP .5", + "AP .75", + "AP (M)", + "AP (L)", + "AR", + "AR .5", + "AR .75", + "AR (M)", + "AR (L)", + ] + + info_str = list(zip(stats_names, coco_eval.stats)) + + return info_str + + def _sort_and_unique_bboxes( + self, kpts: Dict[int, list], key: str = "id" + ) -> Dict[int, list]: + for img_id, persons in kpts.items(): + # deal with bottomup-style output + if isinstance(kpts[img_id][0][key], Sequence): + return kpts + num = len(persons) + kpts[img_id] = sorted(kpts[img_id], key=lambda x: x[key]) + for i in range(num - 1, 0, -1): + if kpts[img_id][i][key] == kpts[img_id][i - 1][key]: + del kpts[img_id][i] + + return kpts + + +# ------------------------------------------------------------------------------- +def nms(dets: np.ndarray, thr: float) -> List[int]: + """Greedily select boxes with high confidence and overlap <= thr. + + Args: + dets (np.ndarray): [[x1, y1, x2, y2, score]]. + thr (float): Retain overlap < thr. + + Returns: + list: Indexes to keep. + """ + if len(dets) == 0: + return [] + + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while len(order) > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thr)[0] + order = order[inds + 1] + + return keep + + +def oks_iou( + g: np.ndarray, + d: np.ndarray, + a_g: float, + a_d: np.ndarray, + sigmas: Optional[np.ndarray] = None, + vis_thr: Optional[float] = None, +) -> np.ndarray: + if sigmas is None: + sigmas = ( + np.array( + [ + 0.26, + 0.25, + 0.25, + 0.35, + 0.35, + 0.79, + 0.79, + 0.72, + 0.72, + 0.62, + 0.62, + 1.07, + 1.07, + 0.87, + 0.87, + 0.89, + 0.89, + ] + ) + / 10.0 + ) + vars = (sigmas * 2) ** 2 + xg = g[0::3] + yg = g[1::3] + vg = g[2::3] + ious = np.zeros(len(d), dtype=np.float32) + for n_d in range(0, len(d)): + xd = d[n_d, 0::3] + yd = d[n_d, 1::3] + vd = d[n_d, 2::3] + dx = xd - xg + dy = yd - yg + e = (dx**2 + dy**2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2 + if vis_thr is not None: + ind = list((vg > vis_thr) & (vd > vis_thr)) + e = e[ind] + ious[n_d] = np.sum(np.exp(-e)) / len(e) if len(e) != 0 else 0.0 + return ious + + +def oks_nms( + kpts_db: List[dict], + thr: float, + sigmas: Optional[np.ndarray] = None, + vis_thr: Optional[float] = None, + score_per_joint: bool = False, +): + if len(kpts_db) == 0: + return [] + + if score_per_joint: + scores = np.array([k["score"].mean() for k in kpts_db]) + else: + scores = np.array([k["score"] for k in kpts_db]) + + kpts = np.array([k["keypoints"].flatten() for k in kpts_db]) + areas = np.array([k["area"] for k in kpts_db]) + + order = scores.argsort()[::-1] + + keep = [] + while len(order) > 0: + i = order[0] + keep.append(i) + + oks_ovr = oks_iou( + kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, vis_thr + ) + + inds = np.where(oks_ovr <= thr)[0] + order = order[inds + 1] + + keep = np.array(keep) + + return keep + + +def _calc_distances( + preds: np.ndarray, gts: np.ndarray, mask: np.ndarray, norm_factor: np.ndarray +) -> np.ndarray: + """Calculate the normalized distances between preds and target. + + Note: + - instance number: N + - keypoint number: K + - keypoint dimension: D (normally, D=2 or D=3) + + Args: + preds (np.ndarray[N, K, D]): Predicted keypoint location. + gts (np.ndarray[N, K, D]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + norm_factor (np.ndarray[N, D]): Normalization factor. + Typical value is heatmap_size. + + Returns: + np.ndarray[K, N]: The normalized distances. \ + If target keypoints are missing, the distance is -1. + """ + N, K, _ = preds.shape + # set mask=0 when norm_factor==0 + _mask = mask.copy() + _mask[np.where((norm_factor == 0).sum(1))[0], :] = False + + distances = np.full((N, K), -1, dtype=np.float32) + # handle invalid values + norm_factor[np.where(norm_factor <= 0)] = 1e6 + distances[_mask] = np.linalg.norm( + ((preds - gts) / norm_factor[:, None, :])[_mask], axis=-1 + ) + return distances.T + + +def _distance_acc(distances: np.ndarray, thr: float = 0.5) -> float: + """Return the percentage below the distance threshold, while ignoring + distances values with -1. + + Note: + - instance number: N + + Args: + distances (np.ndarray[N, ]): The normalized distances. + thr (float): Threshold of the distances. + + Returns: + float: Percentage of distances below the threshold. \ + If all target keypoints are missing, return -1. + """ + distance_valid = distances != -1 + num_distance_valid = distance_valid.sum() + if num_distance_valid > 0: + return (distances[distance_valid] < thr).sum() / num_distance_valid + return -1 + + +def keypoint_pck_accuracy( + pred: np.ndarray, + gt: np.ndarray, + mask: np.ndarray, + thr: np.ndarray, + norm_factor: np.ndarray, +) -> tuple: + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints for coordinates. + + Note: + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + - instance number: N + - keypoint number: K + + Args: + pred (np.ndarray[N, K, 2]): Predicted keypoint location. + gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + thr (float): Threshold of PCK calculation. + norm_factor (np.ndarray[N, 2]): Normalization factor for H&W. + + Returns: + tuple: A tuple containing keypoint accuracy. + + - acc (np.ndarray[K]): Accuracy of each keypoint. + - avg_acc (float): Averaged accuracy across all keypoints. + - cnt (int): Number of valid keypoints. + """ + distances = _calc_distances(pred, gt, mask, norm_factor) + acc = np.array([_distance_acc(d, thr) for d in distances]) + valid_acc = acc[acc >= 0] + cnt = len(valid_acc) + avg_acc = valid_acc.mean() if cnt > 0 else 0.0 + return acc, avg_acc, cnt + + +def pose_pck_accuracy( + output: np.ndarray, + target: np.ndarray, + mask: np.ndarray, + thr: float = 0.05, + normalize: Optional[np.ndarray] = None, +) -> tuple: + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints from heatmaps. + + Note: + PCK metric measures accuracy of the localization of the body joints. + The distances between predicted positions and the ground-truth ones + are typically normalized by the bounding box size. + The threshold (thr) of the normalized distance is commonly set + as 0.05, 0.1 or 0.2 etc. + + - batch_size: N + - num_keypoints: K + - heatmap height: H + - heatmap width: W + + Args: + output (np.ndarray[N, K, H, W]): Model output heatmaps. + target (np.ndarray[N, K, H, W]): Groundtruth heatmaps. + mask (np.ndarray[N, K]): Visibility of the target. False for invisible + joints, and True for visible. Invisible joints will be ignored for + accuracy calculation. + thr (float): Threshold of PCK calculation. Default 0.05. + normalize (np.ndarray[N, 2]): Normalization factor for H&W. + + Returns: + tuple: A tuple containing keypoint accuracy. + + - np.ndarray[K]: Accuracy of each keypoint. + - float: Averaged accuracy across all keypoints. + - int: Number of valid keypoints. + """ + N, K, H, W = output.shape + if K == 0: + return None, 0, 0 + if normalize is None: + normalize = np.tile(np.array([[H, W]]), (N, 1)) + + pred, _ = get_heatmap_maximum(output) + gt, _ = get_heatmap_maximum(target) + return keypoint_pck_accuracy(pred, gt, mask, thr, normalize) diff --git a/sapiens/pose/src/models/__init__.py b/sapiens/pose/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7cc0bef718fe87be4c8419b261ca639ccd950fb --- /dev/null +++ b/sapiens/pose/src/models/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .core import * +from .heads import * +from .losses import * +from .init_model import init_model diff --git a/sapiens/pose/src/models/core/__init__.py b/sapiens/pose/src/models/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3c5c5d82fdbdb4a7eb9adb7660308c866fe729 --- /dev/null +++ b/sapiens/pose/src/models/core/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .pose_topdown_estimator import PoseTopdownEstimator + +__all__ = ["PoseTopdownEstimator"] diff --git a/sapiens/pose/src/models/core/pose_topdown_estimator.py b/sapiens/pose/src/models/core/pose_topdown_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..98a8afca801ea90ef28761bc7e1154f44a282c7a --- /dev/null +++ b/sapiens/pose/src/models/core/pose_topdown_estimator.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Union + +import torch +from sapiens.engine.models import BaseModel +from sapiens.registry import MODELS +from torch import Tensor + + +@MODELS.register_module() +class PoseTopdownEstimator(BaseModel): + def __init__( + self, + backbone: dict = None, + decode_head: dict = None, + init_cfg: dict = None, + train_cfg: dict = {"use_checkpoint": False}, + ): + BaseModel.__init__(self, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + self.decode_head = MODELS.build(decode_head) + self.train_cfg = train_cfg + + def loss( + self, outputs: Union[Tensor, Tuple[Tensor, ...]], data_samples: dict + ) -> Tuple[dict, Tensor]: + losses, preds = self.decode_head.loss(outputs, data_samples) + parsed_losses, log_vars = self.parse_losses(losses) + log_vars["outputs"] = preds + return parsed_losses, log_vars + + def forward(self, inputs: Tensor) -> Tensor: + ## backbone forward returns a list of tensors at different depths, 0 is the final layer + if self.training and self.train_cfg.get("use_checkpoint", False): + x = torch.utils.checkpoint.checkpoint( + self.backbone, inputs, use_reentrant=False + )[0] + else: + x = self.backbone(inputs)[0] + x = self.decode_head(x) + return x diff --git a/sapiens/pose/src/models/heads/__init__.py b/sapiens/pose/src/models/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df61e3d33d498c7a43fded88dbcbd6aef72a9eae --- /dev/null +++ b/sapiens/pose/src/models/heads/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .pose_heatmap_head import PoseHeatmapHead + +__all__ = ["PoseHeatmapHead"] diff --git a/sapiens/pose/src/models/heads/pose_heatmap_head.py b/sapiens/pose/src/models/heads/pose_heatmap_head.py new file mode 100644 index 0000000000000000000000000000000000000000..405773cb90dc1bc2b66b49e5c072dd6e6909d3e5 --- /dev/null +++ b/sapiens/pose/src/models/heads/pose_heatmap_head.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn + +# from sapiens.pose.evaluation.functional import pose_pck_accuracy +from sapiens.registry import MODELS +from torch import nn, Tensor + +from ...evaluators.keypoints308_evaluator import pose_pck_accuracy + + +@MODELS.register_module() +class PoseHeatmapHead(nn.Module): + def __init__( + self, + in_channels: Union[int, Sequence[int]], + out_channels: int, + deconv_out_channels: Optional[Sequence[int]] = (256, 256, 256), + deconv_kernel_sizes: Optional[Sequence[int]] = (4, 4, 4), + conv_out_channels: Optional[Sequence[int]] = None, + conv_kernel_sizes: Optional[Sequence[int]] = None, + loss_decode: dict = None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if deconv_out_channels: + if deconv_kernel_sizes is None or len(deconv_out_channels) != len( + deconv_kernel_sizes + ): + raise ValueError( + '"deconv_out_channels" and "deconv_kernel_sizes" should ' + "be integer sequences with the same length. Got " + f"mismatched lengths {deconv_out_channels} and " + f"{deconv_kernel_sizes}" + ) + + self.deconv_layers = self._make_deconv_layers( + in_channels=in_channels, + layer_out_channels=deconv_out_channels, + layer_kernel_sizes=deconv_kernel_sizes, + ) + in_channels = deconv_out_channels[-1] + else: + self.deconv_layers = nn.Identity() + + if conv_out_channels: + if conv_kernel_sizes is None or len(conv_out_channels) != len( + conv_kernel_sizes + ): + raise ValueError( + '"conv_out_channels" and "conv_kernel_sizes" should ' + "be integer sequences with the same length. Got " + f"mismatched lengths {conv_out_channels} and " + f"{conv_kernel_sizes}" + ) + + self.conv_layers = self._make_conv_layers( + in_channels=in_channels, + layer_out_channels=conv_out_channels, + layer_kernel_sizes=conv_kernel_sizes, + ) + in_channels = conv_out_channels[-1] + else: + self.conv_layers = nn.Identity() + + self.conv_pose = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + self.loss_decode = MODELS.build(loss_decode) + self._init_weights() + + def _init_weights(self) -> None: + """Initialize network weights.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + weight_dtype = m.weight.dtype + weight = nn.init.kaiming_normal_( + m.weight.float(), mode="fan_out", nonlinearity="relu" + ) + m.weight.data = weight.to(weight_dtype) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + weight_dtype = m.weight.dtype + weight = nn.init.kaiming_normal_( + m.weight.float(), mode="fan_in", nonlinearity="linear" + ) + m.weight.data = weight.to(weight_dtype) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.InstanceNorm2d): + if m.weight is not None: + nn.init.ones_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.RMSNorm): + if hasattr(m, "weight"): + nn.init.ones_(m.weight) + + def _make_conv_layers( + self, + in_channels: int, + layer_out_channels: Sequence[int], + layer_kernel_sizes: Sequence[int], + ) -> nn.Module: + """Create convolutional layers by given parameters.""" + layers = [] + for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes): + padding = (kernel_size - 1) // 2 + layers.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding, + ) + ) + layers.append(nn.InstanceNorm2d(out_channels)) + layers.append(nn.SiLU(inplace=True)) + in_channels = out_channels + + return nn.Sequential(*layers) + + def _make_deconv_layers( + self, + in_channels: int, + layer_out_channels: Sequence[int], + layer_kernel_sizes: Sequence[int], + ) -> nn.Module: + """Create deconvolutional layers by given parameters.""" + layers = [] + for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes): + if kernel_size == 4: + padding = 1 + output_padding = 0 + elif kernel_size == 3: + padding = 1 + output_padding = 1 + elif kernel_size == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError( + f"Unsupported kernel size {kernel_size} for" + "deconvlutional layers in " + f"{self.__class__.__name__}" + ) + layers.append( + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False, + ) + ) + layers.append(nn.InstanceNorm2d(out_channels)) + layers.append(nn.SiLU(inplace=True)) + in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x: Union[Tensor, Tuple[Tensor]]) -> Tensor: + x = self.deconv_layers(x) + x = self.conv_layers(x) + x = self.conv_pose(x) + return x + + def loss( + self, + pred_heatmaps: Tensor, + data_samples: dict, + ) -> dict: + gt_heatmaps = data_samples["heatmaps"] # B x K x H x W + keypoint_weights = data_samples["keypoint_weights"] # B x 1 x K + keypoint_weights = keypoint_weights.squeeze(dim=1) # B x K + + if pred_heatmaps.dtype != gt_heatmaps.dtype: + pred_heatmaps = pred_heatmaps.to(gt_heatmaps.dtype) + + ##--------------------------------- + losses = dict() + loss = self.loss_decode(pred_heatmaps, gt_heatmaps, keypoint_weights) + losses.update(loss_kpt=loss) + + # calculate accuracy + _, avg_acc, _ = pose_pck_accuracy( + output=pred_heatmaps.detach().cpu().float().numpy(), + target=gt_heatmaps.detach().cpu().float().numpy(), + mask=keypoint_weights.detach().cpu().float().numpy() > 0, + ) + + acc_pose = torch.tensor(avg_acc, device=gt_heatmaps.device) + losses.update(acc_pose=acc_pose) + return losses, pred_heatmaps diff --git a/sapiens/pose/src/models/init_model.py b/sapiens/pose/src/models/init_model.py new file mode 100644 index 0000000000000000000000000000000000000000..03b9201a9e35c0c70ac31a5dada88bac7a6a2112 --- /dev/null +++ b/sapiens/pose/src/models/init_model.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +from typing import Optional, Union + +import torch +from safetensors.torch import load_file +from sapiens.engine.config import Config +from sapiens.engine.datasets import Compose +from sapiens.registry import MODELS + + +def init_model( + config: Union[str, Path], + checkpoint: Optional[Union[str, Path]] = None, + device: str = "cuda:0", +): + assert isinstance(config, (str, Path)) + assert checkpoint is None or isinstance(checkpoint, (str, Path)) + + config = Config.fromfile(config) + + ## avoid loading the pretrained backbone weights + if "init_cfg" in config.model["backbone"]: + config.model["backbone"].pop("init_cfg") + + model = MODELS.build(config.model) + data_preprocessor = MODELS.build(config.data_preprocessor) + + if checkpoint is not None: + if str(checkpoint).endswith(".safetensors"): + state_dict = load_file(checkpoint, device="cpu") + else: # Handle .pth and .bin files + checkpoint_data = torch.load( + checkpoint, map_location="cpu", weights_only=False + ) + state_dict = ( + checkpoint_data["state_dict"] + if "state_dict" in checkpoint_data + else checkpoint_data["model"] + ) + + incompat = model.load_state_dict(state_dict, strict=False) + + if incompat.missing_keys: + print(f"Missing keys: {incompat.missing_keys}") + + if incompat.unexpected_keys: + print(f"Unexpected keys: {incompat.unexpected_keys}") + + print(f"\033[96mModel loaded from {checkpoint}\033[0m") + + model.cfg = config + model.data_preprocessor = data_preprocessor + model.pipeline = Compose(config.test_pipeline) + + model.to(device) + model.eval() + + return model diff --git a/sapiens/pose/src/models/losses/__init__.py b/sapiens/pose/src/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..86d6bc105a189d8ea7d025fdcbfacaa2c9eb42aa --- /dev/null +++ b/sapiens/pose/src/models/losses/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .heatmap_loss import KeypointMSELoss, KeypointOHKMMSELoss + +__all__ = [ + "KeypointMSELoss", + "KeypointOHKMMSELoss", +] diff --git a/sapiens/pose/src/models/losses/heatmap_loss.py b/sapiens/pose/src/models/losses/heatmap_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f78babef930edf56cd8386a1fe85f4c32d423b52 --- /dev/null +++ b/sapiens/pose/src/models/losses/heatmap_loss.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sapiens.registry import MODELS +from torch import Tensor + + +@MODELS.register_module() +class KeypointMSELoss(nn.Module): + def __init__( + self, + use_target_weight: bool = False, + skip_empty_channel: bool = False, + loss_weight: float = 1.0, + ): + super().__init__() + self.use_target_weight = use_target_weight + self.skip_empty_channel = skip_empty_channel + self.loss_weight = loss_weight + + def forward( + self, + output: Tensor, + target: Tensor, + target_weights: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + ) -> Tensor: + _mask = self._get_mask(target, target_weights, mask) + if _mask is None: + loss = F.mse_loss(output, target) + else: + _loss = F.mse_loss(output, target, reduction="none") + loss = (_loss * _mask).mean() + + return loss * self.loss_weight + + def _get_mask( + self, target: Tensor, target_weights: Optional[Tensor], mask: Optional[Tensor] + ) -> Optional[Tensor]: + # Given spatial mask + if mask is not None: + # check mask has matching type with target + assert mask.ndim == target.ndim and all( + d_m == d_t or d_m == 1 for d_m, d_t in zip(mask.shape, target.shape) + ), f"mask and target have mismatched shapes {mask.shape} v.s.{target.shape}" + + # Mask by target weights (keypoint-wise mask) + if target_weights is not None: + # check target weight has matching shape with target + assert ( + target_weights.ndim in (2, 4) + and target_weights.shape == target.shape[: target_weights.ndim] + ), ( + "target_weights and target have mismatched shapes " + f"{target_weights.shape} v.s. {target.shape}" + ) + + ndim_pad = target.ndim - target_weights.ndim + _mask = target_weights.view(target_weights.shape + (1,) * ndim_pad) + + if mask is None: + mask = _mask + else: + mask = mask * _mask + + # Mask by ``skip_empty_channel`` + if self.skip_empty_channel: + _mask = (target != 0).flatten(2).any(dim=2) + ndim_pad = target.ndim - _mask.ndim + _mask = _mask.view(_mask.shape + (1,) * ndim_pad) + + if mask is None: + mask = _mask + else: + mask = mask * _mask + + return mask + + +@MODELS.register_module() +class KeypointOHKMMSELoss(nn.Module): + def __init__( + self, use_target_weight: bool = False, topk: int = 8, loss_weight: float = 1.0 + ): + super().__init__() + assert topk > 0 + self.criterion = nn.MSELoss(reduction="none") + self.use_target_weight = use_target_weight + self.topk = topk + self.loss_weight = loss_weight + + def _ohkm(self, losses: Tensor) -> Tensor: + ohkm_loss = 0.0 + B = losses.shape[0] + for i in range(B): + sub_loss = losses[i] + _, topk_idx = torch.topk(sub_loss, k=self.topk, dim=0, sorted=False) + tmp_loss = torch.gather(sub_loss, 0, topk_idx) + ohkm_loss += torch.sum(tmp_loss) / self.topk + ohkm_loss /= B + return ohkm_loss + + def forward(self, output: Tensor, target: Tensor, target_weights: Tensor) -> Tensor: + num_keypoints = output.size(1) + if num_keypoints < self.topk: + raise ValueError( + f"topk ({self.topk}) should not be " + f"larger than num_keypoints ({num_keypoints})." + ) + + losses = [] + for idx in range(num_keypoints): + if self.use_target_weight: + target_weight = target_weights[:, idx, None, None] + losses.append( + self.criterion( + output[:, idx] * target_weight, target[:, idx] * target_weight + ) + ) + else: + losses.append(self.criterion(output[:, idx], target[:, idx])) + + losses = [loss.mean(dim=(1, 2)).unsqueeze(dim=1) for loss in losses] + losses = torch.cat(losses, dim=1) + + return self._ohkm(losses) * self.loss_weight diff --git a/sapiens/pose/src/runners/__init__.py b/sapiens/pose/src/runners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fbce0626930cee12b06fb2bb709db82069c0f869 --- /dev/null +++ b/sapiens/pose/src/runners/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .pose_runner import PoseRunner + +__all__ = ["PoseRunner"] diff --git a/sapiens/pose/src/runners/pose_runner.py b/sapiens/pose/src/runners/pose_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..bf82b0fcc30b6b74f599da39fa4d3154c2dfab12 --- /dev/null +++ b/sapiens/pose/src/runners/pose_runner.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from sapiens.engine.runners import BaseRunner + + +## left-right flip for pose val and test +class PoseRunner(BaseRunner): + def test(self) -> None: + if self.accelerator.is_main_process: + self.logger.info(f"\033[95mStarting test...\033[0m") + + self.model.eval() + self.evaluator.reset() + for i, data_batch in enumerate(self.val_dataloader): + data_batch = self.data_preprocessor(data_batch) # preprocess + inputs, data_samples = data_batch["inputs"], data_batch["data_samples"] + + with torch.no_grad(): + pred = self.model(inputs) # forward + + if self.val_cfg.get("flip_test", False): + with torch.no_grad(): + pred_flipped = self.model(inputs.flip(-1)) # forward + + flip_indices = data_samples[0]["meta"]["flip_indices"] + pred_flipped = pred_flipped.flip(-1) ## B x K x heatmap_H x heatmap_W + assert len(flip_indices) == pred_flipped.shape[1] ## K + pred_flipped = pred_flipped[:, flip_indices] + pred = (pred + pred_flipped) / 2.0 + + if self.accelerator.is_main_process and i > 0 and i % 100 == 0: + self.logger.info( + f"\033[95mTest: {i}/{len(self.val_dataloader)}: batch_size: {len(data_batch['inputs'])}\033[0m" + ) + self.evaluator.process( + pred, data_samples, accelerator=self.accelerator + ) ## accelerator used to gather and dedup in val + + # metrics eval on main process + metrics = self.evaluator.evaluate( + logger=self.logger, accelerator=self.accelerator + ) + if self.accelerator.is_main_process: + self.logger.info( + f"\033[95mTest: {', '.join([f'{k}: {v:.4f}' for k, v in metrics.items()])}\033[0m" + ) + self.logger.info(f"\033[95mTesting finished โœ”\033[0m") + + # ------------------------------------------------------------------------- + def val(self) -> None: + self.model.eval() + + if self.accelerator.is_main_process: + self.logger.info(f"\033[95mValidating iter {self.iter}\033[0m") + + self.evaluator.reset() + for i, data_batch in enumerate(self.val_dataloader): + data_batch = self.data_preprocessor(data_batch) # preprocess + inputs, data_samples = data_batch["inputs"], data_batch["data_samples"] + + with torch.no_grad(): + pred = self.model(inputs) # forward + + if self.val_cfg.get("flip_test", False): + with torch.no_grad(): + pred_flipped = self.model(inputs.flip(-1)) # forward + + flip_indices = data_samples[0]["meta"]["flip_indices"] + pred_flipped = pred_flipped.flip(-1) ## B x K x heatmap_H x heatmap_W + assert len(flip_indices) == pred_flipped.shape[1] ## K + pred_flipped = pred_flipped[:, flip_indices] + pred = (pred + pred_flipped) / 2.0 + + if self.accelerator.is_main_process and i > 0 and i % 100 == 0: + self.logger.info( + f"\033[95mVal: {i}/{len(self.val_dataloader)}: batch_size: {len(data_batch['inputs'])}\033[0m" + ) + + self.evaluator.process(pred, data_samples, accelerator=self.accelerator) + metric = self.evaluator.evaluate( + logger=self.logger, accelerator=self.accelerator + ) + self.model.train() + return metric diff --git a/sapiens/pose/src/visualizers/__init__.py b/sapiens/pose/src/visualizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f9b52a93fddd97a3e80c169cc6b619429a02972 --- /dev/null +++ b/sapiens/pose/src/visualizers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .pose_visualizer import PoseVisualizer + + +__all__ = [ + "PoseVisualizer", +] diff --git a/sapiens/pose/src/visualizers/pose_visualizer.py b/sapiens/pose/src/visualizers/pose_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0dccae4f41b551d54a8fd8dfdf7d1b6621d804 --- /dev/null +++ b/sapiens/pose/src/visualizers/pose_visualizer.py @@ -0,0 +1,481 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from pathlib import Path + +import cv2 +import numpy as np +import torch +import torchvision +from sapiens.registry import VISUALIZERS +from torch import nn + +from ..datasets.utils import parse_pose_metainfo + + +@VISUALIZERS.register_module() +class PoseVisualizer(nn.Module): + def __init__( + self, + output_dir: str, + vis_interval: int = 100, + vis_max_samples: int = 4, + vis_image_width: int = 384, + vis_image_height: int = 512, + num_keypoints: int = 308, + scale: int = 4, + line_width: int = 4, + radius: int = 4, + ): + super().__init__() + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.vis_max_samples = vis_max_samples + self.vis_interval = vis_interval + self.vis_image_width = vis_image_width + self.vis_image_height = vis_image_height + self.num_keypoints = num_keypoints + self.scale = scale + self.line_width = line_width + self.radius = radius + + if self.num_keypoints == 308: + self.dataset_meta = parse_pose_metainfo( + dict(from_file="configs/_base_/keypoints308.py") + ) + + self.bbox_color = self.dataset_meta.get("bbox_colors", "green") + self.kpt_color = self.dataset_meta.get("keypoint_colors") + self.link_color = self.dataset_meta.get("skeleton_link_colors") + self.skeleton = self.dataset_meta.get("skeleton_links") + + def add_batch(self, data_batch: dict, logs: dict, step: int): + pred_heatmaps = logs["outputs"] + pred_heatmaps = pred_heatmaps.detach().cpu() # B x K x H x W + + gt_heatmaps = ( + data_batch["data_samples"]["heatmaps"].detach().cpu() + ) # B x K x H x W + + inputs = data_batch["inputs"].detach().cpu() # B x 3 x H x W + + if pred_heatmaps.dtype == torch.bfloat16: + inputs = inputs.float() + pred_heatmaps = pred_heatmaps.float() + + pred_heatmaps = pred_heatmaps.cpu().detach().numpy() ## B x K x H x W + gt_heatmaps = gt_heatmaps.cpu().detach().numpy() ## B x K x H x W + target_weights = ( + data_batch["data_samples"]["keypoint_weights"].squeeze(dim=1).cpu().numpy() + ) ## B x K + + batch_size = min(len(inputs), self.vis_max_samples) + inputs = inputs[:batch_size] + pred_heatmaps = pred_heatmaps[:batch_size] ## B x K x H x W + gt_heatmaps = gt_heatmaps[:batch_size] ## B x K x H x W + target_weights = target_weights[:batch_size] ## B x K + + kps_vis_dir = os.path.join(self.output_dir, "kps") + heatmap_vis_dir = os.path.join(self.output_dir, "heatmap") + + if not os.path.exists(kps_vis_dir): + os.makedirs(kps_vis_dir, exist_ok=True) + + if not os.path.exists(heatmap_vis_dir): + os.makedirs(heatmap_vis_dir, exist_ok=True) + + kps_prefix = os.path.join(kps_vis_dir, "train") + heatmap_prefix = os.path.join(heatmap_vis_dir, "train") + suffix = str(step).zfill(6) + + original_image = inputs / 255.0 ## B x 3 x H x W + + ## heatmap vis for only first 17 kps + self.save_batch_heatmaps( + original_image, + gt_heatmaps[:, :17], + "{}_{}_hm_gt.jpg".format(heatmap_prefix, suffix), + normalize=False, + scale=self.scale, + is_rgb=False, + ) + self.save_batch_heatmaps( + original_image, + pred_heatmaps[:, :17], + "{}_{}_hm_pred.jpg".format(heatmap_prefix, suffix), + normalize=False, + scale=self.scale, + is_rgb=False, + ) + self.save_batch_image_with_joints( + 255 * original_image, + gt_heatmaps, + target_weights, + "{}_{}_gt.jpg".format(kps_prefix, suffix), + scale=self.scale, + is_rgb=False, + ) + self.save_batch_image_with_joints( + 255 * original_image, + pred_heatmaps, + np.ones_like(target_weights), + "{}_{}_pred.jpg".format(kps_prefix, suffix), + scale=self.scale, + is_rgb=False, + ) + return + + def save_batch_heatmaps( + self, + batch_image, + batch_heatmaps, + file_name, + normalize=True, + scale=4, + is_rgb=True, + max_num_joints=17, + ): + """ + batch_image: [batch_size, channel, height, width] + batch_heatmaps: ['batch_size, num_joints, height, width] + file_name: saved file name + """ + ## normalize image + if normalize: + batch_image = batch_image.clone() + min_val = float(batch_image.min()) + max_val = float(batch_image.max()) + + batch_image.add_(-min_val).div_(max_val - min_val + 1e-5) + + ## check if type of batch_heatmaps is numpy.ndarray + if isinstance(batch_heatmaps, np.ndarray): + preds, maxvals = get_max_preds(batch_heatmaps) + batch_heatmaps = torch.from_numpy(batch_heatmaps) + else: + preds, maxvals = get_max_preds(batch_heatmaps.detach().cpu().numpy()) + + preds = preds * scale ## scale to original image size + + batch_size = batch_heatmaps.size(0) + num_joints = batch_heatmaps.size(1) + heatmap_height = int(batch_heatmaps.size(2) * scale) + heatmap_width = int(batch_heatmaps.size(3) * scale) + + num_joints = min(max_num_joints, num_joints) + + grid_image = np.zeros( + (batch_size * heatmap_height, (num_joints + 1) * heatmap_width, 3), + dtype=np.uint8, + ) + + body_joint_order = range(max_num_joints) + + for i in range(batch_size): + image = ( + batch_image[i] + .mul(255) + .clamp(0, 255) + .byte() + .permute(1, 2, 0) + .cpu() + .numpy() + ) + heatmaps = batch_heatmaps[i].mul(255).clamp(0, 255).byte().cpu().numpy() + + if is_rgb == True: + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + resized_image = cv2.resize(image, (int(heatmap_width), int(heatmap_height))) + + height_begin = heatmap_height * i + height_end = heatmap_height * (i + 1) + for j in range(num_joints): + joint_index = body_joint_order[j] + + cv2.circle( + resized_image, + (int(preds[i][joint_index][0]), int(preds[i][joint_index][1])), + 1, + [0, 0, 255], + 1, + ) + heatmap = heatmaps[joint_index, :, :] + colored_heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) + colored_heatmap = cv2.resize( + colored_heatmap, (int(heatmap_width), int(heatmap_height)) + ) + masked_image = colored_heatmap * 0.7 + resized_image * 0.3 + cv2.circle( + masked_image, + (int(preds[i][joint_index][0]), int(preds[i][joint_index][1])), + 1, + [0, 0, 255], + 1, + ) + + width_begin = heatmap_width * (j + 1) + width_end = heatmap_width * (j + 2) + grid_image[height_begin:height_end, width_begin:width_end, :] = ( + masked_image + ) + + grid_image[height_begin:height_end, 0:heatmap_width, :] = resized_image + + ## resize + target_height = batch_size * self.vis_image_height + target_width = (num_joints + 1) * self.vis_image_width + grid_image = cv2.resize(grid_image, (target_width, target_height)) + + cv2.imwrite(file_name, grid_image) + return + + def save_batch_image_with_joints( + self, + batch_image, + batch_heatmaps, + batch_target_weight, + file_name, + is_rgb=True, + scale=4, + nrow=8, + padding=2, + ): + """ + batch_image: [batch_size, channel, height, width] + batch_joints: [batch_size, num_joints, 3], + batch_joints_vis: [batch_size, num_joints, 1], + } + """ + + B, C, H, W = batch_image.size() + num_joints = batch_heatmaps.shape[1] + + ## check if type of batch_heatmaps is numpy.ndarray + if isinstance(batch_heatmaps, np.ndarray): + batch_joints, batch_scores = get_max_preds(batch_heatmaps) + else: + batch_joints, batch_scores = get_max_preds( + batch_heatmaps.detach().cpu().numpy() + ) + + batch_joints = ( + batch_joints * scale + ) ## 4 is the ratio of output heatmap and input image + + if isinstance(batch_joints, torch.Tensor): + batch_joints = batch_joints.cpu().numpy() + + if isinstance(batch_target_weight, torch.Tensor): + batch_target_weight = batch_target_weight.cpu().numpy() + batch_target_weight = batch_target_weight.reshape(B, num_joints) ## B x 17 + + grid = [] + + for i in range(B): + image = ( + batch_image[i].permute(1, 2, 0).cpu().numpy() + ) # image_size x image_size x BGR. if is_rgb is False. + image = image.copy() + kps = batch_joints[i] ## 17 x 2 + kps_vis = batch_target_weight[i] + kps_score = batch_scores[i].reshape(-1) + + if is_rgb == False: + image = cv2.cvtColor( + image, cv2.COLOR_BGR2RGB + ) # convert bgr to rgb image + + kp_vis_image = self.draw_instance_kpts( + image, + keypoints=[kps], + keypoints_visible=[kps_vis], + keypoint_scores=[kps_score], + radius=self.radius, + thickness=self.line_width, + kpt_thr=0.3, + skeleton=self.skeleton, + kpt_color=self.kpt_color, + link_color=self.link_color, + ) ## H, W, C, rgb image + + kp_vis_image = cv2.cvtColor( + kp_vis_image, cv2.COLOR_RGB2BGR + ) ## convert rgb to bgr image + + kp_vis_image = kp_vis_image.transpose((2, 0, 1)).astype(np.float32) + kp_vis_image = torch.from_numpy(kp_vis_image.copy()) + grid.append(kp_vis_image) + + grid = torchvision.utils.make_grid(grid, nrow, padding) + ndarr = grid.byte().permute(1, 2, 0).cpu().numpy() + + ## resize + target_height = self.vis_image_height + target_width = ndarr.shape[1] * target_height // ndarr.shape[0] + ndarr = cv2.resize(ndarr, (target_width, target_height)) + + cv2.imwrite(file_name, ndarr) + return + + def draw_instance_kpts( + self, + image: np.ndarray, # RGB uint8 H,W,3 + keypoints, # list[(J,2)] + keypoints_visible, # list[(J,), {0/1}] + keypoint_scores, # list[(J,)] + *, + radius: int = 4, + thickness: int = -1, + color=(255, 0, 0), + kpt_thr: float = 0.3, + skeleton: list | None = None, # [(i,j)] + kpt_color: list | tuple | np.ndarray | None = None, + link_color: list | tuple | np.ndarray | None = None, + show_kpt_idx: bool = False, + ) -> np.ndarray: + img = image.copy() + H, W = img.shape[:2] + + # defaults + if skeleton is None: + skeleton = [] # points only + if kpt_color is None: + kpt_color = color + if link_color is None: + link_color = (0, 255, 0) + + # robust color normalization: supports tuple, list-of-tuples, np.ndarray (N,3) or (3,) + def _as_color_list(c, n): + # torch -> numpy + if hasattr(c, "detach"): + c = c.detach().cpu().numpy() + # numpy -> array + if isinstance(c, np.ndarray): + if c.ndim == 2 and c.shape[1] == 3: # (N,3) palette + return [tuple(int(v) for v in row) for row in c.tolist()] + if c.size == 3: # single (3,) + return [tuple(int(v) for v in c.tolist())] * max(1, n) + # python containers + if isinstance(c, (list, tuple)): + if n and len(c) == n and isinstance(c[0], (list, tuple, np.ndarray)): + out = [] + for cc in c: + cc = np.asarray(cc).reshape(-1) + assert cc.size == 3, "Each color must be length-3" + out.append(tuple(int(v) for v in cc.tolist())) + return out + # single triplet + c_arr = np.asarray(c).reshape(-1) + if c_arr.size == 3: + return [tuple(int(v) for v in c_arr.tolist())] * max(1, n) + # fallback: red + return [(255, 0, 0)] * max(1, n) + + J = keypoints[0].shape[0] if keypoints else 0 + kpt_colors = _as_color_list(kpt_color, J) + link_colors = _as_color_list(link_color, len(skeleton)) + + def in_bounds(x, y): + return 0 <= x < W and 0 <= y < H + + for kpts, vis, score in zip(keypoints, keypoints_visible, keypoint_scores): + kpts = np.asarray(kpts, float) + vis = np.asarray(vis).reshape(-1).astype(bool) + score = np.asarray(score).reshape(-1) + + # links (draw in RGB; NO channel flip) + for lk, (i, j) in enumerate(skeleton): + if i >= len(kpts) or j >= len(kpts): + continue + if not (vis[i] and vis[j]): + continue + if score[i] < kpt_thr or score[j] < kpt_thr: + continue + + x1, y1 = map(int, np.round(kpts[i])) + x2, y2 = map(int, np.round(kpts[j])) + if not (in_bounds(x1, y1) and in_bounds(x2, y2)): + continue + + cv2.line( + img, + (x1, y1), + (x2, y2), + link_colors[lk % len(link_colors)], + thickness=max(1, self.line_width), + lineType=cv2.LINE_AA, + ) + + # points + for j_idx, (xy, v, s) in enumerate(zip(kpts, vis, score)): + if not v or s < kpt_thr: + continue + x, y = map(int, np.round(xy)) + if not in_bounds(x, y): + continue + + c = kpt_colors[min(j_idx, len(kpt_colors) - 1)] + cv2.circle( + img, (x, y), radius, c, thickness=thickness, lineType=cv2.LINE_AA + ) + if show_kpt_idx: + cv2.putText( + img, + str(j_idx), + (x + radius, y - radius), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + c, + 1, + cv2.LINE_AA, + ) + return img + + +###------------------helpers----------------------- +def batch_unnormalize_image( + images, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375] +): + normalize = transforms.Normalize(mean=mean, std=std) + images[:, 0, :, :] = (images[:, 0, :, :] * normalize.std[0]) + normalize.mean[0] + images[:, 1, :, :] = (images[:, 1, :, :] * normalize.std[1]) + normalize.mean[1] + images[:, 2, :, :] = (images[:, 2, :, :] * normalize.std[2]) + normalize.mean[2] + return images + + +def get_max_preds(batch_heatmaps): + """ + get predictions from score maps + heatmaps: numpy.ndarray([batch_size, num_joints, height, width]) + """ + assert isinstance(batch_heatmaps, np.ndarray), ( + "batch_heatmaps should be numpy.ndarray" + ) + assert batch_heatmaps.ndim == 4, "batch_images should be 4-ndim" + + batch_size = batch_heatmaps.shape[0] + num_joints = batch_heatmaps.shape[1] + width = batch_heatmaps.shape[3] + heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1)) + idx = np.argmax(heatmaps_reshaped, 2) ## B x 17 + maxvals = np.amax(heatmaps_reshaped, 2) ## B x 17 + + maxvals = maxvals.reshape((batch_size, num_joints, 1)) ## B x 17 x 1 + idx = idx.reshape((batch_size, num_joints, 1)) ## B x 17 x 1 + + preds = np.tile(idx, (1, 1, 2)).astype( + np.float32 + ) ## B x 17 x 2, like repeat in pytorch + + preds[:, :, 0] = (preds[:, :, 0]) % width + preds[:, :, 1] = np.floor((preds[:, :, 1]) / width) + + pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2)) + pred_mask = pred_mask.astype(np.float32) + + preds *= pred_mask + return preds, maxvals diff --git a/sapiens/pose/tools/deployment/pytorch2torchscript.py b/sapiens/pose/tools/deployment/pytorch2torchscript.py new file mode 100644 index 0000000000000000000000000000000000000000..7b71864925379e9e6339a4960f6dc6714da7b396 --- /dev/null +++ b/sapiens/pose/tools/deployment/pytorch2torchscript.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# originally copied from https://www.internalfb.com/code/fbsource/[671aa4920700]/fbcode/xrcia/projects/sapiens/experimental_ghe_import/sapiens2/sapiens/seg/tools/deployment/pytorch2torchscript.py?lines=1-204 + +import argparse +import os + +import torch +import torch._C +import torch.serialization +from sapiens.dense.tools.deployment.pytorch2torchscript import check_torch_version +from sapiens.pose.datasets import parse_pose_metainfo, UDPHeatmap +from sapiens.pose.models import init_model + +torch.manual_seed(3) +TORCH_MINIMUM_VERSION = "1.8.0" + + +def pytorch2torchscript( + model: torch.nn.Module, + input_shape: tuple[int, int, int, int], + device: str, + show_graph: bool = False, + output_file: str = "tmp.pt", + verify: bool = False, +) -> None: + """Export Pytorch model to TorchScript model and verify the outputs are + same between Pytorch and TorchScript. + + Args: + model (nn.Module): Pytorch model we want to export. + input_shape (tuple): Use this input shape to construct + the corresponding dummy input and execute the model. + show_graph (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the + output TorchScript model. Default: `tmp.pt`. + verify (bool): Whether compare the outputs between + Pytorch and TorchScript. Default: False. + """ + + inputs = torch.rand(input_shape).to(device) + + # replace the original forward with forward_dummy + # model.forward = model.forward_dummy + model.eval() + traced_model = torch.jit.trace( + model, + example_inputs=inputs, + check_trace=verify, + ) + + if show_graph: + print(traced_model.graph) + + traced_model.save(output_file) + print(f"Successfully exported TorchScript model: {output_file}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert .pth checkpoint to TorchScript" + ) + parser.add_argument("config", help="test config file path") + parser.add_argument("--checkpoint", help="Checkpoint file") + parser.add_argument( + "--show-graph", action="store_true", help="show TorchScript graph" + ) + parser.add_argument( + "--verify", action="store_true", help="verify the TorchScript model" + ) + parser.add_argument("--output-file", type=str, default="tmp.pt") + parser.add_argument( + "--shape", + type=int, + nargs="+", + default=[1024, 768], + help="input image size (height, width)", + ) + parser.add_argument("--device", default="cuda:0", help="Device used for inference") + args = parser.parse_args() + return args + + +def main() -> None: + args = parse_args() + check_torch_version() + + if len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = ( + 1, + 3, + ) + tuple(args.shape) + else: + raise ValueError("invalid input shape") + + # build the model, load checkpoint + model = init_model(args.config, args.checkpoint, device=args.device) + + ## add pose metainfo to model + num_keypoints = model.cfg.num_keypoints + if num_keypoints == 308: + model.pose_metainfo = parse_pose_metainfo( + dict(from_file="configs/_base_/keypoints308.py") + ) + + ## add codec to model + codec_type = model.cfg.codec.pop("type") + assert codec_type == "UDPHeatmap", "Only support UDPHeatmap" + model.codec = UDPHeatmap(**model.cfg.codec) + + ## create the output directory if it does not exist + output_dir = os.path.dirname(args.output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # convert the PyTorch model to TorchScript model + pytorch2torchscript( + model, + input_shape=input_shape, + device=args.device, + show_graph=args.show_graph, + output_file=args.output_file, + verify=args.verify, + ) + + +if __name__ == "__main__": + main() diff --git a/sapiens/pose/tools/dist_test.sh b/sapiens/pose/tools/dist_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..06fe6c67e46c4ffc2458d561c93465782ba49399 --- /dev/null +++ b/sapiens/pose/tools/dist_test.sh @@ -0,0 +1,23 @@ +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH + +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 + +torchrun \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/test.py \ + $CONFIG \ + $CHECKPOINT \ + ${@:4} diff --git a/sapiens/pose/tools/dist_train.sh b/sapiens/pose/tools/dist_train.sh new file mode 100755 index 0000000000000000000000000000000000000000..c5e50b63e5e3208c89e7ea6420da7c1838c33b85 --- /dev/null +++ b/sapiens/pose/tools/dist_train.sh @@ -0,0 +1,21 @@ +CONFIG=$1 +GPUS=$2 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH + +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 + +torchrun \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/train.py \ + $CONFIG \ + ${@:3} diff --git a/sapiens/pose/tools/test.py b/sapiens/pose/tools/test.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0c4a410854eb62b9c04e823c6b7f3085d8e9c4 --- /dev/null +++ b/sapiens/pose/tools/test.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch + +# pyre-ignore[21]: Cannot find module `sapiens.engine.config` +from sapiens.engine.config import Config, DictAction +from sapiens.engine.runners import * +from sapiens.pose.runners import * + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Train a segmentor") + parser.add_argument("config", help="train config file path") + parser.add_argument("checkpoint", help="checkpoint file") + parser.add_argument("--work-dir", help="the dir to save logs and models") + parser.add_argument( + "--cfg-options", + nargs="+", + action=DictAction, + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) + parser.add_argument("--local_rank", "--local-rank", type=int, default=0) + args = parser.parse_args(argv) + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) + + return args + + +def main(argv: list[str] | None = None) -> None: + args = parse_args(argv) + + # load config + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + cfg.work_dir = args.work_dir + cfg.load_from = args.checkpoint + + # set train to false + cfg.train_dataloader = None + + # start testing + runner_type = cfg.get("runner_type", "BaseRunner") + runner = eval(runner_type).from_cfg(cfg) + runner.test() + + +if __name__ == "__main__": + main() diff --git a/sapiens/pose/tools/train.py b/sapiens/pose/tools/train.py new file mode 100644 index 0000000000000000000000000000000000000000..49f180c008f71fbbf1ad9d5dba4ee9e9282aba2d --- /dev/null +++ b/sapiens/pose/tools/train.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +import torchvision + +# pyre-ignore[21]: Cannot find module `sapiens.engine.config` +from sapiens.engine.config import Config, DictAction +from sapiens.engine.runners import * +from sapiens.pose.runners import * + +torch.set_float32_matmul_precision("high") # A100 gpus +torchvision.disable_beta_transforms_warning() # Disable the beta transforms warning + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Train a segmentor") + parser.add_argument("config", help="train config file path") + parser.add_argument("--work-dir", help="the dir to save logs and models") + parser.add_argument( + "--resume", + nargs="?", + type=str, + const="auto", + help="If specify checkpint path, resume from it, while if not " + "specify, try to auto resume from the latest checkpoint " + "in the work directory.", + ) + parser.add_argument( + "--cfg-options", + nargs="+", + action=DictAction, + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) + parser.add_argument("--local_rank", "--local-rank", type=int, default=0) + args = parser.parse_args(argv) + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) + + return args + + +def main(argv: list[str] | None = None) -> None: + args = parse_args(argv) + + # load config + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + cfg.work_dir = args.work_dir + + # resume training + if args.resume is not None: + cfg.resume = True + cfg.load_from = args.resume + + # start training + runner_type = cfg.get("runner_type", "BaseRunner") + runner = eval(runner_type).from_cfg(cfg) + runner.train() + + +if __name__ == "__main__": + main() diff --git a/sapiens/pose/tools/vis/pose_render_utils.py b/sapiens/pose/tools/vis/pose_render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3da108bb2fe67ab4c4dccc4412e784410bc53d33 --- /dev/null +++ b/sapiens/pose/tools/vis/pose_render_utils.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 +import numpy as np + + +def visualize_keypoints( + image: np.ndarray, # RGB uint8 H,W,3 + keypoints, # list[(J,2)] + keypoints_visible, # list[(J,), {0/1}] + keypoint_scores, # list[(J,)] + *, + radius: int = 4, + thickness: int = -1, + color=(255, 0, 0), + kpt_thr: float = 0.3, + skeleton: list | None = None, # [(i,j)] + kpt_color: list | tuple | np.ndarray | None = None, + link_color: list | tuple | np.ndarray | None = None, + show_kpt_idx: bool = False, +) -> np.ndarray: + img = image.copy() + H, W = img.shape[:2] + + # defaults + if skeleton is None: + skeleton = [] # points only + if kpt_color is None: + kpt_color = color + if link_color is None: + link_color = (0, 255, 0) + + # robust color normalization: supports tuple, list-of-tuples, np.ndarray (N,3) or (3,) + def _as_color_list(c, n): + # torch -> numpy + if hasattr(c, "detach"): + c = c.detach().cpu().numpy() + # numpy -> array + if isinstance(c, np.ndarray): + if c.ndim == 2 and c.shape[1] == 3: # (N,3) palette + return [tuple(int(v) for v in row) for row in c.tolist()] + if c.size == 3: # single (3,) + return [tuple(int(v) for v in c.tolist())] * max(1, n) + # python containers + if isinstance(c, (list, tuple)): + if n and len(c) == n and isinstance(c[0], (list, tuple, np.ndarray)): + out = [] + for cc in c: + cc = np.asarray(cc).reshape(-1) + assert cc.size == 3, "Each color must be length-3" + out.append(tuple(int(v) for v in cc.tolist())) + return out + # single triplet + c_arr = np.asarray(c).reshape(-1) + if c_arr.size == 3: + return [tuple(int(v) for v in c_arr.tolist())] * max(1, n) + # fallback: red + return [(255, 0, 0)] * max(1, n) + + J = keypoints[0].shape[0] if keypoints else 0 + kpt_colors = _as_color_list(kpt_color, J) + link_colors = _as_color_list(link_color, len(skeleton)) + + def in_bounds(x, y): + return 0 <= x < W and 0 <= y < H + + for kpts, vis, score in zip(keypoints, keypoints_visible, keypoint_scores): + kpts = np.asarray(kpts, float) + vis = np.asarray(vis).reshape(-1).astype(bool) + score = np.asarray(score).reshape(-1) + + # links (draw in RGB; NO channel flip) + for lk, (i, j) in enumerate(skeleton): + if i >= len(kpts) or j >= len(kpts): + continue + if not (vis[i] and vis[j]): + continue + if score[i] < kpt_thr or score[j] < kpt_thr: + continue + + x1, y1 = map(int, np.round(kpts[i])) + x2, y2 = map(int, np.round(kpts[j])) + if not (in_bounds(x1, y1) and in_bounds(x2, y2)): + continue + + cv2.line( + img, + (x1, y1), + (x2, y2), + link_colors[lk % len(link_colors)], + thickness=max(1, thickness), + lineType=cv2.LINE_AA, + ) + + # points + for j_idx, (xy, v, s) in enumerate(zip(kpts, vis, score)): + if not v or s < kpt_thr: + continue + x, y = map(int, np.round(xy)) + if not in_bounds(x, y): + continue + + c = kpt_colors[min(j_idx, len(kpt_colors) - 1)] + cv2.circle(img, (x, y), radius, c, thickness=-1, lineType=cv2.LINE_AA) + if show_kpt_idx: + cv2.putText( + img, + str(j_idx), + (x + radius, y - radius), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + c, + 1, + cv2.LINE_AA, + ) + return img diff --git a/sapiens/pose/tools/vis/rtmdet_m_640-8xb32_coco-person.py b/sapiens/pose/tools/vis/rtmdet_m_640-8xb32_coco-person.py new file mode 100644 index 0000000000000000000000000000000000000000..957bee10d76264a4b3b26d29be1a1049af0c31f4 --- /dev/null +++ b/sapiens/pose/tools/vis/rtmdet_m_640-8xb32_coco-person.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +_base_ = "mmdet::rtmdet/rtmdet_m_8xb32-300e_coco.py" + +checkpoint = "https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-m_8xb256-rsb-a1-600e_in1k-ecb3bbd9.pth" # noqa + +model = dict( + backbone=dict( + init_cfg=dict(type="Pretrained", prefix="backbone.", checkpoint=checkpoint) + ), + bbox_head=dict(num_classes=1), + test_cfg=dict( + nms_pre=1000, min_bbox_size=0, score_thr=0.05, nms=None, max_per_img=100 + ), +) + +train_dataloader = dict(dataset=dict(metainfo=dict(classes=("person",)))) + +val_dataloader = dict(dataset=dict(metainfo=dict(classes=("person",)))) +test_dataloader = val_dataloader diff --git a/sapiens/pose/tools/vis/vis_pose.py b/sapiens/pose/tools/vis/vis_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..27679b4550e15a842c75daf3e261ab720161802f --- /dev/null +++ b/sapiens/pose/tools/vis/vis_pose.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import sys +from argparse import ArgumentParser + +# Block mmpretrain: mmdet's reid modules try `import mmpretrain` inside +# try/except ImportError, but mmpretrain's BLIP language_model.py raises +# TypeError (transformers API drift) โ€” escapes the except and kills the process. +# We don't use reid or mmpretrain, so force a clean ImportError. +sys.modules["mmpretrain"] = None + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from sapiens.pose.datasets import parse_pose_metainfo, UDPHeatmap +from sapiens.pose.evaluators import nms +from sapiens.pose.models import init_model +from tqdm import tqdm + +from pose_render_utils import visualize_keypoints + +try: + from mmdet.apis import inference_detector, init_detector + + has_mmdet = True +except (ImportError, ModuleNotFoundError): + has_mmdet = False + + +def mmdet_pipeline(cfg): + from mmdet.datasets import transforms + + if "test_dataloader" not in cfg: + return cfg + pipeline = cfg.test_dataloader.dataset.pipeline + for trans in pipeline: + if trans["type"] in dir(transforms): + trans["type"] = "mmdet." + trans["type"] + return cfg + + +def process_one_image(args, image, detector, model): + image_w, image_h = image.shape[1], image.shape[0] + det_result = inference_detector(detector, image) + pred_instance = det_result.pred_instances.cpu().numpy() + bboxes = np.concatenate( + (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1 + ) + bboxes = bboxes[ + np.logical_and( + pred_instance.labels == 0, ## 0 is the person class + pred_instance.scores > args.bbox_thr, + ) + ] + + bboxes = bboxes[nms(bboxes, args.nms_thr), :4] ## B x 4; x1, y1, x2, y2 + # get bbox from the image size + if bboxes is None or len(bboxes) == 0: + bboxes = np.array([[0, 0, image_w - 1, image_h - 1]], dtype=np.float32) + + inputs_list = [] + data_samples_list = [] + for bbox in bboxes: + data_info = dict(img=image) + data_info["bbox"] = bbox[None] # shape (1, 4) + data_info["bbox_score"] = np.ones(1, dtype=np.float32) # shape (1,) + data = model.pipeline(data_info) + data = model.data_preprocessor(data) + inputs_list.append(data["inputs"]) + data_samples_list.append(data["data_samples"]) + + inputs = torch.cat(inputs_list, dim=0) # B x 3 x H x W + with torch.no_grad(): + pred = model(inputs) # B x 3 x H x W + if model.cfg.val_cfg is not None and model.cfg.val_cfg.get("flip_test", False): + pred_flipped = model(inputs.flip(-1)) # B x 3 x H x W + pred_flipped = pred_flipped.flip(-1) ## B x K x heatmap_H x heatmap_W + flip_indices = model.pose_metainfo["flip_indices"] + assert len(flip_indices) == pred_flipped.shape[1] ## K + pred_flipped = pred_flipped[:, flip_indices] + pred = (pred + pred_flipped) / 2.0 + + # ------------------------------------------ + pred = pred.cpu().numpy() ## B x K x heatmap_H x heatmap_W + keypoints = [] + keypoint_scores = [] + for i, data_samples in enumerate(data_samples_list): + ## kps in crop image + ## keypoints_i is 1 x K x 2 + # keypoint_scores_i is 1 x K + keypoints_i, keypoint_scores_i = model.codec.decode(pred[i]) + input_size = data_samples["meta"]["input_size"] ## 1 x 2, 768 x 1024 + bbox_center = data_samples["meta"]["bbox_center"] ## 1 x 2 + bbox_scale = data_samples["meta"]["bbox_scale"] ## 1 x 2 + + keypoints_i = ( + keypoints_i / input_size * bbox_scale + bbox_center - 0.5 * bbox_scale + ) + keypoints.append(keypoints_i[0]) ## remove fake batch dim + keypoint_scores.append(keypoint_scores_i[0]) ## remove fake batch dim + + return keypoints, keypoint_scores, bboxes + + +# ------------------------------------------------------------------------------- +def main(): + parser = ArgumentParser() + parser.add_argument("det_config", help="Config file for detection") + parser.add_argument("det_checkpoint", help="Checkpoint file for detection") + parser.add_argument("config", help="Config file") + parser.add_argument("checkpoint", help="Checkpoint file") + parser.add_argument("--input", help="Input image dir") + parser.add_argument("--output", default=None, help="Path to output dir") + parser.add_argument("--device", default="cuda:0", help="Device used for inference") + parser.add_argument( + "--radius", type=int, default=3, help="Keypoint radius for visualization" + ) + parser.add_argument( + "--thickness", type=int, default=1, help="Link thickness for visualization" + ) + parser.add_argument( + "--kpt-thr", type=float, default=0.3, help="Visualizing keypoint thresholds" + ) + parser.add_argument( + "--bbox-thr", type=float, default=0.3, help="Bounding box score threshold" + ) + parser.add_argument( + "--nms-thr", type=float, default=0.3, help="IoU threshold for bounding box NMS" + ) + parser.add_argument( + "--no-save-json", + action="store_true", + help="Disable saving per-video predictions JSON (saved by default).", + ) + parser.add_argument( + "--predictions-name", + default=None, + help="Override predictions JSON filename (used by helper for per-chunk writes).", + ) + + args = parser.parse_args() + + model = init_model(args.config, args.checkpoint, device=args.device) + os.makedirs(args.output, exist_ok=True) + + ## add pose metainfo to model + num_keypoints = model.cfg.num_keypoints + if num_keypoints == 308: + model.pose_metainfo = parse_pose_metainfo( + dict(from_file="configs/_base_/keypoints308.py") + ) + + ## add codec to model + codec_type = model.cfg.codec.pop("type") + assert codec_type == "UDPHeatmap", "Only support UDPHeatmap" + model.codec = UDPHeatmap(**model.cfg.codec) + + # build detector + detector = init_detector(args.det_config, args.det_checkpoint, device=args.device) + detector.cfg = mmdet_pipeline(detector.cfg) + + # Get image list + if os.path.isdir(args.input): + input_dir = args.input + image_names = [ + name + for name in sorted(os.listdir(input_dir)) + if name.endswith((".jpg", ".png", ".jpeg")) + ] + else: + with open(args.input, "r") as f: + image_paths = [line.strip() for line in f if line.strip()] + image_names = [os.path.basename(path) for path in image_paths] + input_dir = os.path.dirname(image_paths[0]) + + frames_records = [] + image_size = None + num_keypoints_seen = None + + for image_name in tqdm(image_names, total=len(image_names)): + image_path = os.path.join(input_dir, image_name) + image = cv2.imread(image_path) + + try: + keypoints, keypoint_scores, bboxes = process_one_image( + args, image, detector, model + ) + except Exception as e: + print(f"[vis_pose] inference failed on {image_name}: {e}") + continue + + if image_size is None: + image_size = [int(image.shape[0]), int(image.shape[1])] + if num_keypoints_seen is None and len(keypoints) > 0: + num_keypoints_seen = int(np.asarray(keypoints[0]).shape[0]) + + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + vis_image_rgb = visualize_keypoints( + image=image_rgb, + keypoints=keypoints, + keypoints_visible=np.ones_like(keypoint_scores) > 0, + keypoint_scores=keypoint_scores, + radius=args.radius, + thickness=args.thickness, + kpt_thr=args.kpt_thr, + skeleton=model.pose_metainfo["skeleton_links"], + kpt_color=model.pose_metainfo["keypoint_colors"], + link_color=model.pose_metainfo["skeleton_link_colors"], + ) + vis_image = cv2.cvtColor(vis_image_rgb, cv2.COLOR_RGB2BGR) + save_path = os.path.join(args.output, image_name) + cv2.imwrite(save_path, vis_image) + + if not args.no_save_json: + try: + instances = [] + for kpts, scores, bbox in zip(keypoints, keypoint_scores, bboxes): + instances.append({ + "bbox": [float(v) for v in np.asarray(bbox).reshape(-1)[:4]], + "keypoints": np.asarray(kpts, dtype=float).tolist(), + "keypoint_scores": np.asarray(scores, dtype=float).reshape(-1).tolist(), + }) + frames_records.append({ + "image_name": image_name, + "instances": instances, + }) + except Exception as e: + print(f"[vis_pose] json record failed on {image_name}: {e}") + + if not args.no_save_json: + nn = os.path.basename(os.path.normpath(args.output)) + # strip a trailing "_output" suffix so the JSON sidecar name matches the + # video basename (e.g. ".../v3/01/_output/01_predictions.json"). + # `loop.sh` wraps each video output as `